├── requirements.txt ├── pong ├── __init__.py ├── paddle.py ├── ball.py └── game.py ├── best.pickle ├── neat-checkpoint-0 ├── neat-checkpoint-1 ├── neat-checkpoint-2 ├── neat-checkpoint-3 ├── neat-checkpoint-4 ├── neat-checkpoint-5 ├── neat-checkpoint-6 ├── neat-checkpoint-7 ├── neat-checkpoint-8 ├── neat-checkpoint-9 ├── neat-checkpoint-10 ├── neat-checkpoint-11 ├── neat-checkpoint-12 ├── neat-checkpoint-13 ├── neat-checkpoint-14 ├── neat-checkpoint-15 ├── neat-checkpoint-16 ├── neat-checkpoint-17 ├── neat-checkpoint-18 ├── neat-checkpoint-19 ├── neat-checkpoint-20 ├── neat-checkpoint-21 ├── neat-checkpoint-22 ├── neat-checkpoint-23 ├── neat-checkpoint-24 ├── neat-checkpoint-25 ├── neat-checkpoint-26 ├── README.md ├── config.txt ├── .gitignore ├── tutorial.py └── main.py /requirements.txt: -------------------------------------------------------------------------------- 1 | neat-python 2 | pygame 3 | -------------------------------------------------------------------------------- /pong/__init__.py: -------------------------------------------------------------------------------- 1 | from .game import Game 2 | -------------------------------------------------------------------------------- /best.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/best.pickle -------------------------------------------------------------------------------- /neat-checkpoint-0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-0 -------------------------------------------------------------------------------- /neat-checkpoint-1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-1 -------------------------------------------------------------------------------- /neat-checkpoint-2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-2 -------------------------------------------------------------------------------- /neat-checkpoint-3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-3 -------------------------------------------------------------------------------- /neat-checkpoint-4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-4 -------------------------------------------------------------------------------- /neat-checkpoint-5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-5 -------------------------------------------------------------------------------- /neat-checkpoint-6: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-6 -------------------------------------------------------------------------------- /neat-checkpoint-7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-7 -------------------------------------------------------------------------------- /neat-checkpoint-8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-8 -------------------------------------------------------------------------------- /neat-checkpoint-9: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-9 -------------------------------------------------------------------------------- /neat-checkpoint-10: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-10 -------------------------------------------------------------------------------- /neat-checkpoint-11: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-11 -------------------------------------------------------------------------------- /neat-checkpoint-12: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-12 -------------------------------------------------------------------------------- /neat-checkpoint-13: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-13 -------------------------------------------------------------------------------- /neat-checkpoint-14: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-14 -------------------------------------------------------------------------------- /neat-checkpoint-15: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-15 -------------------------------------------------------------------------------- /neat-checkpoint-16: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-16 -------------------------------------------------------------------------------- /neat-checkpoint-17: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-17 -------------------------------------------------------------------------------- /neat-checkpoint-18: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-18 -------------------------------------------------------------------------------- /neat-checkpoint-19: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-19 -------------------------------------------------------------------------------- /neat-checkpoint-20: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-20 -------------------------------------------------------------------------------- /neat-checkpoint-21: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-21 -------------------------------------------------------------------------------- /neat-checkpoint-22: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-22 -------------------------------------------------------------------------------- /neat-checkpoint-23: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-23 -------------------------------------------------------------------------------- /neat-checkpoint-24: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-24 -------------------------------------------------------------------------------- /neat-checkpoint-25: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-25 -------------------------------------------------------------------------------- /neat-checkpoint-26: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliezzahn/ping-pong-ai/HEAD/neat-checkpoint-26 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NEAT-Pong-Python 2 | Using the NEAT algorithm to train an AI to play pong in Python! 3 | -------------------------------------------------------------------------------- /pong/paddle.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | 3 | 4 | class Paddle: 5 | VEL = 4 6 | WIDTH = 20 7 | HEIGHT = 100 8 | 9 | def __init__(self, x, y): 10 | self.x = self.original_x = x 11 | self.y = self.original_y = y 12 | 13 | def draw(self, win): 14 | pygame.draw.rect( 15 | win, (255, 255, 255), (self.x, self.y, self.WIDTH, self.HEIGHT)) 16 | 17 | def move(self, up=True): 18 | if up: 19 | self.y -= self.VEL 20 | else: 21 | self.y += self.VEL 22 | 23 | def reset(self): 24 | self.x = self.original_x 25 | self.y = self.original_y 26 | -------------------------------------------------------------------------------- /pong/ball.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import math 3 | import random 4 | 5 | 6 | class Ball: 7 | MAX_VEL = 5 8 | RADIUS = 7 9 | 10 | def __init__(self, x, y): 11 | self.x = self.original_x = x 12 | self.y = self.original_y = y 13 | 14 | angle = self._get_random_angle(-30, 30, [0]) 15 | pos = 1 if random.random() < 0.5 else -1 16 | 17 | self.x_vel = pos * abs(math.cos(angle) * self.MAX_VEL) 18 | self.y_vel = math.sin(angle) * self.MAX_VEL 19 | 20 | def _get_random_angle(self, min_angle, max_angle, excluded): 21 | angle = 0 22 | while angle in excluded: 23 | angle = math.radians(random.randrange(min_angle, max_angle)) 24 | 25 | return angle 26 | 27 | def draw(self, win): 28 | pygame.draw.circle(win, (255, 255, 255), (self.x, self.y), self.RADIUS) 29 | 30 | def move(self): 31 | self.x += self.x_vel 32 | self.y += self.y_vel 33 | 34 | def reset(self): 35 | self.x = self.original_x 36 | self.y = self.original_y 37 | 38 | angle = self._get_random_angle(-30, 30, [0]) 39 | x_vel = abs(math.cos(angle) * self.MAX_VEL) 40 | y_vel = math.sin(angle) * self.MAX_VEL 41 | 42 | self.y_vel = y_vel 43 | self.x_vel *= -1 44 | -------------------------------------------------------------------------------- /config.txt: -------------------------------------------------------------------------------- 1 | [NEAT] 2 | fitness_criterion = max 3 | fitness_threshold = 400 4 | pop_size = 50 5 | reset_on_extinction = False 6 | 7 | [DefaultStagnation] 8 | species_fitness_func = max 9 | max_stagnation = 20 10 | species_elitism = 2 11 | 12 | [DefaultReproduction] 13 | elitism = 2 14 | survival_threshold = 0.2 15 | 16 | [DefaultGenome] 17 | # node activation options 18 | activation_default = relu 19 | activation_mutate_rate = 1.0 20 | activation_options = relu 21 | 22 | # node aggregation options 23 | aggregation_default = sum 24 | aggregation_mutate_rate = 0.0 25 | aggregation_options = sum 26 | 27 | # node bias options 28 | bias_init_mean = 3.0 29 | bias_init_stdev = 1.0 30 | bias_max_value = 30.0 31 | bias_min_value = -30.0 32 | bias_mutate_power = 0.5 33 | bias_mutate_rate = 0.7 34 | bias_replace_rate = 0.1 35 | 36 | # genome compatibility options 37 | compatibility_disjoint_coefficient = 1.0 38 | compatibility_weight_coefficient = 0.5 39 | 40 | # connection add/remove rates 41 | conn_add_prob = 0.5 42 | conn_delete_prob = 0.5 43 | 44 | # connection enable options 45 | enabled_default = True 46 | enabled_mutate_rate = 0.01 47 | 48 | feed_forward = True 49 | initial_connection = full_direct 50 | 51 | # node add/remove rates 52 | node_add_prob = 0.2 53 | node_delete_prob = 0.2 54 | 55 | # network parameters 56 | num_hidden = 2 57 | num_inputs = 3 58 | num_outputs = 3 59 | 60 | # node response options 61 | response_init_mean = 1.0 62 | response_init_stdev = 0.0 63 | response_max_value = 30.0 64 | response_min_value = -30.0 65 | response_mutate_power = 0.0 66 | response_mutate_rate = 0.0 67 | response_replace_rate = 0.0 68 | 69 | # connection weight options 70 | weight_init_mean = 0.0 71 | weight_init_stdev = 1.0 72 | weight_max_value = 30 73 | weight_min_value = -30 74 | weight_mutate_power = 0.5 75 | weight_mutate_rate = 0.8 76 | weight_replace_rate = 0.1 77 | 78 | [DefaultSpeciesSet] 79 | compatibility_threshold = 3.0 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | 166 | ### Python Patch ### 167 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 168 | poetry.toml 169 | 170 | # ruff 171 | .ruff_cache/ 172 | 173 | # LSP config files 174 | pyrightconfig.json 175 | 176 | # End of https://www.toptal.com/developers/gitignore/api/python -------------------------------------------------------------------------------- /tutorial.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | from pong import Game 3 | import neat 4 | import os 5 | import pickle 6 | 7 | 8 | class PongGame: 9 | def __init__(self, window, width, height): 10 | self.game = Game(window, width, height) 11 | self.left_paddle = self.game.left_paddle 12 | self.right_paddle = self.game.right_paddle 13 | self.ball = self.game.ball 14 | 15 | def test_ai(self, genome, config): 16 | net = neat.nn.FeedForwardNetwork.create(genome, config) 17 | 18 | run = True 19 | clock = pygame.time.Clock() 20 | while run: 21 | clock.tick(60) 22 | for event in pygame.event.get(): 23 | if event.type == pygame.QUIT: 24 | run = False 25 | break 26 | 27 | keys = pygame.key.get_pressed() 28 | if keys[pygame.K_w]: 29 | self.game.move_paddle(left=True, up=True) 30 | if keys[pygame.K_s]: 31 | self.game.move_paddle(left=True, up=False) 32 | 33 | output = net.activate( 34 | (self.right_paddle.y, self.ball.y, abs(self.right_paddle.x - self.ball.x))) 35 | decision = output.index(max(output)) 36 | 37 | if decision == 0: 38 | pass 39 | elif decision == 1: 40 | self.game.move_paddle(left=False, up=True) 41 | else: 42 | self.game.move_paddle(left=False, up=False) 43 | 44 | game_info = self.game.loop() 45 | self.game.draw(True, False) 46 | pygame.display.update() 47 | 48 | pygame.quit() 49 | 50 | def train_ai(self, genome1, genome2, config): 51 | net1 = neat.nn.FeedForwardNetwork.create(genome1, config) 52 | net2 = neat.nn.FeedForwardNetwork.create(genome2, config) 53 | 54 | run = True 55 | while run: 56 | for event in pygame.event.get(): 57 | if event.type == pygame.QUIT: 58 | quit() 59 | 60 | output1 = net1.activate( 61 | (self.left_paddle.y, self.ball.y, abs(self.left_paddle.x - self.ball.x))) 62 | decision1 = output1.index(max(output1)) 63 | 64 | if decision1 == 0: 65 | pass 66 | elif decision1 == 1: 67 | self.game.move_paddle(left=True, up=True) 68 | else: 69 | self.game.move_paddle(left=True, up=False) 70 | 71 | output2 = net2.activate( 72 | (self.right_paddle.y, self.ball.y, abs(self.right_paddle.x - self.ball.x))) 73 | decision2 = output2.index(max(output2)) 74 | 75 | if decision2 == 0: 76 | pass 77 | elif decision2 == 1: 78 | self.game.move_paddle(left=False, up=True) 79 | else: 80 | self.game.move_paddle(left=False, up=False) 81 | 82 | game_info = game.loop() 83 | 84 | game.draw(draw_score=False, draw_hits=True) 85 | pygame.display.update() 86 | 87 | if game_info.left_score >= 1 or game_info.right_score >= 1 or game_info.left_hits > 50: 88 | self.calculate_fitness(genome1, genome2, game_info) 89 | break 90 | 91 | def calculate_fitness(self, genome1, genome2, game_info): 92 | genome1.fitness += game_info.left_hits 93 | genome2.fitness += game_info.right_hits 94 | 95 | 96 | def eval_genomes(genomes, config): 97 | width, height = 700, 500 98 | window = pygame.display.set_mode((width, height)) 99 | 100 | for i, (genome_id1, genome1) in enumerate(genomes): 101 | if i == len(genomes) - 1: 102 | break 103 | genome1.fitness = 0 104 | for genome_id2, genome2 in genomes[i+1:]: 105 | genome2.fitness = 0 if genome2.fitness == None else genome2.fitness 106 | game = PongGame(window, width, height) 107 | game.train_ai(genome1, genome2, config) 108 | 109 | 110 | def run_neat(config): 111 | p = neat.Checkpointer.restore_checkpoint('neat-checkpoint-7') 112 | #p = neat.Population(config) 113 | p.add_reporter(neat.StdOutReporter(True)) 114 | stats = neat.StatisticsReporter() 115 | p.add_reporter(stats) 116 | p.add_reporter(neat.Checkpointer(1)) 117 | 118 | winner = p.run(eval_genomes, 1) 119 | with open("best.pickle", "wb") as f: 120 | pickle.dump(winner, f) 121 | 122 | 123 | def test_ai(config): 124 | width, height = 700, 500 125 | window = pygame.display.set_mode((width, height)) 126 | 127 | with open("best.pickle", "rb") as f: 128 | winner = pickle.load(f) 129 | 130 | game = PongGame(window, width, height) 131 | game.test_ai(winner, config) 132 | 133 | 134 | if __name__ == "__main__": 135 | local_dir = os.path.dirname(__file__) 136 | config_path = os.path.join(local_dir, "config.txt") 137 | 138 | config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction, 139 | neat.DefaultSpeciesSet, neat.DefaultStagnation, 140 | config_path) 141 | # run_neat(config) 142 | test_ai(config) 143 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # https://neat-python.readthedocs.io/en/latest/xor_example.html 2 | from pong import Game 3 | import pygame 4 | import neat 5 | import os 6 | import time 7 | import pickle 8 | 9 | 10 | class PongGame: 11 | def __init__(self, window, width, height): 12 | self.game = Game(window, width, height) 13 | self.ball = self.game.ball 14 | self.left_paddle = self.game.left_paddle 15 | self.right_paddle = self.game.right_paddle 16 | 17 | def test_ai(self, net): 18 | """ 19 | Test the AI against a human player by passing a NEAT neural network 20 | """ 21 | clock = pygame.time.Clock() 22 | run = True 23 | while run: 24 | clock.tick(60) 25 | game_info = self.game.loop() 26 | 27 | for event in pygame.event.get(): 28 | if event.type == pygame.QUIT: 29 | run = False 30 | break 31 | 32 | output = net.activate( 33 | ( 34 | self.right_paddle.y, 35 | abs(self.right_paddle.x - self.ball.x), 36 | self.ball.y, 37 | ) 38 | ) 39 | decision = output.index(max(output)) 40 | 41 | if decision == 1: # AI moves up 42 | self.game.move_paddle(left=False, up=True) 43 | elif decision == 2: # AI moves down 44 | self.game.move_paddle(left=False, up=False) 45 | 46 | keys = pygame.key.get_pressed() 47 | if keys[pygame.K_w]: 48 | self.game.move_paddle(left=True, up=True) 49 | elif keys[pygame.K_s]: 50 | self.game.move_paddle(left=True, up=False) 51 | 52 | self.game.draw(draw_score=True) 53 | pygame.display.update() 54 | 55 | def train_ai(self, genome1, genome2, config, draw=False): 56 | """ 57 | Train the AI by passing two NEAT neural networks and the NEAt config object. 58 | These AI's will play against each other to determine their fitness. 59 | """ 60 | run = True 61 | start_time = time.time() 62 | 63 | net1 = neat.nn.FeedForwardNetwork.create(genome1, config) 64 | net2 = neat.nn.FeedForwardNetwork.create(genome2, config) 65 | self.genome1 = genome1 66 | self.genome2 = genome2 67 | 68 | max_hits = 50 69 | 70 | while run: 71 | for event in pygame.event.get(): 72 | if event.type == pygame.QUIT: 73 | return True 74 | 75 | game_info = self.game.loop() 76 | 77 | self.move_ai_paddles(net1, net2) 78 | 79 | if draw: 80 | self.game.draw(draw_score=False, draw_hits=True) 81 | 82 | pygame.display.update() 83 | 84 | duration = time.time() - start_time 85 | if ( 86 | game_info.left_score == 1 87 | or game_info.right_score == 1 88 | or game_info.left_hits >= max_hits 89 | ): 90 | self.calculate_fitness(game_info, duration) 91 | break 92 | 93 | return False 94 | 95 | def move_ai_paddles(self, net1, net2): 96 | """ 97 | Determine where to move the left and the right paddle based on the two 98 | neural networks that control them. 99 | """ 100 | players = [ 101 | (self.genome1, net1, self.left_paddle, True), 102 | (self.genome2, net2, self.right_paddle, False), 103 | ] 104 | for genome, net, paddle, left in players: 105 | output = net.activate((paddle.y, abs(paddle.x - self.ball.x), self.ball.y)) 106 | decision = output.index(max(output)) 107 | 108 | valid = True 109 | if decision == 0: # Don't move 110 | genome.fitness -= 0.01 # we want to discourage this 111 | elif decision == 1: # Move up 112 | valid = self.game.move_paddle(left=left, up=True) 113 | else: # Move down 114 | valid = self.game.move_paddle(left=left, up=False) 115 | 116 | if ( 117 | not valid 118 | ): # If the movement makes the paddle go off the screen punish the AI 119 | genome.fitness -= 1 120 | 121 | def calculate_fitness(self, game_info, duration): 122 | self.genome1.fitness += game_info.left_hits + duration 123 | self.genome2.fitness += game_info.right_hits + duration 124 | 125 | 126 | def eval_genomes(genomes, config): 127 | """ 128 | Run each genome against eachother one time to determine the fitness. 129 | """ 130 | width, height = 700, 500 131 | win = pygame.display.set_mode((width, height)) 132 | pygame.display.set_caption("Pong") 133 | 134 | for i, (genome_id1, genome1) in enumerate(genomes): 135 | print(round(i / len(genomes) * 100), end=" ") 136 | genome1.fitness = 0 137 | for genome_id2, genome2 in genomes[min(i + 1, len(genomes) - 1) :]: 138 | genome2.fitness = 0 if genome2.fitness == None else genome2.fitness 139 | pong = PongGame(win, width, height) 140 | 141 | force_quit = pong.train_ai(genome1, genome2, config, draw=True) 142 | if force_quit: 143 | quit() 144 | 145 | 146 | def run_neat(config): 147 | # p = neat.Checkpointer.restore_checkpoint('neat-checkpoint-85') 148 | p = neat.Population(config) 149 | p.add_reporter(neat.StdOutReporter(True)) 150 | stats = neat.StatisticsReporter() 151 | p.add_reporter(stats) 152 | p.add_reporter(neat.Checkpointer(1)) 153 | 154 | winner = p.run(eval_genomes, 50) 155 | with open("best.pickle", "wb") as f: 156 | pickle.dump(winner, f) 157 | 158 | 159 | def test_best_network(config): 160 | with open("best.pickle", "rb") as f: 161 | winner = pickle.load(f) 162 | winner_net = neat.nn.FeedForwardNetwork.create(winner, config) 163 | 164 | width, height = 700, 500 165 | win = pygame.display.set_mode((width, height)) 166 | pygame.display.set_caption("Pong") 167 | pong = PongGame(win, width, height) 168 | pong.test_ai(winner_net) 169 | 170 | 171 | if __name__ == "__main__": 172 | local_dir = os.path.dirname(__file__) 173 | config_path = os.path.join(local_dir, "config.txt") 174 | 175 | config = neat.Config( 176 | neat.DefaultGenome, 177 | neat.DefaultReproduction, 178 | neat.DefaultSpeciesSet, 179 | neat.DefaultStagnation, 180 | config_path, 181 | ) 182 | 183 | run_neat(config) 184 | test_best_network(config) 185 | -------------------------------------------------------------------------------- /pong/game.py: -------------------------------------------------------------------------------- 1 | from .paddle import Paddle 2 | from .ball import Ball 3 | import pygame 4 | import random 5 | pygame.init() 6 | 7 | 8 | class GameInformation: 9 | def __init__(self, left_hits, right_hits, left_score, right_score): 10 | self.left_hits = left_hits 11 | self.right_hits = right_hits 12 | self.left_score = left_score 13 | self.right_score = right_score 14 | 15 | 16 | class Game: 17 | """ 18 | To use this class simply initialize and instance and call the .loop() method 19 | inside of a pygame event loop (i.e while loop). Inside of your event loop 20 | you can call the .draw() and .move_paddle() methods according to your use case. 21 | Use the information returned from .loop() to determine when to end the game by calling 22 | .reset(). 23 | """ 24 | SCORE_FONT = pygame.font.SysFont("comicsans", 50) 25 | WHITE = (255, 255, 255) 26 | BLACK = (0, 0, 0) 27 | RED = (255, 0, 0) 28 | 29 | def __init__(self, window, window_width, window_height): 30 | self.window_width = window_width 31 | self.window_height = window_height 32 | 33 | self.left_paddle = Paddle( 34 | 10, self.window_height // 2 - Paddle.HEIGHT // 2) 35 | self.right_paddle = Paddle( 36 | self.window_width - 10 - Paddle.WIDTH, self.window_height // 2 - Paddle.HEIGHT//2) 37 | self.ball = Ball(self.window_width // 2, self.window_height // 2) 38 | 39 | self.left_score = 0 40 | self.right_score = 0 41 | self.left_hits = 0 42 | self.right_hits = 0 43 | self.window = window 44 | 45 | def _draw_score(self): 46 | left_score_text = self.SCORE_FONT.render( 47 | f"{self.left_score}", 1, self.WHITE) 48 | right_score_text = self.SCORE_FONT.render( 49 | f"{self.right_score}", 1, self.WHITE) 50 | self.window.blit(left_score_text, (self.window_width // 51 | 4 - left_score_text.get_width()//2, 20)) 52 | self.window.blit(right_score_text, (self.window_width * (3/4) - 53 | right_score_text.get_width()//2, 20)) 54 | 55 | def _draw_hits(self): 56 | hits_text = self.SCORE_FONT.render( 57 | f"{self.left_hits + self.right_hits}", 1, self.RED) 58 | self.window.blit(hits_text, (self.window_width // 59 | 2 - hits_text.get_width()//2, 10)) 60 | 61 | def _draw_divider(self): 62 | for i in range(10, self.window_height, self.window_height//20): 63 | if i % 2 == 1: 64 | continue 65 | pygame.draw.rect( 66 | self.window, self.WHITE, (self.window_width//2 - 5, i, 10, self.window_height//20)) 67 | 68 | def _handle_collision(self): 69 | ball = self.ball 70 | left_paddle = self.left_paddle 71 | right_paddle = self.right_paddle 72 | 73 | if ball.y + ball.RADIUS >= self.window_height: 74 | ball.y_vel *= -1 75 | elif ball.y - ball.RADIUS <= 0: 76 | ball.y_vel *= -1 77 | 78 | if ball.x_vel < 0: 79 | if ball.y >= left_paddle.y and ball.y <= left_paddle.y + Paddle.HEIGHT: 80 | if ball.x - ball.RADIUS <= left_paddle.x + Paddle.WIDTH: 81 | ball.x_vel *= -1 82 | 83 | middle_y = left_paddle.y + Paddle.HEIGHT / 2 84 | difference_in_y = middle_y - ball.y 85 | reduction_factor = (Paddle.HEIGHT / 2) / ball.MAX_VEL 86 | y_vel = difference_in_y / reduction_factor 87 | ball.y_vel = -1 * y_vel 88 | self.left_hits += 1 89 | 90 | else: 91 | if ball.y >= right_paddle.y and ball.y <= right_paddle.y + Paddle.HEIGHT: 92 | if ball.x + ball.RADIUS >= right_paddle.x: 93 | ball.x_vel *= -1 94 | 95 | middle_y = right_paddle.y + Paddle.HEIGHT / 2 96 | difference_in_y = middle_y - ball.y 97 | reduction_factor = (Paddle.HEIGHT / 2) / ball.MAX_VEL 98 | y_vel = difference_in_y / reduction_factor 99 | ball.y_vel = -1 * y_vel 100 | self.right_hits += 1 101 | 102 | def draw(self, draw_score=True, draw_hits=False): 103 | self.window.fill(self.BLACK) 104 | 105 | self._draw_divider() 106 | 107 | if draw_score: 108 | self._draw_score() 109 | 110 | if draw_hits: 111 | self._draw_hits() 112 | 113 | for paddle in [self.left_paddle, self.right_paddle]: 114 | paddle.draw(self.window) 115 | 116 | self.ball.draw(self.window) 117 | 118 | def move_paddle(self, left=True, up=True): 119 | """ 120 | Move the left or right paddle. 121 | 122 | :returns: boolean indicating if paddle movement is valid. 123 | Movement is invalid if it causes paddle to go 124 | off the screen 125 | """ 126 | if left: 127 | if up and self.left_paddle.y - Paddle.VEL < 0: 128 | return False 129 | if not up and self.left_paddle.y + Paddle.HEIGHT > self.window_height: 130 | return False 131 | self.left_paddle.move(up) 132 | else: 133 | if up and self.right_paddle.y - Paddle.VEL < 0: 134 | return False 135 | if not up and self.right_paddle.y + Paddle.HEIGHT > self.window_height: 136 | return False 137 | self.right_paddle.move(up) 138 | 139 | return True 140 | 141 | def loop(self): 142 | """ 143 | Executes a single game loop. 144 | 145 | :returns: GameInformation instance stating score 146 | and hits of each paddle. 147 | """ 148 | self.ball.move() 149 | self._handle_collision() 150 | 151 | if self.ball.x < 0: 152 | self.ball.reset() 153 | self.right_score += 1 154 | elif self.ball.x > self.window_width: 155 | self.ball.reset() 156 | self.left_score += 1 157 | 158 | game_info = GameInformation( 159 | self.left_hits, self.right_hits, self.left_score, self.right_score) 160 | 161 | return game_info 162 | 163 | def reset(self): 164 | """Resets the entire game.""" 165 | self.ball.reset() 166 | self.left_paddle.reset() 167 | self.right_paddle.reset() 168 | self.left_score = 0 169 | self.right_score = 0 170 | self.left_hits = 0 171 | self.right_hits = 0 172 | --------------------------------------------------------------------------------