├── BrainDQN_Nature.py ├── HighSpeedRacing.py ├── README.md ├── SaveExpertAction.py ├── assets ├── audio │ ├── die.ogg │ ├── die.wav │ ├── hit.ogg │ ├── hit.wav │ ├── point.ogg │ ├── point.wav │ ├── swoosh.ogg │ ├── swoosh.wav │ ├── wing.ogg │ └── wing.wav └── sprites │ ├── 0.png │ ├── 1.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ ├── 7.png │ ├── 8.png │ ├── 9.png │ ├── background.png │ ├── base - 副本.png │ ├── base.png │ ├── leftLane.png │ ├── obscaleCar.png │ ├── pipe-green - 副本.png │ ├── rightLane.png │ ├── straight.png │ └── 捕获.PNG └── game ├── HighSpeedRacingGame.py ├── __pycache__ ├── HighSpeedRacingGame.cpython-35.pyc ├── flappy_bird_utils.cpython-35.pyc ├── utils.cpython-35.pyc └── wrapped_flappy_bird.cpython-35.pyc └── utils.py /BrainDQN_Nature.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import random 4 | from collections import deque 5 | 6 | # Hyper Parameters: 7 | FRAME_PER_ACTION = 1 8 | GAMMA = 0.99 # decay rate of past observations 9 | EXPLORE = 30000. #训练中最大交互次数 frames over which to anneal epsilon 10 | #FINAL_EPSILON = 0#0.001 # final value of epsilon 11 | #INITIAL_EPSILON = 0#0.01 # starting value of epsilon 12 | FINAL_EPSILON = 0.001#0.001 # final value of epsilon 13 | INITIAL_EPSILON = 0.001#0.01 # starting value of epsilon 14 | #INITIAL_EPSILON = 0.001#0.01 # starting value of epsilon 15 | INITIAL_EPSILON = 0.3#0.01 # starting value of epsilon 16 | REPLAY_MEMORY = 50000 # 可存储的最大状态转移信息条数 17 | #BATCH_SIZE = 64 # size of minibatch 18 | #UPDATE_TIME = 200 19 | #BATCH_SIZE = 32 # size of minibatch 20 | #UPDATE_TIME = 100 21 | BATCH_SIZE = 256*2 # size of minibatch 22 | OBSERVE = BATCH_SIZE + 10 # 交互OBSERVE次后用神经网络逼近值函数UPDATE_TIME = 1000 23 | UPDATE_TIME = 1000 24 | alpha = 1e-5 25 | try: 26 | tf.mul 27 | except: 28 | # For new version of tensorflow 29 | # tf.mul has been removed in new version of tensorflow 30 | # Using tf.multiply to replace tf.mul 31 | tf.mul = tf.multiply 32 | 33 | class BrainDQN: 34 | 35 | def __init__(self,actions, imgDim): 36 | self.imgDim = imgDim 37 | # init replay memory 38 | self.replayMemory = deque() 39 | # init some parameters 40 | self.timeStepLast = self.timeStep = 0 41 | 42 | self.epsilon = INITIAL_EPSILON 43 | self.actions = actions 44 | # init Q network 45 | self.stateInput,self.QValue,self.W1,self.b1,self.W2,self.b2 = self.createQNetwork() 46 | # init Target Q Network self.QValueT用于记录用计算TD时用到的下一个状态的值函数 47 | self.stateInputT,self.QValueT,self.W1T,self.b1T,self.W2T,self.b2T = self.createQNetwork() 48 | self.loss_temp = 0 49 | self.copyTargetQNetworkOperation = [self.W1T.assign(self.W1),self.b1T.assign(self.b1),self.W2T.assign(self.W2),self.b2T.assign(self.b2)] 50 | self.createTrainingMethod() 51 | 52 | # saving and loading networks 53 | self.saver = tf.train.Saver() 54 | self.session = tf.InteractiveSession() 55 | self.session.run(tf.initialize_all_variables()) 56 | checkpoint = tf.train.get_checkpoint_state("saved_networks") 57 | if checkpoint and checkpoint.model_checkpoint_path: 58 | self.saver.restore(self.session, checkpoint.model_checkpoint_path) 59 | print ("Successfully loaded:", checkpoint.model_checkpoint_path) 60 | a = checkpoint.model_checkpoint_path 61 | self.timeStepLast = int(a.split('-')[-1]) 62 | else: 63 | print ("Could not find old network weights") 64 | 65 | 66 | def createQNetwork(self): 67 | # # network weights 68 | self.single_image_units = (20 + 10)*3*5 69 | # self.single_units = self.single_image_units + self.actions 70 | self.single_units = self.single_image_units 71 | self.in_units = self.single_units*3 + self.single_image_units 72 | in_units = self.in_units 73 | h1_units = max(10,int(in_units/2)) 74 | o_units = self.actions 75 | W1 = tf.Variable(tf.truncated_normal([in_units, h1_units], stddev=0.1)) 76 | b1 = tf.Variable(tf.zeros([h1_units])) 77 | W2 = tf.Variable(tf.zeros([h1_units, o_units])) 78 | b2 = tf.Variable(tf.zeros([o_units])) 79 | 80 | stateInput = tf.placeholder(tf.float32, [None, in_units]) 81 | hidden1 = tf.nn.relu(tf.matmul(stateInput, W1) + b1) 82 | # hidden1_drop = tf.nn.dropout(hidden1, keep_prob) 83 | #y = tf.nn.softmax(tf.matmul(hidden1_drop, W2) + b2) 84 | QValue = tf.matmul(hidden1, W2) + b2 85 | return stateInput,QValue,W1,b1,W2,b2 86 | 87 | def copyTargetQNetwork(self): 88 | self.session.run(self.copyTargetQNetworkOperation) 89 | 90 | def createTrainingMethod(self): 91 | self.actionInput = tf.placeholder("float",[None,self.actions]) 92 | self.yInput = tf.placeholder("float", [None]) 93 | Q_Action = tf.reduce_sum(tf.mul(self.QValue, self.actionInput), reduction_indices = 1) 94 | self.cost = tf.reduce_mean(tf.square(self.yInput - Q_Action)) 95 | # self.trainStep = tf.train.AdagradOptimizer(1e-6).minimize(self.cost) 96 | self.trainStep = tf.train.AdamOptimizer(alpha).minimize(self.cost) 97 | 98 | 99 | def trainQNetwork(self): 100 | # Step 1: obtain random minibatch from replay memory 101 | minibatch = random.sample(self.replayMemory,BATCH_SIZE) 102 | state_batch = [data[0] for data in minibatch] 103 | action_batch = [data[1] for data in minibatch] 104 | reward_batch = [data[2] for data in minibatch] 105 | nextState_batch = [data[3] for data in minibatch] 106 | 107 | # Step 2: calculate y 108 | y_batch = [] 109 | # print('nextState_batch:',nextState_batch) 110 | # print('nextState_batch[0]:',nextState_batch[0]) 111 | QValue_batch = self.QValueT.eval(feed_dict={self.stateInputT:nextState_batch}) 112 | for i in range(0,BATCH_SIZE): 113 | terminal = minibatch[i][4] 114 | if terminal: 115 | y_batch.append(reward_batch[i]) 116 | else: 117 | y_batch.append(reward_batch[i] + GAMMA * np.max(QValue_batch[i])) 118 | 119 | self.loss_temp, _ = self.session.run([self.cost, self.trainStep],feed_dict={self.yInput : y_batch, self.actionInput : action_batch, self.stateInput : state_batch}) 120 | 121 | # save network every 100000 iteration 122 | if self.timeStep % 3000 == 0: 123 | self.saver.save(self.session, 'saved_networks/' + 'network' + '-dqn', global_step = self.timeStep) 124 | if self.timeStep % UPDATE_TIME == 0: 125 | self.copyTargetQNetwork() 126 | 127 | def setPerception(self,nextObservation,action,reward,terminal): 128 | newState = np.hstack((self.currentState[self.single_image_units:],nextObservation[0])) 129 | self.replayMemory.append((self.currentState,action,reward,newState,terminal)) 130 | if len(self.replayMemory) > REPLAY_MEMORY: 131 | self.replayMemory.popleft() 132 | if self.timeStep > OBSERVE: 133 | # Train the network 134 | self.trainQNetwork() 135 | 136 | # print info 137 | state = "" 138 | if self.timeStep <= OBSERVE: 139 | state = "observe" 140 | elif self.timeStep > OBSERVE and self.timeStep <= OBSERVE + EXPLORE: 141 | state = "explore" 142 | else: 143 | state = "train" 144 | print ("TIMESTEP", self.timeStep, "/ STATE", state, "/ EPSILON %0.3f" %(self.epsilon), "loss: %0.5f" %(self.loss_temp)) 145 | self.currentState = newState 146 | self.timeStep += 1 147 | 148 | def getAction(self): 149 | # print("self.currentState:",self.currentState) 150 | QValue = self.QValue.eval(feed_dict= {self.stateInput:[self.currentState]})[0] 151 | action = np.zeros(self.actions) 152 | action_index = 1 153 | if self.timeStep % FRAME_PER_ACTION == 0: 154 | if random.random() <= self.epsilon: 155 | action_index = random.randrange(self.actions) 156 | action[action_index] = 1 157 | else: 158 | action_index = np.argmax(QValue) 159 | action[action_index] = 1 160 | else: 161 | action[action_index] = 1 # do nothing 162 | # change episilon 163 | if self.epsilon > FINAL_EPSILON and self.timeStep > OBSERVE: 164 | self.epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/EXPLORE 165 | return action 166 | 167 | def setInitState(self,observation,action0): 168 | # self.currentState = np.hstack((observation[0], action0, observation[0], action0, observation[0], action0, observation[0])) 169 | self.currentState = np.hstack((observation[0], observation[0], observation[0], observation[0])) 170 | 171 | def weight_variable(self,shape): 172 | initial = tf.truncated_normal(shape, stddev = 0.01) 173 | return tf.Variable(initial) 174 | 175 | def bias_variable(self,shape): 176 | initial = tf.constant(0.01, shape = shape) 177 | return tf.Variable(initial) 178 | 179 | def conv2d(self,x, W, stride): 180 | return tf.nn.conv2d(x, W, strides = [1, stride, stride, 1], padding = "SAME") 181 | 182 | def max_pool_2x2(self,x): 183 | return tf.nn.max_pool(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME") 184 | 185 | -------------------------------------------------------------------------------- /HighSpeedRacing.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import sys 3 | sys.path.append("game/") 4 | import HighSpeedRacingGame as game 5 | from BrainDQN_Nature import BrainDQN 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import time 9 | imgDim = [80*1,80*1] 10 | # preprocess raw image to 80*80 gray image 11 | def preprocess(observation): 12 | observation = cv2.cvtColor(cv2.resize(observation, (imgDim[0], imgDim[1])), cv2.COLOR_BGR2GRAY) 13 | ret, observation = cv2.threshold(observation,1,255,cv2.THRESH_BINARY) 14 | return np.reshape(observation,(imgDim[0],imgDim[1],1)) 15 | 16 | def HighSpeedRacing(): 17 | # Step 1: init BrainDQN 18 | actions = 5 19 | brain = BrainDQN(actions, imgDim) 20 | # Step 2: init Flappy Bird Game 21 | flappyBird = game.GameState() 22 | # Step 3: play game 23 | # Step 3.1: obtain init state 24 | action0 = np.array([0,1,0,0,0]) # do nothing 25 | observation0, reward0, terminal = flappyBird.frame_step(action0) 26 | print(observation0) 27 | # print('observation0 1:',observation0) 28 | # observation0 = cv2.cvtColor(cv2.resize(observation0, (imgDim[0],imgDim[1])), cv2.COLOR_BGR2GRAY) 29 | # ret, observation0 = cv2.threshold(observation0,1,255,cv2.THRESH_BINARY) 30 | brain.setInitState(observation0,action0) #将observation0复制4份放进BrainDQN的属性self.currentState中 31 | 32 | # isUseExpertData = False 33 | ## isUseExpertData = True 34 | # if(isUseExpertData == True): 35 | # filename = "./expertData/observation" 36 | # actInd = 0 37 | # observation0 = np.load(filename + str(actInd) + ".npy") 38 | # plt.imshow(observation0) 39 | # # # Step 3.2: run the game 40 | # # while 1!= 0: 41 | # for _ in range(1): 42 | # actInd = 0 43 | # for actInd in range(1,2073): 44 | # actInd += 1 45 | # action = np.load(filename + "action" + str(actInd) + ".npy") 46 | # reward = np.load(filename + "reward" + str(actInd) + ".npy") 47 | # terminal = np.load(filename + "terminal" + str(actInd) + ".npy") 48 | # nextObservation = np.load(filename + str(actInd) + ".npy") 49 | # plt.imshow(nextObservation) 50 | # nextObservation = preprocess(nextObservation) 51 | # brain.setPerception(nextObservation,action,reward,terminal) 52 | loss=[] 53 | plt.figure() 54 | ind = 0 55 | # Step 3.2: run the game 56 | while 1!= 0: 57 | # time.sleep(0.1) 58 | action= brain.getAction() 59 | loss.append(brain.loss_temp) 60 | ind += 1 61 | if ind%500==499: 62 | plt.plot(loss) 63 | plt.show() 64 | nextObservation,reward,terminal = flappyBird.frame_step(action) 65 | # nextObservation = preprocess(nextObservation) 66 | brain.setPerception(nextObservation,action,reward,terminal) 67 | 68 | def main(): 69 | HighSpeedRacing() 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HighSpeedRacing 2 | achieve DeepTraffic(MIT 6.S094: Deep Learning for Self-Driving Cars) by Tensorflow and pygame 3 | -------------------------------------------------------------------------------- /SaveExpertAction.py: -------------------------------------------------------------------------------- 1 | # ------------------------- 2 | # Project: Deep Q-Learning on Flappy Bird 3 | # Author: Flood Sung 4 | # Date: 2016.3.21 5 | # ------------------------- 6 | 7 | import cv2 8 | import sys 9 | sys.path.append("game/") 10 | import wrapped_flappy_bird as game 11 | from BrainDQN_Nature import BrainDQN 12 | import numpy as np 13 | imgDim = [80*1,80*1] 14 | # preprocess raw image to 80*80 gray image 15 | def preprocess(observation): 16 | observation = cv2.cvtColor(cv2.resize(observation, (imgDim[0], imgDim[1])), cv2.COLOR_BGR2GRAY) 17 | ret, observation = cv2.threshold(observation,1,255,cv2.THRESH_BINARY) 18 | return np.reshape(observation,(imgDim[0],imgDim[1],1)) 19 | 20 | def playFlappyBird(): 21 | # Step 1: init BrainDQN 22 | actions = 5 23 | brain = BrainDQN(actions, imgDim) 24 | # Step 2: init Flappy Bird Game 25 | flappyBird = game.GameState() 26 | # Step 3: play game 27 | # Step 3.1: obtain init state 28 | action0 = np.array([0,1,0,0,0]) # do nothing 29 | observation0, reward0, terminal = flappyBird.frame_step(action0) 30 | observation0 = cv2.cvtColor(cv2.resize(observation0, (imgDim[0],imgDim[1])), cv2.COLOR_BGR2GRAY) 31 | ret, observation0 = cv2.threshold(observation0,1,255,cv2.THRESH_BINARY) 32 | brain.setInitState(observation0) #将observation0复制4份放进BrainDQN的属性self.currentState中 33 | 34 | filename = "./expertData/observation" 35 | ''' 36 | 第1次试验用,其他全注释 37 | ''' 38 | # actInd = 0 39 | # np.save(filename + str(actInd), observation0) 40 | 41 | actInd = 1045 #上次记录的最后一个数字 42 | # Step 3.2: run the game 43 | while 1!= 0: 44 | # action = brain.getAction()5 45 | act = 0 46 | while(act not in [2,5,8,6,4]): 47 | act = input("Please intput your action:") 48 | if(act == ''): continue 49 | act = int(act) 50 | 51 | if(act == 2): action = np.array([1,0,0,0,0]) 52 | if(act == 5): action = np.array([0,1,0,0,0]) 53 | if(act == 8): action = np.array([0,0,1,0,0]) 54 | if(act == 6): action = np.array([0,0,0,1,0]) 55 | if(act == 4): action = np.array([0,0,0,0,1]) 56 | 57 | # ExpertAct.append(action.tolist()) 58 | nextObservation,reward,terminal = flappyBird.frame_step(action) 59 | actInd += 1 60 | np.save(filename + str(actInd), nextObservation) 61 | np.save(filename + "action" + str(actInd), action) 62 | np.save(filename + "reward" + str(actInd), reward) 63 | np.save(filename + "terminal" + str(actInd), terminal) 64 | 65 | nextObservation = preprocess(nextObservation) 66 | #brain.setPerception(nextObservation,action,reward,terminal) 67 | 68 | def main(): 69 | playFlappyBird() 70 | 71 | if __name__ == '__main__': 72 | main() -------------------------------------------------------------------------------- /assets/audio/die.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/audio/die.ogg -------------------------------------------------------------------------------- /assets/audio/die.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/audio/die.wav -------------------------------------------------------------------------------- /assets/audio/hit.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/audio/hit.ogg -------------------------------------------------------------------------------- /assets/audio/hit.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/audio/hit.wav -------------------------------------------------------------------------------- /assets/audio/point.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/audio/point.ogg -------------------------------------------------------------------------------- /assets/audio/point.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/audio/point.wav -------------------------------------------------------------------------------- /assets/audio/swoosh.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/audio/swoosh.ogg -------------------------------------------------------------------------------- /assets/audio/swoosh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/audio/swoosh.wav -------------------------------------------------------------------------------- /assets/audio/wing.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/audio/wing.ogg -------------------------------------------------------------------------------- /assets/audio/wing.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/audio/wing.wav -------------------------------------------------------------------------------- /assets/sprites/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/0.png -------------------------------------------------------------------------------- /assets/sprites/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/1.png -------------------------------------------------------------------------------- /assets/sprites/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/2.png -------------------------------------------------------------------------------- /assets/sprites/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/3.png -------------------------------------------------------------------------------- /assets/sprites/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/4.png -------------------------------------------------------------------------------- /assets/sprites/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/5.png -------------------------------------------------------------------------------- /assets/sprites/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/6.png -------------------------------------------------------------------------------- /assets/sprites/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/7.png -------------------------------------------------------------------------------- /assets/sprites/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/8.png -------------------------------------------------------------------------------- /assets/sprites/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/9.png -------------------------------------------------------------------------------- /assets/sprites/background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/background.png -------------------------------------------------------------------------------- /assets/sprites/base - 副本.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/base - 副本.png -------------------------------------------------------------------------------- /assets/sprites/base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/base.png -------------------------------------------------------------------------------- /assets/sprites/leftLane.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/leftLane.png -------------------------------------------------------------------------------- /assets/sprites/obscaleCar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/obscaleCar.png -------------------------------------------------------------------------------- /assets/sprites/pipe-green - 副本.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/pipe-green - 副本.png -------------------------------------------------------------------------------- /assets/sprites/rightLane.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/rightLane.png -------------------------------------------------------------------------------- /assets/sprites/straight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/straight.png -------------------------------------------------------------------------------- /assets/sprites/捕获.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/assets/sprites/捕获.PNG -------------------------------------------------------------------------------- /game/HighSpeedRacingGame.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import random 4 | import pygame 5 | import utils 6 | import pygame.surfarray as surfarray 7 | from pygame.locals import * 8 | from itertools import cycle 9 | 10 | FPS = 30 11 | patchNumX = 120 12 | STREETNUM = 7 13 | carWidth = 18 14 | carLength = 36 15 | safeForwardPatches = 3 16 | PatcheWidth = 6 17 | PatcheHeight = 6 18 | safeLen = safeForwardPatches*PatcheWidth 19 | SCREENWIDTH = patchNumX*PatcheWidth 20 | SCREENHEIGHT = STREETNUM*(PatcheHeight*2 + carWidth) 21 | SCREENHEIGTHWithWord = SCREENHEIGHT + 50 22 | LANESIDE = 1 23 | STREETPATCHNUM = 5 24 | FORWARDDECT = 20 25 | BACKWARDDECT = 10 26 | 27 | MinObstacleVelX = 2 28 | MaxObstacleVelX = 10 29 | MinPlayerVelX = 1 30 | MaxPlayerVelX = 15 31 | pygame.init() 32 | FPSCLOCK = pygame.time.Clock() 33 | #SCREEN = pygame.display.set_mode((SCREENWIDTH, SCREENHEIGHT)) 34 | SCREEN = pygame.display.set_mode((SCREENWIDTH, SCREENHEIGTHWithWord)) 35 | pygame.display.set_caption('High speed racing') 36 | 37 | IMAGES, SOUNDS, HITMASKS = utils.load() 38 | #p = 0.4 39 | p = 0.05 40 | p = 0.2 41 | PLAYER_LENGTH = IMAGES['player'][0].get_width() 42 | PLAYER_HEIGHT = IMAGES['player'][0].get_height() 43 | OBSCALE_LENGTH = IMAGES['pipe'][0].get_width() 44 | OBSCALE_HEIGHT = IMAGES['pipe'][0].get_height() 45 | BACKGROUND_WIDTH = IMAGES['background'].get_width() 46 | BASEY = SCREENHEIGHT * 1.0 47 | PLAYER_INDEX_GEN = cycle([0, 1, 2, 1]) 48 | 49 | centerliney = [(ind+0.5)*SCREENHEIGHT/STREETNUM for ind in range(STREETNUM)] 50 | 51 | class GameState: 52 | def __init__(self): 53 | self.score = self.playerIndex = self.loopIter = 0 54 | self.basex = 0 55 | self.baseShift = IMAGES['base'].get_width() - BACKGROUND_WIDTH 56 | self.totalRunLength = 0 57 | self.totalTime = 0 58 | ''' 59 | 初始player_car位置设置 60 | ''' 61 | self.initialVecx = 4 62 | self.playerVelY = 0 # player's velocity along Y, default same as playerFlapped 63 | self.playerVelX = self.initialVecx # player's velocity along Y, default same as playerFlapped 64 | # self.playerVelY = SCREENHEIGHT/STREETNUM # player's velocity along Y, default same as playerFlapped 65 | self.playerx = int(SCREENWIDTH * 0.15) 66 | self.playery = int((SCREENHEIGHT - PLAYER_HEIGHT) / 2) 67 | 68 | ''' 69 | 初始obscale_car个数、位置设置 70 | ''' 71 | self.streetLine = [[] for _ in range(STREETNUM)] 72 | # self.StreetVelX = [8,6,5,6,5,5,8] 73 | self.StreetVelX = [3,5,8,3,6,5,4] 74 | self.carVec = [[] for _ in range(STREETNUM)] 75 | carNums = [2,3,2,2,3,3,2] 76 | for streetInd in range(STREETNUM): 77 | self.streetLine[streetInd] = [] 78 | self.carVec[streetInd] = np.tile([self.StreetVelX[streetInd]], carNums[streetInd]).astype(np.float) 79 | x = ((np.arange(carNums[streetInd]) + 0.5 + 0.2*(np.random.rand(carNums[streetInd]) - 1))*SCREENWIDTH/carNums[streetInd]).astype(np.int16) 80 | if(streetInd == 3): 81 | x = (self.playerx + 2*carLength) + ((np.arange(carNums[streetInd]) + 0.5*(np.random.rand(carNums[streetInd])))*(SCREENWIDTH - (self.playerx + 2*carLength))/carNums[streetInd]).astype(np.int16) 82 | x.sort() 83 | for indCar in range(carNums[streetInd]): 84 | self.streetLine[streetInd].append({'x': x[indCar], 'y': centerliney[streetInd] - OBSCALE_HEIGHT/2}) 85 | 86 | def frame_step(self, input_actions): 87 | pygame.event.pump() 88 | terminal = False 89 | if sum(input_actions) != 1: 90 | raise ValueError('Multiple input actions!') 91 | 92 | ''' 93 | 动作设置 94 | ''' 95 | if input_actions.argmax() == 0:#向上加速 96 | # self.playerVelY = 5 97 | self.playerVelY = PatcheHeight 98 | reward = 0 99 | if input_actions.argmax() == 1:#do nothing 100 | self.playerVelX += 0 101 | self.playerVelY = 0 102 | reward = 0.1 103 | if input_actions.argmax() == 2:#向上加速 104 | # self.playerVelY = -5 105 | self.playerVelY = -PatcheHeight 106 | reward = 0 107 | if input_actions.argmax() == 3:#向前加速 108 | self.playerVelX += 1 109 | if input_actions.argmax() == 4:#向后加速 110 | self.playerVelX -= 1 111 | 112 | ''' 113 | 防止playerCar撞上前面的obscaleCar 114 | ''' 115 | # indStart = int(self.playery/(SCREENHEIGHT/STREETNUM)) 116 | # indEnd = int((self.playery + PLAYER_HEIGHT)/(SCREENHEIGHT/STREETNUM)) 117 | ## print('indStart:',indStart,'indEnd:',indEnd) 118 | # for ind in range(indStart,indEnd+1): 119 | # print('ind:',ind) 120 | # n = len(self.streetLine[ind]) 121 | ## indFront = -1 122 | ## xFront = -1 123 | # for i2 in range(n): 124 | # if self.streetLine[ind][(i2)%n]['x'] > self.playerx + OBSCALE_LENGTH + safeLen : 125 | # d = self.streetLine[ind][(i2)%n]['x'] - (self.playerx + OBSCALE_LENGTH + safeLen) # 距离前车的距离 126 | # else: 127 | # d = SCREENWIDTH + self.streetLine[ind][(i2)%n]['x'] + OBSCALE_LENGTH + safeLen - self.playerx 128 | # if self.playerVelX > d: 129 | # print('sdfsfsfdsdfsdfsdf,ind:',ind) 130 | # self.playerVelX = self.carVec[ind][(i2)%n] 131 | # print("self.playerVelX:",self.playerVelX) 132 | 133 | self.playerVelX = max(self.playerVelX, MinPlayerVelX) 134 | self.playerVelX = min(self.playerVelX, MaxPlayerVelX) 135 | self.baseVelX = -self.playerVelX 136 | self.totalRunLength += self.playerVelX 137 | self.playery += self.playerVelY 138 | ''' 139 | playerCar加速减速奖励 140 | ''' 141 | if input_actions.argmax() == 3:#向前加速 142 | # reward = 0.2*(self.playerVelX - 0) 143 | # reward = 0.1*(self.playerVelX - 4) 144 | reward = 0.1*(self.playerVelX - MinPlayerVelX) 145 | if input_actions.argmax() == 4:#向后加速 146 | reward = 0.1 147 | # reward = 0.05*(self.playerVelX - MinPlayerVelX) 148 | #离中心线越近奖励越多 149 | # reward = 0.1 + 0.5*abs((int(self.playery + PLAYER_HEIGHT/2))%(int(SCREENHEIGHT/STREETNUM)) - int(SCREENHEIGHT/STREETNUM/2)) + 0.3*abs(self.playerVelX - self.initialVecx) 150 | # reward = 0.1 + 0.5*abs((int(self.playery + PLAYER_HEIGHT/2))%(int(SCREENHEIGHT/STREETNUM)) - int(SCREENHEIGHT/STREETNUM/2)) + 10*(self.playerVelX - self.initialVecx) 151 | # reward = 0.1*(self.playerVelX - self.initialVecx) 152 | # if (abs((int(self.playery + PLAYER_HEIGHT/2))%(int(SCREENHEIGHT/STREETNUM)) - int(SCREENHEIGHT/STREETNUM/2))) <= 5 : 153 | # reward += 0.1 154 | 155 | ''' 156 | 超车数量 157 | ''' 158 | playerFront = self.playerx + PLAYER_LENGTH 159 | for ind in range(STREETNUM): 160 | for obscale in self.streetLine[ind]: 161 | obscaleCar = obscale['x'] + OBSCALE_LENGTH 162 | # if obscaleCar <= playerFront <= obscaleCar + 1.1*self.StreetVelX[ind]: 163 | # if playerFront <= obscaleCar and playerFront + self.StreetVelX[ind] > obscaleCar: 164 | if playerFront <= obscaleCar and playerFront + self.playerVelX > obscaleCar: 165 | self.score += 1 166 | #SOUNDS['point'].play() 167 | # reward += 1 168 | self.totalTime += 1 169 | showStr = [u" Vec:%3.2fKm/h AverVec:%3.2f Km/h" %(10*self.playerVelX, 10*self.totalRunLength/self.totalTime), 170 | u"Mileage:%5.2fKm CarPassed:%d" %(self.totalRunLength*10/3600, self.score)] 171 | 172 | ''' 173 | 车道线移动速度 174 | ''' 175 | # self.basex = -((-self.basex + 100) % self.baseShift) 176 | self.basex = self.basex + self.baseVelX 177 | if(abs(self.basex + self.baseShift) < abs(max(self.StreetVelX)*2)): self.basex = 0 178 | 179 | ''' 180 | obscaleCar移动速度 181 | ''' 182 | for ind in range(STREETNUM): 183 | n = len(self.streetLine[ind]) 184 | for i2 in range(n): 185 | if self.streetLine[ind][(i2+1)%n]['x'] > self.streetLine[ind][i2]['x'] + OBSCALE_LENGTH + safeLen: 186 | d = self.streetLine[ind][(i2+1)%n]['x'] - (self.streetLine[ind][i2]['x'] + OBSCALE_LENGTH + safeLen) # 距离前车的距离 187 | else: 188 | d = SCREENWIDTH + self.streetLine[ind][(i2+1)%n]['x'] + OBSCALE_LENGTH + safeLen - self.streetLine[ind][i2]['x'] 189 | if self.carVec[ind][i2] < d: 190 | if np.random.rand() > p: 191 | # self.carVec[ind][i2] += 0.1 192 | self.carVec[ind][i2] += 0.5 193 | else: 194 | # self.carVec[ind][i2] -= 0.1 195 | self.carVec[ind][i2] -= 0.5 196 | else: 197 | # print(d) 198 | self.carVec[ind][i2] = self.carVec[ind][(i2+1)%n] 199 | 200 | for ind in range(STREETNUM): 201 | for indCarVeec in range(len(self.streetLine[ind])): 202 | self.carVec[ind][indCarVeec] = min(self.carVec[ind][indCarVeec], MaxObstacleVelX) 203 | self.carVec[ind][indCarVeec] = max(self.carVec[ind][indCarVeec], MinObstacleVelX) 204 | 205 | ''' 206 | 限制obscaleCar速度以免撞上playerCar 207 | ''' 208 | ind = int(self.playery/(SCREENHEIGHT/STREETNUM)) 209 | n = len(self.streetLine[ind]) 210 | for i2 in range(n): 211 | if self.playerx > self.streetLine[ind][(i2)%n]['x'] + OBSCALE_LENGTH + safeLen: 212 | d = self.playerx - (self.streetLine[ind][(i2)%n]['x'] + OBSCALE_LENGTH + safeLen) # 距离前车的距离 213 | else: 214 | d = SCREENWIDTH + self.playerx + OBSCALE_LENGTH + safeLen - self.streetLine[ind][(i2)%n]['x'] 215 | if self.carVec[ind][(i2)%n] > d: 216 | self.carVec[ind][(i2)%n] = self.playerVelX 217 | 218 | ''' 219 | obscaleCar在屏幕中的位置 220 | ''' 221 | for ind in range(STREETNUM): 222 | for indCarVeec in range(len(self.streetLine[ind])): 223 | self.streetLine[ind][indCarVeec]['x'] += self.carVec[ind][indCarVeec] - self.playerVelX 224 | 225 | for ind in range(STREETNUM): 226 | if self.streetLine[ind][0]['x'] < -OBSCALE_LENGTH: 227 | temp = self.streetLine[ind].pop(0) 228 | temp['x'] += SCREENWIDTH+ OBSCALE_LENGTH 229 | self.streetLine[ind].append(temp) 230 | if self.streetLine[ind][-1]['x'] > SCREENWIDTH*1: 231 | temp = self.streetLine[ind].pop(-1) 232 | temp['x'] -= SCREENWIDTH + OBSCALE_LENGTH 233 | self.streetLine[ind].insert(0, temp) 234 | 235 | # playerIndex basex change 236 | if (self.loopIter + 1) % 3 == 0: 237 | self.playerIndex = next(PLAYER_INDEX_GEN) 238 | self.loopIter = (self.loopIter + 1) % 30 239 | # check if crash here 240 | isCrash= checkCrash({'x': self.playerx, 'y': self.playery,'index': self.playerIndex},self.streetLine) 241 | if isCrash: 242 | #SOUNDS['hit'].play() 243 | #SOUNDS['die'].play() 244 | terminal = True 245 | self.__init__() 246 | reward = -3 247 | 248 | # draw sprites 249 | SCREEN.blit(IMAGES['background'], (0,0)) 250 | 251 | for ind in range(STREETNUM): 252 | for obscaleCar in self.streetLine[ind]: 253 | SCREEN.blit(IMAGES['pipe'][0], (obscaleCar['x'], obscaleCar['y'])) 254 | 255 | for ind in range(STREETNUM-1): 256 | SCREEN.blit(IMAGES['base'], (self.basex, (ind+1)*SCREENHEIGHT/STREETNUM)) 257 | # print score so player overlaps the score 258 | showScore(self.score, showStr) 259 | SCREEN.blit(IMAGES['player'][self.playerIndex], (self.playerx, self.playery)) 260 | 261 | totalPatches = np.zeros((STREETNUM*STREETPATCHNUM, patchNumX)) 262 | for streetInd in range(STREETNUM): 263 | for car in self.streetLine[streetInd]: 264 | carLocxInd = int(car['x']/PatcheWidth) 265 | totalPatches[(streetInd*STREETPATCHNUM+1):(streetInd*STREETPATCHNUM+4), carLocxInd:int((car['x'] + carLength)/PatcheWidth + 1)] = 1 266 | yInd = int(self.playery/(PatcheHeight)) 267 | xInd = int(self.playerx/(PatcheWidth)) 268 | totalPatches[yInd:int((self.playery+carWidth)/PatcheHeight), xInd:int((self.playerx+carLength)/PatcheWidth + 1)] = 2 269 | a = np.vstack((np.ones((STREETNUM*STREETPATCHNUM, patchNumX)),totalPatches)) 270 | totalPatches = np.vstack((a,np.ones((STREETNUM*STREETPATCHNUM, patchNumX)))) 271 | xInd0 = int(STREETNUM*STREETPATCHNUM + yInd - LANESIDE*STREETPATCHNUM - STREETPATCHNUM/2 + 1) 272 | xInd1 = int(STREETNUM*STREETPATCHNUM + yInd + 1 + LANESIDE*STREETPATCHNUM + STREETPATCHNUM/2 ) 273 | yInd0 = int(xInd + 4 - BACKWARDDECT) 274 | yInd1 = int(xInd + 4 + FORWARDDECT) 275 | image_data = totalPatches[xInd0:xInd1, yInd0:yInd1] 276 | image_data = (image_data).reshape(1, image_data.shape[0]*image_data.shape[1]) 277 | pygame.display.update() 278 | FPSCLOCK.tick(FPS) 279 | return image_data, reward, terminal 280 | 281 | def getRandomPipe(): 282 | """returns a randomly generated pipe""" 283 | # y of gap between upper and lower pipe 284 | pipeX = SCREENWIDTH + 10 285 | 286 | return [ 287 | # {'x': pipeX, 'y': centerliney[int(random.random()*(STREETNUM))] - OBSCALE_HEIGHT/2}, # upper pipe 288 | {'x': pipeX, 'y': centerliney[int(random.random()*(STREETNUM))] - OBSCALE_HEIGHT/2} # lower pipe 289 | ] 290 | 291 | def showScore(score, showStr): 292 | """displays score in center of screen""" 293 | # scoreDigits = [int(x) for x in list(str(score))] 294 | # totalWidth = 0 # total width of all numbers to be printed 295 | # for digit in scoreDigits: 296 | # totalWidth += IMAGES['numbers'][digit].get_width() 297 | # Xoffset = (SCREENWIDTH - totalWidth) / 2 298 | # for digit in scoreDigits: 299 | # SCREEN.blit(IMAGES['numbers'][digit], (Xoffset, SCREENHEIGHT * 0.1)) 300 | # Xoffset += IMAGES['numbers'][digit].get_width() 301 | 302 | #获取系统字体,并设置文字大小 303 | cur_font = pygame.font.SysFont("Times New Roman", 25) 304 | #设置是否加粗属性 305 | cur_font.set_bold(False) 306 | #设置是否斜体属性 307 | cur_font.set_italic(False) 308 | #设置文字内容 309 | # text = u"car passed:" 310 | text_fmt1 = cur_font.render(showStr[0], 1, (0, 0, 0)) 311 | text_fmt2 = cur_font.render(showStr[1], 1, (0, 0, 0)) 312 | #绘制文字 313 | SCREEN.blit(text_fmt1, (20, SCREENHEIGHT + 0)) 314 | SCREEN.blit(text_fmt2, (20, SCREENHEIGHT + 25)) 315 | 316 | def checkCrash(player, streetLine): 317 | """returns True if player collders with base or pipes.""" 318 | pi = player['index'] 319 | player['w'] = IMAGES['player'][0].get_width() 320 | player['h'] = IMAGES['player'][0].get_height() 321 | # if player crashes into ground 322 | if player['y'] + player['h'] >= BASEY - 1: 323 | return True 324 | elif player['y'] <= 1: 325 | return True 326 | else: 327 | playerRect = pygame.Rect(player['x'], player['y'], 328 | player['w'], player['h']) 329 | for ind in range(STREETNUM): 330 | for uPipe in streetLine[ind]: 331 | # upper and lower pipe rects 332 | uPipeRect = pygame.Rect(uPipe['x'], uPipe['y'], OBSCALE_LENGTH, OBSCALE_HEIGHT) 333 | # player and upper/lower pipe hitmasks 334 | pHitMask = HITMASKS['player'][pi] 335 | uHitmask = HITMASKS['pipe'][0] 336 | # if bird collided with upipe or lpipe 337 | uCollide = pixelCollision(playerRect, uPipeRect, pHitMask, uHitmask) 338 | if uCollide:return True 339 | return False 340 | 341 | def isChangeStreet(player, streetLine): 342 | """returns True if player collders with base or pipes.""" 343 | pi = player['index'] 344 | player['w'] = IMAGES['player'][0].get_width() 345 | player['h'] = IMAGES['player'][0].get_height() 346 | # if player crashes into ground 347 | if player['y'] + player['h'] >= BASEY - 1: 348 | return True 349 | elif player['y'] <= 1: 350 | return True 351 | else: 352 | playerRect = pygame.Rect(player['x'], player['y'], 353 | player['w'], player['h']) 354 | for ind in range(STREETNUM): 355 | for uPipe in streetLine[ind]: 356 | # upper and lower pipe rects 357 | uPipeRect = pygame.Rect(uPipe['x'], uPipe['y'], OBSCALE_LENGTH, OBSCALE_HEIGHT) 358 | # player and upper/lower pipe hitmasks 359 | pHitMask = HITMASKS['player'][pi] 360 | uHitmask = HITMASKS['pipe'][0] 361 | # if bird collided with upipe or lpipe 362 | uCollide = pixelCollision(playerRect, uPipeRect, pHitMask, uHitmask) 363 | if uCollide:return True 364 | return False 365 | 366 | def pixelCollision(rect1, rect2, hitmask1, hitmask2): 367 | """Checks if two objects collide and not just their rects""" 368 | rect = rect1.clip(rect2) 369 | 370 | if rect.width == 0 or rect.height == 0: 371 | return False 372 | 373 | x1, y1 = rect.x - rect1.x, rect.y - rect1.y 374 | x2, y2 = rect.x - rect2.x, rect.y - rect2.y 375 | 376 | for x in range(rect.width): 377 | for y in range(rect.height): 378 | if hitmask1[x1+x][y1+y] and hitmask2[x2+x][y2+y]: 379 | return True 380 | return False 381 | -------------------------------------------------------------------------------- /game/__pycache__/HighSpeedRacingGame.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/game/__pycache__/HighSpeedRacingGame.cpython-35.pyc -------------------------------------------------------------------------------- /game/__pycache__/flappy_bird_utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/game/__pycache__/flappy_bird_utils.cpython-35.pyc -------------------------------------------------------------------------------- /game/__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/game/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /game/__pycache__/wrapped_flappy_bird.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drilistbox/HighSpeedRacing/d838644a3349e723322915ec46bb06c244707eb6/game/__pycache__/wrapped_flappy_bird.cpython-35.pyc -------------------------------------------------------------------------------- /game/utils.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import sys 3 | def load(): 4 | # path of player with different states 5 | PLAYER_PATH = ( 6 | # 'assets/sprites/straight-upflap.png', 7 | # 'assets/sprites/redbird-midflap.png', 8 | # 'assets/sprites/redbird-downflap.png' 9 | 'assets/sprites/leftLane.png', 10 | 'assets/sprites/straight.png', 11 | 'assets/sprites/rightLane.png' 12 | ) 13 | 14 | # path of background 15 | BACKGROUND_PATH = 'assets/sprites/background.png' 16 | 17 | # path of pipe 18 | PIPE_PATH = 'assets/sprites/obscaleCar.png' 19 | 20 | IMAGES, SOUNDS, HITMASKS = {}, {}, {} 21 | 22 | # numbers sprites for score display 23 | IMAGES['numbers'] = ( 24 | pygame.image.load('assets/sprites/0.png').convert_alpha(), 25 | pygame.image.load('assets/sprites/1.png').convert_alpha(), 26 | pygame.image.load('assets/sprites/2.png').convert_alpha(), 27 | pygame.image.load('assets/sprites/3.png').convert_alpha(), 28 | pygame.image.load('assets/sprites/4.png').convert_alpha(), 29 | pygame.image.load('assets/sprites/5.png').convert_alpha(), 30 | pygame.image.load('assets/sprites/6.png').convert_alpha(), 31 | pygame.image.load('assets/sprites/7.png').convert_alpha(), 32 | pygame.image.load('assets/sprites/8.png').convert_alpha(), 33 | pygame.image.load('assets/sprites/9.png').convert_alpha() 34 | ) 35 | 36 | # base (ground) sprite 37 | IMAGES['base'] = pygame.image.load('assets/sprites/base.png').convert_alpha() 38 | 39 | # sounds 40 | if 'win' in sys.platform: 41 | soundExt = '.wav' 42 | else: 43 | soundExt = '.ogg' 44 | 45 | SOUNDS['die'] = pygame.mixer.Sound('assets/audio/die' + soundExt) 46 | SOUNDS['hit'] = pygame.mixer.Sound('assets/audio/hit' + soundExt) 47 | SOUNDS['point'] = pygame.mixer.Sound('assets/audio/point' + soundExt) 48 | SOUNDS['swoosh'] = pygame.mixer.Sound('assets/audio/swoosh' + soundExt) 49 | SOUNDS['wing'] = pygame.mixer.Sound('assets/audio/wing' + soundExt) 50 | 51 | # select random background sprites 52 | IMAGES['background'] = pygame.image.load(BACKGROUND_PATH).convert() 53 | 54 | # select random player sprites 55 | IMAGES['player'] = ( 56 | pygame.image.load(PLAYER_PATH[0]).convert_alpha(), 57 | pygame.image.load(PLAYER_PATH[1]).convert_alpha(), 58 | pygame.image.load(PLAYER_PATH[2]).convert_alpha(), 59 | ) 60 | 61 | # select random pipe sprites 62 | IMAGES['pipe'] = ( 63 | pygame.transform.rotate( 64 | pygame.image.load(PIPE_PATH).convert_alpha(), 180), 65 | pygame.image.load(PIPE_PATH).convert_alpha(), 66 | ) 67 | 68 | # hismask for pipes 69 | HITMASKS['pipe'] = ( 70 | getHitmask(IMAGES['pipe'][0]), 71 | getHitmask(IMAGES['pipe'][1]), 72 | ) 73 | 74 | # hitmask for player 75 | HITMASKS['player'] = ( 76 | getHitmask(IMAGES['player'][0]), 77 | getHitmask(IMAGES['player'][1]), 78 | getHitmask(IMAGES['player'][2]), 79 | ) 80 | 81 | return IMAGES, SOUNDS, HITMASKS 82 | 83 | def getHitmask(image): 84 | """returns a hitmask using an image's alpha.""" 85 | mask = [] 86 | for x in range(image.get_width()): 87 | mask.append([]) 88 | for y in range(image.get_height()): 89 | mask[x].append(bool(image.get_at((x,y))[3])) 90 | return mask 91 | --------------------------------------------------------------------------------