├── __init__.py ├── images └── TensorFlowPlayCatchGame.gif ├── README.md ├── LICENSE ├── PlayCatchGame.py └── TrainCatchGame.py /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/TensorFlowPlayCatchGame.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/solaris33/CatchGame-QLearningExample-TensorFlow/HEAD/images/TensorFlowPlayCatchGame.gif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CatchGame - Simple Q-Learning example by TensorFlow 2 | ![TensorFlowPlayCatch](https://github.com/solaris33/CatchGame-QLearningExample-TensorFlow/blob/master/images/TensorFlowPlayCatchGame.gif) 3 | 4 | Simple catch game DQN agent is implemented by TensorFlow 5 | 6 | Original code is written by Torch7 and Keras 7 | 8 | Keras code (written by Eder Santanas) : [here](https://gist.github.com/EderSantana/c7222daa328f0e885093) 9 | 10 | Torch7 code (written by SeanNaren): [here](https://github.com/SeanNaren/TorchQLearningExample) 11 | 12 | 13 | ## Dependencies 14 | 15 | [TensorFlow](https://www.tensorflow.org/versions/r0.10/get_started/os_setup.html) 16 | 17 | 18 | ## How to run 19 | 20 | To train a model, run the TrainCatchGame.py script. 21 | ``` 22 | python TrainCatchGame.py 23 | ``` 24 | 25 | 26 | ## Play and Visualization 27 | Play and Visualization implemented in iPython. 28 | 29 | To run, type into terminal: 30 | 31 | ``` 32 | iPython notebook 33 | ``` 34 | 35 | go to this directory and run PlayCatchGame.py script in kernel 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 SOLARIS 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PlayCatchGame.py: -------------------------------------------------------------------------------- 1 | # To run this code you must use iPython. Also you can use the .ipynb file in ipython notebook mode. 2 | 3 | %matplotlib 4 | %matplotlib inline 5 | 6 | from TrainCatchGame import CatchEnvironment, X, W1, b1, input_layer, W2, b2, hidden_layer, W3, b3, output_layer, Y, cost, optimizer 7 | from IPython import display 8 | import matplotlib.pyplot as plt 9 | import matplotlib.patches as patches 10 | import pylab as pl 11 | import time 12 | import tensorflow as tf 13 | import math 14 | import os 15 | 16 | 17 | gridSize = 10 # The size of the grid that the agent is going to play the game on. 18 | maxGames = 100 19 | env = CatchEnvironment(gridSize) 20 | winCount = 0 21 | loseCount = 0 22 | numberOfGames = 0 23 | 24 | ground = 1 25 | plot = pl.figure(figsize=(12,12)) 26 | axis = plot.add_subplot(111, aspect='equal') 27 | axis.set_xlim([-1, 12]) 28 | axis.set_ylim([0, 12]) 29 | 30 | # Add ops to save and restore all the variables. 31 | saver = tf.train.Saver() 32 | 33 | def drawState(fruitRow, fruitColumn, basket): 34 | global gridSize 35 | # column is the x axis 36 | fruitX = fruitColumn 37 | # Invert matrix style points to coordinates 38 | fruitY = (gridSize - fruitRow + 1) 39 | statusTitle = "Wins: " + str(winCount) + " Losses: " + str(loseCount) + " TotalGame: " + str(numberOfGames) 40 | axis.set_title(statusTitle, fontsize=30) 41 | for p in [ 42 | patches.Rectangle( 43 | ((ground - 1), (ground)), 11, 10, 44 | facecolor="#000000" # Black 45 | ), 46 | patches.Rectangle( 47 | (basket - 1, ground), 2, 0.5, 48 | facecolor="#FF0000" # No background 49 | ), 50 | patches.Rectangle( 51 | (fruitX - 0.5, fruitY - 0.5), 1, 1, 52 | facecolor="#FF0000" # red 53 | ), 54 | ]: 55 | axis.add_patch(p) 56 | display.clear_output(wait=True) 57 | display.display(pl.gcf()) 58 | 59 | 60 | with tf.Session() as sess: 61 | # Restore variables from disk. 62 | saver.restore(sess, os.getcwd()+"/model.ckpt") 63 | print('saved model is loaded!') 64 | 65 | while (numberOfGames < maxGames): 66 | numberOfGames = numberOfGames + 1 67 | 68 | # The initial state of the environment. 69 | isGameOver = False 70 | fruitRow, fruitColumn, basket = env.reset() 71 | currentState = env.observe() 72 | drawState(fruitRow, fruitColumn, basket) 73 | 74 | while (isGameOver != True): 75 | # Forward the current state through the network. 76 | q = sess.run(output_layer, feed_dict={X: currentState}) 77 | # Find the max index (the chosen action). 78 | index = q.argmax() 79 | action = index + 1 80 | nextState, reward, gameOver, stateInfo = env.act(action) 81 | fruitRow = stateInfo[0] 82 | fruitColumn = stateInfo[1] 83 | basket = stateInfo[2] 84 | 85 | # Count game results 86 | if (reward == 1): 87 | winCount = winCount + 1 88 | elif (reward == -1): 89 | loseCount = loseCount + 1 90 | 91 | currentState = nextState 92 | isGameOver = gameOver 93 | drawState(fruitRow, fruitColumn, basket) 94 | time.sleep(0.4) 95 | 96 | display.clear_output(wait=True) 97 | 98 | 99 | -------------------------------------------------------------------------------- /TrainCatchGame.py: -------------------------------------------------------------------------------- 1 | """ 2 | TensorFlow translation of the torch example found here (written by SeanNaren). 3 | https://github.com/SeanNaren/TorchQLearningExample 4 | 5 | Original keras example found here (written by Eder Santana). 6 | https://gist.github.com/EderSantana/c7222daa328f0e885093#file-qlearn-py-L164 7 | 8 | The agent plays a game of catch. Fruits drop from the sky and the agent can choose the actions 9 | left/stay/right to catch the fruit before it reaches the ground. 10 | """ 11 | 12 | import tensorflow as tf 13 | import numpy as np 14 | import random 15 | import math 16 | import os 17 | 18 | # Parameters 19 | epsilon = 1 # The probability of choosing a random action (in training). This decays as iterations increase. (0 to 1) 20 | epsilonMinimumValue = 0.001 # The minimum value we want epsilon to reach in training. (0 to 1) 21 | nbActions = 3 # The number of actions. Since we only have left/stay/right that means 3 actions. 22 | epoch = 1001 # The number of games we want the system to run for. 23 | hiddenSize = 100 # Number of neurons in the hidden layers. 24 | maxMemory = 500 # How large should the memory be (where it stores its past experiences). 25 | batchSize = 50 # The mini-batch size for training. Samples are randomly taken from memory till mini-batch size. 26 | gridSize = 10 # The size of the grid that the agent is going to play the game on. 27 | nbStates = gridSize * gridSize # We eventually flatten to a 1d tensor to feed the network. 28 | discount = 0.9 # The discount is used to force the network to choose states that lead to the reward quicker (0 to 1) 29 | learningRate = 0.2 # Learning Rate for Stochastic Gradient Descent (our optimizer). 30 | 31 | # Create the base model. 32 | X = tf.placeholder(tf.float32, [None, nbStates]) 33 | W1 = tf.Variable(tf.truncated_normal([nbStates, hiddenSize], stddev=1.0 / math.sqrt(float(nbStates)))) 34 | b1 = tf.Variable(tf.truncated_normal([hiddenSize], stddev=0.01)) 35 | input_layer = tf.nn.relu(tf.matmul(X, W1) + b1) 36 | W2 = tf.Variable(tf.truncated_normal([hiddenSize, hiddenSize],stddev=1.0 / math.sqrt(float(hiddenSize)))) 37 | b2 = tf.Variable(tf.truncated_normal([hiddenSize], stddev=0.01)) 38 | hidden_layer = tf.nn.relu(tf.matmul(input_layer, W2) + b2) 39 | W3 = tf.Variable(tf.truncated_normal([hiddenSize, nbActions],stddev=1.0 / math.sqrt(float(hiddenSize)))) 40 | b3 = tf.Variable(tf.truncated_normal([nbActions], stddev=0.01)) 41 | output_layer = tf.matmul(hidden_layer, W3) + b3 42 | 43 | # True labels 44 | Y = tf.placeholder(tf.float32, [None, nbActions]) 45 | 46 | # Mean squared error cost function 47 | cost = tf.reduce_sum(tf.square(Y-output_layer)) / (2*batchSize) 48 | 49 | # Stochastic Gradient Decent Optimizer 50 | optimizer = tf.train.GradientDescentOptimizer(learningRate).minimize(cost) 51 | 52 | 53 | # Helper function: Chooses a random value between the two boundaries. 54 | def randf(s, e): 55 | return (float(random.randrange(0, (e - s) * 9999)) / 10000) + s; 56 | 57 | 58 | # The environment: Handles interactions and contains the state of the environment 59 | class CatchEnvironment(): 60 | def __init__(self, gridSize): 61 | self.gridSize = gridSize 62 | self.nbStates = self.gridSize * self.gridSize 63 | self.state = np.empty(3, dtype = np.uint8) 64 | 65 | # Returns the state of the environment. 66 | def observe(self): 67 | canvas = self.drawState() 68 | canvas = np.reshape(canvas, (-1,self.nbStates)) 69 | return canvas 70 | 71 | def drawState(self): 72 | canvas = np.zeros((self.gridSize, self.gridSize)) 73 | canvas[self.state[0]-1, self.state[1]-1] = 1 # Draw the fruit. 74 | # Draw the basket. The basket takes the adjacent two places to the position of basket. 75 | canvas[self.gridSize-1, self.state[2] -1 - 1] = 1 76 | canvas[self.gridSize-1, self.state[2] -1] = 1 77 | canvas[self.gridSize-1, self.state[2] -1 + 1] = 1 78 | return canvas 79 | 80 | # Resets the environment. Randomly initialise the fruit position (always at the top to begin with) and bucket. 81 | def reset(self): 82 | initialFruitColumn = random.randrange(1, self.gridSize + 1) 83 | initialBucketPosition = random.randrange(2, self.gridSize + 1 - 1) 84 | self.state = np.array([1, initialFruitColumn, initialBucketPosition]) 85 | return self.getState() 86 | 87 | def getState(self): 88 | stateInfo = self.state 89 | fruit_row = stateInfo[0] 90 | fruit_col = stateInfo[1] 91 | basket = stateInfo[2] 92 | return fruit_row, fruit_col, basket 93 | 94 | # Returns the award that the agent has gained for being in the current environment state. 95 | def getReward(self): 96 | fruitRow, fruitColumn, basket = self.getState() 97 | if (fruitRow == self.gridSize - 1): # If the fruit has reached the bottom. 98 | if (abs(fruitColumn - basket) <= 1): # Check if the basket caught the fruit. 99 | return 1 100 | else: 101 | return -1 102 | else: 103 | return 0 104 | 105 | def isGameOver(self): 106 | if (self.state[0] == self.gridSize - 1): 107 | return True 108 | else: 109 | return False 110 | 111 | def updateState(self, action): 112 | if (action == 1): 113 | action = -1 114 | elif (action == 2): 115 | action = 0 116 | else: 117 | action = 1 118 | fruitRow, fruitColumn, basket = self.getState() 119 | newBasket = min(max(2, basket + action), self.gridSize - 1) # The min/max prevents the basket from moving out of the grid. 120 | fruitRow = fruitRow + 1 # The fruit is falling by 1 every action. 121 | self.state = np.array([fruitRow, fruitColumn, newBasket]) 122 | 123 | #Action can be 1 (move left) or 2 (move right) 124 | def act(self, action): 125 | self.updateState(action) 126 | reward = self.getReward() 127 | gameOver = self.isGameOver() 128 | return self.observe(), reward, gameOver, self.getState() # For purpose of the visual, I also return the state. 129 | 130 | 131 | # The memory: Handles the internal memory that we add experiences that occur based on agent's actions, 132 | # and creates batches of experiences based on the mini-batch size for training. 133 | class ReplayMemory: 134 | def __init__(self, gridSize, maxMemory, discount): 135 | self.maxMemory = maxMemory 136 | self.gridSize = gridSize 137 | self.nbStates = self.gridSize * self.gridSize 138 | self.discount = discount 139 | canvas = np.zeros((self.gridSize, self.gridSize)) 140 | canvas = np.reshape(canvas, (-1,self.nbStates)) 141 | self.inputState = np.empty((self.maxMemory, 100), dtype = np.float32) 142 | self.actions = np.zeros(self.maxMemory, dtype = np.uint8) 143 | self.nextState = np.empty((self.maxMemory, 100), dtype = np.float32) 144 | self.gameOver = np.empty(self.maxMemory, dtype = np.bool) 145 | self.rewards = np.empty(self.maxMemory, dtype = np.int8) 146 | self.count = 0 147 | self.current = 0 148 | 149 | # Appends the experience to the memory. 150 | def remember(self, currentState, action, reward, nextState, gameOver): 151 | self.actions[self.current] = action 152 | self.rewards[self.current] = reward 153 | self.inputState[self.current, ...] = currentState 154 | self.nextState[self.current, ...] = nextState 155 | self.gameOver[self.current] = gameOver 156 | self.count = max(self.count, self.current + 1) 157 | self.current = (self.current + 1) % self.maxMemory 158 | 159 | def getBatch(self, model, batchSize, nbActions, nbStates, sess, X): 160 | 161 | # We check to see if we have enough memory inputs to make an entire batch, if not we create the biggest 162 | # batch we can (at the beginning of training we will not have enough experience to fill a batch). 163 | memoryLength = self.count 164 | chosenBatchSize = min(batchSize, memoryLength) 165 | 166 | inputs = np.zeros((chosenBatchSize, nbStates)) 167 | targets = np.zeros((chosenBatchSize, nbActions)) 168 | 169 | # Fill the inputs and targets up. 170 | for i in xrange(chosenBatchSize): 171 | if memoryLength == 1: 172 | memoryLength = 2 173 | # Choose a random memory experience to add to the batch. 174 | randomIndex = random.randrange(1, memoryLength) 175 | current_inputState = np.reshape(self.inputState[randomIndex], (1, 100)) 176 | 177 | target = sess.run(model, feed_dict={X: current_inputState}) 178 | 179 | current_nextState = np.reshape(self.nextState[randomIndex], (1, 100)) 180 | current_outputs = sess.run(model, feed_dict={X: current_nextState}) 181 | 182 | # Gives us Q_sa, the max q for the next state. 183 | nextStateMaxQ = np.amax(current_outputs) 184 | if (self.gameOver[randomIndex] == True): 185 | target[0, [self.actions[randomIndex]-1]] = self.rewards[randomIndex] 186 | else: 187 | # reward + discount(gamma) * max_a' Q(s',a') 188 | # We are setting the Q-value for the action to r + gamma*max a' Q(s', a'). The rest stay the same 189 | # to give an error of 0 for those outputs. 190 | target[0, [self.actions[randomIndex]-1]] = self.rewards[randomIndex] + self.discount * nextStateMaxQ 191 | 192 | # Update the inputs and targets. 193 | inputs[i] = current_inputState 194 | targets[i] = target 195 | 196 | return inputs, targets 197 | 198 | 199 | def main(_): 200 | print("Training new model") 201 | 202 | # Define Environment 203 | env = CatchEnvironment(gridSize) 204 | 205 | # Define Replay Memory 206 | memory = ReplayMemory(gridSize, maxMemory, discount) 207 | 208 | # Add ops to save and restore all the variables. 209 | saver = tf.train.Saver() 210 | 211 | winCount = 0 212 | with tf.Session() as sess: 213 | tf.initialize_all_variables().run() 214 | 215 | for i in xrange(epoch): 216 | # Initialize the environment. 217 | err = 0 218 | env.reset() 219 | 220 | isGameOver = False 221 | 222 | # The initial state of the environment. 223 | currentState = env.observe() 224 | 225 | while (isGameOver != True): 226 | action = -9999 # action initilization 227 | # Decides if we should choose a random action, or an action from the policy network. 228 | global epsilon 229 | if (randf(0, 1) <= epsilon): 230 | action = random.randrange(1, nbActions+1) 231 | else: 232 | # Forward the current state through the network. 233 | q = sess.run(output_layer, feed_dict={X: currentState}) 234 | # Find the max index (the chosen action). 235 | index = q.argmax() 236 | action = index + 1 237 | 238 | # Decay the epsilon by multiplying by 0.999, not allowing it to go below a certain threshold. 239 | if (epsilon > epsilonMinimumValue): 240 | epsilon = epsilon * 0.999 241 | 242 | nextState, reward, gameOver, stateInfo = env.act(action) 243 | 244 | if (reward == 1): 245 | winCount = winCount + 1 246 | 247 | memory.remember(currentState, action, reward, nextState, gameOver) 248 | 249 | # Update the current state and if the game is over. 250 | currentState = nextState 251 | isGameOver = gameOver 252 | 253 | # We get a batch of training data to train the model. 254 | inputs, targets = memory.getBatch(output_layer, batchSize, nbActions, nbStates, sess, X) 255 | 256 | # Train the network which returns the error. 257 | _, loss = sess.run([optimizer, cost], feed_dict={X: inputs, Y: targets}) 258 | err = err + loss 259 | 260 | print("Epoch " + str(i) + ": err = " + str(err) + ": Win count = " + str(winCount) + " Win ratio = " + str(float(winCount)/float(i+1)*100)) 261 | # Save the variables to disk. 262 | save_path = saver.save(sess, os.getcwd()+"/model.ckpt") 263 | print("Model saved in file: %s" % save_path) 264 | 265 | if __name__ == '__main__': 266 | tf.app.run() 267 | 268 | --------------------------------------------------------------------------------