├── CoLight_agent.py ├── GNN_data_loader ├── PeMS-M.zip ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.opt-1.pyc │ ├── __init__.cpython-36.pyc │ ├── data_utils.cpython-36.opt-1.pyc │ ├── data_utils.cpython-36.pyc │ └── math_utils.cpython-36.pyc ├── data_utils.py └── math_utils.py ├── GNN_dataset ├── V-128.csv ├── V.csv ├── W-16.csv └── W.csv ├── GNN_models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── base_model.cpython-36.pyc │ ├── layers.cpython-36.pyc │ ├── tester.cpython-36.pyc │ └── trainer.cpython-36.pyc ├── base_model.py ├── layers.py ├── tester.py └── trainer.py ├── GNN_output ├── models │ ├── STGCN-1250.data-00000-of-00001 │ ├── STGCN-1250.index │ ├── STGCN-1250.meta │ ├── STGCN-2500.data-00000-of-00001 │ ├── STGCN-2500.index │ ├── STGCN-2500.meta │ ├── STGCN-3750.data-00000-of-00001 │ ├── STGCN-3750.index │ ├── STGCN-3750.meta │ ├── STGCN-5000.data-00000-of-00001 │ ├── STGCN-5000.index │ ├── STGCN-5000.meta │ └── checkpoint └── tensorboard │ └── train │ └── events.out.tfevents.1590335967.zip ├── GNN_utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.opt-1.pyc │ ├── __init__.cpython-36.pyc │ ├── math_graph.cpython-36.opt-1.pyc │ ├── math_graph.cpython-36.pyc │ ├── math_utils.cpython-36.opt-1.pyc │ └── math_utils.cpython-36.pyc ├── math_graph.py └── math_utils.py ├── README.md ├── agent.py ├── anon_env.py ├── config.py ├── construct_sample.py ├── data ├── Hangzhou │ ├── hangzhou.json │ └── roadnet_4_4.json ├── NewYork │ ├── newyork.json │ └── roadnet_16_3.json ├── synthetic │ ├── roadnet_3_3.json │ └── synthetic.json └── syntheticheavy │ ├── roadnet_3_3.json │ └── synthetic.json ├── generator.py ├── gnnmodule.py ├── lit_agent.py ├── model_pool.py ├── model_test.py ├── network_agent.py ├── pipeline.py ├── run_baseline.py ├── run_baseline_batch.py ├── run_batch.py ├── runexp.py ├── script.py ├── simple_dqn_agent.py ├── simple_dqn_one_agent.py ├── simulator.dockerfile ├── summary_multi_anon.py └── updater.py /GNN_data_loader/PeMS-M.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_data_loader/PeMS-M.zip -------------------------------------------------------------------------------- /GNN_data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : Jan. 10, 2019 15:24 2 | # @Author : Veritas YIN 3 | # @FileName : __init__.py 4 | # @Version : 1.0 5 | # @IDE : PyCharm 6 | # @Github : https://github.com/VeritasYin/Project_Orion 7 | -------------------------------------------------------------------------------- /GNN_data_loader/__pycache__/__init__.cpython-36.opt-1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_data_loader/__pycache__/__init__.cpython-36.opt-1.pyc -------------------------------------------------------------------------------- /GNN_data_loader/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_data_loader/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /GNN_data_loader/__pycache__/data_utils.cpython-36.opt-1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_data_loader/__pycache__/data_utils.cpython-36.opt-1.pyc -------------------------------------------------------------------------------- /GNN_data_loader/__pycache__/data_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_data_loader/__pycache__/data_utils.cpython-36.pyc -------------------------------------------------------------------------------- /GNN_data_loader/__pycache__/math_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_data_loader/__pycache__/math_utils.cpython-36.pyc -------------------------------------------------------------------------------- /GNN_data_loader/data_utils.py: -------------------------------------------------------------------------------- 1 | # @Time : Jan. 10, 2019 15:26 2 | # @Author : Veritas YIN 3 | # @FileName : data_utils.py 4 | # @Version : 1.0 5 | # @IDE : PyCharm 6 | # @Github : https://github.com/VeritasYin/Project_Orion 7 | 8 | from GNN_utils.math_utils import z_score 9 | 10 | import numpy as np 11 | import pandas as pd 12 | 13 | 14 | class Dataset(object): 15 | def __init__(self, data, stats): 16 | self.__data = data 17 | self.mean = stats['mean'] 18 | self.std = stats['std'] 19 | 20 | def get_data(self, type): 21 | return self.__data[type] 22 | 23 | def get_stats(self): 24 | return {'mean': self.mean, 'std': self.std} 25 | 26 | def get_len(self, type): 27 | return len(self.__data[type]) 28 | 29 | def z_inverse(self, type): 30 | return self.__data[type] * self.std + self.mean 31 | 32 | 33 | def seq_gen(len_seq, data_seq, offset, n_frame, n_route, day_slot, C_0): 34 | ''' 35 | Generate data in the form of standard sequence unit. 36 | :param len_seq: int, the length of target date sequence. 37 | :param data_seq: np.ndarray, source data / time-series. 38 | :param offset: int, the starting index of different dataset type. 39 | :param n_frame: int, the number of frame within a standard sequence unit, 40 | which contains n_his = 12 and n_pred = 9 (3 /15 min, 6 /30 min & 9 /45 min). 41 | :param n_route: int, the number of routes in the graph. 42 | :param day_slot: int, the number of time slots per day, controlled by the time window (5 min as default). 43 | :param C_0: int, the size of input channel. 44 | :return: np.ndarray, [len_seq, n_frame, n_route, C_0]. 45 | ''' 46 | #n_route = 228, n_frame=21, day_slot=288 47 | n_slot = day_slot - n_frame + 1 #268 48 | 49 | tmp_seq = np.zeros((len_seq * n_slot, n_frame, n_route, C_0)) 50 | 51 | for i in range(len_seq): 52 | for j in range(n_slot): 53 | sta = (i + offset) * day_slot + j 54 | end = sta + n_frame 55 | x = np.reshape(data_seq[sta:end, :], [n_frame, n_route, C_0]) 56 | tmp_seq[i * n_slot + j, :, :, :] = x 57 | return tmp_seq 58 | 59 | 60 | def data_gen(file_path, data_config, n_route, n_frame=21, day_slot=288): 61 | ''' 62 | Source file load and dataset generation. 63 | :param file_path: str, the file path of data source. 64 | :param data_config: tuple, the configs of dataset in train, validation, test. 65 | :param n_route: int, the number of routes in the graph. 66 | :param n_frame: int, the number of frame within a standard sequence unit, 67 | which contains n_his = 12 and n_pred = 9 (3 /15 min, 6 /30 min & 9 /45 min). 68 | :param day_slot: int, the number of time slots per day, controlled by the time window (5 min as default). 69 | :return: dict, dataset that contains training, validation and test with stats. 70 | ''' 71 | n_train, n_val, n_test,channel = data_config 72 | #34, 5, 5 73 | # generate training, validation and test data 74 | try: 75 | data_seq = pd.read_csv(file_path, header=None).values 76 | except FileNotFoundError: 77 | print(f'ERROR: input file was not found in {file_path}.') 78 | 79 | seq_train = seq_gen(n_train, data_seq, 0, n_frame, n_route, day_slot,channel) 80 | seq_val = seq_gen(n_val, data_seq, n_train, n_frame, n_route, day_slot,channel) 81 | seq_test = seq_gen(n_test, data_seq, n_train + n_val,n_frame, n_route, day_slot,channel) 82 | 83 | # x_stats: dict, the stats for the train dataset, including the value of mean and standard deviation. 84 | x_stats = {'mean': np.mean(seq_train), 'std': np.std(seq_train)} 85 | 86 | # x_train, x_val, x_test: np.array, [sample_size, n_frame, n_route, channel_size]. 87 | x_train = z_score(seq_train, x_stats['mean'], x_stats['std']) 88 | x_val = z_score(seq_val, x_stats['mean'], x_stats['std']) 89 | x_test = z_score(seq_test, x_stats['mean'], x_stats['std']) 90 | 91 | x_data = {'train': x_train, 'val': x_val, 'test': x_test} 92 | dataset = Dataset(x_data, x_stats) 93 | return dataset 94 | 95 | 96 | def gen_batch(inputs, batch_size, dynamic_batch=False, shuffle=False): 97 | ''' 98 | Data iterator in batch. 99 | :param inputs: np.ndarray, [len_seq, n_frame, n_route, C_0], standard sequence units. 100 | :param batch_size: int, the size of batch. 101 | :param dynamic_batch: bool, whether changes the batch size in the last batch if its length is less than the default. 102 | :param shuffle: bool, whether shuffle the batches. 103 | ''' 104 | len_inputs = len(inputs) 105 | 106 | if shuffle: 107 | idx = np.arange(len_inputs) 108 | np.random.shuffle(idx) 109 | 110 | for start_idx in range(0, len_inputs, batch_size): 111 | end_idx = start_idx + batch_size 112 | if end_idx > len_inputs: 113 | if dynamic_batch: 114 | end_idx = len_inputs 115 | else: 116 | break 117 | if shuffle: 118 | slide = idx[start_idx:end_idx] 119 | else: 120 | slide = slice(start_idx, end_idx) 121 | 122 | yield inputs[slide] 123 | 124 | if __name__ == '__main__': 125 | file_path = r'E:\ex\GNN\STGCN_IJCAI-18\dataset\\PeMSD7_V_228.csv' 126 | n_train, n_val, n_test = 34,5,5 127 | n_route = 228 128 | n_frame=21 129 | day_slot=288 130 | 131 | #34, 5, 5 132 | # generate training, validation and test data 133 | try: 134 | data_seq = pd.read_csv(file_path, header=None).values 135 | #(12672, 228) 136 | except FileNotFoundError: 137 | print(f'ERROR: input file was not found in {file_path}.') 138 | 139 | seq_train = seq_gen(n_train, data_seq, 0, n_frame, n_route, day_slot) 140 | #(9112,21,228,1) 141 | seq_val = seq_gen(n_val, data_seq, n_train, n_frame, n_route, day_slot) 142 | #(1340,21,228,1) 143 | seq_test = seq_gen(n_test, data_seq, n_train + n_val, n_frame, n_route, day_slot) 144 | #(1340,21,228,1) 145 | 146 | # x_stats: dict, the stats for the train dataset, including the value of mean and standard deviation. 147 | x_stats = {'mean': np.mean(seq_train), 'std': np.std(seq_train)} 148 | #{'mean': 58.499780775101506, 'std': 13.728615941912187} 149 | 150 | # x_train, x_val, x_test: np.array, [sample_size, n_frame, n_route, channel_size]. 151 | x_train = z_score(seq_train, x_stats['mean'], x_stats['std']) 152 | x_val = z_score(seq_val, x_stats['mean'], x_stats['std']) 153 | x_test = z_score(seq_test, x_stats['mean'], x_stats['std']) 154 | 155 | x_data = {'train': x_train, 'val': x_val, 'test': x_test} 156 | dataset = Dataset(x_data, x_stats) 157 | -------------------------------------------------------------------------------- /GNN_data_loader/math_utils.py: -------------------------------------------------------------------------------- 1 | # @Time : Jan. 10, 2019 15:15 2 | # @Author : Veritas YIN 3 | # @FileName : math_utils.py 4 | # @Version : 1.0 5 | # @IDE : PyCharm 6 | # @Github : https://github.com/VeritasYin/Project_Orion 7 | 8 | import numpy as np 9 | 10 | 11 | def z_score(x, mean, std): 12 | ''' 13 | Z-score normalization function: $z = (X - \mu) / \sigma $, 14 | where z is the z-score, X is the value of the element, 15 | $\mu$ is the population mean, and $\sigma$ is the standard deviation. 16 | :param x: np.ndarray, input array to be normalized. 17 | :param mean: float, the value of mean. 18 | :param std: float, the value of standard deviation. 19 | :return: np.ndarray, z-score normalized array. 20 | ''' 21 | return (x - mean) / std 22 | 23 | 24 | def z_inverse(x, mean, std): 25 | ''' 26 | The inverse of function z_score(). 27 | :param x: np.ndarray, input to be recovered. 28 | :param mean: float, the value of mean. 29 | :param std: float, the value of standard deviation. 30 | :return: np.ndarray, z-score inverse array. 31 | ''' 32 | return x * std + mean 33 | 34 | 35 | def MAPE(v, v_): 36 | ''' 37 | Mean absolute percentage error. 38 | :param v: np.ndarray or int, ground truth. 39 | :param v_: np.ndarray or int, prediction. 40 | :return: int, MAPE averages on all elements of input. 41 | ''' 42 | return np.mean(np.abs(v_ - v) / (v + 1e-5)) 43 | 44 | 45 | def RMSE(v, v_): 46 | ''' 47 | Mean squared error. 48 | :param v: np.ndarray or int, ground truth. 49 | :param v_: np.ndarray or int, prediction. 50 | :return: int, RMSE averages on all elements of input. 51 | ''' 52 | return np.sqrt(np.mean((v_ - v) ** 2)) 53 | 54 | 55 | def MAE(v, v_): 56 | ''' 57 | Mean absolute error. 58 | :param v: np.ndarray or int, ground truth. 59 | :param v_: np.ndarray or int, prediction. 60 | :return: int, MAE averages on all elements of input. 61 | ''' 62 | return np.mean(np.abs(v_ - v)) 63 | 64 | 65 | def evaluation(y, y_, x_stats): 66 | ''' 67 | Evaluation function: interface to calculate MAPE, MAE and RMSE between ground truth and prediction. 68 | Extended version: multi-step prediction can be calculated by self-calling. 69 | :param y: np.ndarray or int, ground truth. 70 | :param y_: np.ndarray or int, prediction. 71 | :param x_stats: dict, paras of z-scores (mean & std). 72 | :return: np.ndarray, averaged metric values. 73 | ''' 74 | dim = len(y_.shape) 75 | 76 | if dim == 3: 77 | # single_step case 78 | v = z_inverse(y, x_stats['mean'], x_stats['std']) 79 | v_ = z_inverse(y_, x_stats['mean'], x_stats['std']) 80 | return np.array([MAPE(v, v_), MAE(v, v_), RMSE(v, v_)]) 81 | else: 82 | # multi_step case 83 | tmp_list = [] 84 | # y -> [time_step, batch_size, n_route, 1] 85 | y = np.swapaxes(y, 0, 1) 86 | # recursively call 87 | for i in range(y_.shape[0]): 88 | tmp_res = evaluation(y[i], y_[i], x_stats) 89 | tmp_list.append(tmp_res) 90 | return np.concatenate(tmp_list, axis=-1) 91 | -------------------------------------------------------------------------------- /GNN_dataset/W-16.csv: -------------------------------------------------------------------------------- 1 | 0.000000000000000000e+00,6.000000000000000000e+02,1.200000000000000000e+03,1.800000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03,2.000000000000000000e+03,2.600000000000000000e+03,1.600000000000000000e+03,2.200000000000000000e+03,2.800000000000000000e+03,3.400000000000000000e+03,2.400000000000000000e+03,3.000000000000000000e+03,3.600000000000000000e+03,4.200000000000000000e+03 2 | 6.000000000000000000e+02,0.000000000000000000e+00,6.000000000000000000e+02,1.200000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03,2.000000000000000000e+03,2.200000000000000000e+03,1.600000000000000000e+03,2.200000000000000000e+03,2.800000000000000000e+03,3.000000000000000000e+03,2.400000000000000000e+03,3.000000000000000000e+03,3.600000000000000000e+03 3 | 1.200000000000000000e+03,6.000000000000000000e+02,0.000000000000000000e+00,6.000000000000000000e+02,2.000000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03,2.800000000000000000e+03,2.200000000000000000e+03,1.600000000000000000e+03,2.200000000000000000e+03,3.600000000000000000e+03,3.000000000000000000e+03,2.400000000000000000e+03,3.000000000000000000e+03 4 | 1.800000000000000000e+03,1.200000000000000000e+03,6.000000000000000000e+02,0.000000000000000000e+00,2.600000000000000000e+03,2.000000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02,3.400000000000000000e+03,2.800000000000000000e+03,2.200000000000000000e+03,1.600000000000000000e+03,4.200000000000000000e+03,3.600000000000000000e+03,3.000000000000000000e+03,2.400000000000000000e+03 5 | 8.000000000000000000e+02,1.400000000000000000e+03,2.000000000000000000e+03,2.600000000000000000e+03,0.000000000000000000e+00,6.000000000000000000e+02,1.200000000000000000e+03,1.800000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03,2.000000000000000000e+03,2.600000000000000000e+03,1.600000000000000000e+03,2.200000000000000000e+03,2.800000000000000000e+03,3.400000000000000000e+03 6 | 1.400000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03,2.000000000000000000e+03,6.000000000000000000e+02,0.000000000000000000e+00,6.000000000000000000e+02,1.200000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03,2.000000000000000000e+03,2.200000000000000000e+03,1.600000000000000000e+03,2.200000000000000000e+03,2.800000000000000000e+03 7 | 2.000000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03,1.200000000000000000e+03,6.000000000000000000e+02,0.000000000000000000e+00,6.000000000000000000e+02,2.000000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03,2.800000000000000000e+03,2.200000000000000000e+03,1.600000000000000000e+03,2.200000000000000000e+03 8 | 2.600000000000000000e+03,2.000000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02,1.800000000000000000e+03,1.200000000000000000e+03,6.000000000000000000e+02,0.000000000000000000e+00,2.600000000000000000e+03,2.000000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02,3.400000000000000000e+03,2.800000000000000000e+03,2.200000000000000000e+03,1.600000000000000000e+03 9 | 1.600000000000000000e+03,2.200000000000000000e+03,2.800000000000000000e+03,3.400000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03,2.000000000000000000e+03,2.600000000000000000e+03,0.000000000000000000e+00,6.000000000000000000e+02,1.200000000000000000e+03,1.800000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03,2.000000000000000000e+03,2.600000000000000000e+03 10 | 2.200000000000000000e+03,1.600000000000000000e+03,2.200000000000000000e+03,2.800000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03,2.000000000000000000e+03,6.000000000000000000e+02,0.000000000000000000e+00,6.000000000000000000e+02,1.200000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03,2.000000000000000000e+03 11 | 2.800000000000000000e+03,2.200000000000000000e+03,1.600000000000000000e+03,2.200000000000000000e+03,2.000000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03,1.200000000000000000e+03,6.000000000000000000e+02,0.000000000000000000e+00,6.000000000000000000e+02,2.000000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03 12 | 3.400000000000000000e+03,2.800000000000000000e+03,2.200000000000000000e+03,1.600000000000000000e+03,2.600000000000000000e+03,2.000000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02,1.800000000000000000e+03,1.200000000000000000e+03,6.000000000000000000e+02,0.000000000000000000e+00,2.600000000000000000e+03,2.000000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02 13 | 2.400000000000000000e+03,3.000000000000000000e+03,3.600000000000000000e+03,4.200000000000000000e+03,1.600000000000000000e+03,2.200000000000000000e+03,2.800000000000000000e+03,3.400000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03,2.000000000000000000e+03,2.600000000000000000e+03,0.000000000000000000e+00,6.000000000000000000e+02,1.200000000000000000e+03,1.800000000000000000e+03 14 | 3.000000000000000000e+03,2.400000000000000000e+03,3.000000000000000000e+03,3.600000000000000000e+03,2.200000000000000000e+03,1.600000000000000000e+03,2.200000000000000000e+03,2.800000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03,2.000000000000000000e+03,6.000000000000000000e+02,0.000000000000000000e+00,6.000000000000000000e+02,1.200000000000000000e+03 15 | 3.600000000000000000e+03,3.000000000000000000e+03,2.400000000000000000e+03,3.000000000000000000e+03,2.800000000000000000e+03,2.200000000000000000e+03,1.600000000000000000e+03,2.200000000000000000e+03,2.000000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02,1.400000000000000000e+03,1.200000000000000000e+03,6.000000000000000000e+02,0.000000000000000000e+00,6.000000000000000000e+02 16 | 4.200000000000000000e+03,3.600000000000000000e+03,3.000000000000000000e+03,2.400000000000000000e+03,3.400000000000000000e+03,2.800000000000000000e+03,2.200000000000000000e+03,1.600000000000000000e+03,2.600000000000000000e+03,2.000000000000000000e+03,1.400000000000000000e+03,8.000000000000000000e+02,1.800000000000000000e+03,1.200000000000000000e+03,6.000000000000000000e+02,0.000000000000000000e+00 17 | -------------------------------------------------------------------------------- /GNN_dataset/W.csv: -------------------------------------------------------------------------------- 1 | 1.000000000000000000e+00 2 | -------------------------------------------------------------------------------- /GNN_models/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : Jan. 10, 2019 17:49 2 | # @Author : Veritas YIN 3 | # @FileName : __init__.py 4 | # @Version : 1.0 5 | # @IDE : PyCharm 6 | # @Github : https://github.com/VeritasYin/Project_Orion 7 | -------------------------------------------------------------------------------- /GNN_models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /GNN_models/__pycache__/base_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_models/__pycache__/base_model.cpython-36.pyc -------------------------------------------------------------------------------- /GNN_models/__pycache__/layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_models/__pycache__/layers.cpython-36.pyc -------------------------------------------------------------------------------- /GNN_models/__pycache__/tester.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_models/__pycache__/tester.cpython-36.pyc -------------------------------------------------------------------------------- /GNN_models/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_models/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /GNN_models/base_model.py: -------------------------------------------------------------------------------- 1 | # @Time : Jan. 12, 2019 19:01 2 | # @Author : Veritas YIN 3 | # @FileName : base_model.py 4 | # @Version : 1.0 5 | # @IDE : PyCharm 6 | # @Github : https://github.com/VeritasYin/Project_Orion 7 | 8 | from models.layers import * 9 | from os.path import join as pjoin 10 | import tensorflow as tf 11 | 12 | 13 | def build_model(inputs, n_his, Ks, Kt, blocks, keep_prob): 14 | ''' 15 | Build the base model. 16 | :param inputs: placeholder. 17 | :param n_his: int, size of historical records for training. 18 | :param Ks: int, kernel size of spatial convolution. 19 | :param Kt: int, kernel size of temporal convolution. 20 | :param blocks: list, channel configs of st_conv blocks. 21 | :param keep_prob: placeholder. 22 | ''' 23 | #build_model(x, n_his, Ks, Kt, blocks, keep_prob) 24 | #x = tf.placeholder(tf.float32, [None, n_his + 1, n, 1], name='data_input') 25 | #keep_prob = tf.placeholder(tf.float32, name='keep_prob') 26 | 27 | #inputs: [None, n_his + 1, n, 1] [, 13,228,1] 28 | x = inputs[:, 0:n_his, :, :] 29 | 30 | # Ko>0: kernel size of temporal convolution in the output layer. 31 | Ko = n_his #12 32 | # ST-Block 33 | 34 | #blocks: [[1, 32, 64], [64, 32, 128]] 35 | for i, channels in enumerate(blocks): 36 | x = st_conv_block(x, Ks, Kt, channels, i, keep_prob, act_func='GLU') 37 | Ko -= 2 * (Ks - 1) 38 | #print(i,channels,x,Ko) 39 | #0 [1, 32, 64] Tensor("dropout/mul:0", shape=(?, 8, 228, 64), dtype=float32) 8 40 | #1 [64, 32, 128] Tensor("dropout_1/mul:0", shape=(?, 4, 228, 128), dtype=float32) 4 41 | 42 | 43 | # Output Layer 44 | if Ko > 1: 45 | y = output_layer(x, Ko, 'output_layer',blocks[0][0]) 46 | else: 47 | raise ValueError(f'ERROR: kernel size Ko must be greater than 1, but received "{Ko}".') 48 | 49 | tf.add_to_collection(name='copy_loss', 50 | value=tf.nn.l2_loss(inputs[:, n_his - 1:n_his, :, :] - inputs[:, n_his:n_his + 1, :, :])) 51 | print(y) 52 | print(inputs) 53 | #不加 最后一层的全连接 54 | # Tensor("output_layer_out/Sigmoid:0", shape=(?, 1, 16, 128), dtype=float32) 55 | # Tensor("data_input:0", shape=(?, 11, 16, 8), dtype=float32) 56 | 57 | #加最后一层的全连接 58 | #Tensor("add:0", shape=(?, 1, 16, 1), dtype=float32) 59 | #Tensor("data_input:0", shape=(?, 11, 16, 8), dtype=float32) 60 | #import sys 61 | #sys.exit(0) 62 | 63 | train_loss = tf.nn.l2_loss(y - inputs[:, n_his:n_his + 1, :, :]) 64 | single_pred = y[:, 0, :, :] 65 | tf.add_to_collection(name='y_pred', value=single_pred) 66 | 67 | #y: Tensor("add:0", shape=(?, 1, 228, 1), dtype=float32) 68 | #train_loss: Tensor("L2Loss_2:0", shape=(), dtype=float32) 69 | #single_pred: Tensor("strided_slice_4:0", shape=(?, 228, 1), dtype=float32) 70 | 71 | return train_loss, single_pred 72 | 73 | 74 | def model_save(sess, global_steps, model_name, save_path='./output/models/'): 75 | ''' 76 | Save the checkpoint of trained model. 77 | :param sess: tf.Session(). 78 | :param global_steps: tensor, record the global step of training in epochs. 79 | :param model_name: str, the name of saved model. 80 | :param save_path: str, the path of saved model. 81 | :return: 82 | ''' 83 | saver = tf.train.Saver(max_to_keep=3) 84 | prefix_path = saver.save(sess, pjoin(save_path, model_name), global_step=global_steps) 85 | print(f'<< Saving model to {prefix_path} ...') 86 | -------------------------------------------------------------------------------- /GNN_models/layers.py: -------------------------------------------------------------------------------- 1 | # @Time : Jan. 12, 2019 17:45 2 | # @Author : Veritas YIN 3 | # @FileName : layers.py 4 | # @Version : 1.0 5 | # @IDE : PyCharm 6 | # @Github : https://github.com/VeritasYin/Project_Orion 7 | 8 | import tensorflow as tf 9 | 10 | 11 | def gconv(x, theta, Ks, c_in, c_out): 12 | ''' 13 | Spectral-based graph convolution function. 14 | :param x: tensor, [batch_size, n_route, c_in]. 15 | :param theta: tensor, [Ks*c_in, c_out], trainable kernel parameters. 16 | :param Ks: int, kernel size of graph convolution. 17 | :param c_in: int, size of input channel. 18 | :param c_out: int, size of output channel. 19 | :return: tensor, [batch_size, n_route, c_out]. 20 | ''' 21 | # graph kernel: tensor, [n_route, Ks*n_route] 22 | kernel = tf.get_collection('graph_kernel')[0] 23 | n = tf.shape(kernel)[0] 24 | # x -> [batch_size, c_in, n_route] -> [batch_size*c_in, n_route] 25 | x_tmp = tf.reshape(tf.transpose(x, [0, 2, 1]), [-1, n]) 26 | # x_mul = x_tmp * ker -> [batch_size*c_in, Ks*n_route] -> [batch_size, c_in, Ks, n_route] 27 | x_mul = tf.reshape(tf.matmul(x_tmp, kernel), [-1, c_in, Ks, n]) 28 | # x_ker -> [batch_size, n_route, c_in, K_s] -> [batch_size*n_route, c_in*Ks] 29 | x_ker = tf.reshape(tf.transpose(x_mul, [0, 3, 1, 2]), [-1, c_in * Ks]) 30 | # x_gconv -> [batch_size*n_route, c_out] -> [batch_size, n_route, c_out] 31 | x_gconv = tf.reshape(tf.matmul(x_ker, theta), [-1, n, c_out]) 32 | return x_gconv 33 | 34 | 35 | def layer_norm(x, scope): 36 | ''' 37 | Layer normalization function. 38 | :param x: tensor, [batch_size, time_step, n_route, channel]. 39 | :param scope: str, variable scope. 40 | :return: tensor, [batch_size, time_step, n_route, channel]. 41 | ''' 42 | _, _, N, C = x.get_shape().as_list() 43 | mu, sigma = tf.nn.moments(x, axes=[2, 3], keep_dims=True) 44 | 45 | with tf.variable_scope(scope): 46 | gamma = tf.get_variable('gamma', initializer=tf.ones([1, 1, N, C])) 47 | beta = tf.get_variable('beta', initializer=tf.zeros([1, 1, N, C])) 48 | _x = (x - mu) / tf.sqrt(sigma + 1e-6) * gamma + beta 49 | return _x 50 | 51 | 52 | def temporal_conv_layer(x, Kt, c_in, c_out, act_func='relu'): 53 | ''' 54 | Temporal convolution layer. 55 | :param x: tensor, [batch_size, time_step, n_route, c_in]. 56 | :param Kt: int, kernel size of temporal convolution. 57 | :param c_in: int, size of input channel. 58 | :param c_out: int, size of output channel. 59 | :param act_func: str, activation function. 60 | :return: tensor, [batch_size, time_step-Kt+1, n_route, c_out]. 61 | ''' 62 | _, T, n, _ = x.get_shape().as_list() 63 | 64 | if c_in > c_out: 65 | w_input = tf.get_variable('wt_input', shape=[1, 1, c_in, c_out], dtype=tf.float32) 66 | tf.add_to_collection(name='weight_decay', value=tf.nn.l2_loss(w_input)) 67 | x_input = tf.nn.conv2d(x, w_input, strides=[1, 1, 1, 1], padding='SAME') 68 | elif c_in < c_out: 69 | # if the size of input channel is less than the output, 70 | # padding x to the same size of output channel. 71 | # Note, _.get_shape() cannot convert a partially known TensorShape to a Tensor. 72 | x_input = tf.concat([x, tf.zeros([tf.shape(x)[0], T, n, c_out - c_in])], axis=3) 73 | else: 74 | x_input = x 75 | 76 | 77 | # keep the original input for residual connection. 78 | x_input = x_input[:, Kt - 1:T, :, :] 79 | 80 | if act_func == 'GLU': 81 | # gated liner unit 82 | wt = tf.get_variable(name='wt', shape=[Kt, 1, c_in, 2 * c_out], dtype=tf.float32) 83 | tf.add_to_collection(name='weight_decay', value=tf.nn.l2_loss(wt)) 84 | bt = tf.get_variable(name='bt', initializer=tf.zeros([2 * c_out]), dtype=tf.float32) 85 | x_conv = tf.nn.conv2d(x, wt, strides=[1, 1, 1, 1], padding='VALID') + bt 86 | return (x_conv[:, :, :, 0:c_out] + x_input) * tf.nn.sigmoid(x_conv[:, :, :, -c_out:]) 87 | else: 88 | wt = tf.get_variable(name='wt', shape=[Kt, 1, c_in, c_out], dtype=tf.float32) 89 | tf.add_to_collection(name='weight_decay', value=tf.nn.l2_loss(wt)) 90 | bt = tf.get_variable(name='bt', initializer=tf.zeros([c_out]), dtype=tf.float32) 91 | x_conv = tf.nn.conv2d(x, wt, strides=[1, 1, 1, 1], padding='VALID') + bt 92 | if act_func == 'linear': 93 | return x_conv 94 | elif act_func == 'sigmoid': 95 | return tf.nn.sigmoid(x_conv) 96 | elif act_func == 'relu': 97 | return tf.nn.relu(x_conv + x_input) 98 | else: 99 | raise ValueError(f'ERROR: activation function "{act_func}" is not defined.') 100 | 101 | 102 | def spatio_conv_layer(x, Ks, c_in, c_out): 103 | ''' 104 | Spatial graph convolution layer. 105 | :param x: tensor, [batch_size, time_step, n_route, c_in]. 106 | :param Ks: int, kernel size of spatial convolution. 107 | :param c_in: int, size of input channel. 108 | :param c_out: int, size of output channel. 109 | :return: tensor, [batch_size, time_step, n_route, c_out]. 110 | ''' 111 | _, T, n, _ = x.get_shape().as_list() 112 | 113 | 114 | if c_in > c_out: 115 | # bottleneck down-sampling 116 | w_input = tf.get_variable('ws_input', shape=[1, 1, c_in, c_out], dtype=tf.float32) 117 | tf.add_to_collection(name='weight_decay', value=tf.nn.l2_loss(w_input)) 118 | x_input = tf.nn.conv2d(x, w_input, strides=[1, 1, 1, 1], padding='SAME') 119 | elif c_in < c_out: 120 | # if the size of input channel is less than the output, 121 | # padding x to the same size of output channel. 122 | # Note, _.get_shape() cannot convert a partially known TensorShape to a Tensor. 123 | x_input = tf.concat([x, tf.zeros([tf.shape(x)[0], T, n, c_out - c_in])], axis=3) 124 | else: 125 | x_input = x 126 | 127 | ws = tf.get_variable(name='ws', shape=[Ks * c_in, c_out], dtype=tf.float32) 128 | tf.add_to_collection(name='weight_decay', value=tf.nn.l2_loss(ws)) 129 | variable_summaries(ws, 'theta') 130 | bs = tf.get_variable(name='bs', initializer=tf.zeros([c_out]), dtype=tf.float32) 131 | # x -> [batch_size*time_step, n_route, c_in] -> [batch_size*time_step, n_route, c_out] 132 | x_gconv = gconv(tf.reshape(x, [-1, n, c_in]), ws, Ks, c_in, c_out) + bs 133 | # x_g -> [batch_size, time_step, n_route, c_out] 134 | x_gc = tf.reshape(x_gconv, [-1, T, n, c_out]) 135 | return tf.nn.relu(x_gc[:, :, :, 0:c_out] + x_input) 136 | 137 | 138 | def st_conv_block(x, Ks, Kt, channels, scope, keep_prob, act_func='GLU'): 139 | ''' 140 | Spatio-temporal convolutional block, which contains two temporal gated convolution layers 141 | and one spatial graph convolution layer in the middle. 142 | :param x: tensor, batch_size, time_step, n_route, c_in]. 143 | :param Ks: int, kernel size of spatial convolution. 144 | :param Kt: int, kernel size of temporal convolution. 145 | :param channels: list, channel configs of a single st_conv block. 146 | :param scope: str, variable scope. 147 | :param keep_prob: placeholder, prob of dropout. 148 | :param act_func: str, activation function. 149 | :return: tensor, [batch_size, time_step, n_route, c_out]. 150 | ''' 151 | c_si, c_t, c_oo = channels 152 | 153 | with tf.variable_scope(f'stn_block_{scope}_in'): 154 | x_s = temporal_conv_layer(x, Kt, c_si, c_t, act_func=act_func) 155 | x_t = spatio_conv_layer(x_s, Ks, c_t, c_t) 156 | with tf.variable_scope(f'stn_block_{scope}_out'): 157 | x_o = temporal_conv_layer(x_t, Kt, c_t, c_oo) 158 | x_ln = layer_norm(x_o, f'layer_norm_{scope}') 159 | return tf.nn.dropout(x_ln, keep_prob) 160 | 161 | 162 | def fully_con_layer(x, n, channel, scope,out_channel): 163 | ''' 164 | Fully connected layer: maps multi-channels to one. 165 | :param x: tensor, [batch_size, 1, n_route, channel]. 166 | :param n: int, number of route / size of graph. 167 | :param channel: channel size of input x. 168 | :param scope: str, variable scope. 169 | :return: tensor, [batch_size, 1, n_route, 1]. 170 | ''' 171 | w = tf.get_variable(name=f'w_{scope}', shape=[1, 1, channel, out_channel], dtype=tf.float32) 172 | tf.add_to_collection(name='weight_decay', value=tf.nn.l2_loss(w)) 173 | b = tf.get_variable(name=f'b_{scope}', initializer=tf.zeros([n, out_channel]), dtype=tf.float32) 174 | return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') + b 175 | 176 | 177 | def output_layer(x, T, scope, out_channel, act_func='GLU'): 178 | ''' 179 | Output layer: temporal convolution layers attach with one fully connected layer, 180 | which map outputs of the last st_conv block to a single-step prediction. 181 | :param x: tensor, [batch_size, time_step, n_route, channel]. 182 | :param T: int, kernel size of temporal convolution. 183 | :param scope: str, variable scope. 184 | :param act_func: str, activation function. 185 | :return: tensor, [batch_size, 1, n_route, 1]. 186 | ''' 187 | _, _, n, channel = x.get_shape().as_list() 188 | 189 | # maps multi-steps to one. 190 | with tf.variable_scope(f'{scope}_in'): 191 | x_i = temporal_conv_layer(x, T, channel, channel, act_func=act_func) 192 | x_ln = layer_norm(x_i, f'layer_norm_{scope}') 193 | with tf.variable_scope(f'{scope}_out'): 194 | x_o = temporal_conv_layer(x_ln, 1, channel, channel, act_func='sigmoid') 195 | # maps multi-channels to one. 196 | #return x_o 197 | x_fc = fully_con_layer(x_o, n, channel, scope,out_channel) 198 | 199 | #x_o: Tensor("output_layer_out/Sigmoid:0", shape=(?, 1, 16, 128), dtype=float32) 200 | #channel: 128 201 | #x_fc: Tensor("add:0", shape=(?, 1, 16, 1), dtype=float32) 202 | return x_fc 203 | 204 | 205 | def variable_summaries(var, v_name): 206 | ''' 207 | Attach summaries to a Tensor (for TensorBoard visualization). 208 | Ref: https://zhuanlan.zhihu.com/p/33178205 209 | :param var: tf.Variable(). 210 | :param v_name: str, name of the variable. 211 | ''' 212 | with tf.name_scope('summaries'): 213 | mean = tf.reduce_mean(var) 214 | tf.summary.scalar(f'mean_{v_name}', mean) 215 | 216 | with tf.name_scope(f'stddev_{v_name}'): 217 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 218 | tf.summary.scalar(f'stddev_{v_name}', stddev) 219 | 220 | tf.summary.scalar(f'max_{v_name}', tf.reduce_max(var)) 221 | tf.summary.scalar(f'min_{v_name}', tf.reduce_min(var)) 222 | 223 | tf.summary.histogram(f'histogram_{v_name}', var) 224 | -------------------------------------------------------------------------------- /GNN_models/tester.py: -------------------------------------------------------------------------------- 1 | # @Time : Jan. 10, 2019 17:52 2 | # @Author : Veritas YIN 3 | # @FileName : tester.py 4 | # @Version : 1.0 5 | # @IDE : PyCharm 6 | # @Github : https://github.com/VeritasYin/Project_Orion 7 | 8 | from data_loader.data_utils import gen_batch 9 | from utils.math_utils import evaluation 10 | from os.path import join as pjoin 11 | 12 | import tensorflow as tf 13 | import numpy as np 14 | import time 15 | 16 | 17 | def multi_pred(sess, y_pred, seq, batch_size, n_his, n_pred, step_idx, dynamic_batch=True): 18 | ''' 19 | Multi_prediction function. 20 | :param sess: tf.Session(). 21 | :param y_pred: placeholder. 22 | :param seq: np.ndarray, [len_seq, n_frame, n_route, C_0]. 23 | :param batch_size: int, the size of batch. 24 | :param n_his: int, size of historical records for training. 25 | :param n_pred: int, the length of prediction. 26 | :param step_idx: int or list, index for prediction slice. 27 | :param dynamic_batch: bool, whether changes the batch size in the last one if its length is less than the default. 28 | :return y_ : tensor, 'sep' [len_inputs, n_route, 1]; 'merge' [step_idx, len_inputs, n_route, 1]. 29 | len_ : int, the length of prediction. 30 | ''' 31 | pred_list = [] 32 | for i in gen_batch(seq, min(batch_size, len(seq)), dynamic_batch=dynamic_batch): 33 | # Note: use np.copy() to avoid the modification of source data. 34 | test_seq = np.copy(i[:, 0:n_his + 1, :, :]) 35 | step_list = [] 36 | for j in range(n_pred): 37 | pred = sess.run(y_pred, 38 | feed_dict={'data_input:0': test_seq, 'keep_prob:0': 1.0}) 39 | if isinstance(pred, list): 40 | pred = np.array(pred[0]) 41 | test_seq[:, 0:n_his - 1, :, :] = test_seq[:, 1:n_his, :, :] 42 | test_seq[:, n_his - 1, :, :] = pred 43 | step_list.append(pred) 44 | pred_list.append(step_list) 45 | # pred_array -> [n_pred, batch_size, n_route, C_0) 46 | pred_array = np.concatenate(pred_list, axis=1) 47 | return pred_array[step_idx], pred_array.shape[1] 48 | 49 | 50 | def model_inference(sess, pred, inputs, batch_size, n_his, n_pred, step_idx, min_va_val, min_val): 51 | ''' 52 | Model inference function. 53 | :param sess: tf.Session(). 54 | :param pred: placeholder. 55 | :param inputs: instance of class Dataset, data source for inference. 56 | :param batch_size: int, the size of batch. 57 | :param n_his: int, the length of historical records for training. 58 | :param n_pred: int, the length of prediction. 59 | :param step_idx: int or list, index for prediction slice. 60 | :param min_va_val: np.ndarray, metric values on validation set. 61 | :param min_val: np.ndarray, metric values on test set. 62 | ''' 63 | x_val, x_test, x_stats = inputs.get_data('val'), inputs.get_data('test'), inputs.get_stats() 64 | 65 | if n_his + n_pred > x_val.shape[1]: 66 | raise ValueError(f'ERROR: the value of n_pred "{n_pred}" exceeds the length limit.') 67 | 68 | y_val, len_val = multi_pred(sess, pred, x_val, batch_size, n_his, n_pred, step_idx) 69 | evl_val = evaluation(x_val[0:len_val, step_idx + n_his, :, :], y_val, x_stats) 70 | 71 | # chks: indicator that reflects the relationship of values between evl_val and min_va_val. 72 | chks = evl_val < min_va_val 73 | # update the metric on test set, if model's performance got improved on the validation. 74 | if sum(chks): 75 | min_va_val[chks] = evl_val[chks] 76 | y_pred, len_pred = multi_pred(sess, pred, x_test, batch_size, n_his, n_pred, step_idx) 77 | evl_pred = evaluation(x_test[0:len_pred, step_idx + n_his, :, :], y_pred, x_stats) 78 | min_val = evl_pred 79 | return min_va_val, min_val 80 | 81 | 82 | def model_test(inputs, batch_size, n_his, n_pred, inf_mode, load_path='./output/models/'): 83 | ''' 84 | Load and test saved model from the checkpoint. 85 | :param inputs: instance of class Dataset, data source for test. 86 | :param batch_size: int, the size of batch. 87 | :param n_his: int, the length of historical records for training. 88 | :param n_pred: int, the length of prediction. 89 | :param inf_mode: str, test mode - 'merge / multi-step test' or 'separate / single-step test'. 90 | :param load_path: str, the path of loaded model. 91 | ''' 92 | start_time = time.time() 93 | model_path = tf.train.get_checkpoint_state(load_path).model_checkpoint_path 94 | 95 | test_graph = tf.Graph() 96 | 97 | with test_graph.as_default(): 98 | saver = tf.train.import_meta_graph(pjoin(f'{model_path}.meta')) 99 | 100 | with tf.Session(graph=test_graph) as test_sess: 101 | saver.restore(test_sess, tf.train.latest_checkpoint(load_path)) 102 | print(f'>> Loading saved model from {model_path} ...') 103 | #./output/models/STGCN-9150 104 | 105 | pred = test_graph.get_collection('y_pred') 106 | 107 | 108 | if inf_mode == 'sep': 109 | # for inference mode 'sep', the type of step index is int. 110 | step_idx = n_pred - 1 111 | tmp_idx = [step_idx] 112 | elif inf_mode == 'merge': 113 | # for inference mode 'merge', the type of step index is np.ndarray. 114 | step_idx = tmp_idx = np.arange(3, n_pred + 1, 3) - 1 115 | else: 116 | raise ValueError(f'ERROR: test mode "{inf_mode}" is not defined.') 117 | 118 | x_test, x_stats = inputs.get_data('test'), inputs.get_stats() 119 | 120 | y_test, len_test = multi_pred(test_sess, pred, x_test, batch_size, n_his, n_pred, step_idx) 121 | evl = evaluation(x_test[0:len_test, step_idx + n_his, :, :], y_test, x_stats) 122 | print(evl,tmp_idx,type(y_test)) 123 | 124 | 125 | for ix in tmp_idx: 126 | te = evl[ix - 2:ix + 1] 127 | print(f'Time Step {ix + 1}: MAPE {te[0]:7.3%}; MAE {te[1]:4.3f}; RMSE {te[2]:6.3f}.') 128 | print(f'Model Test Time {time.time() - start_time:.3f}s') 129 | print('Testing model finished!') 130 | -------------------------------------------------------------------------------- /GNN_models/trainer.py: -------------------------------------------------------------------------------- 1 | # @Time : Jan. 13, 2019 20:16 2 | # @Author : Veritas YIN 3 | # @FileName : trainer.py 4 | # @Version : 1.0 5 | # @IDE : PyCharm 6 | # @Github : https://github.com/VeritasYin/Project_Orion 7 | 8 | from data_loader.data_utils import gen_batch 9 | from models.tester import model_inference 10 | from models.base_model import build_model, model_save 11 | from os.path import join as pjoin 12 | 13 | import tensorflow as tf 14 | import numpy as np 15 | import time 16 | 17 | 18 | def model_train(inputs, blocks, args, sum_path='./output/tensorboard'): 19 | ''' 20 | Train the base model. 21 | :param inputs: instance of class Dataset, data source for training. 22 | :param blocks: list, channel configs of st_conv blocks. 23 | :param args: instance of class argparse, args for training. 24 | ''' 25 | n, n_his, n_pred = args.n_route, args.n_his, args.n_pred 26 | #228,12,9 27 | Ks, Kt = args.ks, args.kt 28 | #3,3 29 | batch_size, epoch, inf_mode, opt = args.batch_size, args.epoch, args.inf_mode, args.opt 30 | #50,50,'merge','RMSProp' 31 | 32 | # Placeholder for model training 33 | x = tf.placeholder(tf.float32, [None, n_his + 1, n, args.channel], name='data_input') 34 | keep_prob = tf.placeholder(tf.float32, name='keep_prob') 35 | 36 | # Define model loss 37 | train_loss, pred = build_model(x, n_his, Ks, Kt, blocks, keep_prob) 38 | #blocks = [[1, 32, 64], [64, 32, 128]] 39 | 40 | #train_loss:Tensor("L2Loss_2:0", shape=(), dtype=float32) 41 | #pred: Tensor("strided_slice_4:0", shape=(?, 228, 1), dtype=float32) 42 | 43 | tf.summary.scalar('train_loss', train_loss) 44 | copy_loss = tf.add_n(tf.get_collection('copy_loss')) 45 | #copy_loss : Tensor("L2Loss_1:0", shape=(), dtype=float32) 46 | tf.summary.scalar('copy_loss', copy_loss) 47 | 48 | # Learning rate settings 49 | global_steps = tf.Variable(0, trainable=False) 50 | #global_steps : 51 | 52 | len_train = inputs.get_len('train') #9112 53 | if len_train % batch_size == 0: 54 | epoch_step = len_train / batch_size 55 | else: 56 | epoch_step = int(len_train / batch_size) + 1 57 | #epoch_step : 183 58 | # Learning rate decay with rate 0.7 every 5 epochs. 59 | lr = tf.train.exponential_decay(args.lr, global_steps, decay_steps=5 * epoch_step, decay_rate=0.7, staircase=True) 60 | tf.summary.scalar('learning_rate', lr) 61 | step_op = tf.assign_add(global_steps, 1) 62 | #lr: Tensor("ExponentialDecay:0", shape=(), dtype=float32) 63 | #step_op: Tensor("AssignAdd:0", shape=(), dtype=int32_ref) 64 | with tf.control_dependencies([step_op]): 65 | if opt == 'RMSProp': 66 | train_op = tf.train.RMSPropOptimizer(lr).minimize(train_loss) 67 | elif opt == 'ADAM': 68 | train_op = tf.train.AdamOptimizer(lr).minimize(train_loss) 69 | else: 70 | raise ValueError(f'ERROR: optimizer "{opt}" is not defined.') 71 | 72 | merged = tf.summary.merge_all() 73 | #Tensor("Merge/MergeSummary:0", shape=(), dtype=string) 74 | 75 | with tf.Session() as sess: 76 | writer = tf.summary.FileWriter(pjoin(sum_path, 'train'), sess.graph) 77 | sess.run(tf.global_variables_initializer()) 78 | 79 | if inf_mode == 'sep': 80 | # for inference mode 'sep', the type of step index is int. 81 | step_idx = n_pred - 1 82 | tmp_idx = [step_idx] 83 | min_val = min_va_val = np.array([4e1, 1e5, 1e5]) 84 | elif inf_mode == 'merge': 85 | #走这里 86 | # for inference mode 'merge', the type of step index is np.ndarray. 87 | step_idx = tmp_idx = np.arange(3, n_pred + 1, 3) - 1 88 | #step_idx = tmp_idx = np.arange(3, n_pred + 1, 3) - 1 89 | #[2,5,8] 90 | min_val = min_va_val = np.array([4e1, 1e5, 1e5] * len(step_idx)) 91 | #[4.e+01, 1.e+05, 1.e+05, 4.e+01, 1.e+05, 1.e+05, 4.e+01, 1.e+05,1.e+05] 92 | else: 93 | raise ValueError(f'ERROR: test mode "{inf_mode}" is not defined.') 94 | 95 | for i in range(epoch): #50 96 | start_time = time.time() 97 | for j, x_batch in enumerate( 98 | gen_batch(inputs.get_data('train'), batch_size, dynamic_batch=True, shuffle=True)): 99 | summary, _ = sess.run([merged, train_op], feed_dict={x: x_batch[:, 0:n_his + 1, :, :], keep_prob: 1.0}) 100 | writer.add_summary(summary, i * epoch_step + j) 101 | if j % 50 == 0: 102 | loss_value = \ 103 | sess.run([train_loss, copy_loss], 104 | feed_dict={x: x_batch[:, 0:n_his + 1, :, :], keep_prob: 1.0}) 105 | print(f'Epoch {i:2d}, Step {j:3d}: [{loss_value[0]:.3f}, {loss_value[1]:.3f}]') 106 | with open(r'OUTPUT.txt','a') as record_file: 107 | record_file.write(f'Epoch {i:2d}, Step {j:3d}: [{loss_value[0]:.3f}, {loss_value[1]:.3f}]\n') 108 | print(f'Epoch {i:2d} Training Time {time.time() - start_time:.3f}s') 109 | with open(r'OUTPUT.txt','a') as record_file: 110 | record_file.write(f'Epoch {i:2d} Training Time {time.time() - start_time:.3f}s\n') 111 | 112 | start_time = time.time() 113 | min_va_val, min_val = \ 114 | model_inference(sess, pred, inputs, batch_size, n_his, n_pred, step_idx, min_va_val, min_val) 115 | 116 | for ix in tmp_idx: 117 | va, te = min_va_val[ix - 2:ix + 1], min_val[ix - 2:ix + 1] 118 | print(f'Time Step {ix + 1}: ' 119 | f'MAPE {va[0]:7.3%}, {te[0]:7.3%}; ' 120 | f'MAE {va[1]:4.3f}, {te[1]:4.3f}; ' 121 | f'RMSE {va[2]:6.3f}, {te[2]:6.3f}.') 122 | print(f'Epoch {i:2d} Inference Time {time.time() - start_time:.3f}s') 123 | with open(r'OUTPUT.txt','a') as record_file: 124 | record_file.write(f'Epoch {i:2d} Inference Time {time.time() - start_time:.3f}s\n') 125 | 126 | if (i + 1) % args.save == 0: 127 | model_save(sess, global_steps, 'STGCN') 128 | writer.close() 129 | print('Training model finished!') 130 | -------------------------------------------------------------------------------- /GNN_output/models/STGCN-1250.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_output/models/STGCN-1250.data-00000-of-00001 -------------------------------------------------------------------------------- /GNN_output/models/STGCN-1250.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_output/models/STGCN-1250.index -------------------------------------------------------------------------------- /GNN_output/models/STGCN-1250.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_output/models/STGCN-1250.meta -------------------------------------------------------------------------------- /GNN_output/models/STGCN-2500.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_output/models/STGCN-2500.data-00000-of-00001 -------------------------------------------------------------------------------- /GNN_output/models/STGCN-2500.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_output/models/STGCN-2500.index -------------------------------------------------------------------------------- /GNN_output/models/STGCN-2500.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_output/models/STGCN-2500.meta -------------------------------------------------------------------------------- /GNN_output/models/STGCN-3750.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_output/models/STGCN-3750.data-00000-of-00001 -------------------------------------------------------------------------------- /GNN_output/models/STGCN-3750.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_output/models/STGCN-3750.index -------------------------------------------------------------------------------- /GNN_output/models/STGCN-3750.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_output/models/STGCN-3750.meta -------------------------------------------------------------------------------- /GNN_output/models/STGCN-5000.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_output/models/STGCN-5000.data-00000-of-00001 -------------------------------------------------------------------------------- /GNN_output/models/STGCN-5000.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_output/models/STGCN-5000.index -------------------------------------------------------------------------------- /GNN_output/models/STGCN-5000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_output/models/STGCN-5000.meta -------------------------------------------------------------------------------- /GNN_output/models/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "STGCN-5000" 2 | all_model_checkpoint_paths: "STGCN-5000" 3 | -------------------------------------------------------------------------------- /GNN_output/tensorboard/train/events.out.tfevents.1590335967.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_output/tensorboard/train/events.out.tfevents.1590335967.zip -------------------------------------------------------------------------------- /GNN_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : Jan. 10, 2019 15:24 2 | # @Author : Veritas YIN 3 | # @FileName : __init__.py 4 | # @Version : 1.0 5 | # @IDE : PyCharm 6 | # @Github : https://github.com/VeritasYin/Project_Orion 7 | -------------------------------------------------------------------------------- /GNN_utils/__pycache__/__init__.cpython-36.opt-1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_utils/__pycache__/__init__.cpython-36.opt-1.pyc -------------------------------------------------------------------------------- /GNN_utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /GNN_utils/__pycache__/math_graph.cpython-36.opt-1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_utils/__pycache__/math_graph.cpython-36.opt-1.pyc -------------------------------------------------------------------------------- /GNN_utils/__pycache__/math_graph.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_utils/__pycache__/math_graph.cpython-36.pyc -------------------------------------------------------------------------------- /GNN_utils/__pycache__/math_utils.cpython-36.opt-1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_utils/__pycache__/math_utils.cpython-36.opt-1.pyc -------------------------------------------------------------------------------- /GNN_utils/__pycache__/math_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangf622/GPLight/6010c74e4267ed1e8512e3b50737369ff0d3a0cf/GNN_utils/__pycache__/math_utils.cpython-36.pyc -------------------------------------------------------------------------------- /GNN_utils/math_graph.py: -------------------------------------------------------------------------------- 1 | # @Time : Jan. 10, 2019 15:21 2 | # @Author : Veritas YIN 3 | # @FileName : math_graph.py 4 | # @Version : 1.0 5 | # @IDE : PyCharm 6 | # @Github : https://github.com/VeritasYin/Project_Orion 7 | 8 | import numpy as np 9 | import pandas as pd 10 | from scipy.sparse.linalg import eigs 11 | 12 | 13 | def scaled_laplacian(W): 14 | ''' 15 | Normalized graph Laplacian function. 16 | :param W: np.ndarray, [n_route, n_route], weighted adjacency matrix of G. 17 | :return: np.matrix, [n_route, n_route]. 18 | ''' 19 | # d -> diagonal degree matrix 20 | n, d = np.shape(W)[0], np.sum(W, axis=1) 21 | # L -> graph Laplacian 22 | L = -W 23 | L[np.diag_indices_from(L)] = d 24 | for i in range(n): 25 | for j in range(n): 26 | if (d[i] > 0) and (d[j] > 0): 27 | L[i, j] = L[i, j] / np.sqrt(d[i] * d[j]) 28 | # lambda_max \approx 2.0, the largest eigenvalues of L. 29 | 30 | lambda_max = eigs(L, k=1, which='LR')[0][0].real 31 | return np.mat(2 * L / lambda_max - np.identity(n)) 32 | 33 | 34 | def cheb_poly_approx(L, Ks, n): 35 | ''' 36 | Chebyshev polynomials approximation function. 37 | :param L: np.matrix, [n_route, n_route], graph Laplacian. 38 | :param Ks: int, kernel size of spatial convolution. 39 | :param n: int, number of routes / size of graph. 40 | :return: np.ndarray, [n_route, Ks*n_route]. 41 | ''' 42 | L0, L1 = np.mat(np.identity(n)), np.mat(np.copy(L)) 43 | 44 | if Ks > 1: 45 | L_list = [np.copy(L0), np.copy(L1)] 46 | for i in range(Ks - 2): 47 | Ln = np.mat(2 * L * L1 - L0) 48 | L_list.append(np.copy(Ln)) 49 | L0, L1 = np.matrix(np.copy(L1)), np.matrix(np.copy(Ln)) 50 | # L_lsit [Ks, n*n], Lk [n, Ks*n] 51 | return np.concatenate(L_list, axis=-1) 52 | elif Ks == 1: 53 | return np.asarray(L0) 54 | else: 55 | raise ValueError(f'ERROR: the size of spatial kernel must be greater than 1, but received "{Ks}".') 56 | 57 | 58 | def first_approx(W, n): 59 | ''' 60 | 1st-order approximation function. 61 | :param W: np.ndarray, [n_route, n_route], weighted adjacency matrix of G. 62 | :param n: int, number of routes / size of graph. 63 | :return: np.ndarray, [n_route, n_route]. 64 | ''' 65 | A = W + np.identity(n) 66 | d = np.sum(A, axis=1) 67 | sinvD = np.sqrt(np.mat(np.diag(d)).I) 68 | # refer to Eq.5 69 | return np.mat(np.identity(n) + sinvD * A * sinvD) 70 | 71 | 72 | def weight_matrix(file_path, sigma2=0.1, epsilon=0.5, scaling=True): 73 | ''' 74 | Load weight matrix function. 75 | :param file_path: str, the path of saved weight matrix file. 76 | :param sigma2: float, scalar of matrix W. 77 | :param epsilon: float, thresholds to control the sparsity of matrix W. 78 | :param scaling: bool, whether applies numerical scaling on W. 79 | :return: np.ndarray, [n_route, n_route]. 80 | ''' 81 | try: 82 | W = pd.read_csv(file_path, header=None).values 83 | #(228,228) 84 | except FileNotFoundError: 85 | print(f'ERROR: input file was not found in {file_path}.') 86 | 87 | # check whether W is a 0/1 matrix. 88 | if set(np.unique(W)) == {0, 1}: 89 | print('The input graph is a 0/1 matrix; set "scaling" to False.') 90 | scaling = False 91 | 92 | if scaling: 93 | n = W.shape[0] #228 94 | W = W / 10000. 95 | W2, W_mask = W * W, np.ones([n, n]) - np.identity(n) 96 | #按位相乘, 全1阵减去单位阵 97 | # refer to Eq.10 98 | #返回的这个矩阵只有1664个元素非0 99 | return np.exp(-W2 / sigma2) * (np.exp(-W2 / sigma2) >= epsilon) * W_mask 100 | else: 101 | return W 102 | -------------------------------------------------------------------------------- /GNN_utils/math_utils.py: -------------------------------------------------------------------------------- 1 | # @Time : Jan. 10, 2019 15:15 2 | # @Author : Veritas YIN 3 | # @FileName : math_utils.py 4 | # @Version : 1.0 5 | # @IDE : PyCharm 6 | # @Github : https://github.com/VeritasYin/Project_Orion 7 | 8 | import numpy as np 9 | 10 | 11 | def z_score(x, mean, std): 12 | ''' 13 | Z-score normalization function: $z = (X - \mu) / \sigma $, 14 | where z is the z-score, X is the value of the element, 15 | $\mu$ is the population mean, and $\sigma$ is the standard deviation. 16 | :param x: np.ndarray, input array to be normalized. 17 | :param mean: float, the value of mean. 18 | :param std: float, the value of standard deviation. 19 | :return: np.ndarray, z-score normalized array. 20 | ''' 21 | return (x - mean) / std 22 | 23 | 24 | def z_inverse(x, mean, std): 25 | ''' 26 | The inverse of function z_score(). 27 | :param x: np.ndarray, input to be recovered. 28 | :param mean: float, the value of mean. 29 | :param std: float, the value of standard deviation. 30 | :return: np.ndarray, z-score inverse array. 31 | ''' 32 | return x * std + mean 33 | 34 | 35 | def MAPE(v, v_): 36 | ''' 37 | Mean absolute percentage error. 38 | :param v: np.ndarray or int, ground truth. 39 | :param v_: np.ndarray or int, prediction. 40 | :return: int, MAPE averages on all elements of input. 41 | ''' 42 | return np.mean(np.abs(v_ - v) / (v + 1e-5)) 43 | 44 | 45 | def RMSE(v, v_): 46 | ''' 47 | Mean squared error. 48 | :param v: np.ndarray or int, ground truth. 49 | :param v_: np.ndarray or int, prediction. 50 | :return: int, RMSE averages on all elements of input. 51 | ''' 52 | return np.sqrt(np.mean((v_ - v) ** 2)) 53 | 54 | 55 | def MAE(v, v_): 56 | ''' 57 | Mean absolute error. 58 | :param v: np.ndarray or int, ground truth. 59 | :param v_: np.ndarray or int, prediction. 60 | :return: int, MAE averages on all elements of input. 61 | ''' 62 | return np.mean(np.abs(v_ - v)) 63 | 64 | 65 | def evaluation(y, y_, x_stats): 66 | ''' 67 | Evaluation function: interface to calculate MAPE, MAE and RMSE between ground truth and prediction. 68 | Extended version: multi-step prediction can be calculated by self-calling. 69 | :param y: np.ndarray or int, ground truth. 70 | :param y_: np.ndarray or int, prediction. 71 | :param x_stats: dict, paras of z-scores (mean & std). 72 | :return: np.ndarray, averaged metric values. 73 | ''' 74 | dim = len(y_.shape) 75 | #print('dim',dim,'end') 76 | 77 | if dim == 3: 78 | # single_step case 79 | v = z_inverse(y, x_stats['mean'], x_stats['std']) 80 | v_ = z_inverse(y_, x_stats['mean'], x_stats['std']) 81 | return np.array([MAPE(v, v_), MAE(v, v_), RMSE(v, v_)]) 82 | else: 83 | # multi_step case 84 | tmp_list = [] 85 | # y -> [time_step, batch_size, n_route, 1] 86 | y = np.swapaxes(y, 0, 1) 87 | # recursively call 88 | for i in range(y_.shape[0]): 89 | tmp_res = evaluation(y[i], y_[i], x_stats) 90 | tmp_list.append(tmp_res) 91 | return np.concatenate(tmp_list, axis=-1) 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPLight 2 | 3 | GPLight (See the [paper](https://arxiv.org/abs/2009.14627) in arXiv) is a deep reinforcement learning agent for multi-intersection traffic signal control with GNN predicting future traffic. 4 | 5 | We first predict the future traffic throughput is predicted using the [STGCN](https://github.com/VeritasYin/STGCN_IJCAI-18) and then control the traffic signal based on the prediction and the current traffic condition. The code follows similat structure [CoLight](https://github.com/wingsweihua/colight). 6 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | 5 | class Agent(object): 6 | 7 | def __init__(self, dic_agent_conf, dic_traffic_env_conf, dic_path, intersection_id="0"): 8 | 9 | self.dic_agent_conf = dic_agent_conf 10 | self.dic_traffic_env_conf = dic_traffic_env_conf 11 | self.dic_path = dic_path 12 | self.intersection_id = intersection_id 13 | 14 | def choose_action(self): 15 | 16 | raise NotImplementedError 17 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # parameters and paths 2 | from baseline.deeplight_agent import DeeplightAgent 3 | from baseline.fixedtime_agent import FixedtimeAgent 4 | from baseline.fixedtimeoffset_agent import FixedtimeOffsetAgent 5 | from baseline.formula_agent import FormulaAgent 6 | from baseline.maxpressure_agent import MaxPressureAgent 7 | from simple_dqn_agent import SimpleDQNAgent 8 | from baseline.random_agent import RandomAgent 9 | from baseline.sotl_agent import SOTLAgent 10 | from CoLight_agent import CoLightAgent 11 | # from gcn_agent import GCNAgent 12 | from simple_dqn_one_agent import SimpleDQNOneAgent 13 | from lit_agent import LitAgent 14 | # from sumo_env import SumoEnv 15 | from anon_env import AnonEnv 16 | from baseline.sliding_formula_agent import SlidingFormulaAgent 17 | 18 | DIC_SLIDINGFORMULA_AGENT_CONF = { 19 | "DAY_TIME": 3600, 20 | "UPDATE_PERIOD": 300, 21 | "FIXED_TIME": [30, 30], 22 | "ROUND_UP": 5, 23 | "PHASE_TO_LANE": [[0, 1], [3, 4], [6, 7], [9, 10]], 24 | "MIN_PHASE_TIME": 5, 25 | "TRAFFIC_FILE": [ "cross.2phases_rou01_equal_450.xml" ] 26 | } 27 | 28 | DIC_SOTL_AGENT_CONF = { 29 | "PHI": 5, 30 | "MIN_GREEN_VEC": 3, 31 | "MAX_RED_VEC": 6, 32 | } 33 | 34 | DIC_EXP_CONF = { 35 | "RUN_COUNTS": 3600, 36 | "TRAFFIC_FILE": [ 37 | "cross.2phases_rou01_equal_450.xml" 38 | ], 39 | "MODEL_NAME": "SimpleDQN", 40 | "NUM_ROUNDS": 200, 41 | "NUM_GENERATORS": 3, 42 | "LIST_MODEL": 43 | ["Fixedtime", "SOTL", "Deeplight", "SimpleDQN"], 44 | "LIST_MODEL_NEED_TO_UPDATE": 45 | ["Deeplight", "SimpleDQN", "CoLight","GCN", "SimpleDQNOne","Lit"], 46 | "MODEL_POOL": False, 47 | "NUM_BEST_MODEL": 3, 48 | "PRETRAIN": True, 49 | "PRETRAIN_MODEL_NAME": "Random", 50 | "PRETRAIN_NUM_ROUNDS": 10, 51 | "PRETRAIN_NUM_GENERATORS": 10, 52 | "AGGREGATE": False, 53 | "DEBUG": False, 54 | "EARLY_STOP": False, 55 | 56 | "MULTI_TRAFFIC": False, 57 | "MULTI_RANDOM": False, 58 | } 59 | 60 | DIC_LIT_AGENT_CONF = { 61 | "LEARNING_RATE": 0.001, 62 | "LR_DECAY": 1, 63 | "MIN_LR": 0.0001, 64 | "UPDATE_PERIOD": 300, 65 | "SAMPLE_SIZE": 300, 66 | "SAMPLE_SIZE_PRETRAIN": 3000, 67 | "BATCH_SIZE": 20, 68 | "EPOCHS": 100, 69 | "EPOCHS_PRETRAIN": 500, 70 | "SEPARATE_MEMORY": True, 71 | "PRIORITY_SAMPLING": False, 72 | "UPDATE_Q_BAR_FREQ": 5, 73 | "GAMMA": 0.8, 74 | "GAMMA_PRETRAIN": 0, 75 | "MAX_MEMORY_LEN": 1000, 76 | "PATIENCE": 10, 77 | "PHASE_SELECTOR": True, 78 | "KEEP_OLD_MEMORY": 0, 79 | "DDQN": False, 80 | "D_DENSE": 20, 81 | "EPSILON": 0.8, 82 | "EPSILON_DECAY": 0.95, 83 | "MIN_EPSILON": 0.2, 84 | "LOSS_FUNCTION": "mean_squared_error", 85 | "NORMAL_FACTOR": 20, 86 | "EARLY_STOP_LOSS": "val_loss", 87 | "DROPOUT_RATE": 0 88 | } 89 | 90 | 91 | DIC_FIXEDTIME_AGENT_CONF = { 92 | "FIXED_TIME": [ 93 | 30, 94 | 30 95 | ], 96 | } 97 | 98 | DIC_SOTL_AGENT_CONF = { 99 | "PHI": 5, 100 | "MIN_GREEN_VEC": 3, 101 | "MAX_RED_VEC": 6, 102 | } 103 | 104 | DIC_RANDOM_AGENT_CONF = { 105 | 106 | } 107 | 108 | DIC_FORMULA_AGENT_CONF = { 109 | "DAY_TIME": 3600, 110 | "UPDATE_PERIOD": 3600, 111 | "FIXED_TIME": [30, 30], 112 | "ROUND_UP": 5, 113 | "PHASE_TO_LANE": [[0, 1], [2, 3]], 114 | "MIN_PHASE_TIME": 5, 115 | "TRAFFIC_FILE": [ 116 | "cross.2phases_rou01_equal_450.xml" 117 | ], 118 | } 119 | 120 | dic_traffic_env_conf = { 121 | "ACTION_PATTERN": "set", 122 | "NUM_INTERSECTIONS": 1, 123 | "MIN_ACTION_TIME": 10, 124 | "YELLOW_TIME": 5, 125 | "ALL_RED_TIME": 0, 126 | "NUM_PHASES": 2, 127 | "NUM_LANES": 1, 128 | "ACTION_DIM": 2, 129 | "MEASURE_TIME": 10, 130 | "IF_GUI": True, 131 | "DEBUG": False, 132 | 133 | "INTERVAL": 1, 134 | "THREADNUM": 8, 135 | "SAVEREPLAY": True, 136 | "RLTRAFFICLIGHT": True, 137 | 138 | "DIC_FEATURE_DIM": dict( 139 | D_LANE_QUEUE_LENGTH=(4,), 140 | D_LANE_NUM_VEHICLE=(4,), 141 | 142 | D_COMING_VEHICLE = (4,), 143 | D_LEAVING_VEHICLE = (4,), 144 | 145 | D_LANE_NUM_VEHICLE_BEEN_STOPPED_THRES1=(4,), 146 | D_CUR_PHASE=(1,), 147 | D_NEXT_PHASE=(1,), 148 | D_TIME_THIS_PHASE=(1,), 149 | D_TERMINAL=(1,), 150 | D_LANE_SUM_WAITING_TIME=(4,), 151 | D_VEHICLE_POSITION_IMG=(4, 60,), 152 | D_VEHICLE_SPEED_IMG=(4, 60,), 153 | D_VEHICLE_WAITING_TIME_IMG=(4, 60,), 154 | 155 | D_PRESSURE=(1,), 156 | 157 | D_ADJACENCY_MATRIX=(2,) 158 | ), 159 | 160 | "LIST_STATE_FEATURE": [ 161 | "cur_phase", 162 | "time_this_phase", 163 | "vehicle_position_img", 164 | "vehicle_speed_img", 165 | "vehicle_acceleration_img", 166 | "vehicle_waiting_time_img", 167 | "lane_num_vehicle", 168 | "lane_num_vehicle_been_stopped_thres01", 169 | "lane_num_vehicle_been_stopped_thres1", 170 | "lane_queue_length", 171 | "lane_num_vehicle_left", 172 | "lane_sum_duration_vehicle_left", 173 | "lane_sum_waiting_time", 174 | "terminal", 175 | 176 | "coming_vehicle", 177 | "leaving_vehicle", 178 | "pressure", 179 | 180 | "adjacency_matrix", 181 | "adjacency_matrix_lane" 182 | 183 | ], 184 | 185 | "DIC_REWARD_INFO": { 186 | "flickering": 0, 187 | "sum_lane_queue_length": 0, 188 | "sum_lane_wait_time": 0, 189 | "sum_lane_num_vehicle_left": 0, 190 | "sum_duration_vehicle_left": 0, 191 | "sum_num_vehicle_been_stopped_thres01": 0, 192 | "sum_num_vehicle_been_stopped_thres1": -0.25, 193 | "pressure": 0, 194 | }, 195 | 196 | "LANE_NUM": { 197 | "LEFT": 1, 198 | "RIGHT": 1, 199 | "STRAIGHT": 1 200 | }, 201 | 202 | "PHASE": [ 203 | 'WSES', 204 | 'NSSS', 205 | 'WLEL', 206 | 'NLSL', 207 | # 'WSWL', 208 | # 'ESEL', 209 | # 'NSNL', 210 | # 'SSSL', 211 | ], 212 | 213 | } 214 | 215 | _LS = {"LEFT": 1, 216 | "RIGHT": 0, 217 | "STRAIGHT": 1 218 | } 219 | _S = { 220 | "LEFT": 0, 221 | "RIGHT": 0, 222 | "STRAIGHT": 1 223 | } 224 | _LSR = { 225 | "LEFT": 1, 226 | "RIGHT": 1, 227 | "STRAIGHT": 1 228 | } 229 | 230 | DIC_DEEPLIGHT_AGENT_CONF = { 231 | "LEARNING_RATE": 0.001, 232 | "UPDATE_PERIOD": 300, 233 | "SAMPLE_SIZE": 300, 234 | "SAMPLE_SIZE_PRETRAIN": 3000, 235 | "BATCH_SIZE": 20, 236 | "EPOCHS": 100, 237 | "EPOCHS_PRETRAIN": 500, 238 | "SEPARATE_MEMORY": True, 239 | "PRIORITY_SAMPLING": False, 240 | "UPDATE_Q_BAR_FREQ": 5, 241 | "GAMMA": 0.8, 242 | "GAMMA_PRETRAIN": 0, 243 | "MAX_MEMORY_LEN": 1000, 244 | "PATIENCE": 10, 245 | "PHASE_SELECTOR": True, 246 | "KEEP_OLD_MEMORY": 0, 247 | "DDQN": False, 248 | "D_DENSE": 20, 249 | "EPSILON": 0.8, 250 | "EPSILON_DECAY": 0.95, 251 | "MIN_EPSILON": 0.2, 252 | "LOSS_FUNCTION": "mean_squared_error", 253 | "NORMAL_FACTOR": 20, 254 | 255 | } 256 | 257 | DIC_SIMPLEDQN_AGENT_CONF = { 258 | "LEARNING_RATE": 0.001, 259 | "SAMPLE_SIZE": 1000, 260 | "BATCH_SIZE": 20, 261 | "EPOCHS": 100, 262 | "UPDATE_Q_BAR_FREQ": 5, 263 | "UPDATE_Q_BAR_EVERY_C_ROUND": False, 264 | "GAMMA": 0.8, 265 | "MAX_MEMORY_LEN": 10000, 266 | "PATIENCE": 10, 267 | "D_DENSE": 20, 268 | "N_LAYER": 2, 269 | "EPSILON": 0.8, 270 | "EPSILON_DECAY": 0.95, 271 | "MIN_EPSILON": 0.2, 272 | "LOSS_FUNCTION": "mean_squared_error", 273 | "SEPARATE_MEMORY": False, 274 | "NORMAL_FACTOR": 20, 275 | "TRAFFIC_FILE": "cross.2phases_rou01_equal_450.xml", 276 | } 277 | 278 | DIC_SIMPLEDQNONE_AGENT_CONF = { 279 | "LEARNING_RATE": 0.001, 280 | "SAMPLE_SIZE": 1000, 281 | "BATCH_SIZE": 20, 282 | "EPOCHS": 100, 283 | "UPDATE_Q_BAR_FREQ": 5, 284 | "UPDATE_Q_BAR_EVERY_C_ROUND": False, 285 | "GAMMA": 0.8, 286 | "MAX_MEMORY_LEN": 10000, 287 | "PATIENCE": 10, 288 | "D_DENSE": 20, 289 | "N_LAYER": 2, 290 | "EPSILON": 0.8, 291 | "EPSILON_DECAY": 0.95, 292 | "MIN_EPSILON": 0.2, 293 | "LOSS_FUNCTION": "mean_squared_error", 294 | "SEPARATE_MEMORY": False, 295 | "NORMAL_FACTOR": 20, 296 | "TRAFFIC_FILE": "cross.2phases_rou01_equal_450.xml", 297 | } 298 | 299 | DIC_MAXPRESSURE_AGENT_CONF = { 300 | 100: [3, 3], 301 | 200: [3, 3], 302 | 300: [7, 7], 303 | 400: [12, 12], 304 | 500: [15, 15], 305 | 600: [28, 28], 306 | 700: [53, 53] 307 | } 308 | 309 | DIC_COLIGHT_AGENT_CONF = { 310 | "CNN_layers":[[32,32]],#,[32,32],[32,32],[32,32]], 311 | "att_regularization":False, 312 | "rularization_rate":0.03, 313 | "LEARNING_RATE": 0.001, 314 | "SAMPLE_SIZE": 1000, 315 | "BATCH_SIZE": 20, 316 | "EPOCHS": 100, 317 | "UPDATE_Q_BAR_FREQ": 5, 318 | "UPDATE_Q_BAR_EVERY_C_ROUND": False, 319 | "GAMMA": 0.8, 320 | "MAX_MEMORY_LEN": 10000, 321 | "PATIENCE": 10, 322 | "D_DENSE": 20, 323 | "N_LAYER": 2, 324 | #special care for pretrain 325 | "EPSILON": 0.8, 326 | "EPSILON_DECAY": 0.95, 327 | "MIN_EPSILON": 0.2, 328 | 329 | "LOSS_FUNCTION": "mean_squared_error", 330 | "SEPARATE_MEMORY": False, 331 | "NORMAL_FACTOR": 20, 332 | "TRAFFIC_FILE": "cross.2phases_rou01_equal_450.xml", 333 | } 334 | 335 | """ 336 | CONNECTED: 337 | - True->intersection level 338 | - EXTERNAL: 339 | True->exclude lanes from the same intersection 340 | False->all lanes from neighboring intersections 341 | - False->lane level 342 | - IN_OUT_MODE: 343 | - IN: all lanes which have flows into the target lane 344 | - OUT: all lanes which have flow out of the target lane 345 | - IN_OUT: all the alnes that are connected to the target lane 346 | """ 347 | DIC_COLIGHT_SIGNAL_AGENT_CONF = { 348 | "CONNECTED":True, 349 | "EXTERNAL":True, 350 | "IN_OUT_MODE":["IN","OUT","IN_OUT"][0], 351 | # "traffic_name":'Atlanta',#'LA' or 'Atlanta' 352 | "LEARNING_RATE": 0.001, 353 | "SAMPLE_SIZE": 1000, 354 | "BATCH_SIZE": 20, 355 | "EPOCHS": 100, 356 | "UPDATE_Q_BAR_FREQ": 5, 357 | "UPDATE_Q_BAR_EVERY_C_ROUND": False, 358 | "GAMMA": 0.8, 359 | "MAX_MEMORY_LEN": 10000, 360 | "PATIENCE": 10, 361 | "D_DENSE": 20, 362 | "N_LAYER": 2, 363 | #special care for pretrain 364 | "EPSILON": 0.8, 365 | "EPSILON_DECAY": 0.95, 366 | "MIN_EPSILON": 0.2, 367 | 368 | "LOSS_FUNCTION": "mean_squared_error", 369 | "SEPARATE_MEMORY": False, 370 | "NORMAL_FACTOR": 20, 371 | "TRAFFIC_FILE": "cross.2phases_rou01_equal_450.xml", 372 | } 373 | 374 | DIC_GATREGU_AGENT_CONF = { 375 | "rularization_rate":0.03, 376 | "LEARNING_RATE": 0.001, 377 | "SAMPLE_SIZE": 1000, 378 | "BATCH_SIZE": 20, 379 | "EPOCHS": 100, 380 | "UPDATE_Q_BAR_FREQ": 5, 381 | "UPDATE_Q_BAR_EVERY_C_ROUND": False, 382 | "GAMMA": 0.8, 383 | "MAX_MEMORY_LEN": 10000, 384 | "PATIENCE": 10, 385 | "D_DENSE": 20, 386 | "N_LAYER": 2, 387 | #special care for pretrain 388 | "EPSILON": 0.8, 389 | "EPSILON_DECAY": 0.95, 390 | "MIN_EPSILON": 0.2, 391 | 392 | "LOSS_FUNCTION": "mean_squared_error", 393 | "SEPARATE_MEMORY": False, 394 | "NORMAL_FACTOR": 20, 395 | "TRAFFIC_FILE": "cross.2phases_rou01_equal_450.xml", 396 | } 397 | 398 | DIC_STGAT_AGENT_CONF = { 399 | "PRIORITY": False, 400 | "nan_code":True, 401 | "att_regularization":False, 402 | "rularization_rate":0.03, 403 | "LEARNING_RATE": 0.001, 404 | "SAMPLE_SIZE": 1000, 405 | "BATCH_SIZE": 20, 406 | "EPOCHS": 100, 407 | "UPDATE_Q_BAR_FREQ": 5, 408 | "UPDATE_Q_BAR_EVERY_C_ROUND": False, 409 | "GAMMA": 0.8, 410 | "MAX_MEMORY_LEN": 10000, 411 | "PATIENCE": 10, 412 | "D_DENSE": 20, 413 | "N_LAYER": 2, 414 | #special care for pretrain 415 | "EPSILON": 0.8, 416 | "EPSILON_DECAY": 0.95, 417 | "MIN_EPSILON": 0.2, 418 | 419 | "LOSS_FUNCTION": "mean_squared_error", 420 | "SEPARATE_MEMORY": False, 421 | "NORMAL_FACTOR": 20, 422 | "TRAFFIC_FILE": "cross.2phases_rou01_equal_450.xml", 423 | } 424 | 425 | DIC_GCN_AGENT_CONF = { 426 | "LEARNING_RATE": 0.001, 427 | "SAMPLE_SIZE": 1000, 428 | "BATCH_SIZE": 20, 429 | "EPOCHS": 100, 430 | "UPDATE_Q_BAR_FREQ": 5, 431 | "UPDATE_Q_BAR_EVERY_C_ROUND": False, 432 | "GAMMA": 0.8, 433 | "MAX_MEMORY_LEN": 10000, 434 | "PATIENCE": 10, 435 | "D_DENSE": [20,20], 436 | "N_LAYER": 2, 437 | "DROPOUT":0, 438 | "EPSILON": 0.8, 439 | "EPSILON_DECAY": 0.95, 440 | "MIN_EPSILON": 0.2, 441 | "LOSS_FUNCTION": "mean_squared_error", 442 | "SEPARATE_MEMORY": False, 443 | "NORMAL_FACTOR": 20, 444 | "TRAFFIC_FILE": "cross.2phases_rou01_equal_450.xml", 445 | } 446 | 447 | DIC_PATH = { 448 | "PATH_TO_MODEL": "model/default", 449 | "PATH_TO_WORK_DIRECTORY": "records/default", 450 | "PATH_TO_DATA": "data/template", 451 | "PATH_TO_PRETRAIN_MODEL": "model/default", 452 | "PATH_TO_PRETRAIN_WORK_DIRECTORY": "records/default", 453 | "PATH_TO_PRETRAIN_DATA": "data/template", 454 | "PATH_TO_AGGREGATE_SAMPLES": "records/initial", 455 | "PATH_TO_ERROR": "errors/default" 456 | } 457 | 458 | DIC_AGENTS = { 459 | "Deeplight": DeeplightAgent, 460 | "Fixedtime": FixedtimeAgent, 461 | "FixedtimeOffset": FixedtimeOffsetAgent, 462 | "SimpleDQN": SimpleDQNAgent, 463 | "Formula": FormulaAgent, 464 | "Random": RandomAgent, 465 | 'CoLight': CoLightAgent, 466 | # 'GCN': GCNAgent, 467 | 'SimpleDQNOne': SimpleDQNOneAgent, 468 | "MaxPressure": MaxPressureAgent, 469 | "Lit":LitAgent, 470 | 'SOTL': SOTLAgent, 471 | 'SlidingFormula': SlidingFormulaAgent 472 | } 473 | 474 | DIC_ENVS = { 475 | "sumo": AnonEnv, #deprecated 476 | # "sumo": SumoEnv, 477 | "anon": AnonEnv 478 | } 479 | 480 | DIC_FIXEDTIMEOFFSET_AGENT_CONF = { 481 | 100: [3, 3], 482 | 200: [3, 3], 483 | 300: [7, 7], 484 | 400: [12, 12], 485 | 500: [15, 15], 486 | 600: [28, 28], 487 | 700: [53, 53] 488 | } 489 | 490 | DIC_FIXEDTIME = { 491 | 100: [3, 3], 492 | 200: [3, 3], 493 | 300: [7, 7], 494 | 400: [12, 12], 495 | 500: [15, 15], 496 | 600: [28, 28], 497 | 700: [53, 53] 498 | } 499 | 500 | # min_action_time: 1 501 | DIC_FIXEDTIME_NEW_SUMO = { 502 | 100: [4, 4], 503 | 200: [4, 4], 504 | 300: [5, 5], 505 | 400: [13, 13], 506 | 500: [9, 9], 507 | 600: [15, 15], 508 | 700: [22, 22] 509 | } 510 | 511 | DIC_FIXEDTIME_MULTI_PHASE = { 512 | 100: [4 for _ in range(4)], 513 | 200: [4 for _ in range(4)], 514 | 300: [5 for _ in range(4)], 515 | 400: [8 for _ in range(4)], 516 | 500: [9 for _ in range(4)], 517 | 600: [13 for _ in range(4)], 518 | 700: [23 for _ in range(4)], 519 | 750: [30 for _ in range(4)] 520 | } 521 | 522 | #DIC_FIXEDTIME_MULTI_PHASE = { 523 | # 100: [2 for _ in range(4)], 524 | # 200: [2 for _ in range(4)], 525 | # 300: [1 for _ in range(4)], 526 | # 400: [8 for _ in range(4)], 527 | # 500: [16 for _ in range(4)], 528 | # 600: [28 for _ in range(4)], 529 | # 700: [55 for _ in range(4)] 530 | #} 531 | 532 | #DIC_FIXEDTIME_MULTI_PHASE = { 533 | # 100: [5, 5, 5, 5], 534 | # 200: [5, 5, 5, 5], 535 | # 300: [5, 5, 5, 5], 536 | # 400: [15, 15, 15, 15], 537 | # 500: [20, 20, 20, 20], 538 | # 600: [35, 35, 35, 35], 539 | # 700: [50, 50, 50, 50] 540 | #} -------------------------------------------------------------------------------- /construct_sample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from multiprocessing import Pool, Process 4 | import os 5 | import traceback 6 | import pandas as pd 7 | 8 | class ConstructSample: 9 | 10 | def __init__(self, path_to_samples, cnt_round, dic_traffic_env_conf): 11 | self.parent_dir = path_to_samples 12 | self.path_to_samples = path_to_samples + "/round_" + str(cnt_round) 13 | #'.../train_round/round_1' 14 | self.cnt_round = cnt_round 15 | self.dic_traffic_env_conf = dic_traffic_env_conf 16 | self.logging_data_list_per_gen = None 17 | self.hidden_states_list = None 18 | self.samples = [] 19 | self.samples_all_intersection = [None]*self.dic_traffic_env_conf['NUM_INTERSECTIONS'] 20 | 21 | def load_data(self, folder, i): 22 | 23 | try: 24 | f_logging_data = open(os.path.join(self.path_to_samples, folder, "inter_{0}.pkl".format(i)), "rb") 25 | logging_data = pickle.load(f_logging_data) 26 | f_logging_data.close() 27 | return 1, logging_data 28 | 29 | except Exception as e: 30 | print("Error occurs when making samples for inter {0}".format(i)) 31 | print('traceback.format_exc():\n%s' % traceback.format_exc()) 32 | return 0, None 33 | 34 | def load_data_for_system(self, folder): 35 | ''' 36 | Load data for all intersections in one folder 37 | :param folder: 38 | :return: a list of logging data of one intersection for one folder 39 | ''' 40 | self.logging_data_list_per_gen = [] 41 | # load settings 42 | print("Load data for system in ", folder) 43 | self.measure_time = self.dic_traffic_env_conf["MEASURE_TIME"] 44 | self.interval = self.dic_traffic_env_conf["MIN_ACTION_TIME"] 45 | 46 | for i in range(self.dic_traffic_env_conf['NUM_INTERSECTIONS']): 47 | pass_code, logging_data = self.load_data(folder, i) 48 | if pass_code == 0: 49 | return 0 50 | self.logging_data_list_per_gen.append(logging_data) 51 | return 1 52 | 53 | def load_hidden_state_for_system(self, folder): 54 | print("loading hidden states: {0}".format(os.path.join(self.path_to_samples, folder, "hidden_states.pkl"))) 55 | # load settings 56 | if self.hidden_states_list is None: 57 | self.hidden_states_list = [] 58 | 59 | try: 60 | f_hidden_state_data = open(os.path.join(self.path_to_samples, folder, "hidden_states.pkl"), "rb") 61 | hidden_state_data = pickle.load(f_hidden_state_data) # hidden state_data is a list of numpy array 62 | # print(hidden_state_data) 63 | print(len(hidden_state_data)) 64 | hidden_state_data_h_c = np.stack(hidden_state_data, axis=2) 65 | hidden_state_data_h_c = pd.Series(list(hidden_state_data_h_c)) 66 | next_hidden_state_data_h_c = hidden_state_data_h_c.shift(-1) 67 | hidden_state_data_h_c_with_next = pd.concat([hidden_state_data_h_c,next_hidden_state_data_h_c], axis=1) 68 | hidden_state_data_h_c_with_next.columns = ['cur_hidden','next_hidden'] 69 | self.hidden_states_list.append(hidden_state_data_h_c_with_next[:-1].values) 70 | return 1 71 | except Exception as e: 72 | print("Error occurs when loading hidden states in ", folder) 73 | print('traceback.format_exc():\n%s' % traceback.format_exc()) 74 | return 0 75 | 76 | 77 | 78 | def construct_state(self,features,time,i): 79 | ''' 80 | 81 | :param features: 82 | :param time: 83 | :param i: intersection id 84 | :return: 85 | ''' 86 | 87 | state = self.logging_data_list_per_gen[i][time] 88 | assert time == state["time"] 89 | if self.dic_traffic_env_conf["BINARY_PHASE_EXPANSION"]: 90 | state_after_selection = {} 91 | for key, value in state["state"].items(): 92 | if key in features: 93 | if "cur_phase" in key: 94 | state_after_selection[key] = self.dic_traffic_env_conf['PHASE'][self.dic_traffic_env_conf['SIMULATOR_TYPE']][value[0]] 95 | else: 96 | state_after_selection[key] = value 97 | else: 98 | state_after_selection = {key: value for key, value in state["state"].items() if key in features} 99 | # print(state_after_selection) 100 | return state_after_selection 101 | 102 | 103 | def _construct_state_process(self, features, time, state, i): 104 | assert time == state["time"] 105 | if self.dic_traffic_env_conf["BINARY_PHASE_EXPANSION"]: 106 | state_after_selection = {} 107 | for key, value in state["state"].items(): 108 | if key in features: 109 | if "cur_phase" in key: 110 | state_after_selection[key] = self.dic_traffic_env_conf['PHASE'][self.dic_traffic_env_conf['SIMULATOR_TYPE']][value[0]] 111 | else: 112 | state_after_selection[key] = value 113 | else: 114 | state_after_selection = {key: value for key, value in state["state"].items() if key in features} 115 | return state_after_selection, i 116 | 117 | 118 | def get_reward_from_features(self, rs): 119 | reward = {} 120 | reward['reward'] = float(rs['reward']) 121 | return reward 122 | 123 | def cal_reward(self, rs, rewards_components): 124 | r = rs['reward'] 125 | return r 126 | 127 | def construct_reward(self,rewards_components,time, i,phase_duration): 128 | #change self.measure_time -->p hase_duration 129 | 130 | rs = self.logging_data_list_per_gen[i][time + phase_duration - 1] 131 | assert time + phase_duration - 1 == rs["time"] 132 | rs = self.get_reward_from_features(rs['state']) 133 | r_instant = self.cal_reward(rs, rewards_components) 134 | 135 | # average 136 | list_r = [] 137 | for t in range(time, time + phase_duration): 138 | #print("t is ", t) 139 | rs = self.logging_data_list_per_gen[i][t] 140 | assert t == rs["time"] 141 | rs = self.get_reward_from_features(rs['state']) 142 | r = self.cal_reward(rs, rewards_components) 143 | list_r.append(r) 144 | r_average = np.average(list_r) 145 | 146 | return r_instant, r_average 147 | 148 | def judge_action(self,time,i): 149 | if self.logging_data_list_per_gen[i][time]['action'] == -1: 150 | raise ValueError 151 | else: 152 | return self.logging_data_list_per_gen[i][time]['action'] 153 | 154 | 155 | def make_reward(self, folder, i): 156 | ''' 157 | make reward for one folder and one intersection, 158 | add the samples of one intersection into the list.samples_all_intersection[i] 159 | :param i: intersection id 160 | :return: 161 | ''' 162 | if self.samples_all_intersection[i] is None: 163 | self.samples_all_intersection[i] = [] 164 | 165 | if i % 100 == 0: 166 | print("make reward for inter {0} in folder {1}".format(i, folder)) 167 | 168 | list_samples = [] 169 | 170 | 171 | with open(os.path.join(self.path_to_samples, folder, "phase_time.txt"), "r") as phase_time_file: 172 | phase_time_list = phase_time_file.read() 173 | phase_time_list = phase_time_list.split('\n') 174 | phase_time_list = [int(i) for i in phase_time_list] 175 | phase_time_list_acc = [int(sum(phase_time_list[:i])) for i in range(len(phase_time_list)+1)] 176 | #[ 10,10,20] 177 | #[0,10,20,40,...,3400 or 3500], 3600 or 3700] 178 | 179 | 180 | try: 181 | total_time = int(self.logging_data_list_per_gen[i][-1]['time'] + 1)#3599+1=3600 182 | # construct samples 183 | time_count = 0 184 | #for time in range(0, total_time - self.measure_time + 1, self.interval): 185 | #3600-10+1=3591 186 | #0,10,20,...,3590 187 | # self.interval --> phase_duration 188 | 189 | for time_index,time in enumerate(phase_time_list_acc[:-1]): 190 | phase_duration = phase_time_list[time_index] 191 | 192 | state = self.construct_state(self.dic_traffic_env_conf["LIST_STATE_FEATURE"], time, i) 193 | reward_instant, reward_average = self.construct_reward(self.dic_traffic_env_conf["DIC_REWARD_INFO"], 194 | time, i,phase_duration) 195 | action = self.judge_action(time, i) 196 | 197 | if time + phase_duration == total_time: 198 | next_state = self.construct_state(self.dic_traffic_env_conf["LIST_STATE_FEATURE"], 199 | time + phase_duration - 1, i) 200 | 201 | else: 202 | next_state = self.construct_state(self.dic_traffic_env_conf["LIST_STATE_FEATURE"], 203 | time + phase_duration, i) 204 | sample = [state, action, next_state, reward_average, reward_instant, time, 205 | folder+"-"+"round_{0}".format(self.cnt_round)] 206 | list_samples.append(sample) 207 | 208 | 209 | # list_samples = self.evaluate_sample(list_samples) 210 | self.samples_all_intersection[i].extend(list_samples) 211 | return 1 212 | except Exception as e: 213 | print("Error occurs when making rewards in generator {0} for intersection {1}".format(folder, i)) 214 | print('traceback.format_exc():\n%s' % traceback.format_exc()) 215 | return 0 216 | 217 | 218 | def make_reward_for_system(self): 219 | ''' 220 | Iterate all the generator folders, and load all the logging data for all intersections for that folder 221 | At last, save all the logging data for that intersection [all the generators] 222 | :return: 223 | ''' 224 | for folder in os.listdir(self.path_to_samples): 225 | print(folder) 226 | if "generator" not in folder: 227 | continue 228 | 229 | if not self.evaluate_sample(folder) or not self.load_data_for_system(folder): 230 | continue 231 | 232 | for i in range(self.dic_traffic_env_conf['NUM_INTERSECTIONS']): 233 | pass_code = self.make_reward(folder, i) 234 | if pass_code == 0: 235 | continue 236 | 237 | for i in range(self.dic_traffic_env_conf['NUM_INTERSECTIONS']): 238 | self.dump_sample(self.samples_all_intersection[i],"inter_{0}".format(i)) 239 | 240 | 241 | def dump_hidden_states(self, folder): 242 | total_hidden_states = np.vstack(self.hidden_states_list) 243 | print("dump_hidden_states shape:",total_hidden_states.shape) 244 | if folder == "": 245 | with open(os.path.join(self.parent_dir, "total_hidden_states.pkl"),"ab+") as f: 246 | pickle.dump(total_hidden_states, f, -1) 247 | elif "inter" in folder: 248 | with open(os.path.join(self.parent_dir, "total_hidden_states_{0}.pkl".format(folder)),"ab+") as f: 249 | pickle.dump(total_hidden_states, f, -1) 250 | else: 251 | with open(os.path.join(self.path_to_samples, folder, "hidden_states_{0}.pkl".format(folder)),'wb') as f: 252 | pickle.dump(total_hidden_states, f, -1) 253 | 254 | 255 | # def evaluate_sample(self,list_samples): 256 | # return list_samples 257 | 258 | 259 | def evaluate_sample(self, generator_folder): 260 | return True 261 | print("Evaluate samples") 262 | list_files = os.listdir(os.path.join(self.path_to_samples, generator_folder, "")) 263 | df = [] 264 | # print(list_files) 265 | for file in list_files: 266 | if ".csv" not in file: 267 | continue 268 | data = pd.read_csv(os.path.join(self.path_to_samples, generator_folder, file)) 269 | df.append(data) 270 | df = pd.concat(df) 271 | num_vehicles = len(df['Unnamed: 0'].unique()) -len(df[df['leave_time'].isna()]['leave_time']) 272 | if num_vehicles < self.dic_traffic_env_conf['VOLUME']* self.dic_traffic_env_conf['NUM_ROW'] and self.cnt_round > 40: # Todo Heuristic 273 | print("Dumpping samples from ",generator_folder) 274 | return False 275 | else: 276 | return True 277 | 278 | 279 | def dump_sample(self, samples, folder): 280 | if folder == "": 281 | with open(os.path.join(self.parent_dir, "total_samples.pkl"),"ab+") as f: 282 | pickle.dump(samples, f, -1) 283 | elif "inter" in folder: 284 | with open(os.path.join(self.parent_dir, "total_samples_{0}.pkl".format(folder)),"ab+") as f: 285 | pickle.dump(samples, f, -1) 286 | else: 287 | with open(os.path.join(self.path_to_samples, folder, "samples_{0}.pkl".format(folder)),'wb') as f: 288 | pickle.dump(samples, f, -1) 289 | 290 | 291 | if __name__=="__main__": 292 | path_to_samples = "/Users/Wingslet/PycharmProjects/RLSignal/records/test/anon_3_3_test/train_round" 293 | generator_folder = "generator_0" 294 | 295 | dic_traffic_env_conf = { 296 | 297 | "NUM_INTERSECTIONS": 9, 298 | "ACTION_PATTERN": "set", 299 | "MEASURE_TIME": 10, 300 | "MIN_ACTION_TIME": 10, 301 | "DEBUG": False, 302 | "BINARY_PHASE_EXPANSION": True, 303 | "FAST_COMPUTE": True, 304 | 305 | "NEIGHBOR": False, 306 | "MODEL_NAME": "STGAT", 307 | "SIMULATOR_TYPE": "anon", 308 | 309 | 310 | 311 | "SAVEREPLAY": False, 312 | "NUM_ROW": 3, 313 | "NUM_COL": 3, 314 | 315 | "VOLUME": 300, 316 | "ROADNET_FILE": "roadnet_{0}.json".format("3_3"), 317 | 318 | "LIST_STATE_FEATURE": [ 319 | "cur_phase", 320 | # "time_this_phase", 321 | # "vehicle_position_img", 322 | # "vehicle_speed_img", 323 | # "vehicle_acceleration_img", 324 | # "vehicle_waiting_time_img", 325 | "lane_num_vehicle", 326 | # "lane_num_vehicle_been_stopped_thres01", 327 | # "lane_num_vehicle_been_stopped_thres1", 328 | # "lane_queue_length", 329 | # "lane_num_vehicle_left", 330 | # "lane_sum_duration_vehicle_left", 331 | # "lane_sum_waiting_time", 332 | # "terminal", 333 | # "coming_vehicle", 334 | # "leaving_vehicle", 335 | # "pressure" 336 | 337 | # "adjacency_matrix", 338 | # "lane_queue_length", 339 | ], 340 | 341 | "DIC_FEATURE_DIM": dict( 342 | D_LANE_QUEUE_LENGTH=(4,), 343 | D_LANE_NUM_VEHICLE=(4,), 344 | 345 | D_COMING_VEHICLE = (12,), 346 | D_LEAVING_VEHICLE = (12,), 347 | 348 | D_LANE_NUM_VEHICLE_BEEN_STOPPED_THRES1=(4,), 349 | D_CUR_PHASE=(1,), 350 | D_NEXT_PHASE=(1,), 351 | D_TIME_THIS_PHASE=(1,), 352 | D_TERMINAL=(1,), 353 | D_LANE_SUM_WAITING_TIME=(4,), 354 | D_VEHICLE_POSITION_IMG=(4, 60,), 355 | D_VEHICLE_SPEED_IMG=(4, 60,), 356 | D_VEHICLE_WAITING_TIME_IMG=(4, 60,), 357 | 358 | D_PRESSURE=(1,), 359 | 360 | D_ADJACENCY_MATRIX=(2,), 361 | 362 | ), 363 | 364 | "DIC_REWARD_INFO": { 365 | "flickering": 0, 366 | "sum_lane_queue_length": 0, 367 | "sum_lane_wait_time": 0, 368 | "sum_lane_num_vehicle_left": 0, 369 | "sum_duration_vehicle_left": 0, 370 | "sum_num_vehicle_been_stopped_thres01": 0, 371 | "sum_num_vehicle_been_stopped_thres1": -0.25, 372 | "pressure": 0 # -0.25 373 | }, 374 | 375 | "LANE_NUM": { 376 | "LEFT": 1, 377 | "RIGHT": 1, 378 | "STRAIGHT": 1 379 | }, 380 | 381 | "PHASE": { 382 | "sumo": { 383 | 0: [0, 1, 0, 1, 0, 0, 0, 0],# 'WSES', 384 | 1: [0, 0, 0, 0, 0, 1, 0, 1],# 'NSSS', 385 | 2: [1, 0, 1, 0, 0, 0, 0, 0],# 'WLEL', 386 | 3: [0, 0, 0, 0, 1, 0, 1, 0]# 'NLSL', 387 | }, 388 | "anon": { 389 | # 0: [0, 0, 0, 0, 0, 0, 0, 0], 390 | 1: [0, 1, 0, 1, 0, 0, 0, 0],# 'WSES', 391 | 2: [0, 0, 0, 0, 0, 1, 0, 1],# 'NSSS', 392 | 3: [1, 0, 1, 0, 0, 0, 0, 0],# 'WLEL', 393 | 4: [0, 0, 0, 0, 1, 0, 1, 0]# 'NLSL', 394 | # 'WSWL', 395 | # 'ESEL', 396 | # 'WSES', 397 | # 'NSSS', 398 | # 'NSNL', 399 | # 'SSSL', 400 | }, 401 | } 402 | } 403 | 404 | cs = ConstructSample(path_to_samples, 0, dic_traffic_env_conf) 405 | cs.make_reward_for_system() 406 | 407 | -------------------------------------------------------------------------------- /data/synthetic/synthetic.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "vehicle": { 4 | "length": 5.0, 5 | "width": 2.0, 6 | "maxPosAcc": 2.0, 7 | "maxNegAcc": 4.5, 8 | "usualPosAcc": 2.0, 9 | "usualNegAcc": 4.5, 10 | "minGap": 2.5, 11 | "maxSpeed": 11.111, 12 | "headwayTime": 2 13 | }, 14 | "route": [ 15 | "road_0_1_0", 16 | "road_1_1_0", 17 | "road_2_1_0", 18 | "road_3_1_0" 19 | ], 20 | "interval": 20.0, 21 | "startTime": 0, 22 | "endTime": 14400 23 | }, 24 | { 25 | "vehicle": { 26 | "length": 5.0, 27 | "width": 2.0, 28 | "maxPosAcc": 2.0, 29 | "maxNegAcc": 4.5, 30 | "usualPosAcc": 2.0, 31 | "usualNegAcc": 4.5, 32 | "minGap": 2.5, 33 | "maxSpeed": 11.111, 34 | "headwayTime": 2 35 | }, 36 | "route": [ 37 | "road_4_1_2", 38 | "road_3_1_2", 39 | "road_2_1_2", 40 | "road_1_1_2" 41 | ], 42 | "interval": 20.0, 43 | "startTime": 0, 44 | "endTime": 14400 45 | }, 46 | { 47 | "vehicle": { 48 | "length": 5.0, 49 | "width": 2.0, 50 | "maxPosAcc": 2.0, 51 | "maxNegAcc": 4.5, 52 | "usualPosAcc": 2.0, 53 | "usualNegAcc": 4.5, 54 | "minGap": 2.5, 55 | "maxSpeed": 11.111, 56 | "headwayTime": 2 57 | }, 58 | "route": [ 59 | "road_0_2_0", 60 | "road_1_2_0", 61 | "road_2_2_0", 62 | "road_3_2_0" 63 | ], 64 | "interval": 20.0, 65 | "startTime": 0, 66 | "endTime": 14400 67 | }, 68 | { 69 | "vehicle": { 70 | "length": 5.0, 71 | "width": 2.0, 72 | "maxPosAcc": 2.0, 73 | "maxNegAcc": 4.5, 74 | "usualPosAcc": 2.0, 75 | "usualNegAcc": 4.5, 76 | "minGap": 2.5, 77 | "maxSpeed": 11.111, 78 | "headwayTime": 2 79 | }, 80 | "route": [ 81 | "road_4_2_2", 82 | "road_3_2_2", 83 | "road_2_2_2", 84 | "road_1_2_2" 85 | ], 86 | "interval": 20.0, 87 | "startTime": 0, 88 | "endTime": 14400 89 | }, 90 | { 91 | "vehicle": { 92 | "length": 5.0, 93 | "width": 2.0, 94 | "maxPosAcc": 2.0, 95 | "maxNegAcc": 4.5, 96 | "usualPosAcc": 2.0, 97 | "usualNegAcc": 4.5, 98 | "minGap": 2.5, 99 | "maxSpeed": 11.111, 100 | "headwayTime": 2 101 | }, 102 | "route": [ 103 | "road_0_3_0", 104 | "road_1_3_0", 105 | "road_2_3_0", 106 | "road_3_3_0" 107 | ], 108 | "interval": 20.0, 109 | "startTime": 0, 110 | "endTime": 14400 111 | }, 112 | { 113 | "vehicle": { 114 | "length": 5.0, 115 | "width": 2.0, 116 | "maxPosAcc": 2.0, 117 | "maxNegAcc": 4.5, 118 | "usualPosAcc": 2.0, 119 | "usualNegAcc": 4.5, 120 | "minGap": 2.5, 121 | "maxSpeed": 11.111, 122 | "headwayTime": 2 123 | }, 124 | "route": [ 125 | "road_4_3_2", 126 | "road_3_3_2", 127 | "road_2_3_2", 128 | "road_1_3_2" 129 | ], 130 | "interval": 20.0, 131 | "startTime": 0, 132 | "endTime": 14400 133 | }, 134 | { 135 | "vehicle": { 136 | "length": 5.0, 137 | "width": 2.0, 138 | "maxPosAcc": 2.0, 139 | "maxNegAcc": 4.5, 140 | "usualPosAcc": 2.0, 141 | "usualNegAcc": 4.5, 142 | "minGap": 2.5, 143 | "maxSpeed": 11.111, 144 | "headwayTime": 2 145 | }, 146 | "route": [ 147 | "road_1_0_1", 148 | "road_1_1_1", 149 | "road_1_2_1", 150 | "road_1_3_1" 151 | ], 152 | "interval": 20.0, 153 | "startTime": 0, 154 | "endTime": 14400 155 | }, 156 | { 157 | "vehicle": { 158 | "length": 5.0, 159 | "width": 2.0, 160 | "maxPosAcc": 2.0, 161 | "maxNegAcc": 4.5, 162 | "usualPosAcc": 2.0, 163 | "usualNegAcc": 4.5, 164 | "minGap": 2.5, 165 | "maxSpeed": 11.111, 166 | "headwayTime": 2 167 | }, 168 | "route": [ 169 | "road_1_4_3", 170 | "road_1_3_3", 171 | "road_1_2_3", 172 | "road_1_1_3" 173 | ], 174 | "interval": 20.0, 175 | "startTime": 0, 176 | "endTime": 14400 177 | }, 178 | { 179 | "vehicle": { 180 | "length": 5.0, 181 | "width": 2.0, 182 | "maxPosAcc": 2.0, 183 | "maxNegAcc": 4.5, 184 | "usualPosAcc": 2.0, 185 | "usualNegAcc": 4.5, 186 | "minGap": 2.5, 187 | "maxSpeed": 11.111, 188 | "headwayTime": 2 189 | }, 190 | "route": [ 191 | "road_2_0_1", 192 | "road_2_1_1", 193 | "road_2_2_1", 194 | "road_2_3_1" 195 | ], 196 | "interval": 20.0, 197 | "startTime": 0, 198 | "endTime": 14400 199 | }, 200 | { 201 | "vehicle": { 202 | "length": 5.0, 203 | "width": 2.0, 204 | "maxPosAcc": 2.0, 205 | "maxNegAcc": 4.5, 206 | "usualPosAcc": 2.0, 207 | "usualNegAcc": 4.5, 208 | "minGap": 2.5, 209 | "maxSpeed": 11.111, 210 | "headwayTime": 2 211 | }, 212 | "route": [ 213 | "road_2_4_3", 214 | "road_2_3_3", 215 | "road_2_2_3", 216 | "road_2_1_3" 217 | ], 218 | "interval": 20.0, 219 | "startTime": 0, 220 | "endTime": 14400 221 | }, 222 | { 223 | "vehicle": { 224 | "length": 5.0, 225 | "width": 2.0, 226 | "maxPosAcc": 2.0, 227 | "maxNegAcc": 4.5, 228 | "usualPosAcc": 2.0, 229 | "usualNegAcc": 4.5, 230 | "minGap": 2.5, 231 | "maxSpeed": 11.111, 232 | "headwayTime": 2 233 | }, 234 | "route": [ 235 | "road_3_0_1", 236 | "road_3_1_1", 237 | "road_3_2_1", 238 | "road_3_3_1" 239 | ], 240 | "interval": 20.0, 241 | "startTime": 0, 242 | "endTime": 14400 243 | }, 244 | { 245 | "vehicle": { 246 | "length": 5.0, 247 | "width": 2.0, 248 | "maxPosAcc": 2.0, 249 | "maxNegAcc": 4.5, 250 | "usualPosAcc": 2.0, 251 | "usualNegAcc": 4.5, 252 | "minGap": 2.5, 253 | "maxSpeed": 11.111, 254 | "headwayTime": 2 255 | }, 256 | "route": [ 257 | "road_3_4_3", 258 | "road_3_3_3", 259 | "road_3_2_3", 260 | "road_3_1_3" 261 | ], 262 | "interval": 20.0, 263 | "startTime": 0, 264 | "endTime": 14400 265 | } 266 | ] -------------------------------------------------------------------------------- /data/syntheticheavy/synthetic.json: -------------------------------------------------------------------------------- 1 | [{"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_0_1_0", "road_1_1_0", "road_2_1_0", "road_3_1_0"], "interval": 10.0, "startTime": 0, "endTime": 899}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_0_1_0", "road_1_1_0", "road_2_1_0", "road_3_1_0"], "interval": 2.0, "startTime": 900, "endTime": 1799}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_0_1_0", "road_1_1_0", "road_2_1_0", "road_3_1_0"], "interval": 10.0, "startTime": 1800, "endTime": 2699}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_0_1_0", "road_1_1_0", "road_2_1_0", "road_3_1_0"], "interval": 10.0, "startTime": 2700, "endTime": 3600}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_4_1_2", "road_3_1_2", "road_2_1_2", "road_1_1_2"], "interval": 10.0, "startTime": 0, "endTime": 899}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_4_1_2", "road_3_1_2", "road_2_1_2", "road_1_1_2"], "interval": 2.0, "startTime": 900, "endTime": 1799}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_4_1_2", "road_3_1_2", "road_2_1_2", "road_1_1_2"], "interval": 10.0, "startTime": 1800, "endTime": 2699}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_4_1_2", "road_3_1_2", "road_2_1_2", "road_1_1_2"], "interval": 10.0, "startTime": 2700, "endTime": 3600}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_0_2_0", "road_1_2_0", "road_2_2_0", "road_3_2_0"], "interval": 10.0, "startTime": 0, "endTime": 899}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_0_2_0", "road_1_2_0", "road_2_2_0", "road_3_2_0"], "interval": 2.0, "startTime": 900, "endTime": 1799}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_0_2_0", "road_1_2_0", "road_2_2_0", "road_3_2_0"], "interval": 10.0, "startTime": 1800, "endTime": 2699}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_0_2_0", "road_1_2_0", "road_2_2_0", "road_3_2_0"], "interval": 10.0, "startTime": 2700, "endTime": 3600}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_4_2_2", "road_3_2_2", "road_2_2_2", "road_1_2_2"], "interval": 10.0, "startTime": 0, "endTime": 899}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_4_2_2", "road_3_2_2", "road_2_2_2", "road_1_2_2"], "interval": 2.0, "startTime": 900, "endTime": 1799}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_4_2_2", "road_3_2_2", "road_2_2_2", "road_1_2_2"], "interval": 10.0, "startTime": 1800, "endTime": 2699}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_4_2_2", "road_3_2_2", "road_2_2_2", "road_1_2_2"], "interval": 10.0, "startTime": 2700, "endTime": 3600}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_0_3_0", "road_1_3_0", "road_2_3_0", "road_3_3_0"], "interval": 10.0, "startTime": 0, "endTime": 899}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_0_3_0", "road_1_3_0", "road_2_3_0", "road_3_3_0"], "interval": 2.0, "startTime": 900, "endTime": 1799}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_0_3_0", "road_1_3_0", "road_2_3_0", "road_3_3_0"], "interval": 10.0, "startTime": 1800, "endTime": 2699}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_0_3_0", "road_1_3_0", "road_2_3_0", "road_3_3_0"], "interval": 10.0, "startTime": 2700, "endTime": 3600}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_4_3_2", "road_3_3_2", "road_2_3_2", "road_1_3_2"], "interval": 10.0, "startTime": 0, "endTime": 899}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_4_3_2", "road_3_3_2", "road_2_3_2", "road_1_3_2"], "interval": 2.0, "startTime": 900, "endTime": 1799}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_4_3_2", "road_3_3_2", "road_2_3_2", "road_1_3_2"], "interval": 10.0, "startTime": 1800, "endTime": 2699}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_4_3_2", "road_3_3_2", "road_2_3_2", "road_1_3_2"], "interval": 10.0, "startTime": 2700, "endTime": 3600}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_1_0_1", "road_1_1_1", "road_1_2_1", "road_1_3_1"], "interval": 10.0, "startTime": 0, "endTime": 899}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_1_0_1", "road_1_1_1", "road_1_2_1", "road_1_3_1"], "interval": 10.0, "startTime": 900, "endTime": 1799}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_1_0_1", "road_1_1_1", "road_1_2_1", "road_1_3_1"], "interval": 10.0, "startTime": 1800, "endTime": 2699}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_1_0_1", "road_1_1_1", "road_1_2_1", "road_1_3_1"], "interval": 2.0, "startTime": 2700, "endTime": 3600}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_1_4_3", "road_1_3_3", "road_1_2_3", "road_1_1_3"], "interval": 10.0, "startTime": 0, "endTime": 899}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_1_4_3", "road_1_3_3", "road_1_2_3", "road_1_1_3"], "interval": 10.0, "startTime": 900, "endTime": 1799}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_1_4_3", "road_1_3_3", "road_1_2_3", "road_1_1_3"], "interval": 10.0, "startTime": 1800, "endTime": 2699}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_1_4_3", "road_1_3_3", "road_1_2_3", "road_1_1_3"], "interval": 2.0, "startTime": 2700, "endTime": 3600}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_2_0_1", "road_2_1_1", "road_2_2_1", "road_2_3_1"], "interval": 10.0, "startTime": 0, "endTime": 899}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_2_0_1", "road_2_1_1", "road_2_2_1", "road_2_3_1"], "interval": 10.0, "startTime": 900, "endTime": 1799}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_2_0_1", "road_2_1_1", "road_2_2_1", "road_2_3_1"], "interval": 10.0, "startTime": 1800, "endTime": 2699}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_2_0_1", "road_2_1_1", "road_2_2_1", "road_2_3_1"], "interval": 2.0, "startTime": 2700, "endTime": 3600}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_2_4_3", "road_2_3_3", "road_2_2_3", "road_2_1_3"], "interval": 10.0, "startTime": 0, "endTime": 899}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_2_4_3", "road_2_3_3", "road_2_2_3", "road_2_1_3"], "interval": 10.0, "startTime": 900, "endTime": 1799}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_2_4_3", "road_2_3_3", "road_2_2_3", "road_2_1_3"], "interval": 10.0, "startTime": 1800, "endTime": 2699}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_2_4_3", "road_2_3_3", "road_2_2_3", "road_2_1_3"], "interval": 2.0, "startTime": 2700, "endTime": 3600}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_3_0_1", "road_3_1_1", "road_3_2_1", "road_3_3_1"], "interval": 10.0, "startTime": 0, "endTime": 899}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_3_0_1", "road_3_1_1", "road_3_2_1", "road_3_3_1"], "interval": 10.0, "startTime": 900, "endTime": 1799}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_3_0_1", "road_3_1_1", "road_3_2_1", "road_3_3_1"], "interval": 10.0, "startTime": 1800, "endTime": 2699}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_3_0_1", "road_3_1_1", "road_3_2_1", "road_3_3_1"], "interval": 2.0, "startTime": 2700, "endTime": 3600}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_3_4_3", "road_3_3_3", "road_3_2_3", "road_3_1_3"], "interval": 10.0, "startTime": 0, "endTime": 899}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_3_4_3", "road_3_3_3", "road_3_2_3", "road_3_1_3"], "interval": 10.0, "startTime": 900, "endTime": 1799}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_3_4_3", "road_3_3_3", "road_3_2_3", "road_3_1_3"], "interval": 10.0, "startTime": 1800, "endTime": 2699}, {"vehicle": {"length": 5.0, "width": 2.0, "maxPosAcc": 2.0, "maxNegAcc": 4.5, "usualPosAcc": 2.0, "usualNegAcc": 4.5, "minGap": 2.5, "maxSpeed": 11.111, "headwayTime": 2}, "route": ["road_3_4_3", "road_3_3_3", "road_3_2_3", "road_3_1_3"], "interval": 2.0, "startTime": 2700, "endTime": 3600}] -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | from config import DIC_AGENTS, DIC_ENVS 4 | import time 5 | import sys 6 | from multiprocessing import Process, Pool 7 | from gnnmodule import GNN 8 | class Generator: 9 | def __init__(self, cnt_round, cnt_gen, dic_path, dic_exp_conf, dic_agent_conf, dic_traffic_env_conf, best_round=None): 10 | 11 | self.cnt_round = cnt_round 12 | self.cnt_gen = cnt_gen 13 | self.dic_exp_conf = dic_exp_conf 14 | self.dic_path = dic_path 15 | self.dic_agent_conf = copy.deepcopy(dic_agent_conf) 16 | self.dic_traffic_env_conf = dic_traffic_env_conf 17 | self.agents = [None]*dic_traffic_env_conf['NUM_AGENTS'] 18 | 19 | if self.dic_exp_conf["PRETRAIN"]: 20 | self.path_to_log = os.path.join(self.dic_path["PATH_TO_PRETRAIN_WORK_DIRECTORY"], "train_round", 21 | "round_" + str(self.cnt_round), "generator_" + str(self.cnt_gen)) 22 | else: 23 | self.path_to_log = os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], "train_round", "round_"+str(self.cnt_round), "generator_"+str(self.cnt_gen)) 24 | if not os.path.exists(self.path_to_log): 25 | os.makedirs(self.path_to_log) 26 | 27 | self.env = DIC_ENVS[dic_traffic_env_conf["SIMULATOR_TYPE"]]( 28 | path_to_log = self.path_to_log, 29 | path_to_work_directory = self.dic_path["PATH_TO_WORK_DIRECTORY"], 30 | dic_traffic_env_conf = self.dic_traffic_env_conf) 31 | self.env.reset() 32 | 33 | # every generator's output 34 | # generator for pretraining 35 | # Todo pretrain with intersection_id 36 | if self.dic_exp_conf["PRETRAIN"]: 37 | 38 | self.agent_name = self.dic_exp_conf["PRETRAIN_MODEL_NAME"] 39 | self.agent = DIC_AGENTS[self.agent_name]( 40 | dic_agent_conf=self.dic_agent_conf, 41 | dic_traffic_env_conf=self.dic_traffic_env_conf, 42 | dic_path=self.dic_path, 43 | cnt_round=self.cnt_round, 44 | best_round=best_round, 45 | ) 46 | 47 | else: 48 | 49 | start_time = time.time() 50 | 51 | for i in range(dic_traffic_env_conf['NUM_AGENTS']): 52 | agent_name = self.dic_exp_conf["MODEL_NAME"] 53 | #the CoLight_Signal needs to know the lane adj in advance, from environment's intersection list 54 | if agent_name=='CoLight_Signal': 55 | agent = DIC_AGENTS[agent_name]( 56 | dic_agent_conf=self.dic_agent_conf, 57 | dic_traffic_env_conf=self.dic_traffic_env_conf, 58 | dic_path=self.dic_path, 59 | cnt_round=self.cnt_round, 60 | best_round=best_round, 61 | inter_info=self.env.list_intersection, 62 | intersection_id=str(i) 63 | ) 64 | else: 65 | agent = DIC_AGENTS[agent_name]( 66 | dic_agent_conf=self.dic_agent_conf, 67 | dic_traffic_env_conf=self.dic_traffic_env_conf, 68 | dic_path=self.dic_path, 69 | cnt_round=self.cnt_round, 70 | best_round=best_round, 71 | intersection_id=str(i) 72 | ) 73 | self.agents[i] = agent 74 | print("Create intersection agent time: ", time.time()-start_time) 75 | 76 | 77 | 78 | 79 | def cal_phasetime(self,k=10,min_t=10,max_t=20,yellow_time=5,l=5.0,delta_l=2.5,v=11.111,a=2.0): 80 | if k>0: 81 | k1 = int(1+v**2/(2*a*(l+delta_l))) 82 | if k<=k1: 83 | phasetime = (2*(k-1)*(l+delta_l)/a)**0.5 84 | else: 85 | t1 = v/a 86 | x1 = v**2/(2*a) 87 | x2 = (k-1)*(l+delta_l) - x1 88 | t2 = x2/v 89 | phasetime = t1+t2 90 | else: 91 | phasetime=0 92 | phasetime = phasetime+yellow_time 93 | import math 94 | phasetime = math.ceil(phasetime) 95 | if phasetime>max_t: 96 | phasetime = max_t 97 | if phasetimepred_time[0]: 141 | pred_time = pred_time[1:] 142 | pred_num = gnn.get_maxpred(self.env) 143 | max_greentime = int(pred_num/20*20) 144 | max_greentime = min(59,max(20,max_greentime)) 145 | 146 | 147 | pressure = self.env._get_pressure() 148 | p_2 = [0]*self.dic_traffic_env_conf["NUM_INTERSECTIONS"] 149 | 150 | for inter_i in range(self.dic_traffic_env_conf["NUM_INTERSECTIONS"]): 151 | p_2[inter_i] = pressure[inter_i][action_list[inter_i]] 152 | 153 | ave_pressure = sum(p_2)/len(p_2) 154 | phase_time = self.cal_phasetime(k = int(ave_pressure),max_t = max_greentime) 155 | 156 | phase_time_list.append(phase_time) 157 | 158 | next_state, reward, done, _ = self.env.step(action_list,phase_time) 159 | 160 | print('gen -',self.cnt_gen,round(pred_num,2),max_greentime,round(ave_pressure,2),phase_time,self.env.get_current_time()-self.dic_traffic_env_conf["MIN_ACTION_TIME"], time.time()-step_start_time) 161 | state = next_state 162 | step_num += phase_time 163 | running_time = time.time() - running_start_time 164 | with open(self.path_to_log+'/phase_time.txt','w') as phase_time_file: 165 | phase_time_file.write('\n'.join([str(phase_time) for phase_time in phase_time_list])) 166 | 167 | log_start_time = time.time() 168 | print("start logging") 169 | self.env.bulk_log_multi_process() 170 | log_time = time.time() - log_start_time 171 | 172 | self.env.end_sumo() 173 | print("reset_env_time: ", reset_env_time) 174 | print("running_time: ", running_time) 175 | print("log_time: ", log_time) 176 | -------------------------------------------------------------------------------- /gnnmodule.py: -------------------------------------------------------------------------------- 1 | from GNN_data_loader.data_utils import gen_batch 2 | from GNN_utils.math_utils import evaluation 3 | from os.path import join as pjoin 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import time 8 | from GNN_utils.math_graph import * 9 | from GNN_data_loader.data_utils import * 10 | 11 | class GNN(): 12 | def __init__(self,W_file,V_file): 13 | 14 | import os 15 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 16 | from os.path import join as pjoin 17 | import tensorflow as tf 18 | config = tf.ConfigProto() 19 | config.gpu_options.allow_growth = True 20 | args = self.set_args() 21 | 22 | self.n, self.n_his, self.n_pred = args.n_route, args.n_his, args.n_pred 23 | #228,12,9 24 | Ks, Kt = args.ks, args.kt 25 | #3,3 26 | blocks = [[args.channel, 32, 64], [64, 32, 128]] 27 | # Load wighted adjacency matrix W 28 | W = weight_matrix(W_file) 29 | 30 | L = scaled_laplacian(W) 31 | Lk = cheb_poly_approx(L, Ks, self.n) 32 | tf.add_to_collection(name='graph_kernel', value=tf.cast(tf.constant(Lk), tf.float32)) 33 | 34 | # Data Preprocessing 35 | n_train, n_val, n_test = 34, 5, 5 36 | PeMS = data_gen(V_file, (n_train, n_val, n_test,args.channel), 37 | self.n, self.n_his + self.n_pred, 48) 38 | 39 | inputs, batch_size, inf_mode = PeMS, PeMS.get_len('test'), args.inf_mode 40 | 41 | load_path='GNN_output/models/' 42 | model_path = tf.train.get_checkpoint_state(load_path).model_checkpoint_path 43 | test_graph = tf.Graph() 44 | with test_graph.as_default(): 45 | saver = tf.train.import_meta_graph(pjoin(f'{model_path}.meta')) 46 | 47 | self.sess = tf.Session(graph=test_graph) 48 | 49 | saver.restore(self.sess, tf.train.latest_checkpoint(load_path)) 50 | print(f'>> Loading saved model from {model_path} ...') 51 | #./output/models/STGCN-9150 52 | 53 | self.pred = test_graph.get_collection('y_pred') 54 | 55 | self.step_idx = self.n_pred - 1 56 | 57 | 58 | 59 | def set_args(self): 60 | import argparse 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument('--n_route', type=int, default=16) 63 | parser.add_argument('--n_his', type=int, default=10) 64 | parser.add_argument('--n_pred', type=int, default=3) 65 | parser.add_argument('--batch_size', type=int, default=50) 66 | parser.add_argument('--epoch', type=int, default=11) 67 | parser.add_argument('--save', type=int, default=10) 68 | parser.add_argument('--ks', type=int, default=3) 69 | parser.add_argument('--kt', type=int, default=3) 70 | parser.add_argument('--lr', type=float, default=1e-3) 71 | parser.add_argument('--opt', type=str, default='RMSProp') 72 | parser.add_argument('--graph', type=str, default='default') 73 | parser.add_argument('--inf_mode', type=str, default='merge') 74 | parser.add_argument('--channel', type=str, default=8) 75 | args = parser.parse_args() 76 | return args 77 | 78 | def single_pred(self,sess, y_pred, seq, n_his, n_pred, step_idx, dynamic_batch=True): 79 | ''' 80 | seq n_his+n_pred 81 | ''' 82 | pred_list = [] 83 | i = seq 84 | 85 | test_seq = np.copy(i[:, 0:self.n_his + 1, :, :]) 86 | step_list = [] 87 | for j in range(n_pred): 88 | pred = sess.run(y_pred, 89 | feed_dict={'data_input:0': test_seq, 'keep_prob:0': 1.0}) 90 | 91 | if isinstance(pred, list): 92 | pred = np.array(pred[0]) 93 | #print(pred.shape,test_seq.shape) 94 | test_seq[:, 0:n_his - 1, :, :] = test_seq[:, 1:n_his, :, :] 95 | test_seq[:, n_his - 1, :, :] = pred 96 | step_list.append(pred) 97 | pred_list.append(step_list) 98 | # pred_array -> [n_pred, batch_size, n_route, C_0) 99 | pred_array = np.concatenate(pred_list, axis=1) 100 | return pred_array[step_idx], pred_array.shape[1] 101 | 102 | def get_predict(self,x_test): 103 | #x_test 长度 应该是 n_his+n_pred 104 | x_stats = {'mean': np.mean(x_test), 'std': np.std(x_test)} 105 | x_test = (x_test - x_stats['mean'])/x_stats['std'] 106 | x_test = np.array([x_test]) 107 | 108 | y_test, len_test = self.single_pred(self.sess,self.pred,x_test,self.n_his,self.n_pred,self.step_idx) 109 | 110 | y_test = y_test*x_stats['std']+x_stats['mean'] 111 | return y_test[0] 112 | 113 | def get_history(self,env,time_duration,single_his): 114 | keep_lane = [0,2,3,5,6,8,9,11] 115 | #0,1,2,3,4,5,6,7 116 | #0 3 117 | #6 9 118 | #2 5 119 | #8 11 120 | ##0,2;4,6;1,3;5,7 121 | print(len(env.list_inter_log)) 122 | time_now = int(env.get_current_time()) 123 | his = [] 124 | for T in range(time_now-time_duration,time_now,single_his): 125 | his_period = [] 126 | for t in range(T,T+single_his): 127 | if env.list_inter_log[0][t]['action']<0: 128 | continue 129 | his_t = [] 130 | for inter in env.list_inter_log: 131 | his_t.append([inter[t]['state']['lane_num_vehicle'][l] for l in keep_lane]) 132 | his_t = np.array(his_t) #(n,8) 133 | his_period.append(his_t) 134 | his_period = np.array(his_period) #(*,n,8) 135 | his.append(np.max(his_period,axis = 0)) #sum --> (n,8) 136 | his = np.array(his)#(T,n,8) 137 | 138 | a,b,c = his.shape 139 | his_add = np.zeros((a+1,b,c)) 140 | his_add[:a,:,:] = his 141 | his_add[a,:,:] = his[-1,:,:] 142 | return his_add 143 | 144 | def get_lanenum(self,env,action_num=4,time_duration = 600,single_his = 60): 145 | his = self.get_history(env,time_duration,single_his) 146 | pred = self.get_predict(his) # (n,8) 147 | #pred = np.average(his[:-1,:,:],axis=0) 148 | action_2_lane = [0,4,1,5] 149 | lane_num = [] 150 | for inter in range(env.dic_traffic_env_conf["NUM_INTERSECTIONS"]): 151 | l=[] 152 | y = pred[inter] 153 | for a in range(action_num): 154 | press = (y[action_2_lane[a]]+y[action_2_lane[a]+2])/2 155 | l.append(press) 156 | lane_num.append(l) 157 | return lane_num 158 | def get_maxpred(self,env,action_num=4,time_duration = 600,single_his = 60): 159 | lane_num = self.get_lanenum(env,action_num=4,time_duration = 600,single_his = 60) 160 | return np.max(np.array(lane_num)) 161 | 162 | 163 | def addnew(x,num): 164 | ''' 165 | 在每两个数之间新加num个 166 | 167 | ''' 168 | 169 | num = num+1 170 | x_new = [] 171 | for i in range(len(x)-1): 172 | 173 | dif = x[i+1] - x[i] 174 | x_new.extend([x[i]+dif/num*k for k in range(num)]) 175 | return x_new 176 | 177 | 178 | if __name__ == "__main__": 179 | pass 180 | #gnn = GNN('W.csv','V.csv') 181 | #y = gnn.get_predict(np.random.uniform(size=(11,1,8))) 182 | #至少是 n_his+1 183 | -------------------------------------------------------------------------------- /lit_agent.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | from network_agent import NetworkAgent, Selector 4 | 5 | import numpy as np 6 | from keras.layers import Input, Multiply, Add 7 | from keras.models import Model 8 | from keras.optimizers import RMSprop 9 | 10 | from keras.layers.merge import concatenate 11 | 12 | 13 | class LitAgent(NetworkAgent): 14 | def build_network(self): 15 | 16 | '''Initialize a Q network''' 17 | dic_input_node = {} 18 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]: 19 | if "phase" in feature_name and self.dic_traffic_env_conf["BINARY_PHASE_EXPANSION"]: 20 | _shape = (self.dic_traffic_env_conf["DIC_FEATURE_DIM"]["D_" + feature_name.upper()][0],) 21 | d_phase_encoding = _shape[0] 22 | elif "phase" in feature_name and not self.dic_traffic_env_conf["BINARY_PHASE_EXPANSION"]: 23 | _shape = self.dic_traffic_env_conf["DIC_FEATURE_DIM"]["D_"+feature_name.upper()] 24 | d_phase_encoding = _shape[0] 25 | else: 26 | _shape = (self.dic_traffic_env_conf["DIC_FEATURE_DIM"]["D_" + feature_name.upper()][0] * self.num_lanes,) 27 | print("_shape", _shape, feature_name) 28 | dic_input_node[feature_name] = Input(shape=_shape, 29 | name="input_" + feature_name) 30 | 31 | # add cnn to image features 32 | dic_flatten_node = {} 33 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]: 34 | if len(self.dic_traffic_env_conf["DIC_FEATURE_DIM"]["D_"+feature_name.upper()]) > 1: 35 | dic_flatten_node[feature_name] = self._cnn_network_structure(dic_input_node[feature_name]) 36 | else: 37 | dic_flatten_node[feature_name] = dic_input_node[feature_name] 38 | 39 | # concatenate features 40 | list_all_flatten_feature = [] 41 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]: 42 | list_all_flatten_feature.append(dic_flatten_node[feature_name]) 43 | all_flatten_feature = concatenate(list_all_flatten_feature, axis=1, name="all_flatten_feature") 44 | 45 | # shared dense layer 46 | shared_dense = self._shared_network_structure(all_flatten_feature, self.dic_agent_conf["D_DENSE"]) 47 | 48 | # build phase selector layer 49 | if "cur_phase" in self.dic_traffic_env_conf["LIST_STATE_FEATURE"] and self.dic_agent_conf["PHASE_SELECTOR"]: 50 | list_selected_q_values = [] 51 | for phase_id in range(1, self.num_phases+1): 52 | if self.dic_traffic_env_conf["BINARY_PHASE_EXPANSION"]: 53 | # print('d_phase_encoding:',d_phase_encoding)#d_phase_encoding: 96 54 | if d_phase_encoding == 4: 55 | phase_expansion = self.dic_traffic_env_conf["phase_expansion_4_lane"][phase_id] 56 | elif d_phase_encoding == 8: 57 | phase_expansion = self.dic_traffic_env_conf["phase_expansion"][phase_id] 58 | else: 59 | raise NotImplementedError 60 | else: 61 | phase_expansion = phase_id 62 | locals()["q_values_{0}".format(phase_id)] = self._separate_network_structure( 63 | shared_dense, self.dic_agent_conf["D_DENSE"], self.num_actions, memo=phase_id) 64 | locals()["selector_{0}".format(phase_id)] = Selector( 65 | phase_expansion, d_phase_encoding=d_phase_encoding, d_action=self.num_actions, 66 | name="selector_{0}".format(phase_id))(dic_input_node["cur_phase"]) 67 | locals()["q_values_{0}_selected".format(phase_id)] = Multiply(name="multiply_{0}".format(phase_id))( 68 | [locals()["q_values_{0}".format(phase_id)], 69 | locals()["selector_{0}".format(phase_id)]] 70 | ) 71 | list_selected_q_values.append(locals()["q_values_{0}_selected".format(phase_id)]) 72 | q_values = Add()(list_selected_q_values) 73 | else: 74 | q_values = self._separate_network_structure(shared_dense, self.dic_agent_conf["D_DENSE"], self.num_actions) 75 | 76 | network = Model(inputs=[dic_input_node[feature_name] 77 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]], 78 | outputs=q_values) 79 | network.compile(optimizer=RMSprop(lr=self.dic_agent_conf["LEARNING_RATE"]), 80 | loss="mean_squared_error") 81 | network.summary() 82 | 83 | return network 84 | -------------------------------------------------------------------------------- /model_pool.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pandas as pd 4 | import random 5 | import numpy as np 6 | import pickle 7 | from math import isnan 8 | from config import DIC_AGENTS, DIC_ENVS 9 | 10 | validation_set = [ 11 | "synthetic-over-WE254-EW221-NS671-SN747-1893.xml", 12 | # "synthetic-over-WE484-EW484-NS700-SN649-2317.xml", 13 | # "synthetic-over-WE495-EW511-NS634-SN736-2376.xml", 14 | "synthetic-over-WE499-EW450-NS502-SN447-1898.xml", 15 | "synthetic-over-WE510-EW445-NS489-SN524-1968.xml", 16 | "synthetic-under-WE221-EW300-NS509-SN524-1554.xml", 17 | "synthetic-under-WE239-EW262-NS690-SN637-1828.xml", 18 | # "synthetic-under-WE240-EW277-NS509-SN544-1570.xml", 19 | # "synthetic-under-WE247-EW279-NS232-SN242-1000.xml", 20 | # "synthetic-under-WE259-EW228-NS265-SN271-1023.xml" 21 | ] 22 | 23 | DIC_MIN_DURATION = { 24 | 200: 26, 25 | 300: 26, 26 | 350: 27, 27 | 400: 28, 28 | 450: 29, 29 | 500: 30, 30 | 550: 34, 31 | 600: 38, 32 | 650: 40 33 | } 34 | 35 | 36 | def get_traffic_volume(file_name, run_cnt): 37 | scale = run_cnt / 3600 # run_cnt > traffic_time, no exact scale 38 | if "synthetic" in file_name: 39 | sta = file_name.rfind("-") + 1 40 | print(file_name, int(int(file_name[sta:-4]) * scale)) 41 | return int(int(file_name[sta:-4]) * scale) 42 | elif "cross" in file_name: 43 | sta = file_name.find("equal_") + len("equal_") 44 | end = file_name.find(".xml") 45 | return int(int(file_name[sta:end]) * scale * 4) # lane_num = 4 46 | 47 | 48 | class ModelPool(): 49 | def __init__(self, dic_path, dic_exp_conf): 50 | self.dic_path = dic_path 51 | self.exp_conf = dic_exp_conf 52 | 53 | self.num_best_model = self.exp_conf["NUM_BEST_MODEL"] 54 | 55 | if os.path.exists(os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], "best_model.pkl")): 56 | self.best_model_pool = pickle.load( 57 | open(os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], "best_model.pkl"), "rb")) 58 | else: 59 | self.best_model_pool = [] 60 | 61 | def single_test(self, cnt_round): 62 | print("Start testing model pool") 63 | 64 | records_dir = self.dic_path["PATH_TO_WORK_DIRECTORY"] 65 | # run_cnt = 360 66 | if_gui = False 67 | 68 | nan_thres = 80 69 | 70 | dic_agent_conf = json.load(open(os.path.join(records_dir, "agent.conf"), "r")) 71 | dic_exp_conf = json.load(open(os.path.join(records_dir, "exp.conf"), "r")) 72 | dic_traffic_env_conf = json.load(open(os.path.join(records_dir, "traffic_env.conf"), "r")) 73 | 74 | # dic_exp_conf["RUN_COUNTS"] = run_cnt 75 | run_cnt = dic_exp_conf["RUN_COUNTS"] 76 | dic_traffic_env_conf["IF_GUI"] = if_gui 77 | 78 | # dump dic_exp_conf 79 | if os.path.exists(os.path.join(records_dir, "test_exp.conf")): 80 | json.dump(dic_exp_conf, open(os.path.join(records_dir, "test_exp.conf"), "w")) 81 | 82 | if dic_exp_conf["MODEL_NAME"] in dic_exp_conf["LIST_MODEL_NEED_TO_UPDATE"]: 83 | dic_agent_conf["EPSILON"] = 0 # dic_agent_conf["EPSILON"] # + 0.1*cnt_gen 84 | dic_agent_conf["MIN_EPSILON"] = 0 85 | agent_name = dic_exp_conf["MODEL_NAME"] 86 | agent = DIC_AGENTS[agent_name]( 87 | dic_agent_conf=dic_agent_conf, 88 | dic_traffic_env_conf=dic_traffic_env_conf, 89 | dic_path=self.dic_path, 90 | cnt_round=0, # useless 91 | ) 92 | # try: 93 | if 1: 94 | # test 95 | agent.load_network("round_{0}".format(cnt_round)) 96 | 97 | path_to_log = os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], "test_round", 98 | "round_{0}".format(cnt_round)) 99 | if not os.path.exists(path_to_log): 100 | os.makedirs(path_to_log) 101 | env = DIC_ENVS[dic_traffic_env_conf["SIMULATOR_TYPE"]]( 102 | path_to_log=path_to_log, 103 | path_to_work_directory=self.dic_path["PATH_TO_WORK_DIRECTORY"], 104 | dic_traffic_env_conf=dic_traffic_env_conf) 105 | 106 | done = False 107 | state = env.reset() 108 | step_num = 0 109 | 110 | while not done and step_num < int(dic_exp_conf["RUN_COUNTS"] / dic_traffic_env_conf["MIN_ACTION_TIME"]): 111 | action_list = [] 112 | for one_state in state: 113 | action = agent.choose_action(step_num, one_state) 114 | 115 | action_list.append(action) 116 | 117 | next_state, reward, done, _ = env.step(action_list) 118 | 119 | state = next_state 120 | step_num += 1 121 | env.bulk_log() 122 | env.end_sumo() 123 | 124 | # summary items (duration) from csv 125 | df_vehicle_inter_0 = pd.read_csv(os.path.join(path_to_log, "vehicle_inter_0.csv"), 126 | sep=',', header=0, dtype={0: str, 1: float, 2: float}, 127 | names=["vehicle_id", "enter_time", "leave_time"]) 128 | 129 | duration = df_vehicle_inter_0["leave_time"].values - df_vehicle_inter_0["enter_time"].values 130 | 131 | dur = np.mean([time for time in duration if not isnan(time)]) 132 | real_traffic_vol = 0 133 | nan_num = 0 134 | for time in duration: 135 | if not isnan(time): 136 | real_traffic_vol += 1 137 | else: 138 | nan_num += 1 139 | 140 | traffic_vol = get_traffic_volume(dic_exp_conf["TRAFFIC_FILE"][0], run_cnt) 141 | 142 | print(nan_num, nan_thres, self.best_model_pool) 143 | if nan_num < nan_thres: 144 | cnt = 0 145 | for i in range(len(self.best_model_pool)): 146 | if self.best_model_pool[i][1] > dur: 147 | break 148 | cnt += 1 149 | 150 | self.best_model_pool.insert(cnt, [cnt_round, dur]) 151 | 152 | num_max = min(len(self.best_model_pool), self.exp_conf["NUM_BEST_MODEL"]) 153 | self.best_model_pool = self.best_model_pool[:num_max] 154 | 155 | # log best models through rounds 156 | print(self.best_model_pool) 157 | f = open(os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], "best_model_pool.log"), "a") 158 | f.write("round: %d " % cnt_round) 159 | for i in range(len(self.best_model_pool)): 160 | f.write("id: %d, duration: %f, " % (self.best_model_pool[i][0], self.best_model_pool[i][1])) 161 | f.write("\n") 162 | f.close() 163 | print("model pool ends") 164 | # except: 165 | # print("fail to test model:%s"%model_round) 166 | # pass 167 | 168 | def model_compare(self, cnt_round): 169 | print("Start testing model pool") 170 | 171 | records_dir = self.dic_path["PATH_TO_WORK_DIRECTORY"] 172 | # run_cnt = 360 173 | if_gui = False 174 | 175 | nan_thres = 80 176 | 177 | dic_agent_conf = json.load(open(os.path.join(records_dir, "agent.conf"), "r")) 178 | dic_exp_conf = json.load(open(os.path.join(records_dir, "exp.conf"), "r")) 179 | dic_sumo_env_conf = json.load(open(os.path.join(records_dir, "sumo_env.conf"), "r")) 180 | 181 | # dic_exp_conf["RUN_COUNTS"] = run_cnt 182 | run_cnt = dic_exp_conf["RUN_COUNTS"] 183 | dic_sumo_env_conf["IF_GUI"] = if_gui 184 | 185 | # dump dic_exp_conf 186 | if os.path.exists(os.path.join(records_dir, "test_exp.conf")): 187 | json.dump(dic_exp_conf, open(os.path.join(records_dir, "test_exp.conf"), "w")) 188 | 189 | 190 | # try: 191 | path_to_log = os.path.join(records_dir, "test_round", "round_%d"%cnt_round) 192 | if 1: 193 | # summary items (duration) from csv 194 | df_vehicle_inter_0 = pd.read_csv(os.path.join(path_to_log, "vehicle_inter_0.csv"), 195 | sep=',', header=0, dtype={0: str, 1: float, 2: float}, 196 | names=["vehicle_id", "enter_time", "leave_time"]) 197 | 198 | duration = df_vehicle_inter_0["leave_time"].values - df_vehicle_inter_0["enter_time"].values 199 | 200 | dur = np.mean([time for time in duration if not isnan(time)]) 201 | real_traffic_vol = 0 202 | nan_num = 0 203 | for time in duration: 204 | if not isnan(time): 205 | real_traffic_vol += 1 206 | else: 207 | nan_num += 1 208 | 209 | traffic_vol = get_traffic_volume(dic_exp_conf["TRAFFIC_FILE"][0], run_cnt) 210 | 211 | print(nan_num, nan_thres, self.best_model_pool) 212 | if nan_num < nan_thres: 213 | cnt = 0 214 | for i in range(len(self.best_model_pool)): 215 | if self.best_model_pool[i][1] > dur: 216 | break 217 | cnt += 1 218 | 219 | self.best_model_pool.insert(cnt, [cnt_round, dur]) 220 | 221 | num_max = min(len(self.best_model_pool), self.exp_conf["NUM_BEST_MODEL"]) 222 | self.best_model_pool = self.best_model_pool[:num_max] 223 | 224 | # log best models through rounds 225 | print(self.best_model_pool) 226 | f = open(os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], "best_model_pool.log"), "a") 227 | f.write("round: %d " % cnt_round) 228 | for i in range(len(self.best_model_pool)): 229 | f.write("id: %d, duration: %f, " % (self.best_model_pool[i][0], self.best_model_pool[i][1])) 230 | f.write("\n") 231 | f.close() 232 | print("model pool ends") 233 | 234 | def get(self): 235 | if not self.best_model_pool: 236 | return 237 | else: 238 | ind = random.randint(0, len(self.best_model_pool) - 1) 239 | return self.best_model_pool[ind][0] 240 | 241 | def dump_model_pool(self): 242 | if self.best_model_pool: 243 | pickle.dump(self.best_model_pool, 244 | open(os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], "best_model.pkl"), "wb")) 245 | -------------------------------------------------------------------------------- /model_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | from config import DIC_AGENTS, DIC_ENVS 5 | from copy import deepcopy 6 | from gnnmodule import GNN 7 | 8 | def check_all_workers_working(list_cur_p): 9 | for i in range(len(list_cur_p)): 10 | if not list_cur_p[i].is_alive(): 11 | return i 12 | 13 | return -1 14 | 15 | 16 | def downsample(path_to_log, i): 17 | path_to_pkl = os.path.join(path_to_log, "inter_{0}.pkl".format(i)) 18 | with open(path_to_pkl, "rb") as f_logging_data: 19 | logging_data = pickle.load(f_logging_data) 20 | subset_data = logging_data[::10] 21 | os.remove(path_to_pkl) 22 | with open(path_to_pkl, "wb") as f_subset: 23 | pickle.dump(subset_data, f_subset) 24 | 25 | def downsample_for_system(path_to_log,dic_traffic_env_conf): 26 | for i in range(dic_traffic_env_conf['NUM_INTERSECTIONS']): 27 | downsample(path_to_log,i) 28 | 29 | def cal_phasetime(k=10,min_t=10,max_t=20,yellow_time=5,l=5.0,delta_l=2.5,v=11.111,a=2.0): 30 | if k>0: 31 | k1 = int(1+v**2/(2*a*(l+delta_l))) 32 | if k<=k1: 33 | phasetime = (2*(k-1)*(l+delta_l)/a)**0.5 34 | else: 35 | t1 = v/a 36 | x1 = v**2/(2*a) 37 | x2 = (k-1)*(l+delta_l) - x1 38 | t2 = x2/v 39 | phasetime = t1+t2 40 | else: 41 | phasetime=0 42 | phasetime = phasetime+yellow_time 43 | import math 44 | phasetime = math.ceil(phasetime) 45 | if phasetime>max_t: 46 | phasetime = max_t 47 | if phasetimepred_time[0]: 168 | pred_time = pred_time[1:] 169 | pred_num = gnn.get_maxpred(env) 170 | max_greentime = int(pred_num/20*20) 171 | max_greentime = min(59,max(20,max_greentime)) 172 | 173 | pressure = env._get_pressure() 174 | p_2 = [0]*dic_traffic_env_conf["NUM_INTERSECTIONS"] 175 | for inter_i in range(dic_traffic_env_conf["NUM_INTERSECTIONS"]): 176 | p_2[inter_i] = pressure[inter_i][action_list[inter_i]] 177 | 178 | ave_pressure = sum(p_2)/len(p_2) 179 | phase_time = cal_phasetime(k = int(ave_pressure),max_t = max_greentime) 180 | 181 | phase_time_list.append(phase_time) 182 | 183 | next_state, reward, done, _ = env.step(action_list,phase_time) 184 | 185 | state = next_state 186 | step_num += phase_time 187 | print('test -',cnt_round,round(pred_num,2),max_greentime,round(ave_pressure,2),phase_time,step_num,time.time()-START_TIME) 188 | # print('bulk_log_multi_process') 189 | with open(path_to_log+'/phase_time.txt','w') as phase_time_file: 190 | phase_time_file.write('\n'.join([str(phase_time) for phase_time in phase_time_list])) 191 | env.bulk_log_multi_process() 192 | env.log_attention(attention_dict) 193 | 194 | env.end_sumo() 195 | if not dic_exp_conf["DEBUG"]: 196 | path_to_log = os.path.join(dic_path["PATH_TO_WORK_DIRECTORY"], "test_round", 197 | model_round) 198 | # print("downsample", path_to_log) 199 | downsample_for_system(path_to_log, dic_traffic_env_conf) 200 | # print("end down") 201 | 202 | except: 203 | import traceback 204 | error_dir = model_dir.replace("model", "errors") 205 | if not os.path.exists(error_dir): 206 | os.makedirs(error_dir) 207 | f = open(os.path.join(error_dir, "error_info.txt"), "a") 208 | f.write("round_%d fail to test model\n"%cnt_round) 209 | print('traceback.format_exc():\n%s' % traceback.format_exc()) 210 | f.close() 211 | -------------------------------------------------------------------------------- /network_agent.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from keras.layers import Input, Dense, Conv2D, Flatten, BatchNormalization, Activation, Multiply, Add 4 | from keras.models import Model, model_from_json, load_model 5 | from keras.optimizers import RMSprop 6 | from keras.layers.core import Dropout 7 | from keras.layers.pooling import MaxPooling2D 8 | from keras import backend as K 9 | import random 10 | from keras.engine.topology import Layer 11 | import os 12 | from keras.callbacks import EarlyStopping, TensorBoard 13 | 14 | from agent import Agent 15 | 16 | class Selector(Layer): 17 | 18 | def __init__(self, select, d_phase_encoding, d_action, **kwargs): 19 | super(Selector, self).__init__(**kwargs) 20 | self.select = select 21 | self.d_phase_encoding = d_phase_encoding 22 | self.d_action = d_action 23 | self.select_neuron = K.constant(value=self.select, shape=(1, self.d_phase_encoding)) 24 | 25 | def build(self, input_shape): 26 | # Create a trainable weight variable for this layer. 27 | super(Selector, self).build(input_shape) # Be sure to call this somewhere! 28 | 29 | def call(self, x): 30 | batch_size = K.shape(x)[0] 31 | constant = K.tile(self.select_neuron, (batch_size, 1)) 32 | return K.min(K.cast(K.equal(x, constant), dtype="float32"), axis=-1, keepdims=True) 33 | 34 | def get_config(self): 35 | config = {"select": self.select, "d_phase_encoding": self.d_phase_encoding, "d_action": self.d_action} 36 | base_config = super(Selector, self).get_config() 37 | return dict(list(base_config.items()) + list(config.items())) 38 | 39 | def compute_output_shape(self, input_shape): 40 | batch_size = input_shape[0] 41 | return (batch_size, self.d_action) 42 | 43 | 44 | def conv2d_bn(input_layer, index_layer, 45 | filters=16, 46 | kernel_size=(3, 3), 47 | strides=(1, 1)): 48 | """Utility function to apply conv + BN. 49 | # Arguments 50 | x: input tensor. 51 | filters: filters in `Conv2D`. 52 | num_row: height of the convolution kernel. 53 | num_col: width of the convolution kernel. 54 | padding: padding mode in `Conv2D`. 55 | strides: strides in `Conv2D`. 56 | name: name of the ops; will become `name + '_conv'` 57 | for the convolution and `name + '_bn'` for the 58 | batch norm layer. 59 | # Returns 60 | Output tensor after applying `Conv2D` and `BatchNormalization`. 61 | """ 62 | if K.image_data_format() == 'channels_first': 63 | bn_axis = 1 64 | else: 65 | bn_axis = 3 66 | conv = Conv2D(filters=filters, 67 | kernel_size=kernel_size, 68 | strides=strides, 69 | padding='same', 70 | use_bias=False, 71 | name="conv{0}".format(index_layer))(input_layer) 72 | bn = BatchNormalization(axis=bn_axis, scale=False, name="bn{0}".format(index_layer))(conv) 73 | act = Activation('relu', name="act{0}".format(index_layer))(bn) 74 | pooling = MaxPooling2D(pool_size=2)(act) 75 | x = Dropout(0.3)(pooling) 76 | return x 77 | 78 | 79 | class NetworkAgent(Agent): 80 | def __init__(self, dic_agent_conf, dic_traffic_env_conf, dic_path, cnt_round, 81 | best_round=None, bar_round=None, intersection_id="0"): 82 | 83 | super(NetworkAgent, self).__init__( 84 | dic_agent_conf, dic_traffic_env_conf, dic_path, intersection_id=intersection_id) 85 | 86 | # ===== check num actions == num phases ============ 87 | 88 | #self.num_actions = self.dic_sumo_env_conf["ACTION_DIM"] 89 | #self.num_phases = self.dic_sumo_env_conf["NUM_PHASES"] 90 | self.num_actions = len(dic_traffic_env_conf["PHASE"][dic_traffic_env_conf['SIMULATOR_TYPE']]) 91 | self.num_phases = len(dic_traffic_env_conf["PHASE"][dic_traffic_env_conf['SIMULATOR_TYPE']]) 92 | self.num_lanes = np.sum(np.array(list(self.dic_traffic_env_conf["LANE_NUM"].values()))) 93 | 94 | self.memory = self.build_memory() 95 | 96 | if cnt_round == 0: 97 | # initialization 98 | if os.listdir(self.dic_path["PATH_TO_MODEL"]): 99 | if self.dic_traffic_env_conf['ONE_MODEL']: 100 | self.load_network("round_0") 101 | else: 102 | self.load_network("round_0_inter_{0}".format(intersection_id)) 103 | else: 104 | self.q_network = self.build_network() 105 | #self.load_network(self.dic_agent_conf["TRAFFIC_FILE"], file_path=self.dic_path["PATH_TO_PRETRAIN_MODEL"]) 106 | self.q_network_bar = self.build_network_from_copy(self.q_network) 107 | else: 108 | try: 109 | if best_round: 110 | # TODO add "ONE_MODEL" 111 | # use model pool 112 | self.load_network("round_{0}_inter_{1}".format(best_round,self.intersection_id)) 113 | 114 | if bar_round and bar_round != best_round and cnt_round > 10: 115 | # load q_bar network from model pool 116 | self.load_network_bar("round_{0}_inter_{1}".format(bar_round,self.intersection_id)) 117 | else: 118 | if "UPDATE_Q_BAR_EVERY_C_ROUND" in self.dic_agent_conf: 119 | if self.dic_agent_conf["UPDATE_Q_BAR_EVERY_C_ROUND"]: 120 | self.load_network_bar("round_{0}".format( 121 | max((best_round - 1) // self.dic_agent_conf["UPDATE_Q_BAR_FREQ"] * self.dic_agent_conf["UPDATE_Q_BAR_FREQ"], 0), 122 | self.intersection_id)) 123 | else: 124 | self.load_network_bar("round_{0}_inter_{1}".format( 125 | max(best_round - self.dic_agent_conf["UPDATE_Q_BAR_FREQ"], 0), 126 | self.intersection_id)) 127 | else: 128 | self.load_network_bar("round_{0}_inter_{1}".format( 129 | max(best_round - self.dic_agent_conf["UPDATE_Q_BAR_FREQ"], 0), self.intersection_id)) 130 | 131 | else: 132 | # not use model pool 133 | if self.dic_traffic_env_conf['ONE_MODEL']: 134 | self.load_network("round_{0}".format(cnt_round-1, self.intersection_id)) 135 | if "UPDATE_Q_BAR_EVERY_C_ROUND" in self.dic_agent_conf: 136 | if self.dic_agent_conf["UPDATE_Q_BAR_EVERY_C_ROUND"]: 137 | self.load_network_bar("round_{0}".format( 138 | max((cnt_round - 1) // self.dic_agent_conf["UPDATE_Q_BAR_FREQ"] * self.dic_agent_conf["UPDATE_Q_BAR_FREQ"], 0))) 139 | else: 140 | self.load_network_bar("round_{0}".format( 141 | max(cnt_round - self.dic_agent_conf["UPDATE_Q_BAR_FREQ"], 0))) 142 | else: 143 | self.load_network_bar("round_{0}".format( 144 | max(cnt_round - self.dic_agent_conf["UPDATE_Q_BAR_FREQ"], 0))) 145 | 146 | else: 147 | self.load_network("round_{0}_inter_{1}".format(cnt_round-1, self.intersection_id)) 148 | 149 | if "UPDATE_Q_BAR_EVERY_C_ROUND" in self.dic_agent_conf: 150 | if self.dic_agent_conf["UPDATE_Q_BAR_EVERY_C_ROUND"]: 151 | self.load_network_bar("round_{0}_inter_{1}".format( 152 | max((cnt_round - 1) // self.dic_agent_conf["UPDATE_Q_BAR_FREQ"] * self.dic_agent_conf["UPDATE_Q_BAR_FREQ"], 0), 153 | self.intersection_id)) 154 | else: 155 | self.load_network_bar("round_{0}_inter_{1}".format( 156 | max(cnt_round - self.dic_agent_conf["UPDATE_Q_BAR_FREQ"], 0), 157 | self.intersection_id)) 158 | else: 159 | self.load_network_bar("round_{0}_inter_{1}".format( 160 | max(cnt_round - self.dic_agent_conf["UPDATE_Q_BAR_FREQ"], 0), self.intersection_id)) 161 | except: 162 | print("fail to load network, current round: {0}".format(cnt_round)) 163 | 164 | # decay the epsilon 165 | 166 | decayed_epsilon = self.dic_agent_conf["EPSILON"] * pow(self.dic_agent_conf["EPSILON_DECAY"], cnt_round) 167 | self.dic_agent_conf["EPSILON"] = max(decayed_epsilon, self.dic_agent_conf["MIN_EPSILON"]) 168 | 169 | @staticmethod 170 | def _unison_shuffled_copies(Xs, Y, sample_weight): 171 | p = np.random.permutation(len(Y)) 172 | new_Xs = [] 173 | for x in Xs: 174 | assert len(x) == len(Y) 175 | new_Xs.append(x[p]) 176 | return new_Xs, Y[p], sample_weight[p] 177 | 178 | @staticmethod 179 | def _cnn_network_structure(img_features): 180 | conv1 = conv2d_bn(img_features, 1, filters=32, kernel_size=(8, 8), strides=(4, 4)) 181 | conv2 = conv2d_bn(conv1, 2, filters=16, kernel_size=(4, 4), strides=(2, 2)) 182 | img_flatten = Flatten()(conv2) 183 | return img_flatten 184 | 185 | @staticmethod 186 | def _shared_network_structure(state_features, dense_d): 187 | hidden_1 = Dense(dense_d, activation="sigmoid", name="hidden_shared_1")(state_features) 188 | return hidden_1 189 | 190 | @staticmethod 191 | def _separate_network_structure(state_features, dense_d, num_actions, memo=""): 192 | hidden_1 = Dense(dense_d, activation="sigmoid", name="hidden_separate_branch_{0}_1".format(memo))(state_features) 193 | q_values = Dense(num_actions, activation="linear", name="q_values_separate_branch_{0}".format(memo))(hidden_1) 194 | return q_values 195 | 196 | def load_network(self, file_name, file_path=None): 197 | if file_path == None: 198 | file_path = self.dic_path["PATH_TO_MODEL"] 199 | self.q_network = load_model(os.path.join(file_path, "%s.h5" % file_name), custom_objects={"Selector": Selector}) 200 | print("succeed in loading model %s"%file_name) 201 | 202 | def load_network_bar(self, file_name, file_path=None): 203 | if file_path == None: 204 | file_path = self.dic_path["PATH_TO_MODEL"] 205 | self.q_network_bar = load_model(os.path.join(file_path, "%s.h5" % file_name), custom_objects={"Selector": Selector}) 206 | print("succeed in loading model %s"%file_name) 207 | 208 | def save_network(self, file_name): 209 | self.q_network.save(os.path.join(self.dic_path["PATH_TO_MODEL"], "%s.h5" % file_name)) 210 | 211 | def save_network_bar(self, file_name): 212 | self.q_network_bar.save(os.path.join(self.dic_path["PATH_TO_MODEL"], "%s.h5" % file_name)) 213 | 214 | def build_network(self): 215 | 216 | raise NotImplementedError 217 | 218 | def build_memory(self): 219 | 220 | return [] 221 | 222 | def build_network_from_copy(self, network_copy): 223 | 224 | '''Initialize a Q network from a copy''' 225 | 226 | network_structure = network_copy.to_json() 227 | network_weights = network_copy.get_weights() 228 | network = model_from_json(network_structure, custom_objects={"Selector": Selector}) 229 | network.set_weights(network_weights) 230 | network.compile(optimizer=RMSprop(lr=self.dic_agent_conf["LEARNING_RATE"]), 231 | loss=self.dic_agent_conf["LOSS_FUNCTION"]) 232 | return network 233 | 234 | 235 | def prepare_Xs_Y(self, memory, dic_exp_conf): 236 | 237 | 238 | ind_end = len(memory) 239 | print("memory size before forget: {0}".format(ind_end)) 240 | # use all the samples to pretrain, i.e., without forgetting 241 | if dic_exp_conf["PRETRAIN"] or dic_exp_conf["AGGREGATE"]: 242 | sample_slice = memory 243 | # forget 244 | else: 245 | ind_sta = max(0, ind_end - self.dic_agent_conf["MAX_MEMORY_LEN"]) 246 | memory_after_forget = memory[ind_sta: ind_end] 247 | print("memory size after forget:", len(memory_after_forget)) 248 | 249 | # sample the memory 250 | sample_size = min(self.dic_agent_conf["SAMPLE_SIZE"], len(memory_after_forget)) 251 | sample_slice = random.sample(memory_after_forget, sample_size) 252 | print("memory samples number:", sample_size) 253 | 254 | dic_state_feature_arrays = {} 255 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]: 256 | dic_state_feature_arrays[feature_name] = [] 257 | Y = [] 258 | 259 | for i in range(len(sample_slice)): 260 | state, action, next_state, reward, instant_reward, _, _ = sample_slice[i] 261 | # print(state) 262 | 263 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]: 264 | dic_state_feature_arrays[feature_name].append(state[feature_name]) 265 | 266 | _state = [] 267 | _next_state = [] 268 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]: 269 | _state.append([state[feature_name]]) 270 | _next_state.append([next_state[feature_name]]) 271 | target = self.q_network.predict(_state) 272 | 273 | next_state_qvalues = self.q_network_bar.predict(_next_state) 274 | 275 | if self.dic_agent_conf["LOSS_FUNCTION"] == "mean_squared_error": 276 | final_target = np.copy(target[0]) 277 | final_target[action] = reward / self.dic_agent_conf["NORMAL_FACTOR"] + self.dic_agent_conf["GAMMA"] * \ 278 | np.max(next_state_qvalues[0]) 279 | elif self.dic_agent_conf["LOSS_FUNCTION"] == "categorical_crossentropy": 280 | raise NotImplementedError 281 | 282 | Y.append(final_target) 283 | 284 | self.Xs = [np.array(dic_state_feature_arrays[feature_name]) for feature_name in 285 | self.dic_traffic_env_conf["LIST_STATE_FEATURE"]] 286 | self.Y = np.array(Y) 287 | 288 | 289 | def convert_state_to_input(self, s): 290 | if self.dic_traffic_env_conf["BINARY_PHASE_EXPANSION"]: 291 | inputs = [] 292 | # print(s) 293 | for feature in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]: 294 | if "cur_phase" in feature: 295 | inputs.append(np.array([self.dic_traffic_env_conf['PHASE'] 296 | [self.dic_traffic_env_conf['SIMULATOR_TYPE']][s[feature][0]]])) 297 | else: 298 | inputs.append(np.array([s[feature]])) 299 | return inputs 300 | else: 301 | return [np.array([s[feature]]) for feature in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]] 302 | 303 | 304 | def choose_action(self, count, state): 305 | 306 | ''' choose the best action for current state ''' 307 | state_input = self.convert_state_to_input(state) 308 | # print("----------State Input",state_input) 309 | q_values = self.q_network.predict(state_input) 310 | if random.random() <= self.dic_agent_conf["EPSILON"]: # continue explore new Random Action 311 | action = random.randrange(len(q_values[0])) 312 | else: # exploitation 313 | action = np.argmax(q_values[0]) 314 | 315 | return action 316 | 317 | def train_network(self, dic_exp_conf): 318 | 319 | if dic_exp_conf["PRETRAIN"] or dic_exp_conf["AGGREGATE"]: 320 | epochs = 1000 321 | else: 322 | epochs = self.dic_agent_conf["EPOCHS"] 323 | batch_size = min(self.dic_agent_conf["BATCH_SIZE"], len(self.Y)) 324 | 325 | early_stopping = EarlyStopping( 326 | monitor='val_loss', patience=self.dic_agent_conf["PATIENCE"], verbose=0, mode='min') 327 | 328 | hist = self.q_network.fit(self.Xs, self.Y, batch_size=batch_size, epochs=epochs, 329 | shuffle=False, 330 | verbose=2, validation_split=0.3, callbacks=[early_stopping]) 331 | -------------------------------------------------------------------------------- /run_baseline.py: -------------------------------------------------------------------------------- 1 | import config 2 | import copy 3 | from baseline.oneline import OneLine 4 | import os 5 | import time 6 | from multiprocessing import Process 7 | import sys 8 | from script import get_traffic_volume 9 | import argparse 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--memo", type=str, default='0509_evening_MaxPressure_NewYork_28_7_triple') 14 | parser.add_argument("-all", action="store_true", default=False) 15 | parser.add_argument("--env", type=int, default=1) 16 | parser.add_argument("--gui", type=bool, default=0) 17 | parser.add_argument("--road_net", type=str, default='28_7') 18 | parser.add_argument("--volume", type=str, default='newyork')#'300' 19 | parser.add_argument("--ratio", type=str, default='real_triple')#'0.3_uni' 20 | parser.add_argument("--model", type=str, default="MaxPressure")#'SlidingFormula','Fixedtime','MaxPressure' 21 | parser.add_argument("--count",type=int, default=3600) 22 | parser.add_argument("--lane", type=int, default=3) 23 | return parser.parse_args() 24 | 25 | def merge(dic_tmp, dic_to_change): 26 | dic_result = copy.deepcopy(dic_tmp) 27 | dic_result.update(dic_to_change) 28 | 29 | return dic_result 30 | 31 | def check_all_workers_working(list_cur_p): 32 | for i in range(len(list_cur_p)): 33 | if not list_cur_p[i].is_alive(): 34 | return i 35 | return -1 36 | 37 | def oneline_wrapper(dic_exp_conf, dic_agent_conf, dic_traffic_env_conf, dic_path): 38 | oneline = OneLine(dic_exp_conf=merge(config.DIC_EXP_CONF, dic_exp_conf), 39 | dic_agent_conf=merge(getattr(config, "DIC_{0}_AGENT_CONF".format(dic_exp_conf["MODEL_NAME"].upper())), 40 | dic_agent_conf), 41 | dic_traffic_env_conf=merge(config.dic_traffic_env_conf, dic_traffic_env_conf), 42 | dic_path=merge(config.DIC_PATH, dic_path) 43 | ) 44 | oneline.train() 45 | return 46 | 47 | def main(memo, rall, road_net, env, gui, volume, ratio, model,count,lane): 48 | NUM_COL = int(road_net.split('_')[0]) 49 | NUM_ROW = int(road_net.split('_')[1]) 50 | num_intersections = NUM_ROW * NUM_COL 51 | 52 | ENVIRONMENT = ["sumo", "anon"][env] 53 | 54 | 55 | if rall: 56 | traffic_file_list = [ENVIRONMENT+"_"+road_net+"_%d_%s" %(v,r) for v in [300,500,700] for r in [0.3,0.6]] 57 | else: 58 | traffic_file_list=["{0}_{1}_{2}_{3}".format(ENVIRONMENT, road_net, volume, ratio)] 59 | 60 | if env: 61 | traffic_file_list = [i+ ".json" for i in traffic_file_list ] 62 | else: 63 | traffic_file_list = [i+ ".xml" for i in traffic_file_list ] 64 | 65 | process_list = [] 66 | multi_process = True 67 | n_workers = 60 68 | 69 | # !! important para !! 70 | # best cycle length sets the min_action_time 1 71 | # normal formula 1 72 | min_action_time = 10 73 | 74 | for traffic_file in traffic_file_list: 75 | #for cl in range(1, 60): 76 | cl = 30 77 | model_name = model 78 | 79 | # automatically choose the n_segments 80 | if "cross" in traffic_file: 81 | n_segments = 1 82 | else: 83 | n_segments = 12 84 | dic_exp_conf_extra = { 85 | "RUN_COUNTS": count, 86 | "MODEL_NAME": model_name, 87 | "TRAFFIC_FILE": [traffic_file], 88 | "ROADNET_FILE": "roadnet_{0}.json".format(road_net) 89 | } 90 | # 91 | if volume=='jinan' or volume=='hangzhou' or volume=='newyork' or volume=='LA' or volume=='georgia' or 'chacha' in volume: 92 | new_volume=volume 93 | volume=300 94 | if 'chacha' in new_volume: 95 | volume=int(new_volume.split('_')[1]) 96 | else: 97 | new_volume='uniform' 98 | if 'anon' not in traffic_file: 99 | volume = int(traffic_file.split('_')[3]) 100 | else: 101 | volume = get_traffic_volume(traffic_file) 102 | 103 | ''' 104 | ## deeplight agent conf 105 | dic_agent_conf_extra = { 106 | "EPOCHS": 100, 107 | "SAMPLE_SIZE": 300, 108 | "MAX_MEMORY_LEN": 1000, 109 | "PHASE_SELECTOR": False, 110 | "SEPARATE_MEMORY": False, 111 | "EPSILON": 0 112 | 113 | } 114 | ''' 115 | 116 | 117 | dic_traffic_env_conf_extra = { 118 | "USE_LANE_ADJACENCY": True, 119 | "NEIGHBOR": False, 120 | "NUM_AGENTS": num_intersections, 121 | "NUM_INTERSECTIONS": num_intersections, 122 | "ACTION_PATTERN": "set", 123 | "MIN_ACTION_TIME": 20, 124 | "MEASURE_TIME": 10, 125 | "IF_GUI": gui, 126 | "DEBUG": False, 127 | "TOP_K_ADJACENCY": 4, 128 | "SIMULATOR_TYPE": ENVIRONMENT, 129 | "BINARY_PHASE_EXPANSION": True, 130 | "FAST_COMPUTE": False, 131 | "ADJACENCY_BY_CONNECTION_OR_GEO":False, 132 | "TOP_K_ADJACENCY_LANE": 5, 133 | "MODEL_NAME":model_name, 134 | 135 | 136 | "SAVEREPLAY": False, 137 | "NUM_ROW": NUM_ROW, 138 | "NUM_COL": NUM_COL, 139 | 140 | "TRAFFIC_FILE": traffic_file, 141 | "ROADNET_FILE": "roadnet_{0}.json".format(road_net), 142 | 143 | "LIST_STATE_FEATURE": [ 144 | "cur_phase", 145 | "time_this_phase", 146 | # "vehicle_position_img", 147 | # "vehicle_speed_img", 148 | # "vehicle_acceleration_img", 149 | # "vehicle_waiting_time_img", 150 | "lane_num_vehicle", 151 | # "lane_num_vehicle_been_stopped_thres01", 152 | # "lane_num_vehicle_been_stopped_thres1", 153 | # "lane_queue_length", 154 | # "lane_num_vehicle_left", 155 | # "lane_sum_duration_vehicle_left", 156 | # "lane_sum_waiting_time", 157 | # "terminal", 158 | # "lane_coming_vehicle", 159 | "coming_vehicle", 160 | "leaving_vehicle", 161 | # "pressure", 162 | 163 | # "adjacency_matrix", 164 | # "lane_queue_length" 165 | ], 166 | 167 | "DIC_FEATURE_DIM": dict( 168 | D_LANE_QUEUE_LENGTH=(4,), 169 | D_LANE_NUM_VEHICLE=(4,), 170 | 171 | D_COMING_VEHICLE = (4,), 172 | D_LEAVING_VEHICLE = (4,), 173 | 174 | D_LANE_NUM_VEHICLE_BEEN_STOPPED_THRES1=(4,), 175 | D_CUR_PHASE=(1,), 176 | D_NEXT_PHASE=(1,), 177 | D_TIME_THIS_PHASE=(1,), 178 | D_TERMINAL=(1,), 179 | D_LANE_SUM_WAITING_TIME=(4,), 180 | D_VEHICLE_POSITION_IMG=(4, 60,), 181 | D_VEHICLE_SPEED_IMG=(4, 60,), 182 | D_VEHICLE_WAITING_TIME_IMG=(4, 60,), 183 | 184 | D_PRESSURE=(1,), 185 | 186 | D_ADJACENCY_MATRIX=(2,) 187 | ), 188 | 189 | "DIC_REWARD_INFO": { 190 | "flickering": 0, 191 | "sum_lane_queue_length": 0, 192 | "sum_lane_wait_time": 0, 193 | "sum_lane_num_vehicle_left": 0, 194 | "sum_duration_vehicle_left": 0, 195 | "sum_num_vehicle_been_stopped_thres01": 0, 196 | "sum_num_vehicle_been_stopped_thres1": -0.25, 197 | "pressure": 0 # -0.25 198 | }, 199 | 200 | "LANE_NUM": { 201 | "LEFT": 1, 202 | "RIGHT": 1, 203 | "STRAIGHT": 1 204 | }, 205 | 206 | "PHASE": { 207 | "sumo": { 208 | 0: [0, 1, 0, 1, 0, 0, 0, 0],# 'WSES', 209 | 1: [0, 0, 0, 0, 0, 1, 0, 1],# 'NSSS', 210 | 2: [1, 0, 1, 0, 0, 0, 0, 0],# 'WLEL', 211 | 3: [0, 0, 0, 0, 1, 0, 1, 0]# 'NLSL', 212 | }, 213 | "anon": { 214 | # 0: [0, 0, 0, 0, 0, 0, 0, 0], 215 | 1: [0, 1, 0, 1, 0, 0, 0, 0],# 'WSES', 216 | 2: [0, 0, 0, 0, 0, 1, 0, 1],# 'NSSS', 217 | 3: [1, 0, 1, 0, 0, 0, 0, 0],# 'WLEL', 218 | 4: [0, 0, 0, 0, 1, 0, 1, 0]# 'NLSL', 219 | # 'WSWL', 220 | # 'ESEL', 221 | # 'WSES', 222 | # 'NSSS', 223 | # 'NSNL', 224 | # 'SSSL', 225 | }, 226 | } 227 | } 228 | 229 | 230 | 231 | ## ==================== multi_phase ==================== 232 | 233 | # if dic_traffic_env_conf_extra["LANE_NUM"] == config._LS: 234 | if lane == 2: 235 | template = "template_ls" 236 | fixed_time = [cl, cl, cl, cl] 237 | fixed_time_dic = config.DIC_FIXEDTIME_MULTI_PHASE 238 | elif lane == 1: 239 | template = "template_s" 240 | fixed_time = [cl, cl] 241 | fixed_time_dic = config.DIC_FIXEDTIME 242 | dic_traffic_env_conf_extra["PHASE"] = { 243 | "sumo": { 244 | 0: [0, 1, 0, 1, 0, 0, 0, 0], # 'WSES', 245 | 1: [0, 0, 0, 0, 0, 1, 0, 1], # 'NSSS', 246 | }, 247 | "anon": { 248 | 1: [0, 1, 0, 1, 0, 0, 0, 0], # 'WSES', 249 | 2: [0, 0, 0, 0, 0, 1, 0, 1], # 'NSSS', 250 | }, 251 | } 252 | dic_traffic_env_conf_extra["LANE_NUM"] = { 253 | "LEFT": 0, 254 | "RIGHT": 0, 255 | "STRAIGHT": 1 256 | } 257 | elif lane == 3: 258 | if new_volume=='jinan': 259 | template='Jinan' 260 | elif new_volume=='hangzhou': 261 | template='Hangzhou' 262 | elif new_volume=='newyork': 263 | template='NewYork' 264 | elif new_volume=='LA': 265 | template='LA' 266 | elif new_volume=='georgia': 267 | template='Georgia' 268 | elif 'chacha' in new_volume: 269 | template='Chacha' 270 | else: 271 | template = "template_lsr" 272 | fixed_time_dic = config.DIC_FIXEDTIME_MULTI_PHASE 273 | else: 274 | raise ValueError 275 | lit_template = "lit_" + template 276 | 277 | ## fixedtime agent conf 278 | 279 | dic_agent_conf_extra = { 280 | "DAY_TIME": 3600, 281 | "UPDATE_PERIOD": 3600 / n_segments, # if synthetic, update_period: 3600/12 282 | "FIXED_TIME": fixed_time_dic[volume], 283 | "ROUND_UP": 1, 284 | "PHASE_TO_LANE": [[0, 1], [2, 3]], 285 | "MIN_PHASE_TIME": 1, 286 | "TRAFFIC_FILE": [traffic_file], 287 | } 288 | dic_traffic_env_conf_extra["NUM_AGENTS"] = dic_traffic_env_conf_extra["NUM_INTERSECTIONS"] 289 | 290 | prefix_intersections = str(road_net) 291 | dic_path_extra = { 292 | "PATH_TO_MODEL": os.path.join("model", memo, traffic_file+"_"+time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))+"_%d"%cl), 293 | "PATH_TO_WORK_DIRECTORY": os.path.join("records", memo, traffic_file+"_"+time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))+"_%d"%cl), 294 | "PATH_TO_DATA": os.path.join("data", template, prefix_intersections) 295 | } 296 | 297 | if multi_process: 298 | process_list.append(Process(target=oneline_wrapper, 299 | args=(dic_exp_conf_extra, dic_agent_conf_extra, 300 | dic_traffic_env_conf_extra, dic_path_extra)) 301 | ) 302 | else: 303 | oneline_wrapper(dic_exp_conf_extra, dic_agent_conf_extra, dic_traffic_env_conf_extra, dic_path_extra) 304 | 305 | if multi_process: 306 | i = 0 307 | list_cur_p = [] 308 | for p in process_list: 309 | if len(list_cur_p) < n_workers: 310 | print(i) 311 | p.start() 312 | list_cur_p.append(p) 313 | i += 1 314 | if len(list_cur_p) < n_workers: 315 | continue 316 | 317 | idle = check_all_workers_working(list_cur_p) 318 | 319 | while idle == -1: 320 | time.sleep(1) 321 | idle = check_all_workers_working( 322 | list_cur_p) 323 | del list_cur_p[idle] 324 | 325 | for p in list_cur_p: 326 | p.join() 327 | 328 | if __name__ == "__main__": 329 | start=time.time() 330 | args = parse_args() 331 | # memo = "test" 332 | # args.g = True 333 | main(args.memo, args.all, args.road_net, args.env, args.gui, args.volume, args.ratio,args.model,args.count,args.lane) 334 | end=time.time() 335 | duration=end-start 336 | print('%.2f seconds, %.2f minutes'%(duration,duration/60)) 337 | 338 | # def main(memo=None, baseline, road_net, env, gui, volume, ratio, model): 339 | 340 | """ 341 | time log for 3600s 342 | Fixedtime: 289.47 seconds, 4.82 min 343 | MaxPressure: 271.12 seconds, 4.51 min 344 | """ -------------------------------------------------------------------------------- /run_baseline_batch.py: -------------------------------------------------------------------------------- 1 | import run_baseline 2 | import summary 3 | 4 | memo = "multi_phase/sumo/fixedtime_no_change_lane_optimal" 5 | run_baseline.main(memo) 6 | print("********************** run_baseline ends **********************") 7 | summary.summary_detail_baseline(memo) 8 | print("********************** run_baseline ends **********************") -------------------------------------------------------------------------------- /run_batch.py: -------------------------------------------------------------------------------- 1 | import runexp 2 | import testexp 3 | import summary 4 | 5 | 6 | memo = "multi_phase/sumo/pipeline" 7 | runexp.main(memo) 8 | print("****************************** runexp ends (generate, train, test)!! ******************************") 9 | summary.main(memo) 10 | print("****************************** summary_detail ends ******************************") -------------------------------------------------------------------------------- /runexp.py: -------------------------------------------------------------------------------- 1 | import config 2 | import copy 3 | from pipeline import Pipeline 4 | import os 5 | import time 6 | from multiprocessing import Process 7 | import argparse 8 | import os 9 | import matplotlib 10 | # matplotlib.use('TkAgg') 11 | 12 | from script import get_traffic_volume 13 | 14 | multi_process = True 15 | TOP_K_ADJACENCY=-1 16 | TOP_K_ADJACENCY_LANE=-1 17 | PRETRAIN=False 18 | NUM_ROUNDS=100 19 | EARLY_STOP=False 20 | NEIGHBOR=False 21 | SAVEREPLAY=False 22 | ADJACENCY_BY_CONNECTION_OR_GEO=False 23 | hangzhou_archive=True 24 | ANON_PHASE_REPRE=[] 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser() 28 | # The file folder to create/log in 29 | parser.add_argument("--memo", type=str, default='0515_afternoon_Colight_6_6_bi')#1_3,2_2,3_3,4_4 30 | parser.add_argument("--env", type=int, default=1) #env=1 means you will run CityFlow 31 | parser.add_argument("--gui", type=bool, default=False) 32 | parser.add_argument("--road_net", type=str, default='6_6')#'1_2') # which road net you are going to run 33 | parser.add_argument("--volume", type=str, default='300')#'300' 34 | parser.add_argument("--suffix", type=str, default="0.3_bi")#0.3 35 | 36 | global hangzhou_archive 37 | hangzhou_archive=False 38 | global TOP_K_ADJACENCY 39 | TOP_K_ADJACENCY=5 40 | global TOP_K_ADJACENCY_LANE 41 | TOP_K_ADJACENCY_LANE=5 42 | global NUM_ROUNDS 43 | NUM_ROUNDS=100 44 | global EARLY_STOP 45 | EARLY_STOP=False 46 | global NEIGHBOR 47 | # TAKE CARE 48 | NEIGHBOR=False 49 | global SAVEREPLAY # if you want to relay your simulation, set it to be True 50 | SAVEREPLAY=False 51 | global ADJACENCY_BY_CONNECTION_OR_GEO 52 | # TAKE CARE 53 | ADJACENCY_BY_CONNECTION_OR_GEO=False 54 | 55 | #modify:TOP_K_ADJACENCY in line 154 56 | global PRETRAIN 57 | PRETRAIN=False 58 | parser.add_argument("--mod", type=str, default='CoLight')#SimpleDQN,SimpleDQNOne,GCN,CoLight,Lit 59 | parser.add_argument("--cnt",type=int, default=3600)#3600 60 | parser.add_argument("--gen",type=int, default=4)#4 61 | 62 | parser.add_argument("-all", action="store_true", default=False) 63 | parser.add_argument("--workers",type=int, default=7) 64 | parser.add_argument("--onemodel",type=bool, default=False) 65 | 66 | parser.add_argument("--visible_gpu", type=str, default="-1") 67 | global ANON_PHASE_REPRE 68 | tt=parser.parse_args() 69 | if 'CoLight_Signal' in tt.mod: 70 | #12dim 71 | ANON_PHASE_REPRE={ 72 | # 0: [0, 0, 0, 0, 0, 0, 0, 0], 73 | 1: [0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1],# 'WSES', 74 | 2: [0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1],# 'NSSS', 75 | 3: [1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1],# 'WLEL', 76 | 4: [0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1]# 'NLSL', 77 | } 78 | else: 79 | #12dim 80 | ANON_PHASE_REPRE={ 81 | 1: [0, 1, 0, 1, 0, 0, 0, 0], 82 | 2: [0, 0, 0, 0, 0, 1, 0, 1], 83 | 3: [1, 0, 1, 0, 0, 0, 0, 0], 84 | 4: [0, 0, 0, 0, 1, 0, 1, 0] 85 | } 86 | print('agent_name:%s',tt.mod) 87 | print('ANON_PHASE_REPRE:',ANON_PHASE_REPRE) 88 | 89 | 90 | return parser.parse_args() 91 | 92 | 93 | def memo_rename(traffic_file_list): 94 | new_name = "" 95 | for traffic_file in traffic_file_list: 96 | if "synthetic" in traffic_file: 97 | sta = traffic_file.rfind("-") + 1 98 | print(traffic_file, int(traffic_file[sta:-4])) 99 | new_name = new_name + "syn" + traffic_file[sta:-4] + "_" 100 | elif "cross" in traffic_file: 101 | sta = traffic_file.find("equal_") + len("equal_") 102 | end = traffic_file.find(".xml") 103 | new_name = new_name + "uniform" + traffic_file[sta:end] + "_" 104 | elif "flow" in traffic_file: 105 | new_name = traffic_file[:-4] 106 | new_name = new_name[:-1] 107 | return new_name 108 | 109 | def merge(dic_tmp, dic_to_change): 110 | dic_result = copy.deepcopy(dic_tmp) 111 | dic_result.update(dic_to_change) 112 | 113 | return dic_result 114 | 115 | def check_all_workers_working(list_cur_p): 116 | for i in range(len(list_cur_p)): 117 | if not list_cur_p[i].is_alive(): 118 | return i 119 | 120 | return -1 121 | 122 | def pipeline_wrapper(dic_exp_conf, dic_agent_conf, dic_traffic_env_conf, dic_path): 123 | ppl = Pipeline(dic_exp_conf=dic_exp_conf, # experiment config 124 | dic_agent_conf=dic_agent_conf, # RL agent config 125 | dic_traffic_env_conf=dic_traffic_env_conf, # the simolation configuration 126 | dic_path=dic_path # where should I save the logs? 127 | ) 128 | global multi_process 129 | ppl.run(multi_process=multi_process) 130 | 131 | print("pipeline_wrapper end") 132 | return 133 | 134 | 135 | 136 | def main(memo, traffic_file_list,roadnet_file,num_row,num_col,path_to_data,workers): 137 | num_intersections = num_row * num_col 138 | print('num_intersections:',num_intersections) 139 | 140 | ENVIRONMENT = "anon" 141 | 142 | process_list = [] 143 | n_workers = workers #len(traffic_file_list) 144 | multi_process = True 145 | 146 | global PRETRAIN 147 | global NUM_ROUNDS 148 | global EARLY_STOP 149 | for traffic_file in traffic_file_list: 150 | global TOP_K_ADJACENCY 151 | global TOP_K_ADJACENCY_LANE 152 | global NEIGHBOR 153 | global SAVEREPLAY 154 | global ADJACENCY_BY_CONNECTION_OR_GEO 155 | global ANON_PHASE_REPRE 156 | global hangzhou_archive 157 | deploy_dic_traffic_env_conf = { 158 | "NUM_INTERSECTIONS": num_intersections, 159 | "NUM_ROW": num_row, 160 | "NUM_COL": num_col, 161 | "TRAFFIC_FILE": traffic_file, 162 | "ROADNET_FILE": roadnet_file, 163 | "ACTION_PATTERN": "set", 164 | "MIN_ACTION_TIME": 10, 165 | "YELLOW_TIME": 5, 166 | "ALL_RED_TIME": 0, 167 | "NUM_PHASES": 2, 168 | "NUM_LANES": 1, 169 | "ACTION_DIM": 2, 170 | "MEASURE_TIME": 10, 171 | "IF_GUI": false, 172 | "DEBUG": false, 173 | "INTERVAL": 1, 174 | "THREADNUM": 8, 175 | "SAVEREPLAY": false, 176 | "RLTRAFFICLIGHT": true, 177 | "DIC_FEATURE_DIM": { 178 | "D_LANE_QUEUE_LENGTH": [4], 179 | "D_LANE_NUM_VEHICLE": [4], 180 | "D_COMING_VEHICLE": [12], 181 | "D_LEAVING_VEHICLE": [12], 182 | "D_LANE_NUM_VEHICLE_BEEN_STOPPED_THRES1": [4], 183 | "D_CUR_PHASE": [8], 184 | "D_NEXT_PHASE": [1], 185 | "D_TIME_THIS_PHASE": [1], 186 | "D_TERMINAL": [1], 187 | "D_LANE_SUM_WAITING_TIME": [4], 188 | "D_VEHICLE_POSITION_IMG": [4,60], 189 | "D_VEHICLE_SPEED_IMG": [4,60], 190 | "D_VEHICLE_WAITING_TIME_IMG": [4,60], 191 | "D_PRESSURE": [1], 192 | "D_ADJACENCY_MATRIX": [5], 193 | "D_ADJACENCY_MATRIX_LANE": [5], 194 | "D_CUR_PHASE_0": [1], 195 | "D_LANE_NUM_VEHICLE_0": [4], 196 | "D_CUR_PHASE_1": [1], 197 | "D_LANE_NUM_VEHICLE_1": [4], 198 | "D_CUR_PHASE_2": [1], 199 | "D_LANE_NUM_VEHICLE_2": [4], 200 | "D_CUR_PHASE_3": [1], 201 | "D_LANE_NUM_VEHICLE_3": [4] 202 | }, 203 | "LIST_STATE_FEATURE": ["cur_phase","lane_num_vehicle","adjacency_matrix","adjacency_matrix_lane"], 204 | "DIC_REWARD_INFO": { 205 | "flickering": 0, 206 | "sum_lane_queue_length": 0, 207 | "sum_lane_wait_time": 0, 208 | "sum_lane_num_vehicle_left": 0, 209 | "sum_duration_vehicle_left": 0, 210 | "sum_num_vehicle_been_stopped_thres01": 0, 211 | "sum_num_vehicle_been_stopped_thres1": -0.25, 212 | "pressure": 0 213 | }, 214 | "LANE_NUM": {"LEFT": 1,"RIGHT": 1,"STRAIGHT": 1}, 215 | "PHASE": { 216 | "anon": { 217 | "1": [0,1,0,1,0,0,0,0], 218 | "2": [0,0,0,0,0,1,0,1], 219 | "3": [1,0,1,0,0,0,0,0], 220 | "4": [0,0,0,0,1,0,1,0] 221 | } 222 | }, 223 | "USE_LANE_ADJACENCY": true, 224 | "ONE_MODEL": false, 225 | "NUM_AGENTS": 1, 226 | "TOP_K_ADJACENCY": 5, 227 | "ADJACENCY_BY_CONNECTION_OR_GEO": false, 228 | "TOP_K_ADJACENCY_LANE": 5, 229 | "SIMULATOR_TYPE": "anon", 230 | "BINARY_PHASE_EXPANSION": true, 231 | "FAST_COMPUTE": true, 232 | "NEIGHBOR": false, 233 | "MODEL_NAME": "CoLight", 234 | "VOLUME": "300", 235 | "phase_expansion": { 236 | "1": [0,1,0,1,0,0,0,0], 237 | "2": [0,0,0,0,0,1,0,1], 238 | "3": [1,0,1,0,0,0,0,0], 239 | "4": [0,0,0,0,1,0,1,0], 240 | "5": [1,1,0,0,0,0,0,0], 241 | "6": [0,0,1,1,0,0,0,0], 242 | "7": [0,0,0,0,0,0,1,1], 243 | "8": [0,0,0,0,1,1,0,0] 244 | }, 245 | "phase_expansion_4_lane": { 246 | "1": [1,1,0,0], 247 | "2": [0,0,1,1] 248 | } 249 | } 250 | 251 | deploy_dic_exp_conf = { 252 | "TRAFFIC_FILE": [traffic_file], 253 | "ROADNET_FILE": roadnet_file, 254 | "RUN_COUNTS": 3600, 255 | "MODEL_NAME": "CoLight", 256 | "NUM_ROUNDS": 100, 257 | "NUM_GENERATORS": 4, 258 | "LIST_MODEL": ["Fixedtime","SOTL","Deeplight","SimpleDQN"], 259 | "LIST_MODEL_NEED_TO_UPDATE": ["Deeplight","SimpleDQN","CoLight","GCN","SimpleDQNOne","Lit"], 260 | "MODEL_POOL": false, 261 | "NUM_BEST_MODEL": 3, 262 | "PRETRAIN": false, 263 | "PRETRAIN_MODEL_NAME": "CoLight", 264 | "PRETRAIN_NUM_ROUNDS": 0, 265 | "PRETRAIN_NUM_GENERATORS": 15, 266 | "AGGREGATE": false, 267 | "DEBUG": false, 268 | "EARLY_STOP": false, 269 | "MULTI_TRAFFIC": false, 270 | "MULTI_RANDOM": false 271 | } 272 | deploy_dic_agent_conf = { 273 | "TRAFFIC_FILE": traffic_file, 274 | "CNN_layers": [[32,32]], 275 | "att_regularization": false, 276 | "rularization_rate": 0.03, 277 | "LEARNING_RATE": 0.001, 278 | "SAMPLE_SIZE": 1000, 279 | "BATCH_SIZE": 20, 280 | "EPOCHS": 100, 281 | "UPDATE_Q_BAR_FREQ": 5, 282 | "UPDATE_Q_BAR_EVERY_C_ROUND": false, 283 | "GAMMA": 0.8, 284 | "MAX_MEMORY_LEN": 10000, 285 | "PATIENCE": 10, 286 | "D_DENSE": 20, 287 | "N_LAYER": 2, 288 | "EPSILON": 0.8, 289 | "EPSILON_DECAY": 0.95, 290 | "MIN_EPSILON": 0.2, 291 | "LOSS_FUNCTION": "mean_squared_error", 292 | "SEPARATE_MEMORY": false, 293 | "NORMAL_FACTOR": 20, 294 | } 295 | 296 | deploy_dic_path = { 297 | "PATH_TO_MODEL": os.path.join("model", memo, traffic_file + "_" + time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))), 298 | "PATH_TO_WORK_DIRECTORY": os.path.join("records", memo, traffic_file + "_" + time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))), 299 | "PATH_TO_DATA": path_to_data, 300 | "PATH_TO_PRETRAIN_MODEL": os.path.join("model", "initial", traffic_file), 301 | "PATH_TO_PRETRAIN_WORK_DIRECTORY": os.path.join("records", "initial", traffic_file), 302 | "PATH_TO_ERROR": os.path.join("errors", memo) 303 | "PATH_TO_PRETRAIN_DATA": "data/template", 304 | "PATH_TO_AGGREGATE_SAMPLES": "records/initial" 305 | } 306 | 307 | deploy_dic_path["PATH_TO_DATA"] = 'data/Hangzhou/4_4' 308 | 309 | 310 | if multi_process: 311 | ppl = Process(target=pipeline_wrapper, 312 | args=(deploy_dic_exp_conf, 313 | deploy_dic_agent_conf, 314 | deploy_dic_traffic_env_conf, 315 | deploy_dic_path)) 316 | process_list.append(ppl) 317 | else: 318 | pipeline_wrapper(dic_exp_conf=deploy_dic_exp_conf, 319 | dic_agent_conf=deploy_dic_agent_conf, 320 | dic_traffic_env_conf=deploy_dic_traffic_env_conf, 321 | dic_path=deploy_dic_path) 322 | 323 | if multi_process: 324 | for i in range(0, len(process_list), n_workers): 325 | i_max = min(len(process_list), i + n_workers) 326 | for j in range(i, i_max): 327 | print(j) 328 | print("start_traffic") 329 | process_list[j].start() 330 | print("after_traffic") 331 | for k in range(i, i_max): 332 | print("traffic to join", k) 333 | process_list[k].join() 334 | print("traffic finish join", k) 335 | 336 | 337 | return memo 338 | 339 | 340 | if __name__ == "__main__": 341 | args = parse_args() 342 | #memo = "multi_phase/optimal_search_new/new_headway_anon" 343 | 344 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu 345 | main(args.memo, ['hangzhou.json'],'roadnet_4_4.json',4,4,'data/Hangzhou',args.workers) -------------------------------------------------------------------------------- /script.py: -------------------------------------------------------------------------------- 1 | # collect the common function 2 | # import script 3 | 4 | def get_traffic_volume(traffic_file): 5 | # only support "cross" and "synthetic" 6 | if "cross" in traffic_file: 7 | sta = traffic_file.find("equal_") + len("equal_") 8 | end = traffic_file.find(".xml") 9 | return int(traffic_file[sta:end]) 10 | elif "synthetic" in traffic_file: 11 | traffic_file_list = traffic_file.split("-") 12 | volume_list = [] 13 | for i in range(2, 6): 14 | volume_list.append(int(traffic_file_list[i][2:])) 15 | 16 | vol = min(max(volume_list[0:2]), max(volume_list[2:])) 17 | 18 | return int(vol/100)*100 19 | elif "anon" in traffic_file: 20 | return int(traffic_file.split('_')[3].split('.')[0]) -------------------------------------------------------------------------------- /simple_dqn_agent.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from keras import backend as K 4 | from keras.layers import Input, Dense, Conv2D, Flatten, BatchNormalization, Activation, Multiply, Add 5 | from keras.models import Model, model_from_json, load_model 6 | from keras.optimizers import RMSprop 7 | from keras.callbacks import EarlyStopping, TensorBoard 8 | from keras.layers.merge import concatenate, add 9 | import random 10 | import os 11 | import pickle 12 | 13 | from network_agent import NetworkAgent, conv2d_bn, Selector 14 | import json 15 | 16 | 17 | class SimpleDQNAgent(NetworkAgent): 18 | 19 | def build_network(self): 20 | 21 | '''Initialize a Q network''' 22 | 23 | # initialize feature node 24 | dic_input_node = {} 25 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]: 26 | if "phase" in feature_name or "adjacency" in feature_name or "pressure" in feature_name: 27 | _shape = self.dic_traffic_env_conf["DIC_FEATURE_DIM"]["D_"+feature_name.upper()] 28 | else: 29 | _shape = (self.dic_traffic_env_conf["DIC_FEATURE_DIM"]["D_"+feature_name.upper()][0]*self.num_lanes,) 30 | # _shape = (self.dic_traffic_env_conf["DIC_FEATURE_DIM"]["D_"+feature_name.upper()]) 31 | dic_input_node[feature_name] = Input(shape=_shape, 32 | name="input_"+feature_name) 33 | 34 | # add cnn to image features 35 | dic_flatten_node = {} 36 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]: 37 | if len(self.dic_traffic_env_conf["DIC_FEATURE_DIM"]["D_"+feature_name.upper()]) > 1: 38 | dic_flatten_node[feature_name] = Flatten()(dic_input_node[feature_name]) 39 | else: 40 | dic_flatten_node[feature_name] = dic_input_node[feature_name] 41 | 42 | # concatenate features 43 | list_all_flatten_feature = [] 44 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]: 45 | list_all_flatten_feature.append(dic_flatten_node[feature_name]) 46 | all_flatten_feature = concatenate(list_all_flatten_feature, axis=1, name="all_flatten_feature") 47 | 48 | # shared dense layer, N_LAYER 49 | locals()["dense_0"] = Dense(self.dic_agent_conf["D_DENSE"], activation="relu", name="dense_0")(all_flatten_feature) 50 | for i in range(1, self.dic_agent_conf["N_LAYER"]): 51 | locals()["dense_%d"%i] = Dense(self.dic_agent_conf["D_DENSE"], activation="relu", name="dense_%d"%i)(locals()["dense_%d"%(i-1)]) 52 | q_values = Dense(self.num_actions, activation="linear", name="q_values")(locals()["dense_%d"%(self.dic_agent_conf["N_LAYER"]-1)]) 53 | network = Model(inputs=[dic_input_node[feature_name] 54 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]], 55 | outputs=q_values) 56 | network.compile(optimizer=RMSprop(lr=self.dic_agent_conf["LEARNING_RATE"]), 57 | loss=self.dic_agent_conf["LOSS_FUNCTION"]) 58 | network.summary() 59 | 60 | return network 61 | -------------------------------------------------------------------------------- /simple_dqn_one_agent.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from keras import backend as K 4 | from keras.layers import Input, Dense, Conv2D, Flatten, BatchNormalization, Activation, Multiply, Add 5 | from keras.models import Model, model_from_json, load_model 6 | from keras.optimizers import RMSprop 7 | from keras.callbacks import EarlyStopping, TensorBoard 8 | from keras.layers.merge import concatenate, add 9 | import random 10 | import os 11 | import pickle 12 | 13 | from network_agent import NetworkAgent, conv2d_bn, Selector 14 | import json 15 | 16 | 17 | class SimpleDQNOneAgent(NetworkAgent): 18 | 19 | def build_network(self): 20 | 21 | '''Initialize a Q network''' 22 | 23 | # initialize feature node 24 | dic_input_node = {} 25 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]: 26 | if "phase" in feature_name or "adjacency" in feature_name or "pressure" in feature_name: 27 | _shape = self.dic_traffic_env_conf["DIC_FEATURE_DIM"]["D_"+feature_name.upper()] 28 | else: 29 | _shape = (self.dic_traffic_env_conf["DIC_FEATURE_DIM"]["D_"+feature_name.upper()][0]*self.num_lanes,) 30 | # _shape = (self.dic_traffic_env_conf["DIC_FEATURE_DIM"]["D_"+feature_name.upper()]) 31 | dic_input_node[feature_name] = Input(shape=_shape, 32 | name="input_"+feature_name) 33 | 34 | # add cnn to image features 35 | dic_flatten_node = {} 36 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]: 37 | if len(self.dic_traffic_env_conf["DIC_FEATURE_DIM"]["D_"+feature_name.upper()]) > 1: 38 | dic_flatten_node[feature_name] = Flatten()(dic_input_node[feature_name]) 39 | else: 40 | dic_flatten_node[feature_name] = dic_input_node[feature_name] 41 | 42 | # concatenate features 43 | list_all_flatten_feature = [] 44 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]: 45 | list_all_flatten_feature.append(dic_flatten_node[feature_name]) 46 | all_flatten_feature = concatenate(list_all_flatten_feature, axis=1, name="all_flatten_feature") 47 | 48 | # shared dense layer, N_LAYER 49 | locals()["dense_0"] = Dense(self.dic_agent_conf["D_DENSE"], activation="relu", name="dense_0")(all_flatten_feature) 50 | for i in range(1, self.dic_agent_conf["N_LAYER"]): 51 | locals()["dense_%d"%i] = Dense(self.dic_agent_conf["D_DENSE"], activation="relu", name="dense_%d"%i)(locals()["dense_%d"%(i-1)]) 52 | q_values = Dense(self.num_actions, activation="linear", name="q_values")(locals()["dense_%d"%(self.dic_agent_conf["N_LAYER"]-1)]) 53 | network = Model(inputs=[dic_input_node[feature_name] 54 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]], 55 | outputs=q_values) 56 | network.compile(optimizer=RMSprop(lr=self.dic_agent_conf["LEARNING_RATE"]), 57 | loss=self.dic_agent_conf["LOSS_FUNCTION"]) 58 | network.summary() 59 | 60 | return network 61 | 62 | def convert_state_to_input(self, s): 63 | if self.dic_traffic_env_conf["BINARY_PHASE_EXPANSION"]: 64 | inputs = [] 65 | for feature in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]: 66 | if feature == "cur_phase": 67 | inputs.append(np.array([self.dic_traffic_env_conf['PHASE'] 68 | [self.dic_traffic_env_conf['SIMULATOR_TYPE']][s[feature][0]]])) 69 | else: 70 | inputs.append(np.array([s[feature]])) 71 | return inputs 72 | else: 73 | return [np.array([s[feature]]) for feature in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]] 74 | 75 | 76 | 77 | def choose_action(self, count, states): 78 | ''' 79 | choose the best action for current state 80 | -input: state:[[state inter1],[state inter1]] 81 | -output: act: [#agents,num_actions] 82 | ''' 83 | dic_state_feature_arrays = {} # {feature1: [inter1, inter2,..], feature2: [inter1, inter 2...]} 84 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]: 85 | dic_state_feature_arrays[feature_name] = [] 86 | 87 | 88 | for s in states: 89 | for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]: 90 | # print(s[feature_name]) 91 | if "cur_phase" in feature_name: 92 | dic_state_feature_arrays[feature_name].append(np.array(self.dic_traffic_env_conf['PHASE'][self.dic_traffic_env_conf['SIMULATOR_TYPE']][s[feature_name][0]])) 93 | else: 94 | dic_state_feature_arrays[feature_name].append(np.array(s[feature_name])) 95 | 96 | state_input = [np.array(dic_state_feature_arrays[feature_name]) for feature_name in 97 | self.dic_traffic_env_conf["LIST_STATE_FEATURE"]] 98 | 99 | # print("----------State Input",state_input) 100 | # print(dic_state_feature_arrays) 101 | 102 | q_values = self.q_network.predict(state_input) 103 | if random.random() <= self.dic_agent_conf["EPSILON"]: # continue explore new Random Action 104 | action = np.random.randint(len(q_values[0]), size=len(q_values)) 105 | else: # exploitation 106 | action = np.argmax(q_values, axis=1) 107 | 108 | return action 109 | 110 | def choose_action_separate(self, count, states): 111 | ''' 112 | choose the best action for current state 113 | -input: state:[[state inter1],[state inter1]] 114 | -output: act: [#agents,num_actions] 115 | ''' 116 | 117 | actions = [] 118 | 119 | for state in states: 120 | action = self.choose_action_i(count, state) 121 | actions.append(action) 122 | 123 | return actions 124 | 125 | def choose_action_i(self, count, state): 126 | 127 | ''' choose the best action for current state ''' 128 | state_input = self.convert_state_to_input(state) 129 | q_values = self.q_network.predict(state_input) 130 | if random.random() <= self.dic_agent_conf["EPSILON"]: # continue explore new Random Action 131 | action = random.randrange(len(q_values[0])) 132 | else: # exploitation 133 | action = np.argmax(q_values[0]) 134 | 135 | return action 136 | -------------------------------------------------------------------------------- /simulator.dockerfile: -------------------------------------------------------------------------------- 1 | FROM cityflowproject/cityflow:latest 2 | 3 | RUN apt-get update --fix-missing && apt-get install -y wget bzip2 ca-certificates \ 4 | libglib2.0-0 libxext6 libsm6 libxrender1 \ 5 | git mercurial subversion 6 | 7 | RUN apt-get install -y software-properties-common && \ 8 | add-apt-repository ppa:sumo/stable && \ 9 | apt-get update && \ 10 | apt-get install -y sumo sumo-tools sumo-doc && \ 11 | pip install msgpack && \ 12 | pip install tensorflow==1.11 && \ 13 | pip install keras==2.1 && \ 14 | pip install networkx && \ 15 | pip install pandas && \ 16 | pip install matplotlib 17 | 18 | --------------------------------------------------------------------------------