├── .gitignore ├── README.md ├── pom.xml ├── resources └── maze.txt └── src ├── main └── java │ └── com │ └── technobium │ └── rl │ └── QLearning.java └── test └── java └── com └── technobium └── rl └── AppTest.java /.gitignore: -------------------------------------------------------------------------------- 1 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 2 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 3 | 4 | # User-specific stuff: 5 | .idea/workspace.xml 6 | .idea/tasks.xml 7 | 8 | # Sensitive or high-churn files: 9 | .idea/dataSources.ids 10 | .idea/dataSources.xml 11 | .idea/dataSources.local.xml 12 | .idea/sqlDataSources.xml 13 | .idea/dynamic.xml 14 | .idea/uiDesigner.xml 15 | 16 | # Gradle: 17 | .idea/gradle.xml 18 | .idea/libraries 19 | 20 | # Mongo Explorer plugin: 21 | .idea/mongoSettings.xml 22 | 23 | ## File-based project format: 24 | *.iws 25 | 26 | ## Plugin-specific files: 27 | 28 | # IntelliJ 29 | /out/ 30 | 31 | # mpeltonen/sbt-idea plugin 32 | .idea_modules/ 33 | 34 | # JIRA plugin 35 | atlassian-ide-plugin.xml 36 | 37 | # Crashlytics plugin (for Android Studio and IntelliJ) 38 | com_crashlytics_export_strings.xml 39 | crashlytics.properties 40 | crashlytics-build.properties 41 | fabric.properties 42 | /q-learning-java.iws 43 | /q-learning-java.ipr 44 | /q-learning-java.iml 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # q-learning-java 2 | Reinforcement learning - Q learning in Java. Code exlained here: http://technobium.com/reinforcement-learning-q-learning-java/ 3 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | com.technobium.rl 5 | q-learning-java 6 | jar 7 | 1.0-SNAPSHOT 8 | q-learning-java 9 | http://maven.apache.org 10 | 11 | 12 | junit 13 | junit 14 | 4.13.1 15 | test 16 | 17 | 18 | 19 | 20 | 21 | org.apache.maven.plugins 22 | maven-compiler-plugin 23 | 3.7.0 24 | 25 | 1.8 26 | 1.8 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /resources/maze.txt: -------------------------------------------------------------------------------- 1 | 0 0 0 2 | X 0 0 3 | 0 0 F -------------------------------------------------------------------------------- /src/main/java/com/technobium/rl/QLearning.java: -------------------------------------------------------------------------------- 1 | package com.technobium.rl; 2 | 3 | import java.io.File; 4 | import java.io.FileInputStream; 5 | import java.io.IOException; 6 | import java.util.ArrayList; 7 | import java.util.Random; 8 | 9 | public class QLearning { 10 | 11 | private final double alpha = 0.1; // Learning rate 12 | private final double gamma = 0.9; // Eagerness - 0 looks in the near future, 1 looks in the distant future 13 | 14 | private final int mazeWidth = 3; 15 | private final int mazeHeight = 3; 16 | private final int statesCount = mazeHeight * mazeWidth; 17 | 18 | private final int reward = 100; 19 | private final int penalty = -10; 20 | 21 | private char[][] maze; // Maze read from file 22 | private int[][] R; // Reward lookup 23 | private double[][] Q; // Q learning 24 | 25 | 26 | public static void main(String args[]) { 27 | QLearning ql = new QLearning(); 28 | 29 | ql.init(); 30 | ql.calculateQ(); 31 | ql.printQ(); 32 | ql.printPolicy(); 33 | } 34 | 35 | public void init() { 36 | File file = new File("resources/maze.txt"); 37 | 38 | R = new int[statesCount][statesCount]; 39 | Q = new double[statesCount][statesCount]; 40 | maze = new char[mazeHeight][mazeWidth]; 41 | 42 | 43 | try (FileInputStream fis = new FileInputStream(file)) { 44 | 45 | int i = 0; 46 | int j = 0; 47 | 48 | int content; 49 | 50 | // Read the maze from the input file 51 | while ((content = fis.read()) != -1) { 52 | char c = (char) content; 53 | if (c != '0' && c != 'F' && c != 'X') { 54 | continue; 55 | } 56 | maze[i][j] = c; 57 | j++; 58 | if (j == mazeWidth) { 59 | j = 0; 60 | i++; 61 | } 62 | } 63 | 64 | // We will navigate through the reward matrix R using k index 65 | for (int k = 0; k < statesCount; k++) { 66 | 67 | // We will navigate with i and j through the maze, so we need 68 | // to translate k into i and j 69 | i = k / mazeWidth; 70 | j = k - i * mazeWidth; 71 | 72 | // Fill in the reward matrix with -1 73 | for (int s = 0; s < statesCount; s++) { 74 | R[k][s] = -1; 75 | } 76 | 77 | // If not in final state or a wall try moving in all directions in the maze 78 | if (maze[i][j] != 'F') { 79 | 80 | // Try to move left in the maze 81 | int goLeft = j - 1; 82 | if (goLeft >= 0) { 83 | int target = i * mazeWidth + goLeft; 84 | if (maze[i][goLeft] == '0') { 85 | R[k][target] = 0; 86 | } else if (maze[i][goLeft] == 'F') { 87 | R[k][target] = reward; 88 | } else { 89 | R[k][target] = penalty; 90 | } 91 | } 92 | 93 | // Try to move right in the maze 94 | int goRight = j + 1; 95 | if (goRight < mazeWidth) { 96 | int target = i * mazeWidth + goRight; 97 | if (maze[i][goRight] == '0') { 98 | R[k][target] = 0; 99 | } else if (maze[i][goRight] == 'F') { 100 | R[k][target] = reward; 101 | } else { 102 | R[k][target] = penalty; 103 | } 104 | } 105 | 106 | // Try to move up in the maze 107 | int goUp = i - 1; 108 | if (goUp >= 0) { 109 | int target = goUp * mazeWidth + j; 110 | if (maze[goUp][j] == '0') { 111 | R[k][target] = 0; 112 | } else if (maze[goUp][j] == 'F') { 113 | R[k][target] = reward; 114 | } else { 115 | R[k][target] = penalty; 116 | } 117 | } 118 | 119 | // Try to move down in the maze 120 | int goDown = i + 1; 121 | if (goDown < mazeHeight) { 122 | int target = goDown * mazeWidth + j; 123 | if (maze[goDown][j] == '0') { 124 | R[k][target] = 0; 125 | } else if (maze[goDown][j] == 'F') { 126 | R[k][target] = reward; 127 | } else { 128 | R[k][target] = penalty; 129 | } 130 | } 131 | } 132 | } 133 | initializeQ(); 134 | printR(R); 135 | } catch (IOException e) { 136 | e.printStackTrace(); 137 | } 138 | } 139 | //Set Q values to R values 140 | void initializeQ() 141 | { 142 | for (int i = 0; i < statesCount; i++){ 143 | for(int j = 0; j < statesCount; j++){ 144 | Q[i][j] = (double)R[i][j]; 145 | } 146 | } 147 | } 148 | // Used for debug 149 | void printR(int[][] matrix) { 150 | System.out.printf("%25s", "States: "); 151 | for (int i = 0; i <= 8; i++) { 152 | System.out.printf("%4s", i); 153 | } 154 | System.out.println(); 155 | 156 | for (int i = 0; i < statesCount; i++) { 157 | System.out.print("Possible states from " + i + " :["); 158 | for (int j = 0; j < statesCount; j++) { 159 | System.out.printf("%4s", matrix[i][j]); 160 | } 161 | System.out.println("]"); 162 | } 163 | } 164 | 165 | void calculateQ() { 166 | Random rand = new Random(); 167 | 168 | for (int i = 0; i < 1000; i++) { // Train cycles 169 | // Select random initial state 170 | int crtState = rand.nextInt(statesCount); 171 | 172 | while (!isFinalState(crtState)) { 173 | int[] actionsFromCurrentState = possibleActionsFromState(crtState); 174 | 175 | // Pick a random action from the ones possible 176 | int index = rand.nextInt(actionsFromCurrentState.length); 177 | int nextState = actionsFromCurrentState[index]; 178 | 179 | // Q(state,action)= Q(state,action) + alpha * (R(state,action) + gamma * Max(next state, all actions) - Q(state,action)) 180 | double q = Q[crtState][nextState]; 181 | double maxQ = maxQ(nextState); 182 | int r = R[crtState][nextState]; 183 | 184 | double value = q + alpha * (r + gamma * maxQ - q); 185 | Q[crtState][nextState] = value; 186 | 187 | crtState = nextState; 188 | } 189 | } 190 | } 191 | 192 | boolean isFinalState(int state) { 193 | int i = state / mazeWidth; 194 | int j = state - i * mazeWidth; 195 | 196 | return maze[i][j] == 'F'; 197 | } 198 | 199 | int[] possibleActionsFromState(int state) { 200 | ArrayList result = new ArrayList<>(); 201 | for (int i = 0; i < statesCount; i++) { 202 | if (R[state][i] != -1) { 203 | result.add(i); 204 | } 205 | } 206 | 207 | return result.stream().mapToInt(i -> i).toArray(); 208 | } 209 | 210 | double maxQ(int nextState) { 211 | int[] actionsFromState = possibleActionsFromState(nextState); 212 | //the learning rate and eagerness will keep the W value above the lowest reward 213 | double maxValue = -10; 214 | for (int nextAction : actionsFromState) { 215 | double value = Q[nextState][nextAction]; 216 | 217 | if (value > maxValue) 218 | maxValue = value; 219 | } 220 | return maxValue; 221 | } 222 | 223 | void printPolicy() { 224 | System.out.println("\nPrint policy"); 225 | for (int i = 0; i < statesCount; i++) { 226 | System.out.println("From state " + i + " goto state " + getPolicyFromState(i)); 227 | } 228 | } 229 | 230 | int getPolicyFromState(int state) { 231 | int[] actionsFromState = possibleActionsFromState(state); 232 | 233 | double maxValue = Double.MIN_VALUE; 234 | int policyGotoState = state; 235 | 236 | // Pick to move to the state that has the maximum Q value 237 | for (int nextState : actionsFromState) { 238 | double value = Q[state][nextState]; 239 | 240 | if (value > maxValue) { 241 | maxValue = value; 242 | policyGotoState = nextState; 243 | } 244 | } 245 | return policyGotoState; 246 | } 247 | 248 | void printQ() { 249 | System.out.println("Q matrix"); 250 | for (int i = 0; i < Q.length; i++) { 251 | System.out.print("From state " + i + ": "); 252 | for (int j = 0; j < Q[i].length; j++) { 253 | System.out.printf("%6.2f ", (Q[i][j])); 254 | } 255 | System.out.println(); 256 | } 257 | } 258 | } 259 | -------------------------------------------------------------------------------- /src/test/java/com/technobium/rl/AppTest.java: -------------------------------------------------------------------------------- 1 | package com.technobium.rl; 2 | 3 | import junit.framework.Test; 4 | import junit.framework.TestCase; 5 | import junit.framework.TestSuite; 6 | 7 | /** 8 | * Unit test for simple App. 9 | */ 10 | public class AppTest 11 | extends TestCase { 12 | /** 13 | * Create the test case 14 | * 15 | * @param testName name of the test case 16 | */ 17 | public AppTest(String testName) { 18 | super(testName); 19 | } 20 | 21 | /** 22 | * @return the suite of tests being tested 23 | */ 24 | public static Test suite() { 25 | return new TestSuite(AppTest.class); 26 | } 27 | 28 | /** 29 | * Rigourous Test :-) 30 | */ 31 | public void testApp() { 32 | assertTrue(true); 33 | } 34 | } 35 | --------------------------------------------------------------------------------