├── .gitignore ├── .gitmodules ├── Building a Pong AI.pptx.pdf ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── common ├── __init__.py └── half_pong_player.py ├── examples ├── 1_random_half_pong_player.py ├── 2_random_with_base_half_pong_player.py ├── 3_mlp_half_pong_player.py ├── 4_tensorflow_q_learning.py ├── 4_theano_q_learning.py ├── 5_mlp_q_learning_half_pong_player.py ├── 6_conv_net_half_pong │ ├── checkpoint │ ├── network-1260000 │ ├── network-1260000.meta │ ├── network-1270000 │ ├── network-1270000.meta │ ├── network-1280000 │ ├── network-1280000.meta │ ├── network-1290000 │ ├── network-1290000.meta │ ├── network-1300000 │ └── network-1300000.meta └── 6_conv_net_half_pong_player.py └── resources └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Pycharm 10 | .idea 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *,cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | 58 | # Sphinx documentation 59 | docs/_build/ 60 | 61 | # PyBuilder 62 | target/ 63 | 64 | #Ipython Notebook 65 | .ipynb_checkpoints 66 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "resources/PyGamePlayer"] 2 | path = resources/PyGamePlayer 3 | url = https://github.com/DanielSlater/PyGamePlayer 4 | -------------------------------------------------------------------------------- /Building a Pong AI.pptx.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielSlater/PyDataLondon2016/274af1fcf41dcfcd6062445f4baa149e30b625b6/Building a Pong AI.pptx.pdf -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:latest 2 | RUN apt-get update && apt-get -y install python-dev build-essential git x11-apps 3 | RUN apt-get install -y python-setuptools python-pip python-pygame python-matplotlib python-numpy python-scipy 4 | RUN pip install cv2 5 | RUN pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl 6 | RUN git clone https://github.com/Theano/Theano.git 7 | RUN cd ./Theano && python setup.py develop 8 | COPY ./ /opt/PyDataLondon2016 9 | ENV PYTHONPATH /opt/PyDataLondon2016/ 10 | WORKDIR /opt/PyDataLondon2016 11 | RUN git submodule init && git submodule update 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Daniel Slater 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | bootstrap: 2 | brew cask install xquartz 3 | brew install socat 4 | build: 5 | docker build -t pydatalondon2016 6 | run: 7 | socat TCP-LISTEN:6000,reuseaddr,fork UNIX-CLIENT:\"$DISPLAY\" & 8 | docker run --net=host -it -e DISPLAY=`ipconfig getifaddr en0`:0 pydatalondon2016 /bin/bash 9 | 10 | all: bootstrap build run 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Building a Pong playing AI 2 | 3 | This repository contains the resources needed for the tutorial, Building a Pong playing AI in just 1 hour(Plus 4 days training time). The full video for the tutorial is on youtube [here](https://www.youtube.com/watch?v=n8NdT_3y9oY). 4 | 5 | ### Installation Guide for OS X 6 | 7 | Tested on a Macbook Pro (late 2013) with El Capitan, unsure if GPU-support works. 8 | 9 | Requirements: 10 | * [Homebrew](http://brew.sh/) 11 | * [Miniconda](http://conda.pydata.org/miniconda.html) 12 | 13 | 14 | Install some image libraries and a X framework for MacOS: 15 | 16 | ```sh 17 | brew install sdl_image 18 | brew install Caskroom/cask/xquartz 19 | ``` 20 | 21 | Clone the repo: 22 | 23 | ``` 24 | git clone git@github.com:DanielSlater/PyDataLondon2016.git 25 | cd PyDataLondon2016/ 26 | ``` 27 | 28 | Create a virtual environment for Python 2: 29 | 30 | ``` 31 | conda create --name pong-ai-27 python=2 32 | source activate pong-ai-27 33 | ``` 34 | 35 | Install listed dependencies plus `opencv`: 36 | 37 | ```sh 38 | conda install matplotlib numpy opencv 39 | ``` 40 | 41 | Install `tensorflow` and `pygame`: 42 | 43 | ```sh 44 | conda install -c https://conda.anaconda.org/jjhelmus tensorflow 45 | conda install -c https://conda.binstar.org/quasiben pygame 46 | ``` 47 | 48 | 49 | Initialize submodules: 50 | 51 | ``` 52 | git submodule init 53 | git submodule update 54 | ``` 55 | 56 | Symlink `resources` and `common` in folder `examples`: 57 | 58 | ``` 59 | cd examples/ 60 | ln -s ../resources/ 61 | ln -s ../common/ 62 | ``` 63 | 64 | 65 | Run an example: 66 | 67 | ``` 68 | python 1_random_half_pong_player.py 69 | ``` 70 | 71 | ### Linux Nvidea GPU installation Guide 72 | 73 | * [Python 2](https://www.python.org/downloads/) 74 | * [PyGame](http://www.pygame.org/download.shtml) 75 | * [TensorFlow](https://www.tensorflow.org/versions/r0.8/get_started/os_setup.html#download-and-setup) 76 | * [Matplotlib](http://matplotlib.org/users/installing.html) 77 | 78 | Tensorflow requires an NVidia GPU and only runs on Linux/Mac so if you don't have these Theano is an option (see below). The examples are all in Tensorflow, but that translates very easily to Theano and we have an example Q-learning Theano implementation that can be extended to work with Pong. 79 | 80 | ## Windows/non nvidia gpu 81 | 82 | #### [Python](https://www.python.org/downloads/) 83 | Either 2 or 3 is fine. 84 | #### [PyGame](http://www.pygame.org/download.shtml) 85 | Download which ever version matches the version of Python you plan on using. 86 | #### [Matplotlib](http://matplotlib.org/users/installing.html) 87 | Match version 88 | 89 | ### [Theano Installation Guide for Windows](http://deeplearning.net/software/theano/install.html) 90 | 91 | Download anaconda and install packages: 92 | 93 | ``` 94 | conda install mingw libpython numpy 95 | ``` 96 | 97 | Clone Theano repo: 98 | 99 | ``` 100 | git clone https://github.com/Theano/Theano.git 101 | ``` 102 | 103 | Install theano package: 104 | 105 | ``` 106 | cd Theano 107 | python setup.py develop 108 | ``` 109 | 110 | ###Docker environment alternative 111 | #### Docker build 112 | Have a look at the Makefile, essentially this helps you setup an xquartz environment exposed to a docker container along with the required dependencies. 113 | 'make all' should in theory launch you into an environment capable of running th examples straight away. 114 | 115 | ## Resources 116 | 117 | #### [PyGame Player](https://github.com/DanielSlater/PyGamePlayer/blob/master/pygame_player.py) 118 | Used for running reinforcement learning agents against PyGame 119 | 120 | #### [PyGame Pong](https://github.com/DanielSlater/PyGamePlayer/blob/master/games/pong.py) 121 | PyGame implementation of pong 122 | 123 | #### [PyGame Half Pong](https://github.com/DanielSlater/PyGamePlayer/tree/master/games) 124 | Even pong can be hard if you're just a machine. 125 | Half pong is a simplified version of pong, if you can believe it. 126 | The score and other bits of noise are removed from the game. 127 | There is only 1 bar and it is only 80x80 pixels which speeds up training and removes the need to downsize the screen 128 | 129 | 130 | ## Examples 131 | 132 | * [Random Half Pong player](https://github.com/DanielSlater/PyDataLondon2016/blob/master/examples/1_random_half_pong_player.py) 133 | * [Random With Base Half Pong player](https://github.com/DanielSlater/PyDataLondon2016/blob/master/examples/2_random_with_base_half_pong_player.py) 134 | * [MLP Half Pong player](https://github.com/DanielSlater/PyDataLondon2016/blob/master/examples/3_mlp_half_pong_player.py) 135 | * [Tensor flow Q learning](https://github.com/DanielSlater/PyDataLondon2016/blob/master/examples/4_tensorflow_q_learning.py) 136 | * [Theano flow Q learning](https://github.com/DanielSlater/PyDataLondon2016/blob/master/examples/4_theano_q_learning.py) 137 | * [MLP Q learning Half Pong player](https://github.com/DanielSlater/PyDataLondon2016/blob/master/examples/5_mlp_q_learning_half_pong_player.py) 138 | * [Convolutional network Half Pong player](https://github.com/DanielSlater/PyDataLondon2016/blob/master/examples/6_conv_net_half_pong_player.py) 139 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielSlater/PyDataLondon2016/274af1fcf41dcfcd6062445f4baa149e30b625b6/common/__init__.py -------------------------------------------------------------------------------- /common/half_pong_player.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | from pygame.constants import K_DOWN 4 | from pygame.constants import K_UP 5 | 6 | from resources.PyGamePlayer.games.half_pong import run 7 | from resources.PyGamePlayer.pygame_player import PyGamePlayer 8 | 9 | 10 | class HalfPongPlayer(PyGamePlayer): 11 | SCREEN_WIDTH = 80 12 | SCREEN_HEIGHT = 80 13 | CUMULATIVE_SCORE_LEN = 1000 14 | ACTIONS_COUNT = 3 15 | 16 | def __init__(self, **kwargs): 17 | """ 18 | Plays Half Pong by choosing moves randomly 19 | """ 20 | super(HalfPongPlayer, self).__init__(**kwargs) 21 | self._last_score = 0 22 | self._score_history = deque() 23 | 24 | def get_feedback(self): 25 | from resources.PyGamePlayer.games.half_pong import score 26 | 27 | # get the difference in scores between this and the last frame 28 | score_change = score - self._last_score 29 | self._last_score = score 30 | self._score_history.append(score_change) 31 | 32 | if len(self._score_history) > self.CUMULATIVE_SCORE_LEN: 33 | self._score_history.popleft() 34 | 35 | return float(score_change), score_change == -1 36 | 37 | def score(self): 38 | return sum(self._score_history)/float(self.CUMULATIVE_SCORE_LEN) 39 | 40 | @staticmethod 41 | def action_index_to_key(action_index): 42 | if action_index == 0: 43 | return [K_DOWN] 44 | elif action_index == 1: 45 | return [] 46 | else: 47 | return [K_UP] 48 | 49 | def start(self): 50 | super(HalfPongPlayer, self).start() 51 | 52 | run(screen_width=self.SCREEN_WIDTH, screen_height=self.SCREEN_HEIGHT) 53 | 54 | 55 | if __name__ == '__main__': 56 | player = HalfPongPlayer() 57 | player.start() 58 | -------------------------------------------------------------------------------- /examples/1_random_half_pong_player.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from pygame.constants import K_DOWN, K_UP 4 | 5 | from resources.PyGamePlayer.games.half_pong import run 6 | from resources.PyGamePlayer.pygame_player import PyGamePlayer 7 | 8 | 9 | class RandomHalfPongPlayer(PyGamePlayer): 10 | def __init__(self): 11 | """ 12 | Plays Half Pong by choosing moves randomly 13 | """ 14 | super(RandomHalfPongPlayer, self).__init__(run_real_time=True) 15 | self._last_score = 0 16 | 17 | def get_keys_pressed(self, screen_array, feedback, terminal): 18 | action_index = random.randrange(3) 19 | 20 | if action_index == 0: 21 | return [K_DOWN] 22 | elif action_index == 1: 23 | return [] 24 | else: 25 | return [K_UP] 26 | 27 | def get_feedback(self): 28 | from resources.PyGamePlayer.games.half_pong import score 29 | 30 | # get the difference in scores between this and the last frame 31 | score_change = score - self._last_score 32 | self._last_score = score 33 | 34 | if score_change != 0: 35 | print("Reward: %s" % score_change) 36 | 37 | return float(score_change), score_change == -1 38 | 39 | def start(self): 40 | super(RandomHalfPongPlayer, self).start() 41 | 42 | run(screen_width=640, screen_height=480) 43 | 44 | 45 | if __name__ == '__main__': 46 | player = RandomHalfPongPlayer() 47 | player.start() 48 | -------------------------------------------------------------------------------- /examples/2_random_with_base_half_pong_player.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from common.half_pong_player import HalfPongPlayer 4 | 5 | 6 | class RandomHalfPongPlayer(HalfPongPlayer): 7 | """ 8 | Same as 1_random half pong player except with most code moved to a base class that will be shared with other 9 | examples 10 | """ 11 | 12 | def get_keys_pressed(self, screen_array, feedback, terminal): 13 | if feedback != 0: 14 | print self.score() 15 | 16 | action_index = random.randrange(3) 17 | 18 | return HalfPongPlayer.action_index_to_key(action_index) 19 | 20 | 21 | if __name__ == '__main__': 22 | player = RandomHalfPongPlayer() 23 | player.start() 24 | -------------------------------------------------------------------------------- /examples/3_mlp_half_pong_player.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from common.half_pong_player import HalfPongPlayer 6 | 7 | 8 | class MLPHalfPongPlayer(HalfPongPlayer): 9 | def __init__(self): 10 | """ 11 | Neural network attached to pong, no way to train it yet 12 | """ 13 | super(MLPHalfPongPlayer, self).__init__(run_real_time=True, force_game_fps=6) 14 | 15 | self._input_layer, self._output_layer = self._create_network() 16 | 17 | init = tf.initialize_all_variables() 18 | self._session = tf.Session() 19 | self._session.run(init) 20 | 21 | def _create_network(self): 22 | input_layer = tf.placeholder("float", [self.SCREEN_WIDTH, self.SCREEN_HEIGHT]) 23 | 24 | feed_forward_weights_1 = tf.Variable(tf.truncated_normal([self.SCREEN_WIDTH*self.SCREEN_HEIGHT, 256], stddev=0.01)) 25 | feed_forward_bias_1 = tf.Variable(tf.constant(0.01, shape=[256])) 26 | 27 | feed_forward_weights_2 = tf.Variable(tf.truncated_normal([256, self.ACTIONS_COUNT], stddev=0.01)) 28 | feed_forward_bias_2 = tf.Variable(tf.constant(0.01, shape=[self.ACTIONS_COUNT])) 29 | 30 | flattened_input = tf.reshape(input_layer, shape=(1, self.SCREEN_WIDTH*self.SCREEN_HEIGHT,)) 31 | 32 | hidden_layer = tf.nn.relu( 33 | tf.matmul(flattened_input, feed_forward_weights_1) + feed_forward_bias_1) 34 | 35 | output_layer = tf.matmul(hidden_layer, feed_forward_weights_2) + feed_forward_bias_2 36 | 37 | return input_layer, output_layer 38 | 39 | def get_keys_pressed(self, screen_array, feedback, terminal): 40 | # images will be black or white 41 | _, binary_image = cv2.threshold(cv2.cvtColor(screen_array, cv2.COLOR_BGR2GRAY), 1, 255, 42 | cv2.THRESH_BINARY) 43 | 44 | output = self._session.run(self._output_layer, feed_dict={self._input_layer: binary_image}) 45 | action = np.argmax(output) 46 | 47 | # How do we train???? 48 | 49 | return self.action_index_to_key(action) 50 | 51 | 52 | if __name__ == '__main__': 53 | player = MLPHalfPongPlayer() 54 | player.start() 55 | -------------------------------------------------------------------------------- /examples/4_tensorflow_q_learning.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | # we will create a set of states, the agent get a reward for getting to the 5th one(4 in zero based array). 5 | # the agent can go forward or backward by one state with wrapping(so if you go back from the 1st state you go to the 6 | # end). 7 | states = [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0] 8 | NUM_STATES = len(states) 9 | NUM_ACTIONS = 2 10 | FUTURE_REWARD_DISCOUNT = 0.5 11 | 12 | 13 | def hot_one_state(index): 14 | array = np.zeros(NUM_STATES) 15 | array[index] = 1. 16 | return array 17 | 18 | # The None here is for batch training 19 | session = tf.Session() 20 | state = tf.placeholder("float", [None, NUM_STATES]) 21 | targets = tf.placeholder("float", [None, NUM_ACTIONS]) 22 | 23 | hidden_weights = tf.Variable(tf.constant(0., shape=[NUM_STATES, NUM_ACTIONS])) 24 | 25 | output = tf.matmul(state, hidden_weights) 26 | 27 | loss = tf.reduce_mean(tf.square(output - targets)) 28 | train_operation = tf.train.AdamOptimizer(0.1).minimize(loss) 29 | 30 | session.run(tf.initialize_all_variables()) 31 | 32 | for i in range(50): 33 | state_batch = [] 34 | rewards_batch = [] 35 | 36 | # create a batch of states 37 | for state_index in range(NUM_STATES): 38 | state_batch.append(hot_one_state(state_index)) 39 | 40 | minus_action_index = (state_index - 1) % NUM_STATES 41 | plus_action_index = (state_index + 1) % NUM_STATES 42 | 43 | minus_action_state_reward = session.run(output, feed_dict={state: [hot_one_state(minus_action_index)]}) 44 | plus_action_state_reward = session.run(output, feed_dict={state: [hot_one_state(plus_action_index)]}) 45 | 46 | # these action rewards are the results of the Q function for this state and the actions minus or plus 47 | action_rewards = [states[minus_action_index] + FUTURE_REWARD_DISCOUNT * np.max(minus_action_state_reward), 48 | states[plus_action_index] + FUTURE_REWARD_DISCOUNT * np.max(plus_action_state_reward)] 49 | rewards_batch.append(action_rewards) 50 | 51 | session.run(train_operation, feed_dict={ 52 | state: state_batch, 53 | targets: rewards_batch}) 54 | 55 | print([states[x] + np.max(session.run(output, feed_dict={state: [hot_one_state(x)]})) 56 | for x in range(NUM_STATES)]) 57 | -------------------------------------------------------------------------------- /examples/4_theano_q_learning.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | 6 | # we will create a set of states, the agent get a reward for getting to the 5th one(4 in zero based array). 7 | # the agent can go forward or backward by one state with wrapping(so if you go back from the 1st state you go to the 8 | # end). 9 | states = [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0] 10 | NUM_STATES = len(states) 11 | NUM_ACTIONS = 2 12 | FUTURE_REWARD_DISCOUNT = 0.5 13 | LEARNING_RATE = 0.1 14 | 15 | 16 | def hot_one_state(index): 17 | array = np.zeros(NUM_STATES) 18 | array[index] = 1. 19 | return array.reshape(array.shape[0], 1) # Theano is sad if the shape looks like (10,) rather than (10,1) 20 | 21 | 22 | state, targets = T.dmatrices('state', 'targets') 23 | hidden_weights = theano.shared(value=np.zeros((NUM_ACTIONS, NUM_STATES)), name='hidden_weights') 24 | 25 | output_fn = T.dot(hidden_weights, state) 26 | output = theano.function([state], output_fn) 27 | 28 | states_input = T.dmatrix('states_input') 29 | loss_fn = T.mean((T.dot(hidden_weights, states_input) - targets) ** 2) 30 | 31 | gradient = T.grad(cost=loss_fn, wrt=hidden_weights) 32 | 33 | train_model = theano.function( 34 | inputs=[states_input, targets], 35 | outputs=loss_fn, 36 | updates=[[hidden_weights, hidden_weights - LEARNING_RATE * gradient]], 37 | allow_input_downcast=True 38 | ) 39 | 40 | for i in range(50): 41 | state_batch = [] 42 | rewards_batch = [] 43 | 44 | # create a batch of states 45 | for state_index in range(NUM_STATES): 46 | state_batch.append(hot_one_state(state_index)[:,0]) 47 | 48 | minus_action_index = (state_index - 1) % NUM_STATES 49 | plus_action_index = (state_index + 1) % NUM_STATES 50 | 51 | minus_action_state_reward = output(state=hot_one_state(minus_action_index)) 52 | plus_action_state_reward = output(state=hot_one_state(plus_action_index)) 53 | 54 | # these action rewards are the results of the Q function for this state and the actions minus or plus 55 | action_rewards = [states[minus_action_index] + FUTURE_REWARD_DISCOUNT * np.max(minus_action_state_reward), 56 | states[plus_action_index] + FUTURE_REWARD_DISCOUNT * np.max(plus_action_state_reward)] 57 | rewards_batch.append(action_rewards) 58 | 59 | train_model(state_batch, np.array(rewards_batch).T) 60 | print([states[x] + np.max(output(hot_one_state(x))) for x in range(NUM_STATES)]) -------------------------------------------------------------------------------- /examples/5_mlp_q_learning_half_pong_player.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from collections import deque 4 | 5 | import cv2 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from common.half_pong_player import HalfPongPlayer 10 | 11 | 12 | class MLPQLearningHalfPongPlayer(HalfPongPlayer): 13 | FUTURE_REWARD_DISCOUNT = 0.99 # decay rate of past observations 14 | OBSERVATION_STEPS = 50000. # time steps to observe before training 15 | EXPLORE_STEPS = 300000. # frames over which to anneal epsilon 16 | INITIAL_RANDOM_ACTION_PROB = 1.0 # starting chance of an action being random 17 | FINAL_RANDOM_ACTION_PROB = 0.05 # final chance of an action being random 18 | MEMORY_SIZE = 200000 # number of observations to remember 19 | MINI_BATCH_SIZE = 100 # size of mini batches 20 | OBS_LAST_STATE_INDEX, OBS_ACTION_INDEX, OBS_REWARD_INDEX, OBS_CURRENT_STATE_INDEX, OBS_TERMINAL_INDEX = range(5) 21 | LEARN_RATE = 0.00001 22 | SAVE_EVERY_X_STEPS = 5000 23 | SCREEN_WIDTH = 40 24 | SCREEN_HEIGHT = 40 25 | STATE_FRAMES = 4 26 | 27 | def __init__(self, checkpoint_path="5_mlp_q_learning_half_pong", playback_mode=False, verbose_logging=True): 28 | """ 29 | MLP now training using Q-learning 30 | """ 31 | self._playback_mode = playback_mode 32 | super(MLPQLearningHalfPongPlayer, self).__init__(run_real_time=self._playback_mode, force_game_fps=6) 33 | self.verbose_logging = verbose_logging 34 | self._checkpoint_path = checkpoint_path 35 | 36 | self._probability_of_random_action = self.INITIAL_RANDOM_ACTION_PROB 37 | self._observations = deque() 38 | self._time = 0 39 | self._last_action = None 40 | self._last_state = None 41 | 42 | self._input_layer, self._output_layer = self._create_network() 43 | 44 | self._actions = tf.placeholder("float", [None, self.ACTIONS_COUNT], name="actions") 45 | self._target = tf.placeholder("float", [None], name="target") 46 | 47 | readout_action = tf.reduce_sum(tf.mul(self._output_layer, self._actions), reduction_indices=1) 48 | 49 | cost = tf.reduce_mean(tf.square(self._target - readout_action)) 50 | self._train_operation = tf.train.AdamOptimizer(self.LEARN_RATE).minimize(cost) 51 | 52 | init = tf.initialize_all_variables() 53 | self._session = tf.Session() 54 | self._session.run(init) 55 | 56 | if not os.path.exists(self._checkpoint_path): 57 | os.mkdir(self._checkpoint_path) 58 | self._saver = tf.train.Saver() 59 | checkpoint = tf.train.get_checkpoint_state(self._checkpoint_path) 60 | 61 | if checkpoint and checkpoint.model_checkpoint_path: 62 | self._saver.restore(self._session, checkpoint.model_checkpoint_path) 63 | print("Loaded checkpoints %s" % checkpoint.model_checkpoint_path) 64 | elif playback_mode: 65 | raise Exception("Could not load checkpoints for playback") 66 | 67 | def _create_network(self): 68 | input_layer = tf.placeholder("float", [None, self.SCREEN_WIDTH * self.SCREEN_HEIGHT * self.STATE_FRAMES], 69 | name="input_layer") 70 | 71 | feed_forward_weights_1 = tf.Variable( 72 | tf.truncated_normal([self.SCREEN_WIDTH * self.SCREEN_HEIGHT * self.STATE_FRAMES, 256], stddev=0.01)) 73 | feed_forward_bias_1 = tf.Variable(tf.constant(0.01, shape=[256])) 74 | 75 | feed_forward_weights_2 = tf.Variable(tf.truncated_normal([256, self.ACTIONS_COUNT], stddev=0.01)) 76 | feed_forward_bias_2 = tf.Variable(tf.constant(0.01, shape=[self.ACTIONS_COUNT])) 77 | 78 | hidden_layer = tf.nn.relu( 79 | tf.matmul(input_layer, feed_forward_weights_1) + feed_forward_bias_1, name="hidden_activations") 80 | 81 | output_layer = tf.matmul(hidden_layer, feed_forward_weights_2) + feed_forward_bias_2 82 | 83 | return input_layer, output_layer 84 | 85 | def get_keys_pressed(self, screen_array, reward, terminal): 86 | # images will be black or white 87 | _, binary_image = cv2.threshold(cv2.cvtColor(screen_array, cv2.COLOR_BGR2GRAY), 1, 255, 88 | cv2.THRESH_BINARY) 89 | 90 | binary_image = np.reshape(binary_image, (self.SCREEN_WIDTH * self.SCREEN_HEIGHT,)) 91 | 92 | # first frame must be handled differently 93 | if self._last_state is None: 94 | self._last_state = np.concatenate(tuple(binary_image for _ in range(self.STATE_FRAMES)), axis=0) 95 | random_action = random.randrange(self.ACTIONS_COUNT) 96 | 97 | self._last_action = np.zeros([self.ACTIONS_COUNT]) 98 | self._last_action[random_action] = 1. 99 | 100 | return self.action_index_to_key(random_action) 101 | 102 | binary_image = np.append(self._last_state[self.SCREEN_WIDTH * self.SCREEN_HEIGHT:], binary_image, axis=0) 103 | 104 | self._observations.append((self._last_state, self._last_action, reward, binary_image, terminal)) 105 | 106 | if len(self._observations) > self.MEMORY_SIZE: 107 | self._observations.popleft() 108 | 109 | if len(self._observations) > self.OBSERVATION_STEPS: 110 | self._train() 111 | self._time += 1 112 | 113 | # gradually reduce the probability of a random actionself. 114 | if self._probability_of_random_action > self.FINAL_RANDOM_ACTION_PROB \ 115 | and len(self._observations) > self.OBSERVATION_STEPS: 116 | self._probability_of_random_action -= \ 117 | (self.INITIAL_RANDOM_ACTION_PROB - self.FINAL_RANDOM_ACTION_PROB) / self.EXPLORE_STEPS 118 | 119 | print("Time: %s random_action_prob: %s reward %s scores %s" % 120 | (self._time, self._probability_of_random_action, reward, 121 | self.score())) 122 | 123 | action = self._choose_next_action(binary_image) 124 | self._last_state = binary_image 125 | 126 | self._last_action = np.zeros([self.ACTIONS_COUNT]) 127 | self._last_action[action] = 1. 128 | return self.action_index_to_key(action) 129 | 130 | def _choose_next_action(self, binary_image): 131 | if (not self._playback_mode) and (random.random() <= self._probability_of_random_action): 132 | return random.randrange(self.ACTIONS_COUNT) 133 | else: 134 | # let the net choose our action 135 | output = self._session.run(self._output_layer, feed_dict={self._input_layer: [binary_image]}) 136 | return np.argmax(output) 137 | 138 | def _train(self): 139 | # sample a mini_batch to train on 140 | mini_batch = random.sample(self._observations, self.MINI_BATCH_SIZE) 141 | # get the batch variables 142 | previous_states = [d[self.OBS_LAST_STATE_INDEX] for d in mini_batch] 143 | actions = [d[self.OBS_ACTION_INDEX] for d in mini_batch] 144 | rewards = [d[self.OBS_REWARD_INDEX] for d in mini_batch] 145 | current_states = [d[self.OBS_CURRENT_STATE_INDEX] for d in mini_batch] 146 | agents_expected_reward = [] 147 | # this gives us the agents expected reward for each action we might 148 | agents_reward_per_action = self._session.run(self._output_layer, feed_dict={self._input_layer: current_states}) 149 | for i in range(len(mini_batch)): 150 | if mini_batch[i][self.OBS_TERMINAL_INDEX]: 151 | # this was a terminal frame so there is no future reward... 152 | agents_expected_reward.append(rewards[i]) 153 | else: 154 | agents_expected_reward.append( 155 | rewards[i] + self.FUTURE_REWARD_DISCOUNT * np.max(agents_reward_per_action[i])) 156 | 157 | # learn that these actions in these states lead to this reward 158 | self._session.run(self._train_operation, feed_dict={ 159 | self._input_layer: previous_states, 160 | self._actions: actions, 161 | self._target: agents_expected_reward}) 162 | 163 | # save checkpoints for later 164 | if self._time % self.SAVE_EVERY_X_STEPS == 0: 165 | self._saver.save(self._session, self._checkpoint_path + '/network', global_step=self._time) 166 | 167 | 168 | if __name__ == '__main__': 169 | player = MLPQLearningHalfPongPlayer() 170 | player.start() 171 | -------------------------------------------------------------------------------- /examples/6_conv_net_half_pong/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "network-1300000" 2 | all_model_checkpoint_paths: "network-1260000" 3 | all_model_checkpoint_paths: "network-1270000" 4 | all_model_checkpoint_paths: "network-1280000" 5 | all_model_checkpoint_paths: "network-1290000" 6 | all_model_checkpoint_paths: "network-1300000" 7 | -------------------------------------------------------------------------------- /examples/6_conv_net_half_pong/network-1260000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielSlater/PyDataLondon2016/274af1fcf41dcfcd6062445f4baa149e30b625b6/examples/6_conv_net_half_pong/network-1260000 -------------------------------------------------------------------------------- /examples/6_conv_net_half_pong/network-1260000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielSlater/PyDataLondon2016/274af1fcf41dcfcd6062445f4baa149e30b625b6/examples/6_conv_net_half_pong/network-1260000.meta -------------------------------------------------------------------------------- /examples/6_conv_net_half_pong/network-1270000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielSlater/PyDataLondon2016/274af1fcf41dcfcd6062445f4baa149e30b625b6/examples/6_conv_net_half_pong/network-1270000 -------------------------------------------------------------------------------- /examples/6_conv_net_half_pong/network-1270000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielSlater/PyDataLondon2016/274af1fcf41dcfcd6062445f4baa149e30b625b6/examples/6_conv_net_half_pong/network-1270000.meta -------------------------------------------------------------------------------- /examples/6_conv_net_half_pong/network-1280000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielSlater/PyDataLondon2016/274af1fcf41dcfcd6062445f4baa149e30b625b6/examples/6_conv_net_half_pong/network-1280000 -------------------------------------------------------------------------------- /examples/6_conv_net_half_pong/network-1280000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielSlater/PyDataLondon2016/274af1fcf41dcfcd6062445f4baa149e30b625b6/examples/6_conv_net_half_pong/network-1280000.meta -------------------------------------------------------------------------------- /examples/6_conv_net_half_pong/network-1290000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielSlater/PyDataLondon2016/274af1fcf41dcfcd6062445f4baa149e30b625b6/examples/6_conv_net_half_pong/network-1290000 -------------------------------------------------------------------------------- /examples/6_conv_net_half_pong/network-1290000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielSlater/PyDataLondon2016/274af1fcf41dcfcd6062445f4baa149e30b625b6/examples/6_conv_net_half_pong/network-1290000.meta -------------------------------------------------------------------------------- /examples/6_conv_net_half_pong/network-1300000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielSlater/PyDataLondon2016/274af1fcf41dcfcd6062445f4baa149e30b625b6/examples/6_conv_net_half_pong/network-1300000 -------------------------------------------------------------------------------- /examples/6_conv_net_half_pong/network-1300000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielSlater/PyDataLondon2016/274af1fcf41dcfcd6062445f4baa149e30b625b6/examples/6_conv_net_half_pong/network-1300000.meta -------------------------------------------------------------------------------- /examples/6_conv_net_half_pong_player.py: -------------------------------------------------------------------------------- 1 | # This is heavily based off https://github.com/asrivat1/DeepLearningVideoGames 2 | import os 3 | import random 4 | from collections import deque 5 | 6 | import cv2 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from common.half_pong_player import HalfPongPlayer 11 | 12 | 13 | class ConvNetHalfPongPlayer(HalfPongPlayer): 14 | ACTIONS_COUNT = 3 # number of valid actions. In this case up, still and down 15 | FUTURE_REWARD_DISCOUNT = 0.99 # decay rate of past observations 16 | OBSERVATION_STEPS = 50000. # time steps to observe before training 17 | EXPLORE_STEPS = 500000. # frames over which to anneal epsilon 18 | INITIAL_RANDOM_ACTION_PROB = 1.0 # starting chance of an action being random 19 | FINAL_RANDOM_ACTION_PROB = 0.05 # final chance of an action being random 20 | MEMORY_SIZE = 500000 # number of observations to remember 21 | MINI_BATCH_SIZE = 200 # size of mini batches 22 | STATE_FRAMES = 4 # number of frames to store in the state 23 | OBS_LAST_STATE_INDEX, OBS_ACTION_INDEX, OBS_REWARD_INDEX, OBS_CURRENT_STATE_INDEX, OBS_TERMINAL_INDEX = range(5) 24 | SAVE_EVERY_X_STEPS = 5000 25 | LEARN_RATE = 1e-6 26 | SCREEN_WIDTH = 40 27 | SCREEN_HEIGHT = 40 28 | 29 | def __init__(self, checkpoint_path="6_conv_net_half_pong", playback_mode=False, verbose_logging=True): 30 | """ 31 | Example of deep q network for pong 32 | 33 | :param checkpoint_path: directory to store checkpoints in 34 | :type checkpoint_path: str 35 | :param playback_mode: if true games runs in real time mode and demos itself running 36 | :type playback_mode: bool 37 | :param verbose_logging: If true then extra log information is printed to std out 38 | :type verbose_logging: bool 39 | """ 40 | self._playback_mode = playback_mode 41 | super(ConvNetHalfPongPlayer, self).__init__(force_game_fps=8, run_real_time=playback_mode) 42 | self.verbose_logging = verbose_logging 43 | self._checkpoint_path = checkpoint_path 44 | self._session = tf.Session() 45 | self._input_layer, self._output_layer = self._create_network() 46 | 47 | self._action = tf.placeholder("float", [None, self.ACTIONS_COUNT]) 48 | self._target = tf.placeholder("float", [None]) 49 | 50 | readout_action = tf.reduce_sum(tf.mul(self._output_layer, self._action), reduction_indices=1) 51 | 52 | cost = tf.reduce_mean(tf.square(self._target - readout_action)) 53 | self._train_operation = tf.train.AdamOptimizer(self.LEARN_RATE).minimize(cost) 54 | 55 | self._observations = deque() 56 | 57 | # set the first action to do nothing 58 | self._last_action = np.zeros(self.ACTIONS_COUNT) 59 | self._last_action[1] = 1 60 | 61 | self._last_state = None 62 | self._probability_of_random_action = self.INITIAL_RANDOM_ACTION_PROB 63 | self._time = 0 64 | 65 | self._session.run(tf.initialize_all_variables()) 66 | 67 | if not os.path.exists(self._checkpoint_path): 68 | os.mkdir(self._checkpoint_path) 69 | self._saver = tf.train.Saver() 70 | checkpoint = tf.train.get_checkpoint_state(self._checkpoint_path) 71 | 72 | if checkpoint and checkpoint.model_checkpoint_path: 73 | self._saver.restore(self._session, checkpoint.model_checkpoint_path) 74 | print("Loaded checkpoints %s" % checkpoint.model_checkpoint_path) 75 | elif playback_mode: 76 | raise Exception("Could not load checkpoints for playback") 77 | 78 | def get_keys_pressed(self, screen_array, reward, terminal): 79 | # images will be black or white 80 | ret, binary_image = cv2.threshold(cv2.cvtColor(screen_array, cv2.COLOR_BGR2GRAY), 1, 255, 81 | cv2.THRESH_BINARY) 82 | 83 | # first frame must be handled differently 84 | if self._last_state is None: 85 | # the _last_state will contain the image data from the last self.STATE_FRAMES frames 86 | self._last_state = np.stack(tuple(binary_image for _ in range(self.STATE_FRAMES)), axis=2) 87 | return self.action_index_to_key(1) 88 | 89 | binary_image = np.reshape(binary_image, 90 | (self.SCREEN_WIDTH, self.SCREEN_HEIGHT, 1)) 91 | current_state = np.append(self._last_state[:, :, 1:], binary_image, axis=2) 92 | 93 | if not self._playback_mode: 94 | # store the transition in previous_observations 95 | self._observations.append((self._last_state, self._last_action, reward, current_state, terminal)) 96 | 97 | if len(self._observations) > self.MEMORY_SIZE: 98 | self._observations.popleft() 99 | 100 | # only train if done observing 101 | if len(self._observations) > self.OBSERVATION_STEPS: 102 | self._train() 103 | self._time += 1 104 | 105 | # update the old values 106 | self._last_state = current_state 107 | 108 | action = self._choose_next_action() 109 | 110 | if not self._playback_mode: 111 | # gradually reduce the probability of a random actionself. 112 | if self._probability_of_random_action > self.FINAL_RANDOM_ACTION_PROB \ 113 | and len(self._observations) > self.OBSERVATION_STEPS: 114 | self._probability_of_random_action -= \ 115 | (self.INITIAL_RANDOM_ACTION_PROB - self.FINAL_RANDOM_ACTION_PROB) / self.EXPLORE_STEPS 116 | 117 | print("Time: %s random_action_prob: %s reward %s scores differential %s" % 118 | (self._time, self._probability_of_random_action, reward, 119 | self.score())) 120 | 121 | self._last_action = np.zeros((3,)) 122 | self._last_action[action] = 1. 123 | 124 | return HalfPongPlayer.action_index_to_key(action) 125 | 126 | def _choose_next_action(self): 127 | if (not self._playback_mode) and (random.random() <= self._probability_of_random_action): 128 | # choose an action randomly 129 | action_index = random.randrange(self.ACTIONS_COUNT) 130 | else: 131 | # choose an action given our last state 132 | readout_t = self._session.run(self._output_layer, feed_dict={self._input_layer: [self._last_state]})[0] 133 | if self.verbose_logging: 134 | print("Action Q-Values are %s" % readout_t) 135 | action_index = np.argmax(readout_t) 136 | 137 | return action_index 138 | 139 | def _train(self): 140 | # sample a mini_batch to train on 141 | mini_batch = random.sample(self._observations, self.MINI_BATCH_SIZE) 142 | # get the batch variables 143 | previous_states = [d[self.OBS_LAST_STATE_INDEX] for d in mini_batch] 144 | actions = [d[self.OBS_ACTION_INDEX] for d in mini_batch] 145 | rewards = [d[self.OBS_REWARD_INDEX] for d in mini_batch] 146 | current_states = [d[self.OBS_CURRENT_STATE_INDEX] for d in mini_batch] 147 | agents_expected_reward = [] 148 | # this gives us the agents expected reward for each action we might take 149 | agents_reward_per_action = self._session.run(self._output_layer, feed_dict={self._input_layer: current_states}) 150 | for i in range(len(mini_batch)): 151 | if mini_batch[i][self.OBS_TERMINAL_INDEX]: 152 | # this was a terminal frame so there is no future reward... 153 | agents_expected_reward.append(rewards[i]) 154 | else: 155 | agents_expected_reward.append( 156 | rewards[i] + self.FUTURE_REWARD_DISCOUNT * np.max(agents_reward_per_action[i])) 157 | 158 | # learn that these actions in these states lead to this reward 159 | self._session.run(self._train_operation, feed_dict={ 160 | self._input_layer: previous_states, 161 | self._action: actions, 162 | self._target: agents_expected_reward}) 163 | 164 | # save checkpoints for later 165 | if self._time % self.SAVE_EVERY_X_STEPS == 0: 166 | self._saver.save(self._session, self._checkpoint_path + '/network', global_step=self._time) 167 | 168 | def _create_network(self): 169 | # network weights 170 | convolution_weights_1 = tf.Variable(tf.truncated_normal([8, 8, self.STATE_FRAMES, 32], stddev=0.01)) 171 | convolution_bias_1 = tf.Variable(tf.constant(0.01, shape=[32])) 172 | 173 | convolution_weights_2 = tf.Variable(tf.truncated_normal([4, 4, 32, 64], stddev=0.01)) 174 | convolution_bias_2 = tf.Variable(tf.constant(0.01, shape=[64])) 175 | 176 | feed_forward_weights_1 = tf.Variable(tf.truncated_normal([256, 256], stddev=0.01)) 177 | feed_forward_bias_1 = tf.Variable(tf.constant(0.01, shape=[256])) 178 | 179 | feed_forward_weights_2 = tf.Variable(tf.truncated_normal([256, self.ACTIONS_COUNT], stddev=0.01)) 180 | feed_forward_bias_2 = tf.Variable(tf.constant(0.01, shape=[self.ACTIONS_COUNT])) 181 | 182 | input_layer = tf.placeholder("float", [None, self.SCREEN_WIDTH, self.SCREEN_HEIGHT, 183 | self.STATE_FRAMES]) 184 | 185 | hidden_convolutional_layer_1 = tf.nn.relu( 186 | tf.nn.conv2d(input_layer, convolution_weights_1, strides=[1, 4, 4, 1], padding="SAME") + convolution_bias_1) 187 | 188 | hidden_max_pooling_layer_1 = tf.nn.max_pool(hidden_convolutional_layer_1, ksize=[1, 2, 2, 1], 189 | strides=[1, 2, 2, 1], padding="SAME") 190 | 191 | hidden_convolutional_layer_2 = tf.nn.relu( 192 | tf.nn.conv2d(hidden_max_pooling_layer_1, convolution_weights_2, strides=[1, 2, 2, 1], 193 | padding="SAME") + convolution_bias_2) 194 | 195 | hidden_max_pooling_layer_2 = tf.nn.max_pool(hidden_convolutional_layer_2, ksize=[1, 2, 2, 1], 196 | strides=[1, 2, 2, 1], padding="SAME") 197 | 198 | hidden_convolutional_layer_3_flat = tf.reshape(hidden_max_pooling_layer_2, [-1, 256]) 199 | 200 | final_hidden_activations = tf.nn.relu( 201 | tf.matmul(hidden_convolutional_layer_3_flat, feed_forward_weights_1) + feed_forward_bias_1) 202 | 203 | output_layer = tf.matmul(final_hidden_activations, feed_forward_weights_2) + feed_forward_bias_2 204 | 205 | return input_layer, output_layer 206 | 207 | 208 | if __name__ == '__main__': 209 | player = ConvNetHalfPongPlayer() 210 | player.start() 211 | -------------------------------------------------------------------------------- /resources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielSlater/PyDataLondon2016/274af1fcf41dcfcd6062445f4baa149e30b625b6/resources/__init__.py --------------------------------------------------------------------------------