├── DrivingScenarioEnv.m ├── OpenTrafficLab ├── +drivingBehavior │ ├── gippsDriverModel.m │ └── intelligentDriverModel.m ├── +trafficControl │ ├── TrafficController.m │ └── TrafficLight.m ├── @DrivingStrategy │ └── DrivingStrategy.m ├── @Node │ ├── Node.m │ └── roadDistanceToCenterAndLeft.m ├── DrivingStrategyRL.m ├── createDrivingScenario.m ├── createTJunctionNetwork.m ├── createTJunctionScenario.m └── createVehiclesForTJunction.m ├── README.md ├── SECURITY.md ├── checkCollision.m ├── createDQN.m ├── createTrainOpts.m ├── license ├── master.m ├── masterLiveScript.mlx ├── observationSpace1.m ├── observationSpace2.m ├── obtainReward.m ├── savedAgents ├── TjunctionDQNAgentDesign1.mat ├── TjunctionDQNAgentDesign2.mat ├── TjunctionDQNAgentDesign2Training.png ├── TjunctionDQNAgentDesign3.mat └── TjunctionDQNAgentDesign3Training.png ├── savedFigures ├── TjunctionRLControl1.gif ├── TjunctionRLControl2.gif ├── TjunctionRLcontrol.gif └── workflow.png ├── savedTestExperience ├── RLTrafficControlDesign1.fig ├── RLTrafficControlDesign2.fig ├── comparision.fig ├── comparision.png ├── comparison2.png ├── createTestPlot.m ├── fixTimeControl.mat ├── phaseDesign1exp.mat ├── phaseDesign2exp.mat ├── phaseDesign3exp.mat └── rewardforEachTest.mat ├── savedVideos ├── RLTrained.gif ├── RLlearningStage.gif └── RLlearningStage_linkedin.jpg ├── signalPhaseDesign1.m ├── signalPhaseDesign2.m └── signalPhaseDesign3.m /DrivingScenarioEnv.m: -------------------------------------------------------------------------------- 1 | classdef DrivingScenarioEnv < rl.env.MATLABEnvironment 2 | % Copyright 2020 The MathWorks, Inc. 3 | %MYENVIRONMENT: Template for defining custom environment in MATLAB. 4 | 5 | % parameters for simulation environment 6 | properties 7 | scenario 8 | network 9 | traffic 10 | cars 11 | state 12 | driver 13 | InjectionRate 14 | TurnRatio 15 | N = 3 % number of road 16 | phaseDuration = 50 % time duration for each of the phase 17 | T 18 | end 19 | 20 | % simulation doesn't have yellow light 21 | % manually set up clearning phase here if needed 22 | properties 23 | clearingPhase = false 24 | clearingPhaseTime = 0 25 | TrafficSignalDesign 26 | ObservationSpaceDesign 27 | end 28 | 29 | % parameter for reward definition 30 | properties 31 | rewardForPass = 0 32 | vehicleEnterJunction % keep record of cars pass the intersection 33 | hitPenalty = 20 34 | penaltyForFreqSwitch = 1 35 | safeDistance = 2.25 % check collision 36 | slowSpeedThreshold = 3.5 % check whether car is waiting 37 | end 38 | 39 | properties 40 | recordVid = false 41 | vid 42 | end 43 | 44 | properties 45 | discrete_action = [0 1 2]; 46 | dim =10; 47 | end 48 | 49 | properties(Access = protected) 50 | IsDone = false 51 | end 52 | 53 | %% Necessary Methods 54 | methods 55 | function this = DrivingScenarioEnv() 56 | % Initialize Observation settings 57 | ObservationInfo = rlNumericSpec([10, 1]); % # of state 58 | ObservationInfo.Name = 'real-time traffic information'; 59 | ObservationInfo.Description = ''; 60 | 61 | % Initialize action settings 62 | ActionInfo = rlFiniteSetSpec([0 1 2]); % three phases 63 | ActionInfo.Name = 'traffic signal phases'; 64 | 65 | % The following line implements built-in functions of the RL environment 66 | this = this@rl.env.MATLABEnvironment(ObservationInfo,ActionInfo); 67 | end 68 | 69 | function [state, Reward,IsDone,LoggedSignals] = step(this, Action) 70 | Action = getForce(this, Action); 71 | % update the action 72 | pre_phase = this.traffic.IsOpen; 73 | if this.TrafficSignalDesign == 1 74 | cur_phase = signalPhaseDesign1(Action); 75 | elseif this.TrafficSignalDesign == 2 76 | cur_phase = signalPhaseDesign2(Action); 77 | elseif this.TrafficSignalDesign == 3 78 | cur_phase = signalPhaseDesign3(Action); 79 | end 80 | 81 | % Reward: penalty for signal phase switch 82 | changed = ~isequal(pre_phase, cur_phase); 83 | Reward = this.penaltyForFreqSwitch * (1 - changed); 84 | 85 | % (yellow light time)add clearing phase when signal phase switch 86 | if changed && this.clearingPhase 87 | for i = 1:this.clearingPhaseTime 88 | this.traffic.IsOpen = [0, 0, 0, 0, 0, 0]; 89 | advance(this.scenario); 90 | this.T = this.T + this.scenario.SampleTime; 91 | notifyEnvUpdated(this); 92 | % check terminal condition 93 | IsHit = checkCollision(this); 94 | Reward = Reward - IsHit * this.hitPenalty; 95 | this.IsDone = IsHit || this.T+0.5 >= this.scenario.StopTime; 96 | if this.IsDone 97 | break 98 | end 99 | end 100 | end 101 | 102 | % (green light time)simulate the signal phase based on the action by RL 103 | this.traffic.IsOpen = cur_phase; 104 | if ~this.IsDone 105 | for i = 1:this.phaseDuration 106 | % update traffic state 107 | advance(this.scenario); 108 | this.T = this.T + this.scenario.SampleTime; 109 | % update visulization 110 | notifyEnvUpdated(this); 111 | % check terminal condition 112 | IsHit = checkCollision(this); 113 | Reward = Reward - IsHit * this.hitPenalty; 114 | this.IsDone = IsHit || this.T+0.5 >= this.scenario.StopTime; 115 | if this.IsDone 116 | break 117 | end 118 | % obtain reward 119 | Reward = Reward + obtainReward(this, cur_phase); 120 | end 121 | end 122 | if this.ObservationSpaceDesign == 1 123 | state = observationSpace1(this, Action); 124 | else 125 | state = observationSpace2(this, Action); 126 | end 127 | this.state = state; 128 | IsDone = this.IsDone; 129 | LoggedSignals = []; 130 | end 131 | 132 | 133 | function InitialState = reset(this) 134 | % flag for record simulation 135 | this.recordVid = false; 136 | % Initialize scenario 137 | this.scenario = createTJunctionScenario(); 138 | this.scenario.StopTime = 100; 139 | this.scenario.SampleTime = 0.05; 140 | this.T = 0; 141 | % initialize network 142 | this.network = createTJunctionNetwork(this.scenario); 143 | this.traffic = trafficControl.TrafficController(this.network(7:12)); 144 | % car parameters 145 | this.InjectionRate = [250, 250, 250]; % veh/hour 146 | this.TurnRatio = [50, 50]; 147 | this.cars = createVehiclesForTJunction(this.scenario, this.network, this.InjectionRate, this.TurnRatio); 148 | this.vehicleEnterJunction = []; 149 | % obtain state from traffic and network 150 | if this.ObservationSpaceDesign == 1 151 | InitialState = observationSpace1(this, 0); 152 | else 153 | InitialState = observationSpace2(this, 0); 154 | end 155 | % visulization 156 | notifyEnvUpdated(this); 157 | end 158 | end 159 | 160 | methods 161 | function force = getForce(this,action) 162 | if ~ismember(action,this.ActionInfo.Elements) 163 | error('Action must be integer from 1 to numAction'); 164 | end 165 | force = action; 166 | end 167 | % update the action info based on max force 168 | function updateActionInfo(this) 169 | this.ActionInfo.Elements = this.discrete_action; 170 | end 171 | end 172 | 173 | methods (Access = protected) 174 | function envUpdatedCallback(this) 175 | if this.T == 0 176 | close all; 177 | plot(this.scenario) 178 | set(gcf,'Visible','On'); 179 | if this.recordVid 180 | this.vid = VideoWriter('baseRLlearningProcess33'); 181 | this.vid.FrameRate=20; 182 | open(this.vid) 183 | end 184 | end 185 | if this.recordVid 186 | frame = getframe(gcf); 187 | writeVideo(this.vid,frame); 188 | end 189 | this.traffic.plotOpenPaths() 190 | drawnow 191 | end 192 | end 193 | end -------------------------------------------------------------------------------- /OpenTrafficLab/+drivingBehavior/gippsDriverModel.m: -------------------------------------------------------------------------------- 1 | function [acc,v_b,v_a] = gippsDriverModel(spacing,speed,speedDiff,varargin) 2 | 3 | desiredSpeed = 10; %Desired speed 4 | maxSpeed = 20; %max. speed 5 | minSpeed = 0; % min. speed 6 | minAcc = -3; %min. acceleration 7 | maxAcc = 3; %max. acceleration 8 | minAccEstimate = minAcc; 9 | reactionTime = 0.8; 10 | S0 = 2; 11 | 12 | v_now = speed; 13 | v_infront = speed+speedDiff; 14 | 15 | v_a = v_now + 2.5*maxAcc*reactionTime*(1-v_now/desiredSpeed)*sqrt(0.025+v_now/desiredSpeed); 16 | v_b = minAcc*reactionTime + sqrt((minAcc*reactionTime)^2-minAcc*(2*(spacing-S0)-v_now*reactionTime-v_infront^2/minAccEstimate)); 17 | 18 | v_b = max(v_b,minSpeed); 19 | v_a = min(v_a,maxSpeed); 20 | v_new = min([v_a,v_b]); 21 | 22 | acc = (v_new-v_now)/reactionTime; 23 | 24 | end 25 | 26 | -------------------------------------------------------------------------------- /OpenTrafficLab/+drivingBehavior/intelligentDriverModel.m: -------------------------------------------------------------------------------- 1 | function [acc] = intelligentDriverModel(spacing,speed,speedDiff,varargin) 2 | % Parameters 3 | % NAME REALISTIC BOUNDS DEFAULT UNITS 4 | % desiredSpeed [0,11] 33 m/s 5 | % safeHeadway [1,3] 1.6 s 6 | % accelMax [0.5,2] 0.73 m/s^2 7 | % decelConf [0.5,2] 1.67 m/s^2 8 | % beta 4 9 | % minJamSpacing [0,5] 2 m 10 | % nonLinJamSpacing [0,5] 3 m 11 | 12 | % Need to program input parse to pass new parameters 13 | desiredSpeed = 10; 14 | safeHeadway = 1.6; 15 | accelMax = 0.73; 16 | decelConf = 1.67; 17 | beta = 4; 18 | minJamSpacing = 2; 19 | nonLinJamSpacing = 0; 20 | 21 | desiredSpacing = minJamSpacing + nonLinJamSpacing*sqrt(speed/desiredSpeed)... 22 | +speed*safeHeadway-speed*speedDiff/2/sqrt(accelMax*decelConf); 23 | 24 | acc = accelMax*(1-(speed/desiredSpeed)^beta-(desiredSpacing/spacing)^2); 25 | 26 | if acc<-10 || acc>3 || isnan(acc) 27 | acc; 28 | end 29 | end 30 | 31 | -------------------------------------------------------------------------------- /OpenTrafficLab/+trafficControl/TrafficController.m: -------------------------------------------------------------------------------- 1 | classdef TrafficController < driving.scenario.MotionStrategy... 2 | & driving.scenario.mixin.PropertiesInitializableInConstructor 3 | 4 | 5 | properties 6 | Scenario 7 | Nodes = Node.empty % List of nodes the traffic controller manages 8 | IsOpen % Boolean list indicating wether the node can be entered 9 | PlotHandles = plot3([],[],[]); 10 | end 11 | 12 | methods 13 | function obj = TrafficController(nodes,varargin) 14 | 15 | obj@driving.scenario.MotionStrategy(nodes(1).Scenario.actor); 16 | obj@driving.scenario.mixin.PropertiesInitializableInConstructor(varargin{:}); 17 | 18 | obj.EgoActor.MotionStrategy = obj; 19 | obj.EgoActor.IsVisible = false; 20 | %obj.EgoActor.Position = nodes(1).getRoadCenterFromStation(10); 21 | 22 | obj.Scenario = nodes(1).Scenario; 23 | obj.Nodes = nodes; 24 | obj.IsOpen = false(size(nodes)); 25 | 26 | end 27 | 28 | function set.Nodes(obj,nodes) 29 | for node = nodes 30 | if ~any(obj.Nodes==node) 31 | obj.Nodes(end+1)=node; 32 | node.TrafficController=obj; 33 | end 34 | end 35 | end 36 | 37 | function running = move(obj,SimulationTime) 38 | 39 | running = true; 40 | end 41 | 42 | function running = restart(obj,inputArg) 43 | %METHOD1 Summary of this method goes here 44 | % Detailed explanation goes here 45 | outputArg = obj.Property1 + inputArg; 46 | end 47 | end 48 | 49 | methods 50 | function state = getNodeState(obj,node) 51 | state = obj.IsOpen(obj.Nodes==node); 52 | end 53 | end 54 | 55 | methods % plot methods 56 | function plotOpenPaths(obj,ax) 57 | green = [0.4660 0.6740 0.1880]; 58 | red = [0.6350 0.0780 0.1840]; 59 | yellow = [0.9290 0.6940 0.1250]; 60 | if nargin<2 61 | ax=gca; 62 | end 63 | hold on 64 | if isempty(obj.PlotHandles) 65 | for node=obj.Nodes 66 | obj.PlotHandles(end+1) = plot3(ax,node.Mapping(:,2),node.Mapping(:,3),node.Mapping(:,4)+10); 67 | end 68 | end 69 | for idx = 1:length(obj.Nodes) 70 | p = obj.PlotHandles(idx); 71 | node = obj.Nodes(idx); 72 | p.XData = node.Mapping(:,2); 73 | p.YData = node.Mapping(:,3); 74 | p.ZData = node.Mapping(:,4)+0; 75 | 76 | p.LineWidth = 2; 77 | if obj.IsOpen(idx)==true 78 | p.Color = [green,1]; 79 | else 80 | p.Color = [red,0.2]; 81 | end 82 | end 83 | end 84 | end 85 | end 86 | 87 | -------------------------------------------------------------------------------- /OpenTrafficLab/+trafficControl/TrafficLight.m: -------------------------------------------------------------------------------- 1 | classdef TrafficLight < trafficControl.TrafficController 2 | 3 | properties 4 | Cliques 5 | Cycle 6 | Phase 7 | end 8 | 9 | methods 10 | function obj = TrafficLight(nodes,varargin) 11 | obj@trafficControl.TrafficController(nodes,varargin{:}); 12 | end 13 | 14 | function running = move(obj,SimulationTime) 15 | obj.IsOpen = false(size(obj.Nodes)); 16 | t = mod(SimulationTime,obj.Cycle(end)); 17 | phase = discretize(t,obj.Cycle); 18 | numPhases = max(obj.Cliques); 19 | for i =1:numPhases 20 | if phase==i 21 | obj.IsOpen(obj.Cliques==i)=true; 22 | end 23 | end 24 | running = true; 25 | end 26 | end 27 | end 28 | 29 | -------------------------------------------------------------------------------- /OpenTrafficLab/@DrivingStrategy/DrivingStrategy.m: -------------------------------------------------------------------------------- 1 | classdef DrivingStrategy < driving.scenario.MotionStrategy... 2 | & driving.scenario.mixin.PropertiesInitializableInConstructor 3 | %% Properties 4 | 5 | properties 6 | Scenario 7 | % 8 | Data = struct('Time',[],... 9 | 'Station',[],... 10 | 'Speed',[],... 11 | 'Node',Node.empty,... 12 | 'UDStates',[],... 13 | 'Position',[]); 14 | NextNode = Node.empty % Vehicle's planned path 15 | UDStates = [] % User defined states 16 | 17 | % Parameters 18 | DesiredSpeed = 10; % [m/s] 19 | ReactionTime = 0.8; % [s] 20 | 21 | % 22 | CarFollowingModel = 'Gipps'; % Options: 'IDM','Gipps' 23 | 24 | % Flags 25 | StoreData = true; % Set to True to store time series data in Data structure 26 | StaticLaneKeeping = true; % % Set to false to implement user define lateral control 27 | end 28 | 29 | properties (Dependent,Access = protected) 30 | % Vehicle States 31 | Position(3,1) double {mustBeReal, mustBeFinite} 32 | Velocity(3,1) double {mustBeReal, mustBeFinite} 33 | ForwardVector(3,1)double {mustBeReal, mustBeFinite} 34 | Speed 35 | end 36 | 37 | properties (Access = protected) 38 | % Input States 39 | Acceleration(1,1) double {mustBeReal, mustBeFinite} = 0 40 | AngularAcceleration(1,1) double {mustBeReal, mustBeFinite} = 0 41 | end 42 | properties (Access = protected) 43 | % Position Dependent Variables 44 | Node = Node.empty; 45 | Station = []; 46 | % Environment Dependent Variables 47 | IsLeader = false; 48 | Leader = driving.scenario.Vehicle.empty; 49 | LeaderSpacing = nan; 50 | end 51 | 52 | %% Constructor 53 | methods 54 | function obj = DrivingStrategy(egoActor,varargin) 55 | obj@driving.scenario.MotionStrategy(egoActor); 56 | obj@driving.scenario.mixin.PropertiesInitializableInConstructor(varargin{:}); 57 | egoActor.MotionStrategy = obj; 58 | obj.Scenario = egoActor.Scenario; 59 | obj.Speed = obj.DesiredSpeed; 60 | end 61 | end 62 | %% Set and Get methods 63 | methods 64 | % Position 65 | function set.Position(obj,pos) 66 | obj.EgoActor.Position = pos; 67 | end 68 | function pos = get.Position(obj) 69 | pos = nan(length(obj),3); 70 | for idx = 1:length(obj) 71 | pos(idx,:) = obj(idx).EgoActor.Position; 72 | end 73 | end 74 | % Velocity 75 | function set.Velocity(obj,vel) 76 | obj.EgoActor.Velocity = vel; 77 | end 78 | function vel = get.Velocity(obj) 79 | vel = nan(length(obj),3); 80 | for idx = 1:length(obj) 81 | vel(idx,:) = obj(idx).EgoActor.Velocity; 82 | end 83 | end 84 | % Forward Vector 85 | function set.ForwardVector(obj,dir) 86 | if all(size(dir)==[3,1]) 87 | dir=dir'; 88 | end 89 | obj.EgoActor.ForwardVector = dir; 90 | end 91 | function dir = get.ForwardVector(obj) 92 | dir = nan(length(obj),3); 93 | for idx = 1:length(obj) 94 | dir(idx,:) = obj(idx).EgoActor.ForwardVector; 95 | end 96 | end 97 | % Speed 98 | function s = get.Speed(obj) 99 | s = obj.ForwardVector*obj.Velocity'; 100 | end 101 | function set.Speed(obj,speed) 102 | obj.EgoActor.Velocity = obj.EgoActor.ForwardVector*speed; 103 | end 104 | 105 | % Node 106 | function set.Node(obj,node) 107 | if node == obj.Node 108 | return 109 | end 110 | 111 | if ~isempty(obj.Node) 112 | removeVehicle(obj.Node,obj.EgoActor); 113 | if ~isempty(node) 114 | addVehicle(node,obj.EgoActor); 115 | end 116 | obj.Node = node; 117 | else 118 | if ~isempty(node) 119 | addVehicle(node,obj.EgoActor); 120 | end 121 | obj.Node = node; 122 | end 123 | 124 | 125 | end 126 | 127 | end 128 | %% Public Access Methods 129 | methods (Access = public) 130 | function s = getStationDistance(obj,time) 131 | if nargin<2 %If no time is given assume current sim time 132 | time = obj(1).Scenario.SimulationTime; 133 | end 134 | dt = obj(1).Scenario.SampleTime; 135 | for idx = 1:length(obj) 136 | tIdx = round((time-obj(idx).Data.Time(1))/dt)+1; 137 | if tIdx>0 && tIdx0 && tIdx0 && tIdx0 && tIdx0 && tIdx2 237 | obj.Speed = speed; 238 | end 239 | obj.Station = 0; 240 | updateUDStates(obj); 241 | addData(obj,t) 242 | obj.EgoActor.IsVisible=true; 243 | 244 | end 245 | 246 | function addData(obj,t) 247 | if ~isempty(obj.Data.Time) 248 | if t~=obj.Data.Time(end)+obj.Scenario.SampleTime 249 | error('Can only add data to the following time step') 250 | end 251 | end 252 | 253 | obj.Data.Time(end+1) = t; 254 | obj.Data.Station(end+1) = obj.Station; 255 | obj.Data.Node(end+1) = obj.Node; 256 | obj.Data.Speed(end+1) = obj.Speed; 257 | obj.Data.Position(end+1,:) = obj.Position; 258 | if ~isempty(obj.UDStates) 259 | obj.Data.UDStates(end+1,:) = obj.UDStates; 260 | end 261 | 262 | % Delete first entry if store data flag is false 263 | if ~obj.StoreData && length(obj.Data.Time)>2 264 | obj.Data.Time(1) = []; 265 | obj.Data.Station(1) = []; 266 | obj.Data.Node(1) = []; 267 | obj.Data.Speed(1) = []; 268 | obj.Data.Position(1,:) = []; 269 | if ~isempty(obj.UDStates) 270 | obj.Data.UDStates(1,:) = []; 271 | end 272 | 273 | end 274 | end 275 | 276 | function orientEgoActor(obj,direction,offset) 277 | if nargin == 1 278 | [s,direction,offset]= obj.getLaneInformation(); 279 | end 280 | obj.ForwardVector = [direction,0]; 281 | obj.Position = obj.Position + [offset,0]; 282 | end 283 | 284 | function [leader,leaderSpacing] = getLeader(obj,tNow) 285 | % Get other actors in the node and their stations 286 | actors = getVehiclesInSegment(obj); 287 | drivers = [actors.MotionStrategy]; 288 | selfIdx = find(drivers == obj); 289 | stations = [drivers.Station]; 290 | deltaStations = stations - stations(selfIdx); 291 | deltaStations(deltaStations<0.1)=inf; 292 | [leaderSpacing,idx]=min(deltaStations); 293 | if isinf(leaderSpacing) 294 | leaderSpacing = nan; 295 | leader = driving.scenario.Vehicle.empty; 296 | if ~isempty(obj.NextNode) 297 | [sLeader,leader] = obj.NextNode(1).getTrailingVehicleStation(tNow); 298 | if ~isempty(leader) 299 | leaderSpacing = sLeader-leader.Length +(obj.getSegmentLength()-stations(selfIdx)); 300 | end 301 | end 302 | 303 | else 304 | leader = actors(idx); 305 | leaderSpacing = leaderSpacing-leader.Length; 306 | end 307 | 308 | end 309 | 310 | function actors = getVehiclesInSegment(obj) 311 | actors = obj.Node.getActiveVehicles(); 312 | end 313 | 314 | function state = getNextNodeState(obj) 315 | state = 1; 316 | if ~isempty(obj.NextNode) 317 | if ~isempty(obj.NextNode(1).TrafficController) 318 | state = obj.NextNode(1).getNodeState(); 319 | end 320 | end 321 | end 322 | end 323 | %% Nominal Driving Logic 324 | methods 325 | function mode = determineDrivingMode(obj) 326 | % This function decides which mode(s) the driver is following, 327 | % and the creates and outputs the function handle to that mode 328 | 329 | segLength = obj.getSegmentLength(); 330 | leader = obj.Leader; 331 | leaderSpacing = obj.LeaderSpacing; 332 | distToEnd = segLength-obj.Station; 333 | state = obj.getNextNodeState(); 334 | isLeader = isempty(leader); 335 | leaderIsPastRoadEnd = leaderSpacing>distToEnd; 336 | 337 | % Determine the mode 338 | if state 339 | if isLeader 340 | mode = 'ApproachingGreenLight'; 341 | else 342 | mode = 'CarFollowing'; 343 | end 344 | else 345 | if isLeader 346 | mode = 'ApproachingRedLight'; 347 | else 348 | if leaderIsPastRoadEnd 349 | mode = 'ApproachingRedLight'; 350 | else 351 | mode = 'CarFollowing'; 352 | end 353 | end 354 | end 355 | end 356 | 357 | function inputs = determineDrivingInputs(obj,mode) 358 | 359 | segLength = obj.getSegmentLength(); 360 | leader = obj.Leader; 361 | distToEnd = segLength-obj.Station; 362 | 363 | switch mode 364 | case 'CarFollowing' 365 | delVel = leader.MotionStrategy.Speed-obj.Speed; 366 | leaderSpacing = obj.LeaderSpacing; 367 | case 'ApproachingRedLight' 368 | delVel = 0 - obj.Speed; 369 | leaderSpacing = distToEnd; 370 | case 'ApproachingGreenLight' 371 | delVel = 0; 372 | leaderSpacing = inf; 373 | end 374 | 375 | inputs = [carFollowing(obj,leaderSpacing,obj.Speed,delVel),0]; 376 | end 377 | 378 | function acc = carFollowing(obj,spacing,speed,speedDiff) 379 | if strcmp(obj.CarFollowingModel,'IDM') 380 | acc = drivingBehavior.intelligentDriverModel(spacing,speed,speedDiff); 381 | end 382 | 383 | if strcmp(obj.CarFollowingModel,'Gipps') 384 | if mod(obj.Scenario.SimulationTime,obj.ReactionTime)=car.ExitTime || tNow0)&&(tNow - car.EntryTime1 426 | speed = min(v_b,v_a); 427 | injectVehicle(obj,tNow,speed); 428 | else 429 | car.EntryTime = car.EntryTime+dt; 430 | running = true; 431 | return; 432 | end 433 | %% 434 | %c = discretize(s-5,[-inf,5,10,inf]); 435 | % switch c 436 | % case 1 % 437 | % car.EntryTime = car.EntryTime+dt; 438 | % running = true; 439 | % return; 440 | % case 2 441 | % if ~isempty(leader) 442 | % speed = leader.MotionStrategy.getSpeed(tNow); 443 | % else 444 | % speed=obj.Speed; 445 | % end 446 | % injectVehicle(obj,tNow); 447 | % case 3 448 | % injectVehicle(obj,tNow,obj.Speed); 449 | % 450 | % end 451 | end 452 | 453 | %% Set vehicle's state and variables to tNow 454 | % State 455 | obj.Position = getPosition(obj,tNow); 456 | obj.Speed = getSpeed(obj,tNow); 457 | % State dependent variables 458 | obj.Station = getStationDistance(obj,tNow); 459 | obj.Node = getNode(obj,tNow); 460 | 461 | % Environment Dependent Variables 462 | [obj.Leader,obj.LeaderSpacing]=getLeader(obj,tNow); 463 | if isempty(obj.Leader) 464 | obj.IsLeader = true; 465 | end 466 | %% Determine Driving Mode 467 | mode = determineDrivingMode(obj); 468 | 469 | %% Get Inputs 470 | inputs = determineDrivingInputs(obj,mode); 471 | 472 | obj.Acceleration = inputs(1); 473 | obj.AngularAcceleration = inputs(2); 474 | 475 | %% Integrate 476 | 477 | % Position and Velocity 478 | obj.Position = obj.Position + dt*car.Velocity; 479 | obj.Speed = obj.Speed + dt*obj.Acceleration; 480 | 481 | % Forward Vector 482 | %obj.ForwardVector = ... to be implemented 483 | 484 | %% Update state dependent variables 485 | % Get new station distance, node and lane information 486 | [station, direction, offset] = getLaneInformation(obj); 487 | 488 | if station > getSegmentLength(obj) % Check if the veh entered a new node 489 | goToNextNode(obj,tNext) 490 | if isempty(obj.Node) 491 | % The vehicle finished its path 492 | running = false; 493 | return 494 | else 495 | % Get new staion andlane information 496 | [station, direction, offset] = getLaneInformation(obj); 497 | end 498 | 499 | end 500 | obj.Station = station; 501 | updateUDStates(obj); 502 | 503 | if obj.StaticLaneKeeping %Orient veh along lane if lateral control is deactivated 504 | obj.orientEgoActor(direction,offset); 505 | end 506 | 507 | addData(obj,tNext); 508 | running = true; 509 | end 510 | 511 | function restart(obj) 512 | % Needs to be programmed 513 | 514 | end 515 | end 516 | end 517 | 518 | -------------------------------------------------------------------------------- /OpenTrafficLab/@Node/Node.m: -------------------------------------------------------------------------------- 1 | classdef Node < handle 2 | %UNTITLED2 Summary of this class goes here 3 | % Detailed explanation goes here 4 | 5 | properties 6 | 7 | ConnectsTo = Node.empty % Node it spills into 8 | ConnectsFrom = Node.empty % Node that feeds it 9 | SharesRoadWith = Node.empty % Road or Junction it is a part of 10 | InjectionRate 11 | TurnRatio 12 | 13 | end 14 | 15 | properties (SetAccess = protected) 16 | Scenario 17 | RoadSegment 18 | Lane 19 | Length 20 | Mapping 21 | Vehicles = driving.scenario.Vehicle.empty 22 | end 23 | 24 | properties (SetAccess = {?trafficControl.TrafficController}) 25 | TrafficController 26 | end 27 | 28 | %% Constructor and Setup methods 29 | methods 30 | function obj = Node(scenario,rs,lane) 31 | obj.Scenario = scenario; 32 | obj.RoadSegment = rs; 33 | obj.Lane = lane; 34 | obj.setMapping; 35 | end 36 | 37 | function setMapping(obj) 38 | rs = obj.RoadSegment; 39 | xr = 0:0.1:rs.hcd(end); 40 | map = nan(length(xr),6); 41 | if obj.Lane == -1 42 | xr = fliplr(xr); 43 | end 44 | for idx = 1:length(xr) 45 | 46 | [center, left, kappa, dkappa, dx_y] = obj.roadDistanceToCenterAndLeft(xr(idx), rs.hcd, rs.hl, rs.hip, rs.course, rs.k0, rs.k1, rs.vpp, rs.bpp); 47 | left = left*(-obj.Lane); 48 | dx_y = dx_y*(obj.Lane); 49 | center = center + left*3.65/2; 50 | 51 | if obj.Lane == -1 52 | map(idx,:) = [xr(end+1-idx),center,real(dx_y),imag(dx_y)]; 53 | else 54 | map(idx,:) = [xr(idx),center,real(dx_y),imag(dx_y)]; 55 | end 56 | end 57 | %map = sortrows(map); 58 | obj.Mapping = map; 59 | obj.Length = max(xr); 60 | end 61 | 62 | function set.ConnectsTo(this,those) 63 | for that=those 64 | if ~any(this.ConnectsTo==that) 65 | this.ConnectsTo(end+1) = that; 66 | end 67 | if ~any(that.ConnectsFrom==this) 68 | that.ConnectsFrom(end+1) = this; 69 | end 70 | end 71 | end 72 | 73 | function set.SharesRoadWith(this,those) 74 | for that = those 75 | if ~any(this.SharesRoadWith==that) 76 | this.SharesRoadWith(end+1) = that; 77 | end 78 | if ~any(that.SharesRoadWith==this) 79 | that.SharesRoadWith(end+1) = this; 80 | end 81 | end 82 | end 83 | 84 | end 85 | %% Public facing methods 86 | methods 87 | 88 | function [dist,forwardVector,offsetVector] = getStationDistance(obj,pos) 89 | % Given a position, this function returns the station distance 90 | % along the road, along with the direction of the road, and 91 | % the distance vector from the point to the road. 92 | distanceToPoints = sqrt(sum((obj.Mapping(:,2:3)-pos).^2,2)); 93 | [~,idx] = min(distanceToPoints); 94 | sampleToPosVector = pos-obj.Mapping(idx,2:3); 95 | forwardVector = obj.Mapping(idx,5:6); 96 | dist = obj.Mapping(idx,1)+dot(forwardVector,sampleToPosVector); 97 | sideVector = cross([forwardVector,0],[0,0,1]); 98 | offsetVector = -dot(sideVector(1:2),sampleToPosVector)*sideVector(1:2); 99 | end 100 | 101 | function [center,indexes] = getRoadCenterFromStation(obj,stations) 102 | center = zeros(length(stations),3); 103 | indexes = zeros(size(stations)); 104 | for idx = 1:length(stations) 105 | s = stations(idx); 106 | dist = (obj.Mapping(:,1)-s).^2; 107 | [~,ii] = min(dist); 108 | center(idx,:) = obj.Mapping(ii,2:4); 109 | indexes(idx)= ii; 110 | end 111 | end 112 | 113 | function length = getRoadSegmentLength(obj) 114 | length = obj.Length; 115 | end 116 | 117 | function added = addVehicle(obj,vehicle) 118 | if ~any(obj.Vehicles==vehicle) 119 | obj.Vehicles(end+1)=vehicle; 120 | added = true; 121 | else 122 | added = true; 123 | %error('Vehicle is already in the node'); 124 | end 125 | 126 | end 127 | 128 | function removed = removeVehicle(obj,vehicle) 129 | if any(obj.Vehicles==vehicle) 130 | obj.Vehicles(obj.Vehicles==vehicle)=[]; 131 | removed = true; 132 | else 133 | error('Vehicle is not in the node') 134 | end 135 | 136 | end 137 | 138 | function actors = getActiveVehicles(obj) 139 | actors = obj.Vehicles; 140 | end 141 | 142 | function [s,veh] = getTrailingVehicleStation(obj,time) 143 | if nargin<2 %If no time is given assume current sim time 144 | time = obj.Scenario.SimulationTime; 145 | end 146 | drivers = [obj.Vehicles.MotionStrategy]; 147 | if isempty(drivers) 148 | s = obj.getRoadSegmentLength(); 149 | veh = driving.scenario.Vehicle.empty; 150 | return 151 | end 152 | [s,idx] = min(getStationDistance(drivers,time)); 153 | veh = drivers(idx).EgoActor; 154 | end 155 | 156 | function s = getLeadingVehicleStation(obj,time) 157 | if nargin<2 %If no time is given assusme current sim time 158 | time = obj.Scenario.SimulationTime; 159 | end 160 | drivers = [obj.Vehicles.MotionStrategy]; 161 | if isempty(drivers) 162 | s = 0; 163 | return 164 | end 165 | s = max(getStationDistance(drivers,time)); 166 | end 167 | 168 | function plotPath(obj,ax) 169 | if nargin<2 170 | ax=gca; 171 | end 172 | for node=obj 173 | plot3(ax,node.Mapping(:,2),node.Mapping(:,3),node.Mapping(:,4)+10) 174 | end 175 | end 176 | 177 | function state = getNodeState(obj) 178 | state = obj.TrafficController.getNodeState(obj); 179 | end 180 | 181 | function clearVehicles(obj) 182 | for idx = 1:length(obj) 183 | obj(idx).Vehicles = []; 184 | end 185 | end 186 | end 187 | %% Helper methods 188 | methods(Static) 189 | [c, l, k, d, dx_y] = roadDistanceToCenterAndLeft(xr, hcd, hl, hip, course, k0, k1, vpp, bpp) 190 | end 191 | end 192 | 193 | -------------------------------------------------------------------------------- /OpenTrafficLab/@Node/roadDistanceToCenterAndLeft.m: -------------------------------------------------------------------------------- 1 | function [center, left, kappa, dkappa, dx_y] = roadDistanceToCenterAndLeft(xr, hcd, hl, hip, course, k0, k1, vpp, bpp) 2 | 3 | % Copyright 2017 The MathWorks, Inc. 4 | 5 | % find index/indices into table 6 | idx = discretize(xr, hcd); 7 | % If nan values come at end of the indices, it replace the maximum value at end of 8 | % the indices. 9 | if isnan(idx(end)) 10 | idx(isnan(idx)) = idx(find(isnan(idx)==0, 1, 'last' )); 11 | end 12 | % fetch clothoid segment at index and initial position. 13 | dkappa = (k1(idx)-k0(idx))./hl(idx); 14 | kappa0 = k0(idx); 15 | theta = course(idx); 16 | p0 = hip(idx); 17 | 18 | % get length and curvature into clothoid segment 19 | l = xr-hcd(idx); 20 | kappa = kappa0 + l.*dkappa; 21 | 22 | % get corresponding points in complex plane and derivative w.r.t. road length 23 | x_y = matlabshared.tracking.internal.scenario.fresnelg2(l, dkappa, kappa0, theta); 24 | dx_y = matlabshared.tracking.internal.scenario.dfresnelg2(l, dkappa, kappa0, theta); 25 | 26 | % get elevation and derivative w.r.t road length 27 | zp = ppval(vpp, xr); 28 | 29 | % get banking angles 30 | bank = ppval(bpp, xr); 31 | 32 | % assemble the 3D positions of the road centers. This corresponds to (xr, 0) in road coordinates. 33 | center = [real(x_y+p0) imag(x_y+p0) zp]; 34 | 35 | % assemble unit tangent to xy in xy plane (neglecting derivative of elevation) 36 | forward = [real(dx_y) imag(dx_y) zeros(length(x_y),1)]; 37 | forward = forward ./ sqrt(sum(forward.^2,2)); 38 | up = repmat([0 0 1],length(x_y),1); 39 | left = cross(up,forward,2); 40 | left = left ./ sqrt(sum(left.^2,2)); 41 | 42 | % apply bank angles 43 | left = [left(:,1).*cos(bank) left(:,2).*cos(bank) sin(bank)]; -------------------------------------------------------------------------------- /OpenTrafficLab/DrivingStrategyRL.m: -------------------------------------------------------------------------------- 1 | classdef DrivingStrategyRL < DrivingStrategy 2 | % Copyright 2020 The MathWorks, Inc. 3 | % DrivingStrategy specialized in reinforcement learning setting 4 | 5 | methods 6 | function obj = DrivingStrategyRL(egoActor,varargin) 7 | obj@DrivingStrategy(egoActor,varargin{:}); 8 | end 9 | 10 | function mode = determineDrivingMode(obj) 11 | % This function decides which mode(s) the driver is following, 12 | % and the creates and outputs the function handle to that mode 13 | 14 | segLength = obj.getSegmentLength(); 15 | leader = obj.Leader; 16 | leaderSpacing = obj.LeaderSpacing; 17 | distToEnd = segLength-obj.Station; 18 | state = obj.getNextNodeState(); 19 | isLeader = isempty(leader); 20 | leaderIsPastRoadEnd = leaderSpacing>distToEnd; 21 | timeToInt = distToEnd/obj.Speed; 22 | minHeadway = 0; 23 | % Determine the mode 24 | if isLeader 25 | if state 26 | mode = 'ApproachingGreenLight'; 27 | elseif timeToInt>minHeadway 28 | mode = 'ApproachingRedLight'; 29 | else 30 | mode = 'ApproachingGreenLight'; 31 | end 32 | else 33 | if leaderIsPastRoadEnd 34 | if state 35 | mode = 'CarFollowing'; 36 | elseif timeToInt>minHeadway 37 | mode = 'ApproachingRedLight'; 38 | else 39 | mode = 'ApproachingGreenLight'; 40 | end 41 | else 42 | mode = 'CarFollowing'; 43 | end 44 | end 45 | 46 | 47 | end 48 | 49 | function acc = carFollowing(obj,spacing,speed,speedDiff) 50 | acc = drivingBehavior.intelligentDriverModel(spacing,speed,speedDiff); 51 | aMin = (0-obj.Speed)/obj.Scenario.SampleTime; 52 | acc = max(acc,aMin); 53 | end 54 | 55 | end 56 | end 57 | 58 | -------------------------------------------------------------------------------- /OpenTrafficLab/createDrivingScenario.m: -------------------------------------------------------------------------------- 1 | function [scenario] = createDrivingScenario() 2 | % Copyright 2020 The MathWorks, Inc. 3 | % createDrivingScenario Returns the drivingScenario defined in the Designer 4 | 5 | % Generated by MATLAB(R) 9.9 (R2020b) and Automated Driving Toolbox 3.2 (R2020b). 6 | % Generated on: 17-Jun-2020 13:50:15 7 | 8 | % Construct a drivingScenario object. 9 | scenario = drivingScenario; 10 | 11 | % Add all road segments 12 | roadCenters = [19.8 -19.2 0; 13 | 42 4.8 0; 14 | 30.2 5.8 0; 15 | 6.7 18.9 0; 16 | -3.8 -10.8 0; 17 | 7.7 -25.6 0]; 18 | laneSpecification = lanespec(2); 19 | road(scenario, roadCenters, 'Lanes', laneSpecification, 'Name', 'Road'); 20 | 21 | % Add the ego vehicle 22 | 23 | 24 | -------------------------------------------------------------------------------- /OpenTrafficLab/createTJunctionNetwork.m: -------------------------------------------------------------------------------- 1 | function [net] = createTJunctionNetwork(scenario) 2 | % Copyright 2020 The MathWorks, Inc. 3 | %Create the Network of Node Objects 4 | % First Road: 5 | for i =1:3 6 | net(i) = Node(scenario,scenario.RoadSegments(i),-1); 7 | end 8 | 9 | for i =1:3 10 | net(3+i) = Node(scenario,scenario.RoadSegments(i),1); 11 | end 12 | 13 | for i = 1:6 14 | rs = scenario.RoadSegments(i+3); 15 | net(i+6) = Node(scenario,rs,1); 16 | end 17 | 18 | net(1).ConnectsTo = net(7); 19 | net(1).ConnectsTo = net(8); 20 | net(1).SharesRoadWith = net(4); 21 | 22 | net(2).ConnectsTo = net(11); 23 | net(2).ConnectsTo = net(12); 24 | net(2).SharesRoadWith = net(5); 25 | 26 | net(3).ConnectsTo = net(9); 27 | net(3).ConnectsTo = net(10); 28 | net(3).SharesRoadWith = net(6); 29 | 30 | net(10).ConnectsTo = net(4); 31 | net(11).ConnectsTo = net(4); 32 | net(10).SharesRoadWith = net([7:9,11,12]); 33 | net(11).SharesRoadWith = net([7:11,12]); 34 | 35 | net(8).ConnectsTo = net(5); 36 | net(9).ConnectsTo = net(5); 37 | net(8).SharesRoadWith = net([7,9:12]); 38 | net(9).SharesRoadWith = net([7,8,10:12]); 39 | 40 | 41 | net(7).ConnectsTo = net(6); 42 | net(12).ConnectsTo = net(6); 43 | net(7).SharesRoadWith = net([8:12]); 44 | net(12).SharesRoadWith = net([7:11]); 45 | 46 | 47 | end 48 | 49 | -------------------------------------------------------------------------------- /OpenTrafficLab/createTJunctionScenario.m: -------------------------------------------------------------------------------- 1 | function [scenario, egoVehicle,net] = createTJunctionScenario() 2 | % Copyright 2020 The MathWorks, Inc. 3 | % createDrivingScenario Returns the drivingScenario defined in the Designer 4 | 5 | % Generated by MATLAB(R) 9.9 (R2020b) and Automated Driving Toolbox 3.2 (R2020b). 6 | % Generated on: 16-Jun-2020 13:01:51 7 | 8 | % Construct a drivingScenario object. 9 | scenario = drivingScenario; 10 | roadLength = [50 50 50]; 11 | % Add all road segments 12 | roadCenters = [-401.4941 2185.809 0; 13 | -414.7725 2170.853 0]; 14 | roadDirection = diff(roadCenters)./norm(diff(roadCenters)); 15 | roadCenters(2,:) = roadCenters(1,:)+roadLength(1)*roadDirection; 16 | marking = [laneMarking('Unmarked') 17 | laneMarking('Unmarked') 18 | laneMarking('Solid', 'Width', 0.13) 19 | laneMarking('Dashed', 'Width', 0.13) 20 | laneMarking('Solid', 'Width', 0.13) 21 | laneMarking('Unmarked') 22 | laneMarking('Unmarked')]; 23 | laneSpecification = lanespec([3 3], 'Width', [1.5 0.15 3.65 3.65 0.15 1.5], 'Marking', marking); 24 | road(scenario, roadCenters, 'Lanes', laneSpecification); 25 | 26 | roadCenters = [-402.2 2206.4 0; 27 | -417.7083 2220.263 0]; 28 | roadDirection = diff(roadCenters)./norm(diff(roadCenters)); 29 | roadCenters(2,:) = roadCenters(1,:)+roadLength(2)*roadDirection; 30 | marking = [laneMarking('Unmarked') 31 | laneMarking('Unmarked') 32 | laneMarking('Solid', 'Width', 0.13) 33 | laneMarking('Dashed', 'Width', 0.13) 34 | laneMarking('Solid', 'Width', 0.13) 35 | laneMarking('Unmarked') 36 | laneMarking('Unmarked')]; 37 | laneSpecification = lanespec([3 3], 'Width', [1.5 0.15 3.65 3.65 0.15 1.5], 'Marking', marking); 38 | road(scenario, roadCenters, 'Lanes', laneSpecification); 39 | 40 | roadCenters = [-381.8 2208.1 0; 41 | -368.298 2223.199 0]; 42 | roadDirection = diff(roadCenters)./norm(diff(roadCenters)); 43 | roadCenters(2,:) = roadCenters(1,:)+roadLength(3)*roadDirection; 44 | marking = [laneMarking('Unmarked') 45 | laneMarking('Unmarked') 46 | laneMarking('Solid', 'Width', 0.13) 47 | laneMarking('Dashed', 'Width', 0.13) 48 | laneMarking('Solid', 'Width', 0.13) 49 | laneMarking('Unmarked') 50 | laneMarking('Unmarked')]; 51 | laneSpecification = lanespec([3 3], 'Width', [1.5 0.15 3.65 3.65 0.15 1.5], 'Marking', marking); 52 | road(scenario, roadCenters, 'Lanes', laneSpecification); 53 | 54 | %rg = driving.scenario.RoadGroup('Name', 'T junction'); 55 | roadCenters = [-401.4941 2185.809 0; 56 | -381.5764 2208.243 0]; 57 | marking = [laneMarking('Unmarked') 58 | laneMarking('Unmarked') 59 | laneMarking('Solid', 'Width', 0.13) 60 | laneMarking('Dashed', 'Width', 0.13) 61 | laneMarking('Solid', 'Width', 0.13) 62 | laneMarking('Unmarked') 63 | laneMarking('Unmarked')]; 64 | laneSpecification = lanespec([3 3], 'Width', [1.5 0.15 3.65 3.65 0.15 1.5], 'Marking', marking); 65 | road(scenario, roadCenters, 'Lanes', laneSpecification); 66 | 67 | roadCenters = [-401.4941 2185.809 0; 68 | -400.5184 2186.908 0; 69 | -398.6231 2189.23 0; 70 | -399.4963 2203.927 0; 71 | -401.6534 2206.009 0; 72 | -402.7523 2206.984 0]; 73 | marking = [laneMarking('Unmarked') 74 | laneMarking('Unmarked') 75 | laneMarking('Solid', 'Width', 0.13) 76 | laneMarking('Dashed', 'Width', 0.13) 77 | laneMarking('Solid', 'Width', 0.13) 78 | laneMarking('Unmarked') 79 | laneMarking('Unmarked')]; 80 | laneSpecification = lanespec([3 3], 'Width', [1.5 0.15 3.65 3.65 0.15 1.5], 'Marking', marking); 81 | road(scenario, roadCenters, 'Lanes', laneSpecification); 82 | 83 | roadCenters = [-381.5764 2208.243 0; 84 | -382.5521 2207.144 0; 85 | -384.6341 2204.987 0; 86 | -399.3306 2204.114 0; 87 | -401.6534 2206.009 0; 88 | -402.7523 2206.984 0]; 89 | marking = [laneMarking('Unmarked') 90 | laneMarking('Unmarked') 91 | laneMarking('Solid', 'Width', 0.13) 92 | laneMarking('Dashed', 'Width', 0.13) 93 | laneMarking('Solid', 'Width', 0.13) 94 | laneMarking('Unmarked') 95 | laneMarking('Unmarked')]; 96 | laneSpecification = lanespec([3 3], 'Width', [1.5 0.15 3.65 3.65 0.15 1.5], 'Marking', marking); 97 | road(scenario, roadCenters, 'Lanes', laneSpecification); 98 | 99 | roadCenters = [-381.5764 2208.243 0; 100 | -401.4941 2185.809 0]; 101 | marking = [laneMarking('Unmarked') 102 | laneMarking('Unmarked') 103 | laneMarking('Solid', 'Width', 0.13) 104 | laneMarking('Dashed', 'Width', 0.13) 105 | laneMarking('Solid', 'Width', 0.13) 106 | laneMarking('Unmarked') 107 | laneMarking('Unmarked')]; 108 | laneSpecification = lanespec([3 3], 'Width', [1.5 0.15 3.65 3.65 0.15 1.5], 'Marking', marking); 109 | road(scenario, roadCenters, 'Lanes', laneSpecification); 110 | 111 | roadCenters = [-402.7523 2206.984 0; 112 | -401.6534 2206.009 0; 113 | -399.4963 2203.927 0; 114 | -398.6231 2189.23 0; 115 | -400.5184 2186.908 0; 116 | -401.4941 2185.809 0]; 117 | marking = [laneMarking('Unmarked') 118 | laneMarking('Unmarked') 119 | laneMarking('Solid', 'Width', 0.13) 120 | laneMarking('Dashed', 'Width', 0.13) 121 | laneMarking('Solid', 'Width', 0.13) 122 | laneMarking('Unmarked') 123 | laneMarking('Unmarked')]; 124 | laneSpecification = lanespec([3 3], 'Width', [1.5 0.15 3.65 3.65 0.15 1.5], 'Marking', marking); 125 | road(scenario, roadCenters, 'Lanes', laneSpecification); 126 | 127 | roadCenters = [-402.7523 2206.984 0; 128 | -401.6534 2206.009 0; 129 | -399.3306 2204.114 0; 130 | -384.6341 2204.987 0; 131 | -382.5521 2207.144 0; 132 | -381.5764 2208.243 0]; 133 | marking = [laneMarking('Unmarked') 134 | laneMarking('Unmarked') 135 | laneMarking('Solid', 'Width', 0.13) 136 | laneMarking('Dashed', 'Width', 0.13) 137 | laneMarking('Solid', 'Width', 0.13) 138 | laneMarking('Unmarked') 139 | laneMarking('Unmarked')]; 140 | laneSpecification = lanespec([3 3], 'Width', [1.5 0.15 3.65 3.65 0.15 1.5], 'Marking', marking); 141 | road(scenario, roadCenters, 'Lanes', laneSpecification); 142 | 143 | 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /OpenTrafficLab/createVehiclesForTJunction.m: -------------------------------------------------------------------------------- 1 | function cars = createVehiclesForTJunction(s,net,InjectionRate,TurnRatio) 2 | % Copyright 2020 The MathWorks, Inc. 3 | % rng(33); 4 | numEntryRoads = 3; 5 | numTurns = 2; 6 | 7 | [s,net,InjectionRate,TurnRatio] = checkInputs(s,net,InjectionRate,TurnRatio); 8 | 9 | cars = driving.scenario.Vehicle.empty; 10 | 11 | for i=1:3 12 | entryTimes = generatePoissonEntryTimes(s.SimulationTime,s.StopTime,InjectionRate(i)); 13 | 14 | for entryTime =entryTimes 15 | j = discretize(rand, [0 cumsum(TurnRatio)]); 16 | path = [net(i), net(i).ConnectsTo(j), net(i).ConnectsTo(j).ConnectsTo(1)]; 17 | pos = path(1).getRoadCenterFromStation(0); 18 | [station,direction,offset]=path(1).getStationDistance(pos(1:2)); 19 | car = vehicle(s,'Position',pos,'EntryTime',entryTime,'Velocity',[10,0,0]); 20 | car.ForwardVector = [direction,0]; 21 | % ms = DrivingStrategy(car,'NextNode',path); 22 | ms = DrivingStrategyRL(car,'NextNode',path,'DesiredSpeed',10); 23 | cars(end+1)=car; 24 | end 25 | end 26 | 27 | end 28 | 29 | function entryTimes = generatePoissonEntryTimes(tMin,tMax,mu) 30 | 31 | t=tMin; 32 | minHeadway = 1; 33 | entryTimes = []; 34 | while tTM 34 | - [Automated Driving Toolbox](https://www.mathworks.com/products/automated-driving.html)TM 35 | - [Parallel Computing Toolbox](https://www.mathworks.com/products/parallel-computing.html)TM 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Reporting Security Vulnerabilities 2 | 3 | If you believe you have discovered a security vulnerability, please report it to 4 | [security@mathworks.com](mailto:security@mathworks.com). Please see 5 | [MathWorks Vulnerability Disclosure Policy for Security Researchers](https://www.mathworks.com/company/aboutus/policies_statements/vulnerability-disclosure-policy.html) 6 | for additional information. 7 | -------------------------------------------------------------------------------- /checkCollision.m: -------------------------------------------------------------------------------- 1 | function IsHit = checkCollision(this) 2 | % Copyright 2020 The MathWorks, Inc. 3 | driver = []; 4 | network = this.network; 5 | IsHit = false; 6 | 7 | % get pairwise distance between all of cars within the intersection 8 | for i = 7:size(network, 2) 9 | if ~isempty(network(i).Vehicles) 10 | driver = [driver, network(i).Vehicles.MotionStrategy]; 11 | end 12 | end 13 | % find whether there is cars distance below the threshould 14 | if length(driver) > 1 15 | dist = pdist(driver.getPosition); 16 | if min(dist) < this.safeDistance 17 | IsHit = true; 18 | end 19 | end 20 | % visulization 21 | % if IsHit 22 | % disp('OMG Car Collision!!') 23 | % disp(driver.getPosition) 24 | % end 25 | 26 | -------------------------------------------------------------------------------- /createDQN.m: -------------------------------------------------------------------------------- 1 | function agent = createDQN(obsInfo, actInfo) 2 | % Copyright 2020 The MathWorks, Inc. 3 | % policy representation by Neural Network 4 | layers = [ 5 | imageInputLayer([obsInfo.Dimension(1) 1 1],"Name","observations","Normalization","none") 6 | fullyConnectedLayer(256,"Name","obs_fc1") 7 | reluLayer("Name","obs_relu1") 8 | fullyConnectedLayer(256,"Name","obs_fc2") 9 | reluLayer("Name","obs_relu2") 10 | fullyConnectedLayer(3,"Name","Q") 11 | ]; 12 | lgraph = layerGraph(layers); 13 | % visualization 14 | figure 15 | plot(lgraph) 16 | 17 | % critic options 18 | criticOpts = rlRepresentationOptions('LearnRate',5e-03,'GradientThreshold',1); 19 | criticOpts.Optimizer = 'sgdm'; 20 | criticOpts.UseDevice = 'cpu'; 21 | % create critic function 22 | critic = rlQValueRepresentation(lgraph,obsInfo,actInfo,... 23 | 'Observation',{'observations'},criticOpts); 24 | % agent options 25 | agentOpts = rlDQNAgentOptions; 26 | agentOpts.EpsilonGreedyExploration.EpsilonDecay = 1e-4; 27 | agentOpts.DiscountFactor = 0.99; 28 | agentOpts.TargetUpdateFrequency = 1; 29 | % create agent 30 | agent = rlDQNAgent(critic, agentOpts); 31 | end -------------------------------------------------------------------------------- /createTrainOpts.m: -------------------------------------------------------------------------------- 1 | function opts = createTrainOpts() 2 | % Copyright 2020 The MathWorks, Inc. 3 | opts = rlTrainingOptions; 4 | 5 | opts.MaxEpisodes = 2000; 6 | opts.MaxStepsPerEpisode = 10000; 7 | opts.StopTrainingCriteria = "AverageReward"; 8 | opts.StopTrainingValue = 550; 9 | opts.ScoreAveragingWindowLength = 5; 10 | 11 | opts.SaveAgentCriteria = "EpisodeReward"; 12 | opts.SaveAgentValue = 800; 13 | opts.SaveAgentDirectory = "savedAgents"; 14 | 15 | opts.UseParallel = true; 16 | opts.ParallelizationOptions.Mode = "async"; 17 | opts.ParallelizationOptions.DataToSendFromWorkers = "experiences"; 18 | opts.ParallelizationOptions.StepsUntilDataIsSent = 30; 19 | opts.ParallelizationOptions.WorkerRandomSeeds = -1; 20 | 21 | opts.Verbose = false; 22 | opts.Plots = "training-progress"; 23 | -------------------------------------------------------------------------------- /license: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, The MathWorks, Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 3. In all cases, the software is, and all modifications and derivatives of the software shall be, licensed to you solely for use in conjunction with MathWorks products and service offerings. 9 | 10 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 11 | -------------------------------------------------------------------------------- /master.m: -------------------------------------------------------------------------------- 1 | % Copyright 2020 The MathWorks, Inc. 2 | % Author: xiangxuezhao@gmail.com 3 | % Last modified: 08-27-2020 4 | % add traffic simulator: OpenTrafficLab to path 5 | folderName = fullfile(cd, 'OpenTrafficLab'); 6 | addpath(folderName) 7 | %% Step 1: create RL environment from Matlab template 8 | env = DrivingScenarioEnv; 9 | %% specifiy traffic problem formulation 10 | % specifiy action space 11 | env.TrafficSignalDesign = 1; % or 2, 3 12 | % the dimensio of signal phase is 3 for design 1 and 2, 4 for design 3, as 13 | % shown in the figure 14 | SignalPhaseDim = 3; 15 | env.phaseDuration = 50; 16 | env.clearingPhase = true; 17 | env.clearingPhaseTime = 5; 18 | % specifiy observation space 19 | env.ObservationSpaceDesign = 1; % or 2 20 | % specify reward parameter 21 | % The car's speed below the threshold will be treated as waiting at the 22 | % intersection 23 | slowSpeedThreshold = 3.5; 24 | % Add penalty for frequent/unnecessary signal phase switch 25 | penaltyForFreqSwitch = 1; 26 | % parameter for car collision 27 | env.hitPenalty = 20; 28 | env.safeDistance = 2.25; 29 | % reward for car pass the intersection 30 | env.rewardForPass = 10; 31 | 32 | % obtain observation and action info 33 | obsInfo = getObservationInfo(env); 34 | actInfo = getActionInfo(env); 35 | %% Step 3: creat DQN agent 36 | % policy representation by Neural Network 37 | layers = [ 38 | imageInputLayer([obsInfo.Dimension(1) 1 1],"Name","observations","Normalization","none") 39 | fullyConnectedLayer(256,"Name","obs_fc1") 40 | reluLayer("Name","obs_relu1") 41 | fullyConnectedLayer(256,"Name","obs_fc2") 42 | reluLayer("Name","obs_relu2") 43 | fullyConnectedLayer(SignalPhaseDim,"Name","Q") 44 | ]; 45 | lgraph = layerGraph(layers); 46 | 47 | figure 48 | plot(lgraph) 49 | 50 | % critic options 51 | criticOpts = rlRepresentationOptions('LearnRate',5e-03,'GradientThreshold',1); 52 | criticOpts.Optimizer = 'sgdm'; 53 | criticOpts.UseDevice = 'cpu'; 54 | % create critic function 55 | critic = rlQValueRepresentation(lgraph,obsInfo,actInfo,... 56 | 'Observation',{'observations'},criticOpts); 57 | 58 | % agent options 59 | agentOpts = rlDQNAgentOptions; 60 | agentOpts.EpsilonGreedyExploration.EpsilonDecay = 1e-4; 61 | agentOpts.DiscountFactor = 0.99; 62 | agentOpts.TargetUpdateFrequency = 1; 63 | 64 | % create agent 65 | agent = rlDQNAgent(critic, agentOpts); 66 | %% Step 4: train agent 67 | % specify training option 68 | trainOpts = rlTrainingOptions; 69 | 70 | trainOpts.MaxEpisodes = 2000; 71 | trainOpts.MaxStepsPerEpisode = 1000; 72 | trainOpts.StopTrainingCriteria = "AverageReward"; 73 | trainOpts.StopTrainingValue = 550; 74 | trainOpts.ScoreAveragingWindowLength = 5; 75 | 76 | trainOpts.SaveAgentCriteria = "EpisodeReward"; 77 | trainOpts.SaveAgentValue = 800; 78 | trainOpts.SaveAgentDirectory = "savedAgents"; 79 | 80 | trainOpts.UseParallel = false; 81 | trainOpts.ParallelizationOptions.Mode = "async"; 82 | trainOpts.ParallelizationOptions.DataToSendFromWorkers = "experiences"; 83 | trainOpts.ParallelizationOptions.StepsUntilDataIsSent = 30; 84 | trainOpts.ParallelizationOptions.WorkerRandomSeeds = -1; 85 | 86 | trainOpts.Verbose = false; 87 | trainOpts.Plots = "training-progress"; 88 | 89 | % train agent or load existing trained agent 90 | doTraining = false; 91 | 92 | if doTraining 93 | % Train the agent. 94 | trainingInfo = train(agent,env,trainOpts); 95 | else 96 | % Load the pretrained agent for the example. 97 | folderName = cd; % change current folder 98 | folderName = fullfile(folderName, 'savedAgents'); 99 | filename = strcat('TjunctionDQNAgentDesign', num2str(env.TrafficSignalDesign), '.mat'); 100 | file = fullfile(folderName, filename); 101 | load(file); 102 | end 103 | %% Step 5: simulate agent 104 | simOpts = rlSimulationOptions('MaxSteps',1000); 105 | experience = sim(env,agent,simOpts); 106 | totalReward = sum(experience.Reward); 107 | -------------------------------------------------------------------------------- /masterLiveScript.mlx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/masterLiveScript.mlx -------------------------------------------------------------------------------- /observationSpace1.m: -------------------------------------------------------------------------------- 1 | function state = observationSpace1(env, curPhase) 2 | % Copyright 2020 The MathWorks, Inc. 3 | % Observation space 1: (3*3 + 1 = 10) 4 | % Each road: 5 | % - head car distance to intersection 6 | % - head car velocity 7 | % - number of cars 8 | % Current signal phase 9 | 10 | % current signal phase 11 | state = [curPhase]; 12 | % observation of each of the road 13 | for i = 1:env.N 14 | % get obs from each road 15 | if isempty(env.network(i).Vehicles) 16 | % if no car 17 | obs = [env.network(i).Length, 0, 0]; 18 | else 19 | % find the head car 20 | car = headCar(env.network(i)); 21 | % head car distance 22 | dist = env.network(i).Length - car.getStationDistance; 23 | % head car velocity 24 | vel = car.getSpeed; 25 | % number of cars 26 | numCars = length(env.network(i).Vehicles); 27 | obs = [dist, vel, numCars]; 28 | end 29 | % merge observations from each road 30 | state = [state, obs]; 31 | end 32 | 33 | % set up the lower and upper limit for distance, velocity and number of 34 | % cars 35 | carLow = [0, 0, 0]; 36 | carHigh = [50, 15, 10]; 37 | lowLimit = [0 carLow carLow carLow]; 38 | highLimit = [2 carHigh carHigh carHigh]; 39 | 40 | % normalize the observation 41 | state = (state - lowLimit) ./ (highLimit - lowLimit); 42 | end 43 | 44 | function headcar = headCar(net) 45 | % find the head car distance to the intersection 46 | cars = [net.Vehicles.MotionStrategy]; 47 | % find the head car index 48 | index = headCarIndex(net, cars); 49 | % head car 50 | headcar = cars(index); 51 | end 52 | 53 | function Index = headCarIndex(net, cars) 54 | Index = 1; 55 | if length(net.Vehicles) > 1 56 | travelDistance = cars.getStationDistance; 57 | [~, Index] = max(travelDistance); 58 | end 59 | end 60 | -------------------------------------------------------------------------------- /observationSpace2.m: -------------------------------------------------------------------------------- 1 | function state = observationSpace2(env, curPhase) 2 | % Copyright 2020 The MathWorks, Inc. 3 | % Observation space 2: (3*4 + 1 = 13) 4 | % Each road: 5 | % - head car distance to intersection 6 | % - head car velocity 7 | % - head car intention: right or left 8 | % - number of cars 9 | % Current signal phase 10 | 11 | % current signal phase 12 | state = [curPhase]; 13 | % observation of each of the road 14 | for i = 1:env.N 15 | % get obs from each road 16 | if isempty(env.network(i).Vehicles) 17 | % if no car 18 | obs = [env.network(i).Length, 0, 0, 0]; 19 | else 20 | % find the head car 21 | car = headCar(env.network(i)); 22 | % head car distance 23 | dist = env.network(i).Length - car.getStationDistance; 24 | % head car velocity 25 | vel = car.getSpeed; 26 | % head car intention: right or left 27 | NextNode = car.NextNode(1); 28 | intent = -1; % default is turn left 29 | % update if it turns right 30 | if i == 1 31 | if NextNode == env.network(7) 32 | intent = 1; 33 | end 34 | elseif i == 2 35 | if NextNode == env.network(11) 36 | intent = 1; 37 | end 38 | elseif i == 3 39 | if NextNode == env.network(9) 40 | intent = 1; 41 | end 42 | end 43 | % number of cars 44 | numCars = length(env.network(i).Vehicles); 45 | obs = [dist, vel, numCars, intent]; 46 | end 47 | % merge observations from each road 48 | state = [state, obs]; 49 | end 50 | 51 | % set up the lower and upper limit for distance, velocity and number of 52 | % cars 53 | carLow = [0, 0, 0, -1]; 54 | carHigh = [50, 15, 10, 1]; 55 | lowLimit = [0 carLow carLow carLow]; 56 | highLimit = [2 carHigh carHigh carHigh]; 57 | 58 | % normalize the observation 59 | state = (state - lowLimit) ./ (highLimit - lowLimit); 60 | end 61 | 62 | function headcar = headCar(net) 63 | % find the head car distance to the intersection 64 | cars = [net.Vehicles.MotionStrategy]; 65 | % find the head car index 66 | index = headCarIndex(net, cars); 67 | % head car 68 | headcar = cars(index); 69 | end 70 | 71 | function Index = headCarIndex(net, cars) 72 | Index = 1; 73 | if length(net.Vehicles) > 1 74 | travelDistance = cars.getStationDistance; 75 | [~, Index] = max(travelDistance); 76 | end 77 | end -------------------------------------------------------------------------------- /obtainReward.m: -------------------------------------------------------------------------------- 1 | function reward = obtainReward(this, phase) 2 | % Copyright 2020 The MathWorks, Inc. 3 | reward = 0; 4 | driver = []; 5 | for i = 1:this.N 6 | % get all of the drivers from each of the road 7 | if ~isempty(this.network(i).Vehicles) 8 | driver = [this.network(i).Vehicles.MotionStrategy]; 9 | end 10 | if isempty(driver) 11 | continue 12 | end 13 | % component 1: judge by the distance to intersection 14 | % distance_to_intersection = network(i).Length - driver.getStationDistance; 15 | % reward = reward - sum(distance_to_intersection < this.thresholdDistance); 16 | 17 | % component 2: waiting time 18 | % speed of the car is lower that specified speed limit 19 | % is the one get delay 20 | speed = driver.getSpeed; 21 | reward = reward - sum(speed < this.slowSpeedThreshold) * this.scenario.SampleTime; 22 | 23 | % strategy 3: maximize the cars speed 24 | speed = driver.getSpeed; 25 | reward = reward + sum(speed) * 0.01; 26 | 27 | end 28 | % component 4: get reward when car are entering the intersection 29 | for i = 7:12 30 | if isempty(this.network(i).Vehicles) 31 | continue 32 | end 33 | for j = 1:length(this.network(i).Vehicles) 34 | if ~ismember(this.network(i).Vehicles(j), this.vehicleEnterJunction) 35 | reward = reward + this.rewardForPass; 36 | this.vehicleEnterJunction = [this.vehicleEnterJunction this.network(i).Vehicles(j)]; 37 | end 38 | end 39 | end -------------------------------------------------------------------------------- /savedAgents/TjunctionDQNAgentDesign1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedAgents/TjunctionDQNAgentDesign1.mat -------------------------------------------------------------------------------- /savedAgents/TjunctionDQNAgentDesign2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedAgents/TjunctionDQNAgentDesign2.mat -------------------------------------------------------------------------------- /savedAgents/TjunctionDQNAgentDesign2Training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedAgents/TjunctionDQNAgentDesign2Training.png -------------------------------------------------------------------------------- /savedAgents/TjunctionDQNAgentDesign3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedAgents/TjunctionDQNAgentDesign3.mat -------------------------------------------------------------------------------- /savedAgents/TjunctionDQNAgentDesign3Training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedAgents/TjunctionDQNAgentDesign3Training.png -------------------------------------------------------------------------------- /savedFigures/TjunctionRLControl1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedFigures/TjunctionRLControl1.gif -------------------------------------------------------------------------------- /savedFigures/TjunctionRLControl2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedFigures/TjunctionRLControl2.gif -------------------------------------------------------------------------------- /savedFigures/TjunctionRLcontrol.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedFigures/TjunctionRLcontrol.gif -------------------------------------------------------------------------------- /savedFigures/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedFigures/workflow.png -------------------------------------------------------------------------------- /savedTestExperience/RLTrafficControlDesign1.fig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedTestExperience/RLTrafficControlDesign1.fig -------------------------------------------------------------------------------- /savedTestExperience/RLTrafficControlDesign2.fig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedTestExperience/RLTrafficControlDesign2.fig -------------------------------------------------------------------------------- /savedTestExperience/comparision.fig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedTestExperience/comparision.fig -------------------------------------------------------------------------------- /savedTestExperience/comparision.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedTestExperience/comparision.png -------------------------------------------------------------------------------- /savedTestExperience/comparison2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedTestExperience/comparison2.png -------------------------------------------------------------------------------- /savedTestExperience/createTestPlot.m: -------------------------------------------------------------------------------- 1 | % Copyright 2020 The MathWorks, Inc. 2 | % load simulated result 3 | folderName = cd; 4 | fileName = 'rewardforEachTest.mat'; 5 | file = fullfile(folderName, fileName); 6 | load(file) 7 | % calculate cumulative reward for each traffic control 8 | cumReward = zeros(4, 40); 9 | for i = 1: 4 10 | for j = 1:40 11 | cumReward(i, j) = sum(data(i, 1:j)); 12 | end 13 | end 14 | % visualization 15 | figure 16 | hold on 17 | plot(1:40, cumReward(1:3, :), "LineWidth", 1.5) 18 | plot(1:40, cumReward(4, :), 'k-', "LineWidth", 1.5) 19 | legend('RL traffic signal design 1', 'RL traffic signal design 2','RL traffic signal design 3','Fixed time signal control') 20 | xlabel('Time') 21 | ylabel('Cumulative reward') 22 | title('Traffic Performance Comparison') -------------------------------------------------------------------------------- /savedTestExperience/fixTimeControl.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedTestExperience/fixTimeControl.mat -------------------------------------------------------------------------------- /savedTestExperience/phaseDesign1exp.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedTestExperience/phaseDesign1exp.mat -------------------------------------------------------------------------------- /savedTestExperience/phaseDesign2exp.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedTestExperience/phaseDesign2exp.mat -------------------------------------------------------------------------------- /savedTestExperience/phaseDesign3exp.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedTestExperience/phaseDesign3exp.mat -------------------------------------------------------------------------------- /savedTestExperience/rewardforEachTest.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedTestExperience/rewardforEachTest.mat -------------------------------------------------------------------------------- /savedVideos/RLTrained.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedVideos/RLTrained.gif -------------------------------------------------------------------------------- /savedVideos/RLlearningStage.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedVideos/RLlearningStage.gif -------------------------------------------------------------------------------- /savedVideos/RLlearningStage_linkedin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/rl-agent-based-traffic-control/0f354091fd4e182698f02488e38de3bace8bbf4b/savedVideos/RLlearningStage_linkedin.jpg -------------------------------------------------------------------------------- /signalPhaseDesign1.m: -------------------------------------------------------------------------------- 1 | function phase = signalPhaseDesign1(action) 2 | % Copyright 2020 The MathWorks, Inc. 3 | % signal phase design 1: each phase has two lanes 4 | if action == 0 5 | phase = [0, 0, 1, 1, 0, 0]; 6 | end 7 | if action == 1 8 | phase = [1, 1, 0, 0, 0, 0]; 9 | end 10 | if action == 2 11 | phase = [0, 0, 0, 0, 1, 1]; 12 | end 13 | end -------------------------------------------------------------------------------- /signalPhaseDesign2.m: -------------------------------------------------------------------------------- 1 | function phase = signalPhaseDesign2(action) 2 | % Copyright 2020 The MathWorks, Inc. 3 | % signal phase design 2: each phase has three lanes 4 | if action == 0 5 | phase = [1, 0, 1, 1, 0, 0]; 6 | end 7 | if action == 1 8 | phase = [1, 1, 0, 0, 1, 0]; 9 | end 10 | if action == 2 11 | phase = [0, 0, 1, 0, 1, 1]; 12 | end 13 | end -------------------------------------------------------------------------------- /signalPhaseDesign3.m: -------------------------------------------------------------------------------- 1 | function phase = signalPhaseDesign3(action) 2 | % Copyright 2020 The MathWorks, Inc. 3 | % signal phase design 3: each phase has three lanes 4 | if action == 0 5 | phase = [1, 0, 1, 0, 1, 0]; 6 | end 7 | if action == 1 8 | phase = [1, 0, 0, 1, 0, 0]; 9 | end 10 | if action == 2 11 | phase = [0, 1, 0, 0, 1, 0]; 12 | end 13 | if action == 3 14 | phase = [0, 0, 1, 0, 0, 1]; 15 | end 16 | end --------------------------------------------------------------------------------