├── img
├── cnn.gif
└── eval
│ ├── graphs.png
│ └── Mean Queue Length, 800 cars.png
├── model
├── dqn_cnn_final.pt
└── dqn_fc_final.pt
├── config.yml
├── train.py
├── src
├── cfg
│ ├── sumo_config.sumocfg
│ └── environment.net.xml
├── test.py
├── memory.py
├── data_storage.py
├── dqn.py
├── training.py
├── generator.py
└── env.py
├── LICENSE
├── .gitignore
└── README.md
/img/cnn.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/suessmann/intelligent_traffic_lights/HEAD/img/cnn.gif
--------------------------------------------------------------------------------
/img/eval/graphs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/suessmann/intelligent_traffic_lights/HEAD/img/eval/graphs.png
--------------------------------------------------------------------------------
/model/dqn_cnn_final.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/suessmann/intelligent_traffic_lights/HEAD/model/dqn_cnn_final.pt
--------------------------------------------------------------------------------
/model/dqn_fc_final.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/suessmann/intelligent_traffic_lights/HEAD/model/dqn_fc_final.pt
--------------------------------------------------------------------------------
/img/eval/Mean Queue Length, 800 cars.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/suessmann/intelligent_traffic_lights/HEAD/img/eval/Mean Queue Length, 800 cars.png
--------------------------------------------------------------------------------
/config.yml:
--------------------------------------------------------------------------------
1 | # learning params
2 | learning_rate: 0.00001
3 | gamma: 0.95
4 | buffer_limit: 500000
5 | batch_size: 128
6 | sim_len: 4500
7 | mem_refill: 1000
8 | epochs: 1600
9 | n_cars: 1000
10 | weights_path: ''
11 |
12 | # SUMO params
13 | sumoBinary: ""
14 | sumoCmd: ""
15 | sumoTools: ""
16 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | from src.training import training
4 |
5 | parser = argparse.ArgumentParser(description='Start training')
6 | parser.add_argument("-c", "--config", type=str, help="Path to the config file")
7 |
8 | if __name__ == '__main__':
9 | args = vars(parser.parse_args())
10 |
11 | with open(args['config'], 'r') as ymfile:
12 | cfg = yaml.full_load(ymfile)
13 |
14 | print(cfg)
15 | training(cfg)
16 |
17 |
--------------------------------------------------------------------------------
/src/cfg/sumo_config.sumocfg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/src/test.py:
--------------------------------------------------------------------------------
1 | from dqn import DQNetwork
2 | from env import SumoIntersection
3 | import torch
4 |
5 | BATCH_SIZE = 32
6 | SIM_LEN = 4500
7 | WEIGHTS_PATH = '/Users/suess_mann/wd/tcqdrl/tca/saved_model/dqn_1246.pt'
8 |
9 | sumoBinary = "/usr/local/opt/sumo/share/sumo/bin/sumo-gui"
10 | sumoCmd = "/Users/suess_mann/wd/tcqdrl/tca/src/cfg/sumo_config.sumocfg"
11 |
12 |
13 | if __name__ == '__main__':
14 | q = DQNetwork()
15 | q.eval()
16 | try:
17 | q.load_state_dict(torch.load(WEIGHTS_PATH))
18 | except FileNotFoundError:
19 | print('No model weights found')
20 |
21 | env = SumoIntersection(sumoBinary, sumoCmd, SIM_LEN, 1600)
22 |
23 | state, _, _, _ = env.step(0)
24 | done = False
25 |
26 | while not done:
27 | a = q.predict(state.as_tuple, 0.00)
28 |
29 | s_prime, r, done, info = env.step(a)
30 | print(r)
31 | s = s_prime
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 suessmann
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/memory.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import numpy as np
4 | from src.data_storage import StoreState
5 | from collections import deque
6 |
7 |
8 | class DQNBuffer:
9 | def __init__(self, max_n, pos):
10 | self._buffer = deque(maxlen=max_n)
11 | self._buffer_positive = deque(maxlen=max_n)
12 | self.data = StoreState
13 | self.pos = pos
14 |
15 | def add(self, sample):
16 | self._buffer.append(sample)
17 |
18 | def add_positive(self, sample):
19 | self._buffer_positive.append(sample)
20 |
21 | def sample(self, n):
22 | mini_batch = random.sample(self._buffer, int(np.ceil(n * (1-self.pos)))) + random.sample(self._buffer_positive, int(np.floor(n * self.pos)))
23 | trans = StoreState(*zip(*mini_batch)) # unzipping
24 |
25 | state, state_prime = trans.concat()
26 | a = torch.tensor(trans.action).unsqueeze(1)
27 | r = torch.tensor(trans.reward).unsqueeze(1)
28 | done_mask = torch.tensor(trans.done_mask).unsqueeze(1)
29 |
30 | return state, a, r, state_prime, done_mask
31 |
32 | def refill(self):
33 | self._buffer.clear()
34 |
35 | @property
36 | def size(self):
37 | return len(self._buffer)
38 |
--------------------------------------------------------------------------------
/src/data_storage.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import torch
3 |
4 |
5 | @dataclass
6 | class StoreState:
7 | """
8 | Class for storing state variables
9 |
10 | *n_prefix* : current state
11 | p_* : prime state
12 | """
13 | position: torch.Tensor = torch.zeros((1, 4, 4, 16))
14 | speed: torch.Tensor = torch.zeros((1, 4, 4, 16))
15 | tl: torch.Tensor = torch.zeros((1, 1, 4))
16 | p_position: torch.Tensor = torch.zeros((1, 4, 4, 16))
17 | p_speed: torch.Tensor = torch.zeros((1, 4, 4, 16))
18 | p_tl: torch.Tensor = torch.zeros((1, 1, 4))
19 |
20 | action: 'typing.Any' = 0
21 | reward: 'typing.Any' = 0
22 | done_mask: float = 1.0
23 |
24 | def concat(self):
25 | self.position = torch.cat(self.position)
26 | self.speed = torch.cat(self.speed)
27 | self.tl = torch.cat(self.tl)
28 | self.p_position = torch.cat(self.p_position)
29 | self.p_speed = torch.cat(self.p_speed)
30 | self.p_tl = torch.cat(self.p_tl)
31 |
32 | return (self.position, self.speed, self.tl), \
33 | (self.p_position, self.p_speed, self.p_tl)
34 |
35 | @property
36 | def as_tuple(self):
37 | return (self.position, self.speed, self.tl)
38 |
39 | def swap(self):
40 | self.position = self.p_position
41 | self.speed = self.p_speed
42 | self.tl = self.p_tl
43 |
44 | self.p_position = torch.zeros((1, 4, 4, 16)) # depth, rows, cols
45 | self.p_speed = torch.zeros((1, 4, 4, 16))
46 | self.p_tl = torch.zeros((1, 1, 4))
47 |
48 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
2 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
3 |
4 | # User-specific stuff
5 | .idea/**/workspace.xml
6 | .idea/**/tasks.xml
7 | .idea/**/usage.statistics.xml
8 | .idea/**/dictionaries
9 | .idea/**/shelf
10 | .idea/
11 |
12 | # Generated files
13 | .idea/**/contentModel.xml
14 |
15 | # Sensitive or high-churn files
16 | .idea/**/dataSources/
17 | .idea/**/dataSources.ids
18 | .idea/**/dataSources.local.xml
19 | .idea/**/sqlDataSources.xml
20 | .idea/**/dynamic.xml
21 | .idea/**/uiDesigner.xml
22 | .idea/**/dbnavigator.xml
23 |
24 | # Gradle
25 | .idea/**/gradle.xml
26 | .idea/**/libraries
27 |
28 | # Gradle and Maven with auto-import
29 | # When using Gradle or Maven with auto-import, you should exclude module files,
30 | # since they will be recreated, and may cause churn. Uncomment if using
31 | # auto-import.
32 | # .idea/artifacts
33 | # .idea/compiler.xml
34 | # .idea/jarRepositories.xml
35 | # .idea/modules.xml
36 | # .idea/*.iml
37 | # .idea/modules
38 | # *.iml
39 | # *.ipr
40 |
41 | # CMake
42 | cmake-build-*/
43 |
44 | # Mongo Explorer plugin
45 | .idea/**/mongoSettings.xml
46 |
47 | # File-based project format
48 | *.iws
49 |
50 | # IntelliJ
51 | out/
52 |
53 | # mpeltonen/sbt-idea plugin
54 | .idea_modules/
55 |
56 | # JIRA plugin
57 | atlassian-ide-plugin.xml
58 |
59 | # Cursive Clojure plugin
60 | .idea/replstate.xml
61 |
62 | # Crashlytics plugin (for Android Studio and IntelliJ)
63 | com_crashlytics_export_strings.xml
64 | crashlytics.properties
65 | crashlytics-build.properties
66 | fabric.properties
67 |
68 | # Editor-based Rest Client
69 | .idea/httpRequests
70 |
71 | # Android studio 3.1+ serialized cache file
72 | .idea/caches/build_file_checksums.ser
73 |
74 | .ipynb_checkpoints
75 | src/cfg/
76 |
77 | .DS_Store
78 | __pycache__/
79 |
80 | src/logs
81 |
--------------------------------------------------------------------------------
/src/dqn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import random
5 |
6 |
7 | class DQNetwork(nn.Module):
8 |
9 | def __init__(self):
10 | super().__init__()
11 |
12 | self.features1 = nn.Sequential(
13 | nn.Conv2d(1, 16, 4, stride=2, padding=1),
14 | nn.ReLU(),
15 | nn.Conv2d(16, 32, 2, stride=1),
16 | nn.ReLU()
17 | )
18 | self.features2 = nn.Sequential(
19 | nn.Conv2d(1, 16, 4, stride=2, padding=1),
20 | nn.ReLU(),
21 | nn.Conv2d(16, 32, 2, stride=1),
22 | nn.ReLU()
23 | )
24 | # self.features3 = nn.Linear(4, 4)
25 |
26 | self.linear_relu1 = nn.Sequential(
27 | nn.Linear(32*7*7 + 32*7*7 + 4, 128),
28 | nn.ReLU(),
29 | # nn.Dropout(0.2),
30 | nn.Linear(128, 64),
31 | nn.ReLU(),
32 | # nn.Dropout(0.2),
33 | nn.Linear(64, 32),
34 | nn.ReLU(),
35 | # nn.Dropout(0.2)
36 | )
37 |
38 | self.classifier = nn.Sequential(
39 | nn.Linear(32, 4),
40 | # nn.Softmax(dim=-1)
41 | # nn.ReLU()
42 | )
43 |
44 |
45 | def forward(self, state):
46 | x1, x2, x3 = state
47 |
48 | x1 = self.features1(x1.view(x1.size(0), 1, -1, 16))
49 | x2 = self.features2(x2.view(x1.size(0), 1, -1, 16))
50 | # x3 = self.features3(x3)
51 |
52 | x1 = x1.view(x1.size(0), -1)
53 | x2 = x2.view(x2.size(0), -1)
54 | x3 = x3.view(x3.size(0), -1)
55 |
56 | x = torch.cat((x1, x2, x3), dim=1)
57 |
58 | x = self.linear_relu1(x)
59 | x = self.classifier(x)
60 |
61 | return x
62 |
63 | def predict(self, state, eps):
64 | prob = random.random()
65 | if prob < eps:
66 | return random.randint(0, 3)
67 | else:
68 | act = self.forward(state)
69 | return act.argmax().item()
70 |
71 |
72 | class FCQNetwork(nn.Module):
73 |
74 | def __init__(self):
75 | super().__init__()
76 |
77 | self.features1 = nn.Sequential(
78 | nn.Linear(256, 256),
79 | nn.ReLU(),
80 | nn.Dropout(0.2),
81 | nn.Linear(256, 128),
82 | nn.ReLU()
83 | )
84 | self.features2 = nn.Sequential(
85 | nn.Linear(256, 256),
86 | nn.ReLU(),
87 | nn.Dropout(0.2),
88 | nn.Linear(256, 128),
89 | nn.ReLU()
90 | )
91 | self.linear_q = nn.Linear(128 + 128 + 4, 4)
92 |
93 | def forward(self, state):
94 | x1, x2, x3 = state
95 |
96 | x1 = self.features1(x1.view(x1.size(0), 1, 1, -1))
97 | x2 = self.features2(x2.view(x1.size(0), 1, 1, -1))
98 | # x3 = self.features3(x3)
99 |
100 | x1 = x1.view(x1.size(0), -1)
101 | x2 = x2.view(x2.size(0), -1)
102 | x3 = x3.view(x3.size(0), -1)
103 |
104 | x = torch.cat((x1, x2, x3), dim=1)
105 |
106 | x = self.linear_q(x)
107 |
108 | return x
109 |
110 | def predict(self, state, eps):
111 | prob = random.random()
112 | if prob < eps:
113 | return random.randint(0, 3)
114 | else:
115 | act = self.forward(state)
116 | return act.argmax().item()
117 |
--------------------------------------------------------------------------------
/src/training.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 |
3 | from src.dqn import DQNetwork, FCQNetwork
4 | from src.env import SumoIntersection
5 | from src.memory import DQNBuffer
6 | from src.data_storage import StoreState
7 |
8 | import torch
9 | import torch.nn as nn
10 | import numpy as np
11 | from tqdm import tqdm
12 |
13 |
14 | def weights_init(m):
15 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
16 | torch.nn.init.xavier_uniform(m.weight.data)
17 | nn.init.constant(m.bias.data, 0)
18 |
19 | def soft_update(local_model, target_model, tau):
20 | for target_param, local_param in zip(target_model.parameters(),
21 | local_model.parameters()):
22 | target_param.data.copy_(tau*local_param.data + (1-tau)*target_param.data)
23 |
24 | def train_net(q, q_target, memory, optimizer, batch_size, gamma):
25 | state, a, r, state_prime, done_mask = memory.sample(batch_size)
26 |
27 | q_out = q(state)
28 | q_a = q_out.gather(1, a)
29 | max_q_prime = q_target(state_prime).max(1)[0].unsqueeze(1)
30 | target = r + gamma * max_q_prime * done_mask
31 | criterion = nn.MSELoss()
32 | loss = criterion(q_a, target)
33 | mse = loss.item()
34 |
35 | optimizer.zero_grad()
36 | loss.backward()
37 | for param in q.parameters():
38 | param.grad.data.clamp_(-1, 1)
39 | optimizer.step()
40 |
41 | def training(config):
42 | learning_rate = config['learning_rate']
43 | gamma = config['gamma']
44 | buffer_limit = config['buffer_limit']
45 | batch_size = config['batch_size']
46 | sim_len = config['sim_len']
47 | mem_refill = config['mem_refill']
48 | epochs = config['epochs']
49 | n_cars = config['n_cars']
50 | weights_path = config['weights_path']
51 |
52 | sumoBinary = config['sumoBinary']
53 | sumoCmd = config['sumoCmd']
54 | sys.path.append(os.path.join(config['sumoTools']))
55 |
56 | q = DQNetwork()
57 |
58 | try:
59 | q.load_state_dict(torch.load(weights_path))
60 | except FileNotFoundError:
61 | q.apply(weights_init)
62 | print('No model weights found, initializing xavier_uniform')
63 |
64 | q_target = DQNetwork()
65 | q_target.load_state_dict(q.state_dict())
66 |
67 | memory = DQNBuffer(buffer_limit, 0.)
68 |
69 | env = SumoIntersection(sumoBinary, sumoCmd, sim_len, n_cars)
70 | optimizer = torch.optim.RMSprop(q.parameters(), lr=learning_rate)
71 |
72 | state = StoreState()
73 | min_wait = float('inf')
74 | total_steps=0
75 | pbar = tqdm(total=epochs)
76 |
77 | for epoch in np.arange(epochs):
78 | eps = max(0.01, 0.07 - 0.01*(epoch/200))
79 | step = 0
80 |
81 | if epoch != 0:
82 | env.reset()
83 | pbar.set_description(f'EPOCH: {epoch}, mean reward: {info[6]}, mean waiting time: {info[2]}')
84 |
85 | state, _, _, _ = env.step(0)
86 | done = False
87 |
88 | done_mask = 1.0
89 | mse = 0
90 | while not done:
91 | with torch.no_grad():
92 | a = q.predict(state.as_tuple, eps)
93 | state, r, done, info = env.step(a) # storing prime state
94 |
95 | if done:
96 | done_mask = 0.0
97 |
98 | if r != 0 or step > 60:
99 | memory.add((state.position, state.speed,
100 | state.tl, state.p_position,
101 | state.p_speed, state.p_tl,
102 | a, r, done_mask))
103 |
104 | state.swap() # state = state_prime
105 |
106 | if memory.size > batch_size:
107 | train_net(q, q_target, memory, optimizer, batch_size, gamma)
108 |
109 | step += 1
110 | total_steps += 1
111 |
112 | # clear memory each mem_refill epoch
113 | if epoch%mem_refill == 0 and epoch != 0:
114 | memory.refill()
115 | pbar.update(1)
116 |
117 | # soft update weights at the end of epoch
118 | soft_update(q, q_target, 0.01)
119 |
120 | if info[4]/info[5] < min_wait and epoch > 50:
121 | torch.save(q.state_dict(), f'../model/dqn_{epoch}.pt')
122 | min_wait = info[4]/info[5]
123 | if epoch == 1599:
124 | torch.save(q.state_dict(), f'../model/dqn_{epoch}_final.pt')
125 |
126 | print('finished training')
127 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Status: Archived (no further updates in repo)
2 |
3 | # Traffic Signal Control with Reinforcement Learning
4 | This repository contains my project on intelligent traffic lights control. It relies on a Deep Q-network algorithm
5 | with a few convolutional layers.
6 |
7 | Below you may find detailed explanation of files in here, also some comments on model, final result and work to be done.
8 |
9 | The preview of final result.
10 | 
11 | *CNN-DQN, 1800 cars*
[Youtube video with comparison](https://www.youtube.com/watch?v=yIhDWvMrWFo)
12 | ---
13 |
14 | # Contents
15 | * `src/data_storage.py` – a dataclass for storing environment data in one place
16 | * `src/dqn.py` – a PyTorch DQN models
17 | * `src/env.py` – an API to communicate with simulator
18 | * `src/generator.py` – traffic generation resides here
19 | * `src/memory.py` – the model has replay memory implemented, it resides here
20 | * `src/test.py` – yet fairly useless file
21 | * `src/training.py` – training happens here
22 | * `cfg/` – a directory with SUMO configs.
23 |
24 | if you wish to train your own model, edit `config.yaml`, then run `train.py -c config.yaml`.
25 |
26 | # Model and Environment
27 | To simulate the intersection, Simulation of Urban MObility ([SUMO](https://www.eclipse.org/sumo/)) is used. To
28 | communicate with Python, TraCI API is provided as one of the tools of SUMO.
29 |
30 | To perform Reinforcement Learning, Deep Q-learning algorithm was chosen. In order for an agent to perform actions
31 | , firstly it observes the situation on the road, namely, it constructs two matrices. The first matrix divides the
32 | road into 16 segments of 7 meters each and checks whether there is a car in a segment. The second matrix checks the
33 | speed of a car in that segment. Those two matrices are fed into Conv-NN, then concatenated with vector of traffic
34 | signal state and magic happens!
35 |
36 | Since the nature of DQN is very unstable, two methods were used to stabilize its convergence. Firstly, Replay Memory
37 | was implemented, which stores past actions and allows to use training in batches over them. The second thing is the
38 | concept of a target Q-network. It allows not to chase the "moving target" while training by initializing two networks
39 | with the same weights. While optimizing, the weights of target network are kept as they are and updated only after N
40 | iterations.
41 |
42 | The final models are kept under `model` directory. Two models, namely convolutional and fully-connected ones are
43 | trained, both with 1600 epochs, 4500 steps each (1.15 hours). Those models were trained on CPU and it took around 24
44 | hours for each of them.
45 |
46 | # Results
47 | For testing purposes, three scenarios were generated with 800, 1300, 1800 cars. They represent three possible states
48 | of the road during low load, mid and peak-hour accordingly. The results for each are presented in the table below.
49 |
50 | | Metric (mean, std, n_cars=800) | Queue cars | Cumulative Waiting Time sec |
51 | |:-:|:-:|:-:|
52 | | CNN-DQN | (1.99, 2.27) | (4.26e04, 2.76e04) |
53 | | FC-DQN | (47.25, 18.04) | (1.23e06, 1.42e06) |
54 | | 90sec. | (87.34, 58.42) | (1.09e07, 1.11e07) |
55 |
56 | | Metric (mean, std, n_cars=1300) | Queue cars | Cumulative Waiting Time sec |
57 | |:------------------------------:|:--------------:|:---------------------------:|
58 | | CNN-DQN | (5.7, 9.1) | (3.56e05, 1.89e05) |
59 | | FC-DQN | (100.66, 34.3) | (6.38e07, 7.02e07) |
60 | | 90sec. | (144.44, 91.7) | (1.87e07, 1.91e07) |
61 |
62 | | Metric (mean, std, n_cars=1800) | Queue cars | Cumulative Waiting Time sec |
63 | |:-:|:-:|:-:|
64 | | CNN-DQN | (26.41, 42.58) | (3.98e06, 2.29e06) |
65 | | FC-DQN | (174.89, 68.83) | (1.21e08,1.33e08) |
66 | | 90sec. | (235.51, 159.42) | (2.78e07, 3.01e07) |
67 |
68 | 
69 |
70 | When the load is at peak, on average 117 seconds a vehicle waits in the que (CNN). That is 20% better than a steady
71 | 90 sec cycle. As it can be seen, FC-DQN cannot manage medium road load, however CNN-DQN manages to operate with
72 | high load with no problem.
73 |
74 | # References
75 | Genders, Wade, and Saiedeh Razavi. "Using a deep reinforcement learning agent for traffic signal control." arXiv preprint arXiv:1611.01142 (2017).
76 |
77 | Wade Genders & Saiedeh Razavi (2019): Asynchronous n-step Q- learning adaptive traffic signal control, Journal of Intelligent Transportation Systems, DOI: 10.1080/15472450.2018.1491003
78 |
--------------------------------------------------------------------------------
/src/generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.stats
3 | import math
4 |
5 |
6 | def generate(max_steps, n_cars):
7 | # np.random.seed(seed) # make tests reproducible
8 | n_cars = np.minimum(int(scipy.stats.invweibull.rvs(1, loc=n_cars, scale=12)), 1512)
9 | # the generation of cars is distributed according to a weibull distribution
10 | timings = np.random.weibull(2, n_cars)
11 | timings = np.sort(timings)
12 |
13 | # reshape the distribution to fit the interval 0:max_steps
14 | car_gen_steps = []
15 | min_old = math.floor(timings[1])
16 | max_old = math.ceil(timings[-1])
17 | min_new = 0
18 | max_new = max_steps
19 | for value in timings:
20 | car_gen_steps = np.append(car_gen_steps,
21 | ((max_new - min_new) / (max_old - min_old)) * (value - max_old) + max_new)
22 |
23 | car_gen_steps = np.rint(car_gen_steps) # round every value to int -> effective steps when a car will be generated
24 | np.insert(car_gen_steps, 0, 0)
25 | # produce the file for cars generation, one car per line
26 | with open("./cfg/episode_routes.rou.xml", "r+") as routes:
27 | print("""
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 | """, file=routes)
41 |
42 | for car_counter, step in enumerate(car_gen_steps):
43 | straight_or_turn = np.random.uniform()
44 | if straight_or_turn < 0.75: # choose direction: straight or turn - 75% of times the car goes straight
45 | route_straight = np.random.randint(1, 5) # choose a random source & destination
46 | if route_straight == 1:
47 | print(
48 | ' ' % (
49 | car_counter, step), file=routes)
50 | elif route_straight == 2:
51 | print(
52 | ' ' % (
53 | car_counter, step), file=routes)
54 | elif route_straight == 3:
55 | print(
56 | ' ' % (
57 | car_counter, step), file=routes)
58 | else:
59 | print(
60 | ' ' % (
61 | car_counter, step), file=routes)
62 | else: # car that turn -25% of the time the car turns
63 | route_turn = np.random.randint(1, 9) # choose random source source & destination
64 | if route_turn == 1:
65 | print(
66 | ' ' % (
67 | car_counter, step), file=routes)
68 | elif route_turn == 2:
69 | print(
70 | ' ' % (
71 | car_counter, step), file=routes)
72 | elif route_turn == 3:
73 | print(
74 | ' ' % (
75 | car_counter, step), file=routes)
76 | elif route_turn == 4:
77 | print(
78 | ' ' % (
79 | car_counter, step), file=routes)
80 | elif route_turn == 5:
81 | print(
82 | ' ' % (
83 | car_counter, step), file=routes)
84 | elif route_turn == 6:
85 | print(
86 | ' ' % (
87 | car_counter, step), file=routes)
88 | elif route_turn == 7:
89 | print(
90 | ' ' % (
91 | car_counter, step), file=routes)
92 | elif route_turn == 8:
93 | print(
94 | ' ' % (
95 | car_counter, step), file=routes)
96 |
97 | print("", file=routes)
98 | return n_cars
99 |
--------------------------------------------------------------------------------
/src/env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sys, os
3 | from collections import deque
4 | from src.generator import generate
5 | from src.data_storage import StoreState
6 |
7 | if 'SUMO_HOME' in os.environ:
8 | tools = os.path.join(os.environ['SUMO_HOME'], 'tools')
9 | sys.path.append(tools)
10 | else:
11 | sys.exit("please declare environment variable 'SUMO_HOME'")
12 |
13 | import traci
14 |
15 | PHASE_NS_GREEN = 0
16 | PHASE_NS_YELLOW = 1
17 | PHASE_NSL_GREEN = 2
18 | PHASE_NSL_YELLOW = 3
19 | PHASE_EW_GREEN = 4
20 | PHASE_EW_YELLOW = 5
21 | PHASE_EWL_GREEN = 6
22 | PHASE_EWL_YELLOW = 7
23 |
24 |
25 | class SumoIntersection:
26 | def __init__(self, path_bin, path_cfg, max_steps, n_cars):
27 | self.n_cars_out = generate(max_steps, n_cars)
28 | self.sumoCmd = [path_bin, "-c", path_cfg, '--waiting-time-memory', "4500", '--no-step-log', 'true', '-W', 'true']
29 | traci.start(self.sumoCmd)
30 | self.traci_api = traci.getConnection()
31 |
32 | self.path_cfg = path_cfg
33 | self.max_steps = max_steps
34 | self.time = 0
35 | self.old_phase = 0
36 | self.waiting_time = 0
37 | self.queue = 0
38 |
39 | self.done = False
40 | self.r_w = 0
41 | self.r_q = 0
42 | self.n_cars = n_cars
43 |
44 | self.r_epoch = 0
45 | self.w_epoch = 0
46 | self.q_epoch = 0
47 |
48 |
49 | self.waiting_time_float_av = deque(maxlen=max_steps)
50 | self.queue_float_av = deque(maxlen=max_steps)
51 | self.waiting_time_av = []
52 | self.queue_av = []
53 | self.r_av = []
54 |
55 | self.data = StoreState()
56 |
57 | def _get_info(self):
58 | # self._get_time()
59 | self._get_length()
60 |
61 | self.waiting_time_float_av.append(self.waiting_time)
62 | self.queue_float_av.append(self.queue)
63 | # self.waiting_time_av.append(self.waiting_time)
64 | # self.queue_av.append(self.queue)
65 |
66 | self.w_epoch += self.waiting_time
67 | self.q_epoch += self.queue
68 |
69 | return [np.mean(self.waiting_time_float_av), \
70 | np.mean(self.queue_float_av), \
71 | self.w_epoch/self.time, \
72 | self.q_epoch/self.time,
73 | self.w_epoch,
74 | self.q_epoch]
75 |
76 | def _tl_control(self, phase):
77 | if (self.time + 5) >= (self.max_steps - 10):
78 | self.done = True
79 | return
80 |
81 | #self.data.tl = torch.zeros((1, 1, 4))
82 |
83 | if phase != self.old_phase:
84 | yellow_phase_code = self.old_phase * 2 + 1
85 | self.traci_api.trafficlight.setPhase("TL", yellow_phase_code)
86 |
87 | if self.traci_api.simulation.getMinExpectedNumber() == 0:
88 | self.done = True
89 | return
90 |
91 | self.traci_api.simulationStep(self.time + 5.)
92 | self.time += 5
93 |
94 | self.data.p_tl[0, 0, phase] = 1
95 |
96 | if phase == 0:
97 | self.traci_api.trafficlight.setPhase("TL", PHASE_NS_GREEN)
98 | elif phase == 1:
99 | self.traci_api.trafficlight.setPhase("TL", PHASE_NSL_GREEN)
100 | elif phase == 2:
101 | self.traci_api.trafficlight.setPhase("TL", PHASE_EW_GREEN)
102 | elif phase == 3:
103 | self.traci_api.trafficlight.setPhase("TL", PHASE_EWL_GREEN)
104 |
105 | self.traci_api.simulationStep(self.time + 2.)
106 | if self.traci_api.simulation.getMinExpectedNumber() == 0:
107 | self.done = True
108 | return
109 |
110 | self.time += 2
111 |
112 | def _get_length(self):
113 | halt_N = self.traci_api.edge.getLastStepHaltingNumber("N2TL")
114 | halt_S = self.traci_api.edge.getLastStepHaltingNumber("S2TL")
115 | halt_E = self.traci_api.edge.getLastStepHaltingNumber("E2TL")
116 | halt_W = self.traci_api.edge.getLastStepHaltingNumber("W2TL")
117 | self.queue = halt_N + halt_S + halt_E + halt_W
118 |
119 | def _get_time(self, car_id):
120 | self.waiting_time += self.traci_api.vehicle.getAccumulatedWaitingTime(car_id)
121 |
122 | def _get_state(self):
123 | car_list = self.traci_api.vehicle.getIDList()
124 | lanes_list = np.array(['W2TL_0', 'W2TL_1', 'W2TL_2', 'W2TL_3',
125 | 'E2TL_0', 'E2TL_1', 'E2TL_2', 'E2TL_3',
126 | 'N2TL_0', 'N2TL_1', 'N2TL_2', 'N2TL_3',
127 | 'S2TL_0', 'S2TL_1', 'S2TL_2', 'S2TL_3'])
128 |
129 | for car_id in car_list:
130 | lane_id = self.traci_api.vehicle.getLaneID(car_id)
131 |
132 |
133 | if lane_id not in lanes_list:
134 | continue
135 |
136 | self._get_time(car_id)
137 |
138 | lane_pos = self.traci_api.vehicle.getLanePosition(car_id)
139 | car_speed = self.traci_api.vehicle.getSpeed(car_id) / 13.89
140 | lane_pos = 750 - lane_pos
141 |
142 | if lane_pos > 112:
143 | continue
144 |
145 | # distance in meters from the traffic light -> mapping into cells
146 | for i, dist in enumerate(np.arange(7, 113, 7)):
147 | if lane_pos < dist:
148 | lane_cell = i
149 | break
150 |
151 | pos = np.where(np.array(lanes_list) == lane_id)[0][0]
152 |
153 | r = int(lane_id[-1])
154 | c = lane_cell
155 |
156 | for i, clown in enumerate(np.arange(4, 17, 4)):
157 | if pos < clown:
158 | d = i
159 | break
160 |
161 | self.data.p_position[0, d, r, c] = 1 # depth, row, col
162 | self.data.p_speed[0, d, r, c] = car_speed
163 |
164 | def _get_reward(self):
165 | if self.time < 4:
166 | return 0
167 |
168 | self.r_w = 0.85 * self.waiting_time_float_av[-2] - self.waiting_time_float_av[-1]
169 |
170 | # if self.r_w < 0:
171 | # self.r_w *= 1.5
172 | #
173 | # self.r_av.append(self.r_w)
174 |
175 | return self.r_w
176 |
177 | def reset(self):
178 | self.n_cars_out = generate(self.max_steps, self.n_cars)
179 |
180 | self.time = 0
181 | self.old_phase = 0
182 | self.done = False
183 | self.r_epoch = 0
184 | self.w_epoch = 0
185 | self.q_epoch = 0
186 |
187 | self.traci_api.load(['-c', self.path_cfg, '--waiting-time-memory', "4500", '--no-step-log', 'true', '-W', 'true'])
188 | self.traci_api.simulationStep()
189 |
190 |
191 | def step(self, a):
192 | self._tl_control(a) # self.traci_api.steps are here
193 |
194 | self.old_phase = a
195 | self.waiting_time = 0
196 |
197 | self._get_state()
198 | info = self._get_info() # [0] av time, [1]: av queue
199 |
200 | self.r = self._get_reward()
201 |
202 | self.r_av.append(self.r)
203 | self.r_epoch += self.r
204 |
205 | info.append(self.r_epoch/self.time)
206 | info.append(self.r_epoch)
207 | info.append(self.time)
208 | if self.done:
209 | info.append(self.n_cars_out)
210 |
211 | return self.data, self.r, self.done, info
212 |
213 |
--------------------------------------------------------------------------------
/src/cfg/environment.net.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
--------------------------------------------------------------------------------