├── .idea ├── MyTraci.iml └── modules.xml ├── LICENSE.txt ├── README.md ├── analyze-excel.ipynb ├── constants ├── constants-grid.json └── constants.json ├── data ├── 1_1_intersections.net.xml ├── 2_2_intersections.net.xml ├── 3_3_intersections.net.xml └── 4_4_intersections.net.xml ├── data_collector.py ├── environments ├── environment.py └── intersections.py ├── images ├── 1_1-grid.png ├── 2_2-grid.png ├── marl.png ├── ppo.png └── sumo.png ├── main.py ├── models ├── cycle_rule_set.py ├── ppo_model.py ├── rule_set.py └── timer_rule_set.py ├── run-data.xlsx ├── run_agent_parallel.py ├── utils ├── env_phases.py ├── grid_search.py ├── net_scrape.py ├── random_search.py └── utils.py ├── vis_agent.py └── workers ├── ppo_worker.py ├── rule_worker.py └── worker.py /.idea/MyTraci.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Maxwell Brenner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distributed PPO for Traffic Light Control with Multi-Agent RL 2 | Uses a distributed version of the deep reinforcement learning algorithm [PPO](https://arxiv.org/abs/1707.06347) to control a grid of traffic lights for optimized traffic flow through the system. The traffic enviornment is implemented in the realistic traffic simulation [SUMO](https://sumo.dlr.de/docs/index.html). Multi-agent RL (MARL) is implemented with each traffic light acting as a single agent. 3 | 4 | ## SUMO / Traci 5 | SUMO (**S**imulation of **U**rban **MO**bility) is a continuous road traffic simulation. [TraCI](Thttps://sumo.dlr.de/docs/TraCI.html) (**Tra**ffic **C**ontrol **I**nterface) connects to a SUMO simulation in a programming language (in this case Python) to allow for feeding inputs and recieving outputs. 6 | 7 | ![SUMO picture](/images/sumo.png) 8 | 9 | The environments implemented for this problem are grids where an intersection is controlled by a traffic light. Either NS cars can go or EW cars, at a time. So each intersection has 2 possible configurations. Cars spawn at the edges and then have a predefined destination edge where they despawn. 10 | 11 | ## Models 12 | ### PPO 13 | [Proximal Policy Optimization](https://openai.com/blog/openai-baselines-ppo/) (PPO) is a policy gradient based reinforcement learning algorithm created by OpenAI. It is efficient and fairly simple and tends to be the goto for RL nowadays. There are a lot of great tutorials and code on PPO ([this](https://medium.com/@jonathan_hui/rl-proximal-policy-optimization-ppo-explained-77f014ec3f12), [this](https://github.com/ShangtongZhang/DeepRL/blob/master/deep_rl/agent/PPO_agent.py) and many more). 14 | 15 | ![PPO code](/images/ppo.png) 16 | 17 | ### DPPO 18 | #### DISCLAIMER: The DPPO implementation here is incorrect. It does not properly aggregate the gradients during training. 19 | 20 | Distributed algorithms use multiple processes to speed up existing algorithms such as PPO. There arent as many simple resources on DPPO but I used a few different sources noted in my code such as [this repo](https://github.com/alexis-jacq/Pytorch-DPPO). I first implemented single-agent RL which means that in a single environment there is only one agent. In this apps case, this means all traffic lights are controlled by one agent. However, that means as the grid size increases the action size increases exponentially. 21 | 22 | ### MARL 23 | ![1x1 grid](/images/1_1-grid.png) 24 | 25 | For example, the action space for a single intersection is 2 as either the NS light can be green or the EW light can be greed. 26 | 27 | ![2x2 grid](/images/2_2-grid.png) 28 | 29 | The number of actions for a 2x2 grid is 2^4 = 16. For example if 1 means NS is green and 0 means EW is green. Then 1011 in binary (13 in decimal) would mean that 3 of the 4 intersections are NS green. This can become a problem as the grid gets even larger. 30 | 31 | ![MARL](/images/marl.png) 32 | 33 | [Cooperative MARL](https://arxiv.org/abs/1908.03963) is a way to fix this ["curse of dimensionality"](https://en.wikipedia.org/wiki/Curse_of_dimensionality) problem. With MARL there are multiple agents in the environment. And in this case each agent controls a single intersection. So now an agent only has 2 possible actions no matter how big the grid gets! MARL also helps with inputs. Instead of a single agent needing to be trained to deal with say 4 states (for a 2x2 grid) it can just deal with one. MARL is a great tool in cases where your problem can run into scaling issues. 34 | 35 | In the case of this repo, I use independent MARL which means each agent does not directly communicate. However, each actor and critic share parameters across all agents. One trick for better cooperation is to share certain info across agents (other than just weights). Reward and states are two popular items to share. This [post](https://bair.berkeley.edu/blog/2018/12/12/rllib/) by Berkeley goes into this more. 36 | 37 | ## How to Run this 38 | ### Depndencies 39 | * numpy 40 | * traci 41 | * sumolib 42 | * scipy 43 | * pytorch 44 | * pandas 45 | 46 | ### Running 47 | Can alter `constants.json` or `constans-grid.json` in /constants to change different hyperparameters. In `main.py` can run experiments with `run_normal` (runs multiple experiments using `constants.json`), `run_random_search` (runs a random search on `constants-grid.json`) or `run_grid_search` (runs a grid search on `constants-grid.json`). Can save and load models. Can also visualize models by running `vis_agent.py` and changing `run(load_model_file=)` to the model file. The 4 envs implemented are 1x1, 2x2, 3x3 and 4x4. 48 | 49 | `shape` is the grid, `rush_hour` can be set to true for 2x2 which adds a kind of rush-hour spawning probability distribution. And `uniform_generation_probability` is the spawn rate for cars when `rush_hour` is false. 50 | ``` 51 | "environment": { 52 | "shape": [4, 4], 53 | "rush_hour": false, 54 | "uniform_generation_probability": 0.06 55 | }, 56 | ``` 57 | 58 | Change `num_workers` based on how many processes you want active for the distribibuted part of DPPO. 59 | ``` 60 | "parallel":{ 61 | "num_workers": 8 62 | } 63 | ``` 64 | Finally, you can change the `agent_type` to `rule` if you want a simple rule based agent to run (which just changes each light after a set amount of time). And can change `single_agent` to true to not use MARL. 65 | 66 | ``` 67 | "agent": { 68 | "agent_type": "ppo", 69 | "single_agent": false 70 | }, 71 | ``` 72 | -------------------------------------------------------------------------------- /analyze-excel.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import pandas as pd" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": { 18 | "collapsed": true 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "data = pd.read_excel('run-data.xlsx')" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "data": { 32 | "text/plain": [ 33 | "(12, 33)" 34 | ] 35 | }, 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "output_type": "execute_result" 39 | } 40 | ], 41 | "source": [ 42 | "data.shape" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 4, 48 | "metadata": { 49 | "collapsed": true 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "# data.drop(len(data)-1, inplace=True)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 5, 59 | "metadata": { 60 | "collapsed": true 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "# data.to_excel('../run-data.xlsx', index=False)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 6, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "data": { 74 | "text/html": [ 75 | "
\n", 76 | "\n", 89 | "\n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | "
rolloutwaitingTimetimerewEp_num_train_rolloutsEp_rollout_lengthEp_eval_freqEp_eval_num_epsEp_max_ep_stepsEp_generation_ep_steps...P_value_loss_coefP_hidden_layer_sizeR_rule_setR_rule_set_paramsEn_shapeEn_rush_hourEn_uniform_generation_probabilityM_reward_interpolationM_state_interpolationP_num_workers
0250513.50916738.6214791139.5337783002048502512501000...132cycle{'cycle_length': 20, 'NS_mult': 1.0, 'EW_mult'...[3, 3]False0.05008
1300511.36333340.8137511141.9401113002048502512501000...132cycle{'cycle_length': 20, 'NS_mult': 1.0, 'EW_mult'...[3, 3]False0.05008
2308491.54583338.0349521145.1475563002048502512501000...132cycle{'cycle_length': 20, 'NS_mult': 1.0, 'EW_mult'...[3, 3]False0.05008
3250520.53083338.9719261138.7324443002048502512501000...132cycle{'cycle_length': 20, 'NS_mult': 1.0, 'EW_mult'...[3, 3]False0.05008
4250488.76333339.4283851146.7114443002048502512501000...132cycle{'cycle_length': 20, 'NS_mult': 1.0, 'EW_mult'...[3, 3]False0.05008
\n", 239 | "

5 rows × 33 columns

\n", 240 | "
" 241 | ], 242 | "text/plain": [ 243 | " rollout waitingTime time rew Ep_num_train_rollouts \\\n", 244 | "0 250 513.509167 38.621479 1139.533778 300 \n", 245 | "1 300 511.363333 40.813751 1141.940111 300 \n", 246 | "2 308 491.545833 38.034952 1145.147556 300 \n", 247 | "3 250 520.530833 38.971926 1138.732444 300 \n", 248 | "4 250 488.763333 39.428385 1146.711444 300 \n", 249 | "\n", 250 | " Ep_rollout_length Ep_eval_freq Ep_eval_num_eps Ep_max_ep_steps \\\n", 251 | "0 2048 50 25 1250 \n", 252 | "1 2048 50 25 1250 \n", 253 | "2 2048 50 25 1250 \n", 254 | "3 2048 50 25 1250 \n", 255 | "4 2048 50 25 1250 \n", 256 | "\n", 257 | " Ep_generation_ep_steps ... P_value_loss_coef P_hidden_layer_size \\\n", 258 | "0 1000 ... 1 32 \n", 259 | "1 1000 ... 1 32 \n", 260 | "2 1000 ... 1 32 \n", 261 | "3 1000 ... 1 32 \n", 262 | "4 1000 ... 1 32 \n", 263 | "\n", 264 | " R_rule_set R_rule_set_params En_shape \\\n", 265 | "0 cycle {'cycle_length': 20, 'NS_mult': 1.0, 'EW_mult'... [3, 3] \n", 266 | "1 cycle {'cycle_length': 20, 'NS_mult': 1.0, 'EW_mult'... [3, 3] \n", 267 | "2 cycle {'cycle_length': 20, 'NS_mult': 1.0, 'EW_mult'... [3, 3] \n", 268 | "3 cycle {'cycle_length': 20, 'NS_mult': 1.0, 'EW_mult'... [3, 3] \n", 269 | "4 cycle {'cycle_length': 20, 'NS_mult': 1.0, 'EW_mult'... [3, 3] \n", 270 | "\n", 271 | " En_rush_hour En_uniform_generation_probability M_reward_interpolation \\\n", 272 | "0 False 0.05 0 \n", 273 | "1 False 0.05 0 \n", 274 | "2 False 0.05 0 \n", 275 | "3 False 0.05 0 \n", 276 | "4 False 0.05 0 \n", 277 | "\n", 278 | " M_state_interpolation P_num_workers \n", 279 | "0 0 8 \n", 280 | "1 0 8 \n", 281 | "2 0 8 \n", 282 | "3 0 8 \n", 283 | "4 0 8 \n", 284 | "\n", 285 | "[5 rows x 33 columns]" 286 | ] 287 | }, 288 | "execution_count": 6, 289 | "metadata": {}, 290 | "output_type": "execute_result" 291 | } 292 | ], 293 | "source": [ 294 | "data.head()" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "metadata": {}, 300 | "source": [ 301 | "# Notes\n", 302 | "- rollout_length: at least 1024 (2048 is better)\n", 303 | "- gae_tau: at most 0.9 (maybe lower)\n", 304 | "- entropy_weight: at least 0.1 (test higher)\n", 305 | "- minibatch_size: at least 128 (maybe higher)\n", 306 | "- optimization_epochs: 5 but 10 is very close\n", 307 | "- ppo_ratio_clip: 0.3 but 0.2 is close\n", 308 | "- clip_grads: true\n", 309 | "- value_loss_coef: 1.0\n", 310 | "- reward_interpolation: 0 \n", 311 | "- state_interpolation: anything\n", 312 | "- warmup_ep_steps: anything (0)\n", 313 | "\n", 314 | "\n", 315 | "- NS_mult: 1.0 or 1.5\n", 316 | "- phase_end_offset: any (0 - 100)\n", 317 | "- multi intersection cycle length: 20\n", 318 | "- (best combo 1.5, 100, 20)" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": 8, 324 | "metadata": {}, 325 | "outputs": [ 326 | { 327 | "data": { 328 | "text/plain": [ 329 | "505.47277777777776" 330 | ] 331 | }, 332 | "execution_count": 8, 333 | "metadata": {}, 334 | "output_type": "execute_result" 335 | } 336 | ], 337 | "source": [ 338 | "data[data['Ep_warmup_ep_steps'] == 0]['waitingTime'].mean()" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 9, 344 | "metadata": { 345 | "scrolled": false 346 | }, 347 | "outputs": [ 348 | { 349 | "data": { 350 | "text/plain": [ 351 | "509.30361111111114" 352 | ] 353 | }, 354 | "execution_count": 9, 355 | "metadata": {}, 356 | "output_type": "execute_result" 357 | } 358 | ], 359 | "source": [ 360 | "data[data['Ep_warmup_ep_steps'] == 50]['waitingTime'].mean()" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 10, 366 | "metadata": {}, 367 | "outputs": [ 368 | { 369 | "data": { 370 | "text/plain": [ 371 | "510.13444444444445" 372 | ] 373 | }, 374 | "execution_count": 10, 375 | "metadata": {}, 376 | "output_type": "execute_result" 377 | } 378 | ], 379 | "source": [ 380 | "data[data['Ep_warmup_ep_steps'] == 100]['waitingTime'].mean()" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 11, 386 | "metadata": {}, 387 | "outputs": [ 388 | { 389 | "data": { 390 | "text/plain": [ 391 | "520.2569444444445" 392 | ] 393 | }, 394 | "execution_count": 11, 395 | "metadata": {}, 396 | "output_type": "execute_result" 397 | } 398 | ], 399 | "source": [ 400 | "data[data['Ep_warmup_ep_steps'] == 200]['waitingTime'].mean()" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": 11, 406 | "metadata": {}, 407 | "outputs": [ 408 | { 409 | "data": { 410 | "text/plain": [ 411 | "533.9041666666666" 412 | ] 413 | }, 414 | "execution_count": 11, 415 | "metadata": {}, 416 | "output_type": "execute_result" 417 | } 418 | ], 419 | "source": [ 420 | "data[data['M_state_interpolation'] == 1]['waitingTime'].mean()" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "metadata": { 427 | "collapsed": true 428 | }, 429 | "outputs": [], 430 | "source": [] 431 | } 432 | ], 433 | "metadata": { 434 | "kernelspec": { 435 | "display_name": "Python 3", 436 | "language": "python", 437 | "name": "python3" 438 | }, 439 | "language_info": { 440 | "codemirror_mode": { 441 | "name": "ipython", 442 | "version": 3 443 | }, 444 | "file_extension": ".py", 445 | "mimetype": "text/x-python", 446 | "name": "python", 447 | "nbconvert_exporter": "python", 448 | "pygments_lexer": "ipython3", 449 | "version": "3.6.8" 450 | } 451 | }, 452 | "nbformat": 4, 453 | "nbformat_minor": 2 454 | } 455 | -------------------------------------------------------------------------------- /constants/constants-grid.json: -------------------------------------------------------------------------------- 1 | { 2 | "episode":{ 3 | "num_train_rollouts": [300], 4 | "rollout_length": [2048], 5 | "eval_freq": [50], 6 | "eval_num_eps": [25], 7 | "max_ep_steps": [1250], 8 | "generation_ep_steps": [1000], 9 | "warmup_ep_steps": [0], 10 | "test_num_eps": [50] 11 | }, 12 | "agent": { 13 | "agent_type": ["rule"], 14 | "single_agent": [false] 15 | }, 16 | "ppo":{ 17 | "gae_tau": [0.85], 18 | "entropy_weight": [0.2], 19 | "minibatch_size": [256], 20 | "optimization_epochs": [5], 21 | "ppo_ratio_clip": [0.3], 22 | "discount": [0.99], 23 | "learning_rate": [1e-4], 24 | "clip_grads": [true], 25 | "gradient_clip": [1.0], 26 | "value_loss_coef": [1.0], 27 | "hidden_layer_size": [32] 28 | }, 29 | "rule": { 30 | "rule_set": ["cycle"], 31 | "rule_set_params": [{"cycle_length": 40, "NS_mult": 1.0, "EW_mult": 1.0, "phase_end_offset": 100}] 32 | }, 33 | "environment": { 34 | "shape": [[4, 4]], 35 | "rush_hour": [false], 36 | "uniform_generation_probability": [0.06] 37 | }, 38 | "multiagent": { 39 | "reward_interpolation": [0], 40 | "state_interpolation": [0] 41 | }, 42 | "parallel":{ 43 | "num_workers": [8] 44 | } 45 | } -------------------------------------------------------------------------------- /constants/constants.json: -------------------------------------------------------------------------------- 1 | { 2 | "episode":{ 3 | "num_train_rollouts": 300, 4 | "rollout_length": 2048, 5 | "eval_freq": 50, 6 | "eval_num_eps": 25, 7 | "max_ep_steps": 1250, 8 | "generation_ep_steps": 1000, 9 | "warmup_ep_steps": 0, 10 | "test_num_eps": 50 11 | }, 12 | "agent": { 13 | "agent_type": "ppo", 14 | "single_agent": false 15 | }, 16 | "ppo":{ 17 | "gae_tau": 0.85, 18 | "entropy_weight": 0.2, 19 | "minibatch_size": 256, 20 | "optimization_epochs": 5, 21 | "ppo_ratio_clip": 0.3, 22 | "discount": 0.99, 23 | "learning_rate": 1e-4, 24 | "clip_grads": true, 25 | "gradient_clip": 1.0, 26 | "value_loss_coef": 1.0, 27 | "hidden_layer_size": 32 28 | }, 29 | "rule": { 30 | "rule_set": "cycle", 31 | "rule_set_params": {"cycle_length": 40, "NS_mult": 1.0, "EW_mult": 1.0, "phase_end_offset": 100} 32 | }, 33 | "environment": { 34 | "shape": [4, 4], 35 | "rush_hour": false, 36 | "uniform_generation_probability": 0.06 37 | }, 38 | "multiagent": { 39 | "reward_interpolation": 0, 40 | "state_interpolation": 0 41 | }, 42 | "parallel":{ 43 | "num_workers": 8 44 | } 45 | } -------------------------------------------------------------------------------- /data/1_1_intersections.net.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /data/2_2_intersections.net.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | -------------------------------------------------------------------------------- /data_collector.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Manager, Lock 2 | import pandas as pd 3 | import os.path 4 | from os import path 5 | from collections import OrderedDict 6 | import time 7 | import torch 8 | from copy import deepcopy 9 | 10 | 11 | # What data is possible to be collected 12 | POSSIBLE_DATA = ['time', 13 | 'rew', 14 | 'waitingTime'] 15 | 16 | # Keys that must be added 17 | _EVAL_DEF_KEYS = ['rollout'] 18 | _TEST_DEF_KEYS = [] 19 | 20 | _GLOBAL_DATA = {'time': {'header': 'total time taken(m)', 'init_value': 0, 'add_type': 'set'}} 21 | _EVAL_DATA = {'rew': {'header': 'eval_best_rew', 'init_value': float('-inf'), 'add_type': 'set_on_greater'}, 22 | 'waitingTime': {'header': 'eval_best_waiting_time', 'init_value': float('inf'), 'add_type': 'set_on_less'}, 23 | 'rollout': {'header': 'eval_train_rollout', 'init_value': float('inf'), 'add_type': 'set_on_less'}} 24 | _TEST_DATA = {'rew': {'header': 'test_avg_rew', 'init_value': [], 'add_type': 'append'}, 25 | 'waitingTime': {'header': 'test_avg_waiting_time', 'init_value': [], 'add_type': 'append'}} 26 | 27 | class DataCollector: 28 | # Will make a df or add to an existing df (if a file already exists in the path) 29 | def __init__(self, data_keys, mvp_key, constants, eval_or_test, df_path, overwrite, verbose): 30 | assert len(set(data_keys)) == len(data_keys), 'There are duplicate keys in data_keys: {}'.format(data_keys) 31 | # self.manager = Manager() 32 | self.data_keys = deepcopy(data_keys) 33 | self.eval_or_test = eval_or_test 34 | if self.eval_or_test == 'eval': 35 | assert mvp_key in self.data_keys, 'Need MVP key in the set of sent in data keys: {} not in {}'.format(mvp_key, self.data_keys) 36 | if self.data_keys[0] != mvp_key: 37 | self.data_keys.remove(mvp_key) 38 | self.data_keys.insert(0, mvp_key) 39 | self._add_def_keys() # add def keys 40 | assert len(set(data_keys)) == len(data_keys), 'There are duplicate keys in data_keys: {}'.format( 41 | self.data_keys) 42 | self.mvp_key = mvp_key 43 | self.constants = constants 44 | self._make_data_dict() 45 | self.df_path = df_path 46 | # If overwite then make a new one no matter if it already exists, other wise check the path 47 | self.df = self._check_if_df_exists() if not overwrite else self._make_df() 48 | self.verbose = verbose 49 | self.save_model_path = None 50 | # self.lock = Lock() 51 | 52 | def set_save_model_path(self, save_model_path): 53 | self.save_model_path = save_model_path 54 | 55 | # Check to make sure def keys are in the new_data 56 | def _check_def_keys(self, new_data): 57 | li = _EVAL_DEF_KEYS if self.eval_or_test == 'eval' else _TEST_DEF_KEYS 58 | for k in li: 59 | assert k in new_data, 'Didnt return the def key {} with new data for collection'.format(k) 60 | 61 | # This is for collecting info at the end of an ep (or chunk of episodes), check MVP value if eval to see if I should 62 | # add this set of data, because during eval mode I only save data from eps where the MVP value is better than before 63 | def collect_ep(self, new_data, model_state=None, ep_count=None): 64 | self._check_def_keys(new_data) 65 | new_best = False 66 | set_values = False 67 | if self.eval_or_test == 'eval': 68 | assert self.mvp_key in new_data, 'MVP key must be in data for episode collection.' 69 | if self.info_dict[self.mvp_key]['add_type'] == 'set_on_greater' and new_data[self.mvp_key] > self.data[self.mvp_key]: 70 | new_best = True 71 | set_values = True 72 | elif self.info_dict[self.mvp_key]['add_type'] == 'set_on_less' and new_data[self.mvp_key] < self.data[self.mvp_key]: 73 | new_best = True 74 | set_values = True 75 | # Save if new_best 76 | if new_best and model_state: 77 | self._save_model(model_state) 78 | else: 79 | set_values = True 80 | 81 | if set_values: 82 | for key in new_data: self._add_value(key, new_data[key]) 83 | 84 | self._print(new_data, new_best, ep_count) 85 | 86 | # Whenever (only if eval) a new best is reached 87 | def _save_model(self, model_state): 88 | if self.save_model_path: 89 | torch.save(model_state, self.save_model_path) 90 | 91 | def _print(self, new_data, new_best, ep_count): 92 | if not self.verbose: return 93 | prepend = '' 94 | if self.eval_or_test == 'eval': 95 | prepend = 'NEW BEST ' if new_best else '' 96 | prepend += 'on rollout: {} '.format(new_data['rollout']) 97 | if ep_count is not None: 98 | prepend = 'Current test episode: {} '.format(ep_count) 99 | all_print = prepend 100 | for k in list(self.data.keys()): # So its in the correct order 101 | if k in new_data and k != 'rollout': 102 | all_print += '{}: {:.2f} '.format(k, new_data[k]) 103 | print(all_print) 104 | 105 | def print_summary(self, exp1, exp2=None): 106 | if exp2 == None: 107 | all_print = 'Experiment {} Summary: '.format(exp1) 108 | else: 109 | all_print = 'Experiment {}.{} Summary: '.format(exp1, exp2) 110 | for k, v in list(self.data.items()): 111 | if k not in _EVAL_DEF_KEYS + _TEST_DEF_KEYS + POSSIBLE_DATA: 112 | continue 113 | if isinstance(v, float): 114 | all_print += '{}: {:.2f} '.format(k, v) 115 | else: 116 | all_print += '{}: {} '.format(k, v) 117 | print(all_print) 118 | 119 | def _add_value(self, k, v): 120 | if k not in self.data: return 121 | add_type = self.info_dict[k]['add_type'] 122 | # with self.lock: 123 | if add_type == 'append': 124 | # print() 125 | self.data[k].append(v) 126 | elif add_type == 'set_on_greater' or add_type == 'set_on_less': self.data[k] = v 127 | else: assert 1 == 0, add_type 128 | 129 | # print(self.data[k]) 130 | 131 | def _add_def_keys(self): 132 | li = _EVAL_DEF_KEYS if self.eval_or_test == 'eval' else _TEST_DEF_KEYS 133 | for k in li[::-1]: 134 | self.data_keys.insert(0, k) 135 | 136 | def _check_if_df_exists(self): 137 | # If it eists then return it 138 | if path.exists(self.df_path): 139 | return pd.read_excel(self.df_path) 140 | else: 141 | return self._make_df() 142 | 143 | def _make_df(self): 144 | df = pd.DataFrame(columns=list(self.data.keys())) 145 | df.to_excel(self.df_path, index=False, header=list(self.data.keys())) 146 | return df 147 | 148 | def _add_hyp_param_dict(self): 149 | def add_hyp_param_dict(append_letter, dic): 150 | for k, v in list(dic.items()): 151 | self.data[append_letter + '_' + k] = v 152 | add_hyp_param_dict('Ep', self.constants['episode']) 153 | add_hyp_param_dict('A', self.constants['agent']) 154 | add_hyp_param_dict('P', self.constants['ppo']) 155 | add_hyp_param_dict('R', self.constants['rule']) 156 | add_hyp_param_dict('En', self.constants['environment']) 157 | add_hyp_param_dict('M', self.constants['multiagent']) 158 | add_hyp_param_dict('P', self.constants['parallel']) 159 | return self.data 160 | 161 | def _append_to_df(self): 162 | assert self.df is not None, 'At this point df should not be none.' 163 | dat = {k: v for k, v in list(self.data.items())} 164 | self.df = self.df.append(dat, ignore_index=True) 165 | 166 | def _write_to_excel(self): 167 | self.df.to_excel(self.df_path, index=False) 168 | 169 | # Because one data collector is made for all runs (ie a new one is NOT made per experiment) need to refresh 170 | def _refresh_data_store(self): 171 | self._make_data_dict() 172 | 173 | # This is just for getting the mean for array objects in the data dict 174 | def process_data(self): 175 | for k, v in list(self.data.items()): 176 | if k in self.info_dict and self.info_dict[k]['add_type'] == 'append': 177 | self.data[k] = sum(v) / len(v) 178 | 179 | # Signals to this object that it can add to df and and refresh the data store (call when threads are done) 180 | def done_with_experiment(self): 181 | self._append_to_df() 182 | self._write_to_excel() 183 | self._refresh_data_store() 184 | 185 | def _make_data_dict(self): 186 | self.info_dict = _GLOBAL_DATA.copy() 187 | manager = Manager() 188 | self.data = manager.dict() 189 | # Add values to collect 190 | for key in self.data_keys: 191 | # Check eval or test first then global then error 192 | if self.eval_or_test == 'eval' and key in _EVAL_DATA: 193 | self.data[key] = _EVAL_DATA[key]['init_value'] 194 | self.info_dict[key] = _EVAL_DATA[key].copy() 195 | elif self.eval_or_test == 'test' and key in _TEST_DATA: 196 | self.data[key] = _TEST_DATA[key]['init_value'] 197 | self.info_dict[key] = _TEST_DATA[key].copy() 198 | elif key in _GLOBAL_DATA: 199 | self.data[key] = _GLOBAL_DATA[key]['init_value'] 200 | self.info_dict[key] = _GLOBAL_DATA[key].copy() 201 | else: 202 | raise ValueError('Data collector was sent {} which isnt data that can be collected'.format(key)) 203 | if self.data[key] == []: self.data[key] = manager.list() 204 | # Add hyp param values 205 | self._add_hyp_param_dict() 206 | return self.data, self.info_dict 207 | 208 | def start_timer(self): 209 | self.start_time = time.time() 210 | 211 | def end_timer(self, printIt): 212 | end_time = time.time() 213 | total_time = (end_time - self.start_time) / 60. 214 | if printIt: 215 | print(' - Time taken (m): {:.2f} - '.format(total_time)) 216 | if 'time' in self.data: 217 | self.data['time'] = total_time 218 | -------------------------------------------------------------------------------- /environments/environment.py: -------------------------------------------------------------------------------- 1 | import traci 2 | import numpy as np 3 | from utils.net_scrape import * 4 | 5 | 6 | # Base class 7 | class Environment: 8 | def __init__(self, constants, device, agent_ID, eval_agent, net_path, vis=False): 9 | self.constants = constants 10 | self.device = device 11 | self.agent_ID = agent_ID 12 | self.eval_agent = eval_agent 13 | # For sumo connection 14 | self.conn_label = 'label_' + str(self.agent_ID) 15 | self.net_path = net_path 16 | self.vis = vis 17 | self.phases = None 18 | self.agent_type = constants['agent']['agent_type'] 19 | self.single_agent = constants['agent']['single_agent'] 20 | self.intersections = get_intersections(net_path) 21 | # for adding to state 22 | self.intersections_index = {intersection: i for i, intersection in enumerate(self.intersections)} 23 | # for state discounting/interpolation 24 | self.neighborhoods, self.max_num_neighbors = get_intersection_neighborhoods(net_path) 25 | 26 | def _make_state(self): 27 | if self.agent_type == 'rule': return {} 28 | if self.single_agent: return [] 29 | else: 30 | if self.constants['multiagent']['state_interpolation'] == 0: 31 | return [[] for _ in range(len(self.intersections))] 32 | else: return {} # if state disc is 0 then return [[], [], ...] ow dict 33 | 34 | def _add_to_state(self, state, value, key, intersection): 35 | if self.agent_type == 'rule': 36 | if intersection: 37 | if intersection not in state: state[intersection] = {} 38 | state[intersection][key] = value 39 | else: 40 | state[key] = value # global props for 41 | else: 42 | if self.single_agent: 43 | if isinstance(value, list): 44 | state.extend(value) 45 | else: 46 | state.append(value) 47 | else: 48 | if self.constants['multiagent']['state_interpolation'] == 0: 49 | if isinstance(value, list): 50 | state[self.intersections_index[intersection]].extend(value) 51 | else: 52 | state[self.intersections_index[intersection]].append(value) 53 | else: 54 | if intersection not in state: state[intersection] = [] 55 | if isinstance(value, list): 56 | state[intersection].extend(value) 57 | else: 58 | state[intersection].append(value) 59 | 60 | def _process_state(self, state): 61 | if not self.agent_type == 'rule': 62 | if self.single_agent: 63 | return np.expand_dims(np.array(state), axis=0) 64 | else: 65 | return np.array(state) 66 | return state 67 | 68 | def _open_connection(self): 69 | raise NotImplementedError 70 | 71 | def _close_connection(self): 72 | self.connection.close() 73 | del traci.main._connections[self.conn_label] 74 | 75 | def reset(self): 76 | # If there is a conn to close, then close it 77 | if self.conn_label in traci.main._connections: 78 | self._close_connection() 79 | # Start a new one 80 | self._open_connection() 81 | return self.get_state() 82 | 83 | # if action is a single number then convert to array of binary 84 | # and convert to dict 85 | # assume a is numpy array 86 | def _process_action(self, a): 87 | action = a.copy() 88 | if self.single_agent and self.agent_type != 'rule': 89 | action = '{0:0b}'.format(a[0]) 90 | action = action.zfill(len(self.intersections)) 91 | action = [int(c) for c in action] 92 | return {intersection: action[i] for i, intersection in enumerate(self.intersections)} 93 | 94 | def step(self, a, ep_step, get_global_reward, def_agent=False): 95 | action = self._process_action(a) 96 | if not def_agent: 97 | self._execute_action(action) 98 | self.connection.simulationStep() 99 | s_ = self.get_state() 100 | r = self.get_reward(get_global_reward) 101 | # Check if done (and if so reset) 102 | done = False 103 | if self.connection.simulation.getMinExpectedNumber() <= 0 or ep_step >= self.constants['episode']['max_ep_steps']: 104 | # Just close the conn without restarting if eval agent 105 | if self.eval_agent: 106 | self._close_connection() 107 | else: 108 | s_ = self.reset() 109 | done = True 110 | return s_, r, done 111 | 112 | def get_state(self): 113 | raise NotImplementedError 114 | 115 | def get_reward(self, get_global): 116 | raise NotImplementedError 117 | 118 | def _execute_action(self, action): 119 | raise NotImplementedError 120 | 121 | def _generate_configfile(self): 122 | raise NotImplementedError 123 | 124 | def _generate_routefile(self): 125 | raise NotImplementedError 126 | 127 | def _generate_addfile(self): 128 | raise NotImplementedError 129 | 130 | -------------------------------------------------------------------------------- /environments/intersections.py: -------------------------------------------------------------------------------- 1 | from numpy.random import choice 2 | import random 3 | from sumolib import checkBinary 4 | import traci 5 | import numpy as np 6 | from environments.environment import Environment 7 | from collections import OrderedDict 8 | from utils.env_phases import get_phases, get_current_phase_probs 9 | from utils.net_scrape import * 10 | 11 | 12 | ''' 13 | Assumptions: 14 | 1. Single lane roads 15 | 2. Grid like layout, where gen or rem nodes only have one possible edge in and out 16 | ''' 17 | 18 | # per agent 19 | PER_AGENT_STATE_SIZE = 6 20 | GLOBAL_STATE_SIZE = 1 21 | # per agent 22 | ACTION_SIZE = 2 23 | 24 | VEH_LENGTH = 5 25 | VEH_MIN_GAP = 2.5 26 | DET_LENGTH_IN_CARS = 20 27 | 28 | 29 | def get_rel_net_path(phase_id): 30 | return phase_id.replace('_rush_hour', '') + '.net.xml' 31 | 32 | def get_env_name(constants): 33 | shape = constants['environment']['shape'] 34 | return '{}_{}_intersections'.format(shape[0], shape[1]) 35 | 36 | class IntersectionsEnv(Environment): 37 | def __init__(self, constants, device, agent_ID, eval_agent, net_path, vis=False): 38 | super(IntersectionsEnv, self).__init__(constants, device, agent_ID, eval_agent, net_path, vis) 39 | # For file names 40 | self.env_name = get_env_name(constants) 41 | self.phases = get_phases(constants['environment'], net_path) 42 | self.node_edge_dic = get_node_edge_dict(self.net_path) 43 | self._generate_addfile() 44 | # Calc intersection distances for reward calc 45 | self.distances = get_cartesian_intersection_distances(net_path) 46 | 47 | def _open_connection(self): 48 | self._generate_routefile() 49 | sumoB = checkBinary('sumo' if not self.vis else 'sumo-gui') 50 | # Need to edit .sumocfg to have the right route file 51 | self._generate_configfile() 52 | traci.start([sumoB, "-c", "data/{}_{}.sumocfg".format(self.env_name, self.agent_ID)], label=self.conn_label) 53 | self.connection = traci.getConnection(self.conn_label) 54 | 55 | def _get_sim_step(self, normalize): 56 | sim_step = self.connection.simulation.getTime() 57 | if normalize: sim_step /= (self.constants['episode']['max_ep_steps'] / 10.) # Normalized between 0 and 10 58 | return sim_step 59 | 60 | # todo: Work with normalization 61 | def get_state(self): 62 | # State is made of the jam length for each detector, current phase of each intersection, elapsed time 63 | # for the current phase of each intersection and the current ep step 64 | state = self._make_state() 65 | # Normalize if not rule 66 | normalize = True if self.agent_type != 'rule' else False 67 | # Get sim step 68 | sim_step = self._get_sim_step(normalize) 69 | for intersection, dets in list(self.intersection_dets.items()): 70 | # Jam length - WARNING: on sim step 0 this seems to have an issue so just return all zeros 71 | jam_length = [self.connection.lanearea.getJamLengthVehicle(det) for det in dets] if self.connection.simulation.getTime() != 0 else [0] * len(dets) 72 | self._add_to_state(state, jam_length, key='jam_length', intersection=intersection) 73 | # Current phase 74 | curr_phase = self.connection.trafficlight.getPhase(intersection) 75 | self._add_to_state(state, curr_phase, key='curr_phase', intersection=intersection) 76 | # Elapsed time of current phase 77 | elapsed_phase_time = self.connection.trafficlight.getPhaseDuration(intersection) - \ 78 | (self.connection.trafficlight.getNextSwitch(intersection) - 79 | self.connection.simulation.getTime()) 80 | if normalize: elapsed_phase_time /= 10. # Slight normalization 81 | self._add_to_state(state, elapsed_phase_time, key='elapsed_phase_time', intersection=intersection) 82 | # Add global param of current sim step 83 | if not self.single_agent and self.agent_type != 'rule': 84 | self._add_to_state(state, sim_step, key='sim_step', intersection=intersection) 85 | # DeMorgan's law of above 86 | if self.single_agent or self.agent_type == 'rule': 87 | self._add_to_state(state, sim_step, key='sim_step', intersection=None) 88 | 89 | # Don't interpolate if single agent or agent type is rule or multiagent but state disc is 0 90 | if self.single_agent or self.agent_type == 'rule' or self.constants['multiagent']['state_interpolation'] == 0: 91 | return self._process_state(state) 92 | 93 | state_size = PER_AGENT_STATE_SIZE + GLOBAL_STATE_SIZE 94 | final_state = [] 95 | for intersection in self.intersections: 96 | neighborhood = self.neighborhoods[intersection] 97 | # Add the intersection state itself 98 | intersection_state = state[intersection] 99 | final_state.append(np.zeros(shape=(state_size * self.max_num_neighbors,))) 100 | # Slice in this intersection's state not discounted 101 | final_state[-1][:state_size] = np.array(intersection_state) 102 | # Then its discounted neighbors 103 | for n, neighbor in enumerate(neighborhood): 104 | assert neighbor != intersection 105 | extension = self.constants['multiagent']['state_interpolation'] * np.array(state[neighbor]) 106 | range_start = (n + 1) * state_size 107 | range_end = range_start + state_size 108 | final_state[-1][range_start:range_end] = extension 109 | state = self._process_state(final_state) 110 | return state 111 | 112 | # Allows for interpolation between local and global reward given a reward disc. factor 113 | # get_global is used to signal returning the global rew as a single value for the eval runs, ow an array is returned 114 | def get_reward(self, get_global): 115 | reward_interpolation = self.constants['multiagent']['reward_interpolation'] 116 | # Get local rewards for each intersection 117 | local_rewards = {} 118 | for intersection in self.intersections: 119 | dets = self.intersection_dets[intersection] 120 | dets_rew = sum([self.connection.lanearea.getJamLengthVehicle(det) for det in dets]) 121 | dets_rew = (len(dets) * DET_LENGTH_IN_CARS) - 2 * dets_rew 122 | # todo: test performance of normalizing by len of dets NOT all dets (make sure to remove assertian below) 123 | dets_rew /= (len(self.all_dets) * DET_LENGTH_IN_CARS) 124 | local_rewards[intersection] = dets_rew 125 | # If getting global then return the sum (singe value) 126 | if get_global: 127 | ret = sum([local_rewards[i] for i in self.intersections]) 128 | assert -1.001 <= ret <= 1.001 129 | return ret 130 | # if single intersection 131 | if len(self.intersections) == 1: 132 | ret = list(local_rewards.values())[0] 133 | assert -1.001 <= ret <= 1.001 134 | return np.array([ret]) 135 | # If single agent 136 | if self.single_agent: 137 | ret = sum([local_rewards[i] for i in self.intersections]) 138 | assert -1.001 <= ret <= 1.001 139 | return np.array([ret]) 140 | # Disc edge cases 141 | if reward_interpolation == 0.: # Local 142 | ret = np.array([r for r in list(local_rewards.values())]) 143 | return ret 144 | if reward_interpolation == 1.: # global 145 | gr = sum([local_rewards[i] for i in self.intersections]) 146 | ret = np.array([gr] * len(self.intersections)) 147 | return ret 148 | # O.w. interpolation 149 | arr = [] 150 | for intersection in self.intersections: 151 | dists = self.distances[intersection] 152 | max_dist = max([d for d in list(dists.values())]) 153 | local_rew = 0. 154 | for inner_int in self.intersections: 155 | d = dists[inner_int] 156 | r = local_rewards[inner_int] 157 | local_rew += pow(reward_interpolation, (d / max_dist)) * r 158 | arr.append(local_rew) 159 | return np.array(arr) 160 | 161 | # Switch 162 | # action: {"intersectionNW": 0 or 1, .... } 163 | def _execute_action(self, action): 164 | # dont allow ANY switching if in yellow phase (ie in process of switching) 165 | # Loop through digits, one means switch, zero means stay 166 | for intersection in self.intersections: 167 | value = action[intersection] 168 | currPhase = self.connection.trafficlight.getPhase(intersection) 169 | if currPhase == 1 or currPhase == 3: # Yellow, pass 170 | continue 171 | if value == 0: # do nothing 172 | continue 173 | else: # switch 174 | newPhase = currPhase + 1 175 | self.connection.trafficlight.setPhase(intersection, newPhase) 176 | 177 | def _generate_configfile(self): 178 | with open('data/{}_{}.sumocfg'.format(self.env_name, self.agent_ID), 'w') as config: 179 | print(""" 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | """.format(get_rel_net_path(self.env_name), self.env_name, self.agent_ID, self.env_name, self.agent_ID), file=config) 199 | 200 | def _add_vehicle(self, t, node_probs, routes_string): 201 | rem_nodes = [] 202 | rem_probs = [] 203 | for k, v in list(node_probs.items()): 204 | rem_nodes.append(k) 205 | rem_probs.append(v['rem']) 206 | 207 | # Pick removal func 208 | def pick_rem_edge(gen_node): 209 | while True: 210 | chosen = choice(rem_nodes, p=rem_probs) 211 | if chosen != gen_node: return chosen 212 | 213 | # Loop through all gen edges and see if a veh generates 214 | for gen_k, dic in list(node_probs.items()): 215 | gen_prob = dic['gen'] 216 | if random.random() < gen_prob: 217 | # It does generate a veh. so pick a removal edge 218 | rem_k = pick_rem_edge(gen_k) 219 | route_id = gen_k + '___' + rem_k 220 | gen_edge = self.node_edge_dic[gen_k]['gen'] 221 | rem_edge = self.node_edge_dic[rem_k]['rem'] 222 | routes_string += ' '.format(route_id, t, gen_edge, rem_edge, t) 223 | return routes_string 224 | 225 | def _generate_routefile(self): 226 | routes_string = \ 227 | """ 228 | 229 | 230 | """.format(VEH_LENGTH, VEH_MIN_GAP) 231 | # Add the vehicles 232 | for t in range(self.constants['episode']['generation_ep_steps']): 233 | routes_string = self._add_vehicle(t, get_current_phase_probs(t, self.phases, self.constants['episode']['generation_ep_steps']), routes_string) 234 | routes_string += '' 235 | # Output 236 | with open("data/{}_{}.rou.xml".format(self.env_name, self.agent_ID), "w") as routes: 237 | print(routes_string, file=routes) 238 | 239 | def _generate_addfile(self): 240 | self.all_dets = [] # For reward function 241 | self.intersection_dets = OrderedDict({k: [] for k in self.intersections}) # For state 242 | add_string = '' 243 | # Loop through the net file to get all edges that go to an intersection 244 | tree = ET.parse(self.net_path) 245 | root = tree.getroot() 246 | for c in root.iter('edge'): 247 | id = c.attrib['id'] 248 | # If function key in attib then continue or if not going to intersection 249 | if 'function' in c.attrib: 250 | continue 251 | if not 'intersection' in c.attrib['to']: 252 | continue 253 | length = float(c[0].attrib['length']) # SINGLE LANE ONLY 254 | pos = length - (DET_LENGTH_IN_CARS * (VEH_LENGTH + VEH_MIN_GAP)) 255 | det_id = 'DET+++'+id 256 | self.all_dets.append(det_id) 257 | self.intersection_dets[c.attrib['to']].append(det_id) 258 | add_string += ' ' \ 260 | ''.format(det_id, id, pos, length, self.env_name) 261 | add_string += \ 262 | """ 263 | 264 | 265 | """.format(self.agent_ID) 266 | with open("data/{}_{}.add.xml".format(self.env_name, self.agent_ID), "w") as add: 267 | print(add_string, file=add) 268 | -------------------------------------------------------------------------------- /images/1_1-grid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxbrenner-ai/Multi-Agent-Distributed-PPO-Traffc-light-control/b2919a50237a2047cad033c2613549e7ad19f305/images/1_1-grid.png -------------------------------------------------------------------------------- /images/2_2-grid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxbrenner-ai/Multi-Agent-Distributed-PPO-Traffc-light-control/b2919a50237a2047cad033c2613549e7ad19f305/images/2_2-grid.png -------------------------------------------------------------------------------- /images/marl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxbrenner-ai/Multi-Agent-Distributed-PPO-Traffc-light-control/b2919a50237a2047cad033c2613549e7ad19f305/images/marl.png -------------------------------------------------------------------------------- /images/ppo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxbrenner-ai/Multi-Agent-Distributed-PPO-Traffc-light-control/b2919a50237a2047cad033c2613549e7ad19f305/images/ppo.png -------------------------------------------------------------------------------- /images/sumo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxbrenner-ai/Multi-Agent-Distributed-PPO-Traffc-light-control/b2919a50237a2047cad033c2613549e7ad19f305/images/sumo.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import time 2 | from utils.utils import * 3 | from utils.random_search import grid_choices_random 4 | from utils.grid_search import grid_choices, get_num_grid_choices 5 | from run_agent_parallel import train_PPO, test_rule_based, test_PPO 6 | import sys 7 | import os 8 | from data_collector import DataCollector, POSSIBLE_DATA 9 | 10 | 11 | def run_grid_search(verbose, num_repeat_experiment, df_path=None, overwrite=True, data_to_collect=POSSIBLE_DATA, 12 | MVP_key='waitingTime', save_model=True): 13 | grid = load_constants('constants/constants-grid.json') 14 | 15 | # Make the grid choice generator 16 | bases, num_choices = get_num_grid_choices(grid) 17 | grid_choice_gen = grid_choices(grid, bases) 18 | for diff_experiment, constants in enumerate(grid_choice_gen): 19 | data_collector_obj = DataCollector(data_to_collect, MVP_key, constants, 20 | 'test' if constants['agent']['agent_type'] == 'rule' else 'eval', 21 | df_path, overwrite if diff_experiment == 0 else False, verbose) 22 | 23 | for same_experiment in range(num_repeat_experiment): 24 | print(' --- Running experiment {}.{} / {}.{} --- '.format(diff_experiment+1, same_experiment+1, 25 | num_choices, num_repeat_experiment)) 26 | if save_model: data_collector_obj.set_save_model_path( 27 | 'models/saved_models/grid_{}-{}.pt'.format(diff_experiment + 1, same_experiment + 1)) 28 | run_experiment(diff_experiment+1, same_experiment+1, constants, data_collector_obj) 29 | 30 | 31 | def run_random_search(verbose, num_diff_experiments, num_repeat_experiment, allow_duplicates=False, df_path=None, 32 | overwrite=True, data_to_collect=POSSIBLE_DATA, MVP_key='waitingTime', save_model=True): 33 | grid = load_constants('constants/constants-grid.json') 34 | 35 | if not allow_duplicates: 36 | _, num_choices = get_num_grid_choices(grid) 37 | num_diff_experiments = min(num_choices, num_diff_experiments) 38 | # Make grid choice generator 39 | grid_choice_gen = grid_choices_random(grid, num_diff_experiments) 40 | for diff_experiment, constants in enumerate(grid_choice_gen): 41 | data_collector_obj = DataCollector(data_to_collect, MVP_key, constants, 42 | 'test' if constants['agent']['agent_type'] == 'rule' else 'eval', 43 | df_path, overwrite if diff_experiment == 0 else False, verbose) 44 | 45 | for same_experiment in range(num_repeat_experiment): 46 | print(' --- Running experiment {}.{} / {}.{} --- '.format(diff_experiment+1, same_experiment+1, 47 | num_diff_experiments, num_repeat_experiment)) 48 | if save_model: data_collector_obj.set_save_model_path('models/saved_models/random_{}-{}.pt'. 49 | format(diff_experiment+1, same_experiment+1)) 50 | run_experiment(diff_experiment+1, same_experiment+1, constants, data_collector_obj) 51 | 52 | 53 | def run_normal(verbose, num_experiments=1, df_path=None, overwrite=True, data_to_collect=POSSIBLE_DATA, 54 | MVP_key='waitingTime', save_model=True, load_model_file=None): 55 | # if loading, then dont save 56 | if load_model_file: 57 | save_model = False 58 | 59 | if not df_path: 60 | df_path = 'run-data.xlsx' # def. path 61 | 62 | # Load constants 63 | constants = load_constants('constants/constants.json') 64 | data_collector_obj = DataCollector(data_to_collect, MVP_key, constants, 65 | 'test' if constants['agent']['agent_type'] == 'rule' or load_model_file else 'eval', 66 | df_path, overwrite, verbose) 67 | 68 | loaded_model = None 69 | if load_model_file: 70 | loaded_model = torch.load('models/saved_models/' + load_model_file) 71 | 72 | for exp in range(num_experiments): 73 | print(' --- Running experiment {} / {} --- '.format(exp + 1, num_experiments)) 74 | if save_model: data_collector_obj.set_save_model_path('models/saved_models/normal_{}.pt'.format(exp+1)) 75 | run_experiment(exp+1, None, constants, data_collector_obj, loaded_model=loaded_model) 76 | 77 | 78 | def run_experiment(exp1, exp2, constants, data_collector_obj, loaded_model=None): 79 | data_collector_obj.start_timer() 80 | 81 | if loaded_model: 82 | test_PPO(constants, device, data_collector_obj, loaded_model) 83 | elif constants['agent']['agent_type'] == 'ppo': 84 | train_PPO(constants, device, data_collector_obj) 85 | else: 86 | assert constants['agent']['agent_type'] == 'rule' 87 | test_rule_based(constants, device, data_collector_obj) 88 | 89 | # Save and Refresh the data_collector 90 | data_collector_obj.end_timer(printIt=True) 91 | data_collector_obj.process_data() 92 | data_collector_obj.print_summary(exp1, exp2) 93 | data_collector_obj.done_with_experiment() 94 | 95 | 96 | if __name__ == '__main__': 97 | # we need to import python modules from the $SUMO_HOME/tools directory 98 | if 'SUMO_HOME' in os.environ: 99 | tools = os.path.join(os.environ['SUMO_HOME'], 'tools') 100 | sys.path.append(tools) 101 | else: 102 | sys.exit("please declare environment variable 'SUMO_HOME'") 103 | 104 | device = torch.device('cpu') 105 | 106 | # df_path = 'run-data.xlsx' 107 | 108 | # print('Num cores: {}'.format(mp.cpu_count())) 109 | 110 | run_normal(verbose=False, num_experiments=3, df_path='run-data.xlsx', overwrite=True, 111 | data_to_collect=POSSIBLE_DATA, MVP_key='waitingTime', save_model=True, load_model_file=None) 112 | 113 | # run_random_search(verbose=False, num_diff_experiments=800, num_repeat_experiment=3, allow_duplicates=False, 114 | # df_path='run-data.xlsx', overwrite=True, data_to_collect=POSSIBLE_DATA, MVP_key='waitingTime', 115 | # save_model=True) 116 | 117 | # run_grid_search(verbose=False, num_repeat_experiment=3, df_path='run-data.xlsx', overwrite=True, 118 | # data_to_collect=POSSIBLE_DATA, MVP_key='waitingTime', save_model=True) 119 | 120 | -------------------------------------------------------------------------------- /models/cycle_rule_set.py: -------------------------------------------------------------------------------- 1 | from scipy.sparse.csgraph import shortest_path 2 | import numpy as np 3 | from utils.net_scrape import * 4 | from models.rule_set import RuleSet 5 | from random import sample 6 | 7 | 8 | ''' 9 | 1. Dont want the light lengths to be less than the yellow phase (3 seconds) 10 | 2. Assume NS starting config for lights 11 | ''' 12 | 13 | class CycleRuleSet(RuleSet): 14 | def __init__(self, params, net_path, constants): 15 | super(CycleRuleSet, self).__init__(params, net_path) 16 | assert params['phases'] 17 | assert 'cycle_length' in params 18 | self.net_path = net_path 19 | self.gen_time = constants['episode']['generation_ep_steps'] 20 | self.NS_mult = params['NS_mult'] 21 | self.EW_mult = params['EW_mult'] 22 | self.phase_end_offset = params['phase_end_offset'] 23 | self.cycle_length = params['cycle_length'] 24 | self.intersections = get_intersections(self.net_path) 25 | # [{'start_step': t, 'light_lengths': {intersection: {NS: #, EW: #}...}}...] 26 | self.phases = self._get_light_lengths(params['phases']) 27 | self.reset() 28 | 29 | def reset(self): 30 | self.curr_phase_indx = -1 # Keep track of current phase 31 | # Keep track of current light for each inetrsection as well as current time with that light 32 | self.info = {intersection: {'curr_light': 'NS', 'curr_time': 0} for intersection in self.intersections} 33 | 34 | def __call__(self, state): 35 | old_lights = [] 36 | new_lights = [] 37 | t = state['sim_step'] 38 | next_phase = False 39 | if self.curr_phase_indx < len(self.phases)-1 and self.phases[self.curr_phase_indx + 1]['start_step'] == t: 40 | next_phase = True 41 | self.curr_phase_indx += 1 42 | 43 | for intersection in self.intersections: 44 | if self.info[intersection]['curr_light'] == 'EW': 45 | assert state[intersection]['curr_phase'] == 1 or state[intersection]['curr_phase'] == 2 46 | elif self.info[intersection]['curr_light'] == 'NS': 47 | assert state[intersection]['curr_phase'] == 0 or state[intersection]['curr_phase'] == 3 48 | # First get old lights to compare with new lights to see if switch is needed 49 | old_lights.append(self.info[intersection]['curr_light']) 50 | # If the next phase start step is equal to curr time then switch to next phase 51 | if next_phase: 52 | # Select random starting light if not yellow light (1 or 3) 53 | if state[intersection]['curr_phase'] == 0 or state[intersection]['curr_phase'] == 2: 54 | choice = sample(['NS', 'EW'], 1)[0] 55 | else: 56 | if state[intersection]['curr_phase'] == 1: 57 | choice = 'EW' 58 | else: 59 | choice = 'NS' 60 | self.info[intersection]['curr_light'] = choice 61 | self.info[intersection]['curr_time'] = 0 62 | else: # o.w. continue in curr phase, switch if end of current light time 63 | light_lengths = self.phases[self.curr_phase_indx]['light_lengths'] 64 | intersection_light_lengths = light_lengths[intersection] 65 | curr_light = self.info[intersection]['curr_light'] 66 | curr_time = self.info[intersection]['curr_time'] 67 | # If curr time of light is equal to its length then switch 68 | if curr_time >= intersection_light_lengths[curr_light]: 69 | self.info[intersection]['curr_light'] = 'NS' if curr_light == 'EW' else 'EW' 70 | self.info[intersection]['curr_time'] = 0 71 | else: 72 | self.info[intersection]['curr_time'] += 1 73 | # Get new light 74 | new_lights.append(self.info[intersection]['curr_light']) 75 | # Compare old lights with new lights 76 | bin = [0 if x == y else 1 for x, y in zip(old_lights, new_lights)] 77 | return np.array(bin) 78 | 79 | def _get_light_lengths(self, phases): 80 | NODES_arr, NODES = get_node_arr_and_dict(self.net_path) 81 | N = len(list(NODES)) 82 | assert N == len(set(list(NODES.keys()))) 83 | 84 | # Matrix of edge lengths for dikstras 85 | M = np.zeros(shape=(N, N)) 86 | 87 | # Need dict of edge lengths 88 | edge_lengths = get_edge_length_dict(self.net_path, jitter=False) 89 | 90 | # Loop through nodes and fill M using incLanes attrib 91 | # Also make dict of positions for nodes so I can work with phases 92 | node_locations = {} 93 | tree = ET.parse(self.net_path) 94 | root = tree.getroot() 95 | for c in root.iter('junction'): 96 | id = c.attrib['id'] 97 | # Ignore if type is internal 98 | if c.attrib['type'] == 'internal': 99 | continue 100 | # Add location for future use 101 | node_locations[id] = (c.attrib['x'], c.attrib['y']) 102 | # Current col 103 | col = NODES[id] 104 | # Get the edges to this node through the incLanes attrib (remove _0 for the first lane) 105 | incLanes = c.attrib['incLanes'] 106 | # Space delimited 107 | incLanes_list = incLanes.split(' ') 108 | for lane in incLanes_list: 109 | in_edge = lane.split('_0')[0] 110 | in_node = in_edge.split('___')[0] 111 | row = NODES[in_node] 112 | length = edge_lengths[in_edge] 113 | M[row, col] = length 114 | 115 | # Get the shortest path matrix 116 | _, preds = shortest_path(M, directed=True, return_predecessors=True) 117 | 118 | # Fill dict of gen{i}___rem{j} with all nodes on the shortest route 119 | # Todo: Important assumption, "intersection" is ONLY in the name/id of an intersection 120 | shortest_routes = {} 121 | for i in range(N): 122 | gen = NODES_arr[i] 123 | if 'intersection' in gen: continue 124 | for j in range(N): 125 | if i == j: 126 | continue 127 | rem = NODES_arr[j] 128 | if 'intersection' in rem: continue 129 | route = [rem] 130 | while preds[i, j] != i: # via 131 | j = preds[i, j] 132 | route.append(NODES_arr[int(j)]) 133 | route.append(gen) 134 | shortest_routes['{}___{}'.format(rem, gen)] = route 135 | 136 | def get_weightings(probs): 137 | weights = {intersection: {'NS': 0., 'EW': 0.} for intersection in self.intersections} 138 | for intersection, phases in list(weights.items()): 139 | # Collect all routes that include the intersection 140 | for v in list(shortest_routes.values()): 141 | if intersection in v: 142 | int_indx = v.index(intersection) 143 | # Route weight is gen + rem probs 144 | route_weight = probs[v[0]]['gen'] + probs[v[-1]]['rem'] 145 | # Detect if NS or EW 146 | # If xs are equal between the node before and the intersection then EW is used 147 | if node_locations[v[int_indx - 1]][0] == node_locations[intersection][0]: 148 | phases['NS'] += self.NS_mult * route_weight 149 | else: 150 | assert node_locations[v[int_indx - 1]][1] == node_locations[intersection][ 151 | 1], 'xs not equal so ys should be.' 152 | phases['EW'] += self.EW_mult * route_weight 153 | return weights 154 | 155 | def convert_weights_to_lengths(weights, cycle_length=10): 156 | ret_dic = {} 157 | for intersection, dic in list(weights.items()): 158 | total = dic['NS'] + dic['EW'] 159 | NS = round((dic['NS'] / total) * cycle_length) 160 | EW = round((dic['EW'] / total) * cycle_length) 161 | assert NS >= 3 and EW >= 3, 'Assuming yellow phases are 3 make sure these lengths are not less.' 162 | ret_dic[intersection] = {'NS': NS, 'EW': EW} 163 | return ret_dic 164 | 165 | # For each phase get weights and cycle lengths 166 | new_phases = [] 167 | curr_t = 0 168 | for p, phase in enumerate(phases): 169 | weights = get_weightings(phase['probs']) 170 | light_lengths = convert_weights_to_lengths(weights, self.cycle_length) 171 | phase_length = phase['duration'] * self.gen_time 172 | # Want to offset the end of each phase, except for the first phase which needs to start at the beginning 173 | offset = self.phase_end_offset 174 | if p == 0: 175 | offset = 0 176 | new_phases.append({'start_step': round(curr_t + offset), 'light_lengths': light_lengths}) 177 | curr_t += round(phase_length) 178 | return new_phases 179 | -------------------------------------------------------------------------------- /models/ppo_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | from utils.utils import plot_grad_flow, layer_init_filter 5 | 6 | 7 | # Simple one layer 8 | class ModelBody(nn.Module): 9 | def __init__(self, input_size, hidden_size): 10 | super(ModelBody, self).__init__() 11 | self.name = 'model_body' 12 | self.model = nn.Sequential( 13 | nn.Linear(input_size, hidden_size), 14 | nn.ReLU() 15 | ) 16 | self.model.apply(layer_init_filter) 17 | 18 | def forward(self, states): 19 | hidden_states = self.model(states) 20 | return hidden_states 21 | 22 | 23 | class ActorModel(nn.Module): 24 | def __init__(self, hidden_size, action_size): 25 | super(ActorModel, self).__init__() 26 | self.name = 'actor' 27 | self.model = nn.Sequential( 28 | nn.Linear(hidden_size, action_size) 29 | ) 30 | self.model.apply(layer_init_filter) 31 | 32 | def forward(self, hidden_states): 33 | outputs = self.model(hidden_states) 34 | return outputs 35 | 36 | 37 | class CriticModel(nn.Module): 38 | def __init__(self, hidden_size): 39 | super(CriticModel, self).__init__() 40 | self.name = 'critic' 41 | self.model = nn.Sequential( 42 | nn.Linear(hidden_size, 1) 43 | ) 44 | self.model.apply(layer_init_filter) 45 | 46 | def forward(self, hidden_states): 47 | outputs = self.model(hidden_states) 48 | return outputs 49 | 50 | 51 | class NN_Model(nn.Module): 52 | def __init__(self, state_size, action_size, hidden_layer_size, device): 53 | super(NN_Model, self).__init__() 54 | self.body_model = ModelBody(state_size, hidden_layer_size).to(device) 55 | self.actor_model = ActorModel(hidden_layer_size, action_size).to(device) 56 | self.critic_model = CriticModel(hidden_layer_size).to(device) 57 | 58 | self.models = [self.body_model, self.actor_model, self.critic_model] 59 | 60 | def forward(self, states, actions=None): 61 | hidden_states = self.body_model(states) 62 | v = self.critic_model(hidden_states) 63 | logits = self.actor_model(hidden_states) 64 | dist = torch.distributions.Categorical(logits=logits) 65 | if actions is None: 66 | actions = dist.sample() 67 | log_prob = dist.log_prob(actions).unsqueeze(-1) 68 | entropy = dist.entropy().unsqueeze(-1) 69 | return {'a': actions, 70 | 'log_pi_a': log_prob, 71 | 'ent': entropy, 72 | 'v': v} 73 | -------------------------------------------------------------------------------- /models/rule_set.py: -------------------------------------------------------------------------------- 1 | class RuleSet: 2 | def __init__(self, params, net_path=None, constants=None): 3 | self.net_path = net_path 4 | 5 | def reset(self): 6 | raise NotImplementedError 7 | 8 | def __call__(self, state): 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /models/timer_rule_set.py: -------------------------------------------------------------------------------- 1 | from models.rule_set import RuleSet 2 | import numpy as np 3 | 4 | 5 | class TimerRuleSet(RuleSet): 6 | def __init__(self, params, net_path=None, constants=None): 7 | super(TimerRuleSet, self).__init__(params) 8 | assert 'length' in params, 'Length not in params' 9 | self.timer_length = params['length'] 10 | self.offset = params['offset'] if 'offset' in params else None 11 | self.start_phase = params['start_phase'] if 'start_phase' in params else None 12 | assert isinstance(self.timer_length, list) 13 | self.reset() 14 | 15 | def reset(self): 16 | # Start at offset if given, else zeros 17 | self.current_phase_timer = [x for x in self.offset] if self.offset else [0] * len(self.timer_length) 18 | self.just_reset = True 19 | 20 | def __call__(self, state): 21 | # If just reset then send switch for start phases of 1 22 | if self.start_phase and self.just_reset: 23 | self.just_reset = False 24 | return np.array(self.start_phase) 25 | li = [] 26 | for i in range(len(self.timer_length)): 27 | # Can ignore the actual env state 28 | if self.current_phase_timer[i] >= self.timer_length[i]: 29 | self.current_phase_timer[i] = 0 30 | li.append(1) # signals to the env to switch the actual phase 31 | continue 32 | self.current_phase_timer[i] += 1 33 | li.append(0) 34 | return np.array(li) 35 | -------------------------------------------------------------------------------- /run-data.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxbrenner-ai/Multi-Agent-Distributed-PPO-Traffc-light-control/b2919a50237a2047cad033c2613549e7ad19f305/run-data.xlsx -------------------------------------------------------------------------------- /run_agent_parallel.py: -------------------------------------------------------------------------------- 1 | from models.ppo_model import NN_Model 2 | from utils.utils import * 3 | from environments.intersections import IntersectionsEnv, PER_AGENT_STATE_SIZE, GLOBAL_STATE_SIZE, ACTION_SIZE 4 | from copy import deepcopy 5 | from workers.ppo_worker import PPOWorker 6 | from workers.rule_worker import RuleBasedWorker 7 | from math import pow 8 | from utils.net_scrape import get_intersection_neighborhoods 9 | 10 | 11 | # target for multiprocess 12 | def train_worker(id, shared_NN, data_collector, optimizer, rollout_counter, constants, device, max_neighborhood_size): 13 | s_a = get_state_action_size(PER_AGENT_STATE_SIZE, GLOBAL_STATE_SIZE, ACTION_SIZE, max_neighborhood_size, constants) 14 | env = IntersectionsEnv(constants, device, id, False, get_net_path(constants)) 15 | # Assumes PPO worker 16 | local_NN = NN_Model(s_a['s'], s_a['a'], constants['ppo']['hidden_layer_size'], device).to(device) 17 | worker = PPOWorker(constants, device, env, None, shared_NN, local_NN, optimizer, id) 18 | train_step = 0 19 | while rollout_counter.get() < constants['episode']['num_train_rollouts'] + 1: 20 | worker.train_rollout(train_step) 21 | rollout_counter.increment() 22 | # Kill connection to sumo server 23 | worker.env.connection.close() 24 | # print('...Training worker {} done'.format(id)) 25 | 26 | 27 | # target for multiprocess 28 | def eval_worker(id, shared_NN, data_collector, rollout_counter, constants, device, max_neighborhood_size): 29 | s_a = get_state_action_size(PER_AGENT_STATE_SIZE, GLOBAL_STATE_SIZE, ACTION_SIZE, max_neighborhood_size, constants) 30 | env = IntersectionsEnv(constants, device, id, True, get_net_path(constants)) 31 | # Assumes PPO worker 32 | local_NN = NN_Model(s_a['s'], s_a['a'], constants['ppo']['hidden_layer_size'], device).to(device) 33 | worker = PPOWorker(constants, device, env, data_collector, shared_NN, local_NN, None, id) 34 | last_eval = 0 35 | while True: 36 | curr_r = rollout_counter.get() 37 | if curr_r % constants['episode']['eval_freq'] == 0 and last_eval != curr_r: 38 | last_eval = curr_r 39 | worker.eval_episodes(curr_r, model_state=worker.NN.state_dict()) 40 | # End the eval agent 41 | if curr_r >= constants['episode']['num_train_rollouts'] + 1: 42 | break 43 | # Eval at end 44 | worker.eval_episodes(curr_r, model_state=worker.NN.state_dict()) 45 | # Kill connection to sumo server 46 | worker.env.connection.close() 47 | # print('...Eval worker {} done'.format(id)) 48 | 49 | 50 | # target for multiprocess 51 | def test_worker(id, ep_counter, constants, device, worker=None, data_collector=None, shared_NN=None, max_neighborhood_size=None): 52 | # assume PPO agent if agent=None 53 | if not worker: 54 | s_a = get_state_action_size(PER_AGENT_STATE_SIZE, GLOBAL_STATE_SIZE, ACTION_SIZE, max_neighborhood_size, constants) 55 | env = IntersectionsEnv(constants, device, id, True, get_net_path(constants)) 56 | local_NN = NN_Model(s_a['s'], s_a['a'], constants['ppo']['hidden_layer_size'], device).to(device) 57 | worker = PPOWorker(constants, device, env, data_collector, 58 | shared_NN, local_NN, None, id) 59 | while ep_counter.get() < constants['episode']['test_num_eps']: 60 | worker.eval_episodes(None, ep_count=ep_counter.get()) 61 | ep_counter.increment(constants['episode']['eval_num_eps']) 62 | # Kill connection to sumo server 63 | worker.env.connection.close() 64 | # print('...Testing agent {} done'.format(id)) 65 | 66 | 67 | # ====================================================================================================================== 68 | 69 | 70 | def train_PPO(constants, device, data_collector): 71 | _, max_neighborhood_size = get_intersection_neighborhoods(get_net_path(constants)) 72 | s_a = get_state_action_size(PER_AGENT_STATE_SIZE, GLOBAL_STATE_SIZE, ACTION_SIZE, max_neighborhood_size, constants) 73 | shared_NN = NN_Model(s_a['s'], s_a['a'], constants['ppo']['hidden_layer_size'], device).to(device) 74 | shared_NN.share_memory() 75 | optimizer = torch.optim.Adam(shared_NN.parameters(), constants['ppo']['learning_rate']) 76 | rollout_counter = Counter() # To keep track of all the rollouts amongst agents 77 | processes = [] 78 | # Run eval agent 79 | id = 'eval_0' 80 | p = mp.Process(target=eval_worker, args=(id, shared_NN, data_collector, rollout_counter, constants, device, max_neighborhood_size)) 81 | p.start() 82 | processes.append(p) 83 | # Run training agents 84 | for i in range(constants['parallel']['num_workers']): 85 | id = 'train_'+str(i) 86 | p = mp.Process(target=train_worker, args=(id, shared_NN, data_collector, optimizer, rollout_counter, constants, device, max_neighborhood_size)) 87 | p.start() 88 | processes.append(p) 89 | for p in processes: 90 | p.join() 91 | 92 | 93 | def test_PPO(constants, device, data_collector, loaded_model): 94 | _, max_neighborhood_size = get_intersection_neighborhoods(get_net_path(constants)) 95 | s_a = get_state_action_size(PER_AGENT_STATE_SIZE, GLOBAL_STATE_SIZE, ACTION_SIZE, max_neighborhood_size, constants) 96 | shared_NN = NN_Model(s_a['s'], s_a['a'], constants['ppo']['hidden_layer_size'], device).to(device) 97 | shared_NN.load_state_dict(loaded_model) 98 | shared_NN.share_memory() 99 | ep_counter = Counter() # Eps across all agents 100 | processes = [] 101 | for i in range(constants['parallel']['num_workers']): 102 | id = 'test_'+str(i) 103 | p = mp.Process(target=test_worker, args=(id, ep_counter, constants, device, None, data_collector, shared_NN, max_neighborhood_size)) 104 | p.start() 105 | processes.append(p) 106 | for p in processes: 107 | p.join() 108 | 109 | 110 | # verbose means test prints at end of each batch of eps 111 | def test_rule_based(constants, device, data_collector): 112 | _, max_neighborhood_size = get_intersection_neighborhoods(get_net_path(constants)) 113 | # Check rule set 114 | rule_set_class = get_rule_set_class(constants['rule']['rule_set']) 115 | ep_counter = Counter() # Eps across all agents 116 | processes = [] 117 | for i in range(constants['parallel']['num_workers']): 118 | id = 'test_'+str(i) 119 | env = IntersectionsEnv(constants, device, id, True, get_net_path(constants)) 120 | rule_set_params = deepcopy(constants['rule']['rule_set_params']) 121 | rule_set_params['phases'] = env.phases 122 | worker = RuleBasedWorker(constants, device, env, rule_set_class(rule_set_params, get_net_path(constants), constants), data_collector, id) 123 | p = mp.Process(target=test_worker, args=(id, ep_counter, constants, device, worker, None, None, max_neighborhood_size)) 124 | p.start() 125 | processes.append(p) 126 | for p in processes: 127 | p.join() 128 | -------------------------------------------------------------------------------- /utils/env_phases.py: -------------------------------------------------------------------------------- 1 | from utils.net_scrape import * 2 | 3 | 4 | def get_phases(environment_C, net_path): 5 | if environment_C['rush_hour']: 6 | if environment_C['shape'] == [2, 2]: return two_two_intersections_rush_hour() 7 | else: raise AssertionError('Wrong env given for rush hour.') 8 | else: 9 | return uniform_intersections(environment_C['uniform_generation_probability'], net_path) 10 | 11 | # Normalize all rem probs for each phase to sum to 1 12 | def normalize(dic): 13 | total = sum([v['rem'] for v in list(dic.values())]) 14 | for v in list(dic.values()): 15 | v['rem'] /= total 16 | return dic 17 | 18 | def uniform_intersections(gen, net_path): 19 | node_arr = get_non_intersections(net_path) 20 | probs = {} 21 | for node in node_arr: 22 | probs[node] = {'gen': gen, 'rem': 1} 23 | phase = [{'duration': 1.0, 'probs': normalize(probs)}] 24 | return phase 25 | 26 | def two_two_intersections_rush_hour(): 27 | high_rush_hour_prob_gen = 0.2 # When an edge is generating a bunch of vehicles for rush hour 28 | low_rush_hour_prob_gen = 0.01 # When an edge is on the removal side of rush hour 29 | side_high_rush_hour_prob_gen = 0.05 # When a side edge is generating for rush hour 30 | side_low_rush_hour_prob_gen = 0.01 # When a side edge is removing for rush hour 31 | non_rush_hour_prob_gen = 0.1 # For all edges when it isn't rush hour 32 | 33 | high_rush_hour_prob_rem = 0.5 34 | non_rush_hour_prob_rem = 0.1 35 | side_high_rush_hour_prob_rem = 0.2 36 | 37 | first_rush_hour_probs = {'northLeft': {'gen': low_rush_hour_prob_gen, 'rem': high_rush_hour_prob_rem}, 38 | 'northRight': {'gen': low_rush_hour_prob_gen, 'rem': high_rush_hour_prob_rem}, 39 | 'southLeft': {'gen': high_rush_hour_prob_gen, 'rem': non_rush_hour_prob_rem}, 40 | 'southRight': {'gen': high_rush_hour_prob_gen, 'rem': non_rush_hour_prob_rem}, 41 | 'westTop': {'gen': side_low_rush_hour_prob_gen, 'rem': side_high_rush_hour_prob_rem}, 42 | 'eastTop': {'gen': side_low_rush_hour_prob_gen, 'rem': side_high_rush_hour_prob_rem}, 43 | 'westBottom': {'gen': side_high_rush_hour_prob_gen, 'rem': non_rush_hour_prob_rem}, 44 | 'eastBottom': {'gen': side_high_rush_hour_prob_gen, 'rem': non_rush_hour_prob_rem}} 45 | 46 | non_rush_hour_probs = {'northLeft': {'gen': non_rush_hour_prob_gen, 'rem': non_rush_hour_prob_rem}, 47 | 'northRight': {'gen': non_rush_hour_prob_gen, 'rem': non_rush_hour_prob_rem}, 48 | 'southLeft': {'gen': non_rush_hour_prob_gen, 'rem': non_rush_hour_prob_rem}, 49 | 'southRight': {'gen': non_rush_hour_prob_gen, 'rem': non_rush_hour_prob_rem}, 50 | 'westTop': {'gen': non_rush_hour_prob_gen, 'rem': non_rush_hour_prob_rem}, 51 | 'eastTop': {'gen': non_rush_hour_prob_gen, 'rem': non_rush_hour_prob_rem}, 52 | 'westBottom': {'gen': non_rush_hour_prob_gen, 'rem': non_rush_hour_prob_rem}, 53 | 'eastBottom': {'gen': non_rush_hour_prob_gen, 'rem': non_rush_hour_prob_rem}} 54 | 55 | second_rush_hour_probs = {'northLeft': {'gen': high_rush_hour_prob_gen, 'rem': non_rush_hour_prob_rem}, 56 | 'northRight': {'gen': high_rush_hour_prob_gen, 'rem': non_rush_hour_prob_rem}, 57 | 'southLeft': {'gen': low_rush_hour_prob_gen, 'rem': high_rush_hour_prob_rem}, 58 | 'southRight': {'gen': low_rush_hour_prob_gen, 'rem': high_rush_hour_prob_rem}, 59 | 'westTop': {'gen': side_high_rush_hour_prob_gen, 'rem': non_rush_hour_prob_rem}, 60 | 'eastTop': {'gen': side_high_rush_hour_prob_gen, 'rem': non_rush_hour_prob_rem}, 61 | 'westBottom': {'gen': side_low_rush_hour_prob_gen, 'rem': side_high_rush_hour_prob_rem}, 62 | 'eastBottom': {'gen': side_low_rush_hour_prob_gen, 'rem': side_high_rush_hour_prob_rem}} 63 | 64 | # Phase durations are a proportion of the total generation time 65 | phases = [{'duration': 0.25, 'probs': normalize(first_rush_hour_probs)}, 66 | {'duration': 0.5, 'probs': normalize(non_rush_hour_probs)}, 67 | {'duration': 0.25, 'probs': normalize(second_rush_hour_probs)}] 68 | 69 | assert sum([p['duration'] for p in phases]) == 1. 70 | 71 | return phases 72 | 73 | 74 | # Given t figure out which phase its in 75 | def get_current_phase_probs(t, phases, gen_time): 76 | curr_start = 0 77 | for p in phases: 78 | dur = p['duration'] * gen_time 79 | if curr_start <= t < dur + curr_start: 80 | return p['probs'] 81 | curr_start += dur 82 | raise AssertionError() 83 | -------------------------------------------------------------------------------- /utils/grid_search.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_num_grid_choices(grid): 5 | bases = [] 6 | for k, v in list(grid.items()): 7 | for v_in in list(v.values()): 8 | bases.append(len(v_in)) 9 | bases = np.array(bases) 10 | return bases, np.prod(bases) 11 | 12 | def grid_choices(grid, bases): 13 | # Grab choices generator 14 | increment = _next_choice(bases) 15 | # Unflatten 16 | for choices in increment: 17 | index = 0 18 | constants = {} 19 | for k, v in list(grid.items()): 20 | constants[k] = {} 21 | for k_in, v_in in list(v.items()): 22 | constants[k][k_in] = v_in[choices[index]] 23 | index += 1 24 | yield constants 25 | 26 | def _next_choice(bases): 27 | ret = np.zeros(shape=(bases.shape[0],), dtype=np.int) 28 | yield ret 29 | while not np.array_equal(ret, bases-1): 30 | i = 0 31 | overflow = True 32 | while overflow: 33 | overflow = True if ((ret[i] + 1) % bases[i]) == 0 else False 34 | if not overflow: ret[i] += 1 35 | else: ret[i] = 0; i += 1 36 | yield ret 37 | -------------------------------------------------------------------------------- /utils/net_scrape.py: -------------------------------------------------------------------------------- 1 | import xml.etree.ElementTree as ET 2 | import random 3 | import numpy as np 4 | 5 | 6 | # Since tie breaks for shortest path will result in deterministic behavior, i can make it nondeterministic better with jitter 7 | def get_edge_length_dict(filepath, jitter=False): 8 | edges = {} 9 | # Parse through .net.xml file for lengths 10 | tree = ET.parse(filepath) 11 | root = tree.getroot() 12 | for c in root.iter('edge'): 13 | id = c.attrib['id'] 14 | # If function key in attib then continue 15 | if 'function' in c.attrib: 16 | continue 17 | length = float(c[0].attrib['length']) 18 | if jitter: 19 | length += random.randint(-1, 1) 20 | edges[id] = length # ONLY WORKS FOR SINGLE LANE ROADS 21 | return edges 22 | 23 | 24 | # Make a dict {external node: {'gen': gen edge for node, 'rem': rem edge for node} ... } for external nodes 25 | # For adding trips 26 | def get_node_edge_dict(filepath): 27 | dic = {} 28 | tree = ET.parse(filepath) 29 | root = tree.getroot() 30 | for c in root.iter('junction'): 31 | node = c.attrib['id'] 32 | if c.attrib['type'] == 'internal' or 'intersection' in node: 33 | continue 34 | incLane = c.attrib['incLanes'].split('_0')[0] 35 | dic[node] = {'gen': incLane.split('___')[1] + '___' + incLane.split('___')[0], 'rem': incLane} 36 | return dic 37 | 38 | 39 | def get_intersections(filepath): 40 | arr = [] 41 | tree = ET.parse(filepath) 42 | root = tree.getroot() 43 | for c in root.iter('junction'): 44 | node = c.attrib['id'] 45 | if c.attrib['type'] == 'internal' or not 'intersection' in node: 46 | continue 47 | arr.append(node) 48 | return arr 49 | 50 | 51 | # Returns a arr of the nodes, and a dict where the key is a node and value is the index in the arr 52 | def get_node_arr_and_dict(filepath): 53 | arr = [] 54 | dic = {} 55 | i = 0 56 | tree = ET.parse(filepath) 57 | root = tree.getroot() 58 | for c in root.iter('junction'): 59 | node = c.attrib['id'] 60 | if c.attrib['type'] == 'internal': 61 | continue 62 | arr.append(node) 63 | dic[node] = i 64 | i += 1 65 | return arr, dic 66 | 67 | 68 | # For calc. global/localized discounted rewards 69 | # {intA: {intA: 0, intB: 5, ...} ...} 70 | def get_cartesian_intersection_distances(filepath): 71 | # First gather the positions 72 | tree = ET.parse(filepath) 73 | root = tree.getroot() 74 | intersection_positions = {} 75 | for c in root.iter('junction'): 76 | node = c.attrib['id'] 77 | # If not intersection then continue 78 | if c.attrib['type'] == 'internal' or not 'intersection' in node: 79 | continue 80 | x = float(c.attrib['x']) 81 | y = float(c.attrib['y']) 82 | intersection_positions[node] = np.array([x, y]) 83 | # Second calc. euclidean distances 84 | distances = {} 85 | for outer_k, outer_v in list(intersection_positions.items()): 86 | distances[outer_k] = {} 87 | for inner_k, inner_v in list(intersection_positions.items()): 88 | dist = np.linalg.norm(outer_v - inner_v) 89 | distances[outer_k][inner_k] = dist 90 | return distances 91 | 92 | 93 | def get_non_intersections(filepath): 94 | tree = ET.parse(filepath) 95 | root = tree.getroot() 96 | arr = [] 97 | for c in root.iter('junction'): 98 | node = c.attrib['id'] 99 | if c.attrib['type'] == 'internal' or 'intersection' in node: 100 | continue 101 | arr.append(node) 102 | return arr 103 | 104 | 105 | # for state interoplation 106 | # {intA: [intB...]...} -> dict of each intersection of arrays of the intersections in its neighborhood 107 | def get_intersection_neighborhoods(filepath): 108 | tree = ET.parse(filepath) 109 | root = tree.getroot() 110 | dic = {} 111 | for c in root.iter('junction'): 112 | node = c.attrib['id'] 113 | if c.attrib['type'] == 'internal' or 'intersection' not in node: 114 | continue 115 | dic[node] = [] 116 | # parse through incoming lanes 117 | incLanes = c.attrib['incLanes'] 118 | incLanes_list = incLanes.split(' ') 119 | for lane in incLanes_list: 120 | lane_split = lane.split('___')[0] 121 | if 'intersection' in lane_split: 122 | dic[node].append(lane_split) 123 | # returns the dic of neighbors and the max number of neighbors ( + 1) 124 | return dic, max([len(v) for v in list(dic.values())]) + 1 # + 1 for including itself 125 | -------------------------------------------------------------------------------- /utils/random_search.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | def grid_choices_random(grid, num_runs): 5 | # Unflatten 6 | possibilities = [] 7 | for v in list(grid.values()): 8 | for v_in in list(v.values()): 9 | possibilities.append(v_in) 10 | incrementer = _next_choice(possibilities, num_runs) 11 | 12 | for choices in incrementer: 13 | constants = {} 14 | index = 0 15 | for k, v in list(grid.items()): 16 | constants[k] = {} 17 | for k_in in list(v.keys()): 18 | constants[k][k_in] = choices[index] 19 | index += 1 20 | yield constants 21 | 22 | def _get_random_choice(possibilities): 23 | choice = [] 24 | for const in possibilities: 25 | choice.append(random.sample(const, 1)[0]) 26 | return choice 27 | 28 | def _next_choice(possibilities, num_runs): 29 | storage = [] 30 | index = 0 31 | while index < num_runs: 32 | choice = _get_random_choice(possibilities) 33 | while choice in storage: 34 | choice = _get_random_choice(possibilities) 35 | storage.append(choice) 36 | index += 1 37 | yield choice 38 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from collections import OrderedDict 3 | import pickle 4 | import numpy as np 5 | import networkx as nx 6 | from sklearn.metrics.pairwise import cosine_similarity 7 | import pandas as pd 8 | import json 9 | import torch.nn as nn 10 | import torch 11 | import random 12 | import torch.multiprocessing as mp 13 | from models.timer_rule_set import TimerRuleSet 14 | from models.cycle_rule_set import CycleRuleSet 15 | 16 | 17 | # Acknowledge: Ilya Kostrikov (https://github.com/ikostrikov) 18 | # This assigns this agents grads to the shared grads at the very start 19 | def ensure_shared_grads(model, shared_model): 20 | for param, shared_param in zip(model.parameters(), 21 | shared_model.parameters()): 22 | if shared_param.grad is not None: 23 | return 24 | shared_param._grad = param.grad 25 | 26 | 27 | # Acknowledge: Alexis Jacq (https://github.com/alexis-jacq) 28 | class Counter: 29 | """enable the chief to access worker's total number of updates""" 30 | 31 | def __init__(self): 32 | self.val = mp.Value("i", 0) 33 | self.lock = mp.Lock() 34 | 35 | def get(self): 36 | # used by chief 37 | with self.lock: 38 | return self.val.value 39 | 40 | def increment(self, amt=1): 41 | # used by workers 42 | with self.lock: 43 | self.val.value += amt 44 | 45 | def reset(self): 46 | # used by chief 47 | with self.lock: 48 | self.val.value = 0 49 | 50 | 51 | # Works for both single contants and lists for grid 52 | def load_constants(filepath): 53 | with open(filepath) as f: 54 | data = json.load(f) 55 | return data 56 | 57 | 58 | def refresh_excel(filepath, excel_header): 59 | df = pd.DataFrame(columns=excel_header) 60 | df.to_excel(filepath, index=False, header=excel_header) 61 | 62 | 63 | # Acknowledge: Shangtong Zhang (https://github.com/ShangtongZhang) 64 | class Storage: 65 | def __init__(self, size, keys=None): 66 | if keys is None: 67 | keys = [] 68 | keys = keys + ['s', 'a', 'r', 'm', 69 | 'v', 'q', 'pi', 'log_pi', 'ent', 70 | 'adv', 'ret', 'q_a', 'log_pi_a', 71 | 'mean'] 72 | self.keys = keys 73 | self.size = size 74 | self.reset() 75 | 76 | def add(self, data): 77 | for k, v in data.items(): 78 | if k not in self.keys: 79 | self.keys.append(k) 80 | setattr(self, k, []) 81 | getattr(self, k).append(v) 82 | 83 | def placeholder(self): 84 | for k in self.keys: 85 | v = getattr(self, k) 86 | if len(v) == 0: 87 | setattr(self, k, [None] * self.size) 88 | 89 | def reset(self): 90 | for key in self.keys: 91 | setattr(self, key, []) 92 | 93 | def cat(self, keys): 94 | data = [getattr(self, k)[:self.size] for k in keys] 95 | return map(lambda x: torch.cat(x, dim=0), data) 96 | 97 | 98 | # Acknowledge: Shangtong Zhang (https://github.com/ShangtongZhang) 99 | def tensor(x, device): 100 | if isinstance(x, torch.Tensor): 101 | return x 102 | x = np.asarray(x, dtype=np.float) 103 | x = torch.tensor(x, device=device, dtype=torch.float32) 104 | return x 105 | 106 | 107 | # Acknowledge: Shangtong Zhang (https://github.com/ShangtongZhang) 108 | def random_sample(indices, batch_size): 109 | indices = np.asarray(np.random.permutation(indices)) 110 | batches = indices[:len(indices) // batch_size * batch_size].reshape(-1, batch_size) 111 | for batch in batches: 112 | yield batch 113 | r = len(indices) % batch_size 114 | if r: 115 | yield indices[-r:] 116 | 117 | 118 | def layer_init_filter(m): 119 | if type(m) == nn.Linear: 120 | nn.init.xavier_uniform_(m.weight) 121 | m.bias.data.fill_(0) 122 | 123 | 124 | def layer_init(layer, w_scale=1.0): 125 | nn.init.xavier_uniform_(layer.weight.data) 126 | layer.weight.data.mul_(w_scale) 127 | nn.init.constant_(layer.bias.data, 0) 128 | return layer 129 | 130 | 131 | def plot_grad_flow(layers, ave_grads, max_grads): 132 | plt.plot(ave_grads, alpha=0.3, color="b") 133 | plt.hlines(0, 0, len(ave_grads) + 1, linewidth=1, color="k") 134 | plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") 135 | plt.xlim(left=0, right=len(ave_grads)) 136 | plt.xlabel("Layers") 137 | plt.ylabel("average gradient") 138 | plt.title("Gradient flow") 139 | plt.grid(True) 140 | 141 | 142 | def get_rule_set_class(id): 143 | if id == 'timer': 144 | return TimerRuleSet 145 | if id == 'cycle': 146 | return CycleRuleSet 147 | raise AssertionError('Ruleset not in possible sets!') 148 | 149 | 150 | # For traffic probability based on episode step (NOT USED) 151 | class PiecewiseLinearFunction: 152 | def __init__(self, points): 153 | self.points = points 154 | self.linear_funcs = [] 155 | for p in range(len(points)-1): 156 | if p % 2 == 1: 157 | continue 158 | self.linear_funcs.append(LinearFunction(points[p], points[p+1])) 159 | 160 | def get_output(self, x): 161 | for lf in self.linear_funcs: 162 | output = lf.func(x) 163 | if output is not None: 164 | return output 165 | # If got here something wrong with the lfs sent in 166 | assert True is False, 'Something wrong with lfs sent in.' 167 | 168 | def visualize(self): 169 | xs = [p[0] for p in self.points] 170 | ys = [p[1] for p in self.points] 171 | plt.plot(xs, ys) 172 | plt.show() 173 | 174 | 175 | class LinearFunction: 176 | def __init__(self, p1, p2): 177 | x0, y0 = p1 178 | x1, y1 = p2 179 | domain = (min(x0, x1), max(x0, x1)) 180 | m = (y1 - y0) / (x1 - x0) 181 | b = y0 - m * x0 182 | self.func = lambda x: m * x + b if domain[0] <= x < domain[1] else None 183 | 184 | 185 | def get_net_path(constants): 186 | shape = constants['environment']['shape'] 187 | file = 'data/' 188 | file += '{}_{}_'.format(shape[0], shape[1]) 189 | file += 'intersections.net.xml' 190 | return file 191 | 192 | 193 | def get_state_action_size(PER_AGENT_STATE_SIZE, GLOBAL_STATE_SIZE, ACTION_SIZE, max_neighborhood_size, constants): 194 | def get_num_agents(): 195 | shape = constants['environment']['shape'] 196 | return int(shape[0] * shape[1]) 197 | single_agent = constants['agent']['single_agent'] 198 | # if single agent then make sure to scale up the state size and action size 199 | if single_agent: 200 | return {'s': int((get_num_agents() * PER_AGENT_STATE_SIZE) + GLOBAL_STATE_SIZE), 'a': int(pow(ACTION_SIZE, get_num_agents()))} 201 | else: 202 | # Multi-agent: 203 | # 1) If state disc is 0 then state size is PER_AGENT_STATE_SIZE + GLOBAL_STATE_SIZE 204 | if constants['multiagent']['state_interpolation'] == 0: 205 | return {'s': int(PER_AGENT_STATE_SIZE + GLOBAL_STATE_SIZE), 'a': int(ACTION_SIZE)} 206 | # 2) O.w. need to calc the biggest possible neighborhood 207 | else: 208 | return {'s': int((PER_AGENT_STATE_SIZE + GLOBAL_STATE_SIZE) * max_neighborhood_size), 'a': int(ACTION_SIZE)} 209 | -------------------------------------------------------------------------------- /vis_agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from workers.ppo_worker import PPOWorker 4 | from workers.rule_worker import RuleBasedWorker 5 | from models.ppo_model import NN_Model 6 | from utils.utils import * 7 | from environments.intersections import IntersectionsEnv, PER_AGENT_STATE_SIZE, GLOBAL_STATE_SIZE, ACTION_SIZE 8 | from copy import deepcopy 9 | from utils.net_scrape import * 10 | 11 | 12 | def test_worker(worker): 13 | worker.eval_episode({}) 14 | # Kill connection to sumo server 15 | worker.env.connection.close() 16 | 17 | def test_PPO_agent(constants, device, loaded_model): 18 | _, max_neighborhood_size = get_intersection_neighborhoods(get_net_path(constants)) 19 | s_a = get_state_action_size(PER_AGENT_STATE_SIZE, GLOBAL_STATE_SIZE, ACTION_SIZE, max_neighborhood_size, constants) 20 | env = IntersectionsEnv(constants, device, 'vis_ppo', True, get_net_path(constants), vis=True) 21 | local_NN = NN_Model(s_a['s'], s_a['a'], constants['ppo']['hidden_layer_size'], device).to(device) 22 | local_NN.load_state_dict(loaded_model) 23 | worker = PPOWorker(constants, device, env, None, None, local_NN, None, 'ppo', dont_reset=True) 24 | test_worker(worker) 25 | 26 | # verbose means test prints at end of each batch of eps 27 | def test_rule_based_agent(constants, device): 28 | env = IntersectionsEnv(constants, device, 'vis_rule_based', True, get_net_path(constants), vis=True) 29 | # Check rule set 30 | rule_set_class = get_rule_set_class(constants['rule']['rule_set']) 31 | rule_set_params = deepcopy(constants['rule']['rule_set_params']) 32 | rule_set_params['phases'] = env.phases 33 | worker = RuleBasedWorker(constants, device, env, rule_set_class(rule_set_params, get_net_path(constants), constants), None, 'rule_based') 34 | test_worker(worker) 35 | 36 | def run(load_model_file=None): 37 | # Load constants 38 | constants = load_constants('constants/constants.json') 39 | 40 | loaded_model = None 41 | if load_model_file: 42 | loaded_model = torch.load('models/saved_models/' + load_model_file) 43 | 44 | if loaded_model: 45 | test_PPO_agent(constants, device, loaded_model) 46 | else: 47 | assert constants['agent']['agent_type'] == 'rule' 48 | test_rule_based_agent(constants, device) 49 | 50 | if __name__ == '__main__': 51 | # we need to import python modules from the $SUMO_HOME/tools directory 52 | if 'SUMO_HOME' in os.environ: 53 | tools = os.path.join(os.environ['SUMO_HOME'], 'tools') 54 | sys.path.append(tools) 55 | else: 56 | sys.exit("please declare environment variable 'SUMO_HOME'") 57 | 58 | device = torch.device('cpu') 59 | 60 | run(load_model_file='grid_1-1.pt') 61 | -------------------------------------------------------------------------------- /workers/ppo_worker.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from utils.utils import Storage, tensor, random_sample, ensure_shared_grads 6 | from workers.worker import Worker 7 | 8 | 9 | # Code adapted from: Shangtong Zhang (https://github.com/ShangtongZhang) 10 | class PPOWorker(Worker): 11 | def __init__(self, constants, device, env, data_collector, shared_NN, local_NN, optimizer, id, dont_reset=False): 12 | super(PPOWorker, self).__init__(constants, device, env, id, data_collector) 13 | self.NN = local_NN 14 | self.shared_NN = shared_NN 15 | if not dont_reset: # for the vis agent script this messes things up 16 | self.state = self.env.reset() 17 | self.ep_step = 0 18 | self.opt = optimizer 19 | self.num_agents = len(env.intersections) if not self.constants['agent']['single_agent'] else 1 20 | self.NN.eval() 21 | 22 | def _reset(self): 23 | pass 24 | 25 | def _get_prediction(self, states, actions=None, ep_step=None): 26 | return self.NN(tensor(states, self.device), actions) 27 | 28 | def _get_action(self, prediction): 29 | return prediction['a'].cpu().numpy() 30 | 31 | def _copy_shared_model_to_local(self): 32 | self.NN.load_state_dict(self.shared_NN.state_dict()) 33 | 34 | def _stack(self, val): 35 | assert not isinstance(val, list) 36 | return np.stack([val] * self.num_agents) 37 | 38 | def train_rollout(self, total_step): 39 | storage = Storage(self.constants['episode']['rollout_length']) 40 | state = np.copy(self.state) 41 | step_times = [] 42 | # Sync. 43 | self._copy_shared_model_to_local() 44 | rollout_amt = 0 # keeps track of the amt in the current rollout storage 45 | while rollout_amt < self.constants['episode']['rollout_length']: 46 | start_step_time = time.time() 47 | prediction = self._get_prediction(state) 48 | action = self._get_action(prediction) 49 | next_state, reward, done = self.env.step(action, self.ep_step, get_global_reward=False) 50 | self.ep_step += 1 51 | if done: 52 | # Sync local model with shared model at start of each ep 53 | self._copy_shared_model_to_local() 54 | self.ep_step = 0 55 | 56 | # This is to stabilize learning, since the first some amt of states wont represent the env very well 57 | # since it will be more empty than normal 58 | if self.ep_step > self.constants['episode']['warmup_ep_steps']: 59 | storage.add(prediction) 60 | storage.add({'r': tensor(reward, self.device).unsqueeze(-1), 61 | 'm': tensor(self._stack(1 - done), self.device).unsqueeze(-1), 62 | 's': tensor(state, self.device)}) 63 | rollout_amt += 1 64 | state = np.copy(next_state) 65 | 66 | total_step += 1 67 | 68 | end_step_time = time.time() 69 | step_times.append(end_step_time - start_step_time) 70 | 71 | self.state = np.copy(state) 72 | 73 | prediction = self._get_prediction(state) 74 | storage.add(prediction) 75 | storage.placeholder() 76 | 77 | advantages = tensor(np.zeros((self.num_agents, 1)), self.device) 78 | returns = prediction['v'].detach() 79 | for i in reversed(range(self.constants['episode']['rollout_length'])): 80 | # Disc. Return 81 | returns = storage.r[i] + self.constants['ppo']['discount'] * storage.m[i] * returns 82 | # GAE 83 | td_error = storage.r[i] + self.constants['ppo']['discount'] * storage.m[i] * storage.v[i + 1] - storage.v[i] 84 | advantages = advantages * self.constants['ppo']['gae_tau'] * self.constants['ppo']['discount'] * storage.m[i] + td_error 85 | storage.adv[i] = advantages.detach() 86 | storage.ret[i] = returns.detach() 87 | 88 | 89 | states, actions, log_probs_old, returns, advantages = storage.cat(['s', 'a', 'log_pi_a', 'ret', 'adv']) 90 | actions = actions.detach() 91 | log_probs_old = log_probs_old.detach() 92 | advantages = (advantages - advantages.mean()) / advantages.std() 93 | 94 | # Train 95 | self.NN.train() 96 | batch_times = [] 97 | train_pred_times = [] 98 | for _ in range(self.constants['ppo']['optimization_epochs']): 99 | # Sync. at start of each epoch 100 | self._copy_shared_model_to_local() 101 | sampler = random_sample(np.arange(states.size(0)), self.constants['ppo']['minibatch_size']) 102 | for batch_indices in sampler: 103 | start_batch_time = time.time() 104 | 105 | batch_indices = tensor(batch_indices, self.device).long() 106 | 107 | # Important Node: these are tensors but dont have a grad 108 | sampled_states = states[batch_indices] 109 | sampled_actions = actions[batch_indices] 110 | sampled_log_probs_old = log_probs_old[batch_indices] 111 | sampled_returns = returns[batch_indices] 112 | sampled_advantages = advantages[batch_indices] 113 | 114 | start_pred_time = time.time() 115 | prediction = self._get_prediction(sampled_states, sampled_actions) 116 | end_pred_time = time.time() 117 | train_pred_times.append(end_pred_time - start_pred_time) 118 | 119 | # Calc. Loss 120 | ratio = (prediction['log_pi_a'] - sampled_log_probs_old).exp() 121 | 122 | obj = ratio * sampled_advantages 123 | obj_clipped = ratio.clamp(1.0 - self.constants['ppo']['ppo_ratio_clip'], 124 | 1.0 + self.constants['ppo']['ppo_ratio_clip']) * sampled_advantages 125 | 126 | # policy loss and value loss are scalars 127 | policy_loss = -torch.min(obj, obj_clipped).mean() - self.constants['ppo']['entropy_weight'] * prediction['ent'].mean() 128 | 129 | value_loss = self.constants['ppo']['value_loss_coef'] * (sampled_returns - prediction['v']).pow(2).mean() 130 | 131 | self.opt.zero_grad() 132 | (policy_loss + value_loss).backward() 133 | if self.constants['ppo']['clip_grads']: 134 | nn.utils.clip_grad_norm_(self.NN.parameters(), self.constants['ppo']['gradient_clip']) 135 | ensure_shared_grads(self.NN, self.shared_NN) 136 | self.opt.step() 137 | end_batch_time = time.time() 138 | batch_times.append(end_batch_time - start_batch_time) 139 | self.NN.eval() 140 | return total_step 141 | -------------------------------------------------------------------------------- /workers/rule_worker.py: -------------------------------------------------------------------------------- 1 | from workers.worker import Worker 2 | 3 | 4 | class RuleBasedWorker(Worker): 5 | def __init__(self, constants, device, env, rule_set, data_collector, id): 6 | super(RuleBasedWorker, self).__init__(constants, device, env, id, data_collector) 7 | self.rule_set = rule_set 8 | 9 | def _reset(self): 10 | self.rule_set.reset() 11 | 12 | def _get_prediction(self, states, actions=None, ep_step=None): 13 | assert actions == None # Doesnt want actions 14 | assert isinstance(states, dict) 15 | return {'a': self.rule_set(states)} 16 | 17 | def _get_action(self, prediction): 18 | return prediction['a'] 19 | 20 | def _copy_shared_model_to_local(self): 21 | pass 22 | -------------------------------------------------------------------------------- /workers/worker.py: -------------------------------------------------------------------------------- 1 | import xml.etree.ElementTree as ET 2 | import numpy as np 3 | from collections import defaultdict 4 | from copy import deepcopy 5 | 6 | 7 | # Parent - abstract 8 | class Worker: 9 | def __init__(self, constants, device, env, id, data_collector): 10 | self.constants = constants 11 | self.device = device 12 | self.env = env 13 | self.id = id 14 | self.data_collector = data_collector 15 | 16 | def _reset(self): 17 | raise NotImplementedError 18 | 19 | def _get_prediction(self, states, actions=None, ep_step=None): 20 | raise NotImplementedError 21 | 22 | def _get_action(self, prediction): 23 | raise NotImplementedError 24 | 25 | def _copy_shared_model_to_local(self): 26 | raise NotImplementedError 27 | 28 | # This method gets the ep results i really care about (gets averages) 29 | def _read_edge_results(self, file, results): 30 | tree = ET.parse(file) 31 | root = tree.getroot() 32 | for c in root.iter('edge'): 33 | for k, v in list(c.attrib.items()): 34 | if k == 'id': continue 35 | results[k].append(float(v)) 36 | return results 37 | 38 | # ep_count only for test 39 | def eval_episode(self, results): 40 | ep_rew = 0 41 | step = 0 42 | state = self.env.reset() 43 | self._reset() 44 | while True: 45 | prediction = self._get_prediction(state, ep_step=step) 46 | action = self._get_action(prediction) 47 | next_state, reward, done = self.env.step(action, step, get_global_reward=True) 48 | ep_rew += reward 49 | if done: 50 | break 51 | if not isinstance(state, dict): 52 | state = np.copy(next_state) 53 | else: 54 | state = deepcopy(next_state) 55 | step += 1 56 | results = self._read_edge_results('data/edgeData_{}.out.xml'.format(self.id), results) 57 | results['rew'].append(ep_rew) 58 | return results 59 | 60 | def eval_episodes(self, current_rollout, model_state=None, ep_count=None): 61 | self._copy_shared_model_to_local() 62 | results = defaultdict(list) 63 | for ep in range(self.constants['episode']['eval_num_eps']): 64 | results = self.eval_episode(results) 65 | # So I could add the data from each ep to the data collector but I like the smoothing effect this has 66 | if current_rollout: 67 | results['rollout'] = [current_rollout] 68 | results = {k: sum(v) / len(v) for k, v in list(results.items())} 69 | self.data_collector.collect_ep(results, model_state, ep_count+self.constants['episode']['eval_num_eps'] if ep_count is not None else None) 70 | --------------------------------------------------------------------------------