├── .gitignore ├── README.md ├── game ├── __init__.py ├── assets │ ├── audio │ │ ├── die.ogg │ │ ├── die.wav │ │ ├── hit.ogg │ │ ├── hit.wav │ │ ├── point.ogg │ │ ├── point.wav │ │ ├── swoosh.ogg │ │ ├── swoosh.wav │ │ ├── wing.ogg │ │ └── wing.wav │ └── sprites │ │ ├── 0.png │ │ ├── 1.png │ │ ├── 2.png │ │ ├── 3.png │ │ ├── 4.png │ │ ├── 5.png │ │ ├── 6.png │ │ ├── 7.png │ │ ├── 8.png │ │ ├── 9.png │ │ ├── background-black.png │ │ ├── base.png │ │ ├── pipe-green.png │ │ ├── redbird-downflap.png │ │ ├── redbird-midflap.png │ │ └── redbird-upflap.png ├── flappy_bird_utils.py └── wrapped_flappy_bird.py ├── q_game.py ├── requirements.txt ├── test_game.py └── weights.h5 /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlappyBirdTensorflow 2 | 3 | 项目介绍:[强化学习Deep Q-Network自动玩flappy bird](https://yuerblog.cc/2021/01/26/%e5%bc%ba%e5%8c%96%e5%ad%a6%e4%b9%a0deep-q-network%e8%87%aa%e5%8a%a8%e7%8e%a9flappy-bird/) 4 | 5 | ## 项目运行 6 | 7 | 基于python3.8+tensorflow2.4+pygame实现,最好利用conda初始化一个新的python环境。 8 | 9 | 安装依赖(最好-i指定走阿里云pip镜像): 10 | 11 | ``` 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## 体验效果 16 | 17 | weights.h5是我训练好的模型,可以直接体验效果: 18 | 19 | ``` 20 | python q_game.py --model-only 21 | ``` 22 | 23 | ## 自己训练 24 | 25 | 删除weights.h5文件,启动训练: 26 | 27 | ``` 28 | python q_game.py 29 | ``` 30 | 31 | 模型训练需要6小时+才能收敛稳定,最好睡觉前运行,睡醒后观察效果,模型会自动保存到weights.h5。 -------------------------------------------------------------------------------- /game/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/__init__.py -------------------------------------------------------------------------------- /game/assets/audio/die.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/audio/die.ogg -------------------------------------------------------------------------------- /game/assets/audio/die.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/audio/die.wav -------------------------------------------------------------------------------- /game/assets/audio/hit.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/audio/hit.ogg -------------------------------------------------------------------------------- /game/assets/audio/hit.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/audio/hit.wav -------------------------------------------------------------------------------- /game/assets/audio/point.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/audio/point.ogg -------------------------------------------------------------------------------- /game/assets/audio/point.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/audio/point.wav -------------------------------------------------------------------------------- /game/assets/audio/swoosh.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/audio/swoosh.ogg -------------------------------------------------------------------------------- /game/assets/audio/swoosh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/audio/swoosh.wav -------------------------------------------------------------------------------- /game/assets/audio/wing.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/audio/wing.ogg -------------------------------------------------------------------------------- /game/assets/audio/wing.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/audio/wing.wav -------------------------------------------------------------------------------- /game/assets/sprites/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/sprites/0.png -------------------------------------------------------------------------------- /game/assets/sprites/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/sprites/1.png -------------------------------------------------------------------------------- /game/assets/sprites/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/sprites/2.png -------------------------------------------------------------------------------- /game/assets/sprites/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/sprites/3.png -------------------------------------------------------------------------------- /game/assets/sprites/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/sprites/4.png -------------------------------------------------------------------------------- /game/assets/sprites/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/sprites/5.png -------------------------------------------------------------------------------- /game/assets/sprites/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/sprites/6.png -------------------------------------------------------------------------------- /game/assets/sprites/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/sprites/7.png -------------------------------------------------------------------------------- /game/assets/sprites/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/sprites/8.png -------------------------------------------------------------------------------- /game/assets/sprites/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/sprites/9.png -------------------------------------------------------------------------------- /game/assets/sprites/background-black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/sprites/background-black.png -------------------------------------------------------------------------------- /game/assets/sprites/base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/sprites/base.png -------------------------------------------------------------------------------- /game/assets/sprites/pipe-green.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/sprites/pipe-green.png -------------------------------------------------------------------------------- /game/assets/sprites/redbird-downflap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/sprites/redbird-downflap.png -------------------------------------------------------------------------------- /game/assets/sprites/redbird-midflap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/sprites/redbird-midflap.png -------------------------------------------------------------------------------- /game/assets/sprites/redbird-upflap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/game/assets/sprites/redbird-upflap.png -------------------------------------------------------------------------------- /game/flappy_bird_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | 游戏素材加载 3 | """ 4 | import pygame 5 | import sys 6 | import os 7 | 8 | assets_dir = os.path.dirname(__file__) 9 | 10 | def load(): 11 | # 小鸟挥动翅膀的3个造型 12 | PLAYER_PATH = ( 13 | assets_dir + '/assets/sprites/redbird-upflap.png', 14 | assets_dir + '/assets/sprites/redbird-midflap.png', 15 | assets_dir + '/assets/sprites/redbird-downflap.png' 16 | ) 17 | 18 | # 游戏背景图,纯黑色是为了训练降低干扰 19 | BACKGROUND_PATH = assets_dir + '/assets/sprites/background-black.png' 20 | 21 | # 水管图片 22 | PIPE_PATH = assets_dir + '/assets/sprites/pipe-green.png' 23 | 24 | IMAGES, SOUNDS, HITMASKS = {}, {}, {} 25 | 26 | # 加载数字0~9的图片,类型是Surface图像 27 | IMAGES['numbers'] = ( 28 | pygame.image.load(assets_dir + '/assets/sprites/0.png').convert_alpha(), 29 | pygame.image.load(assets_dir + '/assets/sprites/1.png').convert_alpha(), 30 | pygame.image.load(assets_dir + '/assets/sprites/2.png').convert_alpha(), 31 | pygame.image.load(assets_dir + '/assets/sprites/3.png').convert_alpha(), # convert/conver_alpha是为了将图片转成绘制用的像素格式,提高绘制效率 32 | pygame.image.load(assets_dir + '/assets/sprites/4.png').convert_alpha(), 33 | pygame.image.load(assets_dir + '/assets/sprites/5.png').convert_alpha(), 34 | pygame.image.load(assets_dir + '/assets/sprites/6.png').convert_alpha(), 35 | pygame.image.load(assets_dir + '/assets/sprites/7.png').convert_alpha(), 36 | pygame.image.load(assets_dir + '/assets/sprites/8.png').convert_alpha(), 37 | pygame.image.load(assets_dir + '/assets/sprites/9.png').convert_alpha() 38 | ) 39 | 40 | # 地面图片 41 | IMAGES['base'] = pygame.image.load(assets_dir + '/assets/sprites/base.png').convert_alpha() 42 | 43 | # 根据操作系统加载不同格式的声音文件 44 | if 'win' in sys.platform: 45 | soundExt = '.wav' 46 | else: 47 | soundExt = '.ogg' 48 | 49 | # 各种Sound对象 50 | SOUNDS['die'] = pygame.mixer.Sound(assets_dir + '/assets/audio/die' + soundExt) 51 | SOUNDS['hit'] = pygame.mixer.Sound(assets_dir + '/assets/audio/hit' + soundExt) 52 | SOUNDS['point'] = pygame.mixer.Sound(assets_dir + '/assets/audio/point' + soundExt) 53 | SOUNDS['swoosh'] = pygame.mixer.Sound(assets_dir + '/assets/audio/swoosh' + soundExt) 54 | SOUNDS['wing'] = pygame.mixer.Sound(assets_dir + '/assets/audio/wing' + soundExt) 55 | 56 | # 加载背景图片 57 | IMAGES['background'] = pygame.image.load(BACKGROUND_PATH).convert() 58 | 59 | # 加载小鸟的3个姿态 60 | IMAGES['player'] = ( 61 | pygame.image.load(PLAYER_PATH[0]).convert_alpha(), 62 | pygame.image.load(PLAYER_PATH[1]).convert_alpha(), 63 | pygame.image.load(PLAYER_PATH[2]).convert_alpha(), 64 | ) 65 | 66 | # 加载水管图片,并反转180度产生上方的水管图片 67 | IMAGES['pipe'] = ( 68 | pygame.transform.rotate( 69 | pygame.image.load(PIPE_PATH).convert_alpha(), 180), 70 | pygame.image.load(PIPE_PATH).convert_alpha(), 71 | ) 72 | 73 | # 计算水管图片的bool掩码 74 | HITMASKS['pipe'] = ( 75 | getHitmask(IMAGES['pipe'][0]), 76 | getHitmask(IMAGES['pipe'][1]), 77 | ) 78 | 79 | # 生成小鸟图片的bool掩码 80 | HITMASKS['player'] = ( 81 | getHitmask(IMAGES['player'][0]), 82 | getHitmask(IMAGES['player'][1]), 83 | getHitmask(IMAGES['player'][2]), 84 | ) 85 | 86 | return IMAGES, SOUNDS, HITMASKS 87 | 88 | # 生成图片的bool掩码矩阵,true表示对应像素位置不是透明的部分 89 | def getHitmask(image): 90 | """returns a hitmask using an image's alpha.""" 91 | mask = [] 92 | for x in range(image.get_width()): 93 | mask.append([]) 94 | for y in range(image.get_height()): 95 | mask[x].append(bool(image.get_at((x,y))[3])) # 像素点是RGBA,例如:(83, 56, 70, 255),最后是透明度(0是透明,255是不透明) 96 | return mask 97 | -------------------------------------------------------------------------------- /game/wrapped_flappy_bird.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import random 4 | import pygame 5 | from . import flappy_bird_utils 6 | import pygame.surfarray as surfarray 7 | from pygame.locals import * 8 | from itertools import cycle 9 | 10 | # 屏幕宽*高 11 | FPS = 30 12 | SCREENWIDTH = 288 13 | SCREENHEIGHT = 512 14 | 15 | # 初始化游戏 16 | pygame.init() 17 | FPSCLOCK = pygame.time.Clock() # FPS限速器 18 | SCREEN = pygame.display.set_mode((SCREENWIDTH, SCREENHEIGHT)) # 宽*高 19 | pygame.display.set_caption('Flappy Bird') # 标题 20 | 21 | # 加载素材 22 | IMAGES, SOUNDS, HITMASKS = flappy_bird_utils.load() 23 | 24 | PIPEGAPSIZE = 100 # 上下水管之间的距离是固定的100像素 25 | BASEY = SCREENHEIGHT * 0.79 # 地面图片的y坐标 26 | 27 | # 小鸟图片的宽*高 28 | PLAYER_WIDTH = IMAGES['player'][0].get_width() 29 | PLAYER_HEIGHT = IMAGES['player'][0].get_height() 30 | # 水管图片的宽*高 31 | PIPE_WIDTH = IMAGES['pipe'][0].get_width() 32 | PIPE_HEIGHT = IMAGES['pipe'][0].get_height() 33 | 34 | # 背景图片的宽 35 | BACKGROUND_WIDTH = IMAGES['background'].get_width() 36 | 37 | # 小鸟图片动画播放顺序 38 | PLAYER_INDEX_GEN = cycle([0, 1, 2, 1]) 39 | 40 | # Flappy bird游戏类 41 | class GameState: 42 | def __init__(self): 43 | self.score = 0 44 | self.playerIndex = 0 45 | self.loopIter = 0 46 | 47 | # 玩家初始坐标 48 | self.playerx = int(SCREENWIDTH * 0.2) 49 | self.playery = int((SCREENHEIGHT - PLAYER_HEIGHT) / 2) 50 | 51 | # 地面图片需要跑马灯效果,它比屏幕宽一点,每帧向左移动,当要耗尽时重新回到右边,如此往复 52 | self.basex = 0 # 地面图片的x坐标 53 | self.baseShift = IMAGES['base'].get_width() - BACKGROUND_WIDTH # 地面图片比屏幕宽度长多少像素,就是它可以移动的距离 54 | 55 | newPipe1 = getRandomPipe() # 生成一对上下管子 56 | newPipe2 = getRandomPipe() # 再生成一对上下管子 57 | 58 | # 上面2根管子,都放到屏幕右侧之外,x相邻半个屏幕距离 59 | self.upperPipes = [ 60 | {'x': SCREENWIDTH, 'y': newPipe1[0]['y']}, 61 | {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[0]['y']}, 62 | ] 63 | # 下面2根管子,都放到屏幕右侧之外,x相邻半个屏幕距离 64 | self.lowerPipes = [ 65 | {'x': SCREENWIDTH, 'y': newPipe1[1]['y']}, 66 | {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[1]['y']}, 67 | ] 68 | 69 | # 水管的水平移动速度,每次x-4实现向左移动 70 | self.pipeVelX = -4 71 | 72 | # 小鸟Y方向速度 73 | self.playerVelY = 0 74 | # 小鸟Y方向重力加速度,每帧作用域playerVelY,令其Y速度向下加大 75 | self.playerAccY = 1 76 | # 点击后,小鸟Y方向速度重置为-9,也就是开始向上移动 77 | self.playerFlapAcc = -9 78 | 79 | # 小鸟Y方向速度限制 80 | self.playerMaxVelY = 10 # Y向下最大速度10 81 | 82 | # 执行一次操作,返回操作后的画面、本次操作的奖励(活着+0.1,死了-1,飞过水管+1)、游戏是否结束 83 | def frame_step(self, input_actions): 84 | # 给pygame对积累的事件做一下默认处理 85 | pygame.event.pump() 86 | 87 | # 活着就奖励0.1分 88 | reward = 0.01 89 | # 是否死了 90 | terminal = False 91 | 92 | # 必须传有效的action,[1,0]表示不点击,[0,1]表示点击,全传0是不对的 93 | if sum(input_actions) != 1: 94 | raise ValueError('Multiple input actions!') 95 | 96 | # 每3帧换一次小鸟造型图片,loopIter统计经过了多少帧 97 | if (self.loopIter + 1) % 3 == 0: 98 | self.playerIndex = next(PLAYER_INDEX_GEN) 99 | self.loopIter += 1 100 | 101 | # 让地面向左移动,游戏开始的时候地面x=0,逐步减小x 102 | if self.basex + self.pipeVelX <= -self.baseShift: 103 | self.basex = 0 104 | else: # 图片即将滚动耗尽,重置x坐标 105 | self.basex += self.pipeVelX 106 | 107 | # 点击了屏幕 108 | if input_actions[1] == 1: 109 | self.playerVelY = self.playerFlapAcc # 将小鸟y方向速度重置为-9,也就是向上移动 110 | #SOUNDS['wing'].play() # 播放扇翅膀的声音 111 | elif self.playerVelY < self.playerMaxVelY: # 没点击屏幕并且没达到最大掉落速度,继续施加重力加速度 112 | self.playerVelY += self.playerAccY 113 | 114 | # 将速度施加到小鸟的y坐标上 115 | self.playery += self.playerVelY 116 | if self.playery < 0: # 撞到上边缘不算死 117 | self.playery = 0 # 限制它别飞出去 118 | elif self.playery + PLAYER_HEIGHT >= BASEY: # 小鸟碰到地面 119 | self.playery = BASEY - PLAYER_HEIGHT # 限制它别穿地 120 | 121 | # 让上下水管都向左移动一次 122 | for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes): 123 | uPipe['x'] += self.pipeVelX 124 | lPipe['x'] += self.pipeVelX 125 | 126 | # 判断小鸟是否穿过了一排水管,因为上下水管x一样,只需要用上排水管判断 127 | playerMidPos = self.playerx + PLAYER_WIDTH / 2 # 小鸟中心的x坐标(这个是固定值,小鸟实际不会动,是水管在动) 128 | for pipe in self.upperPipes: # 检查与上排水管的关系 129 | pipeMidPos = pipe['x'] + PIPE_WIDTH / 2 # 水管中心的x坐标 130 | if pipeMidPos <= playerMidPos < pipeMidPos + abs(self.pipeVelX): # 小鸟x坐标刚刚飞过了水管x中心(4是水管的移动速度) 131 | self.score += 1 # 游戏得分+1 132 | #SOUNDS['point'].play() 133 | reward = 100 # 产生强化学习的动作奖励10分 134 | 135 | # 最左侧水管马上离开屏幕,生成新水管 136 | if 0 < self.upperPipes[0]['x'] < 5: 137 | newPipe = getRandomPipe() 138 | self.upperPipes.append(newPipe[0]) 139 | self.lowerPipes.append(newPipe[1]) 140 | 141 | # 最左侧水管彻底离开屏幕,删除它的上下2根水管 142 | if self.upperPipes[0]['x'] < -PIPE_WIDTH: 143 | self.upperPipes.pop(0) 144 | self.lowerPipes.pop(0) 145 | 146 | # 检查小鸟是否碰到水管 147 | isCrash= checkCrash({'x': self.playerx, 'y': self.playery, 'index': self.playerIndex}, self.upperPipes, self.lowerPipes) 148 | if isCrash: # 死掉了 149 | #SOUNDS['hit'].play() 150 | #SOUNDS['die'].play() 151 | reward = -10 # 负向激励分 152 | terminal = True # 本次操作导致游戏结束了 153 | 154 | ##### 进入重绘 ####### 155 | 156 | # 贴背景图 157 | SCREEN.blit(IMAGES['background'], (0,0)) 158 | # 画水管 159 | for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes): 160 | SCREEN.blit(IMAGES['pipe'][0], (uPipe['x'], uPipe['y'])) 161 | SCREEN.blit(IMAGES['pipe'][1], (lPipe['x'], lPipe['y'])) 162 | # 画地面 163 | SCREEN.blit(IMAGES['base'], (self.basex, BASEY)) 164 | # 画得分(训练时候别打开,造成干扰了) 165 | #showScore(self.score) 166 | # 画小鸟 167 | SCREEN.blit(IMAGES['player'][self.playerIndex], (self.playerx, self.playery)) 168 | # 重绘 169 | pygame.display.update() 170 | # 留存游戏画面(截图是列优先存储的,需要转行行优先存储) 171 | # https://stackoverflow.com/questions/34673424/how-to-get-numpy-array-of-rgb-colors-from-pygame-surface 172 | image_data = pygame.surfarray.array3d(pygame.display.get_surface()).swapaxes(0,1) 173 | # 死亡则重置游戏状态 174 | if terminal: 175 | self.__init__() 176 | # 控制FPS 177 | FPSCLOCK.tick(FPS) 178 | return image_data, reward, terminal 179 | 180 | # 生成一对水管,放到屏幕外面 181 | def getRandomPipe(): 182 | gapY = random.randint(70, 140) 183 | 184 | # 注:每一对水管的缝隙高度都是一样的PIPEGAPSIZE,gayY决定的是缝隙的上边缘y坐标 185 | pipeX = SCREENWIDTH + 10 # 水管出现在屏幕右侧之外 186 | 187 | return [ 188 | {'x': pipeX, 'y': gapY - PIPE_HEIGHT}, # 计算上面水管图片的y坐标,就是缝隙上边缘y减去水管本身高度 189 | {'x': pipeX, 'y': gapY + PIPEGAPSIZE}, # 计算下面水管图片的y坐标,就是缝隙上边缘y加上缝隙本身高度 190 | ] 191 | 192 | # 检查小鸟是否碰到水管或者地面(天花板不算) 193 | def checkCrash(player, upperPipes, lowerPipes): 194 | pi = player['index'] # 小鸟的第几张图片 195 | 196 | # 图片的宽*高 197 | player['w'] = IMAGES['player'][pi].get_width() 198 | player['h'] = IMAGES['player'][pi].get_height() 199 | 200 | # 小鸟碰到了地面 201 | if player['y'] + player['h'] >= BASEY - 1: 202 | return True 203 | else: # 小鸟与水管进行碰撞检测 204 | # 小鸟图片的矩形区域 205 | playerRect = pygame.Rect(player['x'], player['y'], player['w'], player['h']) 206 | 207 | # 每一对水管 208 | for uPipe, lPipe in zip(upperPipes, lowerPipes): 209 | # 上面水管的矩形 210 | uPipeRect = pygame.Rect(uPipe['x'], uPipe['y'], PIPE_WIDTH, PIPE_HEIGHT) 211 | # 下面水管的矩形 212 | lPipeRect = pygame.Rect(lPipe['x'], lPipe['y'], PIPE_WIDTH, PIPE_HEIGHT) 213 | 214 | # 小鸟图片的非透明像素掩码 215 | pHitMask = HITMASKS['player'][pi] 216 | # 上水管的非透明像素掩码 217 | uHitmask = HITMASKS['pipe'][0] 218 | # 下水管的非透明像素掩码 219 | lHitmask = HITMASKS['pipe'][1] 220 | 221 | # 检测小鸟与上面水管的碰撞 222 | uCollide = pixelCollision(playerRect, uPipeRect, pHitMask, uHitmask) 223 | # 检测小鸟与下面水管的碰撞 224 | lCollide = pixelCollision(playerRect, lPipeRect, pHitMask, lHitmask) 225 | 226 | if uCollide or lCollide: 227 | return True 228 | return False 229 | 230 | 231 | # 2个矩形区域的碰撞检测 232 | def pixelCollision(rect1, rect2, hitmask1, hitmask2): 233 | # 求2个矩形相交的矩形区域 234 | rect = rect1.clip(rect2) 235 | 236 | # 相交面积为0 237 | if rect.width == 0 or rect.height == 0: 238 | return False 239 | 240 | # 相交矩形x,y相对于2个矩形左上角的距离 241 | x1, y1 = rect.x - rect1.x, rect.y - rect1.y 242 | x2, y2 = rect.x - rect2.x, rect.y - rect2.y 243 | 244 | # 检查相交矩形内的每个点,是否在2个矩形内同时是非透明点,那么就碰撞了 245 | for x in range(rect.width): 246 | for y in range(rect.height): 247 | if hitmask1[x1+x][y1+y] and hitmask2[x2+x][y2+y]: 248 | return True 249 | return False 250 | 251 | # 展示得分,传入一个整数得分 252 | def showScore(score): 253 | # 转成单个数字的列表 254 | scoreDigits = [int(x) for x in list(str(score))] 255 | 256 | # 计算展示所有数字要占多少像素宽度 257 | totalWidth = 0 258 | for digit in scoreDigits: 259 | totalWidth += IMAGES['numbers'][digit].get_width() 260 | 261 | # 计算绘制起始x坐标 262 | Xoffset = (SCREENWIDTH - totalWidth) / 2 263 | 264 | # 逐个数字绘制 265 | for digit in scoreDigits: 266 | SCREEN.blit(IMAGES['numbers'][digit], (Xoffset, 20)) # y坐标贴近屏幕上边缘 267 | Xoffset += IMAGES['numbers'][digit].get_width() # 移动绘制x坐标 -------------------------------------------------------------------------------- /q_game.py: -------------------------------------------------------------------------------- 1 | """ 2 | 强化学习q learning flappy bird 3 | """ 4 | from game.wrapped_flappy_bird import GameState 5 | import time 6 | import numpy as np 7 | import skimage.color 8 | import skimage.transform 9 | import skimage.exposure 10 | import tensorflow as tf 11 | import random 12 | import argparse 13 | 14 | # 命令行参数 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model-only", help="加载已有模型,不随机探索,仍旧训练", action='store_true') 17 | args = parser.parse_args() 18 | 19 | # 测试用代码 20 | def _test_save_img(img): 21 | # 把每一帧图片存储到文件里,调试用 22 | from PIL import Image 23 | im = Image.fromarray((img*255).astype(np.uint8), mode='L') # 图片已经被处理为0~1之间的亮度值,所以*255取整数变灰度展示 24 | im.save('./img.jpg') 25 | 26 | # 构建卷积神经网络 27 | def build_model(): 28 | # 卷积神经网络:https://blog.csdn.net/FontThrone/article/details/76652753 29 | model = tf.keras.models.Sequential([ 30 | tf.keras.layers.Input(shape=(80,80,4)), 31 | tf.keras.layers.Conv2D(filters=32, kernel_size=(8, 8), padding='same',strides=4, activation='relu'), 32 | tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'), 33 | tf.keras.layers.Conv2D(filters=64, kernel_size=(4, 4), padding='same',strides=2, activation='relu'), 34 | tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'), 35 | tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same',strides=1, activation='relu'), 36 | tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'), 37 | tf.keras.layers.Flatten(), 38 | tf.keras.layers.Dense(256, activation='relu'), 39 | tf.keras.layers.Dense(2), # 对应2个action未来总回报预期 40 | ]) 41 | model.compile(loss='mse', optimizer='adam') 42 | 43 | # 尝试加载之前保存的模型参数 44 | try: 45 | model.load_weights('./weights.h5') 46 | print('加载模型成功...................') 47 | except: 48 | pass 49 | return model 50 | 51 | # 创建游戏 52 | game = GameState() 53 | # 卷积模型 54 | model = build_model() 55 | 56 | # 执行1帧游戏 57 | def run_one_frame(action): 58 | global game 59 | # image_data:执行动作后的图像(288*512*3的RGB三维数组) 60 | # reward:本次动作的奖励 61 | # terminal:游戏是否失败 62 | img, reward, terminal = game.frame_step(action) 63 | # RGB转灰度图 64 | img = skimage.color.rgb2gray(img) 65 | # 压缩到80*80的图片(根据RGB算出来的亮度,其数值很小) 66 | img = skimage.transform.resize(img, (80,80)) 67 | # 把亮度标准化到0~1之间,用作模型输入 68 | img = skimage.exposure.rescale_intensity(img, out_range=(0,1)) 69 | return img,reward,terminal 70 | 71 | # 强化学习初始化状态 72 | def reset_stat(): 73 | # 执行第一帧,不点击 74 | img_t,_,_ = run_one_frame([1,0]) 75 | # 卷积网络的输入是连续的4帧游戏画面,对于首帧只能重复4遍 76 | stat_t = np.stack([img_t] * 4, axis=2) 77 | return stat_t 78 | 79 | # 初始状态 80 | stat_t = reset_stat() 81 | # 训练样本 82 | transitions = [] 83 | 84 | # 时刻 85 | t = 0 86 | 87 | # 随机探索的概率控制 88 | INIT_EPSILON = 0.1 89 | FINAL_EPSILON = 0.005 90 | EPSLION_DELTA = 1e-6 91 | # 最大留存样本个数 92 | TRANS_CAP = 20000 93 | # 至少有多少样本才训练 94 | TRANS_SIZE_FIT = 10000 95 | # 训练集大小 96 | BATCH_SIZE = 32 97 | # 未来激励折扣 98 | GAMMA = 0.99 99 | 100 | # 随机探索概率 101 | if args.model_only: # 不随机探索(极低概率) 102 | epsilon = FINAL_EPSILON 103 | else: 104 | epsilon = INIT_EPSILON 105 | 106 | # 打印一些进度信息 107 | rand_flap =0 # 随机点击次数 108 | rand_noflap = 0 # 随机不点击次数 109 | model_flap=0 # 模型点击次数 110 | model_noflap=0 # 模型不点击次数 111 | model_train_times = 0 # 模型训练次数 112 | 113 | # 游戏启动 114 | while True: 115 | # 动作 116 | action_t = [0,0] 117 | 118 | action_type = '随机' 119 | 120 | # 随着学习,降低随机探索的概率,让模型趋于稳定 121 | if (t <= TRANS_SIZE_FIT and not args.model_only) or random.random() <= epsilon: 122 | n = random.random() 123 | if n <= 0.95: 124 | action_index = 0 125 | rand_noflap+=1 126 | else: 127 | action_index = 1 128 | rand_flap+=1 129 | #print('[随机探索] t时刻进行随机动作探索...') 130 | else: # 模型预测2个操作的未来累计回报 131 | action_type = '经验' 132 | Q_t = model.predict(np.expand_dims(stat_t, axis=0))[0] 133 | action_index = np.argmax(Q_t) # 回报最大的action下标 134 | if action_index==0: 135 | model_noflap+=1 136 | else: 137 | model_flap+=1 138 | #print('[已有经验] 预测t时刻2个动作的未来总回报 -- 不点击:{} 点击:{}'.format(Q_t[0], Q_t[1])) 139 | 140 | action_t[action_index] = 1 141 | #print('时刻t将执行的动作为{}'.format(action_t)) 142 | 143 | # 执行当前动作,返回操作后的图片、本次激励、游戏是否结束 144 | img_t1, reward, terminal = run_one_frame(action_t) 145 | _test_save_img(img_t1) 146 | img_t1 = img_t1.reshape((80,80,1)) # 增加通道维度,因为我们要最近4帧作为4通道图片,用作卷积模型输入 147 | stat_t1 = np.append(stat_t[:,:,1:], img_t1, axis=2) # 80*80*4,淘汰当前的第0通道,添加最新t1时刻到第3通道 148 | 149 | # 收集训练样本(保留有限的) 150 | transitions.append({ 151 | 'stat_t': stat_t, # t时刻状态 152 | 'stat_t1': stat_t1, # t1时刻状态 153 | 'reward': reward, # 本次动作的激励得分 154 | 'terminal': terminal, # 执行动作后游戏是否结束(ps: 结束意味着没有未来激励了) 155 | 'action_index': action_index, # 执行了什么动作(0:不点击,1:点击) 156 | }) 157 | if len(transitions) > TRANS_CAP: 158 | transitions.pop(0) 159 | 160 | # 游戏结束则重置stat_t 161 | if terminal: 162 | stat_t = reset_stat() 163 | #print('死了!!!!!!! 状态t重置为初始帧...') 164 | else: # 否则切为新的状态 165 | stat_t = stat_t1 166 | #print('没死~~~ 状态t切换为状态t1...') 167 | 168 | # 过了观察期,开始训练 169 | if t >= TRANS_SIZE_FIT and t % 10 == 0: 170 | minibatch = random.sample(transitions, BATCH_SIZE) 171 | # 模型训练的输入:t时刻的状态(最近4帧图片) 172 | inputs_t = np.concatenate([tran['stat_t'].reshape((1,80,80,4)) for tran in minibatch]) 173 | #print('inputs_t shape', inputs_t.shape) 174 | ###################################################### 175 | # 模型训练的输出:t时刻的未来总激励(Q_t = reward+gamma*Q_t1) 176 | # 1,让模型预测t时刻2种action的未来总激励 177 | Q_t = model.predict(inputs_t, batch_size=len(minibatch)) 178 | # 2,让模型预测t1时刻2种action的未来总激励 179 | input_t1 = np.concatenate([tran['stat_t1'].reshape((1,80,80,4)) for tran in minibatch]) 180 | Q_t1 = model.predict(input_t1, batch_size=len(minibatch)) 181 | # 3,保留t1时刻2个action中最大的未来总激励 182 | Q_t1_max = [max(q) for q in Q_t1] 183 | # 4,t时刻进行action_index动作得到真实激励 184 | reward_t = [tran['reward'] for tran in minibatch] 185 | # 5,t时刻进行了什么action 186 | action_index_t = [tran['action_index'] for tran in minibatch] 187 | # 6,t1时刻是否死掉了 188 | terminal = [tran['terminal'] for tran in minibatch] 189 | # 7,修正训练的目标Q_t=reward+gamma*Q_t1 190 | # (t时刻action_index的未来总激励=action_index真实激励+t1时刻预测的最大未来总激励) 191 | for i in range(len(minibatch)): 192 | if terminal[i]: 193 | Q_t[i][action_index_t[i]] = reward_t[i] # 因为t1时刻已经死了,所以没有t1之后的累计激励 194 | else: 195 | Q_t[i][action_index_t[i]] = reward_t[i] + GAMMA*Q_t1_max[i] # Q_t=reward+Q_t1 196 | # print('Q_t shape', Q_t.shape) 197 | # 训练一波 198 | #print(inputs_t) 199 | #print(Q_t) 200 | model.fit(inputs_t, Q_t, batch_size=len(minibatch)) 201 | model_train_times += 1 202 | # 训练1次则降低些许的随机探索概率 203 | if epsilon > FINAL_EPSILON: 204 | epsilon -= EPSLION_DELTA 205 | 206 | # 每5000次batch保存一次模型权重(不适用saved_model,后续加载只会加载权重,模型结构还是程序构造,因为这样可以保持keras model的api) 207 | if model_train_times % 5000 == 0: 208 | model.save_weights('./weights.h5') 209 | 210 | ###################################################### 211 | if t % 100 == 0: 212 | print('总帧数:{} 剩余探索概率:{}% 累计训练次数:{} 累计随机点:{} 累计随机不点:{} 累计模型点:{} 累计模型不点:{} 训练集:{} '.format( 213 | t, round(epsilon * 100, 4), model_train_times, rand_flap, rand_noflap, model_flap, model_noflap, 214 | len(transitions))) 215 | t = t + 1 216 | #time.sleep(1) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/requirements.txt -------------------------------------------------------------------------------- /test_game.py: -------------------------------------------------------------------------------- 1 | """ 2 | 演示pygame制作的flappy bird如何逐帧调用执行 3 | """ 4 | from game.wrapped_flappy_bird import GameState 5 | from random import random 6 | import time 7 | 8 | # 创建游戏 9 | game = GameState() 10 | 11 | # 游戏启动 12 | while True: 13 | r = random() 14 | if r <= 0.92: # 92%的概率不点击屏幕 15 | game.frame_step([1,0]) # 动作:[1,0] 表示不点击 16 | else: # 8%的概率点击屏幕 17 | game.frame_step([0,1]) # 动作:[0,1] 表示点击 -------------------------------------------------------------------------------- /weights.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/FlappyBirdTensorflow/4b95ca33262b32ceee6ac91ce74ad185ebe0953a/weights.h5 --------------------------------------------------------------------------------