├── img ├── cnn.gif └── eval │ ├── graphs.png │ └── Mean Queue Length, 800 cars.png ├── model ├── dqn_cnn_final.pt └── dqn_fc_final.pt ├── config.yml ├── train.py ├── src ├── cfg │ ├── sumo_config.sumocfg │ └── environment.net.xml ├── test.py ├── memory.py ├── data_storage.py ├── dqn.py ├── training.py ├── generator.py └── env.py ├── LICENSE ├── .gitignore └── README.md /img/cnn.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suessmann/intelligent_traffic_lights/HEAD/img/cnn.gif -------------------------------------------------------------------------------- /img/eval/graphs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suessmann/intelligent_traffic_lights/HEAD/img/eval/graphs.png -------------------------------------------------------------------------------- /model/dqn_cnn_final.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suessmann/intelligent_traffic_lights/HEAD/model/dqn_cnn_final.pt -------------------------------------------------------------------------------- /model/dqn_fc_final.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suessmann/intelligent_traffic_lights/HEAD/model/dqn_fc_final.pt -------------------------------------------------------------------------------- /img/eval/Mean Queue Length, 800 cars.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suessmann/intelligent_traffic_lights/HEAD/img/eval/Mean Queue Length, 800 cars.png -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | # learning params 2 | learning_rate: 0.00001 3 | gamma: 0.95 4 | buffer_limit: 500000 5 | batch_size: 128 6 | sim_len: 4500 7 | mem_refill: 1000 8 | epochs: 1600 9 | n_cars: 1000 10 | weights_path: '' 11 | 12 | # SUMO params 13 | sumoBinary: "" 14 | sumoCmd: "" 15 | sumoTools: "" 16 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | from src.training import training 4 | 5 | parser = argparse.ArgumentParser(description='Start training') 6 | parser.add_argument("-c", "--config", type=str, help="Path to the config file") 7 | 8 | if __name__ == '__main__': 9 | args = vars(parser.parse_args()) 10 | 11 | with open(args['config'], 'r') as ymfile: 12 | cfg = yaml.full_load(ymfile) 13 | 14 | print(cfg) 15 | training(cfg) 16 | 17 | -------------------------------------------------------------------------------- /src/cfg/sumo_config.sumocfg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | from dqn import DQNetwork 2 | from env import SumoIntersection 3 | import torch 4 | 5 | BATCH_SIZE = 32 6 | SIM_LEN = 4500 7 | WEIGHTS_PATH = '/Users/suess_mann/wd/tcqdrl/tca/saved_model/dqn_1246.pt' 8 | 9 | sumoBinary = "/usr/local/opt/sumo/share/sumo/bin/sumo-gui" 10 | sumoCmd = "/Users/suess_mann/wd/tcqdrl/tca/src/cfg/sumo_config.sumocfg" 11 | 12 | 13 | if __name__ == '__main__': 14 | q = DQNetwork() 15 | q.eval() 16 | try: 17 | q.load_state_dict(torch.load(WEIGHTS_PATH)) 18 | except FileNotFoundError: 19 | print('No model weights found') 20 | 21 | env = SumoIntersection(sumoBinary, sumoCmd, SIM_LEN, 1600) 22 | 23 | state, _, _, _ = env.step(0) 24 | done = False 25 | 26 | while not done: 27 | a = q.predict(state.as_tuple, 0.00) 28 | 29 | s_prime, r, done, info = env.step(a) 30 | print(r) 31 | s = s_prime 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 suessmann 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/memory.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | from src.data_storage import StoreState 5 | from collections import deque 6 | 7 | 8 | class DQNBuffer: 9 | def __init__(self, max_n, pos): 10 | self._buffer = deque(maxlen=max_n) 11 | self._buffer_positive = deque(maxlen=max_n) 12 | self.data = StoreState 13 | self.pos = pos 14 | 15 | def add(self, sample): 16 | self._buffer.append(sample) 17 | 18 | def add_positive(self, sample): 19 | self._buffer_positive.append(sample) 20 | 21 | def sample(self, n): 22 | mini_batch = random.sample(self._buffer, int(np.ceil(n * (1-self.pos)))) + random.sample(self._buffer_positive, int(np.floor(n * self.pos))) 23 | trans = StoreState(*zip(*mini_batch)) # unzipping 24 | 25 | state, state_prime = trans.concat() 26 | a = torch.tensor(trans.action).unsqueeze(1) 27 | r = torch.tensor(trans.reward).unsqueeze(1) 28 | done_mask = torch.tensor(trans.done_mask).unsqueeze(1) 29 | 30 | return state, a, r, state_prime, done_mask 31 | 32 | def refill(self): 33 | self._buffer.clear() 34 | 35 | @property 36 | def size(self): 37 | return len(self._buffer) 38 | -------------------------------------------------------------------------------- /src/data_storage.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | 4 | 5 | @dataclass 6 | class StoreState: 7 | """ 8 | Class for storing state variables 9 | 10 | *n_prefix* : current state 11 | p_* : prime state 12 | """ 13 | position: torch.Tensor = torch.zeros((1, 4, 4, 16)) 14 | speed: torch.Tensor = torch.zeros((1, 4, 4, 16)) 15 | tl: torch.Tensor = torch.zeros((1, 1, 4)) 16 | p_position: torch.Tensor = torch.zeros((1, 4, 4, 16)) 17 | p_speed: torch.Tensor = torch.zeros((1, 4, 4, 16)) 18 | p_tl: torch.Tensor = torch.zeros((1, 1, 4)) 19 | 20 | action: 'typing.Any' = 0 21 | reward: 'typing.Any' = 0 22 | done_mask: float = 1.0 23 | 24 | def concat(self): 25 | self.position = torch.cat(self.position) 26 | self.speed = torch.cat(self.speed) 27 | self.tl = torch.cat(self.tl) 28 | self.p_position = torch.cat(self.p_position) 29 | self.p_speed = torch.cat(self.p_speed) 30 | self.p_tl = torch.cat(self.p_tl) 31 | 32 | return (self.position, self.speed, self.tl), \ 33 | (self.p_position, self.p_speed, self.p_tl) 34 | 35 | @property 36 | def as_tuple(self): 37 | return (self.position, self.speed, self.tl) 38 | 39 | def swap(self): 40 | self.position = self.p_position 41 | self.speed = self.p_speed 42 | self.tl = self.p_tl 43 | 44 | self.p_position = torch.zeros((1, 4, 4, 16)) # depth, rows, cols 45 | self.p_speed = torch.zeros((1, 4, 4, 16)) 46 | self.p_tl = torch.zeros((1, 1, 4)) 47 | 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 2 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 3 | 4 | # User-specific stuff 5 | .idea/**/workspace.xml 6 | .idea/**/tasks.xml 7 | .idea/**/usage.statistics.xml 8 | .idea/**/dictionaries 9 | .idea/**/shelf 10 | .idea/ 11 | 12 | # Generated files 13 | .idea/**/contentModel.xml 14 | 15 | # Sensitive or high-churn files 16 | .idea/**/dataSources/ 17 | .idea/**/dataSources.ids 18 | .idea/**/dataSources.local.xml 19 | .idea/**/sqlDataSources.xml 20 | .idea/**/dynamic.xml 21 | .idea/**/uiDesigner.xml 22 | .idea/**/dbnavigator.xml 23 | 24 | # Gradle 25 | .idea/**/gradle.xml 26 | .idea/**/libraries 27 | 28 | # Gradle and Maven with auto-import 29 | # When using Gradle or Maven with auto-import, you should exclude module files, 30 | # since they will be recreated, and may cause churn. Uncomment if using 31 | # auto-import. 32 | # .idea/artifacts 33 | # .idea/compiler.xml 34 | # .idea/jarRepositories.xml 35 | # .idea/modules.xml 36 | # .idea/*.iml 37 | # .idea/modules 38 | # *.iml 39 | # *.ipr 40 | 41 | # CMake 42 | cmake-build-*/ 43 | 44 | # Mongo Explorer plugin 45 | .idea/**/mongoSettings.xml 46 | 47 | # File-based project format 48 | *.iws 49 | 50 | # IntelliJ 51 | out/ 52 | 53 | # mpeltonen/sbt-idea plugin 54 | .idea_modules/ 55 | 56 | # JIRA plugin 57 | atlassian-ide-plugin.xml 58 | 59 | # Cursive Clojure plugin 60 | .idea/replstate.xml 61 | 62 | # Crashlytics plugin (for Android Studio and IntelliJ) 63 | com_crashlytics_export_strings.xml 64 | crashlytics.properties 65 | crashlytics-build.properties 66 | fabric.properties 67 | 68 | # Editor-based Rest Client 69 | .idea/httpRequests 70 | 71 | # Android studio 3.1+ serialized cache file 72 | .idea/caches/build_file_checksums.ser 73 | 74 | .ipynb_checkpoints 75 | src/cfg/ 76 | 77 | .DS_Store 78 | __pycache__/ 79 | 80 | src/logs 81 | -------------------------------------------------------------------------------- /src/dqn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import random 5 | 6 | 7 | class DQNetwork(nn.Module): 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | self.features1 = nn.Sequential( 13 | nn.Conv2d(1, 16, 4, stride=2, padding=1), 14 | nn.ReLU(), 15 | nn.Conv2d(16, 32, 2, stride=1), 16 | nn.ReLU() 17 | ) 18 | self.features2 = nn.Sequential( 19 | nn.Conv2d(1, 16, 4, stride=2, padding=1), 20 | nn.ReLU(), 21 | nn.Conv2d(16, 32, 2, stride=1), 22 | nn.ReLU() 23 | ) 24 | # self.features3 = nn.Linear(4, 4) 25 | 26 | self.linear_relu1 = nn.Sequential( 27 | nn.Linear(32*7*7 + 32*7*7 + 4, 128), 28 | nn.ReLU(), 29 | # nn.Dropout(0.2), 30 | nn.Linear(128, 64), 31 | nn.ReLU(), 32 | # nn.Dropout(0.2), 33 | nn.Linear(64, 32), 34 | nn.ReLU(), 35 | # nn.Dropout(0.2) 36 | ) 37 | 38 | self.classifier = nn.Sequential( 39 | nn.Linear(32, 4), 40 | # nn.Softmax(dim=-1) 41 | # nn.ReLU() 42 | ) 43 | 44 | 45 | def forward(self, state): 46 | x1, x2, x3 = state 47 | 48 | x1 = self.features1(x1.view(x1.size(0), 1, -1, 16)) 49 | x2 = self.features2(x2.view(x1.size(0), 1, -1, 16)) 50 | # x3 = self.features3(x3) 51 | 52 | x1 = x1.view(x1.size(0), -1) 53 | x2 = x2.view(x2.size(0), -1) 54 | x3 = x3.view(x3.size(0), -1) 55 | 56 | x = torch.cat((x1, x2, x3), dim=1) 57 | 58 | x = self.linear_relu1(x) 59 | x = self.classifier(x) 60 | 61 | return x 62 | 63 | def predict(self, state, eps): 64 | prob = random.random() 65 | if prob < eps: 66 | return random.randint(0, 3) 67 | else: 68 | act = self.forward(state) 69 | return act.argmax().item() 70 | 71 | 72 | class FCQNetwork(nn.Module): 73 | 74 | def __init__(self): 75 | super().__init__() 76 | 77 | self.features1 = nn.Sequential( 78 | nn.Linear(256, 256), 79 | nn.ReLU(), 80 | nn.Dropout(0.2), 81 | nn.Linear(256, 128), 82 | nn.ReLU() 83 | ) 84 | self.features2 = nn.Sequential( 85 | nn.Linear(256, 256), 86 | nn.ReLU(), 87 | nn.Dropout(0.2), 88 | nn.Linear(256, 128), 89 | nn.ReLU() 90 | ) 91 | self.linear_q = nn.Linear(128 + 128 + 4, 4) 92 | 93 | def forward(self, state): 94 | x1, x2, x3 = state 95 | 96 | x1 = self.features1(x1.view(x1.size(0), 1, 1, -1)) 97 | x2 = self.features2(x2.view(x1.size(0), 1, 1, -1)) 98 | # x3 = self.features3(x3) 99 | 100 | x1 = x1.view(x1.size(0), -1) 101 | x2 = x2.view(x2.size(0), -1) 102 | x3 = x3.view(x3.size(0), -1) 103 | 104 | x = torch.cat((x1, x2, x3), dim=1) 105 | 106 | x = self.linear_q(x) 107 | 108 | return x 109 | 110 | def predict(self, state, eps): 111 | prob = random.random() 112 | if prob < eps: 113 | return random.randint(0, 3) 114 | else: 115 | act = self.forward(state) 116 | return act.argmax().item() 117 | -------------------------------------------------------------------------------- /src/training.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | from src.dqn import DQNetwork, FCQNetwork 4 | from src.env import SumoIntersection 5 | from src.memory import DQNBuffer 6 | from src.data_storage import StoreState 7 | 8 | import torch 9 | import torch.nn as nn 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | 14 | def weights_init(m): 15 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 16 | torch.nn.init.xavier_uniform(m.weight.data) 17 | nn.init.constant(m.bias.data, 0) 18 | 19 | def soft_update(local_model, target_model, tau): 20 | for target_param, local_param in zip(target_model.parameters(), 21 | local_model.parameters()): 22 | target_param.data.copy_(tau*local_param.data + (1-tau)*target_param.data) 23 | 24 | def train_net(q, q_target, memory, optimizer, batch_size, gamma): 25 | state, a, r, state_prime, done_mask = memory.sample(batch_size) 26 | 27 | q_out = q(state) 28 | q_a = q_out.gather(1, a) 29 | max_q_prime = q_target(state_prime).max(1)[0].unsqueeze(1) 30 | target = r + gamma * max_q_prime * done_mask 31 | criterion = nn.MSELoss() 32 | loss = criterion(q_a, target) 33 | mse = loss.item() 34 | 35 | optimizer.zero_grad() 36 | loss.backward() 37 | for param in q.parameters(): 38 | param.grad.data.clamp_(-1, 1) 39 | optimizer.step() 40 | 41 | def training(config): 42 | learning_rate = config['learning_rate'] 43 | gamma = config['gamma'] 44 | buffer_limit = config['buffer_limit'] 45 | batch_size = config['batch_size'] 46 | sim_len = config['sim_len'] 47 | mem_refill = config['mem_refill'] 48 | epochs = config['epochs'] 49 | n_cars = config['n_cars'] 50 | weights_path = config['weights_path'] 51 | 52 | sumoBinary = config['sumoBinary'] 53 | sumoCmd = config['sumoCmd'] 54 | sys.path.append(os.path.join(config['sumoTools'])) 55 | 56 | q = DQNetwork() 57 | 58 | try: 59 | q.load_state_dict(torch.load(weights_path)) 60 | except FileNotFoundError: 61 | q.apply(weights_init) 62 | print('No model weights found, initializing xavier_uniform') 63 | 64 | q_target = DQNetwork() 65 | q_target.load_state_dict(q.state_dict()) 66 | 67 | memory = DQNBuffer(buffer_limit, 0.) 68 | 69 | env = SumoIntersection(sumoBinary, sumoCmd, sim_len, n_cars) 70 | optimizer = torch.optim.RMSprop(q.parameters(), lr=learning_rate) 71 | 72 | state = StoreState() 73 | min_wait = float('inf') 74 | total_steps=0 75 | pbar = tqdm(total=epochs) 76 | 77 | for epoch in np.arange(epochs): 78 | eps = max(0.01, 0.07 - 0.01*(epoch/200)) 79 | step = 0 80 | 81 | if epoch != 0: 82 | env.reset() 83 | pbar.set_description(f'EPOCH: {epoch}, mean reward: {info[6]}, mean waiting time: {info[2]}') 84 | 85 | state, _, _, _ = env.step(0) 86 | done = False 87 | 88 | done_mask = 1.0 89 | mse = 0 90 | while not done: 91 | with torch.no_grad(): 92 | a = q.predict(state.as_tuple, eps) 93 | state, r, done, info = env.step(a) # storing prime state 94 | 95 | if done: 96 | done_mask = 0.0 97 | 98 | if r != 0 or step > 60: 99 | memory.add((state.position, state.speed, 100 | state.tl, state.p_position, 101 | state.p_speed, state.p_tl, 102 | a, r, done_mask)) 103 | 104 | state.swap() # state = state_prime 105 | 106 | if memory.size > batch_size: 107 | train_net(q, q_target, memory, optimizer, batch_size, gamma) 108 | 109 | step += 1 110 | total_steps += 1 111 | 112 | # clear memory each mem_refill epoch 113 | if epoch%mem_refill == 0 and epoch != 0: 114 | memory.refill() 115 | pbar.update(1) 116 | 117 | # soft update weights at the end of epoch 118 | soft_update(q, q_target, 0.01) 119 | 120 | if info[4]/info[5] < min_wait and epoch > 50: 121 | torch.save(q.state_dict(), f'../model/dqn_{epoch}.pt') 122 | min_wait = info[4]/info[5] 123 | if epoch == 1599: 124 | torch.save(q.state_dict(), f'../model/dqn_{epoch}_final.pt') 125 | 126 | print('finished training') 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Status: Archived (no further updates in repo) 2 | 3 | # Traffic Signal Control with Reinforcement Learning 4 | This repository contains my project on intelligent traffic lights control. It relies on a Deep Q-network algorithm 5 | with a few convolutional layers. 6 | 7 | Below you may find detailed explanation of files in here, also some comments on model, final result and work to be done. 8 | 9 | The preview of final result. 10 | ![sumo_dqn](./img/cnn.gif) 11 | *CNN-DQN, 1800 cars*
[Youtube video with comparison](https://www.youtube.com/watch?v=yIhDWvMrWFo) 12 | --- 13 | 14 | # Contents 15 | * `src/data_storage.py` – a dataclass for storing environment data in one place 16 | * `src/dqn.py` – a PyTorch DQN models 17 | * `src/env.py` – an API to communicate with simulator 18 | * `src/generator.py` – traffic generation resides here 19 | * `src/memory.py` – the model has replay memory implemented, it resides here 20 | * `src/test.py` – yet fairly useless file 21 | * `src/training.py` – training happens here 22 | * `cfg/` – a directory with SUMO configs. 23 | 24 | if you wish to train your own model, edit `config.yaml`, then run `train.py -c config.yaml`. 25 | 26 | # Model and Environment 27 | To simulate the intersection, Simulation of Urban MObility ([SUMO](https://www.eclipse.org/sumo/)) is used. To 28 | communicate with Python, TraCI API is provided as one of the tools of SUMO. 29 | 30 | To perform Reinforcement Learning, Deep Q-learning algorithm was chosen. In order for an agent to perform actions 31 | , firstly it observes the situation on the road, namely, it constructs two matrices. The first matrix divides the 32 | road into 16 segments of 7 meters each and checks whether there is a car in a segment. The second matrix checks the 33 | speed of a car in that segment. Those two matrices are fed into Conv-NN, then concatenated with vector of traffic 34 | signal state and magic happens! 35 | 36 | Since the nature of DQN is very unstable, two methods were used to stabilize its convergence. Firstly, Replay Memory 37 | was implemented, which stores past actions and allows to use training in batches over them. The second thing is the 38 | concept of a target Q-network. It allows not to chase the "moving target" while training by initializing two networks 39 | with the same weights. While optimizing, the weights of target network are kept as they are and updated only after N 40 | iterations. 41 | 42 | The final models are kept under `model` directory. Two models, namely convolutional and fully-connected ones are 43 | trained, both with 1600 epochs, 4500 steps each (1.15 hours). Those models were trained on CPU and it took around 24 44 | hours for each of them. 45 | 46 | # Results 47 | For testing purposes, three scenarios were generated with 800, 1300, 1800 cars. They represent three possible states 48 | of the road during low load, mid and peak-hour accordingly. The results for each are presented in the table below. 49 | 50 | | Metric (mean, std, n_cars=800) | Queue cars | Cumulative Waiting Time sec | 51 | |:-:|:-:|:-:| 52 | | CNN-DQN | (1.99, 2.27) | (4.26e04, 2.76e04) | 53 | | FC-DQN | (47.25, 18.04) | (1.23e06, 1.42e06) | 54 | | 90sec. | (87.34, 58.42) | (1.09e07, 1.11e07) | 55 | 56 | | Metric (mean, std, n_cars=1300) | Queue cars | Cumulative Waiting Time sec | 57 | |:------------------------------:|:--------------:|:---------------------------:| 58 | | CNN-DQN | (5.7, 9.1) | (3.56e05, 1.89e05) | 59 | | FC-DQN | (100.66, 34.3) | (6.38e07, 7.02e07) | 60 | | 90sec. | (144.44, 91.7) | (1.87e07, 1.91e07) | 61 | 62 | | Metric (mean, std, n_cars=1800) | Queue cars | Cumulative Waiting Time sec | 63 | |:-:|:-:|:-:| 64 | | CNN-DQN | (26.41, 42.58) | (3.98e06, 2.29e06) | 65 | | FC-DQN | (174.89, 68.83) | (1.21e08,1.33e08) | 66 | | 90sec. | (235.51, 159.42) | (2.78e07, 3.01e07) | 67 | 68 | ![800_que](./img/eval/graphs.png) 69 | 70 | When the load is at peak, on average 117 seconds a vehicle waits in the que (CNN). That is 20% better than a steady 71 | 90 sec cycle. As it can be seen, FC-DQN cannot manage medium road load, however CNN-DQN manages to operate with 72 | high load with no problem. 73 | 74 | # References 75 | Genders, Wade, and Saiedeh Razavi. "Using a deep reinforcement learning agent for traffic signal control." arXiv preprint arXiv:1611.01142 (2017). 76 | 77 | Wade Genders & Saiedeh Razavi (2019): Asynchronous n-step Q- learning adaptive traffic signal control, Journal of Intelligent Transportation Systems, DOI: 10.1080/15472450.2018.1491003 78 | -------------------------------------------------------------------------------- /src/generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats 3 | import math 4 | 5 | 6 | def generate(max_steps, n_cars): 7 | # np.random.seed(seed) # make tests reproducible 8 | n_cars = np.minimum(int(scipy.stats.invweibull.rvs(1, loc=n_cars, scale=12)), 1512) 9 | # the generation of cars is distributed according to a weibull distribution 10 | timings = np.random.weibull(2, n_cars) 11 | timings = np.sort(timings) 12 | 13 | # reshape the distribution to fit the interval 0:max_steps 14 | car_gen_steps = [] 15 | min_old = math.floor(timings[1]) 16 | max_old = math.ceil(timings[-1]) 17 | min_new = 0 18 | max_new = max_steps 19 | for value in timings: 20 | car_gen_steps = np.append(car_gen_steps, 21 | ((max_new - min_new) / (max_old - min_old)) * (value - max_old) + max_new) 22 | 23 | car_gen_steps = np.rint(car_gen_steps) # round every value to int -> effective steps when a car will be generated 24 | np.insert(car_gen_steps, 0, 0) 25 | # produce the file for cars generation, one car per line 26 | with open("./cfg/episode_routes.rou.xml", "r+") as routes: 27 | print(""" 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | """, file=routes) 41 | 42 | for car_counter, step in enumerate(car_gen_steps): 43 | straight_or_turn = np.random.uniform() 44 | if straight_or_turn < 0.75: # choose direction: straight or turn - 75% of times the car goes straight 45 | route_straight = np.random.randint(1, 5) # choose a random source & destination 46 | if route_straight == 1: 47 | print( 48 | ' ' % ( 49 | car_counter, step), file=routes) 50 | elif route_straight == 2: 51 | print( 52 | ' ' % ( 53 | car_counter, step), file=routes) 54 | elif route_straight == 3: 55 | print( 56 | ' ' % ( 57 | car_counter, step), file=routes) 58 | else: 59 | print( 60 | ' ' % ( 61 | car_counter, step), file=routes) 62 | else: # car that turn -25% of the time the car turns 63 | route_turn = np.random.randint(1, 9) # choose random source source & destination 64 | if route_turn == 1: 65 | print( 66 | ' ' % ( 67 | car_counter, step), file=routes) 68 | elif route_turn == 2: 69 | print( 70 | ' ' % ( 71 | car_counter, step), file=routes) 72 | elif route_turn == 3: 73 | print( 74 | ' ' % ( 75 | car_counter, step), file=routes) 76 | elif route_turn == 4: 77 | print( 78 | ' ' % ( 79 | car_counter, step), file=routes) 80 | elif route_turn == 5: 81 | print( 82 | ' ' % ( 83 | car_counter, step), file=routes) 84 | elif route_turn == 6: 85 | print( 86 | ' ' % ( 87 | car_counter, step), file=routes) 88 | elif route_turn == 7: 89 | print( 90 | ' ' % ( 91 | car_counter, step), file=routes) 92 | elif route_turn == 8: 93 | print( 94 | ' ' % ( 95 | car_counter, step), file=routes) 96 | 97 | print("", file=routes) 98 | return n_cars 99 | -------------------------------------------------------------------------------- /src/env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys, os 3 | from collections import deque 4 | from src.generator import generate 5 | from src.data_storage import StoreState 6 | 7 | if 'SUMO_HOME' in os.environ: 8 | tools = os.path.join(os.environ['SUMO_HOME'], 'tools') 9 | sys.path.append(tools) 10 | else: 11 | sys.exit("please declare environment variable 'SUMO_HOME'") 12 | 13 | import traci 14 | 15 | PHASE_NS_GREEN = 0 16 | PHASE_NS_YELLOW = 1 17 | PHASE_NSL_GREEN = 2 18 | PHASE_NSL_YELLOW = 3 19 | PHASE_EW_GREEN = 4 20 | PHASE_EW_YELLOW = 5 21 | PHASE_EWL_GREEN = 6 22 | PHASE_EWL_YELLOW = 7 23 | 24 | 25 | class SumoIntersection: 26 | def __init__(self, path_bin, path_cfg, max_steps, n_cars): 27 | self.n_cars_out = generate(max_steps, n_cars) 28 | self.sumoCmd = [path_bin, "-c", path_cfg, '--waiting-time-memory', "4500", '--no-step-log', 'true', '-W', 'true'] 29 | traci.start(self.sumoCmd) 30 | self.traci_api = traci.getConnection() 31 | 32 | self.path_cfg = path_cfg 33 | self.max_steps = max_steps 34 | self.time = 0 35 | self.old_phase = 0 36 | self.waiting_time = 0 37 | self.queue = 0 38 | 39 | self.done = False 40 | self.r_w = 0 41 | self.r_q = 0 42 | self.n_cars = n_cars 43 | 44 | self.r_epoch = 0 45 | self.w_epoch = 0 46 | self.q_epoch = 0 47 | 48 | 49 | self.waiting_time_float_av = deque(maxlen=max_steps) 50 | self.queue_float_av = deque(maxlen=max_steps) 51 | self.waiting_time_av = [] 52 | self.queue_av = [] 53 | self.r_av = [] 54 | 55 | self.data = StoreState() 56 | 57 | def _get_info(self): 58 | # self._get_time() 59 | self._get_length() 60 | 61 | self.waiting_time_float_av.append(self.waiting_time) 62 | self.queue_float_av.append(self.queue) 63 | # self.waiting_time_av.append(self.waiting_time) 64 | # self.queue_av.append(self.queue) 65 | 66 | self.w_epoch += self.waiting_time 67 | self.q_epoch += self.queue 68 | 69 | return [np.mean(self.waiting_time_float_av), \ 70 | np.mean(self.queue_float_av), \ 71 | self.w_epoch/self.time, \ 72 | self.q_epoch/self.time, 73 | self.w_epoch, 74 | self.q_epoch] 75 | 76 | def _tl_control(self, phase): 77 | if (self.time + 5) >= (self.max_steps - 10): 78 | self.done = True 79 | return 80 | 81 | #self.data.tl = torch.zeros((1, 1, 4)) 82 | 83 | if phase != self.old_phase: 84 | yellow_phase_code = self.old_phase * 2 + 1 85 | self.traci_api.trafficlight.setPhase("TL", yellow_phase_code) 86 | 87 | if self.traci_api.simulation.getMinExpectedNumber() == 0: 88 | self.done = True 89 | return 90 | 91 | self.traci_api.simulationStep(self.time + 5.) 92 | self.time += 5 93 | 94 | self.data.p_tl[0, 0, phase] = 1 95 | 96 | if phase == 0: 97 | self.traci_api.trafficlight.setPhase("TL", PHASE_NS_GREEN) 98 | elif phase == 1: 99 | self.traci_api.trafficlight.setPhase("TL", PHASE_NSL_GREEN) 100 | elif phase == 2: 101 | self.traci_api.trafficlight.setPhase("TL", PHASE_EW_GREEN) 102 | elif phase == 3: 103 | self.traci_api.trafficlight.setPhase("TL", PHASE_EWL_GREEN) 104 | 105 | self.traci_api.simulationStep(self.time + 2.) 106 | if self.traci_api.simulation.getMinExpectedNumber() == 0: 107 | self.done = True 108 | return 109 | 110 | self.time += 2 111 | 112 | def _get_length(self): 113 | halt_N = self.traci_api.edge.getLastStepHaltingNumber("N2TL") 114 | halt_S = self.traci_api.edge.getLastStepHaltingNumber("S2TL") 115 | halt_E = self.traci_api.edge.getLastStepHaltingNumber("E2TL") 116 | halt_W = self.traci_api.edge.getLastStepHaltingNumber("W2TL") 117 | self.queue = halt_N + halt_S + halt_E + halt_W 118 | 119 | def _get_time(self, car_id): 120 | self.waiting_time += self.traci_api.vehicle.getAccumulatedWaitingTime(car_id) 121 | 122 | def _get_state(self): 123 | car_list = self.traci_api.vehicle.getIDList() 124 | lanes_list = np.array(['W2TL_0', 'W2TL_1', 'W2TL_2', 'W2TL_3', 125 | 'E2TL_0', 'E2TL_1', 'E2TL_2', 'E2TL_3', 126 | 'N2TL_0', 'N2TL_1', 'N2TL_2', 'N2TL_3', 127 | 'S2TL_0', 'S2TL_1', 'S2TL_2', 'S2TL_3']) 128 | 129 | for car_id in car_list: 130 | lane_id = self.traci_api.vehicle.getLaneID(car_id) 131 | 132 | 133 | if lane_id not in lanes_list: 134 | continue 135 | 136 | self._get_time(car_id) 137 | 138 | lane_pos = self.traci_api.vehicle.getLanePosition(car_id) 139 | car_speed = self.traci_api.vehicle.getSpeed(car_id) / 13.89 140 | lane_pos = 750 - lane_pos 141 | 142 | if lane_pos > 112: 143 | continue 144 | 145 | # distance in meters from the traffic light -> mapping into cells 146 | for i, dist in enumerate(np.arange(7, 113, 7)): 147 | if lane_pos < dist: 148 | lane_cell = i 149 | break 150 | 151 | pos = np.where(np.array(lanes_list) == lane_id)[0][0] 152 | 153 | r = int(lane_id[-1]) 154 | c = lane_cell 155 | 156 | for i, clown in enumerate(np.arange(4, 17, 4)): 157 | if pos < clown: 158 | d = i 159 | break 160 | 161 | self.data.p_position[0, d, r, c] = 1 # depth, row, col 162 | self.data.p_speed[0, d, r, c] = car_speed 163 | 164 | def _get_reward(self): 165 | if self.time < 4: 166 | return 0 167 | 168 | self.r_w = 0.85 * self.waiting_time_float_av[-2] - self.waiting_time_float_av[-1] 169 | 170 | # if self.r_w < 0: 171 | # self.r_w *= 1.5 172 | # 173 | # self.r_av.append(self.r_w) 174 | 175 | return self.r_w 176 | 177 | def reset(self): 178 | self.n_cars_out = generate(self.max_steps, self.n_cars) 179 | 180 | self.time = 0 181 | self.old_phase = 0 182 | self.done = False 183 | self.r_epoch = 0 184 | self.w_epoch = 0 185 | self.q_epoch = 0 186 | 187 | self.traci_api.load(['-c', self.path_cfg, '--waiting-time-memory', "4500", '--no-step-log', 'true', '-W', 'true']) 188 | self.traci_api.simulationStep() 189 | 190 | 191 | def step(self, a): 192 | self._tl_control(a) # self.traci_api.steps are here 193 | 194 | self.old_phase = a 195 | self.waiting_time = 0 196 | 197 | self._get_state() 198 | info = self._get_info() # [0] av time, [1]: av queue 199 | 200 | self.r = self._get_reward() 201 | 202 | self.r_av.append(self.r) 203 | self.r_epoch += self.r 204 | 205 | info.append(self.r_epoch/self.time) 206 | info.append(self.r_epoch) 207 | info.append(self.time) 208 | if self.done: 209 | info.append(self.n_cars_out) 210 | 211 | return self.data, self.r, self.done, info 212 | 213 | -------------------------------------------------------------------------------- /src/cfg/environment.net.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | --------------------------------------------------------------------------------