├── .gitignore
├── README.md
├── pom.xml
├── resources
└── maze.txt
└── src
├── main
└── java
│ └── com
│ └── technobium
│ └── rl
│ └── QLearning.java
└── test
└── java
└── com
└── technobium
└── rl
└── AppTest.java
/.gitignore:
--------------------------------------------------------------------------------
1 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
2 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
3 |
4 | # User-specific stuff:
5 | .idea/workspace.xml
6 | .idea/tasks.xml
7 |
8 | # Sensitive or high-churn files:
9 | .idea/dataSources.ids
10 | .idea/dataSources.xml
11 | .idea/dataSources.local.xml
12 | .idea/sqlDataSources.xml
13 | .idea/dynamic.xml
14 | .idea/uiDesigner.xml
15 |
16 | # Gradle:
17 | .idea/gradle.xml
18 | .idea/libraries
19 |
20 | # Mongo Explorer plugin:
21 | .idea/mongoSettings.xml
22 |
23 | ## File-based project format:
24 | *.iws
25 |
26 | ## Plugin-specific files:
27 |
28 | # IntelliJ
29 | /out/
30 |
31 | # mpeltonen/sbt-idea plugin
32 | .idea_modules/
33 |
34 | # JIRA plugin
35 | atlassian-ide-plugin.xml
36 |
37 | # Crashlytics plugin (for Android Studio and IntelliJ)
38 | com_crashlytics_export_strings.xml
39 | crashlytics.properties
40 | crashlytics-build.properties
41 | fabric.properties
42 | /q-learning-java.iws
43 | /q-learning-java.ipr
44 | /q-learning-java.iml
45 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # q-learning-java
2 | Reinforcement learning - Q learning in Java. Code exlained here: http://technobium.com/reinforcement-learning-q-learning-java/
3 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
3 | 4.0.0
4 | com.technobium.rl
5 | q-learning-java
6 | jar
7 | 1.0-SNAPSHOT
8 | q-learning-java
9 | http://maven.apache.org
10 |
11 |
12 | junit
13 | junit
14 | 4.13.1
15 | test
16 |
17 |
18 |
19 |
20 |
21 | org.apache.maven.plugins
22 | maven-compiler-plugin
23 | 3.7.0
24 |
25 | 1.8
26 | 1.8
27 |
28 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/resources/maze.txt:
--------------------------------------------------------------------------------
1 | 0 0 0
2 | X 0 0
3 | 0 0 F
--------------------------------------------------------------------------------
/src/main/java/com/technobium/rl/QLearning.java:
--------------------------------------------------------------------------------
1 | package com.technobium.rl;
2 |
3 | import java.io.File;
4 | import java.io.FileInputStream;
5 | import java.io.IOException;
6 | import java.util.ArrayList;
7 | import java.util.Random;
8 |
9 | public class QLearning {
10 |
11 | private final double alpha = 0.1; // Learning rate
12 | private final double gamma = 0.9; // Eagerness - 0 looks in the near future, 1 looks in the distant future
13 |
14 | private final int mazeWidth = 3;
15 | private final int mazeHeight = 3;
16 | private final int statesCount = mazeHeight * mazeWidth;
17 |
18 | private final int reward = 100;
19 | private final int penalty = -10;
20 |
21 | private char[][] maze; // Maze read from file
22 | private int[][] R; // Reward lookup
23 | private double[][] Q; // Q learning
24 |
25 |
26 | public static void main(String args[]) {
27 | QLearning ql = new QLearning();
28 |
29 | ql.init();
30 | ql.calculateQ();
31 | ql.printQ();
32 | ql.printPolicy();
33 | }
34 |
35 | public void init() {
36 | File file = new File("resources/maze.txt");
37 |
38 | R = new int[statesCount][statesCount];
39 | Q = new double[statesCount][statesCount];
40 | maze = new char[mazeHeight][mazeWidth];
41 |
42 |
43 | try (FileInputStream fis = new FileInputStream(file)) {
44 |
45 | int i = 0;
46 | int j = 0;
47 |
48 | int content;
49 |
50 | // Read the maze from the input file
51 | while ((content = fis.read()) != -1) {
52 | char c = (char) content;
53 | if (c != '0' && c != 'F' && c != 'X') {
54 | continue;
55 | }
56 | maze[i][j] = c;
57 | j++;
58 | if (j == mazeWidth) {
59 | j = 0;
60 | i++;
61 | }
62 | }
63 |
64 | // We will navigate through the reward matrix R using k index
65 | for (int k = 0; k < statesCount; k++) {
66 |
67 | // We will navigate with i and j through the maze, so we need
68 | // to translate k into i and j
69 | i = k / mazeWidth;
70 | j = k - i * mazeWidth;
71 |
72 | // Fill in the reward matrix with -1
73 | for (int s = 0; s < statesCount; s++) {
74 | R[k][s] = -1;
75 | }
76 |
77 | // If not in final state or a wall try moving in all directions in the maze
78 | if (maze[i][j] != 'F') {
79 |
80 | // Try to move left in the maze
81 | int goLeft = j - 1;
82 | if (goLeft >= 0) {
83 | int target = i * mazeWidth + goLeft;
84 | if (maze[i][goLeft] == '0') {
85 | R[k][target] = 0;
86 | } else if (maze[i][goLeft] == 'F') {
87 | R[k][target] = reward;
88 | } else {
89 | R[k][target] = penalty;
90 | }
91 | }
92 |
93 | // Try to move right in the maze
94 | int goRight = j + 1;
95 | if (goRight < mazeWidth) {
96 | int target = i * mazeWidth + goRight;
97 | if (maze[i][goRight] == '0') {
98 | R[k][target] = 0;
99 | } else if (maze[i][goRight] == 'F') {
100 | R[k][target] = reward;
101 | } else {
102 | R[k][target] = penalty;
103 | }
104 | }
105 |
106 | // Try to move up in the maze
107 | int goUp = i - 1;
108 | if (goUp >= 0) {
109 | int target = goUp * mazeWidth + j;
110 | if (maze[goUp][j] == '0') {
111 | R[k][target] = 0;
112 | } else if (maze[goUp][j] == 'F') {
113 | R[k][target] = reward;
114 | } else {
115 | R[k][target] = penalty;
116 | }
117 | }
118 |
119 | // Try to move down in the maze
120 | int goDown = i + 1;
121 | if (goDown < mazeHeight) {
122 | int target = goDown * mazeWidth + j;
123 | if (maze[goDown][j] == '0') {
124 | R[k][target] = 0;
125 | } else if (maze[goDown][j] == 'F') {
126 | R[k][target] = reward;
127 | } else {
128 | R[k][target] = penalty;
129 | }
130 | }
131 | }
132 | }
133 | initializeQ();
134 | printR(R);
135 | } catch (IOException e) {
136 | e.printStackTrace();
137 | }
138 | }
139 | //Set Q values to R values
140 | void initializeQ()
141 | {
142 | for (int i = 0; i < statesCount; i++){
143 | for(int j = 0; j < statesCount; j++){
144 | Q[i][j] = (double)R[i][j];
145 | }
146 | }
147 | }
148 | // Used for debug
149 | void printR(int[][] matrix) {
150 | System.out.printf("%25s", "States: ");
151 | for (int i = 0; i <= 8; i++) {
152 | System.out.printf("%4s", i);
153 | }
154 | System.out.println();
155 |
156 | for (int i = 0; i < statesCount; i++) {
157 | System.out.print("Possible states from " + i + " :[");
158 | for (int j = 0; j < statesCount; j++) {
159 | System.out.printf("%4s", matrix[i][j]);
160 | }
161 | System.out.println("]");
162 | }
163 | }
164 |
165 | void calculateQ() {
166 | Random rand = new Random();
167 |
168 | for (int i = 0; i < 1000; i++) { // Train cycles
169 | // Select random initial state
170 | int crtState = rand.nextInt(statesCount);
171 |
172 | while (!isFinalState(crtState)) {
173 | int[] actionsFromCurrentState = possibleActionsFromState(crtState);
174 |
175 | // Pick a random action from the ones possible
176 | int index = rand.nextInt(actionsFromCurrentState.length);
177 | int nextState = actionsFromCurrentState[index];
178 |
179 | // Q(state,action)= Q(state,action) + alpha * (R(state,action) + gamma * Max(next state, all actions) - Q(state,action))
180 | double q = Q[crtState][nextState];
181 | double maxQ = maxQ(nextState);
182 | int r = R[crtState][nextState];
183 |
184 | double value = q + alpha * (r + gamma * maxQ - q);
185 | Q[crtState][nextState] = value;
186 |
187 | crtState = nextState;
188 | }
189 | }
190 | }
191 |
192 | boolean isFinalState(int state) {
193 | int i = state / mazeWidth;
194 | int j = state - i * mazeWidth;
195 |
196 | return maze[i][j] == 'F';
197 | }
198 |
199 | int[] possibleActionsFromState(int state) {
200 | ArrayList result = new ArrayList<>();
201 | for (int i = 0; i < statesCount; i++) {
202 | if (R[state][i] != -1) {
203 | result.add(i);
204 | }
205 | }
206 |
207 | return result.stream().mapToInt(i -> i).toArray();
208 | }
209 |
210 | double maxQ(int nextState) {
211 | int[] actionsFromState = possibleActionsFromState(nextState);
212 | //the learning rate and eagerness will keep the W value above the lowest reward
213 | double maxValue = -10;
214 | for (int nextAction : actionsFromState) {
215 | double value = Q[nextState][nextAction];
216 |
217 | if (value > maxValue)
218 | maxValue = value;
219 | }
220 | return maxValue;
221 | }
222 |
223 | void printPolicy() {
224 | System.out.println("\nPrint policy");
225 | for (int i = 0; i < statesCount; i++) {
226 | System.out.println("From state " + i + " goto state " + getPolicyFromState(i));
227 | }
228 | }
229 |
230 | int getPolicyFromState(int state) {
231 | int[] actionsFromState = possibleActionsFromState(state);
232 |
233 | double maxValue = Double.MIN_VALUE;
234 | int policyGotoState = state;
235 |
236 | // Pick to move to the state that has the maximum Q value
237 | for (int nextState : actionsFromState) {
238 | double value = Q[state][nextState];
239 |
240 | if (value > maxValue) {
241 | maxValue = value;
242 | policyGotoState = nextState;
243 | }
244 | }
245 | return policyGotoState;
246 | }
247 |
248 | void printQ() {
249 | System.out.println("Q matrix");
250 | for (int i = 0; i < Q.length; i++) {
251 | System.out.print("From state " + i + ": ");
252 | for (int j = 0; j < Q[i].length; j++) {
253 | System.out.printf("%6.2f ", (Q[i][j]));
254 | }
255 | System.out.println();
256 | }
257 | }
258 | }
259 |
--------------------------------------------------------------------------------
/src/test/java/com/technobium/rl/AppTest.java:
--------------------------------------------------------------------------------
1 | package com.technobium.rl;
2 |
3 | import junit.framework.Test;
4 | import junit.framework.TestCase;
5 | import junit.framework.TestSuite;
6 |
7 | /**
8 | * Unit test for simple App.
9 | */
10 | public class AppTest
11 | extends TestCase {
12 | /**
13 | * Create the test case
14 | *
15 | * @param testName name of the test case
16 | */
17 | public AppTest(String testName) {
18 | super(testName);
19 | }
20 |
21 | /**
22 | * @return the suite of tests being tested
23 | */
24 | public static Test suite() {
25 | return new TestSuite(AppTest.class);
26 | }
27 |
28 | /**
29 | * Rigourous Test :-)
30 | */
31 | public void testApp() {
32 | assertTrue(true);
33 | }
34 | }
35 |
--------------------------------------------------------------------------------