├── .gitignore ├── .scalafmt.conf ├── README.md ├── build-site.sh ├── build.sbt ├── data └── pacman │ └── Q.json ├── gridworld.html ├── index.html ├── pacman.html ├── polecart-human.html ├── polecart-qlearning.html ├── project ├── build.properties └── plugins.sbt └── src └── main └── scala └── rl ├── core ├── ActionResult.scala ├── AgentBehaviour.scala ├── Environment.scala ├── QLearning.scala ├── StateConversion.scala └── package.scala ├── gridworld ├── core │ └── GridworldProblem.scala └── ui │ └── GridworldUI.scala ├── pacman ├── core │ └── PacmanProblem.scala ├── training │ ├── PacmanTraining.scala │ └── QKeyValue.scala └── ui │ └── PacmanUI.scala └── polecart ├── core └── PoleBalancingProblem.scala └── ui ├── HumanUI.scala └── QLearningUI.scala /.gitignore: -------------------------------------------------------------------------------- 1 | pacman-training/ 2 | -------------------------------------------------------------------------------- /.scalafmt.conf: -------------------------------------------------------------------------------- 1 | align = true 2 | maxColumn = 100 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reinforcement Learning in Scala 2 | 3 | This repo contains the source code for the demos to accompany my talk 4 | 'Reinforcement Learning in Scala'. 5 | 6 | The slides are available 7 | [here](https://slides.com/cb372/reinforcement-learning-in-scala). 8 | 9 | The demos are available[here](https://cb372.github.io/rl-in-scala/). 10 | 11 | ## Running locally 12 | 13 | The demos are implemented using Scala.js, so first you need to build the 14 | JavaScript: 15 | 16 | ``` 17 | $ sbt fastOptJS 18 | ``` 19 | 20 | Next, start a simple web server of your choice. I use the Python one: 21 | 22 | ``` 23 | $ python -m SimpleHTTPServer 24 | Serving HTTP on 0.0.0.0 port 8000 ... 25 | ``` 26 | 27 | Finally open the site in your browser: 28 | 29 | ``` 30 | $ open localhost:8000 31 | ``` 32 | 33 | ## Pacman training 34 | 35 | If you'd like to try your hand at making the Pacman agent smarter, the expected 36 | workflow looks something like this: 37 | 38 | 1. Update 39 | [PacmanProblem.scala](src/main/scala/rl/pacman/core/PacmanProblem.scala) to 40 | improve the agent's state space, making it a more efficient learner. 41 | 42 | 2. Run the training harness: 43 | 44 | ``` 45 | $ sbt run 46 | ``` 47 | 48 | This will make the agent play a very large number of games of Pacman. It 49 | will run forever. Every 1 million time steps it will print out some stats to 50 | give an indicator of the agent's learning progress. Every five million time 51 | steps it will write the agent's Q-values to a JSON file in the 52 | `pacman-training` directory. 53 | 54 | 3. Once you have Q-values you are happy with, copy the JSON file to 55 | `data/pacman/Q.json`, overwriting the existing file. 56 | 57 | 4. Follow the steps above for running locally. Open the Pacman UI in your 58 | browser and watch your trained agent show those ghosts who's boss! 59 | 60 | ### Hints 61 | 62 | If you make your state space too large, you'll have a number of problems: 63 | 64 | * Your JSON file will probably be huge enough to crash your browser when the UI 65 | tries to load it. 66 | 67 | * The agent will learn very slowly because it needs to explore so many states. 68 | 69 | So the trick is to find a way of encoding enough information about the game 70 | state without the number of states exploding. e.g. if you were to track the 71 | exact locations of Pacman and both ghosts, you already have 65 x 65 x 65 = 72 | 274,675 states to deal with. 73 | 74 | Your state encoding should also make sense when combined with the reward 75 | function. For example, the environment gives a reward when Pacman eats food, so 76 | intuitively the state should track food in some way. 77 | 78 | If your agent is struggling to win games, you could try: 79 | 80 | * Making the ghosts move more randomly by reducing their `smartMoveProb` 81 | 82 | * Making a smaller grid, maybe with only one ghost 83 | -------------------------------------------------------------------------------- /build-site.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | sbt clean fullOptJS 6 | 7 | echo "Copying reinforcement-learning-in-scala-opt.js" 8 | mkdir -p ../rl-in-scala/js 9 | cp target/scala-2.12/reinforcement-learning-in-scala-opt.js ../rl-in-scala/js 10 | 11 | for file in *.html; do 12 | echo "Copying $file" 13 | sed -e "s/target\/scala-2.12\/reinforcement-learning-in-scala-fastopt.js/js\/reinforcement-learning-in-scala-opt.js/" $file > ../rl-in-scala/$file 14 | done 15 | 16 | echo "Copying Pacman data dir" 17 | mkdir -p ../rl-in-scala/data 18 | cp -R data/pacman ../rl-in-scala/data 19 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | scalaVersion := "2.12.6" 2 | enablePlugins(ScalaJSPlugin) 3 | 4 | //scalaJSUseMainModuleInitializer := true 5 | libraryDependencies ++= Seq( 6 | "org.scala-js" %%% "scalajs-dom" % "0.9.6", 7 | "io.circe" %%% "circe-generic" % "0.10.1", 8 | "io.circe" %%% "circe-parser" % "0.10.1" 9 | ) 10 | -------------------------------------------------------------------------------- /data/pacman/Q.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /gridworld.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Gridworld 5 | 6 | 10 | 15 | 16 | 17 | 37 | 38 |
39 | 40 |
41 |

Grid

42 | 43 | 44 | 45 |
46 |
Controls
47 |
48 | 49 | 50 | 51 |
52 |
53 |
54 | 55 |
56 |

Rules

57 | 64 | 65 |
66 | 67 |
68 |

Q(s, a)

69 |
70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 |
107 |
108 |
109 | 110 |
111 |

Policy

112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 |
150 |
151 | 152 |
153 | 154 | 155 | 164 | 165 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Reinforcement Learning in Scala 5 | 6 | 10 | 13 | 16 | 17 | 18 | 38 | 39 |
40 |

Reinforcement Learning in Scala

41 | 42 |

This site contains the demos for my 'Reinforcement Learning in Scala' talk.

43 | 44 |

Links

45 | 46 |

The slides for the talk are available here.

47 | 48 |

The source code for all the demos is available on GitHub.

49 | 50 |

Demos

51 | 52 |

53 | There are 3 demos, all of which use the same RL algorithm known as Q-learning. 54 |

55 | 56 | 61 | 62 |
63 |

This is a continous (non-episodic) problem with very simple rules:

64 | 71 |

72 | Of course, the optimal policy is to always move towards A in order to pick up the reward of 10. 73 | If you run the demo, you should see the agent gradually learn this policy. 74 |

75 |

76 | It may get stuck in a local minimum (i.e. preferring the B cell) for a while, 77 | but it is guaranteed to eventually converge on the optimal policy. 78 | This is because the agent constantly explores the state space using the ε-greedy algorithm. 79 |

80 |

81 | The big table under the grid shows the agent's current Q(s, a) for all state-action pairs. 82 | This is the estimate that the agent holds for being in state s and taking action a. 83 |

84 |

85 | The smaller table shows the same information summarised as a policy. 86 | In other words, for a given state, what action(s) the agent currently believes to be the best. 87 |

88 |
89 | 90 |
91 |

This episodic problem is a classic in RL literature.

92 | 93 |

94 | At every time step the agent must push the cart either to the left or the right. 95 | The goal is to stop the pole from toppling too far either to the left or the right, 96 | whilst also ensuring the cart does not crash into the walls. 97 |

98 | 99 |

The rules are as follows:

100 | 105 | 106 |

107 | It's fascinating to see how quickly the agent learns, especially bearing in mind: 108 |

109 |
    110 |
  1. 111 | Q-learning is a model-free algorithm, so the agent has no idea of the problem it's solving. 112 | It doesn't know anything about poles, carts, angular velocities, and so on. 113 | All it knows is that it has to pick one of two actions at every time step. 114 |
  2. 115 |
  3. 116 | The amount of feedback from the environment is very small. All the agent gets is a negative reward 117 | at the end of the episode. 118 |
  4. 119 |
120 | 121 |

122 | To get a feel for the problem, you might want to try it yourself first. 123 | Use the Left and Right arrow keys on your keyboard to move the cart. 124 |

125 | 126 |

127 | Next you can watch the agent learn. 128 | Use the buttons to run through a single time step, a single episode or continously. 129 |

130 |
131 | 132 |
133 |

This one is an exercise for the reader.

134 | 135 |

136 | The demo shows a very "dumb" agent. Its state space is enormous, so it has no chance of doing any meaningful learning. 137 |

138 | 139 |

140 | See if you can improve the agent by redesigning its state space and putting it through some training. 141 |

142 | 143 |

144 | Take a look at the README for more details. 145 |

146 |
147 | 148 |
149 | 150 | 153 | 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /pacman.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Pacman 5 | 6 | 10 | 11 | 12 | 32 | 33 |
34 |
35 | 36 | 37 |
38 | 39 | 40 | 47 | 48 | -------------------------------------------------------------------------------- /polecart-human.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Pole balancing - try it yourself 5 | 6 | 10 | 11 | 12 | 32 | 33 |
34 |
35 |
36 |

Pole balancing - try it yourself

37 | 38 | 39 | 40 |
41 |
42 |
43 |
44 | 45 | 46 | 53 | 54 | -------------------------------------------------------------------------------- /polecart-qlearning.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Pole balancing - watch it learn 5 | 6 | 10 | 13 | 16 | 37 | 38 | 39 | 59 | 60 |
61 |
62 |
63 |

Pole balancing - watch it learn

64 | 65 | 66 | 67 |
68 |
Controls
69 |
70 | 71 | 72 | 73 | 74 |
75 |
76 | 77 |
78 |
79 |
80 | 81 |
82 |
83 |

Q(s, a)

84 | 85 |
Tab names indicate cart position.
86 | 87 | 92 | 93 |
94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 |
Cart velocityFast leftSlowFast right
Pole velocityFast leftSlowFast rightFast leftSlowFast rightFast leftSlowFast right
Pole angleVery left
Quite left
Slightly left
Slightly right
Quite right
Very right
206 |
207 | 208 |
209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 |
Cart velocityFast leftSlowFast right
Pole velocityFast leftSlowFast rightFast leftSlowFast rightFast leftSlowFast right
Pole angleVery left
Quite left
Slightly left
Slightly right
Quite right
Very right
321 |
322 | 323 |
324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 |
Cart velocityFast leftSlowFast right
Pole velocityFast leftSlowFast rightFast leftSlowFast rightFast leftSlowFast right
Pole angleVery left
Quite left
Slightly left
Slightly right
Quite right
Very right
436 |
437 |
438 |
439 |
440 | 441 | 442 | 453 | 456 | 457 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.2.3 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("org.scala-js" % "sbt-scalajs" % "0.6.25") 2 | addSbtPlugin("com.geirsson" % "sbt-scalafmt" % "1.5.1") 3 | -------------------------------------------------------------------------------- /src/main/scala/rl/core/ActionResult.scala: -------------------------------------------------------------------------------- 1 | package rl.core 2 | 3 | /** 4 | * The results of the agent taking an action: 5 | * it receives a reward and ends up in a new state. 6 | */ 7 | case class ActionResult[State](reward: Reward, nextState: State) 8 | -------------------------------------------------------------------------------- /src/main/scala/rl/core/AgentBehaviour.scala: -------------------------------------------------------------------------------- 1 | package rl.core 2 | 3 | trait AgentBehaviour[AgentData, State, Action] { 4 | 5 | /** 6 | * Given an agent and the current state, asks the agent to choose the next action. 7 | * 8 | * Returns two things: 9 | * 10 | * 1. the action that the agent chose 11 | * 2. a function that, given the results of taking the action, 12 | * uses it to improve the agent's policy and thus returns a new version of the agent 13 | */ 14 | def chooseAction(agentData: AgentData, 15 | state: State, 16 | validActions: List[Action]): (Action, ActionResult[State] => AgentData) 17 | 18 | } 19 | -------------------------------------------------------------------------------- /src/main/scala/rl/core/Environment.scala: -------------------------------------------------------------------------------- 1 | package rl.core 2 | 3 | trait Environment[State, Action] { 4 | 5 | /** 6 | * Given the current state, what are the legal actions the agent can take? 7 | */ 8 | def possibleActions(currentState: State): List[Action] 9 | 10 | /** 11 | * Given the current state and the action chosen by the agent, 12 | * what state does the agent move into and what reward does it get? 13 | * 14 | * Things to note: 15 | * - The reward might be positive, negative or zero. 16 | * - The next state might be the same as the current state. 17 | * - Both the state transition function and the reward function may be stochastic, 18 | * meaning they follow some probability distribution and do not always 19 | * give the same output for a given input. 20 | */ 21 | def step(currentState: State, actionTaken: Action): (State, Reward) 22 | 23 | /** 24 | * Is the given state terminal or not? 25 | * For continuous (non-episodic) problems, this will always be false. 26 | */ 27 | def isTerminal(state: State): Boolean 28 | 29 | } 30 | -------------------------------------------------------------------------------- /src/main/scala/rl/core/QLearning.scala: -------------------------------------------------------------------------------- 1 | package rl.core 2 | 3 | import scala.util.Random 4 | 5 | case class QLearning[State, Action]( 6 | α: Double, // step size, 0.0 ≦ α ≦ 1.0, controls how much the agent updates its action-value function Q(s, a) 7 | γ: Double, // discount rate, 0.0 ≦ γ ≦ 1.0, controls how much the one-step backup affects Q(s, a) 8 | ε: Double, // 0.0 ≦ ε ≦ 1.0, probability of choosing a random action 9 | Q: Map[State, Map[Action, Double]] // the estimated action-value function Q(s, a) 10 | ) 11 | 12 | object QLearning { 13 | 14 | implicit def agentBehaviour[State, Action] 15 | : AgentBehaviour[QLearning[State, Action], State, Action] = 16 | new AgentBehaviour[QLearning[State, Action], State, Action] { 17 | 18 | def chooseAction( 19 | agentData: QLearning[State, Action], 20 | state: State, 21 | validActions: List[Action]): (Action, ActionResult[State] => QLearning[State, Action]) = { 22 | // Get Q(s, {a}), or initialise it arbitrarily to 0 for all actions if not initialised yet 23 | val actionValues = agentData.Q.getOrElse(state, validActions.map(_ -> 0.0).toMap) 24 | 25 | // choose the next action 26 | val (chosenAction, currentActionValue) = epsilonGreedy(actionValues, agentData.ε) 27 | 28 | // learn! 29 | val updateStateActionValue: ActionResult[State] => QLearning[State, Action] = { 30 | actionResult => 31 | val nextStateActionValues = 32 | agentData.Q.getOrElse(actionResult.nextState, validActions.map(_ -> 0.0).toMap) 33 | val maxNextStateActionValue = 34 | nextStateActionValues.values.fold(Double.MinValue)(_ max _) 35 | 36 | // Q(s_t, a_t) <- Q(s_t, a_t) + α (r_t+1 + γ max_a Q(s_t+1, a) - Q(s_t, a_t)) 37 | val updatedActionValue = 38 | currentActionValue + agentData.α * (actionResult.reward + agentData.γ * maxNextStateActionValue - currentActionValue) 39 | 40 | val updatedActionValues = actionValues + (chosenAction -> updatedActionValue) 41 | val updatedQ = agentData.Q + (state -> updatedActionValues) 42 | 43 | agentData.copy(Q = updatedQ) 44 | } 45 | 46 | (chosenAction, updateStateActionValue) 47 | } 48 | 49 | /* 50 | ε-greedy: choose one of the actions with the highest value most of the time (i.e. exploit) 51 | but choose an action randomly some of the time (i.e. explore) 52 | */ 53 | private def epsilonGreedy(actionValues: Map[Action, Double], ε: Double): (Action, Double) = { 54 | if (Random.nextDouble() < ε) { 55 | Random.shuffle(actionValues.toList).head 56 | } else { 57 | val sorted = actionValues.toList.sortBy(_._2).reverse 58 | val maxValue = sorted.head._2 59 | Random.shuffle(sorted.takeWhile(_._2 == maxValue)).head 60 | } 61 | } 62 | 63 | } 64 | 65 | } 66 | -------------------------------------------------------------------------------- /src/main/scala/rl/core/StateConversion.scala: -------------------------------------------------------------------------------- 1 | package rl.core 2 | 3 | trait StateConversion[EnvState, AgentState] { 4 | 5 | /** 6 | * Convert from the "true", complete state as known by the environment, 7 | * into a simplified state that we give to the agent. 8 | * 9 | * This is a chance to do things: 10 | * 11 | * 1. If the problem includes any constraints that say the agent should have incomplete 12 | * knowledge of the environment, we can encode that here. 13 | * 14 | * 2. We can discard some information in order to reduce the agent's state space, 15 | * e.g. by bucketing a large number of environment states into a single agent state. 16 | */ 17 | def convertState(envState: EnvState): AgentState 18 | 19 | } 20 | -------------------------------------------------------------------------------- /src/main/scala/rl/core/package.scala: -------------------------------------------------------------------------------- 1 | package rl 2 | 3 | package object core { 4 | 5 | // In reinforcement learning the reward is always numeric 6 | type Reward = Double 7 | 8 | } 9 | -------------------------------------------------------------------------------- /src/main/scala/rl/gridworld/core/GridworldProblem.scala: -------------------------------------------------------------------------------- 1 | package rl.gridworld.core 2 | 3 | import rl.core.{Environment, Reward} 4 | 5 | object GridworldProblem { 6 | 7 | // Note: x and y range from 0 to 4, not 1 to 5 8 | case class AgentLocation(x: Int, y: Int) 9 | 10 | sealed trait Move 11 | object Move { 12 | case object Up extends Move 13 | case object Down extends Move 14 | case object Left extends Move 15 | case object Right extends Move 16 | } 17 | 18 | val allActions: List[Move] = List(Move.Up, Move.Down, Move.Left, Move.Right) 19 | 20 | implicit val environment: Environment[AgentLocation, Move] = 21 | new Environment[AgentLocation, Move] { 22 | 23 | override def possibleActions(currentState: AgentLocation): List[Move] = 24 | GridworldProblem.allActions 25 | 26 | override def step(currentLocation: AgentLocation, 27 | actionTaken: Move): (AgentLocation, Reward) = currentLocation match { 28 | case AgentLocation(1, 0) => 29 | // special cell A: regardless of action, jump to A' and receive 10 reward 30 | (AgentLocation(1, 4), 10.0) 31 | case AgentLocation(3, 0) => 32 | // special cell B: regardless of action, jump to B' and receive 5 reward 33 | (AgentLocation(3, 2), 5.0) 34 | case AgentLocation(x, y) if wouldLeaveBoard(x, y, actionTaken) => 35 | // negative reward for trying to leave the edge of the board 36 | (currentLocation, -1.0) 37 | case AgentLocation(x, y) => 38 | val newLocation = actionTaken match { 39 | case Move.Up => AgentLocation(x, y - 1) 40 | case Move.Down => AgentLocation(x, y + 1) 41 | case Move.Left => AgentLocation(x - 1, y) 42 | case Move.Right => AgentLocation(x + 1, y) 43 | } 44 | // zero reward in all other cases 45 | (newLocation, 0.0) 46 | } 47 | 48 | private def wouldLeaveBoard(x: Int, y: Int, move: Move): Boolean = 49 | (x == 0 && move == Move.Left) || 50 | (x == 4 && move == Move.Right) || 51 | (y == 0 && move == Move.Up) || 52 | (y == 4 && move == Move.Down) 53 | 54 | override def isTerminal(state: AgentLocation): Boolean = 55 | false // this is a continuous (non-episodic) problem 56 | 57 | } 58 | 59 | } 60 | -------------------------------------------------------------------------------- /src/main/scala/rl/gridworld/ui/GridworldUI.scala: -------------------------------------------------------------------------------- 1 | package rl.gridworld.ui 2 | 3 | import org.scalajs.dom 4 | import org.scalajs.dom.html.{Button, Canvas} 5 | import rl.core.{ActionResult, AgentBehaviour, Environment, QLearning} 6 | import rl.gridworld.core.GridworldProblem 7 | import rl.gridworld.core.GridworldProblem.{AgentLocation, Move} 8 | 9 | import scala.scalajs.js.annotation.{JSExport, JSExportTopLevel} 10 | import scala.util.Random 11 | 12 | @JSExportTopLevel("GridworldUI") 13 | object GridworldUI { 14 | 15 | sealed trait UIState 16 | case object Idle extends UIState 17 | case object Stepping extends UIState 18 | case object Running extends UIState 19 | 20 | private val initialState: AgentLocation = 21 | AgentLocation(Random.nextInt(5), Random.nextInt(5)) 22 | 23 | /* 24 | Note: because this is a continuous (non-episodic) problem, 25 | we use discounting, i.e. we set γ to less than 1. 26 | 27 | This ensures that Q-values will converge to a fixed value 28 | even though the agent continues moving around the grid 29 | and accruing rewards forever. 30 | */ 31 | private val initialAgentData: QLearning[AgentLocation, Move] = 32 | QLearning(α = 0.9, γ = 0.9, ε = 0.4, Q = Map.empty) 33 | 34 | private val env: Environment[AgentLocation, Move] = implicitly 35 | private val agentBehaviour: AgentBehaviour[QLearning[AgentLocation, Move], AgentLocation, Move] = 36 | implicitly 37 | 38 | @JSExport 39 | def main(document: dom.Document, 40 | canvas: Canvas, 41 | stepButton: Button, 42 | runButton: Button, 43 | pauseButton: Button): Unit = { 44 | var uiState: UIState = Idle 45 | 46 | var agentData = initialAgentData 47 | var currentState = initialState 48 | 49 | def step(): Unit = { 50 | val (nextAction, updateAgent) = 51 | agentBehaviour.chooseAction(agentData, currentState, GridworldProblem.allActions) 52 | val (nextState, reward) = env.step(currentState, nextAction) 53 | 54 | agentData = updateAgent(ActionResult(reward, nextState)) 55 | currentState = nextState 56 | 57 | updateUI(document, canvas, agentData, currentState) 58 | } 59 | 60 | def tick(): Unit = uiState match { 61 | case Idle => 62 | updateUI(document, canvas, agentData, currentState) 63 | 64 | case Stepping => 65 | step() 66 | uiState = Idle 67 | 68 | case Running => 69 | step() 70 | } 71 | 72 | stepButton.onclick = _ => uiState = Stepping 73 | runButton.onclick = _ => uiState = Running 74 | pauseButton.onclick = _ => uiState = Idle 75 | 76 | dom.window.setInterval(() => tick(), 150) 77 | } 78 | 79 | private def updateUI(document: dom.Document, 80 | canvas: Canvas, 81 | agentData: QLearning[AgentLocation, Move], 82 | agentLocation: AgentLocation): Unit = { 83 | val ctx = canvas 84 | .getContext("2d") 85 | .asInstanceOf[dom.CanvasRenderingContext2D] 86 | 87 | val cellWidth = canvas.width / 5 88 | val cellHeight = canvas.height / 5 89 | 90 | ctx.clearRect(0, 0, canvas.width, canvas.height) 91 | 92 | ctx.fillStyle = "black" 93 | ctx.lineWidth = 1 94 | ctx.font = "30px arial" 95 | 96 | // draw the grid 97 | for (i <- 0 until 5) { 98 | for (j <- 0 until 5) { 99 | ctx.strokeRect(i * cellWidth, j * cellHeight, cellWidth, cellHeight) 100 | } 101 | } 102 | 103 | // draw the annotations on the special cells 104 | ctx.fillText("A", cellWidth + 10, 30) 105 | ctx.fillText("B", 3 * cellWidth + 10, 30) 106 | ctx.fillText("A'", cellWidth + 10, 4 * cellHeight + 30) 107 | ctx.fillText("B'", 3 * cellWidth + 10, 2 * cellHeight + 30) 108 | drawArrow(ctx, cellWidth + 20, 50, 4 * cellHeight - 10, "+10") 109 | drawArrow(ctx, 3 * cellWidth + 20, 50, 2 * cellHeight - 10, "+5") 110 | 111 | // draw the agent's red dot 112 | ctx.fillStyle = "red" 113 | ctx.beginPath() 114 | ctx.arc((agentLocation.x + 0.5) * cellWidth, 115 | (agentLocation.y + 0.5) * cellHeight, 116 | 0.2 * cellWidth, 117 | 0, 118 | 2 * Math.PI) 119 | ctx.fill() 120 | ctx.closePath() 121 | 122 | updateTables(document, agentData.Q) 123 | } 124 | 125 | private def drawArrow(ctx: dom.CanvasRenderingContext2D, 126 | x: Int, 127 | fromY: Int, 128 | toY: Int, 129 | text: String): Unit = { 130 | val headLength = 10 131 | 132 | ctx.beginPath() 133 | ctx.lineWidth = 2 134 | ctx.moveTo(x, fromY) 135 | ctx.lineTo(x, toY) 136 | ctx.lineTo(x - headLength * Math.cos(Math.PI / 3), toY - headLength * Math.sin(Math.PI / 3)) 137 | ctx.moveTo(x, toY) 138 | ctx.lineTo(x - headLength * Math.cos(2 * Math.PI / 3), 139 | toY - headLength * Math.sin(2 * Math.PI / 3)) 140 | ctx.stroke() 141 | ctx.closePath() 142 | 143 | ctx.fillText(text, x + 5, (toY + fromY) / 2 + 5) 144 | } 145 | 146 | private def updateTables(document: dom.Document, 147 | Q: Map[AgentLocation, Map[Move, Double]]): Unit = { 148 | for { 149 | x <- 0 to 4 150 | y <- 0 to 4 151 | } { 152 | val actionValues = Q.getOrElse(AgentLocation(x, y), Map.empty) 153 | 154 | val Qtext = { 155 | GridworldProblem.allActions 156 | .map { move => 157 | val paddedMove = move.toString.padTo(5, ' ').replaceAllLiterally(" ", " ") 158 | val actionValue = actionValues.getOrElse(move, 0.0) 159 | f"$paddedMove: $actionValue%.4f" 160 | } 161 | .mkString("
") 162 | } 163 | 164 | document.getElementById(s"Q_${x}_$y").innerHTML = Qtext 165 | 166 | val policyText = { 167 | val descendingActionValues = actionValues.groupBy(_._2).toList.sortBy(_._1).reverse 168 | if (descendingActionValues.length < 2) { 169 | "??" 170 | } else { 171 | descendingActionValues.head._2.map(_._1.toString.head).toList.sorted.mkString 172 | } 173 | } 174 | 175 | document.getElementById(s"policy_${x}_$y").innerHTML = policyText 176 | } 177 | } 178 | 179 | } 180 | -------------------------------------------------------------------------------- /src/main/scala/rl/pacman/core/PacmanProblem.scala: -------------------------------------------------------------------------------- 1 | package rl.pacman.core 2 | 3 | import rl.core._ 4 | 5 | import scala.util.Random 6 | 7 | object PacmanProblem { 8 | 9 | // Note: x ranges from 0 to 19, y ranges from 0 to 6 10 | case class Location(x: Int, y: Int) { 11 | 12 | def move(move: Move): Location = move match { 13 | case Move.Left => Location(x - 1, y) 14 | case Move.Right => Location(x + 1, y) 15 | case Move.Up => Location(x, y - 1) 16 | case Move.Down => Location(x, y + 1) 17 | } 18 | 19 | } 20 | 21 | // convenience method for constructing a Location 22 | private def xy(x: Int, y: Int) = Location(x, y) 23 | 24 | // the current game mode: are the ghosts chasing Pacman or vice versa? 25 | sealed trait Mode { def chasingGhosts: Boolean } 26 | object Mode { 27 | case object Normal extends Mode { val chasingGhosts = false } 28 | case class ChaseGhosts(timeRemaining: Int) extends Mode { val chasingGhosts = true } 29 | } 30 | 31 | /* 32 | The complete state of the game: 33 | - the location of each ghost 34 | - Pacman's location 35 | - the locations of all remaining food 36 | - the locations of all remaining pills 37 | - the current game mode 38 | */ 39 | case class GameState( 40 | ghost1: Location, 41 | ghost2: Location, 42 | pacman: Location, 43 | food: Set[Location], 44 | pills: Set[Location], 45 | mode: Mode 46 | ) 47 | 48 | // the actions that the agent can take to move Pacman 49 | sealed trait Move 50 | object Move { 51 | case object Up extends Move 52 | case object Down extends Move 53 | case object Left extends Move 54 | case object Right extends Move 55 | } 56 | 57 | val allActions: List[Move] = List(Move.Up, Move.Down, Move.Left, Move.Right) 58 | 59 | /* 60 | We use the following "smallClassic" grid: 61 | 62 | 0123456789 63 | 0%%%%%%%%%%%%%%%%%%%% 64 | 1%......%G G%......% 65 | 2%.%%...%% %%...%%.% 66 | 3%.%o.%........%.o%.% 67 | 4%.%%.%.%%%%%%.%.%%.% 68 | 5%........P.........% 69 | 6%%%%%%%%%%%%%%%%%%%% 70 | 71 | % = wall 72 | . = food 73 | o = pill 74 | G = ghost start location 75 | P = Pacman start location 76 | */ 77 | 78 | val walls: Set[Location] = 79 | // format: off 80 | List.tabulate(20)(xy(_, 0)).toSet ++ // top wall 81 | Set(xy(0, 1), xy(7, 1), xy(12, 1), xy(19, 1)) ++ 82 | Set(xy(0, 2), xy(2, 2), xy(3, 2), xy(7, 2), xy(8, 2), xy(11, 2), xy(12, 2), xy(16, 2), xy(17, 2), xy(19, 2)) ++ 83 | Set(xy(0, 3), xy(2, 3), xy(5, 3), xy(14, 3), xy(17, 3), xy(19, 3)) ++ 84 | Set(xy(0, 4), xy(2, 4), xy(3, 4), xy(5, 4), xy(7, 4), xy(8, 4), xy(9, 4), xy(10, 4), xy(11, 4), xy(12, 4), xy(14, 4), xy(16, 4), xy(17, 4), xy(19, 4)) ++ 85 | Set(xy(0, 5), xy(19, 5)) ++ 86 | List.tabulate(20)(xy(_, 6)).toSet // bottom wall 87 | // format: on 88 | 89 | private val initialGhost1 = xy(8, 1) 90 | private val initialGhost2 = xy(11, 1) 91 | private val initialPacman = xy(9, 5) 92 | private val initialPills = Set(xy(3, 3), xy(16, 3)) 93 | 94 | private val initialFood: Set[Location] = 95 | // format: off 96 | Set(xy(1, 1), xy(2, 1), xy(3, 1), xy(4, 1), xy(5, 1), xy(6, 1), xy(13, 1), xy(14, 1), xy(15, 1), xy(16, 1), xy(17, 1), xy(18, 1)) ++ 97 | Set(xy(1, 2), xy(4, 2), xy(5, 2), xy(6, 2), xy(13, 2), xy(14, 2), xy(15, 2), xy(18, 2)) ++ 98 | Set(xy(1, 3), xy(4, 3), xy(6, 3), xy(7, 3), xy(8, 3), xy(9, 3), xy(10, 3), xy(11, 3), xy(12, 3), xy(13, 3), xy(15, 3), xy(18, 3)) ++ 99 | Set(xy(1, 4), xy(4, 4), xy(6, 4), xy(13, 4), xy(15, 4), xy(18, 4)) ++ 100 | Set(xy(1, 5), xy(2, 5), xy(3, 5), xy(4, 5), xy(5, 5), xy(6, 5), xy(7, 5), xy(8, 5), xy(10, 5), xy(11, 5), xy(12, 5), xy(13, 5), xy(14, 5), xy(15, 5), xy(16, 5), xy(17, 5), xy(18, 5)) 101 | // format: on 102 | 103 | val initialState: GameState = GameState( 104 | ghost1 = initialGhost1, 105 | ghost2 = initialGhost2, 106 | pacman = initialPacman, 107 | food = initialFood, 108 | pills = initialPills, 109 | mode = Mode.Normal 110 | ) 111 | 112 | implicit val environment: Environment[GameState, Move] = 113 | new Environment[GameState, Move] { 114 | 115 | override def possibleActions(currentState: GameState): List[Move] = 116 | allActions.filterNot(move => walls.contains(currentState.pacman.move(move))) 117 | 118 | override def step(currentState: GameState, actionTaken: Move): (GameState, Reward) = { 119 | // Calculate Pacman's new location, based on actionTaken and adjacent walls 120 | val nextPacmanLocation = updatePacmanLocation(currentState.pacman, actionTaken) 121 | 122 | // Calculate ghosts' new locations, based on their current locations and directions 123 | val nextGhost1 = updateGhost(currentState.ghost1, nextPacmanLocation, currentState.mode) 124 | val nextGhost2 = updateGhost(currentState.ghost2, nextPacmanLocation, currentState.mode) 125 | 126 | // Check if Pacman ate some food by moving to his new location 127 | val (ateFood, updatedFoodLocations) = { 128 | if (currentState.food.contains(nextPacmanLocation)) 129 | (true, currentState.food - nextPacmanLocation) 130 | else 131 | (false, currentState.food) 132 | } 133 | 134 | val (atePill, updatedPillLocations) = { 135 | if (currentState.pills.contains(nextPacmanLocation)) 136 | (true, currentState.pills - nextPacmanLocation) 137 | else 138 | (false, currentState.pills) 139 | } 140 | 141 | // If current mode is ChaseGhosts, decrement its timer. If it reaches zero, switch back to Normal. 142 | val updatedMode = { 143 | if (atePill) 144 | Mode.ChaseGhosts(timeRemaining = 40) 145 | else 146 | currentState.mode match { 147 | case Mode.Normal => Mode.Normal 148 | case Mode.ChaseGhosts(0) => Mode.Normal 149 | case Mode.ChaseGhosts(t) => Mode.ChaseGhosts(t - 1) 150 | } 151 | } 152 | 153 | // Check if Pacman caught any ghosts 154 | val pacmanTouchingGhost1 = nextPacmanLocation == nextGhost1 155 | val updatedGhost1 = 156 | if (pacmanTouchingGhost1 && updatedMode.chasingGhosts) 157 | initialGhost1 158 | else 159 | nextGhost1 160 | 161 | val pacmanTouchingGhost2 = nextPacmanLocation == nextGhost2 162 | val updatedGhost2 = 163 | if (pacmanTouchingGhost2 && updatedMode.chasingGhosts) 164 | initialGhost2 165 | else 166 | nextGhost2 167 | 168 | val pacmanTouchingAGhost = pacmanTouchingGhost1 || pacmanTouchingGhost2 169 | val pacmanCaughtByGhost = pacmanTouchingAGhost && !updatedMode.chasingGhosts 170 | val pacmanCaughtAGhost = pacmanTouchingAGhost && updatedMode.chasingGhosts 171 | 172 | val nextState = GameState( 173 | ghost1 = updatedGhost1, 174 | ghost2 = updatedGhost2, 175 | pacman = nextPacmanLocation, 176 | food = updatedFoodLocations, 177 | pills = updatedPillLocations, 178 | mode = updatedMode 179 | ) 180 | 181 | val reward = { 182 | if (pacmanCaughtByGhost) 183 | -100.0 184 | else if (ateFood) 185 | 1.0 186 | else if (atePill) 187 | 10.0 188 | else if (pacmanCaughtAGhost) 189 | 50.0 190 | else 191 | 0.0 192 | } 193 | 194 | (nextState, reward) 195 | } 196 | 197 | override def isTerminal(state: GameState): Boolean = 198 | state.food.isEmpty || isGameOver(state) 199 | 200 | private def isGameOver(state: GameState): Boolean = { 201 | val pacmanTouchingGhost = state.pacman == state.ghost1 || state.pacman == state.ghost2 202 | pacmanTouchingGhost && !state.mode.chasingGhosts 203 | } 204 | 205 | private def updatePacmanLocation(pacman: Location, move: Move): Location = { 206 | val next = pacman.move(move) 207 | if (walls.contains(next)) 208 | // can't move into a wall, so stay where you are 209 | pacman 210 | else 211 | next 212 | } 213 | 214 | private def updateGhost(ghost: Location, pacman: Location, mode: Mode): Location = { 215 | if (ghost == pacman && !mode.chasingGhosts) { 216 | // if you've caught Pacman, stay where you are! 217 | ghost 218 | } else { 219 | val smartMoveProb = 0.7 220 | 221 | val validPositions = allActions.map(ghost.move).filterNot(walls.contains) 222 | 223 | if (Random.nextDouble() < smartMoveProb) { 224 | // make a "smart" move, i.e. either chase Pacman or run away from him depending on the game mode 225 | val sortedByDistance = validPositions 226 | .map(location => (location, manhattanDist(location, pacman))) 227 | .sortBy { 228 | case (_, distance) => 229 | if (mode.chasingGhosts) 230 | distance * -1 // the further from Pacman the better 231 | else 232 | distance // the closer the better 233 | } 234 | val bestDistance = sortedByDistance.head._2 235 | val bestPositions = sortedByDistance.takeWhile(_._2 == bestDistance) 236 | Random.shuffle(bestPositions).head._1 237 | } else { 238 | Random.shuffle(validPositions).head 239 | } 240 | } 241 | } 242 | 243 | } 244 | 245 | /* 246 | The ghosts use Manhattan distance when chasing Pacman. 247 | You might find it handy for your Pacman agent as well. 248 | */ 249 | private def manhattanDist(from: Location, to: Location): Int = 250 | Math.abs(from.x - to.x) + Math.abs(from.y - to.y) 251 | 252 | /* 253 | TODO: Define a suitable agent state, and the conversion from `GameState` to `AgentState`. 254 | 255 | The trick is to find a way of encoding enough information about the game state 256 | without the number of states exploding. 257 | e.g. if you were to track the exact locations of Pacman and both ghosts, 258 | you already have 65 x 65 x 65 = 274,675 states to deal with. 259 | 260 | Your state encoding should also make sense when combined with the reward function. 261 | For example, the environment gives a reward when Pacman eats food, so intuitively 262 | the state should track food in some way. 263 | */ 264 | //case class AgentState(...) 265 | type AgentState = GameState 266 | 267 | implicit val stateConversion: StateConversion[GameState, AgentState] = { gameState => 268 | gameState 269 | } 270 | 271 | } 272 | -------------------------------------------------------------------------------- /src/main/scala/rl/pacman/training/PacmanTraining.scala: -------------------------------------------------------------------------------- 1 | package rl.pacman.training 2 | import java.nio.file.{Files, Paths} 3 | import java.time.Instant 4 | 5 | import rl.core._ 6 | import rl.pacman.core.PacmanProblem.{AgentState, GameState, Move, initialState} 7 | 8 | import scala.collection.mutable 9 | import scala.scalajs.niocharset.StandardCharsets 10 | 11 | /** 12 | * This is a training harness for your Pacman agent. 13 | * It makes the agent play a lot of episodes very quickly. 14 | * 15 | * Every million time steps it will print out some stats about the agent's progress. 16 | * 17 | * Every 5 million time steps it will save the agent's Q-values to a JSON file 18 | * in a format suitable for loading by the Pacman UI. 19 | * 20 | * If all goes well, you should see your Pacman agent start to 21 | * win more and more games as its training proceeds. 22 | */ 23 | object PacmanTraining extends App { 24 | 25 | // TODO: feel free to tweak α, γ and ε as you see fit 26 | private val initialAgentData: QLearning[AgentState, Move] = 27 | QLearning(α = 0.9, γ = 1.0, ε = 0.5, Q = Map.empty) 28 | 29 | private val env: Environment[GameState, Move] = implicitly 30 | private val stateConversion: StateConversion[GameState, AgentState] = implicitly 31 | private val agentBehaviour: AgentBehaviour[QLearning[AgentState, Move], AgentState, Move] = 32 | implicitly 33 | 34 | private var t: Long = 0 35 | private var episodeLength = 0 36 | private var longestEpisode = 0 37 | private var wins: Long = 0 38 | private var losses: Long = 0 39 | private val recentResults: mutable.Queue[Boolean] = new mutable.Queue[Boolean]() 40 | private val MaxQueueSize = 10000 41 | 42 | private var agentData = initialAgentData 43 | private var gameState: GameState = initialState 44 | 45 | val trainingDir = Paths.get(s"pacman-training/${Instant.now()}") 46 | Files.createDirectories(trainingDir) 47 | Files.write( 48 | trainingDir.resolve("parameters.txt"), 49 | s"alpha = ${initialAgentData.α}, gamma = ${initialAgentData.γ}, epsilon = ${initialAgentData.ε}" 50 | .getBytes(StandardCharsets.UTF_8) 51 | ) 52 | 53 | private def step(): Unit = { 54 | val currentState = stateConversion.convertState(gameState) 55 | val possibleActions = env.possibleActions(gameState) 56 | val (nextAction, updateAgent) = 57 | agentBehaviour.chooseAction(agentData, currentState, possibleActions) 58 | val (nextState, reward) = env.step(gameState, nextAction) 59 | 60 | agentData = updateAgent(ActionResult(reward, stateConversion.convertState(nextState))) 61 | gameState = nextState 62 | 63 | episodeLength += 1 64 | t += 1 65 | 66 | if (env.isTerminal(gameState)) { 67 | longestEpisode = longestEpisode max episodeLength 68 | 69 | val won = gameState.food.isEmpty 70 | if (won) { 71 | wins += 1 72 | } else { 73 | losses += 1 74 | } 75 | recentResults.enqueue(won) 76 | if (recentResults.size > MaxQueueSize) { 77 | recentResults.dequeue() 78 | } 79 | 80 | gameState = initialState 81 | episodeLength = 0 82 | } 83 | } 84 | 85 | private def report(): Unit = { 86 | println(s"t = $t") 87 | println(s"Completed episodes = ${wins + losses}") 88 | println(s"Wins = $wins") 89 | println(s"Losses = $losses") 90 | println(s"Longest episode so far = $longestEpisode") 91 | println(s"Won ${recentResults.count(identity)} of the last 10,000 games") 92 | println(s"State space size = ${agentData.Q.size}") 93 | println() 94 | 95 | if (t % 5000000 == 0) { 96 | saveQValues() 97 | } 98 | } 99 | 100 | private def saveQValues(): Unit = { 101 | print("Saving Q values to file... ") 102 | val list: List[QKeyValue] = agentData.Q.map { case (k, v) => QKeyValue(k, v) }.toList 103 | 104 | import io.circe.syntax._ 105 | val json = list.asJson 106 | 107 | Files.write( 108 | trainingDir.resolve(s"Q-after-$t-steps.json"), 109 | json.noSpaces.getBytes(StandardCharsets.UTF_8) 110 | ) 111 | println("Done.") 112 | } 113 | 114 | while (true) { 115 | step() 116 | 117 | if (t % 1000000 == 0) { 118 | report() 119 | } 120 | } 121 | 122 | } 123 | -------------------------------------------------------------------------------- /src/main/scala/rl/pacman/training/QKeyValue.scala: -------------------------------------------------------------------------------- 1 | package rl.pacman.training 2 | 3 | import io.circe.{Decoder, Encoder, KeyDecoder, KeyEncoder} 4 | import io.circe.generic.auto._ 5 | import io.circe.generic.semiauto._ 6 | import rl.pacman.core.PacmanProblem.{AgentState, Move} 7 | 8 | /* 9 | This is just an artifact of the way we encode the Q-values as JSON. 10 | Q is a Map[AgentState, Map[Move, Double]], so it has non-String keys. 11 | When we write it to the JSON file we turn it into a List[(AgentState, Map[Move, Double])]. 12 | */ 13 | case class QKeyValue(key: AgentState, value: Map[Move, Double]) 14 | 15 | object QKeyValue { 16 | 17 | implicit val moveEncoder: KeyEncoder[Move] = (move: Move) => move.toString 18 | implicit val moveDecoder: KeyDecoder[Move] = { 19 | case "Left" => Some(Move.Left) 20 | case "Right" => Some(Move.Right) 21 | case "Up" => Some(Move.Up) 22 | case "Down" => Some(Move.Down) 23 | case _ => None 24 | } 25 | 26 | implicit val encoder: Encoder[QKeyValue] = deriveEncoder 27 | implicit val decoder: Decoder[QKeyValue] = deriveDecoder 28 | 29 | } 30 | -------------------------------------------------------------------------------- /src/main/scala/rl/pacman/ui/PacmanUI.scala: -------------------------------------------------------------------------------- 1 | package rl.pacman.ui 2 | 3 | import org.scalajs.dom 4 | import org.scalajs.dom.html 5 | import rl.core._ 6 | import rl.pacman.core.PacmanProblem._ 7 | import rl.pacman.training.QKeyValue 8 | 9 | import scala.concurrent.Future 10 | import scala.concurrent.ExecutionContext.Implicits.global 11 | import scala.scalajs.js.annotation.{JSExport, JSExportTopLevel} 12 | import scala.util.{Failure, Success} 13 | 14 | @JSExportTopLevel("PacmanUI") 15 | object PacmanUI { 16 | 17 | sealed trait UIState 18 | case object Running extends UIState 19 | case class GameOver(flashesLeft: Int) extends UIState 20 | 21 | private val env: Environment[GameState, Move] = implicitly 22 | private val stateConversion: StateConversion[GameState, AgentState] = implicitly 23 | private val agentBehaviour: AgentBehaviour[QLearning[AgentState, Move], AgentState, Move] = 24 | implicitly 25 | 26 | @JSExport 27 | def main(document: dom.Document, canvas: html.Canvas, info: html.Div): Unit = { 28 | loadQ(info) onComplete { 29 | case Failure(e) => 30 | info.innerHTML = "Failed to load Q data." 31 | println(e) 32 | case Success(q) => 33 | info.innerHTML = "Loaded Q data." 34 | 35 | val initialAgentData: QLearning[AgentState, Move] = 36 | QLearning(α = 0.1, γ = 0.9, ε = 0.1, Q = q) 37 | 38 | var agentData = initialAgentData 39 | var gameState: GameState = initialState 40 | var lastAction: Move = Move.Left 41 | var uiState: UIState = Running 42 | 43 | def step(): Unit = uiState match { 44 | case Running => 45 | val currentState = stateConversion.convertState(gameState) 46 | val possibleActions = env.possibleActions(gameState) 47 | val (chosenAction, updateAgent) = 48 | agentBehaviour.chooseAction(agentData, currentState, possibleActions) 49 | val (nextState, reward) = env.step(gameState, chosenAction) 50 | 51 | agentData = updateAgent(ActionResult(reward, stateConversion.convertState(nextState))) 52 | gameState = nextState 53 | lastAction = chosenAction 54 | 55 | drawGame(canvas, gameState, lastAction) 56 | 57 | if (env.isTerminal(gameState)) { 58 | uiState = GameOver(flashesLeft = 5) 59 | } 60 | case GameOver(0) => 61 | gameState = initialState 62 | uiState = Running 63 | case GameOver(n) => 64 | if (n % 2 == 0) 65 | drawGame(canvas, gameState, lastAction) 66 | else 67 | drawBlankCanvas(canvas) 68 | uiState = GameOver(n - 1) 69 | } 70 | 71 | dom.window.setInterval(() => step(), 500) 72 | } 73 | 74 | } 75 | 76 | private def loadQ(info: html.Div): Future[Map[AgentState, Map[Move, Double]]] = { 77 | info.innerHTML = "Downloading Q-values JSON file..." 78 | 79 | dom.ext.Ajax.get("data/pacman/Q.json").flatMap { r => 80 | info.innerHTML = "Parsing Q-values JSON file..." 81 | 82 | import io.circe.parser._ 83 | val either = for { 84 | json <- parse(r.responseText) 85 | decoded <- json.as[List[QKeyValue]] 86 | } yield { 87 | decoded.map(kv => (kv.key, kv.value)).toMap 88 | } 89 | 90 | Future.fromTry(either.toTry) 91 | } 92 | } 93 | 94 | private def drawGame(canvas: html.Canvas, state: GameState, actionTaken: Move): Unit = { 95 | val ctx = canvas.getContext("2d").asInstanceOf[dom.CanvasRenderingContext2D] 96 | val pixelSize = 50 97 | val pixelCentre = pixelSize / 2 98 | 99 | def drawEye(ghost: Location, xOffset: Int): Unit = { 100 | ctx.save() 101 | 102 | ctx.translate(ghost.x * pixelSize + pixelCentre + xOffset, 103 | ghost.y * pixelSize + pixelCentre - 10) 104 | ctx.scale(1.0, 1.5) 105 | 106 | ctx.beginPath() 107 | ctx.arc(0.0, 0.0, 4, 0.0, Math.PI * 2.0) 108 | ctx.fillStyle = "white" 109 | ctx.fill() 110 | ctx.closePath() 111 | 112 | ctx.beginPath() 113 | ctx.arc(0.0, 0.0, 2, 0.0, Math.PI * 2.0) 114 | ctx.fillStyle = "black" 115 | ctx.fill() 116 | ctx.closePath() 117 | 118 | ctx.restore() 119 | } 120 | 121 | def drawGhost(ghost: Location, colour: String): Unit = { 122 | // body 123 | ctx.beginPath() 124 | ctx.fillStyle = colour 125 | ctx.arc(ghost.x * pixelSize + pixelCentre, 126 | ghost.y * pixelSize + pixelCentre, 127 | 20, 128 | 0.0, 129 | Math.PI * 2.0) 130 | ctx.fill() 131 | ctx.closePath() 132 | 133 | // left eye 134 | drawEye(ghost, -10) 135 | 136 | // right eye 137 | drawEye(ghost, 10) 138 | } 139 | 140 | ctx.fillStyle = "black" 141 | ctx.fillRect(0, 0, 1000, 350) 142 | 143 | // draw walls 144 | for (wall <- walls) { 145 | ctx.fillStyle = "blue" 146 | ctx.fillRect(wall.x * pixelSize + 10, wall.y * pixelSize + 10, pixelSize - 20, pixelSize - 20) 147 | } 148 | 149 | // draw food 150 | for (food <- state.food) { 151 | ctx.beginPath() 152 | ctx.fillStyle = "yellow" 153 | ctx.arc(food.x * pixelSize + pixelCentre, 154 | food.y * pixelSize + pixelCentre, 155 | 5, 156 | 0.0, 157 | Math.PI * 2.0) 158 | ctx.fill() 159 | ctx.closePath() 160 | } 161 | 162 | // draw pills 163 | for (pill <- state.pills) { 164 | ctx.beginPath() 165 | ctx.fillStyle = "yellow" 166 | ctx.arc(pill.x * pixelSize + pixelCentre, 167 | pill.y * pixelSize + pixelCentre, 168 | 15, 169 | 0.0, 170 | Math.PI * 2.0) 171 | ctx.fill() 172 | ctx.closePath() 173 | } 174 | 175 | // draw ghosts 176 | state.mode match { 177 | case Mode.Normal => 178 | drawGhost(state.ghost1, "green") 179 | drawGhost(state.ghost2, "red") 180 | case Mode.ChaseGhosts(_) => 181 | drawGhost(state.ghost1, "cyan") 182 | drawGhost(state.ghost2, "cyan") 183 | } 184 | 185 | // draw pacman 186 | val (startAngle, endAngle) = actionTaken match { 187 | case Move.Left => (Math.PI * -0.8, Math.PI * 0.8) 188 | case Move.Up => (Math.PI * -0.3, Math.PI * 1.3) 189 | case Move.Right => (Math.PI * 0.2, Math.PI * 1.8) 190 | case Move.Down => (Math.PI * 0.7, Math.PI * 2.3) 191 | } 192 | ctx.beginPath() 193 | ctx.fillStyle = "yellow" 194 | ctx.arc(state.pacman.x * pixelSize + pixelCentre, 195 | state.pacman.y * pixelSize + pixelCentre, 196 | 20.0, 197 | startAngle, 198 | endAngle) 199 | ctx.lineTo(state.pacman.x * pixelSize + pixelCentre, state.pacman.y * pixelSize + pixelCentre) 200 | ctx.closePath() 201 | ctx.fill() 202 | } 203 | 204 | private def drawBlankCanvas(canvas: html.Canvas): Unit = { 205 | val ctx = canvas.getContext("2d").asInstanceOf[dom.CanvasRenderingContext2D] 206 | ctx.fillStyle = "black" 207 | ctx.fillRect(0, 0, 1000, 350) 208 | } 209 | 210 | } 211 | -------------------------------------------------------------------------------- /src/main/scala/rl/polecart/core/PoleBalancingProblem.scala: -------------------------------------------------------------------------------- 1 | package rl.polecart.core 2 | 3 | import rl.core.{Environment, Reward, StateConversion} 4 | import java.lang.Math._ 5 | 6 | object PoleBalancingProblem { 7 | 8 | case class PoleCartState( 9 | cartPosition: Double, // metres from start position (middle of cart, -ve is left, +ve is right) 10 | cartVelocity: Double, // m/s 11 | poleAngle: Double, // radians, angle from vertical 12 | poleVelocity: Double // radians/second, angular velocity 13 | ) { 14 | override def toString: String = 15 | s"""Pole-cart: 16 | |x = $cartPosition 17 | |x' = $cartVelocity 18 | |θ = ${toDegrees(poleAngle)} 19 | |θ' = ${toDegrees(poleVelocity)} 20 | """.stripMargin 21 | } 22 | 23 | sealed trait PushCart 24 | object PushCart { 25 | case object Left extends PushCart 26 | case object Right extends PushCart 27 | } 28 | 29 | val allActions: List[PushCart] = List(PushCart.Left, PushCart.Right) 30 | 31 | implicit val environment: Environment[PoleCartState, PushCart] = 32 | new Environment[PoleCartState, PushCart] { 33 | 34 | override def possibleActions(currentState: PoleCartState): List[PushCart] = 35 | PoleBalancingProblem.allActions 36 | 37 | override def step(currentState: PoleCartState, 38 | actionTaken: PushCart): (PoleCartState, Reward) = { 39 | /* 40 | First we use non-linear differential equations to calculate the double derivatives 41 | x'' and θ'' of the cart position (x) and pole angle (θ) at time t, 42 | given x, θ, x' and θ' at time t. 43 | 44 | See the appendix of the paper "Neuronlike Adaptive Elements That Can Solve Difficult Learning Problems" 45 | (Barto, Sutton and Anderson, 1983) for the details of the differential equations. 46 | 47 | Once we have x'' and θ'' at time t, we use Euler's method (with a time step of 0.02 seconds) 48 | to estimate x' and θ' at time t+1: 49 | 50 | x'(t+1) = x'(t) + 0.02 * x''(t) 51 | θ'(t+1) = θ'(t) + 0.02 * θ''(t) 52 | 53 | We also use Euler's method to estimate x and θ at time t+1 given x, θ, x' and θ' at time t: 54 | 55 | x(t+1) = x(t) + 0.02 * x'(t) 56 | θ(t+1) = θ(t) + 0.02 * θ'(t) 57 | 58 | This gives us the new state. 59 | The reward is simple: 0 if non-terminal, -1 if terminal. 60 | */ 61 | 62 | // The values for these constants are also taken from the Barton, Sutton and Anderson paper 63 | val g = -9.8 // m/s^2, acceleration due to gravity 64 | val m_c = 1.0 // kg, mass of cart 65 | val m = 0.1 // kg, mass of pole 66 | val l = 0.5 // m, half-pole length 67 | val μ_c = 0.0005 // coefficient of friction of cart on track 68 | val μ_p = 0.000002 // coefficient of friction of pole on cart 69 | val F = actionTaken match { 70 | case PushCart.Left => 71 | -10.0 // Newtons, force applied to cart's centre of mass 72 | case PushCart.Right => 73 | 10.0 74 | } 75 | 76 | val x_t = currentState.cartPosition 77 | val `x'_t` = currentState.cartVelocity 78 | val θ_t = currentState.poleAngle 79 | val `θ'_t` = currentState.poleVelocity 80 | 81 | val h = 0.02 // seconds, time step 82 | 83 | val `θ''_t` = 84 | (g * sin(θ_t) + cos(θ_t) * (-F - m * l * `θ'_t` * `θ'_t` * sin(θ_t) + μ_c * signum( 85 | `x'_t`)) - ((μ_p * `θ'_t`) / (m * l))) / 86 | (l * (4.0 / 3.0 - (m * cos(θ_t) * cos(θ_t)) / (m_c + m))) 87 | 88 | val `x''_t` = 89 | (F + m * l * (`θ'_t` * `θ'_t` * sin(θ_t) - `θ''_t` * cos(θ_t)) - μ_c * signum(`x'_t`)) / 90 | (m_c + m) 91 | 92 | val `x'_t+1` = `x'_t` + h * `x''_t` 93 | val `θ'_t+1` = `θ'_t` + h * `θ''_t` 94 | 95 | val `x_t+1` = `x_t` + h * `x'_t` 96 | val `θ_t+1` = `θ_t` + h * `θ'_t` 97 | 98 | val nextState = PoleCartState( 99 | cartPosition = `x_t+1`, 100 | cartVelocity = `x'_t+1`, 101 | poleAngle = `θ_t+1`, 102 | poleVelocity = `θ'_t+1` 103 | ) 104 | val reward = if (isTerminal(nextState)) -1.0 else 0.0 105 | 106 | (nextState, reward) 107 | } 108 | 109 | // Episode ends in failure if pole topples too far or cart hits either of the walls 110 | override def isTerminal(state: PoleCartState): Boolean = { 111 | val absPosition = Math.abs(state.cartPosition) 112 | val absAngleDegrees = toDegrees(Math.abs(state.poleAngle)) 113 | absPosition > 2.4 || absAngleDegrees > 12 114 | } 115 | 116 | } 117 | 118 | sealed trait RoughCartPosition 119 | object RoughCartPosition { 120 | case object Left extends RoughCartPosition // x ≦ -0.8 (metres) 121 | case object Middle extends RoughCartPosition // -0.8 < x ≦ 0.8 122 | case object Right extends RoughCartPosition // x > 0.8 123 | } 124 | 125 | sealed trait RoughCartVelocity 126 | object RoughCartVelocity { 127 | case object FastLeft extends RoughCartVelocity // -0.5 ≦ x (m/s) 128 | case object Slow extends RoughCartVelocity // -0.5 < x ≦ 0.5 129 | case object FastRight extends RoughCartVelocity // x > 0.5 130 | } 131 | 132 | sealed trait RoughPoleAngle 133 | object RoughPoleAngle { 134 | case object VeryLeft extends RoughPoleAngle // x ≦ -6 (degrees) 135 | case object QuiteLeft extends RoughPoleAngle // -6 < x ≦ -1 136 | case object SlightlyLeft extends RoughPoleAngle // -1 < x ≦ 0 137 | case object SlightlyRight extends RoughPoleAngle // 0 < x ≦ 1 138 | case object QuiteRight extends RoughPoleAngle // 1 < x ≦ 6 139 | case object VeryRight extends RoughPoleAngle // x > 6 140 | } 141 | 142 | sealed trait RoughPoleVelocity 143 | object RoughPoleVelocity { 144 | case object FastLeft extends RoughPoleVelocity // -50 ≦ x (degrees/second) 145 | case object Slow extends RoughPoleVelocity // -50 < x ≦ 50 146 | case object FastRight extends RoughPoleVelocity // x > 50 147 | } 148 | 149 | case class RoughPoleCartState( 150 | cartPosition: RoughCartPosition, 151 | cartVelocity: RoughCartVelocity, 152 | poleAngle: RoughPoleAngle, 153 | poleVelocity: RoughPoleVelocity 154 | ) 155 | 156 | implicit val stateConversion: StateConversion[PoleCartState, RoughPoleCartState] = { 157 | envState: PoleCartState => 158 | val roughCartPosition = envState.cartPosition match { 159 | case x if x <= -0.8 => RoughCartPosition.Left 160 | case x if x > -0.8 && x <= 0.8 => RoughCartPosition.Middle 161 | case _ => RoughCartPosition.Right 162 | } 163 | 164 | val roughCartVelocity = envState.cartVelocity match { 165 | case x if x <= -0.5 => RoughCartVelocity.FastLeft 166 | case x if x > -0.5 && x <= 0.5 => RoughCartVelocity.Slow 167 | case _ => RoughCartVelocity.FastRight 168 | } 169 | 170 | val roughPoleAngle = toDegrees(envState.poleAngle) match { 171 | case x if x <= -6.0 => RoughPoleAngle.VeryLeft 172 | case x if x > -6.0 && x <= -1.0 => RoughPoleAngle.QuiteLeft 173 | case x if x > -1.0 && x <= 0.0 => RoughPoleAngle.SlightlyLeft 174 | case x if x > 0.0 && x <= 1.0 => RoughPoleAngle.SlightlyRight 175 | case x if x > 1.0 && x <= 6.0 => RoughPoleAngle.QuiteRight 176 | case _ => RoughPoleAngle.VeryRight 177 | } 178 | 179 | val roughPoleVelocity = toDegrees(envState.poleVelocity) match { 180 | case x if x <= -50.0 => RoughPoleVelocity.FastLeft 181 | case x if x > -50.0 && x <= 50.0 => RoughPoleVelocity.Slow 182 | case _ => RoughPoleVelocity.FastRight 183 | } 184 | 185 | RoughPoleCartState( 186 | roughCartPosition, 187 | roughCartVelocity, 188 | roughPoleAngle, 189 | roughPoleVelocity 190 | ) 191 | } 192 | 193 | } 194 | -------------------------------------------------------------------------------- /src/main/scala/rl/polecart/ui/HumanUI.scala: -------------------------------------------------------------------------------- 1 | package rl.polecart.ui 2 | 3 | import org.scalajs.dom 4 | import org.scalajs.dom.html.Canvas 5 | import rl.polecart.core.PoleBalancingProblem 6 | 7 | import scala.scalajs.js.annotation.{JSExport, JSExportTopLevel} 8 | 9 | @JSExportTopLevel("PolecartHumanUI") 10 | object HumanUI { 11 | 12 | sealed trait UIState 13 | case object Idle extends UIState 14 | case object Running extends UIState 15 | 16 | private val initialPoleCartState: PoleBalancingProblem.PoleCartState = 17 | PoleBalancingProblem.PoleCartState(0.0, 0.0, 0.0, 0.0) 18 | 19 | @JSExport 20 | def main(window: dom.Window, canvas: Canvas, infoLabel: dom.Element): Unit = { 21 | var uiState: UIState = Idle 22 | 23 | var poleCartState: PoleBalancingProblem.PoleCartState = initialPoleCartState 24 | var currentAction: PoleBalancingProblem.PushCart = 25 | PoleBalancingProblem.PushCart.Left 26 | var timeElapsed = 0.0 27 | var maxTimeElapsed = 0.0 28 | 29 | def tick(): Unit = { 30 | clear(canvas) 31 | 32 | uiState match { 33 | case Idle => 34 | drawCart(canvas, poleCartState, timeElapsed) 35 | case Running => 36 | timeElapsed += 0.02 37 | poleCartState = PoleBalancingProblem.environment 38 | .step(poleCartState, currentAction) 39 | ._1 40 | drawCart(canvas, poleCartState, timeElapsed) 41 | if (PoleBalancingProblem.environment.isTerminal(poleCartState)) { 42 | failed() 43 | } 44 | } 45 | } 46 | 47 | def running(): Unit = { 48 | infoLabel.textContent = "" 49 | poleCartState = initialPoleCartState 50 | uiState = Running 51 | } 52 | 53 | def failed(): Unit = { 54 | maxTimeElapsed = maxTimeElapsed max timeElapsed 55 | infoLabel.textContent = 56 | f"FAILED! You lasted $timeElapsed%.2f seconds. Your record is $maxTimeElapsed%.2f seconds. Press ← or → to try again" 57 | timeElapsed = 0.0 58 | uiState = Idle 59 | } 60 | 61 | window.onkeydown = { event => 62 | event.key match { 63 | case "ArrowLeft" => 64 | currentAction = PoleBalancingProblem.PushCart.Left 65 | running() 66 | case "ArrowRight" => 67 | currentAction = PoleBalancingProblem.PushCart.Right 68 | running() 69 | case other => 70 | // ignore 71 | } 72 | } 73 | 74 | infoLabel.textContent = "Press ← or → to start" 75 | 76 | dom.window.setInterval(() => tick(), 20) 77 | } 78 | 79 | private def clear(canvas: Canvas): Unit = { 80 | val ctx = canvas.getContext("2d").asInstanceOf[dom.CanvasRenderingContext2D] 81 | 82 | // clear the canvas 83 | ctx.fillStyle = "white" 84 | ctx.clearRect(0, 0, canvas.width, canvas.height) 85 | 86 | // border 87 | ctx.lineWidth = 3 88 | ctx.strokeStyle = "black" 89 | ctx.fillStyle = "black" 90 | ctx.strokeRect(0, 0, canvas.width, canvas.height) 91 | 92 | // walls 93 | val wallWidth = 50 94 | val wallHeight = 50 95 | val wallTop = canvas.height - wallHeight 96 | ctx.fillRect(0, wallTop, wallWidth, wallHeight) 97 | ctx.fillRect(canvas.width - wallWidth, wallTop, wallWidth, wallHeight) 98 | } 99 | 100 | private def drawCart(canvas: Canvas, 101 | state: PoleBalancingProblem.PoleCartState, 102 | timeElapsed: Double): Unit = { 103 | val ctx = canvas.getContext("2d").asInstanceOf[dom.CanvasRenderingContext2D] 104 | 105 | val cartTopY = canvas.height - 50 106 | val cartWidth = 70 107 | val cartHeight = 30 108 | val cartMiddleX = 325 + (state.cartPosition * 100) 109 | val cartLeftX = cartMiddleX - cartWidth / 2 110 | 111 | val wheelY = canvas.height - 10 112 | val leftWheelX = cartMiddleX - 20 113 | val rightWheelX = cartMiddleX + 20 114 | 115 | // cart 116 | ctx.fillStyle = "blue" 117 | ctx.fillRect(cartLeftX, cartTopY, cartWidth, cartHeight) 118 | 119 | // left wheel 120 | ctx.beginPath() 121 | ctx.fillStyle = "blue" 122 | ctx.arc(leftWheelX, wheelY, 10.0, 0.0, 2 * Math.PI) 123 | ctx.fill() 124 | ctx.closePath() 125 | 126 | // right wheel 127 | ctx.beginPath() 128 | ctx.fillStyle = "blue" 129 | ctx.arc(rightWheelX, wheelY, 10.0, 0.0, 2 * Math.PI) 130 | ctx.fill() 131 | ctx.closePath() 132 | 133 | // pole 134 | val poleX = cartMiddleX 135 | val poleBottomY = cartTopY - 5 136 | val poleTopY = poleBottomY - 50 137 | 138 | ctx.beginPath() 139 | 140 | ctx.translate(poleX, poleBottomY) 141 | ctx.rotate(state.poleAngle) 142 | ctx.translate(-poleX, -poleBottomY) 143 | 144 | ctx.strokeStyle = "green" 145 | ctx.moveTo(poleX, poleBottomY) 146 | ctx.lineTo(poleX, poleTopY) 147 | ctx.lineWidth = 6 148 | ctx.stroke() 149 | 150 | ctx.closePath() 151 | 152 | // reset transform 153 | ctx.setTransform(1, 0, 0, 1, 0, 0) 154 | 155 | // time counter in top-left corner 156 | ctx.beginPath() 157 | ctx.strokeStyle = "black" 158 | ctx.lineWidth = 1 159 | ctx.strokeText(f"t = $timeElapsed%.2f", 10, 20) 160 | ctx.closePath() 161 | } 162 | 163 | } 164 | -------------------------------------------------------------------------------- /src/main/scala/rl/polecart/ui/QLearningUI.scala: -------------------------------------------------------------------------------- 1 | package rl.polecart.ui 2 | 3 | import org.scalajs.dom 4 | import org.scalajs.dom.html.{Canvas, Button} 5 | import rl.core._ 6 | import rl.polecart.core.PoleBalancingProblem 7 | import rl.polecart.core.PoleBalancingProblem._ 8 | 9 | import scala.scalajs.js.annotation.{JSExport, JSExportTopLevel} 10 | 11 | @JSExportTopLevel("PolecartQLearningUI") 12 | object QLearningUI { 13 | 14 | sealed trait UIState 15 | case object Idle extends UIState 16 | case object Stepping extends UIState 17 | case object RunningEpisode extends UIState 18 | case object RunningForever extends UIState 19 | 20 | private val initialPoleCartState: PoleCartState = 21 | PoleBalancingProblem.PoleCartState(0.0, 0.0, 0.0, 0.0) 22 | 23 | private val initialAgentData: QLearning[RoughPoleCartState, PushCart] = 24 | QLearning(α = 0.1, γ = 1.0, ε = 0.1, Q = Map.empty) 25 | 26 | private val env: Environment[PoleCartState, PushCart] = implicitly 27 | private val stateConversion: StateConversion[PoleCartState, RoughPoleCartState] = implicitly 28 | private val agentBehaviour 29 | : AgentBehaviour[QLearning[RoughPoleCartState, PushCart], RoughPoleCartState, PushCart] = 30 | implicitly 31 | 32 | @JSExport 33 | def main(document: dom.Document, 34 | canvas: Canvas, 35 | infoLabel: dom.Element, 36 | stepButton: Button, 37 | runEpisodeButton: Button, 38 | runForeverButton: Button, 39 | pauseButton: Button): Unit = { 40 | var uiState: UIState = Idle 41 | 42 | var agentData = initialAgentData 43 | var poleCartState = initialPoleCartState 44 | var timeElapsed = 0.0 45 | var maxTimeElapsed = 0.0 46 | var episodeCount = 1 47 | 48 | def step(): Unit = { 49 | timeElapsed += 0.02 50 | 51 | val currentState = stateConversion.convertState(poleCartState) 52 | val (nextAction, updateAgent) = 53 | agentBehaviour.chooseAction(agentData, currentState, allActions) 54 | val (nextState, reward) = env.step(poleCartState, nextAction) 55 | 56 | agentData = updateAgent(ActionResult(reward, stateConversion.convertState(nextState))) 57 | poleCartState = nextState 58 | 59 | drawCart(canvas, poleCartState, episodeCount, timeElapsed) 60 | updateTable(document, agentData.Q) 61 | } 62 | 63 | def endOfEpisode(): Unit = { 64 | maxTimeElapsed = maxTimeElapsed max timeElapsed 65 | timeElapsed = 0.0 66 | episodeCount += 1 67 | poleCartState = initialPoleCartState 68 | infoLabel.textContent = f"Longest episode so far: $maxTimeElapsed%.2f seconds" 69 | } 70 | 71 | def tick(): Unit = { 72 | clear(canvas) 73 | 74 | uiState match { 75 | case Idle => 76 | drawCart(canvas, poleCartState, episodeCount, timeElapsed) 77 | 78 | case Stepping => 79 | step() 80 | if (env.isTerminal(poleCartState)) { 81 | endOfEpisode() 82 | } 83 | uiState = Idle 84 | 85 | case RunningEpisode => 86 | step() 87 | if (env.isTerminal(poleCartState)) { 88 | endOfEpisode() 89 | uiState = Idle 90 | } 91 | 92 | case RunningForever => 93 | step() 94 | if (env.isTerminal(poleCartState)) { 95 | endOfEpisode() 96 | } 97 | } 98 | } 99 | 100 | stepButton.onclick = _ => uiState = Stepping 101 | runEpisodeButton.onclick = _ => uiState = RunningEpisode 102 | runForeverButton.onclick = _ => uiState = RunningForever 103 | pauseButton.onclick = _ => uiState = Idle 104 | 105 | dom.window.setInterval(() => tick(), 20) 106 | } 107 | 108 | private def clear(canvas: Canvas): Unit = { 109 | val ctx = canvas.getContext("2d").asInstanceOf[dom.CanvasRenderingContext2D] 110 | 111 | // clear the canvas 112 | ctx.fillStyle = "white" 113 | ctx.clearRect(0, 0, canvas.width, canvas.height) 114 | 115 | // border 116 | ctx.lineWidth = 3 117 | ctx.strokeStyle = "black" 118 | ctx.fillStyle = "black" 119 | ctx.strokeRect(0, 0, canvas.width, canvas.height) 120 | 121 | // walls 122 | val wallWidth = 50 123 | val wallHeight = 50 124 | val wallTop = canvas.height - wallHeight 125 | ctx.fillRect(0, wallTop, wallWidth, wallHeight) 126 | ctx.fillRect(canvas.width - wallWidth, wallTop, wallWidth, wallHeight) 127 | } 128 | 129 | private def drawCart(canvas: Canvas, 130 | state: PoleBalancingProblem.PoleCartState, 131 | episodeCount: Int, 132 | timeElapsed: Double): Unit = { 133 | val ctx = canvas.getContext("2d").asInstanceOf[dom.CanvasRenderingContext2D] 134 | 135 | val cartTopY = canvas.height - 50 136 | val cartWidth = 70 137 | val cartHeight = 30 138 | val cartMiddleX = 325 + (state.cartPosition * 100) 139 | val cartLeftX = cartMiddleX - cartWidth / 2 140 | 141 | val wheelY = canvas.height - 10 142 | val leftWheelX = cartMiddleX - 20 143 | val rightWheelX = cartMiddleX + 20 144 | 145 | // cart 146 | ctx.fillStyle = "blue" 147 | ctx.fillRect(cartLeftX, cartTopY, cartWidth, cartHeight) 148 | 149 | // left wheel 150 | ctx.beginPath() 151 | ctx.fillStyle = "blue" 152 | ctx.arc(leftWheelX, wheelY, 10.0, 0.0, 2 * Math.PI) 153 | ctx.fill() 154 | ctx.closePath() 155 | 156 | // right wheel 157 | ctx.beginPath() 158 | ctx.fillStyle = "blue" 159 | ctx.arc(rightWheelX, wheelY, 10.0, 0.0, 2 * Math.PI) 160 | ctx.fill() 161 | ctx.closePath() 162 | 163 | // pole 164 | val poleX = cartMiddleX 165 | val poleBottomY = cartTopY - 5 166 | val poleTopY = poleBottomY - 50 167 | 168 | ctx.beginPath() 169 | 170 | ctx.translate(poleX, poleBottomY) 171 | ctx.rotate(state.poleAngle) 172 | ctx.translate(-poleX, -poleBottomY) 173 | 174 | ctx.strokeStyle = "green" 175 | ctx.moveTo(poleX, poleBottomY) 176 | ctx.lineTo(poleX, poleTopY) 177 | ctx.lineWidth = 6 178 | ctx.stroke() 179 | 180 | ctx.closePath() 181 | 182 | // reset transform 183 | ctx.setTransform(1, 0, 0, 1, 0, 0) 184 | 185 | // Time counter in top-left corner 186 | ctx.beginPath() 187 | ctx.strokeStyle = "black" 188 | ctx.lineWidth = 1 189 | ctx.strokeText(s"Episode $episodeCount", 10, 20) 190 | ctx.strokeText(f"t = $timeElapsed%.2f", 10, 40) 191 | ctx.closePath() 192 | } 193 | 194 | private def updateTable( 195 | document: dom.Document, 196 | Q: Map[PoleBalancingProblem.RoughPoleCartState, Map[PoleBalancingProblem.PushCart, Double]]) 197 | : Unit = { 198 | def lower(x: Any): String = x.toString.toLowerCase.replaceAllLiterally(" ", "") 199 | 200 | for { 201 | cartPos <- List(RoughCartPosition.Left, RoughCartPosition.Middle, RoughCartPosition.Right) 202 | cartVel <- List(RoughCartVelocity.FastLeft, 203 | RoughCartVelocity.Slow, 204 | RoughCartVelocity.FastRight) 205 | poleVel <- List(RoughPoleVelocity.FastLeft, 206 | RoughPoleVelocity.Slow, 207 | RoughPoleVelocity.FastRight) 208 | poleAngle <- List( 209 | RoughPoleAngle.VeryLeft, 210 | RoughPoleAngle.QuiteLeft, 211 | RoughPoleAngle.SlightlyLeft, 212 | RoughPoleAngle.SlightlyRight, 213 | RoughPoleAngle.QuiteRight, 214 | RoughPoleAngle.VeryRight 215 | ) 216 | } { 217 | val actionValues = 218 | Q.getOrElse(RoughPoleCartState(cartPos, cartVel, poleAngle, poleVel), Map.empty) 219 | val text = (actionValues.get(PushCart.Left), actionValues.get(PushCart.Right)) match { 220 | case (Some(l), Some(r)) => f"L: $l%.4f
R: $r%.4f" 221 | case _ => "" 222 | } 223 | 224 | val id = s"${lower(cartPos)}_${lower(cartVel)}_${lower(poleVel)}_${lower(poleAngle)}" 225 | document.getElementById(id).innerHTML = text 226 | } 227 | } 228 | 229 | } 230 | --------------------------------------------------------------------------------