├── BrainDQN_NIPS.py ├── BrainDQN_Nature.py ├── FlappyBirdDQN.py ├── README.md ├── assets ├── audio │ ├── die.ogg │ ├── die.wav │ ├── hit.ogg │ ├── hit.wav │ ├── point.ogg │ ├── point.wav │ ├── swoosh.ogg │ ├── swoosh.wav │ ├── wing.ogg │ └── wing.wav └── sprites │ ├── 0.png │ ├── 1.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ ├── 7.png │ ├── 8.png │ ├── 9.png │ ├── background-black.png │ ├── base.png │ ├── pipe-green.png │ ├── redbird-downflap.png │ ├── redbird-midflap.png │ └── redbird-upflap.png ├── game ├── flappy_bird_utils.py └── wrapped_flappy_bird.py └── saved_networks ├── checkpoint ├── network-dqn-50000 └── network-dqn-50000.meta /BrainDQN_NIPS.py: -------------------------------------------------------------------------------- 1 | # ----------------------------- 2 | # File: Deep Q-Learning Algorithm 3 | # Author: Flood Sung 4 | # Date: 2016.3.21 5 | # ----------------------------- 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | import random 10 | from collections import deque 11 | 12 | # Hyper Parameters: 13 | FRAME_PER_ACTION = 1 14 | GAMMA = 0.99 # decay rate of past observations 15 | OBSERVE = 100. # timesteps to observe before training 16 | EXPLORE = 150000. # frames over which to anneal epsilon 17 | FINAL_EPSILON = 0.0 # final value of epsilon 18 | INITIAL_EPSILON = 0.9 # starting value of epsilon 19 | REPLAY_MEMORY = 50000 # number of previous transitions to remember 20 | BATCH_SIZE = 32 # size of minibatch 21 | 22 | class BrainDQN: 23 | 24 | def __init__(self,actions): 25 | # init replay memory 26 | self.replayMemory = deque() 27 | # init some parameters 28 | self.timeStep = 0 29 | self.epsilon = INITIAL_EPSILON 30 | self.actions = actions 31 | # init Q network 32 | self.createQNetwork() 33 | 34 | def createQNetwork(self): 35 | # network weights 36 | W_conv1 = self.weight_variable([8,8,4,32]) 37 | b_conv1 = self.bias_variable([32]) 38 | 39 | W_conv2 = self.weight_variable([4,4,32,64]) 40 | b_conv2 = self.bias_variable([64]) 41 | 42 | W_conv3 = self.weight_variable([3,3,64,64]) 43 | b_conv3 = self.bias_variable([64]) 44 | 45 | W_fc1 = self.weight_variable([1600,512]) 46 | b_fc1 = self.bias_variable([512]) 47 | 48 | W_fc2 = self.weight_variable([512,self.actions]) 49 | b_fc2 = self.bias_variable([self.actions]) 50 | 51 | # input layer 52 | 53 | self.stateInput = tf.placeholder("float",[None,80,80,4]) 54 | 55 | # hidden layers 56 | h_conv1 = tf.nn.relu(self.conv2d(self.stateInput,W_conv1,4) + b_conv1) 57 | h_pool1 = self.max_pool_2x2(h_conv1) 58 | 59 | h_conv2 = tf.nn.relu(self.conv2d(h_pool1,W_conv2,2) + b_conv2) 60 | 61 | h_conv3 = tf.nn.relu(self.conv2d(h_conv2,W_conv3,1) + b_conv3) 62 | 63 | h_conv3_flat = tf.reshape(h_conv3,[-1,1600]) 64 | h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat,W_fc1) + b_fc1) 65 | 66 | # Q Value layer 67 | self.QValue = tf.matmul(h_fc1,W_fc2) + b_fc2 68 | 69 | self.actionInput = tf.placeholder("float",[None,self.actions]) 70 | self.yInput = tf.placeholder("float", [None]) 71 | Q_action = tf.reduce_sum(tf.mul(self.QValue, self.actionInput), reduction_indices = 1) 72 | self.cost = tf.reduce_mean(tf.square(self.yInput - Q_action)) 73 | self.trainStep = tf.train.AdamOptimizer(1e-6).minimize(self.cost) 74 | 75 | # saving and loading networks 76 | self.saver = tf.train.Saver() 77 | self.session = tf.InteractiveSession() 78 | self.session.run(tf.initialize_all_variables()) 79 | checkpoint = tf.train.get_checkpoint_state("saved_networks") 80 | if checkpoint and checkpoint.model_checkpoint_path: 81 | self.saver.restore(self.session, checkpoint.model_checkpoint_path) 82 | print ("Successfully loaded:", checkpoint.model_checkpoint_path) 83 | else: 84 | print ("Could not find old network weights") 85 | 86 | def trainQNetwork(self): 87 | # Step 1: obtain random minibatch from replay memory 88 | minibatch = random.sample(self.replayMemory,BATCH_SIZE) 89 | state_batch = [data[0] for data in minibatch] 90 | action_batch = [data[1] for data in minibatch] 91 | reward_batch = [data[2] for data in minibatch] 92 | nextState_batch = [data[3] for data in minibatch] 93 | 94 | # Step 2: calculate y 95 | y_batch = [] 96 | QValue_batch = self.QValue.eval(feed_dict={self.stateInput:nextState_batch}) 97 | for i in range(0,BATCH_SIZE): 98 | terminal = minibatch[i][4] 99 | if terminal: 100 | y_batch.append(reward_batch[i]) 101 | else: 102 | y_batch.append(reward_batch[i] + GAMMA * np.max(QValue_batch[i])) 103 | 104 | self.trainStep.run(feed_dict={ 105 | self.yInput : y_batch, 106 | self.actionInput : action_batch, 107 | self.stateInput : state_batch 108 | }) 109 | 110 | # save network every 100000 iteration 111 | if self.timeStep % 10000 == 0: 112 | self.saver.save(self.session, 'saved_networks/' + 'network' + '-dqn', global_step = self.timeStep) 113 | 114 | 115 | def setPerception(self,nextObservation,action,reward,terminal): 116 | #newState = np.append(nextObservation,self.currentState[:,:,1:],axis = 2) 117 | newState = np.append(self.currentState[:,:,1:],nextObservation,axis = 2) 118 | self.replayMemory.append((self.currentState,action,reward,newState,terminal)) 119 | if len(self.replayMemory) > REPLAY_MEMORY: 120 | self.replayMemory.popleft() 121 | if self.timeStep > OBSERVE: 122 | # Train the network 123 | self.trainQNetwork() 124 | 125 | self.currentState = newState 126 | self.timeStep += 1 127 | 128 | def getAction(self): 129 | QValue = self.QValue.eval(feed_dict= {self.stateInput:[self.currentState]})[0] 130 | action = np.zeros(self.actions) 131 | action_index = 0 132 | if self.timeStep % FRAME_PER_ACTION == 0: 133 | if random.random() <= self.epsilon: 134 | action_index = random.randrange(self.actions) 135 | action[action_index] = 1 136 | else: 137 | action_index = np.argmax(QValue) 138 | action[action_index] = 1 139 | else: 140 | action[0] = 1 # do nothing 141 | 142 | # change episilon 143 | if self.epsilon > FINAL_EPSILON and self.timeStep > OBSERVE: 144 | self.epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/EXPLORE 145 | 146 | return action 147 | 148 | def setInitState(self,observation): 149 | self.currentState = np.stack((observation, observation, observation, observation), axis = 2) 150 | 151 | def weight_variable(self,shape): 152 | initial = tf.truncated_normal(shape, stddev = 0.01) 153 | return tf.Variable(initial) 154 | 155 | def bias_variable(self,shape): 156 | initial = tf.constant(0.01, shape = shape) 157 | return tf.Variable(initial) 158 | 159 | def conv2d(self,x, W, stride): 160 | return tf.nn.conv2d(x, W, strides = [1, stride, stride, 1], padding = "SAME") 161 | 162 | def max_pool_2x2(self,x): 163 | return tf.nn.max_pool(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME") 164 | 165 | -------------------------------------------------------------------------------- /BrainDQN_Nature.py: -------------------------------------------------------------------------------- 1 | # ----------------------------- 2 | # File: Deep Q-Learning Algorithm 3 | # Author: Flood Sung 4 | # Date: 2016.3.21 5 | # ----------------------------- 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | import random 10 | from collections import deque 11 | 12 | # Hyper Parameters: 13 | FRAME_PER_ACTION = 1 14 | GAMMA = 0.99 # decay rate of past observations 15 | OBSERVE = 100. # timesteps to observe before training 16 | EXPLORE = 200000. # frames over which to anneal epsilon 17 | FINAL_EPSILON = 0#0.001 # final value of epsilon 18 | INITIAL_EPSILON = 0#0.01 # starting value of epsilon 19 | REPLAY_MEMORY = 50000 # number of previous transitions to remember 20 | BATCH_SIZE = 32 # size of minibatch 21 | UPDATE_TIME = 100 22 | 23 | try: 24 | tf.mul 25 | except: 26 | # For new version of tensorflow 27 | # tf.mul has been removed in new version of tensorflow 28 | # Using tf.multiply to replace tf.mul 29 | tf.mul = tf.multiply 30 | 31 | class BrainDQN: 32 | 33 | def __init__(self,actions): 34 | # init replay memory 35 | self.replayMemory = deque() 36 | # init some parameters 37 | self.timeStep = 0 38 | self.epsilon = INITIAL_EPSILON 39 | self.actions = actions 40 | # init Q network 41 | self.stateInput,self.QValue,self.W_conv1,self.b_conv1,self.W_conv2,self.b_conv2,self.W_conv3,self.b_conv3,self.W_fc1,self.b_fc1,self.W_fc2,self.b_fc2 = self.createQNetwork() 42 | 43 | # init Target Q Network 44 | self.stateInputT,self.QValueT,self.W_conv1T,self.b_conv1T,self.W_conv2T,self.b_conv2T,self.W_conv3T,self.b_conv3T,self.W_fc1T,self.b_fc1T,self.W_fc2T,self.b_fc2T = self.createQNetwork() 45 | 46 | self.copyTargetQNetworkOperation = [self.W_conv1T.assign(self.W_conv1),self.b_conv1T.assign(self.b_conv1),self.W_conv2T.assign(self.W_conv2),self.b_conv2T.assign(self.b_conv2),self.W_conv3T.assign(self.W_conv3),self.b_conv3T.assign(self.b_conv3),self.W_fc1T.assign(self.W_fc1),self.b_fc1T.assign(self.b_fc1),self.W_fc2T.assign(self.W_fc2),self.b_fc2T.assign(self.b_fc2)] 47 | 48 | self.createTrainingMethod() 49 | 50 | # saving and loading networks 51 | self.saver = tf.train.Saver() 52 | self.session = tf.InteractiveSession() 53 | self.session.run(tf.initialize_all_variables()) 54 | checkpoint = tf.train.get_checkpoint_state("saved_networks") 55 | if checkpoint and checkpoint.model_checkpoint_path: 56 | self.saver.restore(self.session, checkpoint.model_checkpoint_path) 57 | print ("Successfully loaded:", checkpoint.model_checkpoint_path) 58 | else: 59 | print ("Could not find old network weights") 60 | 61 | 62 | def createQNetwork(self): 63 | # network weights 64 | W_conv1 = self.weight_variable([8,8,4,32]) 65 | b_conv1 = self.bias_variable([32]) 66 | 67 | W_conv2 = self.weight_variable([4,4,32,64]) 68 | b_conv2 = self.bias_variable([64]) 69 | 70 | W_conv3 = self.weight_variable([3,3,64,64]) 71 | b_conv3 = self.bias_variable([64]) 72 | 73 | W_fc1 = self.weight_variable([1600,512]) 74 | b_fc1 = self.bias_variable([512]) 75 | 76 | W_fc2 = self.weight_variable([512,self.actions]) 77 | b_fc2 = self.bias_variable([self.actions]) 78 | 79 | # input layer 80 | 81 | stateInput = tf.placeholder("float",[None,80,80,4]) 82 | 83 | # hidden layers 84 | h_conv1 = tf.nn.relu(self.conv2d(stateInput,W_conv1,4) + b_conv1) 85 | h_pool1 = self.max_pool_2x2(h_conv1) 86 | 87 | h_conv2 = tf.nn.relu(self.conv2d(h_pool1,W_conv2,2) + b_conv2) 88 | 89 | h_conv3 = tf.nn.relu(self.conv2d(h_conv2,W_conv3,1) + b_conv3) 90 | 91 | h_conv3_flat = tf.reshape(h_conv3,[-1,1600]) 92 | h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat,W_fc1) + b_fc1) 93 | 94 | # Q Value layer 95 | QValue = tf.matmul(h_fc1,W_fc2) + b_fc2 96 | 97 | return stateInput,QValue,W_conv1,b_conv1,W_conv2,b_conv2,W_conv3,b_conv3,W_fc1,b_fc1,W_fc2,b_fc2 98 | 99 | def copyTargetQNetwork(self): 100 | self.session.run(self.copyTargetQNetworkOperation) 101 | 102 | def createTrainingMethod(self): 103 | self.actionInput = tf.placeholder("float",[None,self.actions]) 104 | self.yInput = tf.placeholder("float", [None]) 105 | Q_Action = tf.reduce_sum(tf.mul(self.QValue, self.actionInput), reduction_indices = 1) 106 | self.cost = tf.reduce_mean(tf.square(self.yInput - Q_Action)) 107 | self.trainStep = tf.train.AdamOptimizer(1e-6).minimize(self.cost) 108 | 109 | 110 | def trainQNetwork(self): 111 | 112 | 113 | # Step 1: obtain random minibatch from replay memory 114 | minibatch = random.sample(self.replayMemory,BATCH_SIZE) 115 | state_batch = [data[0] for data in minibatch] 116 | action_batch = [data[1] for data in minibatch] 117 | reward_batch = [data[2] for data in minibatch] 118 | nextState_batch = [data[3] for data in minibatch] 119 | 120 | # Step 2: calculate y 121 | y_batch = [] 122 | QValue_batch = self.QValueT.eval(feed_dict={self.stateInputT:nextState_batch}) 123 | for i in range(0,BATCH_SIZE): 124 | terminal = minibatch[i][4] 125 | if terminal: 126 | y_batch.append(reward_batch[i]) 127 | else: 128 | y_batch.append(reward_batch[i] + GAMMA * np.max(QValue_batch[i])) 129 | 130 | self.trainStep.run(feed_dict={ 131 | self.yInput : y_batch, 132 | self.actionInput : action_batch, 133 | self.stateInput : state_batch 134 | }) 135 | 136 | # save network every 100000 iteration 137 | if self.timeStep % 10000 == 0: 138 | self.saver.save(self.session, 'saved_networks/' + 'network' + '-dqn', global_step = self.timeStep) 139 | 140 | if self.timeStep % UPDATE_TIME == 0: 141 | self.copyTargetQNetwork() 142 | 143 | 144 | def setPerception(self,nextObservation,action,reward,terminal): 145 | #newState = np.append(nextObservation,self.currentState[:,:,1:],axis = 2) 146 | newState = np.append(self.currentState[:,:,1:],nextObservation,axis = 2) 147 | self.replayMemory.append((self.currentState,action,reward,newState,terminal)) 148 | if len(self.replayMemory) > REPLAY_MEMORY: 149 | self.replayMemory.popleft() 150 | if self.timeStep > OBSERVE: 151 | # Train the network 152 | self.trainQNetwork() 153 | 154 | # print info 155 | state = "" 156 | if self.timeStep <= OBSERVE: 157 | state = "observe" 158 | elif self.timeStep > OBSERVE and self.timeStep <= OBSERVE + EXPLORE: 159 | state = "explore" 160 | else: 161 | state = "train" 162 | 163 | print ("TIMESTEP", self.timeStep, "/ STATE", state, \ 164 | "/ EPSILON", self.epsilon) 165 | 166 | self.currentState = newState 167 | self.timeStep += 1 168 | 169 | def getAction(self): 170 | QValue = self.QValue.eval(feed_dict= {self.stateInput:[self.currentState]})[0] 171 | action = np.zeros(self.actions) 172 | action_index = 0 173 | if self.timeStep % FRAME_PER_ACTION == 0: 174 | if random.random() <= self.epsilon: 175 | action_index = random.randrange(self.actions) 176 | action[action_index] = 1 177 | else: 178 | action_index = np.argmax(QValue) 179 | action[action_index] = 1 180 | else: 181 | action[0] = 1 # do nothing 182 | 183 | # change episilon 184 | if self.epsilon > FINAL_EPSILON and self.timeStep > OBSERVE: 185 | self.epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/EXPLORE 186 | 187 | return action 188 | 189 | def setInitState(self,observation): 190 | self.currentState = np.stack((observation, observation, observation, observation), axis = 2) 191 | 192 | def weight_variable(self,shape): 193 | initial = tf.truncated_normal(shape, stddev = 0.01) 194 | return tf.Variable(initial) 195 | 196 | def bias_variable(self,shape): 197 | initial = tf.constant(0.01, shape = shape) 198 | return tf.Variable(initial) 199 | 200 | def conv2d(self,x, W, stride): 201 | return tf.nn.conv2d(x, W, strides = [1, stride, stride, 1], padding = "SAME") 202 | 203 | def max_pool_2x2(self,x): 204 | return tf.nn.max_pool(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME") 205 | 206 | -------------------------------------------------------------------------------- /FlappyBirdDQN.py: -------------------------------------------------------------------------------- 1 | # ------------------------- 2 | # Project: Deep Q-Learning on Flappy Bird 3 | # Author: Flood Sung 4 | # Date: 2016.3.21 5 | # ------------------------- 6 | 7 | import cv2 8 | import sys 9 | sys.path.append("game/") 10 | import wrapped_flappy_bird as game 11 | from BrainDQN_Nature import BrainDQN 12 | import numpy as np 13 | 14 | # preprocess raw image to 80*80 gray image 15 | def preprocess(observation): 16 | observation = cv2.cvtColor(cv2.resize(observation, (80, 80)), cv2.COLOR_BGR2GRAY) 17 | ret, observation = cv2.threshold(observation,1,255,cv2.THRESH_BINARY) 18 | return np.reshape(observation,(80,80,1)) 19 | 20 | def playFlappyBird(): 21 | # Step 1: init BrainDQN 22 | actions = 2 23 | brain = BrainDQN(actions) 24 | # Step 2: init Flappy Bird Game 25 | flappyBird = game.GameState() 26 | # Step 3: play game 27 | # Step 3.1: obtain init state 28 | action0 = np.array([1,0]) # do nothing 29 | observation0, reward0, terminal = flappyBird.frame_step(action0) 30 | observation0 = cv2.cvtColor(cv2.resize(observation0, (80, 80)), cv2.COLOR_BGR2GRAY) 31 | ret, observation0 = cv2.threshold(observation0,1,255,cv2.THRESH_BINARY) 32 | brain.setInitState(observation0) 33 | 34 | # Step 3.2: run the game 35 | while 1!= 0: 36 | action = brain.getAction() 37 | nextObservation,reward,terminal = flappyBird.frame_step(action) 38 | nextObservation = preprocess(nextObservation) 39 | brain.setPerception(nextObservation,action,reward,terminal) 40 | 41 | def main(): 42 | playFlappyBird() 43 | 44 | if __name__ == '__main__': 45 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Playing Flappy Bird Using Deep Reinforcement Learning (Based on Deep Q Learning DQN) 3 | 4 | ## Include NIPS 2013 version and Nature Version DQN 5 | 6 | 7 | I rewrite the code from another repo and make it much simpler and easier to understand Deep Q Network Algorithm from DeepMind 8 | 9 | The code of DQN is only 160 lines long. 10 | 11 | To run the code, just type python FlappyBirdDQN.py 12 | 13 | Since the DQN code is a unique class, you can use it to play other games. 14 | 15 | 16 | ## About the code 17 | 18 | As a reinforcement learning problem, we knows we need to obtain observations and output actions, and the 'brain' do the processing work. 19 | 20 | Therefore, you can easily understand the BrainDQN.py code. There are three interfaces: 21 | 22 | 1. getInitState() for initialization 23 | 2. getAction() 24 | 3. setPerception(nextObservation,action,reward,terminal) 25 | 26 | the game interface just need to be able to feed the action to the game and output observation,reward,terminal 27 | 28 | 29 | ## Disclaimer 30 | This work is based on the repo: [yenchenlin1994/DeepLearningFlappyBird](https://github.com/yenchenlin1994/DeepLearningFlappyBird.git) 31 | 32 | -------------------------------------------------------------------------------- /assets/audio/die.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/audio/die.ogg -------------------------------------------------------------------------------- /assets/audio/die.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/audio/die.wav -------------------------------------------------------------------------------- /assets/audio/hit.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/audio/hit.ogg -------------------------------------------------------------------------------- /assets/audio/hit.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/audio/hit.wav -------------------------------------------------------------------------------- /assets/audio/point.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/audio/point.ogg -------------------------------------------------------------------------------- /assets/audio/point.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/audio/point.wav -------------------------------------------------------------------------------- /assets/audio/swoosh.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/audio/swoosh.ogg -------------------------------------------------------------------------------- /assets/audio/swoosh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/audio/swoosh.wav -------------------------------------------------------------------------------- /assets/audio/wing.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/audio/wing.ogg -------------------------------------------------------------------------------- /assets/audio/wing.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/audio/wing.wav -------------------------------------------------------------------------------- /assets/sprites/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/sprites/0.png -------------------------------------------------------------------------------- /assets/sprites/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/sprites/1.png -------------------------------------------------------------------------------- /assets/sprites/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/sprites/2.png -------------------------------------------------------------------------------- /assets/sprites/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/sprites/3.png -------------------------------------------------------------------------------- /assets/sprites/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/sprites/4.png -------------------------------------------------------------------------------- /assets/sprites/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/sprites/5.png -------------------------------------------------------------------------------- /assets/sprites/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/sprites/6.png -------------------------------------------------------------------------------- /assets/sprites/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/sprites/7.png -------------------------------------------------------------------------------- /assets/sprites/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/sprites/8.png -------------------------------------------------------------------------------- /assets/sprites/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/sprites/9.png -------------------------------------------------------------------------------- /assets/sprites/background-black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/sprites/background-black.png -------------------------------------------------------------------------------- /assets/sprites/base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/sprites/base.png -------------------------------------------------------------------------------- /assets/sprites/pipe-green.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/sprites/pipe-green.png -------------------------------------------------------------------------------- /assets/sprites/redbird-downflap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/sprites/redbird-downflap.png -------------------------------------------------------------------------------- /assets/sprites/redbird-midflap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/sprites/redbird-midflap.png -------------------------------------------------------------------------------- /assets/sprites/redbird-upflap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/assets/sprites/redbird-upflap.png -------------------------------------------------------------------------------- /game/flappy_bird_utils.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import sys 3 | def load(): 4 | # path of player with different states 5 | PLAYER_PATH = ( 6 | 'assets/sprites/redbird-upflap.png', 7 | 'assets/sprites/redbird-midflap.png', 8 | 'assets/sprites/redbird-downflap.png' 9 | ) 10 | 11 | # path of background 12 | BACKGROUND_PATH = 'assets/sprites/background-black.png' 13 | 14 | # path of pipe 15 | PIPE_PATH = 'assets/sprites/pipe-green.png' 16 | 17 | IMAGES, SOUNDS, HITMASKS = {}, {}, {} 18 | 19 | # numbers sprites for score display 20 | IMAGES['numbers'] = ( 21 | pygame.image.load('assets/sprites/0.png').convert_alpha(), 22 | pygame.image.load('assets/sprites/1.png').convert_alpha(), 23 | pygame.image.load('assets/sprites/2.png').convert_alpha(), 24 | pygame.image.load('assets/sprites/3.png').convert_alpha(), 25 | pygame.image.load('assets/sprites/4.png').convert_alpha(), 26 | pygame.image.load('assets/sprites/5.png').convert_alpha(), 27 | pygame.image.load('assets/sprites/6.png').convert_alpha(), 28 | pygame.image.load('assets/sprites/7.png').convert_alpha(), 29 | pygame.image.load('assets/sprites/8.png').convert_alpha(), 30 | pygame.image.load('assets/sprites/9.png').convert_alpha() 31 | ) 32 | 33 | # base (ground) sprite 34 | IMAGES['base'] = pygame.image.load('assets/sprites/base.png').convert_alpha() 35 | 36 | # sounds 37 | if 'win' in sys.platform: 38 | soundExt = '.wav' 39 | else: 40 | soundExt = '.ogg' 41 | 42 | SOUNDS['die'] = pygame.mixer.Sound('assets/audio/die' + soundExt) 43 | SOUNDS['hit'] = pygame.mixer.Sound('assets/audio/hit' + soundExt) 44 | SOUNDS['point'] = pygame.mixer.Sound('assets/audio/point' + soundExt) 45 | SOUNDS['swoosh'] = pygame.mixer.Sound('assets/audio/swoosh' + soundExt) 46 | SOUNDS['wing'] = pygame.mixer.Sound('assets/audio/wing' + soundExt) 47 | 48 | # select random background sprites 49 | IMAGES['background'] = pygame.image.load(BACKGROUND_PATH).convert() 50 | 51 | # select random player sprites 52 | IMAGES['player'] = ( 53 | pygame.image.load(PLAYER_PATH[0]).convert_alpha(), 54 | pygame.image.load(PLAYER_PATH[1]).convert_alpha(), 55 | pygame.image.load(PLAYER_PATH[2]).convert_alpha(), 56 | ) 57 | 58 | # select random pipe sprites 59 | IMAGES['pipe'] = ( 60 | pygame.transform.rotate( 61 | pygame.image.load(PIPE_PATH).convert_alpha(), 180), 62 | pygame.image.load(PIPE_PATH).convert_alpha(), 63 | ) 64 | 65 | # hismask for pipes 66 | HITMASKS['pipe'] = ( 67 | getHitmask(IMAGES['pipe'][0]), 68 | getHitmask(IMAGES['pipe'][1]), 69 | ) 70 | 71 | # hitmask for player 72 | HITMASKS['player'] = ( 73 | getHitmask(IMAGES['player'][0]), 74 | getHitmask(IMAGES['player'][1]), 75 | getHitmask(IMAGES['player'][2]), 76 | ) 77 | 78 | return IMAGES, SOUNDS, HITMASKS 79 | 80 | def getHitmask(image): 81 | """returns a hitmask using an image's alpha.""" 82 | mask = [] 83 | for x in range(image.get_width()): 84 | mask.append([]) 85 | for y in range(image.get_height()): 86 | mask[x].append(bool(image.get_at((x,y))[3])) 87 | return mask 88 | -------------------------------------------------------------------------------- /game/wrapped_flappy_bird.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import random 4 | import pygame 5 | import flappy_bird_utils 6 | import pygame.surfarray as surfarray 7 | from pygame.locals import * 8 | from itertools import cycle 9 | 10 | FPS = 30 11 | SCREENWIDTH = 288 12 | SCREENHEIGHT = 512 13 | 14 | pygame.init() 15 | FPSCLOCK = pygame.time.Clock() 16 | SCREEN = pygame.display.set_mode((SCREENWIDTH, SCREENHEIGHT)) 17 | pygame.display.set_caption('Flappy Bird') 18 | 19 | IMAGES, SOUNDS, HITMASKS = flappy_bird_utils.load() 20 | PIPEGAPSIZE = 100 # gap between upper and lower part of pipe 21 | BASEY = SCREENHEIGHT * 0.79 22 | 23 | PLAYER_WIDTH = IMAGES['player'][0].get_width() 24 | PLAYER_HEIGHT = IMAGES['player'][0].get_height() 25 | PIPE_WIDTH = IMAGES['pipe'][0].get_width() 26 | PIPE_HEIGHT = IMAGES['pipe'][0].get_height() 27 | BACKGROUND_WIDTH = IMAGES['background'].get_width() 28 | 29 | PLAYER_INDEX_GEN = cycle([0, 1, 2, 1]) 30 | 31 | 32 | class GameState: 33 | def __init__(self): 34 | self.score = self.playerIndex = self.loopIter = 0 35 | self.playerx = int(SCREENWIDTH * 0.2) 36 | self.playery = int((SCREENHEIGHT - PLAYER_HEIGHT) / 2) 37 | self.basex = 0 38 | self.baseShift = IMAGES['base'].get_width() - BACKGROUND_WIDTH 39 | 40 | newPipe1 = getRandomPipe() 41 | newPipe2 = getRandomPipe() 42 | self.upperPipes = [ 43 | {'x': SCREENWIDTH, 'y': newPipe1[0]['y']}, 44 | {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[0]['y']}, 45 | ] 46 | self.lowerPipes = [ 47 | {'x': SCREENWIDTH, 'y': newPipe1[1]['y']}, 48 | {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[1]['y']}, 49 | ] 50 | 51 | # player velocity, max velocity, downward accleration, accleration on flap 52 | self.pipeVelX = -4 53 | self.playerVelY = 0 # player's velocity along Y, default same as playerFlapped 54 | self.playerMaxVelY = 10 # max vel along Y, max descend speed 55 | self.playerMinVelY = -8 # min vel along Y, max ascend speed 56 | self.playerAccY = 1 # players downward accleration 57 | self.playerFlapAcc = -7 # players speed on flapping 58 | self.playerFlapped = False # True when player flaps 59 | 60 | def frame_step(self, input_actions): 61 | pygame.event.pump() 62 | 63 | reward = 0.1 64 | terminal = False 65 | 66 | if sum(input_actions) != 1: 67 | raise ValueError('Multiple input actions!') 68 | 69 | # input_actions[0] == 1: do nothing 70 | # input_actions[1] == 1: flap the bird 71 | if input_actions[1] == 1: 72 | if self.playery > -2 * PLAYER_HEIGHT: 73 | self.playerVelY = self.playerFlapAcc 74 | self.playerFlapped = True 75 | #SOUNDS['wing'].play() 76 | 77 | # check for score 78 | playerMidPos = self.playerx + PLAYER_WIDTH / 2 79 | for pipe in self.upperPipes: 80 | pipeMidPos = pipe['x'] + PIPE_WIDTH / 2 81 | if pipeMidPos <= playerMidPos < pipeMidPos + 4: 82 | self.score += 1 83 | #SOUNDS['point'].play() 84 | reward = 1 85 | 86 | # playerIndex basex change 87 | if (self.loopIter + 1) % 3 == 0: 88 | self.playerIndex = next(PLAYER_INDEX_GEN) 89 | self.loopIter = (self.loopIter + 1) % 30 90 | self.basex = -((-self.basex + 100) % self.baseShift) 91 | 92 | # player's movement 93 | if self.playerVelY < self.playerMaxVelY and not self.playerFlapped: 94 | self.playerVelY += self.playerAccY 95 | if self.playerFlapped: 96 | self.playerFlapped = False 97 | self.playery += min(self.playerVelY, BASEY - self.playery - PLAYER_HEIGHT) 98 | if self.playery < 0: 99 | self.playery = 0 100 | 101 | # move pipes to left 102 | for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes): 103 | uPipe['x'] += self.pipeVelX 104 | lPipe['x'] += self.pipeVelX 105 | 106 | # add new pipe when first pipe is about to touch left of screen 107 | if 0 < self.upperPipes[0]['x'] < 5: 108 | newPipe = getRandomPipe() 109 | self.upperPipes.append(newPipe[0]) 110 | self.lowerPipes.append(newPipe[1]) 111 | 112 | # remove first pipe if its out of the screen 113 | if self.upperPipes[0]['x'] < -PIPE_WIDTH: 114 | self.upperPipes.pop(0) 115 | self.lowerPipes.pop(0) 116 | 117 | # check if crash here 118 | isCrash= checkCrash({'x': self.playerx, 'y': self.playery, 119 | 'index': self.playerIndex}, 120 | self.upperPipes, self.lowerPipes) 121 | if isCrash: 122 | #SOUNDS['hit'].play() 123 | #SOUNDS['die'].play() 124 | terminal = True 125 | self.__init__() 126 | reward = -1 127 | 128 | # draw sprites 129 | SCREEN.blit(IMAGES['background'], (0,0)) 130 | 131 | for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes): 132 | SCREEN.blit(IMAGES['pipe'][0], (uPipe['x'], uPipe['y'])) 133 | SCREEN.blit(IMAGES['pipe'][1], (lPipe['x'], lPipe['y'])) 134 | 135 | SCREEN.blit(IMAGES['base'], (self.basex, BASEY)) 136 | # print score so player overlaps the score 137 | # showScore(self.score) 138 | SCREEN.blit(IMAGES['player'][self.playerIndex], 139 | (self.playerx, self.playery)) 140 | 141 | image_data = pygame.surfarray.array3d(pygame.display.get_surface()) 142 | pygame.display.update() 143 | FPSCLOCK.tick(FPS) 144 | #print self.upperPipes[0]['y'] + PIPE_HEIGHT - int(BASEY * 0.2) 145 | return image_data, reward, terminal 146 | 147 | def getRandomPipe(): 148 | """returns a randomly generated pipe""" 149 | # y of gap between upper and lower pipe 150 | gapYs = [20, 30, 40, 50, 60, 70, 80, 90] 151 | index = random.randint(0, len(gapYs)-1) 152 | gapY = gapYs[index] 153 | 154 | gapY += int(BASEY * 0.2) 155 | pipeX = SCREENWIDTH + 10 156 | 157 | return [ 158 | {'x': pipeX, 'y': gapY - PIPE_HEIGHT}, # upper pipe 159 | {'x': pipeX, 'y': gapY + PIPEGAPSIZE}, # lower pipe 160 | ] 161 | 162 | 163 | def showScore(score): 164 | """displays score in center of screen""" 165 | scoreDigits = [int(x) for x in list(str(score))] 166 | totalWidth = 0 # total width of all numbers to be printed 167 | 168 | for digit in scoreDigits: 169 | totalWidth += IMAGES['numbers'][digit].get_width() 170 | 171 | Xoffset = (SCREENWIDTH - totalWidth) / 2 172 | 173 | for digit in scoreDigits: 174 | SCREEN.blit(IMAGES['numbers'][digit], (Xoffset, SCREENHEIGHT * 0.1)) 175 | Xoffset += IMAGES['numbers'][digit].get_width() 176 | 177 | 178 | def checkCrash(player, upperPipes, lowerPipes): 179 | """returns True if player collders with base or pipes.""" 180 | pi = player['index'] 181 | player['w'] = IMAGES['player'][0].get_width() 182 | player['h'] = IMAGES['player'][0].get_height() 183 | 184 | # if player crashes into ground 185 | if player['y'] + player['h'] >= BASEY - 1: 186 | return True 187 | else: 188 | 189 | playerRect = pygame.Rect(player['x'], player['y'], 190 | player['w'], player['h']) 191 | 192 | for uPipe, lPipe in zip(upperPipes, lowerPipes): 193 | # upper and lower pipe rects 194 | uPipeRect = pygame.Rect(uPipe['x'], uPipe['y'], PIPE_WIDTH, PIPE_HEIGHT) 195 | lPipeRect = pygame.Rect(lPipe['x'], lPipe['y'], PIPE_WIDTH, PIPE_HEIGHT) 196 | 197 | # player and upper/lower pipe hitmasks 198 | pHitMask = HITMASKS['player'][pi] 199 | uHitmask = HITMASKS['pipe'][0] 200 | lHitmask = HITMASKS['pipe'][1] 201 | 202 | # if bird collided with upipe or lpipe 203 | uCollide = pixelCollision(playerRect, uPipeRect, pHitMask, uHitmask) 204 | lCollide = pixelCollision(playerRect, lPipeRect, pHitMask, lHitmask) 205 | 206 | if uCollide or lCollide: 207 | return True 208 | 209 | return False 210 | 211 | def pixelCollision(rect1, rect2, hitmask1, hitmask2): 212 | """Checks if two objects collide and not just their rects""" 213 | rect = rect1.clip(rect2) 214 | 215 | if rect.width == 0 or rect.height == 0: 216 | return False 217 | 218 | x1, y1 = rect.x - rect1.x, rect.y - rect1.y 219 | x2, y2 = rect.x - rect2.x, rect.y - rect2.y 220 | 221 | for x in range(rect.width): 222 | for y in range(rect.height): 223 | if hitmask1[x1+x][y1+y] and hitmask2[x2+x][y2+y]: 224 | return True 225 | return False 226 | -------------------------------------------------------------------------------- /saved_networks/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "network-dqn-50000" 2 | all_model_checkpoint_paths: "network-dqn-10000" 3 | all_model_checkpoint_paths: "network-dqn-20000" 4 | all_model_checkpoint_paths: "network-dqn-30000" 5 | all_model_checkpoint_paths: "network-dqn-40000" 6 | all_model_checkpoint_paths: "network-dqn-50000" 7 | -------------------------------------------------------------------------------- /saved_networks/network-dqn-50000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/saved_networks/network-dqn-50000 -------------------------------------------------------------------------------- /saved_networks/network-dqn-50000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floodsung/DRL-FlappyBird/f4fb85de8cb74db8ea21aa12a822d6ab52f3f60b/saved_networks/network-dqn-50000.meta --------------------------------------------------------------------------------