├── .gitignore
├── README.md
├── assets
├── audio
│ ├── die.ogg
│ ├── die.wav
│ ├── hit.ogg
│ ├── hit.wav
│ ├── point.ogg
│ ├── point.wav
│ ├── swoosh.ogg
│ ├── swoosh.wav
│ ├── wing.ogg
│ └── wing.wav
└── sprites
│ ├── 0.png
│ ├── 1.png
│ ├── 2.png
│ ├── 3.png
│ ├── 4.png
│ ├── 5.png
│ ├── 6.png
│ ├── 7.png
│ ├── 8.png
│ ├── 9.png
│ ├── background-black.png
│ ├── base.png
│ ├── pipe-green.png
│ ├── redbird-downflap.png
│ ├── redbird-midflap.png
│ └── redbird-upflap.png
├── deep_q_network.py
├── game
├── flappy_bird_utils.py
└── wrapped_flappy_bird.py
├── images
├── flappy_bird_demp.gif
├── network.png
└── preprocess.png
├── logs_bird
├── hidden.txt
└── readout.txt
└── saved_networks
├── bird-dqn-2880000
├── bird-dqn-2880000.meta
├── bird-dqn-2890000
├── bird-dqn-2890000.meta
├── bird-dqn-2900000
├── bird-dqn-2900000.meta
├── bird-dqn-2910000
├── bird-dqn-2910000.meta
├── bird-dqn-2920000
├── bird-dqn-2920000.meta
├── checkpoint
└── pretrained_model
└── bird-dqn-policy
/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore all pyc files.
2 | *.pyc
3 |
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Using Deep Q-Network to Learn How To Play Flappy Bird
2 |
3 |
4 |
5 | 7 mins version: [DQN for flappy bird](https://www.youtube.com/watch?v=THhUXIhjkCM)
6 |
7 | ## Overview
8 | This project follows the description of the Deep Q Learning algorithm described in Playing Atari with Deep Reinforcement Learning [2] and shows that this learning algorithm can be further generalized to the notorious Flappy Bird.
9 |
10 | ## Installation Dependencies:
11 | * Python 2.7 or 3
12 | * TensorFlow 0.7
13 | * pygame
14 | * OpenCV-Python
15 |
16 | ## How to Run?
17 | ```
18 | git clone https://github.com/yenchenlin1994/DeepLearningFlappyBird.git
19 | cd DeepLearningFlappyBird
20 | python deep_q_network.py
21 | ```
22 |
23 | ## What is Deep Q-Network?
24 | It is a convolutional neural network, trained with a variant of Q-learning, whose input is raw pixels and whose output is a value function estimating future rewards.
25 |
26 | For those who are interested in deep reinforcement learning, I highly recommend to read the following post:
27 |
28 | [Demystifying Deep Reinforcement Learning](http://www.nervanasys.com/demystifying-deep-reinforcement-learning/)
29 |
30 | ## Deep Q-Network Algorithm
31 |
32 | The pseudo-code for the Deep Q Learning algorithm, as given in [1], can be found below:
33 |
34 | ```
35 | Initialize replay memory D to size N
36 | Initialize action-value function Q with random weights
37 | for episode = 1, M do
38 | Initialize state s_1
39 | for t = 1, T do
40 | With probability ϵ select random action a_t
41 | otherwise select a_t=max_a Q(s_t,a; θ_i)
42 | Execute action a_t in emulator and observe r_t and s_(t+1)
43 | Store transition (s_t,a_t,r_t,s_(t+1)) in D
44 | Sample a minibatch of transitions (s_j,a_j,r_j,s_(j+1)) from D
45 | Set y_j:=
46 | r_j for terminal s_(j+1)
47 | r_j+γ*max_(a^' ) Q(s_(j+1),a'; θ_i) for non-terminal s_(j+1)
48 | Perform a gradient step on (y_j-Q(s_j,a_j; θ_i))^2 with respect to θ
49 | end for
50 | end for
51 | ```
52 |
53 | ## Experiments
54 |
55 | #### Environment
56 | Since deep Q-network is trained on the raw pixel values observed from the game screen at each time step, [3] finds that remove the background appeared in the original game can make it converge faster. This process can be visualized as the following figure:
57 |
58 |
59 |
60 | #### Network Architecture
61 | According to [1], I first preprocessed the game screens with following steps:
62 |
63 | 1. Convert image to grayscale
64 | 2. Resize image to 80x80
65 | 3. Stack last 4 frames to produce an 80x80x4 input array for network
66 |
67 | The architecture of the network is shown in the figure below. The first layer convolves the input image with an 8x8x4x32 kernel at a stride size of 4. The output is then put through a 2x2 max pooling layer. The second layer convolves with a 4x4x32x64 kernel at a stride of 2. We then max pool again. The third layer convolves with a 3x3x64x64 kernel at a stride of 1. We then max pool one more time. The last hidden layer consists of 256 fully connected ReLU nodes.
68 |
69 |
70 |
71 | The final output layer has the same dimensionality as the number of valid actions which can be performed in the game, where the 0th index always corresponds to doing nothing. The values at this output layer represent the Q function given the input state for each valid action. At each time step, the network performs whichever action corresponds to the highest Q value using a ϵ greedy policy.
72 |
73 |
74 | #### Training
75 | At first, I initialize all weight matrices randomly using a normal distribution with a standard deviation of 0.01, then set the replay memory with a max size of 500,00 experiences.
76 |
77 | I start training by choosing actions uniformly at random for the first 10,000 time steps, without updating the network weights. This allows the system to populate the replay memory before training begins.
78 |
79 | Note that unlike [1], which initialize ϵ = 1, I linearly anneal ϵ from 0.1 to 0.0001 over the course of the next 3000,000 frames. The reason why I set it this way is that agent can choose an action every 0.03s (FPS=30) in our game, high ϵ will make it **flap** too much and thus keeps itself at the top of the game screen and finally bump the pipe clumsy. This condition will make Q function converge relatively slow since it only start to look other conditions when ϵ is low.
80 | However, in other games, initialize ϵ to 1 is more reasonable.
81 |
82 | During training time, at each time step, the network samples minibatches of size 32 from the replay memory to train on, and performs a gradient step on the loss function described above using the Adam optimization algorithm with a learning rate of 0.000001. After annealing finishes, the network continues to train indefinitely, with ϵ fixed at 0.001.
83 |
84 | ## FAQ
85 |
86 | #### Checkpoint not found
87 | Change [first line of `saved_networks/checkpoint`](https://github.com/yenchenlin1994/DeepLearningFlappyBird/blob/master/saved_networks/checkpoint#L1) to
88 |
89 | `model_checkpoint_path: "saved_networks/bird-dqn-2920000"`
90 |
91 | #### How to reproduce?
92 | 1. Comment out [these lines](https://github.com/yenchenlin1994/DeepLearningFlappyBird/blob/master/deep_q_network.py#L108-L112)
93 |
94 | 2. Modify `deep_q_network.py`'s parameter as follow:
95 | ```python
96 | OBSERVE = 10000
97 | EXPLORE = 3000000
98 | FINAL_EPSILON = 0.0001
99 | INITIAL_EPSILON = 0.1
100 | ```
101 |
102 | ## References
103 |
104 | [1] Mnih Volodymyr, Koray Kavukcuoglu, David Silver, Andrei A. Rusu, Joel Veness, Marc G. Bellemare, Alex Graves, Martin Riedmiller, Andreas K. Fidjeland, Georg Ostrovski, Stig Petersen, Charles Beattie, Amir Sadik, Ioannis Antonoglou, Helen King, Dharshan Kumaran, Daan Wierstra, Shane Legg, and Demis Hassabis. **Human-level Control through Deep Reinforcement Learning**. Nature, 529-33, 2015.
105 |
106 | [2] Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, Ioannis Antonoglou, Daan Wierstra, and Martin Riedmiller. **Playing Atari with Deep Reinforcement Learning**. NIPS, Deep Learning workshop
107 |
108 | [3] Kevin Chen. **Deep Reinforcement Learning for Flappy Bird** [Report](http://cs229.stanford.edu/proj2015/362_report.pdf) | [Youtube result](https://youtu.be/9WKBzTUsPKc)
109 |
110 | ## Disclaimer
111 | This work is highly based on the following repos:
112 |
113 | 1. [sourabhv/FlapPyBird] (https://github.com/sourabhv/FlapPyBird)
114 | 2. [asrivat1/DeepLearningVideoGames](https://github.com/asrivat1/DeepLearningVideoGames)
115 |
116 |
--------------------------------------------------------------------------------
/assets/audio/die.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/audio/die.ogg
--------------------------------------------------------------------------------
/assets/audio/die.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/audio/die.wav
--------------------------------------------------------------------------------
/assets/audio/hit.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/audio/hit.ogg
--------------------------------------------------------------------------------
/assets/audio/hit.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/audio/hit.wav
--------------------------------------------------------------------------------
/assets/audio/point.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/audio/point.ogg
--------------------------------------------------------------------------------
/assets/audio/point.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/audio/point.wav
--------------------------------------------------------------------------------
/assets/audio/swoosh.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/audio/swoosh.ogg
--------------------------------------------------------------------------------
/assets/audio/swoosh.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/audio/swoosh.wav
--------------------------------------------------------------------------------
/assets/audio/wing.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/audio/wing.ogg
--------------------------------------------------------------------------------
/assets/audio/wing.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/audio/wing.wav
--------------------------------------------------------------------------------
/assets/sprites/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/sprites/0.png
--------------------------------------------------------------------------------
/assets/sprites/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/sprites/1.png
--------------------------------------------------------------------------------
/assets/sprites/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/sprites/2.png
--------------------------------------------------------------------------------
/assets/sprites/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/sprites/3.png
--------------------------------------------------------------------------------
/assets/sprites/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/sprites/4.png
--------------------------------------------------------------------------------
/assets/sprites/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/sprites/5.png
--------------------------------------------------------------------------------
/assets/sprites/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/sprites/6.png
--------------------------------------------------------------------------------
/assets/sprites/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/sprites/7.png
--------------------------------------------------------------------------------
/assets/sprites/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/sprites/8.png
--------------------------------------------------------------------------------
/assets/sprites/9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/sprites/9.png
--------------------------------------------------------------------------------
/assets/sprites/background-black.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/sprites/background-black.png
--------------------------------------------------------------------------------
/assets/sprites/base.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/sprites/base.png
--------------------------------------------------------------------------------
/assets/sprites/pipe-green.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/sprites/pipe-green.png
--------------------------------------------------------------------------------
/assets/sprites/redbird-downflap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/sprites/redbird-downflap.png
--------------------------------------------------------------------------------
/assets/sprites/redbird-midflap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/sprites/redbird-midflap.png
--------------------------------------------------------------------------------
/assets/sprites/redbird-upflap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/assets/sprites/redbird-upflap.png
--------------------------------------------------------------------------------
/deep_q_network.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from __future__ import print_function
3 |
4 | import tensorflow as tf
5 | import cv2
6 | import sys
7 | sys.path.append("game/")
8 | import wrapped_flappy_bird as game
9 | import random
10 | import numpy as np
11 | from collections import deque
12 |
13 | GAME = 'bird' # the name of the game being played for log files
14 | ACTIONS = 2 # number of valid actions
15 | GAMMA = 0.99 # decay rate of past observations
16 | OBSERVE = 100000. # timesteps to observe before training
17 | EXPLORE = 2000000. # frames over which to anneal epsilon
18 | FINAL_EPSILON = 0.0001 # final value of epsilon
19 | INITIAL_EPSILON = 0.0001 # starting value of epsilon
20 | REPLAY_MEMORY = 50000 # number of previous transitions to remember
21 | BATCH = 32 # size of minibatch
22 | FRAME_PER_ACTION = 1
23 |
24 | def weight_variable(shape):
25 | initial = tf.truncated_normal(shape, stddev = 0.01)
26 | return tf.Variable(initial)
27 |
28 | def bias_variable(shape):
29 | initial = tf.constant(0.01, shape = shape)
30 | return tf.Variable(initial)
31 |
32 | def conv2d(x, W, stride):
33 | return tf.nn.conv2d(x, W, strides = [1, stride, stride, 1], padding = "SAME")
34 |
35 | def max_pool_2x2(x):
36 | return tf.nn.max_pool(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME")
37 |
38 | def createNetwork():
39 | # network weights
40 | W_conv1 = weight_variable([8, 8, 4, 32])
41 | b_conv1 = bias_variable([32])
42 |
43 | W_conv2 = weight_variable([4, 4, 32, 64])
44 | b_conv2 = bias_variable([64])
45 |
46 | W_conv3 = weight_variable([3, 3, 64, 64])
47 | b_conv3 = bias_variable([64])
48 |
49 | W_fc1 = weight_variable([1600, 512])
50 | b_fc1 = bias_variable([512])
51 |
52 | W_fc2 = weight_variable([512, ACTIONS])
53 | b_fc2 = bias_variable([ACTIONS])
54 |
55 | # input layer
56 | s = tf.placeholder("float", [None, 80, 80, 4])
57 |
58 | # hidden layers
59 | h_conv1 = tf.nn.relu(conv2d(s, W_conv1, 4) + b_conv1)
60 | h_pool1 = max_pool_2x2(h_conv1)
61 |
62 | h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2, 2) + b_conv2)
63 | #h_pool2 = max_pool_2x2(h_conv2)
64 |
65 | h_conv3 = tf.nn.relu(conv2d(h_conv2, W_conv3, 1) + b_conv3)
66 | #h_pool3 = max_pool_2x2(h_conv3)
67 |
68 | #h_pool3_flat = tf.reshape(h_pool3, [-1, 256])
69 | h_conv3_flat = tf.reshape(h_conv3, [-1, 1600])
70 |
71 | h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat, W_fc1) + b_fc1)
72 |
73 | # readout layer
74 | readout = tf.matmul(h_fc1, W_fc2) + b_fc2
75 |
76 | return s, readout, h_fc1
77 |
78 | def trainNetwork(s, readout, h_fc1, sess):
79 | # define the cost function
80 | a = tf.placeholder("float", [None, ACTIONS])
81 | y = tf.placeholder("float", [None])
82 | readout_action = tf.reduce_sum(tf.mul(readout, a), reduction_indices=1)
83 | cost = tf.reduce_mean(tf.square(y - readout_action))
84 | train_step = tf.train.AdamOptimizer(1e-6).minimize(cost)
85 |
86 | # open up a game state to communicate with emulator
87 | game_state = game.GameState()
88 |
89 | # store the previous observations in replay memory
90 | D = deque()
91 |
92 | # printing
93 | a_file = open("logs_" + GAME + "/readout.txt", 'w')
94 | h_file = open("logs_" + GAME + "/hidden.txt", 'w')
95 |
96 | # get the first state by doing nothing and preprocess the image to 80x80x4
97 | do_nothing = np.zeros(ACTIONS)
98 | do_nothing[0] = 1
99 | x_t, r_0, terminal = game_state.frame_step(do_nothing)
100 | x_t = cv2.cvtColor(cv2.resize(x_t, (80, 80)), cv2.COLOR_BGR2GRAY)
101 | ret, x_t = cv2.threshold(x_t,1,255,cv2.THRESH_BINARY)
102 | s_t = np.stack((x_t, x_t, x_t, x_t), axis=2)
103 |
104 | # saving and loading networks
105 | saver = tf.train.Saver()
106 | sess.run(tf.initialize_all_variables())
107 | checkpoint = tf.train.get_checkpoint_state("saved_networks")
108 | if checkpoint and checkpoint.model_checkpoint_path:
109 | saver.restore(sess, checkpoint.model_checkpoint_path)
110 | print("Successfully loaded:", checkpoint.model_checkpoint_path)
111 | else:
112 | print("Could not find old network weights")
113 |
114 | # start training
115 | epsilon = INITIAL_EPSILON
116 | t = 0
117 | while "flappy bird" != "angry bird":
118 | # choose an action epsilon greedily
119 | readout_t = readout.eval(feed_dict={s : [s_t]})[0]
120 | a_t = np.zeros([ACTIONS])
121 | action_index = 0
122 | if t % FRAME_PER_ACTION == 0:
123 | if random.random() <= epsilon:
124 | print("----------Random Action----------")
125 | action_index = random.randrange(ACTIONS)
126 | a_t[random.randrange(ACTIONS)] = 1
127 | else:
128 | action_index = np.argmax(readout_t)
129 | a_t[action_index] = 1
130 | else:
131 | a_t[0] = 1 # do nothing
132 |
133 | # scale down epsilon
134 | if epsilon > FINAL_EPSILON and t > OBSERVE:
135 | epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE
136 |
137 | # run the selected action and observe next state and reward
138 | x_t1_colored, r_t, terminal = game_state.frame_step(a_t)
139 | x_t1 = cv2.cvtColor(cv2.resize(x_t1_colored, (80, 80)), cv2.COLOR_BGR2GRAY)
140 | ret, x_t1 = cv2.threshold(x_t1, 1, 255, cv2.THRESH_BINARY)
141 | x_t1 = np.reshape(x_t1, (80, 80, 1))
142 | #s_t1 = np.append(x_t1, s_t[:,:,1:], axis = 2)
143 | s_t1 = np.append(x_t1, s_t[:, :, :3], axis=2)
144 |
145 | # store the transition in D
146 | D.append((s_t, a_t, r_t, s_t1, terminal))
147 | if len(D) > REPLAY_MEMORY:
148 | D.popleft()
149 |
150 | # only train if done observing
151 | if t > OBSERVE:
152 | # sample a minibatch to train on
153 | minibatch = random.sample(D, BATCH)
154 |
155 | # get the batch variables
156 | s_j_batch = [d[0] for d in minibatch]
157 | a_batch = [d[1] for d in minibatch]
158 | r_batch = [d[2] for d in minibatch]
159 | s_j1_batch = [d[3] for d in minibatch]
160 |
161 | y_batch = []
162 | readout_j1_batch = readout.eval(feed_dict = {s : s_j1_batch})
163 | for i in range(0, len(minibatch)):
164 | terminal = minibatch[i][4]
165 | # if terminal, only equals reward
166 | if terminal:
167 | y_batch.append(r_batch[i])
168 | else:
169 | y_batch.append(r_batch[i] + GAMMA * np.max(readout_j1_batch[i]))
170 |
171 | # perform gradient step
172 | train_step.run(feed_dict = {
173 | y : y_batch,
174 | a : a_batch,
175 | s : s_j_batch}
176 | )
177 |
178 | # update the old values
179 | s_t = s_t1
180 | t += 1
181 |
182 | # save progress every 10000 iterations
183 | if t % 10000 == 0:
184 | saver.save(sess, 'saved_networks/' + GAME + '-dqn', global_step = t)
185 |
186 | # print info
187 | state = ""
188 | if t <= OBSERVE:
189 | state = "observe"
190 | elif t > OBSERVE and t <= OBSERVE + EXPLORE:
191 | state = "explore"
192 | else:
193 | state = "train"
194 |
195 | print("TIMESTEP", t, "/ STATE", state, \
196 | "/ EPSILON", epsilon, "/ ACTION", action_index, "/ REWARD", r_t, \
197 | "/ Q_MAX %e" % np.max(readout_t))
198 | # write info to files
199 | '''
200 | if t % 10000 <= 100:
201 | a_file.write(",".join([str(x) for x in readout_t]) + '\n')
202 | h_file.write(",".join([str(x) for x in h_fc1.eval(feed_dict={s:[s_t]})[0]]) + '\n')
203 | cv2.imwrite("logs_tetris/frame" + str(t) + ".png", x_t1)
204 | '''
205 |
206 | def playGame():
207 | sess = tf.InteractiveSession()
208 | s, readout, h_fc1 = createNetwork()
209 | trainNetwork(s, readout, h_fc1, sess)
210 |
211 | def main():
212 | playGame()
213 |
214 | if __name__ == "__main__":
215 | main()
216 |
--------------------------------------------------------------------------------
/game/flappy_bird_utils.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import sys
3 | def load():
4 | # path of player with different states
5 | PLAYER_PATH = (
6 | 'assets/sprites/redbird-upflap.png',
7 | 'assets/sprites/redbird-midflap.png',
8 | 'assets/sprites/redbird-downflap.png'
9 | )
10 |
11 | # path of background
12 | BACKGROUND_PATH = 'assets/sprites/background-black.png'
13 |
14 | # path of pipe
15 | PIPE_PATH = 'assets/sprites/pipe-green.png'
16 |
17 | IMAGES, SOUNDS, HITMASKS = {}, {}, {}
18 |
19 | # numbers sprites for score display
20 | IMAGES['numbers'] = (
21 | pygame.image.load('assets/sprites/0.png').convert_alpha(),
22 | pygame.image.load('assets/sprites/1.png').convert_alpha(),
23 | pygame.image.load('assets/sprites/2.png').convert_alpha(),
24 | pygame.image.load('assets/sprites/3.png').convert_alpha(),
25 | pygame.image.load('assets/sprites/4.png').convert_alpha(),
26 | pygame.image.load('assets/sprites/5.png').convert_alpha(),
27 | pygame.image.load('assets/sprites/6.png').convert_alpha(),
28 | pygame.image.load('assets/sprites/7.png').convert_alpha(),
29 | pygame.image.load('assets/sprites/8.png').convert_alpha(),
30 | pygame.image.load('assets/sprites/9.png').convert_alpha()
31 | )
32 |
33 | # base (ground) sprite
34 | IMAGES['base'] = pygame.image.load('assets/sprites/base.png').convert_alpha()
35 |
36 | # sounds
37 | if 'win' in sys.platform:
38 | soundExt = '.wav'
39 | else:
40 | soundExt = '.ogg'
41 |
42 | SOUNDS['die'] = pygame.mixer.Sound('assets/audio/die' + soundExt)
43 | SOUNDS['hit'] = pygame.mixer.Sound('assets/audio/hit' + soundExt)
44 | SOUNDS['point'] = pygame.mixer.Sound('assets/audio/point' + soundExt)
45 | SOUNDS['swoosh'] = pygame.mixer.Sound('assets/audio/swoosh' + soundExt)
46 | SOUNDS['wing'] = pygame.mixer.Sound('assets/audio/wing' + soundExt)
47 |
48 | # select random background sprites
49 | IMAGES['background'] = pygame.image.load(BACKGROUND_PATH).convert()
50 |
51 | # select random player sprites
52 | IMAGES['player'] = (
53 | pygame.image.load(PLAYER_PATH[0]).convert_alpha(),
54 | pygame.image.load(PLAYER_PATH[1]).convert_alpha(),
55 | pygame.image.load(PLAYER_PATH[2]).convert_alpha(),
56 | )
57 |
58 | # select random pipe sprites
59 | IMAGES['pipe'] = (
60 | pygame.transform.rotate(
61 | pygame.image.load(PIPE_PATH).convert_alpha(), 180),
62 | pygame.image.load(PIPE_PATH).convert_alpha(),
63 | )
64 |
65 | # hismask for pipes
66 | HITMASKS['pipe'] = (
67 | getHitmask(IMAGES['pipe'][0]),
68 | getHitmask(IMAGES['pipe'][1]),
69 | )
70 |
71 | # hitmask for player
72 | HITMASKS['player'] = (
73 | getHitmask(IMAGES['player'][0]),
74 | getHitmask(IMAGES['player'][1]),
75 | getHitmask(IMAGES['player'][2]),
76 | )
77 |
78 | return IMAGES, SOUNDS, HITMASKS
79 |
80 | def getHitmask(image):
81 | """returns a hitmask using an image's alpha."""
82 | mask = []
83 | for x in range(image.get_width()):
84 | mask.append([])
85 | for y in range(image.get_height()):
86 | mask[x].append(bool(image.get_at((x,y))[3]))
87 | return mask
88 |
--------------------------------------------------------------------------------
/game/wrapped_flappy_bird.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sys
3 | import random
4 | import pygame
5 | import flappy_bird_utils
6 | import pygame.surfarray as surfarray
7 | from pygame.locals import *
8 | from itertools import cycle
9 |
10 | FPS = 30
11 | SCREENWIDTH = 288
12 | SCREENHEIGHT = 512
13 |
14 | pygame.init()
15 | FPSCLOCK = pygame.time.Clock()
16 | SCREEN = pygame.display.set_mode((SCREENWIDTH, SCREENHEIGHT))
17 | pygame.display.set_caption('Flappy Bird')
18 |
19 | IMAGES, SOUNDS, HITMASKS = flappy_bird_utils.load()
20 | PIPEGAPSIZE = 100 # gap between upper and lower part of pipe
21 | BASEY = SCREENHEIGHT * 0.79
22 |
23 | PLAYER_WIDTH = IMAGES['player'][0].get_width()
24 | PLAYER_HEIGHT = IMAGES['player'][0].get_height()
25 | PIPE_WIDTH = IMAGES['pipe'][0].get_width()
26 | PIPE_HEIGHT = IMAGES['pipe'][0].get_height()
27 | BACKGROUND_WIDTH = IMAGES['background'].get_width()
28 |
29 | PLAYER_INDEX_GEN = cycle([0, 1, 2, 1])
30 |
31 |
32 | class GameState:
33 | def __init__(self):
34 | self.score = self.playerIndex = self.loopIter = 0
35 | self.playerx = int(SCREENWIDTH * 0.2)
36 | self.playery = int((SCREENHEIGHT - PLAYER_HEIGHT) / 2)
37 | self.basex = 0
38 | self.baseShift = IMAGES['base'].get_width() - BACKGROUND_WIDTH
39 |
40 | newPipe1 = getRandomPipe()
41 | newPipe2 = getRandomPipe()
42 | self.upperPipes = [
43 | {'x': SCREENWIDTH, 'y': newPipe1[0]['y']},
44 | {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[0]['y']},
45 | ]
46 | self.lowerPipes = [
47 | {'x': SCREENWIDTH, 'y': newPipe1[1]['y']},
48 | {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[1]['y']},
49 | ]
50 |
51 | # player velocity, max velocity, downward accleration, accleration on flap
52 | self.pipeVelX = -4
53 | self.playerVelY = 0 # player's velocity along Y, default same as playerFlapped
54 | self.playerMaxVelY = 10 # max vel along Y, max descend speed
55 | self.playerMinVelY = -8 # min vel along Y, max ascend speed
56 | self.playerAccY = 1 # players downward accleration
57 | self.playerFlapAcc = -9 # players speed on flapping
58 | self.playerFlapped = False # True when player flaps
59 |
60 | def frame_step(self, input_actions):
61 | pygame.event.pump()
62 |
63 | reward = 0.1
64 | terminal = False
65 |
66 | if sum(input_actions) != 1:
67 | raise ValueError('Multiple input actions!')
68 |
69 | # input_actions[0] == 1: do nothing
70 | # input_actions[1] == 1: flap the bird
71 | if input_actions[1] == 1:
72 | if self.playery > -2 * PLAYER_HEIGHT:
73 | self.playerVelY = self.playerFlapAcc
74 | self.playerFlapped = True
75 | #SOUNDS['wing'].play()
76 |
77 | # check for score
78 | playerMidPos = self.playerx + PLAYER_WIDTH / 2
79 | for pipe in self.upperPipes:
80 | pipeMidPos = pipe['x'] + PIPE_WIDTH / 2
81 | if pipeMidPos <= playerMidPos < pipeMidPos + 4:
82 | self.score += 1
83 | #SOUNDS['point'].play()
84 | reward = 1
85 |
86 | # playerIndex basex change
87 | if (self.loopIter + 1) % 3 == 0:
88 | self.playerIndex = next(PLAYER_INDEX_GEN)
89 | self.loopIter = (self.loopIter + 1) % 30
90 | self.basex = -((-self.basex + 100) % self.baseShift)
91 |
92 | # player's movement
93 | if self.playerVelY < self.playerMaxVelY and not self.playerFlapped:
94 | self.playerVelY += self.playerAccY
95 | if self.playerFlapped:
96 | self.playerFlapped = False
97 | self.playery += min(self.playerVelY, BASEY - self.playery - PLAYER_HEIGHT)
98 | if self.playery < 0:
99 | self.playery = 0
100 |
101 | # move pipes to left
102 | for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes):
103 | uPipe['x'] += self.pipeVelX
104 | lPipe['x'] += self.pipeVelX
105 |
106 | # add new pipe when first pipe is about to touch left of screen
107 | if 0 < self.upperPipes[0]['x'] < 5:
108 | newPipe = getRandomPipe()
109 | self.upperPipes.append(newPipe[0])
110 | self.lowerPipes.append(newPipe[1])
111 |
112 | # remove first pipe if its out of the screen
113 | if self.upperPipes[0]['x'] < -PIPE_WIDTH:
114 | self.upperPipes.pop(0)
115 | self.lowerPipes.pop(0)
116 |
117 | # check if crash here
118 | isCrash= checkCrash({'x': self.playerx, 'y': self.playery,
119 | 'index': self.playerIndex},
120 | self.upperPipes, self.lowerPipes)
121 | if isCrash:
122 | #SOUNDS['hit'].play()
123 | #SOUNDS['die'].play()
124 | terminal = True
125 | self.__init__()
126 | reward = -1
127 |
128 | # draw sprites
129 | SCREEN.blit(IMAGES['background'], (0,0))
130 |
131 | for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes):
132 | SCREEN.blit(IMAGES['pipe'][0], (uPipe['x'], uPipe['y']))
133 | SCREEN.blit(IMAGES['pipe'][1], (lPipe['x'], lPipe['y']))
134 |
135 | SCREEN.blit(IMAGES['base'], (self.basex, BASEY))
136 | # print score so player overlaps the score
137 | # showScore(self.score)
138 | SCREEN.blit(IMAGES['player'][self.playerIndex],
139 | (self.playerx, self.playery))
140 |
141 | image_data = pygame.surfarray.array3d(pygame.display.get_surface())
142 | pygame.display.update()
143 | FPSCLOCK.tick(FPS)
144 | #print self.upperPipes[0]['y'] + PIPE_HEIGHT - int(BASEY * 0.2)
145 | return image_data, reward, terminal
146 |
147 | def getRandomPipe():
148 | """returns a randomly generated pipe"""
149 | # y of gap between upper and lower pipe
150 | gapYs = [20, 30, 40, 50, 60, 70, 80, 90]
151 | index = random.randint(0, len(gapYs)-1)
152 | gapY = gapYs[index]
153 |
154 | gapY += int(BASEY * 0.2)
155 | pipeX = SCREENWIDTH + 10
156 |
157 | return [
158 | {'x': pipeX, 'y': gapY - PIPE_HEIGHT}, # upper pipe
159 | {'x': pipeX, 'y': gapY + PIPEGAPSIZE}, # lower pipe
160 | ]
161 |
162 |
163 | def showScore(score):
164 | """displays score in center of screen"""
165 | scoreDigits = [int(x) for x in list(str(score))]
166 | totalWidth = 0 # total width of all numbers to be printed
167 |
168 | for digit in scoreDigits:
169 | totalWidth += IMAGES['numbers'][digit].get_width()
170 |
171 | Xoffset = (SCREENWIDTH - totalWidth) / 2
172 |
173 | for digit in scoreDigits:
174 | SCREEN.blit(IMAGES['numbers'][digit], (Xoffset, SCREENHEIGHT * 0.1))
175 | Xoffset += IMAGES['numbers'][digit].get_width()
176 |
177 |
178 | def checkCrash(player, upperPipes, lowerPipes):
179 | """returns True if player collders with base or pipes."""
180 | pi = player['index']
181 | player['w'] = IMAGES['player'][0].get_width()
182 | player['h'] = IMAGES['player'][0].get_height()
183 |
184 | # if player crashes into ground
185 | if player['y'] + player['h'] >= BASEY - 1:
186 | return True
187 | else:
188 |
189 | playerRect = pygame.Rect(player['x'], player['y'],
190 | player['w'], player['h'])
191 |
192 | for uPipe, lPipe in zip(upperPipes, lowerPipes):
193 | # upper and lower pipe rects
194 | uPipeRect = pygame.Rect(uPipe['x'], uPipe['y'], PIPE_WIDTH, PIPE_HEIGHT)
195 | lPipeRect = pygame.Rect(lPipe['x'], lPipe['y'], PIPE_WIDTH, PIPE_HEIGHT)
196 |
197 | # player and upper/lower pipe hitmasks
198 | pHitMask = HITMASKS['player'][pi]
199 | uHitmask = HITMASKS['pipe'][0]
200 | lHitmask = HITMASKS['pipe'][1]
201 |
202 | # if bird collided with upipe or lpipe
203 | uCollide = pixelCollision(playerRect, uPipeRect, pHitMask, uHitmask)
204 | lCollide = pixelCollision(playerRect, lPipeRect, pHitMask, lHitmask)
205 |
206 | if uCollide or lCollide:
207 | return True
208 |
209 | return False
210 |
211 | def pixelCollision(rect1, rect2, hitmask1, hitmask2):
212 | """Checks if two objects collide and not just their rects"""
213 | rect = rect1.clip(rect2)
214 |
215 | if rect.width == 0 or rect.height == 0:
216 | return False
217 |
218 | x1, y1 = rect.x - rect1.x, rect.y - rect1.y
219 | x2, y2 = rect.x - rect2.x, rect.y - rect2.y
220 |
221 | for x in range(rect.width):
222 | for y in range(rect.height):
223 | if hitmask1[x1+x][y1+y] and hitmask2[x2+x][y2+y]:
224 | return True
225 | return False
226 |
--------------------------------------------------------------------------------
/images/flappy_bird_demp.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/images/flappy_bird_demp.gif
--------------------------------------------------------------------------------
/images/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/images/network.png
--------------------------------------------------------------------------------
/images/preprocess.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/images/preprocess.png
--------------------------------------------------------------------------------
/logs_bird/hidden.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/logs_bird/hidden.txt
--------------------------------------------------------------------------------
/logs_bird/readout.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/logs_bird/readout.txt
--------------------------------------------------------------------------------
/saved_networks/bird-dqn-2880000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/saved_networks/bird-dqn-2880000
--------------------------------------------------------------------------------
/saved_networks/bird-dqn-2880000.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/saved_networks/bird-dqn-2880000.meta
--------------------------------------------------------------------------------
/saved_networks/bird-dqn-2890000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/saved_networks/bird-dqn-2890000
--------------------------------------------------------------------------------
/saved_networks/bird-dqn-2890000.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/saved_networks/bird-dqn-2890000.meta
--------------------------------------------------------------------------------
/saved_networks/bird-dqn-2900000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/saved_networks/bird-dqn-2900000
--------------------------------------------------------------------------------
/saved_networks/bird-dqn-2900000.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/saved_networks/bird-dqn-2900000.meta
--------------------------------------------------------------------------------
/saved_networks/bird-dqn-2910000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/saved_networks/bird-dqn-2910000
--------------------------------------------------------------------------------
/saved_networks/bird-dqn-2910000.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/saved_networks/bird-dqn-2910000.meta
--------------------------------------------------------------------------------
/saved_networks/bird-dqn-2920000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/saved_networks/bird-dqn-2920000
--------------------------------------------------------------------------------
/saved_networks/bird-dqn-2920000.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/saved_networks/bird-dqn-2920000.meta
--------------------------------------------------------------------------------
/saved_networks/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "bird-dqn-2920000"
2 | all_model_checkpoint_paths: "bird-dqn-2880000"
3 | all_model_checkpoint_paths: "bird-dqn-2890000"
4 | all_model_checkpoint_paths: "bird-dqn-2900000"
5 | all_model_checkpoint_paths: "bird-dqn-2910000"
6 | all_model_checkpoint_paths: "bird-dqn-2920000"
7 |
--------------------------------------------------------------------------------
/saved_networks/pretrained_model/bird-dqn-policy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepLearningProjects/DeepLearningFlappyBird/5ff8ff654ddcfb83f4efeddfaca29fed6e69fa1e/saved_networks/pretrained_model/bird-dqn-policy
--------------------------------------------------------------------------------