├── .idea ├── Learn-to-pass-crossroad-DQN.iml ├── Robotics-Navigation-Deep-Reinforcement-Learning.iml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── README.md ├── basic ├── car.rou.xml ├── cfg.sumocfg.xml ├── edges.edg.xml ├── nodes.nod.xml └── road.net.xml ├── crossroad ├── Env.py ├── Model.py ├── crossroad.net.xml ├── crossroad.rou.xml ├── crossroad.sumocfg └── main.py └── programImage.png /.idea/Learn-to-pass-crossroad-DQN.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | -------------------------------------------------------------------------------- /.idea/Robotics-Navigation-Deep-Reinforcement-Learning.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | sp 52 | 53 | 54 | 55 | 57 | 58 | 64 | 65 | 66 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 3 | 4 | 5 | 6 | 10 | -------------------------------------------------------------------------------- /basic/edges.edg.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /basic/nodes.nod.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /basic/road.net.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /crossroad/Env.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | import subprocess 5 | import traci 6 | import random 7 | import math 8 | import numpy as np 9 | import cv2 10 | 11 | 12 | class action_space: 13 | def __init__(self,action_num): 14 | if not action_num: 15 | print "dimisition size should not be 0" 16 | else: 17 | self.action_idx = np.arange(action_num) 18 | 19 | def sample(self): 20 | length = len(self.action_idx) 21 | if length: 22 | 23 | return np.random.randint(0,length) 24 | else: 25 | print "dimisition size should not be 0" 26 | 27 | 28 | class Intersaction: 29 | 30 | global config 31 | config = dict() 32 | config['PORT'] = 8813 33 | #config['SumoBinary'] = "/usr/bin/sumo-gui" # for linux 34 | #config['SumoBinary'] = "/usr/local/bin/sumo-gui" # for Mac 35 | config['SumoBinary'] = "C:\\Program Files (x86)\\Sumo\\bin\\sumo-gui" 36 | def __init__(self): 37 | 38 | subprocess.Popen([config['SumoBinary'], "-c", "crossroad.sumocfg", "--remote-port", str(config['PORT']), "--collision.check-junctions", "true","--collision.action", "teleport","--step-length","0.2"], stdout=sys.stdout, stderr=sys.stderr) 39 | 40 | self.vehicle_domain = traci._vehicle.VehicleDomain() 41 | self.simu_domain = traci._simulation.SimulationDomain() 42 | self.vehicletype_domain = traci._vehicletype.VehicleTypeDomain() 43 | self.departPos = "790" 44 | self.arrivalPos = "20" 45 | self.DrivingRoute = "cross" 46 | 47 | 48 | 49 | self.cur_state = None 50 | self.cur_reward = None 51 | 52 | 53 | self.Audi_viewrange_axis0 = 50 # one-side 54 | self.Audi_viewrange_axis1 = 30 55 | 56 | self.unit = 5 57 | 58 | 59 | self.done = False 60 | self.accident = False 61 | self.action_space = action_space(5) # need more consideration 62 | 63 | self.goal_speed = 0 64 | 65 | traci.init(config['PORT']) 66 | self.c_vid = "Audi L3" 67 | self.generate_flow(self.vehicle_domain) 68 | 69 | 70 | 71 | 72 | def reset(self): 73 | 74 | 75 | vehicle_list = self.vehicle_domain.getIDList() 76 | 77 | if self.c_vid in vehicle_list: 78 | self.vehicle_domain.remove(self.c_vid) 79 | 80 | self.vehicle_domain.addFull(self.c_vid, self.DrivingRoute, typeID="sporty", departPos=self.departPos,arrivalPos=self.arrivalPos) 81 | self.vehicle_domain.setSpeedMode(self.c_vid, 0) 82 | self.vehicle_domain.setSpeed(self.c_vid,0) 83 | 84 | self.goal_speed = 0 85 | 86 | traci.simulationStep() 87 | 88 | self.cur_state = self.state(self.vehicle_domain) 89 | 90 | self.done = False 91 | self.accident = False 92 | self.successful = False 93 | 94 | return self.cur_state 95 | 96 | 97 | # state: 30 * 100 * 1, grey scale image 98 | # range: -50 ~ +50; 0 ~ 30; Do not consider the direction change of vehicle's head 99 | def state(self, vehicle_domain): 100 | 101 | vehicle_list = self.vehicle_domain.getIDList() 102 | print vehicle_list 103 | 104 | if self.c_vid not in vehicle_list and self.accident == False: 105 | self.successful = True 106 | self.vehicle_domain.addFull(self.c_vid, self.DrivingRoute, typeID="sporty", departPos=self.departPos,arrivalPos=self.arrivalPos) 107 | 108 | Audi_x, Audi_y = vehicle_domain.getPosition(self.c_vid) 109 | Audi_interval_axis0 = (Audi_x - self.Audi_viewrange_axis0,Audi_x + self.Audi_viewrange_axis0) 110 | Audi_interval_axis1 = (Audi_y,Audi_y + self.Audi_viewrange_axis1) 111 | valid_vehicles = dict() 112 | for vid in vehicle_list: 113 | x, y = vehicle_domain.getPosition(vid) 114 | print "vid:%s, vid.x:%d, vid.y:%d" % (vid, x, y) 115 | vehicle_L = self.vehicletype_domain.getLength(vehicle_domain.getTypeID(vid))/2 116 | vehicle_W = self.vehicletype_domain.getWidth(vehicle_domain.getTypeID(vid))/2 117 | 118 | Ax = (x - vehicle_L - Audi_interval_axis0[0]) 119 | Bx = (x + vehicle_L - Audi_interval_axis0[0]) 120 | Cx = (x - vehicle_L - Audi_interval_axis0[1]) 121 | Dx = (x + vehicle_L - Audi_interval_axis0[1]) 122 | Ay = (y - vehicle_W - Audi_interval_axis1[0]) 123 | By = (y + vehicle_W - Audi_interval_axis1[0]) 124 | Cy = (y - vehicle_W - Audi_interval_axis1[1]) 125 | Dy = (y + vehicle_W - Audi_interval_axis1[1]) 126 | 127 | if vid != self.c_vid: 128 | if Ax * Bx < 0 : 129 | if Ay * By < 0: 130 | valid_vehicles[vid] = [Audi_interval_axis0[0], x + vehicle_L,Audi_interval_axis1[0],y + vehicle_W] 131 | elif Cy * Dy < 0: 132 | valid_vehicles[vid] = [Audi_interval_axis0[0], x + vehicle_L,y - vehicle_W,Audi_interval_axis1[1]] 133 | elif Ay * Dy < 0: 134 | valid_vehicles[vid] = [Audi_interval_axis0[0], x + vehicle_L,y - vehicle_W,y + vehicle_W] 135 | 136 | elif Cx * Dx < 0: 137 | if Ay * By < 0: 138 | valid_vehicles[vid] = [x - vehicle_L,Audi_interval_axis0[1],Audi_interval_axis1[0],y + vehicle_W] 139 | elif Cy * Dy < 0: 140 | valid_vehicles[vid] = [x - vehicle_L,Audi_interval_axis0[1],y - vehicle_W,Audi_interval_axis1[1]] 141 | elif Ay * Dy < 0: 142 | valid_vehicles[vid] = [x - vehicle_L,Audi_interval_axis0[1],y - vehicle_W,y + vehicle_W] 143 | 144 | elif Ax * Dx < 0: 145 | if Ay * By < 0: 146 | valid_vehicles[vid] = [x - vehicle_L, x + vehicle_L,Audi_interval_axis1[0],y + vehicle_W] 147 | elif Cy * Dy < 0: 148 | valid_vehicles[vid] = [x - vehicle_L, x + vehicle_L,y - vehicle_W,Audi_interval_axis1[1]] 149 | elif Ay * Dy < 0: 150 | valid_vehicles[vid] = [x - vehicle_L, x + vehicle_L,y - vehicle_W,y + vehicle_W] 151 | 152 | print "valid_vehicles:",valid_vehicles 153 | 154 | # coordinate transfer. (c_vid.x,c_vid.y) ---> (self.Audi_viewrange_axis0,0) 155 | img = np.zeros((self.Audi_viewrange_axis1*self.unit,self.Audi_viewrange_axis0*2*self.unit),np.uint8) 156 | 157 | transfer_axis0 = self.Audi_viewrange_axis0 - Audi_x 158 | transfer_axis1 = 0 - Audi_y 159 | for key in valid_vehicles: 160 | top_left = (int(round(self.unit*(valid_vehicles[key][0]+transfer_axis0))),int(round(self.unit*(valid_vehicles[key][3] + transfer_axis1)))) 161 | bottom_right = (int(round(self.unit*(valid_vehicles[key][1]+transfer_axis0))),int(round(self.unit*(valid_vehicles[key][2] + transfer_axis1)))) 162 | print "top_left:", top_left 163 | print "bottom_right:",bottom_right 164 | # Opencv top-left as origin point, not bottom-left, so 165 | top_left = (top_left[0],self.unit * self.Audi_viewrange_axis1 - top_left[1]) 166 | bottom_right = (bottom_right[0], self.unit * self.Audi_viewrange_axis1 - bottom_right[1]) 167 | cv2.rectangle(img,bottom_right,top_left,255,-1) 168 | #cv2.rectangle(img, (384, 0), (510, 128), 255, -1) 169 | cv2.imshow("image",img) 170 | cv2.waitKey(100) 171 | cv2.destroyAllWindows() 172 | 173 | return img 174 | 175 | 176 | # Generate other traffic flows 177 | def generate_flow(self, vehicle_domain): 178 | 179 | random.seed(42) # make tests reproducible 180 | N = 10000 # number of time steps 181 | # demand per second from different directions 182 | pWE = 1. / 5 183 | pEW = 1. / 6 184 | pNS = 1. / 15 185 | vehNr = 0 186 | for i in range(N): 187 | if random.uniform(0, 1) < pWE: 188 | vehicle_domain.addFull("normal_%i" % vehNr, "left-right", typeID="normal", depart="%i" % i, 189 | departPos="750") 190 | vehicle_domain.setSpeedMode("normal_%i" % vehNr, 0) 191 | vehNr += 1 192 | 193 | if random.uniform(0, 1) < pEW: 194 | vehicle_domain.addFull("sporty_%i" % vehNr, "left-right", typeID="sporty", depart="%i" % i, 195 | departPos="750") 196 | vehicle_domain.setSpeedMode("sporty_%i" % vehNr, 0) 197 | vehNr += 1 198 | 199 | if random.uniform(0, 1) < pNS: 200 | vehicle_domain.addFull("trailer_%i" % vehNr, "left-right", typeID="trailer", depart="%i" % i, 201 | departPos="750") 202 | vehicle_domain.setSpeedMode("trailer_%i" % vehNr, 0) 203 | vehNr += 1 204 | 205 | 206 | def step(self,action_idx): 207 | steps = 0 208 | timesteps = 0 209 | next_state = np.zeros((self.Audi_viewrange_axis1*self.unit,self.Audi_viewrange_axis0*2*self.unit),np.uint8) 210 | 211 | if action_idx != 4: 212 | steps = 2 ^ action_idx # steps 213 | timesteps = steps * 0.2 # seconds 214 | self.vehicle_domain.setSpeed(self.c_vid,0) 215 | 216 | else: 217 | steps = 1 # steps 218 | timesteps = steps * 0.2 # seconds 219 | self.vehicle_domain.slowDown(self.c_vid,self.vehicle_domain.getSpeed(self.c_vid)+2*timesteps,timesteps*1000) 220 | 221 | traci.simulationStep(self.simu_domain.getCurrentTime()+timesteps*1000) 222 | 223 | teleport_list = self.simu_domain.getEndingTeleportIDList() 224 | print("teleport list:", teleport_list) 225 | arrived_list = self.simu_domain.getArrivedIDList() 226 | print("arrived list:", arrived_list) 227 | 228 | if len(teleport_list): 229 | print("collisions happened!") 230 | self.accident = True 231 | self.cur_reward = -10 232 | next_state = self.reset() 233 | 234 | elif self.c_vid in arrived_list: 235 | print("successful arrival!") 236 | self.done = True 237 | self.cur_reward = 1 238 | next_state = self.reset() 239 | else: 240 | self.cur_reward = -0.01 241 | next_state = self.state(self.vehicle_domain) 242 | 243 | 244 | return next_state,self.cur_reward,self.done,self.accident,steps 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | #*************** back up for test **************# 266 | 267 | ''' 268 | def generate_flow(vehicle_domain): 269 | random.seed(42) # make tests reproducible 270 | N = 10000 # number of time steps 271 | # demand per second from different directions 272 | pWE = 1. / 5 273 | pEW = 1. / 6 274 | pNS = 1. / 15 275 | lastVeh = 0 276 | vehNr = 0 277 | for i in range(N): 278 | if random.uniform(0, 1) < pWE: 279 | vehicle_domain.addFull("normal_%i" % vehNr, "left-right", typeID="normal", depart="%i" % i, departPos="780") 280 | vehicle_domain.setSpeedMode("normal_%i" % vehNr, 0) 281 | vehNr += 1 282 | lastVeh = i 283 | if random.uniform(0, 1) < pEW: 284 | vehicle_domain.addFull("sporty_%i" % vehNr, "left-right", typeID="sporty", depart="%i" % i, departPos="780") 285 | vehicle_domain.setSpeedMode("sporty_%i" % vehNr, 0) 286 | vehNr += 1 287 | lastVeh = i 288 | if random.uniform(0, 1) < pNS: 289 | vehicle_domain.addFull("trailer_%i" % vehNr, "left-right", typeID="trailer", depart="%i" % i, 290 | departPos="780") 291 | vehicle_domain.setSpeedMode("trailer_%i" % vehNr, 0) 292 | vehNr += 1 293 | lastVeh = i 294 | ''' 295 | ''' 296 | PORT = 8813 297 | #sumoBinary = "/usr/local/bin/sumo-gui" # on Mac 298 | sumoBinary = "/usr/bin/sumo-gui" # on linux 299 | sumoProcess = subprocess.Popen([sumoBinary, "-c", "crossroad.sumocfg", "--remote-port", str(PORT),"--collision.check-junctions","true","--collision.action","teleport"],stdout=sys.stdout, stderr=sys.stderr) 300 | 301 | 302 | 303 | vd = traci._vehicle.VehicleDomain() 304 | sd = traci._simulation.SimulationDomain() 305 | 306 | traci.init(PORT) 307 | 308 | generate_flow(vd) 309 | 310 | vid = "Audi L3" 311 | vd.addFull(vid, "cross", typeID="trailer", departPos="780", arrivalPos="780") 312 | vd.setSpeedMode(vid, 0) 313 | 314 | 315 | step = 0 316 | 317 | while step < 10000: 318 | print "Here we take the control !!" 319 | 320 | /*****************************************************************/ 321 | if "reach goal || 100 steps gone || collisions happen": 322 | 323 | "first remove the auto car " 324 | "then add a new auto car" 325 | /*****************************************************************/ 326 | 327 | teleport_list = sd.getEndingTeleportIDList() 328 | print("teleport list:",teleport_list) 329 | arrived_list = sd.getArrivedIDList() 330 | print("arrived list:",arrived_list) 331 | 332 | if vid in teleport_list: 333 | collision_t = sd.getCurrentTime() 334 | print("collisions happened!") 335 | print("collision time: %d" % collision_t) 336 | vd.remove(vid) 337 | vd.addFull(vid, "cross", typeID="trailer", departPos="780", arrivalPos="780") 338 | vd.setSpeedMode(vid, 0) 339 | 340 | elif vid in arrived_list: 341 | 342 | vd.addFull(vid, "cross", typeID="trailer", departPos="780", arrivalPos="780") 343 | vd.setSpeedMode(vid, 0) 344 | 345 | elif not step % 100: 346 | vd.remove(vid) 347 | vd.addFull(vid, "cross", typeID="trailer", departPos="780", arrivalPos="780") 348 | vd.setSpeedMode(vid, 0) 349 | 350 | #print vd.getPosition(vid) 351 | #print len(vd.getIDList()) 352 | traci.simulationStep() 353 | 354 | 355 | step += 1 356 | 357 | traci.close() 358 | 359 | ''' -------------------------------------------------------------------------------- /crossroad/Model.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import subprocess 3 | import traci 4 | import random 5 | import math 6 | import numpy as np 7 | 8 | from keras.models import Sequential 9 | from keras.layers import Dense, Activation,Flatten, Input, merge, Lambda 10 | 11 | 12 | from keras.initializers import normal, identity 13 | from keras import optimizers 14 | 15 | class DQN_for_CROSS: 16 | 17 | def __init__(self,env): 18 | 19 | self.env = env 20 | self.memory = list() 21 | 22 | 23 | self.action_size = 5 24 | self.input_dim = 150*500*1 25 | self.mini_batch_size = 50 26 | self.gamma = 0.9 27 | self.learning_rate = 0.0001 28 | self.epsilon = 1.0 29 | self.epsilon_decay = 0.995 30 | self.epsilon_min = 0.05 31 | self.create_model() 32 | 33 | def create_model(self): 34 | 35 | model = Sequential() 36 | model.add(Dense(100,input_dim=self.input_dim,activation='relu')) 37 | model.add(Dense(100,activation='relu')) 38 | model.add(Dense(100,activation='relu')) 39 | model.add(Dense(self.action_size,activation='linear')) 40 | model.compile(loss='mse',optimizer = optimizers.RMSprop(lr=self.learning_rate)) 41 | self.model = model 42 | 43 | def remember(self,state,action,reward,next_state,done): 44 | self.memory.append((state,action,reward,next_state,done)) 45 | 46 | def act(self, state): 47 | if np.random.rand() <= self.epsilon: 48 | return self.env.action_space.sample() 49 | act_values = self.model.predict(state) 50 | return np.argmax(act_values[0]) # returns action 51 | -------------------------------------------------------------------------------- /crossroad/crossroad.net.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | -------------------------------------------------------------------------------- /crossroad/crossroad.rou.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /crossroad/crossroad.sumocfg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /crossroad/main.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import subprocess 3 | import traci 4 | import random 5 | import math 6 | import numpy as np 7 | from Env import Intersaction 8 | from Model import DQN_for_CROSS 9 | import cv2 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | 14 | 15 | 16 | if __name__ == '__main__': 17 | 18 | Intersac = Intersaction() 19 | ''' 20 | Intersac.vehicle_domain.addFull(Intersac.c_vid, Intersac.DrivingRoute, typeID="sporty", departPos=Intersac.departPos, 21 | arrivalPos=Intersac.arrivalPos) 22 | 23 | for i in range(10000): 24 | traci.simulationStep() 25 | Intersac.state(Intersac.vehicle_domain) 26 | ''' 27 | DQN = DQN_for_CROSS(Intersac) 28 | episode_num = 100000 29 | ep_idx = 0 30 | max_timestep = 100 31 | 32 | for ep_idx in range(episode_num): 33 | loss = 0 34 | cur_state = Intersac.reset() 35 | cur_state = np.reshape(cur_state,[1,150*500*1]) 36 | n = 0 37 | c_action_idx = -1 38 | while n <= max_timestep: 39 | 40 | 41 | if c_action_idx != 4: 42 | c_action_idx = DQN.act(cur_state) 43 | 44 | next_state, c_reward, done, accident, _ = Intersac.step(c_action_idx) 45 | print "next state's shape:",next_state.shape 46 | 47 | 48 | next_state = np.reshape(next_state,[1,150*500*1]) 49 | DQN.remember(cur_state, c_action_idx, c_reward, next_state, done) 50 | 51 | cur_state = next_state 52 | 53 | if done or accident: 54 | print "one episode come to end!" 55 | break 56 | 57 | n += _ 58 | 59 | # --------------------------- start replay training -------------------------# 60 | 61 | batch_size = min(DQN.mini_batch_size, len(DQN.memory)) 62 | batches_idx = np.random.choice(len(DQN.memory), batch_size) 63 | 64 | for i in batches_idx: 65 | replay_c_state, replay_c_action_idx, replay_c_reward, replay_next_state, replay_done = DQN.memory[i] 66 | print "replay_next_state.shape:",replay_next_state.shape #[1L,75000L] 67 | print "replay_c_state.shape:",replay_c_state.shape #[1L,75000L] 68 | if replay_done: 69 | replay_c_dreamValue_c_action = replay_c_reward 70 | else: 71 | # model.predict means forward calculation for getting real label. Since we don't have any label before but to calculate it. 72 | # if we have real label(like the situation in Lily's SVM), then we don't need to model.predict here 73 | replay_c_dreamValue_c_action = replay_c_reward + DQN.gamma * np.max(DQN.model.predict(replay_next_state)[0]) 74 | 75 | temp = DQN.model.predict(replay_c_state) 76 | temp[0][replay_c_action_idx] = replay_c_dreamValue_c_action 77 | replay_c_dreamValue_all_actions = temp 78 | 79 | # model.fit means training. replay_c_state : X (input sample); replay_c_expectValue:Y (real label) 80 | history = DQN.model.fit(replay_c_state, replay_c_dreamValue_all_actions, epochs=1, batch_size=1) 81 | 82 | 83 | loss += history.history["loss"][0] 84 | loss = loss * 1.0 / batch_size 85 | print("trainning loss at Episode %d, step %d : %.2f" % (ep_idx,n,loss)) 86 | if DQN.epsilon > DQN.epsilon_min: 87 | DQN.epsilon *= DQN.epsilon_decay 88 | 89 | traci.close() 90 | ''' 91 | # test case 92 | if __name__ == '__main__': 93 | Intersac = Intersaction() 94 | 95 | Intersac.vehicle_domain.addFull(Intersac.c_vid, Intersac.DrivingRoute, typeID="sporty", departPos=Intersac.departPos, 96 | arrivalPos=Intersac.arrivalPos) 97 | 98 | for i in range(10000): 99 | traci.simulationStep() 100 | Intersac.state(Intersac.vehicle_domain) 101 | ''' -------------------------------------------------------------------------------- /programImage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xubo92/robotics-navigation-deep-reinforcement-learning/1cbd4f545d9a99d4229ea2ba20d79a4950e9e556/programImage.png --------------------------------------------------------------------------------