├── .gitignore ├── BrainDQN.py ├── README.md ├── assets ├── audio │ ├── die.ogg │ ├── die.wav │ ├── hit.ogg │ ├── hit.wav │ ├── point.ogg │ ├── point.wav │ ├── swoosh.ogg │ ├── swoosh.wav │ ├── wing.ogg │ └── wing.wav └── sprites │ ├── 0.png │ ├── 1.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ ├── 7.png │ ├── 8.png │ ├── 9.png │ ├── background-black.png │ ├── base.png │ ├── pipe-green.png │ ├── redbird-downflap.png │ ├── redbird-midflap.png │ └── redbird-upflap.png ├── game ├── flappy_bird_utils.py └── wrapped_flappy_bird.py ├── main.py ├── misc.py ├── play.sh └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | back/ 3 | *.tar 4 | 5 | -------------------------------------------------------------------------------- /BrainDQN.py: -------------------------------------------------------------------------------- 1 | # ----------------------------- 2 | # Author: Flood Sung 3 | # Date: 2016.3.21 4 | # ============================= 5 | # Modified by xmfbit, 2017.4 6 | # ----------------------------- 7 | 8 | import random 9 | from collections import deque 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | ACTIONS = 2 # total available action number for the game: UP and DO NOTHING 16 | 17 | class BrainDQN(nn.Module): 18 | empty_frame = np.zeros((128, 72), dtype=np.float32) 19 | empty_state = np.stack((empty_frame, empty_frame, empty_frame, empty_frame), axis=0) 20 | 21 | def __init__(self, epsilon, mem_size, cuda): 22 | """Initialization 23 | 24 | epsilon: initial epsilon for exploration 25 | mem_size: memory size for experience replay 26 | cuda: use cuda or not 27 | """ 28 | super(BrainDQN, self).__init__() 29 | self.train = None 30 | # init replay memory 31 | self.replay_memory = deque() 32 | # init some parameters 33 | self.time_step = 0 34 | self.epsilon = epsilon 35 | self.actions = ACTIONS 36 | self.mem_size = mem_size 37 | self.use_cuda = cuda 38 | # init Q network 39 | self.createQNetwork() 40 | 41 | def createQNetwork(self): 42 | """ Create dqn, invoked by `__init__` 43 | 44 | model structure: conv->conv->fc->fc 45 | change it to your new design 46 | """ 47 | self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4, padding=2) 48 | self.relu1 = nn.ReLU(inplace=True) 49 | self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1) 50 | self.relu2 = nn.ReLU(inplace=True) 51 | self.map_size = (64, 16, 9) 52 | self.fc1 = nn.Linear(self.map_size[0]*self.map_size[1]*self.map_size[2], 256) 53 | self.relu3 = nn.ReLU(inplace=True) 54 | self.fc2 = nn.Linear(256, self.actions) 55 | 56 | def get_q_value(self, o): 57 | """Get Q value estimation w.r.t. current observation `o` 58 | 59 | o -- current observation 60 | """ 61 | # get Q estimation 62 | out = self.conv1(o) 63 | out = self.relu1(out) 64 | out = self.conv2(out) 65 | out = self.relu2(out) 66 | out = out.view(out.size()[0], -1) 67 | out = self.fc1(out) 68 | out = self.relu3(out) 69 | out = self.fc2(out) 70 | return out 71 | 72 | def forward(self, o): 73 | """Forward procedure to get MSE loss 74 | 75 | o -- current observation 76 | """ 77 | # get Q(s,a;\theta) 78 | q = self.get_q_value(o) 79 | return q 80 | 81 | def set_train(self): 82 | """Set phase TRAIN 83 | """ 84 | self.train = True 85 | 86 | def set_eval(self): 87 | """Set phase EVALUATION 88 | """ 89 | self.train = False 90 | 91 | def set_initial_state(self, state=None): 92 | """Set initial state 93 | 94 | state: initial state. if None, use `BrainDQN.empty_state` 95 | """ 96 | if state is None: 97 | self.current_state = BrainDQN.empty_state 98 | else: 99 | self.current_state = state 100 | 101 | 102 | def store_transition(self, o_next, action, reward, terminal): 103 | """Store transition (\fan_t, a_t, r_t, \fan_{t+1}) 104 | 105 | o_next: next observation, \fan_{t+1} 106 | action: action, a_t 107 | reward: reward, r_t 108 | terminal: terminal(\fan_{t+1}) 109 | """ 110 | next_state = np.append(self.current_state[1:,:,:], o_next.reshape((1,)+o_next.shape), axis=0) 111 | self.replay_memory.append((self.current_state, action, reward, next_state, terminal)) 112 | if len(self.replay_memory) > self.mem_size: 113 | self.replay_memory.popleft() 114 | if not terminal: 115 | self.current_state = next_state 116 | else: 117 | self.set_initial_state() 118 | 119 | def get_action_randomly(self): 120 | """Get action randomly 121 | """ 122 | action = np.zeros(self.actions, dtype=np.float32) 123 | #action_index = random.randrange(self.actions) 124 | action_index = 0 if random.random() < 0.8 else 1 125 | action[action_index] = 1 126 | return action 127 | 128 | def get_optim_action(self): 129 | """Get optimal action based on current state 130 | """ 131 | state = self.current_state 132 | state_var = Variable(torch.from_numpy(state), volatile=True).unsqueeze(0) 133 | if self.use_cuda: 134 | state_var = state_var.cuda() 135 | q_value = self.forward(state_var) 136 | _, action_index = torch.max(q_value, dim=1) 137 | action_index = action_index.data[0][0] 138 | action = np.zeros(self.actions, dtype=np.float32) 139 | action[action_index] = 1 140 | return action 141 | 142 | def get_action(self): 143 | """Get action w.r.t current state 144 | """ 145 | if self.train and random.random() <= self.epsilon: 146 | return self.get_action_randomly() 147 | return self.get_optim_action() 148 | 149 | def increase_time_step(self, time_step=1): 150 | """increase time step""" 151 | self.time_step += time_step 152 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Flappy Bird With DQN 2 | 3 | DQN is a technology to realize reinforcement learning, first proposed by Deep Mind in NIPS13([paper in arxiv](https://arxiv.org/pdf/1312.5602v1.pdf)), whose input is raw pixels and whose output is a value function estimating future rewards. Using Experience Replay, they overcame the problem of network training. 4 | 5 | This demo is about using DQN to train a convolutional neural network to play flappy bird game. It is a practice when I learned reinforcement learning and partly reused [songrotek's code](https://github.com/songrotek/DRL-FlappyBird), especially the game engine and basic idea. Thanks for sharing, thanks to the spirit and community of open source. 6 | 7 | A video of the demo can be found on [YouTube](https://youtu.be/h4jEdF_roXU) or [优酷 Youku](http://v.youku.com/v_show/id_XMjcwOTcwMjYzMg==.html?spm=a2hzp.8253869.0.0&from=y1.7-2#paction) if you don't have access to YouTube. 8 | 9 | ### DQN implemented by PyTorch 10 | 11 | PyTorch is an elegant framework published by Facebook. I implemented the neural network and training/testing procedure using PyTorch. So you need install PyTorch to run this demo. Besides, pygame package is needed by the game engine. 12 | 13 | ### How to run the demo 14 | 15 | #### Play the game with pretrained model 16 | 17 | At the beginning, you can play the game with a pretrained model by me. You can download the pretrained model from [Google Drive](https://drive.google.com/file/d/0B98MUaCGMMG0em1uQzkzYmt3U00/view?usp=sharing) (or [Baidu Netdisk](https://pan.baidu.com/s/1pKOpRqr) if Google Drive is not available) and use the following commands to play the game. Make sure that the pretrained model is in the root directory of this project. 18 | 19 | ``` 20 | chmod +x play.sh 21 | ./play.sh 22 | ``` 23 | 24 | For more detail infomation about the meaning of the arguments of the program, run `python main.py --help` or refer to the code in `main.py`. 25 | 26 | #### Train DQN 27 | 28 | You can use the following commands to train the model from scrach or finetuning(if pretrained weight file is provided). 29 | 30 | ``` 31 | chmod +x train.sh 32 | ./train.sh # please see `main.py` for detail info about the variables 33 | ``` 34 | 35 | Some tips for training: 36 | 37 | - Do not set `memory_size` too large(or too small). It depends on available memory in your computer. 38 | 39 | - It takes a long time to complete training. I finetuned the model several times and change `epsilon` of ϵ-greedy exploration manually every time. 40 | 41 | - When choose action randomly in training, I prefer to 'DO Nothing' compared to 'UP'. I think it can accelarate convergence. See `get_action_randomly` method in `BrainDQN.py` for detail. 42 | 43 | ### Disclaimer 44 | 45 | This work is based on the repo [songrotek/DRL-FlappyBird](https://github.com/songrotek/DRL-FlappyBird) and [yenchenlin1994/DeepLearningFlappyBird](https://github.com/yenchenlin1994/DeepLearningFlappyBird.git). Thanks two authors! 46 | -------------------------------------------------------------------------------- /assets/audio/die.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/audio/die.ogg -------------------------------------------------------------------------------- /assets/audio/die.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/audio/die.wav -------------------------------------------------------------------------------- /assets/audio/hit.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/audio/hit.ogg -------------------------------------------------------------------------------- /assets/audio/hit.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/audio/hit.wav -------------------------------------------------------------------------------- /assets/audio/point.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/audio/point.ogg -------------------------------------------------------------------------------- /assets/audio/point.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/audio/point.wav -------------------------------------------------------------------------------- /assets/audio/swoosh.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/audio/swoosh.ogg -------------------------------------------------------------------------------- /assets/audio/swoosh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/audio/swoosh.wav -------------------------------------------------------------------------------- /assets/audio/wing.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/audio/wing.ogg -------------------------------------------------------------------------------- /assets/audio/wing.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/audio/wing.wav -------------------------------------------------------------------------------- /assets/sprites/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/sprites/0.png -------------------------------------------------------------------------------- /assets/sprites/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/sprites/1.png -------------------------------------------------------------------------------- /assets/sprites/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/sprites/2.png -------------------------------------------------------------------------------- /assets/sprites/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/sprites/3.png -------------------------------------------------------------------------------- /assets/sprites/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/sprites/4.png -------------------------------------------------------------------------------- /assets/sprites/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/sprites/5.png -------------------------------------------------------------------------------- /assets/sprites/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/sprites/6.png -------------------------------------------------------------------------------- /assets/sprites/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/sprites/7.png -------------------------------------------------------------------------------- /assets/sprites/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/sprites/8.png -------------------------------------------------------------------------------- /assets/sprites/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/sprites/9.png -------------------------------------------------------------------------------- /assets/sprites/background-black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/sprites/background-black.png -------------------------------------------------------------------------------- /assets/sprites/base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/sprites/base.png -------------------------------------------------------------------------------- /assets/sprites/pipe-green.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/sprites/pipe-green.png -------------------------------------------------------------------------------- /assets/sprites/redbird-downflap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/sprites/redbird-downflap.png -------------------------------------------------------------------------------- /assets/sprites/redbird-midflap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/sprites/redbird-midflap.png -------------------------------------------------------------------------------- /assets/sprites/redbird-upflap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmfbit/DQN-FlappyBird/3f84cf341e9ef531e193721f9446f6434076d243/assets/sprites/redbird-upflap.png -------------------------------------------------------------------------------- /game/flappy_bird_utils.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import sys 3 | def load(): 4 | # path of player with different states 5 | PLAYER_PATH = ( 6 | 'assets/sprites/redbird-upflap.png', 7 | 'assets/sprites/redbird-midflap.png', 8 | 'assets/sprites/redbird-downflap.png' 9 | ) 10 | 11 | # path of background 12 | BACKGROUND_PATH = 'assets/sprites/background-black.png' 13 | 14 | # path of pipe 15 | PIPE_PATH = 'assets/sprites/pipe-green.png' 16 | 17 | IMAGES, SOUNDS, HITMASKS = {}, {}, {} 18 | 19 | # numbers sprites for score display 20 | IMAGES['numbers'] = ( 21 | pygame.image.load('assets/sprites/0.png').convert_alpha(), 22 | pygame.image.load('assets/sprites/1.png').convert_alpha(), 23 | pygame.image.load('assets/sprites/2.png').convert_alpha(), 24 | pygame.image.load('assets/sprites/3.png').convert_alpha(), 25 | pygame.image.load('assets/sprites/4.png').convert_alpha(), 26 | pygame.image.load('assets/sprites/5.png').convert_alpha(), 27 | pygame.image.load('assets/sprites/6.png').convert_alpha(), 28 | pygame.image.load('assets/sprites/7.png').convert_alpha(), 29 | pygame.image.load('assets/sprites/8.png').convert_alpha(), 30 | pygame.image.load('assets/sprites/9.png').convert_alpha() 31 | ) 32 | 33 | # base (ground) sprite 34 | IMAGES['base'] = pygame.image.load('assets/sprites/base.png').convert_alpha() 35 | 36 | # sounds 37 | if 'win' in sys.platform: 38 | soundExt = '.wav' 39 | else: 40 | soundExt = '.ogg' 41 | 42 | SOUNDS['die'] = pygame.mixer.Sound('assets/audio/die' + soundExt) 43 | SOUNDS['hit'] = pygame.mixer.Sound('assets/audio/hit' + soundExt) 44 | SOUNDS['point'] = pygame.mixer.Sound('assets/audio/point' + soundExt) 45 | SOUNDS['swoosh'] = pygame.mixer.Sound('assets/audio/swoosh' + soundExt) 46 | SOUNDS['wing'] = pygame.mixer.Sound('assets/audio/wing' + soundExt) 47 | 48 | # select random background sprites 49 | IMAGES['background'] = pygame.image.load(BACKGROUND_PATH).convert() 50 | 51 | # select random player sprites 52 | IMAGES['player'] = ( 53 | pygame.image.load(PLAYER_PATH[0]).convert_alpha(), 54 | pygame.image.load(PLAYER_PATH[1]).convert_alpha(), 55 | pygame.image.load(PLAYER_PATH[2]).convert_alpha(), 56 | ) 57 | 58 | # select random pipe sprites 59 | IMAGES['pipe'] = ( 60 | pygame.transform.rotate( 61 | pygame.image.load(PIPE_PATH).convert_alpha(), 180), 62 | pygame.image.load(PIPE_PATH).convert_alpha(), 63 | ) 64 | 65 | # hismask for pipes 66 | HITMASKS['pipe'] = ( 67 | getHitmask(IMAGES['pipe'][0]), 68 | getHitmask(IMAGES['pipe'][1]), 69 | ) 70 | 71 | # hitmask for player 72 | HITMASKS['player'] = ( 73 | getHitmask(IMAGES['player'][0]), 74 | getHitmask(IMAGES['player'][1]), 75 | getHitmask(IMAGES['player'][2]), 76 | ) 77 | 78 | return IMAGES, SOUNDS, HITMASKS 79 | 80 | def getHitmask(image): 81 | """returns a hitmask using an image's alpha.""" 82 | mask = [] 83 | for x in range(image.get_width()): 84 | mask.append([]) 85 | for y in range(image.get_height()): 86 | mask[x].append(bool(image.get_at((x,y))[3])) 87 | return mask 88 | -------------------------------------------------------------------------------- /game/wrapped_flappy_bird.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import random 4 | import pygame 5 | import flappy_bird_utils 6 | import pygame.surfarray as surfarray 7 | from pygame.locals import * 8 | from itertools import cycle 9 | 10 | FPS = 30 11 | SCREENWIDTH = 288 12 | SCREENHEIGHT = 512 13 | 14 | pygame.init() 15 | FPSCLOCK = pygame.time.Clock() 16 | SCREEN = pygame.display.set_mode((SCREENWIDTH, SCREENHEIGHT)) 17 | pygame.display.set_caption('Flappy Bird') 18 | 19 | IMAGES, SOUNDS, HITMASKS = flappy_bird_utils.load() 20 | PIPEGAPSIZE = 100 # gap between upper and lower part of pipe 21 | BASEY = SCREENHEIGHT * 0.79 22 | 23 | PLAYER_WIDTH = IMAGES['player'][0].get_width() 24 | PLAYER_HEIGHT = IMAGES['player'][0].get_height() 25 | PIPE_WIDTH = IMAGES['pipe'][0].get_width() 26 | PIPE_HEIGHT = IMAGES['pipe'][0].get_height() 27 | BACKGROUND_WIDTH = IMAGES['background'].get_width() 28 | 29 | PLAYER_INDEX_GEN = cycle([0, 1, 2, 1]) 30 | 31 | 32 | class GameState: 33 | def __init__(self): 34 | self.score = self.playerIndex = self.loopIter = 0 35 | self.playerx = int(SCREENWIDTH * 0.2) 36 | self.playery = int((SCREENHEIGHT - PLAYER_HEIGHT) / 2) 37 | self.basex = 0 38 | self.baseShift = IMAGES['base'].get_width() - BACKGROUND_WIDTH 39 | 40 | newPipe1 = getRandomPipe() 41 | newPipe2 = getRandomPipe() 42 | self.upperPipes = [ 43 | {'x': SCREENWIDTH, 'y': newPipe1[0]['y']}, 44 | {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[0]['y']}, 45 | ] 46 | self.lowerPipes = [ 47 | {'x': SCREENWIDTH, 'y': newPipe1[1]['y']}, 48 | {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[1]['y']}, 49 | ] 50 | 51 | # player velocity, max velocity, downward accleration, accleration on flap 52 | self.pipeVelX = -4 53 | self.playerVelY = 0 # player's velocity along Y, default same as playerFlapped 54 | self.playerMaxVelY = 10 # max vel along Y, max descend speed 55 | self.playerMinVelY = -8 # min vel along Y, max ascend speed 56 | self.playerAccY = 1 # players downward accleration 57 | self.playerFlapAcc = -7 # players speed on flapping 58 | self.playerFlapped = False # True when player flaps 59 | 60 | def frame_step(self, input_actions): 61 | pygame.event.pump() 62 | 63 | reward = 0.1 64 | terminal = False 65 | 66 | if sum(input_actions) != 1: 67 | raise ValueError('Multiple input actions!') 68 | 69 | # input_actions[0] == 1: do nothing 70 | # input_actions[1] == 1: flap the bird 71 | if input_actions[1] == 1: 72 | if self.playery > -2 * PLAYER_HEIGHT: 73 | self.playerVelY = self.playerFlapAcc 74 | self.playerFlapped = True 75 | #SOUNDS['wing'].play() 76 | 77 | # check for score 78 | playerMidPos = self.playerx + PLAYER_WIDTH / 2 79 | for pipe in self.upperPipes: 80 | pipeMidPos = pipe['x'] + PIPE_WIDTH / 2 81 | if pipeMidPos <= playerMidPos < pipeMidPos + 4: 82 | self.score += 1 83 | #SOUNDS['point'].play() 84 | reward = 1 85 | 86 | # playerIndex basex change 87 | if (self.loopIter + 1) % 3 == 0: 88 | self.playerIndex = PLAYER_INDEX_GEN.next() 89 | self.loopIter = (self.loopIter + 1) % 30 90 | self.basex = -((-self.basex + 100) % self.baseShift) 91 | 92 | # player's movement 93 | if self.playerVelY < self.playerMaxVelY and not self.playerFlapped: 94 | self.playerVelY += self.playerAccY 95 | if self.playerFlapped: 96 | self.playerFlapped = False 97 | self.playery += min(self.playerVelY, BASEY - self.playery - PLAYER_HEIGHT) 98 | if self.playery < 0: 99 | self.playery = 0 100 | 101 | # move pipes to left 102 | for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes): 103 | uPipe['x'] += self.pipeVelX 104 | lPipe['x'] += self.pipeVelX 105 | 106 | # add new pipe when first pipe is about to touch left of screen 107 | if 0 < self.upperPipes[0]['x'] < 5: 108 | newPipe = getRandomPipe() 109 | self.upperPipes.append(newPipe[0]) 110 | self.lowerPipes.append(newPipe[1]) 111 | 112 | # remove first pipe if its out of the screen 113 | if self.upperPipes[0]['x'] < -PIPE_WIDTH: 114 | self.upperPipes.pop(0) 115 | self.lowerPipes.pop(0) 116 | 117 | # check if crash here 118 | isCrash= checkCrash({'x': self.playerx, 'y': self.playery, 119 | 'index': self.playerIndex}, 120 | self.upperPipes, self.lowerPipes) 121 | if isCrash: 122 | #SOUNDS['hit'].play() 123 | #SOUNDS['die'].play() 124 | terminal = True 125 | self.__init__() 126 | reward = -5 127 | 128 | # draw sprites 129 | SCREEN.blit(IMAGES['background'], (0,0)) 130 | 131 | for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes): 132 | SCREEN.blit(IMAGES['pipe'][0], (uPipe['x'], uPipe['y'])) 133 | SCREEN.blit(IMAGES['pipe'][1], (lPipe['x'], lPipe['y'])) 134 | 135 | SCREEN.blit(IMAGES['base'], (self.basex, BASEY)) 136 | # print score so player overlaps the score 137 | # showScore(self.score) 138 | SCREEN.blit(IMAGES['player'][self.playerIndex], 139 | (self.playerx, self.playery)) 140 | 141 | image_data = pygame.surfarray.array3d(pygame.display.get_surface()) 142 | pygame.display.update() 143 | FPSCLOCK.tick(FPS) 144 | #print self.upperPipes[0]['y'] + PIPE_HEIGHT - int(BASEY * 0.2) 145 | return image_data, reward, terminal 146 | 147 | def getRandomPipe(): 148 | """returns a randomly generated pipe""" 149 | # y of gap between upper and lower pipe 150 | gapYs = [20, 30, 40, 50, 60, 70, 80, 90] 151 | index = random.randint(0, len(gapYs)-1) 152 | gapY = gapYs[index] 153 | 154 | gapY += int(BASEY * 0.2) 155 | pipeX = SCREENWIDTH + 10 156 | 157 | return [ 158 | {'x': pipeX, 'y': gapY - PIPE_HEIGHT}, # upper pipe 159 | {'x': pipeX, 'y': gapY + PIPEGAPSIZE}, # lower pipe 160 | ] 161 | 162 | 163 | def showScore(score): 164 | """displays score in center of screen""" 165 | scoreDigits = [int(x) for x in list(str(score))] 166 | totalWidth = 0 # total width of all numbers to be printed 167 | 168 | for digit in scoreDigits: 169 | totalWidth += IMAGES['numbers'][digit].get_width() 170 | 171 | Xoffset = (SCREENWIDTH - totalWidth) / 2 172 | 173 | for digit in scoreDigits: 174 | SCREEN.blit(IMAGES['numbers'][digit], (Xoffset, SCREENHEIGHT * 0.1)) 175 | Xoffset += IMAGES['numbers'][digit].get_width() 176 | 177 | 178 | def checkCrash(player, upperPipes, lowerPipes): 179 | """returns True if player collders with base or pipes.""" 180 | pi = player['index'] 181 | player['w'] = IMAGES['player'][0].get_width() 182 | player['h'] = IMAGES['player'][0].get_height() 183 | 184 | # if player crashes into ground 185 | if player['y'] + player['h'] >= BASEY - 1: 186 | return True 187 | else: 188 | 189 | playerRect = pygame.Rect(player['x'], player['y'], 190 | player['w'], player['h']) 191 | 192 | for uPipe, lPipe in zip(upperPipes, lowerPipes): 193 | # upper and lower pipe rects 194 | uPipeRect = pygame.Rect(uPipe['x'], uPipe['y'], PIPE_WIDTH, PIPE_HEIGHT) 195 | lPipeRect = pygame.Rect(lPipe['x'], lPipe['y'], PIPE_WIDTH, PIPE_HEIGHT) 196 | 197 | # player and upper/lower pipe hitmasks 198 | pHitMask = HITMASKS['player'][pi] 199 | uHitmask = HITMASKS['pipe'][0] 200 | lHitmask = HITMASKS['pipe'][1] 201 | 202 | # if bird collided with upipe or lpipe 203 | uCollide = pixelCollision(playerRect, uPipeRect, pHitMask, uHitmask) 204 | lCollide = pixelCollision(playerRect, lPipeRect, pHitMask, lHitmask) 205 | 206 | if uCollide or lCollide: 207 | return True 208 | 209 | return False 210 | 211 | def pixelCollision(rect1, rect2, hitmask1, hitmask2): 212 | """Checks if two objects collide and not just their rects""" 213 | rect = rect1.clip(rect2) 214 | 215 | if rect.width == 0 or rect.height == 0: 216 | return False 217 | 218 | x1, y1 = rect.x - rect1.x, rect.y - rect1.y 219 | x2, y2 = rect.x - rect2.x, rect.y - rect2.y 220 | 221 | for x in xrange(rect.width): 222 | for y in xrange(rect.height): 223 | if hitmask1[x1+x][y1+y] and hitmask2[x2+x][y2+y]: 224 | return True 225 | return False 226 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from misc import * 4 | from BrainDQN import * 5 | import torch.cuda 6 | 7 | parser = argparse.ArgumentParser(description='DQN demo for flappy bird') 8 | 9 | parser.add_argument('--train', action='store_true', default=False, 10 | help='If set true, train the model; otherwise, play game with pretrained model') 11 | parser.add_argument('--cuda', action='store_true', default=False, 12 | help='If set true, with cuda enabled; otherwise, with CPU only') 13 | parser.add_argument('--lr', type=float, help='learning rate', default=0.0001) 14 | parser.add_argument('--gamma', type=float, 15 | help='discount rate', default=0.99) 16 | parser.add_argument('--batch_size', type=int, 17 | help='batch size', default=32) 18 | parser.add_argument('--memory_size', type=int, 19 | help='memory size for experience replay', default=5000) 20 | parser.add_argument('--init_e', type=float, 21 | help='initial epsilon for epsilon-greedy exploration', 22 | default=1.0) 23 | parser.add_argument('--final_e', type=float, 24 | help='final epsilon for epsilon-greedy exploration', 25 | default=0.1) 26 | parser.add_argument('--observation', type=int, 27 | help='random observation number in the beginning before training', 28 | default=100) 29 | parser.add_argument('--exploration', type=int, 30 | help='number of exploration using epsilon-greedy policy', 31 | default=10000) 32 | parser.add_argument('--max_episode', type=int, 33 | help='maximum episode of training', 34 | default=20000) 35 | parser.add_argument('--weight', type=str, 36 | help='weight file name for finetunig(Optional)', default='') 37 | parser.add_argument('--save_checkpoint_freq', type=int, 38 | help='episode interval to save checkpoint', default=2000) 39 | 40 | if __name__ == '__main__': 41 | args = parser.parse_args() 42 | if args.cuda and not torch.cuda.is_available(): 43 | print 'CUDA is not availale, maybe you should not set --cuda' 44 | sys.exit(1) 45 | if not args.train and args.weight == '': 46 | print 'When test, a pretrained weight model file should be given' 47 | sys.exit(1) 48 | if args.cuda: 49 | print 'With GPU support!' 50 | if args.train: 51 | model = BrainDQN(epsilon=args.init_e, mem_size=args.memory_size, cuda=args.cuda) 52 | resume = not args.weight == '' 53 | train_dqn(model, args, resume) 54 | else: 55 | play_game(args.weight, args.cuda, True) 56 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("game/") 3 | import wrapped_flappy_bird as game 4 | from BrainDQN import * 5 | import shutil 6 | import numpy as np 7 | import random 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch.autograd import Variable 12 | 13 | import PIL.Image as Image 14 | 15 | IMAGE_SIZE = (72, 128) 16 | 17 | def preprocess(frame): 18 | """Do preprocessing: resize and binarize. 19 | 20 | Downsampling to 128x72 size and convert to grayscale 21 | frame -- input frame, rgb image with 512x288 size 22 | """ 23 | im = Image.fromarray(frame).resize(IMAGE_SIZE).convert(mode='L') 24 | out = np.asarray(im).astype(np.float32) 25 | out[out <= 1.] = 0.0 26 | out[out > 1.] = 1.0 27 | return out 28 | 29 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 30 | """Save checkpoint model to disk 31 | 32 | state -- checkpoint state: model weight and other info 33 | binding by user 34 | is_best -- if the checkpoint is the best. If it is, then 35 | save as a best model 36 | """ 37 | torch.save(state, filename) 38 | if is_best: 39 | shutil.copyfile(filename, 'model_best.pth.tar') 40 | 41 | def load_checkpoint(filename, model): 42 | """Load previous checkpoint model 43 | 44 | filename -- model file name 45 | model -- DQN model 46 | """ 47 | try: 48 | checkpoint = torch.load(filename) 49 | except: 50 | # load weight saved on gpy device to cpu device 51 | # see https://discuss.pytorch.org/t/on-a-cpu-device-how-to-load-checkpoint-saved-on-gpu-device/349/3 52 | checkpoint = torch.load(filename, map_location=lambda storage, loc:storage) 53 | episode = checkpoint['episode'] 54 | epsilon = checkpoint['epsilon'] 55 | print 'pretrained episode = {}'.format(episode) 56 | print 'pretrained epsilon = {}'.format(epsilon) 57 | model.load_state_dict(checkpoint['state_dict']) 58 | time_step = checkpoint.get('best_time_step', None) 59 | if time_step is None: 60 | time_step = checkpoint('time_step') 61 | print 'pretrained time step = {}'.format(time_step) 62 | return episode, epsilon, time_step 63 | 64 | def train_dqn(model, options, resume): 65 | """Train DQN 66 | 67 | model -- DQN model 68 | lr -- learning rate 69 | max_episode -- maximum episode 70 | resume -- resume previous model 71 | model_name -- checkpoint file name 72 | """ 73 | best_time_step = 0. 74 | if resume: 75 | if options.weight is None: 76 | print 'when resume, you should give weight file name.' 77 | return 78 | print 'load previous model weight: {}'.format(options.weight) 79 | _, _, best_time_step = load_checkpoint(options.weight, model) 80 | 81 | flappyBird = game.GameState() 82 | optimizer = optim.RMSprop(model.parameters(), lr=options.lr) 83 | ceriterion = nn.MSELoss() 84 | 85 | action = [1, 0] 86 | o, r, terminal = flappyBird.frame_step(action) 87 | o = preprocess(o) 88 | model.set_initial_state() 89 | 90 | if options.cuda: 91 | model = model.cuda() 92 | # in the first `OBSERVE` time steos, we dont train the model 93 | for i in xrange(options.observation): 94 | action = model.get_action_randomly() 95 | o, r, terminal = flappyBird.frame_step(action) 96 | o = preprocess(o) 97 | model.store_transition(o, action, r, terminal) 98 | # start training 99 | for episode in xrange(options.max_episode): 100 | model.time_step = 0 101 | model.set_train() 102 | total_reward = 0. 103 | # begin an episode! 104 | while True: 105 | optimizer.zero_grad() 106 | action = model.get_action() 107 | o_next, r, terminal = flappyBird.frame_step(action) 108 | total_reward += options.gamma**model.time_step * r 109 | o_next = preprocess(o_next) 110 | model.store_transition(o_next, action, r, terminal) 111 | model.increase_time_step() 112 | # Step 1: obtain random minibatch from replay memory 113 | minibatch = random.sample(model.replay_memory, options.batch_size) 114 | state_batch = np.array([data[0] for data in minibatch]) 115 | action_batch = np.array([data[1] for data in minibatch]) 116 | reward_batch = np.array([data[2] for data in minibatch]) 117 | next_state_batch = np.array([data[3] for data in minibatch]) 118 | state_batch_var = Variable(torch.from_numpy(state_batch)) 119 | next_state_batch_var = Variable(torch.from_numpy(next_state_batch), 120 | volatile=True) 121 | if options.cuda: 122 | state_batch_var = state_batch_var.cuda() 123 | next_state_batch_var = next_state_batch_var.cuda() 124 | # Step 2: calculate y 125 | q_value_next = model.forward(next_state_batch_var) 126 | 127 | q_value = model.forward(state_batch_var) 128 | 129 | y = reward_batch.astype(np.float32) 130 | max_q, _ = torch.max(q_value_next, dim=1) 131 | 132 | for i in xrange(options.batch_size): 133 | if not minibatch[i][4]: 134 | y[i] += options.gamma*max_q.data[i][0] 135 | 136 | y = Variable(torch.from_numpy(y)) 137 | action_batch_var = Variable(torch.from_numpy(action_batch)) 138 | if options.cuda: 139 | y = y.cuda() 140 | action_batch_var = action_batch_var.cuda() 141 | q_value = torch.sum(torch.mul(action_batch_var, q_value), dim=1) 142 | 143 | loss = ceriterion(q_value, y) 144 | loss.backward() 145 | 146 | optimizer.step() 147 | # when the bird dies, the episode ends 148 | if terminal: 149 | break 150 | 151 | print 'episode: {}, epsilon: {:.4f}, max time step: {}, total reward: {:.6f}'.format( 152 | episode, model.epsilon, model.time_step, total_reward) 153 | 154 | if model.epsilon > options.final_e: 155 | delta = (options.init_e - options.final_e)/options.exploration 156 | model.epsilon -= delta 157 | 158 | if episode % 100 == 0: 159 | ave_time = test_dqn(model, episode) 160 | 161 | if ave_time > best_time_step: 162 | best_time_step = ave_time 163 | save_checkpoint({ 164 | 'episode': episode, 165 | 'epsilon': model.epsilon, 166 | 'state_dict': model.state_dict(), 167 | 'best_time_step': best_time_step, 168 | }, True, 'checkpoint-episode-%d.pth.tar' %episode) 169 | elif episode % options.save_checkpoint_freq == 0: 170 | save_checkpoint({ 171 | 'episode:': episode, 172 | 'epsilon': model.epsilon, 173 | 'state_dict': model.state_dict(), 174 | 'time_step': ave_time, 175 | }, False, 'checkpoint-episode-%d.pth.tar' %episode) 176 | else: 177 | continue 178 | print 'save checkpoint, episode={}, ave time step={:.2f}'.format( 179 | episode, ave_time) 180 | 181 | def test_dqn(model, episode): 182 | """Test the behavor of dqn when training 183 | 184 | model -- dqn model 185 | episode -- current training episode 186 | """ 187 | model.set_eval() 188 | ave_time = 0. 189 | for test_case in xrange(5): 190 | model.time_step = 0 191 | flappyBird = game.GameState() 192 | o, r, terminal = flappyBird.frame_step([1, 0]) 193 | o = preprocess(o) 194 | model.set_initial_state() 195 | while True: 196 | action = model.get_optim_action() 197 | o, r, terminal = flappyBird.frame_step(action) 198 | if terminal: 199 | break 200 | o = preprocess(o) 201 | model.current_state = np.append(model.current_state[1:,:,:], o.reshape((1,)+o.shape), axis=0) 202 | model.increase_time_step() 203 | ave_time += model.time_step 204 | ave_time /= 5 205 | print 'testing: episode: {}, average time: {}'.format(episode, ave_time) 206 | return ave_time 207 | 208 | 209 | def play_game(model_file_name, cuda=False, best=True): 210 | """Play flappy bird with pretrained dqn model 211 | 212 | weight -- model file name containing weight of dqn 213 | best -- if the model is best or not 214 | """ 215 | print 'load pretrained model file: ' + model_file_name 216 | model = BrainDQN(epsilon=0., mem_size=0, cuda=cuda) 217 | load_checkpoint(model_file_name, model) 218 | 219 | model.set_eval() 220 | bird_game = game.GameState() 221 | model.set_initial_state() 222 | if cuda: 223 | model = model.cuda() 224 | while True: 225 | action = model.get_optim_action() 226 | o, r, terminal = bird_game.frame_step(action) 227 | if terminal: 228 | break 229 | o = preprocess(o) 230 | 231 | model.current_state = np.append(model.current_state[1:,:,:], o.reshape((1,)+o.shape), axis=0) 232 | 233 | model.increase_time_step() 234 | print 'total time step is {}'.format(model.time_step) 235 | -------------------------------------------------------------------------------- /play.sh: -------------------------------------------------------------------------------- 1 | # play flappy bird with pretrained model 2 | # you canchange `model_best.pth.tar` to your pretrained model file name 3 | python main.py\ 4 | --weight model_best.pth.tar\ 5 | --cuda # uncomment when gpu is available 6 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | lr=0.0001 2 | gamma=0.99 3 | batch_size=32 4 | mem_size=5000 5 | initial_epsilon=1. 6 | final_epsilon=0.1 7 | observation=100 8 | exploration=50000 9 | max_episode=100000 10 | # for fine tuning, uncomment this 11 | #weight=model_best.pth.tar 12 | 13 | python main.py --train\ 14 | --cuda\ 15 | --lr=$lr\ 16 | --gamma=$gamma\ 17 | --batch_size=$batch_size\ 18 | --memory_size=$mem_size\ 19 | --init_e=$initial_epsilon\ 20 | --final_e=$final_epsilon\ 21 | --observation=$observation\ 22 | --exploration=$exploration\ 23 | --max_episode=$max_episode 24 | #--weight=$weight # for fine tuning, uncomment this 25 | 26 | --------------------------------------------------------------------------------