├── LICENSE ├── README.md ├── ch01CoinGame.ipynb ├── ch02TicTacToe.ipynb ├── ch03ConnectFour.ipynb ├── ch04MiniMax.ipynb ├── ch05DepthPruning.ipynb ├── ch06AlphaBetaPruning.ipynb ├── ch07PositionEvaluationFunction.ipynb ├── ch08MonteCarloTreeSearch.ipynb ├── ch09DeepLearningCoinGame.ipynb ├── ch10PolicyNetworksTicTacToe.ipynb ├── ch11PolicyNetworkConnectFour.ipynb ├── ch12TabularQ-LearningCoinGame.ipynb ├── ch13DeepReinforcementLearningCoin.ipynb ├── ch14VectorizationTicTacToe.ipynb ├── ch15ValueNetworkConnectFour.ipynb ├── ch16AlphaGoCoinGame.ipynb ├── ch17AlphaGoTicTacToe.ipynb ├── ch18AlphaGoTuning.ipynb ├── ch19ACcoinAlphaZero.ipynb ├── ch20AlphaZeroTicTacToe.ipynb ├── ch21AlphaZeroUnsolvedGames.ipynb ├── files ├── CONNzero4.h5 ├── PG_coin.h5 ├── PG_conn.h5 ├── ac_coin.h5 ├── coin_Qs.csv ├── fast_coin.h5 ├── fast_ttt.h5 ├── pg_ttt.h5 ├── policy_conn.h5 ├── strong_coin.h5 ├── strong_ttt.h5 ├── value_coin.h5 ├── value_conn.h5 ├── value_ttt.h5 └── zero_ttt4.h5 ├── tmp.ipynb └── utils ├── __init__.py ├── cash.png ├── ch01util.py ├── ch02util.py ├── ch03util.py ├── ch04util.py ├── ch05util.py ├── ch06util.py ├── ch07util.py ├── ch08util.py ├── ch16util.py ├── ch17util.py ├── ch19util.py ├── ch20util.py ├── coin_env.py ├── coin_simple_env.py ├── conn_env.py ├── conn_simple_env.py ├── ttt_env.py ├── ttt_new_env.py ├── ttt_simple_env.py └── ttt_think3.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Mark Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AlphaGo Simplified: Ruled-Based A.I. and Deep Learning in Everyday Games 2 |
3 | The whole is greater than the sum of its parts. -- Aristotle 4 |
5 | 6 | 7 | 8 | May 11, 1997, was a watershed moment in the history of artificial intelligence (AI): the IBM supercomputer Chess engine, Deep Blue, beat the world Chess champion, Garry Kasparov. It was the first time a machine had triumphed over a human player in a Chess tournament. The news garnered much media attention. But here's the twist: Deep Blue was no machine learning (ML) marvel; it was powered by traditional rule-based AI, a far cry from the modern AI we know today. 9 | 10 | Fast forward 19 years to the electrifying showdown of 2016, where Google DeepMind's AlphaGo took on Go champion Lee Sedol. Once again, AI seized the spotlight, creating a frenzy in the media. But this time, it was a new kind of AI, driven by the incredible force of ML, particularly deep reinforcement learning, that captured the world's imagination. The strategies it employed were a game-changer, quite literally. But that's not all! The real secret to AlphaGo’s success is the fact that it merged deep reinforcement learning with rule-based AI such as Monte Carlo Tree Search (MCTS) to supercharge the game strategies (the whole is indeed greater than the sum of its parts in this case). 11 | 12 | Now, you may be wondering: What is ML, and how does it relate to AI? Why is deep learning all the rage in today's tech-driven world? This book has the answers. In these pages, you'll unravel the inner workings of traditional rule-based AI and modern ML. I'll show you how to apply these concepts to three simple yet captivating games: Last Coin Standing, Tic Tac Toe, and Connect Four. You'll dive into the exciting world of artificial intelligence through the thrilling stories of Deep Blue and AlphaGo, two groundbreaking moments that rewrote the history of AI. 13 | 14 | For readers who are unfamiliar with these games, Last Coin Standing (the coin game from now on) is a game in which two players take turns removing one or two coins at a time from a pile of 21 coins. The player who takes the last coin is the winner. In Tic Tac Toe, two players take turns marking a cell with an X or O in a three-by-three grid. The first player to connect three Xs or Os in a row horizontally, vertically, or diagonally wins. Connect Four pits two players against each other in a quest to form a direct line of four game pieces of the same color. 15 | 16 | Why these three games, you ask? Well, they're simple to grasp and perfect for exploring the world of rule-based AI, deep reinforcement learning (e.g., policy gradients or the actor-critic method), and other AI techniques. There is no need for you to get bogged down in complex game rules. In contrast, games like Chess and Go require a deep well of domain knowledge to devise effective strategies. Deep Blue's algorithm, for instance, uses thousands of rules to evaluate board positions in Chess. The game of Go involves complicated rules like “no self-capturing”, “komi” (compensation for the first mover's advantage), and the tricky concept of "ko" (avoiding creating a previous board position). What's more, applying rule-based AI and deep learning to these three games is fast and easy and doesn't require costly computational resources. As you'll discover in this book, you can train intelligent game strategies on a regular computer in minutes or hours. As a result, the trained AI provides perfect solutions for the first two games, while Connect Four gets super-human performance. In contrast, Chess and Go require supercomputing facilities. Deep Blue blitzed through hundreds of millions of board positions each second, and AlphaGo gobbled up the processing power of 1920 CPUs and 280 GPUs. The average reader doesn't have access to such supercomputing might, which is why we choose everyday simple games to make AI learning accessible to everyone with a regular computer. 17 | 18 | By immersing yourself in these three captivating games, you'll grasp the essence of rule-based AI, from the MiniMax algorithm to alpha-beta pruning and the exhilarating Monte Carlo Tree Search. Afterward, we'll venture into the realm of ML, specifically deep reinforcement learning – the secret sauce behind AlphaGo's victory. DeepMind's brilliant minds used the policy-gradient method and crafted two deep neural networks to train game strategies. You’ll learn how to combine modern ML techniques with rule-based AI to achieve superhuman performance, just like the DeepMind team did in the epic AlphaGo showdown. Get ready for an AI journey that's not just educational but downright thrilling! 19 | 20 | This book is divided into four parts. Part I provides an introduction to the three games and outlines how to develop strategies using rule-based AI techniques like MiniMax and MCTS. Part II delves into deep learning and its application to the three games. Part III explores the fundamentals of reinforcement learning and demonstrates how to enhance game strategies through self-play deep reinforcement learning. Finally, in Part IV, we integrate rule-based AI with deep reinforcement learning to construct AlphaGo (and its successor, AlphaGo Zero) algorithms for the three games. 21 | 22 | Here’s an overview of the book: 23 | 24 | Part I: Rule-Based A.I. 25 | 26 | Chapter 1: Rule-Based AI in the Coin Game 27 | 28 | In this chapter, we introduce the first of the three games featured in this book: the coin game. You'll be guided through the rules of the game and learn how to create a rule-based AI that can achieve a 100\% win rate against any opponent when playing in the second position. 29 | 30 | Chapter 2: Look-Ahead Search in Tic Tac Toe 31 | 32 | you'll delve into the rules of the second game: Tic Tac Toe. You'll learn how to develop strategies that enable the AI to plan up to three moves ahead. To think one step ahead, the AI evaluates all possible next moves to see if any of them immediately result in a win. If so, the AI takes that move. Thinking two steps ahead involves the AI trying to block the opponent from winning in their next turn. By thinking three steps ahead, the AI chooses the path that is most likely to lead to victory after three moves. In many cases, thinking three steps ahead ensures that the AI can secure a win within three moves. 33 | 34 | Chapter 3: Planning Three Steps Ahead in Connect Four 35 | 36 | In this chapter, you'll learn how to play Connect Four in the local game environment and develop rule-based AI strategies for the game. These AI players will think up to three moves ahead. When the AI looks one step ahead, it evaluates all possible next moves and chooses one that could lead to an immediate win. However, thinking two steps ahead in Connect Four introduces more complexity. The AI must determine whether its next move will block the opponent's win or inadvertently help them win. Our strategy for this scenario is twofold: if the AI's move blocks the opponent, it proceeds with that move; otherwise, it avoids a move that helps the opponent win in two moves. When thinking three steps ahead, the AI follows the path that is most likely to lead to victory after three moves. 37 | 38 | Chapter 4: Recursion and MiniMax Tree Search 39 | 40 | Up to this point, you have been hardcoding look-ahead search techniques to devise game strategies. However, extending this approach beyond three steps becomes increasingly cumbersome and prone to errors. To address this, you'll learn a more efficient method for conducting look-ahead searches: the MiniMax tree search. You'll learn how to implement the MiniMax tree search in the coin game by employing recursion, which involves calling a function within itself. This recursive approach enables the MiniMax agent to search ahead until the end of the game. As a result, the MiniMax agent is able to solve the coin game. 41 | 42 | Chapter 5: Depth Pruning in MiniMax 43 | 44 | You'll begin by developing a MiniMax agent for Tic Tac Toe using recursion. The MiniMax tree search method explores all potential future paths and solves the game. However, in more complex games like Connect Four, Chess, or Go, the MiniMax algorithm is unable to exhaustively explore all possible future paths in a reasonable time. To overcome this limitation, the algorithm employs a technique called depth pruning. This approach involves searching a specific number of moves ahead and then stopping the search. You'll learn how to create a MiniMax agent with depth pruning and apply it to Connect Four. 45 | 46 | Chapter 6: Alpha-Beta Pruning 47 | 48 | Alpha-beta pruning enhances the look-ahead search algorithm by allowing it to bypass branches that cannot impact the final outcome of the game. This optimization significantly reduces the time it takes for the MiniMax agent to decide on a move, enabling it to explore more steps ahead and develop smarter moves within a given time frame. To apply alpha-beta pruning in a game, the agent keeps track of two values: alpha and beta, which represent the best possible outcomes for the two players, respectively. Whenever $alpha>−beta$, or equivalently $beta>−alpha$, the agent stops searching a branch. You'll implement alpha-beta pruning in both Tic Tac Toe and Connect Four. By incorporating this technique, the time required for the MiniMax agent to decide on a move is reduced by up to 97\%, making the algorithm much more efficient. 49 | 50 | Chapter 7: Position Evaluation in MiniMax 51 | 52 | Depth pruning enables a MiniMax agent to develop intelligent (though not perfect) strategies quickly in complex games. In earlier chapters, when the depth of a tree search reaches zero and the game is not yet over, the algorithm assumes a tied game. However, in many real-world games, we often have a good estimate of the outcome based on heuristics, even when the game is not over. In this chapter, you'll learn the concept of the position evaluation function and apply it to the Connect Four game. When the depth reaches zero in a tree search and the game is not over, you'll use an evaluation function to assess the value of the game state. This evaluation function provides a more accurate assessment of the game state, enabling the MiniMax agent to make more intelligent moves. 53 | 54 | Chapter 8: Monte Carlo Tree Search 55 | 56 | MiniMax tree search, augmented by powerful position evaluation functions, helped Deep Blue beat the world Chess champion Garry Kasparov in 1997. While in games such as Chess, position evaluation functions are relatively accurate, in other games such as Go, evaluating positions is more challenging. In such scenarios, researchers usually use Monte Carlo Tree Search (MCTS). The idea behind MCTS is to roll out random games starting from the current game state and see what the average game outcome is. After rolling out, say 1000, games from this point, if Player 1 wins 99 percent of the time, the current game state must favor Player 1 over Player 2. To select the best next move, the MCTS algorithm uses the Upper Confidence Bounds for Trees (UCT) method. This chapter breaks down the process into four steps: selection, expansion, simulation, and backpropagation. Readers learn to implement a generic MCTS algorithm that can be applied to the three games in this book: the coin game, Tic Tac Toe, and Connect Four. 57 | 58 | Part II: Deep Learning 59 | 60 | Chapter 9: Deep Learning in the Coin Game 61 | 62 | In this chapter, you'll learn what deep learning is and how it's related to ML and AI. Deep learning is a type of ML method that's based on artificial neural networks. A neural network is a computational model inspired by the structure of neural networks in the human brain. It's designed to recognize patterns in data, and it contains layers of interconnected nodes, or neurons. You'll learn to use deep neural networks to design game strategies for the coin game in this chapter. In particular, you'll build and train two networks, a fast policy network and a strong policy network, to be used in the AlphaGo algorithm later in the book. 63 | 64 | Chapter 10: Policy Networks in Tic Tac Toe 65 | 66 | In this chapter, you'll train a fast policy network and a strong policy network in Tic Tac Toe. The two networks contain convolutional layers, which treat game boards as multi-dimensional objects and extract spatial features from them. Convolutional layers greatly improve the predictive power of neural networks and this, in turn, leads to more intelligent game strategies. To generate expert moves in Tic Tac Toe, you'll use the MiniMax algorithm with alpha-beta pruning from Chapter 6. You’ll use game board positions as inputs and expert moves as targets to train the two policy neural networks. The two trained policy networks will be used in the AlphaGo algorithm later in this book. 67 | 68 | Chapter 11: A Policy Network in Connect Four 69 | 70 | In this chapter, you'll train a policy network in Connect Four so that it can be used in the AlphaGo algorithm later in the book. To generate expert moves, you'll first design an agent who chooses moves by using the MCTS algorithm half of the time and a MiniMax algorithm with alpha-beta pruning the other 50\% of the time. You let the above agent play against itself for 10,000 games. In each game, the winner's moves are considered expert moves while the loser's moves are discarded. You'll then create a deep neural network and use the self-play game experience data to train it to predict expert moves. 71 | 72 | Part III: Reinforcement Learning 73 | 74 | Chapter 12: Tabular Q-Learning in the Coin Game 75 | 76 | You'll learn the basics of reinforcement learning in this chapter. You’ll use tabular Q-learning to solve the coin game. Along the way, you'll learn the concepts of dynamic programming and the Bellman equation. You'll also learn how to train and use Q-tables in tabular Q-learning. 77 | 78 | Chapter 13: Self-Play Deep Reinforcement Learning 79 | 80 | You'll learn to use self-play deep reinforcement learning to further train the strong policy network for the coin game. You'll learn what a policy is and how to implement the policy gradient method to train the agent in deep reinforcement learning. A policy is a decision rule that tells the agent what actions to take in a given game state. In the policy gradients method, the agent engages in numerous game sessions to learn the optimal policy. The agent bases its actions on the model's predictions, observes the resulting rewards, and adjusts the model parameters to align predicted action probabilities with desired probabilities. You'll also use the game experience data from self-plays to train a value network that predicts game outcomes based on current game states. 81 | 82 | Chapter 14: Vectorization to Speed Up Deep Reinforcement Learning 83 | 84 | You'll apply self-play deep reinforcement learning to Tic Tac Toe in this chapter. You'll learn to manage several challenges in the process. Firstly, unlike the coin game, where illegal moves are non-existent as players simply choose to remove one or two coins per turn, Tic Tac Toe has a decreasing number of legal moves as the game advances. You'll learn to assign negative rewards to illegal moves so the trained agent chooses only legal moves. Secondly, the computational demands of training for Tic Tac Toe are significantly higher than for the coin game. You'll learn to use vectorization to speed up the training process. Finally, this chapter will guide you through the process of encoding the game board in a player-independent manner so that the current player's game pieces are represented as 1 and the opponent's as $-1$. This allows the use of the same neural network for training both players. With the above challenges properly handled, you'll implement the policy gradient method in Tic Tac Toe to train a policy network. You'll also train a value network based on the game experience data. 85 | 86 | Chapter 15: A Value Network in Connect Four 87 | 88 | You'll apply the policy gradient method to Connect Four in this chapter. You'll learn to handle illegal moves by assigning a reward of $-1$ every time the agent makes an illegal move. You'll also use vectorization to speed up training. Self-play deep reinforcement learning is used to further train a policy network in Connect Four based on the policy network from Chapter 11. Readers also use the game experience data from self-play to train a value network: the network will predict the game outcome based on the board position. 89 | 90 | Part IV: AlphaGo Algorithms 91 | 92 | Chapter 16: Implementing AlphaGo in the Coin Game 93 | 94 | In this chapter, you'll implement the AlphaGo algorithm in the coin game. MCTS constitutes the backbone of AlphaGo's decision-making process: it is used to find the most promising moves by building a search tree. You'll learn to use three deep neural networks you developed in earlier chapters. The strengthened policy network (from Chapter 13) selects child nodes so that the most valuable child nodes are selected to roll out games. The fast policy network developed in Chapter 9 helps in narrowing down the selection of moves to consider in game rollouts. The value network from Chapter 13 evaluates board positions and predicts the winner of the game from that position so that the agent doesn't have to play out the entire game in rollouts. When moving second, the AlphaGo agent beats the ruled-based AI player in 100\% of games. 95 | 96 | Chapter 17: AlphaGo in Tic Tac Toe and Connect Four 97 | 98 | In this chapter, you'll create an AlphaGo agent featuring two main enhancements. Firstly, it will be versatile and capable of handling two games, Tic Tac Toe and Connect Four. Secondly, the agent's game simulation strategy includes a choice between random moves and those suggested by a fast policy network. MCTS remains the core of the agent's decision process, involving selection, expansion, simulation, and backpropagation. Three deep neural networks, introduced in previous chapters, will enhance the tree search. You'll evaluate the AlphaGo algorithm's effectiveness in Tic Tac Toe. Against the perfect rule-based AI from Chapter 6, the AlphaGo agent consistently draws, indicating its ability to solve the game. 99 | 100 | Chapter 18: Hyperparameter Tuning in AlphaGo 101 | 102 | Unlike agents in the coin game or Tic Tac Toe, the AlpahGo agent in Connect Four does not fully solve the game. This chapter, therefore, focuses on identifying the optimal combination of hyperparameters in AlphaGo that yields the most effective game strategy. Grid search is a common approach for hyperparameter tuning. This process involves experimenting with different permutations of hyperparameters in the model to determine empirically which combination offers the best performance. You'll learn to fine-tune four key hyperparameters for the AlphaGo agent in Connect Four in this chapter. The optimized AlphaGo agent can defeat an AI that plans four steps ahead. 103 | 104 | Chapter 19: The Actor-Critic Method and AlphaZero 105 | 106 | In 2017, the DeepMind team introduced an advanced version of AlphaGo, named AlphaGo Zero (which we'll refer to as AlphaZero in this book because we apply the algorithm to various games beyond Go). AlphaZero's training relied exclusively on deep reinforcement learning, without any human-derived strategies or domain-specific knowledge, except for the basic rules of the game. It learned through self-play from scratch. AlphaZero utilizes a single neural network with two outputs: a policy network for predicting the next move and a value network for forecasting game outcomes. In this chapter, you'll learn this advanced deep reinforcement learning strategy known as the actor-critic method, applying it specifically to the coin game. You'll then integrate both the policy and value networks from the actor-critic method with MCTS for making decisions in actual games, mirroring the approach used in the AlphaGo algorithm in Chapter 16. The AlphaZero agent developed in this chapter, if playing as Player 2, can beat the AlphaGo agent developed earlier in Chapter 16. 107 | 108 | Chapter 20: Iterative Self-Play and AlphaZero in Tic Tac Toe 109 | 110 | You'll learn to construct an AlphaZero agent for both Tic Tac Toe and Connect Four in this chapter. You'll then implement the AlphaZero algorithm in Tic Tac Toe by integrating a policy gradient network with MCTS. To train the model, you'll start a policy gradient network from scratch and initialize it with random weights. During training, the policy gradient agent competes against a more advanced version of itself: the AlphaZero agent. As training progresses, both agents gradually improve their performance. This dynamic scenario presents a unique challenge, as the agent effectively faces a moving target. To address this challenge, an iterative self-play approach is used. Initially, you'll keep the weights of the policy gradient network, as utilized by the AlphaZero agent, constant, while updating the weights within the policy gradient network itself. After an iteration of training, the weights in the policy gradient network used by the AlphaZero agent will be updated. This process is repeated in successive iterations until the AlphaZero agent perfects its gameplay. 111 | 112 | Chapter 21: AlphaZero in Unsolved Games 113 | 114 | In the previous two chapters, you have used rule-based AI to periodically evaluate the AlphaZero agent to gauge its performance and decide when to stop training. Even though rule-based AI was not used in the training process directly, it was used for testing purposes to monitor the agent's performance. In unsolved games, no game-solving algorithm can be used as the benchmark. How should we test the performance of AlphaZero and decide when to stop training in such cases? In this chapter, you'll treat Connect Four as an unsolved game. To test the performance of AlphaZero and decide when to stop training, an earlier version of AlphaZero is used as the benchmark. When AlphaZero outperforms an earlier version of itself by a certain margin, a training iteration is complete. You'll then update the parameters in the older version of AlphaZero and restart the training process. You'll train the AlphaZero model for several iterations so that the AlphaZero agent becomes increasingly stronger. The trained AlphaZero agent consistently outperforms its predecessor, AlphaGo, in Connect Four! 115 | 116 | All Python programs, along with answers to some end-of-the-chapter questions, are 117 | provided in the GitHub repository https://github.com/markhliu/AlphaGoSimplified. 118 | 119 | 120 | -------------------------------------------------------------------------------- /ch04MiniMax.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "653089c9", 6 | "metadata": {}, 7 | "source": [ 8 | "# Chapter 4: Recursion and MiniMax Tree Search\n", 9 | "\n", 10 | "\n", 11 | "\n", 12 | "***\n", 13 | "*“In another thirty years people will laugh at anyone who tries to invent a language without closures, just as they'll laugh now at anyone who tries to invent a language without recursion.”*\n", 14 | "\n", 15 | "-- Mark Jason Dominus, in 2005\n", 16 | "***\n", 17 | "\n", 18 | "\n", 19 | "\n", 20 | "What you'll learn in this chapter:\n", 21 | "\n", 22 | "* The logic behind MiniMax tree search\n", 23 | "* Understanding recursion and applying it in MiniMax tree search\n", 24 | "* Implementing MiniMax tree search in the coin game\n", 25 | "* Testing the effectiveness of the MiniMax agent\n", 26 | "\n", 27 | "In Chapters 2 and 3, you learned how to use look-ahead search to design\n", 28 | "intelligent game strategies in Tic Tac Toe and Connect Four. However,\n", 29 | "the search process was hard coded. If we look ahead beyond three steps,\n", 30 | "the coding becomes tedious and error-prone. You may wonder if there is a\n", 31 | "systematic and more efficient way of conducting look-ahead search. The\n", 32 | "answer is yes: MiniMax tree search does exactly that. The MiniMax\n", 33 | "algorithm is a decision rule in artificial intelligence and game theory.\n", 34 | "The algorithm assumes that each player in the game makes the best\n", 35 | "possible decisions at each step. Further, each player knows that other\n", 36 | "players make fully rational decisions as well, and so on.\n", 37 | "\n", 38 | "In this chapter, you'll learn to implement MiniMax tree search in the\n", 39 | "coin game. Specifically, you'll use recursion to call a function inside\n", 40 | "the function itself. This creates an infinite loop: all command lines in\n", 41 | "the function are executed iteration after iteration until a certain\n", 42 | "condition is met. The recursive algorithm allows the MiniMax agent to search ahead\n", 43 | "until the end of the game.\n", 44 | "\n", 45 | "You'll create a MiniMax agent in the coin game by using the game environment that we developed in\n", 46 | "Chapter 1. The algorithm makes hypothetical future moves and exhausts all\n", 47 | "possible future game paths. The algorithm then uses backward induction\n", 48 | "to calculate the best move in each step of the game. The MiniMax agent\n", 49 | "solves the coin game and plays perfectly: it always wins when it pays\n", 50 | "second. The MiniMax agent makes moves very quickly as well: each move\n", 51 | "takes a fraction of a second.\n", 52 | "\n", 53 | "After this chapter, you'll understand the logic behind MiniMax tree\n", 54 | "search and be able to design game strategies for any game based on the\n", 55 | "algorithm. You'll apply the algorithm to Tic Tac Toe and Connect Four as\n", 56 | "well in the next few chapters and find ways to overcome or mitigate\n", 57 | "drawbacks associated with MiniMax tree search." 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "282ad876", 63 | "metadata": {}, 64 | "source": [ 65 | "# 1. Introducing MiniMax and Recursion\n", 66 | "This section introduces MiniMax tree search and explains the concept of recursion in programming languages. " 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "79c437e8", 72 | "metadata": {}, 73 | "source": [ 74 | "## 1.1. What is MiniMax Tree Search?" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "id": "de2d7767", 80 | "metadata": {}, 81 | "source": [ 82 | "## 1.2. Backward Induction and the Solution to MiniMax" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "id": "cbfacacd", 88 | "metadata": {}, 89 | "source": [ 90 | "## 1.3. What is Recursion? \n", 91 | "Recursion is the calling of a function inside the function itself. We'll use recursion to implement MiniMax tree search in this book. Below, I'll show you one example of recursion. \n", 92 | "\n", 93 | "Suppose you want to create a clock to tell time. The normal approach is as follows:" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 1, 99 | "id": "efd621db", 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "The current time is 19:41:45\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "import time\n", 112 | "\n", 113 | "def clock():\n", 114 | " time_now=time.strftime(\"%H:%M:%S\")\n", 115 | " print(f\"The current time is {time_now}\") \n", 116 | "clock()" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 2, 122 | "id": "66d5e175", 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "name": "stdout", 127 | "output_type": "stream", 128 | "text": [ 129 | "The current time is 19:45:46\n", 130 | "The current time is 19:45:47\n", 131 | "The current time is 19:45:48\n", 132 | "The current time is 19:45:49\n", 133 | "The current time is 19:45:50\n", 134 | "The current time is 19:45:51\n", 135 | "The current time is 19:45:52\n", 136 | "The current time is 19:45:53\n", 137 | "The current time is 19:45:54\n", 138 | "The current time is 19:45:55\n" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "start=time.time()\n", 144 | "def clock():\n", 145 | " time_now=time.strftime(\"%H:%M:%S\")\n", 146 | " print(f\"The current time is {time_now}\") \n", 147 | " time.sleep(1)\n", 148 | " if time.time()-start<=10:\n", 149 | " clock()\n", 150 | "clock() " 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "id": "ff3fe384", 156 | "metadata": {}, 157 | "source": [ 158 | "In the above cell, we call the *clock()* function in the function itself, unless more than ten seconds have passed. As a result, the function tells time for ten consecutive seconds. " 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "id": "4db31da9", 164 | "metadata": {}, 165 | "source": [ 166 | "# 2. MiniMax Tree Search in the Coin Game" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "id": "7252a608", 172 | "metadata": {}, 173 | "source": [ 174 | "We'll use a simplified version of the self-made coin game environment from Chapter 1 to speed up MiniMax tree search. Specifically, the module is saved as *coin_simple_env.py* in the folder *utils* in the book's GitHub repository https://github.com/markhliu/AlphaGoSimplified. Download the file and save it in the folder /Desktop/ags/utils/ on your computer. The file *coin_simple_env.py* is the same as *coin_env.py* that we used in Chapter 1, except that we have deleted the graphical game window functionality. As a result, you cannot use the render() method in the simplified coin game environment. We use the simplified coin game environment to make the MiniMax agent make moves faster. \n", 175 | "\n", 176 | "First, let's define a couple of functions that the MiniMax algorithm will use. " 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "id": "52ddc025", 182 | "metadata": {}, 183 | "source": [ 184 | "## 2.1. The *MiniMax()* Function " 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 3, 190 | "id": "fb8839e8", 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "from copy import deepcopy\n", 195 | "from random import choice\n", 196 | "\n", 197 | "def MiniMax(env):\n", 198 | " # create a list to store winning moves\n", 199 | " wins=[]\n", 200 | " # iterate through all possible next moves\n", 201 | " for m in env.validinputs:\n", 202 | " # make a hypothetical move and see what happens\n", 203 | " env_copy=deepcopy(env)\n", 204 | " new_state, reward, done, info = env_copy.step(m) \n", 205 | " # if move m lead to a win now, take it\n", 206 | " if done and reward==1:\n", 207 | " return m \n", 208 | " # see what's the best response from the opponent\n", 209 | " opponent_payoff=maximized_payoff(env_copy,reward,done) \n", 210 | " # Opponent's payoff is the opposite of your payoff\n", 211 | " my_payoff=-opponent_payoff \n", 212 | " if my_payoff==1:\n", 213 | " wins.append(m)\n", 214 | " # pick winning moves if there is any \n", 215 | " if len(wins)>0:\n", 216 | " return choice(wins)\n", 217 | " # otherwise randomly pick\n", 218 | " return env.sample()" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "id": "eee28a4e", 224 | "metadata": {}, 225 | "source": [ 226 | "## 2.2. The *maximized_payoff()* Function \n", 227 | "Next, we'll define *maximized_payoff()* function in the local module *ch04util*. The function produces the best possible outcome for the next player in the next step of the game. Note this function applies to any player in any stage of the game so we don't need to define one function for Player 1 and another function for Player 2." 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 4, 233 | "id": "a847d4fe", 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "def maximized_payoff(env, reward, done):\n", 238 | " # if the game has ended after the previous player's move\n", 239 | " if done:\n", 240 | " return -1\n", 241 | " # otherwise, search for action to maximize payoff\n", 242 | " best_payoff=-2\n", 243 | " # iterate through all possible next moves\n", 244 | " for m in env.validinputs:\n", 245 | " env_copy=deepcopy(env)\n", 246 | " new_state,reward,done,info=env_copy.step(m) \n", 247 | " # what's the opponent's response\n", 248 | " opponent_payoff=maximized_payoff(env_copy, reward, done)\n", 249 | " # opponent's payoff is the opposite of your payoff\n", 250 | " my_payoff=-opponent_payoff \n", 251 | " # update your best payoff \n", 252 | " if my_payoff>best_payoff: \n", 253 | " best_payoff=my_payoff\n", 254 | " return best_payoff" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "id": "d963a4e3", 260 | "metadata": {}, 261 | "source": [ 262 | "## 2.3. Human versus MiniMax in the Coin Game\n", 263 | "Next, you'll play a game against the MiniMax algorithm. We'll let the MiniMax agent move second and see if it can win the game. " 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 5, 269 | "id": "15b9c0d7", 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "there are 21 coins in the pile\n", 277 | "Player 1, what's your move (1 or 2)?2\n", 278 | "Player 1 has chosen action=2\n", 279 | "there are 19 coins in the pile\n", 280 | "time lapse = 0.31630 seconds\n", 281 | "Player 2 has chosen action=1\n", 282 | "there are 18 coins in the pile\n", 283 | "Player 1, what's your move (1 or 2)?1\n", 284 | "Player 1 has chosen action=1\n", 285 | "there are 17 coins in the pile\n", 286 | "time lapse = 0.12109 seconds\n", 287 | "Player 2 has chosen action=2\n", 288 | "there are 15 coins in the pile\n", 289 | "Player 1, what's your move (1 or 2)?2\n", 290 | "Player 1 has chosen action=2\n", 291 | "there are 13 coins in the pile\n", 292 | "time lapse = 0.02154 seconds\n", 293 | "Player 2 has chosen action=1\n", 294 | "there are 12 coins in the pile\n", 295 | "Player 1, what's your move (1 or 2)?1\n", 296 | "Player 1 has chosen action=1\n", 297 | "there are 11 coins in the pile\n", 298 | "time lapse = 0.00852 seconds\n", 299 | "Player 2 has chosen action=2\n", 300 | "there are 9 coins in the pile\n", 301 | "Player 1, what's your move (1 or 2)?2\n", 302 | "Player 1 has chosen action=2\n", 303 | "there are 7 coins in the pile\n", 304 | "time lapse = 0.00100 seconds\n", 305 | "Player 2 has chosen action=1\n", 306 | "there are 6 coins in the pile\n", 307 | "Player 1, what's your move (1 or 2)?1\n", 308 | "Player 1 has chosen action=1\n", 309 | "there are 5 coins in the pile\n", 310 | "time lapse = 0.00000 seconds\n", 311 | "Player 2 has chosen action=2\n", 312 | "there are 3 coins in the pile\n", 313 | "Player 1, what's your move (1 or 2)?2\n", 314 | "Player 1 has chosen action=2\n", 315 | "there are 1 coins in the pile\n", 316 | "time lapse = 0.00000 seconds\n", 317 | "Player 2 has chosen action=1\n", 318 | "there are 0 coins in the pile\n", 319 | "Player 2 has won!\n" 320 | ] 321 | } 322 | ], 323 | "source": [ 324 | "from utils.coin_simple_env import coin_game\n", 325 | "from utils.ch04util import MiniMax\n", 326 | "\n", 327 | "# Initiate the game environment\n", 328 | "env=coin_game()\n", 329 | "state=env.reset() \n", 330 | "# Play a full game\n", 331 | "while True:\n", 332 | " print(f\"there are {state} coins in the pile\") \n", 333 | " action=input(\"Player 1, what's your move (1 or 2)?\")\n", 334 | " print(f\"Player 1 has chosen action={action}\") \n", 335 | " state, reward, done, info=env.step(action)\n", 336 | " if done:\n", 337 | " print(f\"there are {state} coins in the pile\")\n", 338 | " print(f\"Player 1 has won!\") \n", 339 | " break\n", 340 | " print(f\"there are {state} coins in the pile\") \n", 341 | " start=time.time()\n", 342 | " action=MiniMax(env)\n", 343 | " print(f\"time lapse = {time.time()-start:.5f} seconds\") \n", 344 | " print(f\"Player 2 has chosen action={action}\") \n", 345 | " state, reward, done, info=env.step(action)\n", 346 | " if done:\n", 347 | " print(f\"there are {state} coins in the pile\")\n", 348 | " print(f\"Player 2 has won!\") \n", 349 | " break" 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "id": "c6ece520", 355 | "metadata": {}, 356 | "source": [ 357 | "# 3. Effectiveness of MiniMax in the Coin Game\n", 358 | "Next, we’ll test how often the MiniMax Algorithm wins against the rule-based AI game strategy that we developed in Chapter 1. We'll first let the MiniMax agent play againt random moves. We'll then test the MiniMax agent against the rule-based AI. " 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "id": "5cbe6000", 364 | "metadata": {}, 365 | "source": [ 366 | "## 3.1. Minimax versus Random Moves in the Coin Game" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 6, 372 | "id": "6ebc9a63", 373 | "metadata": {}, 374 | "outputs": [ 375 | { 376 | "name": "stdout", 377 | "output_type": "stream", 378 | "text": [ 379 | "the MiniMax algorithm won 96 games\n", 380 | "the MiniMax algorithm lost 4 games\n" 381 | ] 382 | } 383 | ], 384 | "source": [ 385 | "from utils.ch01util import random_player, one_coin_game\n", 386 | "\n", 387 | "env=coin_game()\n", 388 | "results=[]\n", 389 | "for i in range(100):\n", 390 | " # MiniMax moves first \n", 391 | " result=one_coin_game(MiniMax,random_player)\n", 392 | " # record game outcome\n", 393 | " results.append(result)\n", 394 | "# count how many times MiniMax has won\n", 395 | "wins=results.count(1)\n", 396 | "print(f\"the MiniMax algorithm won {wins} games\")\n", 397 | "# count how many times MiniMax has lost\n", 398 | "losses=results.count(-1)\n", 399 | "print(f\"the MiniMax algorithm lost {losses} games\") " 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": 7, 405 | "id": "9da020e9", 406 | "metadata": {}, 407 | "outputs": [ 408 | { 409 | "name": "stdout", 410 | "output_type": "stream", 411 | "text": [ 412 | "the MiniMax algorithm won 100 games\n", 413 | "the MiniMax algorithm lost 0 games\n" 414 | ] 415 | } 416 | ], 417 | "source": [ 418 | "results=[]\n", 419 | "for i in range(100):\n", 420 | " # MiniMax moves second\n", 421 | " result=one_coin_game(random_player,MiniMax)\n", 422 | " # record negative game outcome\n", 423 | " results.append(-result)\n", 424 | "# count how many times MiniMax has won\n", 425 | "wins=results.count(1)\n", 426 | "print(f\"the MiniMax algorithm won {wins} games\")\n", 427 | "# count how many times MiniMax has lost\n", 428 | "losses=results.count(-1)\n", 429 | "print(f\"the MiniMax algorithm lost {losses} games\") " 430 | ] 431 | }, 432 | { 433 | "cell_type": "markdown", 434 | "id": "61143017", 435 | "metadata": {}, 436 | "source": [ 437 | "## 3.2. MiniMax versus Rule-Based AI in the Coin Game" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 8, 443 | "id": "a2328736", 444 | "metadata": {}, 445 | "outputs": [ 446 | { 447 | "name": "stdout", 448 | "output_type": "stream", 449 | "text": [ 450 | "the MiniMax algorithm won 0 games\n", 451 | "the MiniMax algorithm lost 100 games\n" 452 | ] 453 | } 454 | ], 455 | "source": [ 456 | "from utils.ch01util import rule_based_AI\n", 457 | "\n", 458 | "env=coin_game()\n", 459 | "results=[]\n", 460 | "for i in range(100):\n", 461 | " # MiniMax moves first \n", 462 | " result=one_coin_game(MiniMax,rule_based_AI)\n", 463 | " # record game outcome\n", 464 | " results.append(result)\n", 465 | "# count how many times MiniMax has won\n", 466 | "wins=results.count(1)\n", 467 | "print(f\"the MiniMax algorithm won {wins} games\")\n", 468 | "# count how many times MiniMax has lost\n", 469 | "losses=results.count(-1)\n", 470 | "print(f\"the MiniMax algorithm lost {losses} games\") " 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": 9, 476 | "id": "811ad71e", 477 | "metadata": {}, 478 | "outputs": [ 479 | { 480 | "name": "stdout", 481 | "output_type": "stream", 482 | "text": [ 483 | "the MiniMax algorithm won 100 games\n", 484 | "the MiniMax algorithm lost 0 games\n" 485 | ] 486 | } 487 | ], 488 | "source": [ 489 | "results=[]\n", 490 | "for i in range(100):\n", 491 | " # MiniMax moves second\n", 492 | " result=one_coin_game(rule_based_AI,MiniMax)\n", 493 | " # record negative game outcome\n", 494 | " results.append(-result)\n", 495 | "# count how many times MiniMax has won\n", 496 | "wins=results.count(1)\n", 497 | "print(f\"the MiniMax algorithm won {wins} games\")\n", 498 | "# count how many times MiniMax has lost\n", 499 | "losses=results.count(-1)\n", 500 | "print(f\"the MiniMax algorithm lost {losses} games\") " 501 | ] 502 | } 503 | ], 504 | "metadata": { 505 | "kernelspec": { 506 | "display_name": "Python 3 (ipykernel)", 507 | "language": "python", 508 | "name": "python3" 509 | }, 510 | "language_info": { 511 | "codemirror_mode": { 512 | "name": "ipython", 513 | "version": 3 514 | }, 515 | "file_extension": ".py", 516 | "mimetype": "text/x-python", 517 | "name": "python", 518 | "nbconvert_exporter": "python", 519 | "pygments_lexer": "ipython3", 520 | "version": "3.9.12" 521 | } 522 | }, 523 | "nbformat": 4, 524 | "nbformat_minor": 5 525 | } 526 | -------------------------------------------------------------------------------- /ch06AlphaBetaPruning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "653089c9", 6 | "metadata": {}, 7 | "source": [ 8 | "# Chapter 6: Alpha-Beta Pruning\n", 9 | "\n", 10 | "\n", 11 | "\n", 12 | "***\n", 13 | "*“Art is the elimination of the unnecessary.”*\n", 14 | "\n", 15 | "-- Pablo Picasso\n", 16 | "***\n", 17 | "\n", 18 | "\n", 19 | "\n", 20 | "What you'll learn in this chapter:\n", 21 | "\n", 22 | "* The logic behind alpha-beta pruning\n", 23 | "* Implementing alpha-beta pruning in Tic Tac Toe and Connect Four\n", 24 | "* Calculating time saved by the alpha-beta pruning agent\n", 25 | "* Verifying that alpha-beta pruning won’t affect game outcomes\n", 26 | "\n", 27 | "As you have seen in Chapter 5, depth pruning makes the MiniMax algorithm possible in complicated games such as Connect Four, Chess, and Go. In this chapter, you'll use another method to improve the MiniMax algorithm and make it more efficient. Specifically, alpha beta pruning allows us to skip certain branches that cannot possibly influence the final game outcome. Doing so significantly reduces the amount of time for the MiniMax agent to come up with a move.\n", 28 | "\n", 29 | "To implement alpha-beta pruning in a game, we keep track of two numbers: alpha and beta, the best outcomes so far for Players 1 and 2, respectively. Whenever we have $alpha>-beta$, or equivalently $beta>-alpha$, the MiniMax algorithm stop searching a branch. \n", 30 | "\n", 31 | "We implement alpha-beta pruning in both Tic Tac Toe and Connect Four in this chapter. We show that the outcomes are the same with and without alpha-beta pruning. We also show that alpha-beta pruning saves significant amount of time for the player to find the best moves. For example, in Tic Tac Toe, the amount of time for the MiniMax agent to come up with the first move decreases from 34 seconds without alpha-beta pruning to 1.06 seconds with alpha-beta pruning, a 97% reduction in the amount of time the MiniMax agent needs to come up with a move. In Connect Four, we find that on average, the time spent on a move has reduced from 0.15 seconds to 0.05 seconds after we added in alpha-beta pruning when we limit the depth to three. " 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "id": "282ad876", 37 | "metadata": {}, 38 | "source": [ 39 | "# 1. What is Alpha Beta Pruning?" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "id": "5cbe6000", 45 | "metadata": {}, 46 | "source": [ 47 | "# 2. Alpha-Beta Pruning in Tic Tac Toe" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "id": "7710c0ce", 53 | "metadata": {}, 54 | "source": [ 55 | "## 2.1. The maximized_payoff_ttt() Function" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 1, 61 | "id": "6b94f675", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "def maximized_payoff_ttt(env,reward,done,alpha,beta):\n", 66 | " # if game ended after previous player's move\n", 67 | " if done:\n", 68 | " # if it's not a tie\n", 69 | " if reward!=0:\n", 70 | " return -1\n", 71 | " else:\n", 72 | " return 0\n", 73 | " # set initial alpha and beta to -2\n", 74 | " if alpha==None:\n", 75 | " alpha=-2\n", 76 | " if beta==None:\n", 77 | " beta=-2\n", 78 | " if env.turn==\"X\":\n", 79 | " best_payoff = alpha\n", 80 | " if env.turn==\"O\":\n", 81 | " best_payoff = beta \n", 82 | " # iterate through all possible moves\n", 83 | " for m in env.validinputs:\n", 84 | " env_copy=deepcopy(env)\n", 85 | " state,reward,done,info=env_copy.step(m) \n", 86 | " # If I make this move, what's the opponent's response?\n", 87 | " opponent_payoff=maximized_payoff_ttt(env_copy,\\\n", 88 | " reward,done,alpha,beta)\n", 89 | " # Opponent's payoff is the opposite of your payoff\n", 90 | " my_payoff=-opponent_payoff \n", 91 | " if my_payoff > best_payoff: \n", 92 | " best_payoff = my_payoff\n", 93 | " if env.turn==\"X\":\n", 94 | " alpha=best_payoff\n", 95 | " if env.turn==\"O\":\n", 96 | " beta=best_payoff \n", 97 | " # skip the rest of the branch \n", 98 | " if alpha>=-beta:\n", 99 | " break \n", 100 | " return best_payoff " 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "id": "db6cead4", 106 | "metadata": {}, 107 | "source": [ 108 | "## 2.2. The MiniMax_ab() Function" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 2, 114 | "id": "06ef98d9", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "def MiniMax_ab(env):\n", 119 | " wins=[]\n", 120 | " ties=[]\n", 121 | " losses=[] \n", 122 | " # iterate through all possible next moves\n", 123 | " for m in env.validinputs:\n", 124 | " # make a hypothetical move and see what happens\n", 125 | " env_copy=deepcopy(env)\n", 126 | " state,reward,done,info=env_copy.step(m) \n", 127 | " # If player X wins right away with move m, take it.\n", 128 | " if done and reward!=0:\n", 129 | " return m \n", 130 | " # See what's the best response from the opponent\n", 131 | " opponent_payoff=maximized_payoff_ttt(env_copy,\\\n", 132 | " reward,done,-2,-2) \n", 133 | " # Opponent's payoff is the opposite of your payoff\n", 134 | " my_payoff=-opponent_payoff \n", 135 | " if my_payoff==1:\n", 136 | " wins.append(m)\n", 137 | " elif my_payoff==0:\n", 138 | " ties.append(m)\n", 139 | " else:\n", 140 | " losses.append(m)\n", 141 | " # pick winning moves if there is any \n", 142 | " if len(wins)>0:\n", 143 | " return choice(wins)\n", 144 | " # otherwise pick tying moves\n", 145 | " elif len(ties)>0:\n", 146 | " return choice(ties)\n", 147 | " return env.sample() " 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "0838d4e3", 153 | "metadata": {}, 154 | "source": [ 155 | "## 2.3. Time Saved by Alpha-Beta Pruning\n" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 3, 161 | "id": "20703e88", 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "name": "stdout", 166 | "output_type": "stream", 167 | "text": [ 168 | "Player X has chosen action=1\n", 169 | "It took the agent 1.0604541301727295 seconds\n", 170 | "Current state is \n", 171 | "[[0 0 0]\n", 172 | " [0 0 0]\n", 173 | " [1 0 0]]\n", 174 | "Player O, what's your move?\n", 175 | "9\n", 176 | "Player O has chosen action=9\n", 177 | "Current state is \n", 178 | "[[ 0 0 -1]\n", 179 | " [ 0 0 0]\n", 180 | " [ 1 0 0]]\n", 181 | "Player X has chosen action=7\n", 182 | "It took the agent 0.06767606735229492 seconds\n", 183 | "Current state is \n", 184 | "[[ 1 0 -1]\n", 185 | " [ 0 0 0]\n", 186 | " [ 1 0 0]]\n", 187 | "Player O, what's your move?\n", 188 | "4\n", 189 | "Player O has chosen action=4\n", 190 | "Current state is \n", 191 | "[[ 1 0 -1]\n", 192 | " [-1 0 0]\n", 193 | " [ 1 0 0]]\n", 194 | "Player X has chosen action=3\n", 195 | "It took the agent 0.006998538970947266 seconds\n", 196 | "Current state is \n", 197 | "[[ 1 0 -1]\n", 198 | " [-1 0 0]\n", 199 | " [ 1 0 1]]\n", 200 | "Player O, what's your move?\n", 201 | "5\n", 202 | "Player O has chosen action=5\n", 203 | "Current state is \n", 204 | "[[ 1 0 -1]\n", 205 | " [-1 -1 0]\n", 206 | " [ 1 0 1]]\n", 207 | "Player X has chosen action=2\n", 208 | "It took the agent 0.0 seconds\n", 209 | "Current state is \n", 210 | "[[ 1 0 -1]\n", 211 | " [-1 -1 0]\n", 212 | " [ 1 1 1]]\n", 213 | "Player X has won!\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "from utils.ch06util import MiniMax_ab\n", 219 | "from utils.ttt_simple_env import ttt\n", 220 | "import time\n", 221 | "\n", 222 | "# Initiate the game environment\n", 223 | "env=ttt()\n", 224 | "state=env.reset() \n", 225 | "# Play a full game manually\n", 226 | "while True:\n", 227 | " # Mesure how long it takes to come up with a move\n", 228 | " start=time.time()\n", 229 | " action=MiniMax_ab(env)\n", 230 | " end=time.time()\n", 231 | " print(f\"Player X has chosen action={action}\") \n", 232 | " print(f\"It took the agent {end-start} seconds\") \n", 233 | " state, reward, done, info = env.step(action)\n", 234 | " print(f\"Current state is \\n{state.reshape(3,3)[::-1]}\")\n", 235 | " if done:\n", 236 | " if reward==1:\n", 237 | " print(f\"Player X has won!\") \n", 238 | " else:\n", 239 | " print(\"Game over, it's a tie!\")\n", 240 | " break \n", 241 | " action = input(\"Player O, what's your move?\\n\")\n", 242 | " print(f\"Player O has chosen action={action}\") \n", 243 | " state, reward, done, info = env.step(int(action))\n", 244 | " print(f\"Current state is \\n{state.reshape(3,3)[::-1]}\")\n", 245 | " if done:\n", 246 | " print(f\"Player O has won!\") \n", 247 | " break" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "id": "975a7d1a", 253 | "metadata": {}, 254 | "source": [ 255 | "It took only 1.06 seconds for the MiniMax agent to make the first move, instead of 34 seconds when alpha-beta pruning is not used. That's a huge improvement on the efficiency of the algorithm without affecting the effectiveness of the agent. " 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "id": "58df6cd7", 261 | "metadata": {}, 262 | "source": [ 263 | "# 3. Test MiniMax with Alpha-Beta Pruning\n" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 4, 269 | "id": "0c3d78a2", 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "from utils.ch05util import MiniMax_X,MiniMax_O\n", 274 | "from utils.ch02util import one_ttt_game \n", 275 | "\n", 276 | "results=[]\n", 277 | "for i in range(10):\n", 278 | " # MiniMax with pruning moves first if i is an even number\n", 279 | " if i%2==0:\n", 280 | " result=one_ttt_game(MiniMax_ab,MiniMax_O)\n", 281 | " # record game outcome\n", 282 | " results.append(result)\n", 283 | " # MiniMax with pruning moves second if i is an odd number\n", 284 | " else:\n", 285 | " result=one_ttt_game(MiniMax_X,MiniMax_ab)\n", 286 | " # record negative game outcome\n", 287 | " results.append(-result)" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 5, 293 | "id": "787481ac", 294 | "metadata": {}, 295 | "outputs": [ 296 | { 297 | "name": "stdout", 298 | "output_type": "stream", 299 | "text": [ 300 | "MiniMax with pruning won 0 games\n", 301 | "MiniMax with pruning lost 0 games\n", 302 | "the game was tied 10 times\n" 303 | ] 304 | } 305 | ], 306 | "source": [ 307 | "# count how many times MiniMax with pruning won\n", 308 | "wins=results.count(1)\n", 309 | "print(f\"MiniMax with pruning won {wins} games\")\n", 310 | "# count how many times MiniMax with pruning lost\n", 311 | "losses=results.count(-1)\n", 312 | "print(f\"MiniMax with pruning lost {losses} games\")\n", 313 | "# count tie games\n", 314 | "ties=results.count(0)\n", 315 | "print(f\"the game was tied {ties} times\") " 316 | ] 317 | }, 318 | { 319 | "cell_type": "markdown", 320 | "id": "44a3674b", 321 | "metadata": {}, 322 | "source": [ 323 | "# 4. Alpha-Beta Pruning in Connect Four\n" 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "id": "3daff7f0", 329 | "metadata": {}, 330 | "source": [ 331 | "## 4.1. Add Alpha-Beta Pruning in Connect Four\n" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 6, 337 | "id": "9eecbe91", 338 | "metadata": {}, 339 | "outputs": [], 340 | "source": [ 341 | "def max_payoff_conn(env,reward,done,depth,alpha,beta):\n", 342 | " # if the game has ended after the previous player's move\n", 343 | " if done:\n", 344 | " # if it's not a tie\n", 345 | " if reward!=0:\n", 346 | " return -1\n", 347 | " else:\n", 348 | " return 0\n", 349 | " # If the maximum depth is reached, assume tie game\n", 350 | " if depth==0:\n", 351 | " return 0 \n", 352 | " if alpha==None:\n", 353 | " alpha=-2\n", 354 | " if beta==None:\n", 355 | " beta=-2\n", 356 | " if env.turn==\"red\":\n", 357 | " best_payoff = alpha\n", 358 | " if env.turn==\"yellow\":\n", 359 | " best_payoff = beta \n", 360 | " # iterate through all possible moves\n", 361 | " for m in env.validinputs:\n", 362 | " env_copy=deepcopy(env)\n", 363 | " state,reward,done,info=env_copy.step(m) \n", 364 | " # If I make this move, what's the opponent's response?\n", 365 | " opponent_payoff=max_payoff_conn(env_copy,\\\n", 366 | " reward,done,depth-1,alpha,beta)\n", 367 | " # Opponent's payoff is the opposite of your payoff\n", 368 | " my_payoff=-opponent_payoff \n", 369 | " if my_payoff > best_payoff: \n", 370 | " best_payoff = my_payoff\n", 371 | " if env.turn==\"red\":\n", 372 | " alpha=best_payoff\n", 373 | " if env.turn==\"yellow\":\n", 374 | " beta=best_payoff \n", 375 | " # Skip the rest of the branch\n", 376 | " if alpha>=-beta:\n", 377 | " break \n", 378 | " return best_payoff " 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 7, 384 | "id": "4f1d6271", 385 | "metadata": {}, 386 | "outputs": [], 387 | "source": [ 388 | "def MiniMax_conn(env,depth=3):\n", 389 | " wins=[]\n", 390 | " ties=[]\n", 391 | " losses=[] \n", 392 | " # iterate through all possible next moves\n", 393 | " for m in env.validinputs:\n", 394 | " # make a hypothetical move and see what happens\n", 395 | " env_copy=deepcopy(env)\n", 396 | " state,reward,done,info=env_copy.step(m) \n", 397 | " # If player X wins right away with move m, take it.\n", 398 | " if done and reward!=0:\n", 399 | " return m \n", 400 | " # See what's the best response from the opponent\n", 401 | " opponent_payoff=max_payoff_conn(env_copy,\\\n", 402 | " reward,done,depth,-2,-2) \n", 403 | " # Opponent's payoff is the opposite of your payoff\n", 404 | " my_payoff=-opponent_payoff \n", 405 | " if my_payoff==1:\n", 406 | " wins.append(m)\n", 407 | " elif my_payoff==0:\n", 408 | " ties.append(m)\n", 409 | " else:\n", 410 | " losses.append(m)\n", 411 | " # pick winning moves if there is any \n", 412 | " if len(wins)>0:\n", 413 | " return choice(wins)\n", 414 | " # otherwise pick tying moves\n", 415 | " elif len(ties)>0:\n", 416 | " return choice(ties)\n", 417 | " return env.sample() " 418 | ] 419 | }, 420 | { 421 | "cell_type": "markdown", 422 | "id": "3cf67f75", 423 | "metadata": {}, 424 | "source": [ 425 | "## 4.2. Time Saved due to Alpha-Beta Pruning in Connect Four\n" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": 8, 431 | "id": "d0982762", 432 | "metadata": {}, 433 | "outputs": [ 434 | { 435 | "name": "stdout", 436 | "output_type": "stream", 437 | "text": [ 438 | "The red player has chosen action=6\n", 439 | "It took the agent 0.034066200256347656 seconds\n", 440 | "Current state is \n", 441 | "[[0 0 0 0 0 0 0]\n", 442 | " [0 0 0 0 0 0 0]\n", 443 | " [0 0 0 0 0 0 0]\n", 444 | " [0 0 0 0 0 0 0]\n", 445 | " [0 0 0 0 0 0 0]\n", 446 | " [0 0 0 0 0 1 0]]\n", 447 | "Player yellow, what's your move?\n", 448 | "1\n", 449 | "Player yellow has chosen action=1\n", 450 | "Current state is \n", 451 | "[[ 0 0 0 0 0 0 0]\n", 452 | " [ 0 0 0 0 0 0 0]\n", 453 | " [ 0 0 0 0 0 0 0]\n", 454 | " [ 0 0 0 0 0 0 0]\n", 455 | " [ 0 0 0 0 0 0 0]\n", 456 | " [-1 0 0 0 0 1 0]]\n", 457 | "The red player has chosen action=1\n", 458 | "It took the agent 0.03681349754333496 seconds\n", 459 | "Current state is \n", 460 | "[[ 0 0 0 0 0 0 0]\n", 461 | " [ 0 0 0 0 0 0 0]\n", 462 | " [ 0 0 0 0 0 0 0]\n", 463 | " [ 0 0 0 0 0 0 0]\n", 464 | " [ 1 0 0 0 0 0 0]\n", 465 | " [-1 0 0 0 0 1 0]]\n", 466 | "Player yellow, what's your move?\n", 467 | "2\n", 468 | "Player yellow has chosen action=2\n", 469 | "Current state is \n", 470 | "[[ 0 0 0 0 0 0 0]\n", 471 | " [ 0 0 0 0 0 0 0]\n", 472 | " [ 0 0 0 0 0 0 0]\n", 473 | " [ 0 0 0 0 0 0 0]\n", 474 | " [ 1 0 0 0 0 0 0]\n", 475 | " [-1 -1 0 0 0 1 0]]\n", 476 | "The red player has chosen action=5\n", 477 | "It took the agent 0.047269582748413086 seconds\n", 478 | "Current state is \n", 479 | "[[ 0 0 0 0 0 0 0]\n", 480 | " [ 0 0 0 0 0 0 0]\n", 481 | " [ 0 0 0 0 0 0 0]\n", 482 | " [ 0 0 0 0 0 0 0]\n", 483 | " [ 1 0 0 0 0 0 0]\n", 484 | " [-1 -1 0 0 1 1 0]]\n", 485 | "Player yellow, what's your move?\n", 486 | "1\n", 487 | "Player yellow has chosen action=1\n", 488 | "Current state is \n", 489 | "[[ 0 0 0 0 0 0 0]\n", 490 | " [ 0 0 0 0 0 0 0]\n", 491 | " [ 0 0 0 0 0 0 0]\n", 492 | " [-1 0 0 0 0 0 0]\n", 493 | " [ 1 0 0 0 0 0 0]\n", 494 | " [-1 -1 0 0 1 1 0]]\n", 495 | "The red player has chosen action=4\n", 496 | "It took the agent 0.049561500549316406 seconds\n", 497 | "Current state is \n", 498 | "[[ 0 0 0 0 0 0 0]\n", 499 | " [ 0 0 0 0 0 0 0]\n", 500 | " [ 0 0 0 0 0 0 0]\n", 501 | " [-1 0 0 0 0 0 0]\n", 502 | " [ 1 0 0 0 0 0 0]\n", 503 | " [-1 -1 0 1 1 1 0]]\n", 504 | "Player yellow, what's your move?\n", 505 | "3\n", 506 | "Player yellow has chosen action=3\n", 507 | "Current state is \n", 508 | "[[ 0 0 0 0 0 0 0]\n", 509 | " [ 0 0 0 0 0 0 0]\n", 510 | " [ 0 0 0 0 0 0 0]\n", 511 | " [-1 0 0 0 0 0 0]\n", 512 | " [ 1 0 0 0 0 0 0]\n", 513 | " [-1 -1 -1 1 1 1 0]]\n", 514 | "The red player has chosen action=7\n", 515 | "It took the agent 0.05922341346740723 seconds\n", 516 | "Current state is \n", 517 | "[[ 0 0 0 0 0 0 0]\n", 518 | " [ 0 0 0 0 0 0 0]\n", 519 | " [ 0 0 0 0 0 0 0]\n", 520 | " [-1 0 0 0 0 0 0]\n", 521 | " [ 1 0 0 0 0 0 0]\n", 522 | " [-1 -1 -1 1 1 1 1]]\n", 523 | "The red player has won!\n" 524 | ] 525 | } 526 | ], 527 | "source": [ 528 | "from utils.ch06util import MiniMax_conn\n", 529 | "from utils.conn_env import conn\n", 530 | "import time\n", 531 | "\n", 532 | "# Initiate the game environment\n", 533 | "env=conn()\n", 534 | "state=env.reset() \n", 535 | "# Play a full game manually\n", 536 | "while True:\n", 537 | " # Mesure how long it takes to come up with a move\n", 538 | " start=time.time()\n", 539 | " action=MiniMax_conn(env,depth=3)\n", 540 | " end=time.time()\n", 541 | " print(f\"The red player has chosen action={action}\") \n", 542 | " print(f\"It took the agent {end-start} seconds\") \n", 543 | " state, reward, done, info = env.step(action)\n", 544 | " print(f\"Current state is \\n{state.T[::-1]}\")\n", 545 | " if done:\n", 546 | " if reward==1:\n", 547 | " print(f\"The red player has won!\") \n", 548 | " else:\n", 549 | " print(\"Game over, it's a tie!\")\n", 550 | " break \n", 551 | " action=input(\"Player yellow, what's your move?\\n\")\n", 552 | " print(f\"Player yellow has chosen action={action}\") \n", 553 | " state, reward, done, info = env.step(int(action))\n", 554 | " print(f\"Current state is \\n{state.T[::-1]}\")\n", 555 | " if done:\n", 556 | " if reward==-1:\n", 557 | " print(f\"The yellow player has won!\") \n", 558 | " else:\n", 559 | " print(\"Game over, it's a tie!\")\n", 560 | " break " 561 | ] 562 | }, 563 | { 564 | "cell_type": "markdown", 565 | "id": "8f64596e", 566 | "metadata": {}, 567 | "source": [ 568 | "## 4.3. Effectiveness of Alpha-Beta Pruning in Connect Four\n" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": 9, 574 | "id": "c97c5efc", 575 | "metadata": {}, 576 | "outputs": [], 577 | "source": [ 578 | "from utils.ch05util import MiniMax_depth\n", 579 | "from utils.ch03util import one_conn_game \n", 580 | "\n", 581 | "results=[]\n", 582 | "for i in range(100):\n", 583 | " # MiniMax with pruning moves first if i is even \n", 584 | " if i%2==0:\n", 585 | " result=one_conn_game(MiniMax_conn,MiniMax_depth)\n", 586 | " # record game outcome\n", 587 | " results.append(result)\n", 588 | " # MiniMax with pruning moves second if i is odd \n", 589 | " else:\n", 590 | " result=one_conn_game(MiniMax_depth,MiniMax_conn)\n", 591 | " # record negative game outcome\n", 592 | " results.append(-result)" 593 | ] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "execution_count": 10, 598 | "id": "7d97135d", 599 | "metadata": {}, 600 | "outputs": [ 601 | { 602 | "name": "stdout", 603 | "output_type": "stream", 604 | "text": [ 605 | "MiniMax with alpha-beta pruning won 41 games\n", 606 | "MiniMax with alpha-beta pruning lost 39 games\n", 607 | "the game was tied 20 times\n" 608 | ] 609 | } 610 | ], 611 | "source": [ 612 | "# count how many times MiniMax with alpha-beta pruning won\n", 613 | "wins=results.count(1)\n", 614 | "print(f\"MiniMax with alpha-beta pruning won {wins} games\")\n", 615 | "# count how many times MiniMax with pruning lost\n", 616 | "losses=results.count(-1)\n", 617 | "print(f\"MiniMax with alpha-beta pruning lost {losses} games\")\n", 618 | "# count tie games\n", 619 | "ties=results.count(0)\n", 620 | "print(f\"the game was tied {ties} times\") " 621 | ] 622 | }, 623 | { 624 | "cell_type": "markdown", 625 | "id": "35ec2255", 626 | "metadata": {}, 627 | "source": [ 628 | "The above results show that the MiniMax agent with alpha-beta pruning has won 41 times and lost 39 times. This shows that the MiniMax agent with alpha-beta pruning is as intelligent as the agent without alpha-beta pruning. Note that since the outcomes are random, you may get results showing that the MiniMax agent with alpha-beta pruning has lost more often than it has won. If that happens, run the above two cells again and see if the results change. " 629 | ] 630 | } 631 | ], 632 | "metadata": { 633 | "kernelspec": { 634 | "display_name": "Python 3 (ipykernel)", 635 | "language": "python", 636 | "name": "python3" 637 | }, 638 | "language_info": { 639 | "codemirror_mode": { 640 | "name": "ipython", 641 | "version": 3 642 | }, 643 | "file_extension": ".py", 644 | "mimetype": "text/x-python", 645 | "name": "python", 646 | "nbconvert_exporter": "python", 647 | "pygments_lexer": "ipython3", 648 | "version": "3.9.12" 649 | } 650 | }, 651 | "nbformat": 4, 652 | "nbformat_minor": 5 653 | } 654 | -------------------------------------------------------------------------------- /ch09DeepLearningCoinGame.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "653089c9", 6 | "metadata": {}, 7 | "source": [ 8 | "# Chapter 9: Deep Learning in the Coin Game\n", 9 | "\n", 10 | "\n", 11 | "\n", 12 | "***\n", 13 | "*“My CPU is a neural net processor, a learning computer. The more contact I have with humans, the more I learn.”*\n", 14 | "\n", 15 | "-- The Terminator, in Terminator 2: Judgement Day\n", 16 | "***\n", 17 | "\n", 18 | "\n", 19 | "\n", 20 | "What you'll learn in this chapter:\n", 21 | "\n", 22 | "* The architecture of a neural network\n", 23 | "* How deep learning is related to machine learning and artificial intelligence\n", 24 | "* Steps and Components of the AlphaGo algorithm\n", 25 | "* Building and training a fast policy network and a strong policy network in the coin game\n", 26 | "* Implementing an MCTS game strategy with policy rollouts\n", 27 | "\n", 28 | "Starting from this chapter, you’ll learn a new AI\n", 29 | "paradigm: machine learning (ML). Instead of hard\n", 30 | "coding in the rules, ML algorithms take in input-output pairs and figure out the\n", 31 | "relation between the inputs (which we call features) and outputs (the labels). One\n", 32 | "field of ML, deep learning, has attracted much attention recently. The algorithm used\n", 33 | "by AlphaGo is based on deep reinforcement learning, which is a combination of deep\n", 34 | "learning and reinforcement learning (a type of ML we’ll cover later in this book). In\n", 35 | "this chapter, you’ll learn what deep learning is and how it’s related to AI and ML.\n", 36 | "\n", 37 | "Deep learning is a type of ML method that’s based on artificial neural networks. A\n", 38 | "neural network is a computational model inspired by the structure of neural networks\n", 39 | "in the human brain. It’s designed to recognize patterns in data, and it contains layers\n", 40 | "of interconnected nodes, or neurons. In this chapter, you’ll learn to use deep neural\n", 41 | "networks to design game strategies for the coin game. In particular, you’ll follow the\n", 42 | "steps in AlphaGo and create two policy networks. We’ll use these networks later in\n", 43 | "the book to create an AlphaGo agent to play the coin game.\n", 44 | "\n", 45 | "Specifically, the AlphaGo algorithm follows the following steps. We first gather a\n", 46 | "large number of games played by Go experts and use deep learning to train two\n", 47 | "policy networks to predict the moves of the Go experts: a fast policy network and a\n", 48 | "strong policy network. In the second step, we use self-play deep reinforcement learning\n", 49 | "to further train and improve the strong policy network. At the same time, we train a\n", 50 | "value network to predict game outcomes by using the game experience data from the\n", 51 | "self-plays. Finally, we design a game strategy based on an improved version of MCTS. Instead of using the upper confidence bounds for trees (UCT) formula to select the\n", 52 | "next move, AlphaGo uses a combination of the UCT formula, the improved strong\n", 53 | "policy network, and the value network. Further, instead of randomly selecting moves\n", 54 | "in game rollouts, AlphaGo uses the fast policy network to roll out games.\n", 55 | "\n", 56 | "In this chapter, you’ll implement the first step in the AlphaGo algorithm in the coin\n", 57 | "game. Specifically, you’ll use the rule-based AI we developed in Chapter 1 to generate\n", 58 | "expert moves.We then create two neural networks and use the generated expert moves\n", 59 | "to train the two networks to predict moves. You’ll then implement policy rollouts in\n", 60 | "MCTS, where games are played based on the probability distribution from the fast\n", 61 | "policy network, leading to a more intelligent MCTS agent compared to the traditional\n", 62 | "one." 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "id": "48598568", 68 | "metadata": {}, 69 | "source": [ 70 | "# 1. Deep Learning, ML, and AI" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "id": "282ad876", 76 | "metadata": {}, 77 | "source": [ 78 | "# 2. What Are Neural Networks?" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "id": "9bcbdf00", 84 | "metadata": {}, 85 | "source": [ 86 | "# 3. Two Policy Networks in the Coin Game\n", 87 | "# 4. Train Two Networks in the Coin game" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 1, 93 | "id": "fc2f8827", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "import numpy as np\n", 98 | "import random\n", 99 | "\n", 100 | "def expert(env):\n", 101 | " if env.state%3 != 0:\n", 102 | " move = env.state%3\n", 103 | " else:\n", 104 | " move = random.choice([1,2])\n", 105 | " return move \n", 106 | "\n", 107 | "def non_expert(env):\n", 108 | " if env.state%3 != 0 and np.random.rand()<0.5:\n", 109 | " move = env.state%3\n", 110 | " else:\n", 111 | " move = random.choice([1,2])\n", 112 | " return move " 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 2, 118 | "id": "6db4ea45", 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "[(20, 2), (17, 2), (13, 1), (11, 2), (8, 2), (5, 2), (1, 1)]\n" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "from utils.coin_simple_env import coin_game\n", 131 | "import time\n", 132 | "\n", 133 | "# Initiate the game environment\n", 134 | "env=coin_game()\n", 135 | "# Define the one_game() function\n", 136 | "def one_game(episode):\n", 137 | " history=[]\n", 138 | " state=env.reset() \n", 139 | " # The nonexpert moves firsts half the time\n", 140 | " if episode%2==0:\n", 141 | " action=non_expert(env)\n", 142 | " state,reward,done,_=env.step(action)\n", 143 | " while True: \n", 144 | " action=expert(env) \n", 145 | " history.append((state,action))\n", 146 | " state,reward,done,_=env.step(action)\n", 147 | " if done:\n", 148 | " break\n", 149 | " action=non_expert(env)\n", 150 | " state,reward,done,_=env.step(action) \n", 151 | " if done:\n", 152 | " break\n", 153 | " return history\n", 154 | "\n", 155 | "# Simulate one game and print out results\n", 156 | "history=one_game(0)\n", 157 | "print(history) " 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 3, 163 | "id": "8a6a600a", 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "# simulate the game 10000 times \n", 168 | "results = [] \n", 169 | "for episode in range(10000):\n", 170 | " history=one_game(episode)\n", 171 | " results+=history " 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 4, 177 | "id": "88782a8d", 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "name": "stdout", 182 | "output_type": "stream", 183 | "text": [ 184 | "[(20, 2), (17, 2), (14, 2), (11, 2), (8, 2), (5, 2), (2, 2), (21, 1), (18, 1), (15, 1)]\n" 185 | ] 186 | } 187 | ], 188 | "source": [ 189 | "import pickle\n", 190 | "# save the simulation data on your computer\n", 191 | "with open('files/games_coin.p', 'wb') as fp:\n", 192 | " pickle.dump(results,fp)\n", 193 | "# read the data and print out the first 10 observations \n", 194 | "with open('files/games_coin.p', 'rb') as fp:\n", 195 | " games = pickle.load(fp)\n", 196 | "print(games[:10])" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "id": "4bd8eed1", 202 | "metadata": {}, 203 | "source": [ 204 | "## 4.2. Create Two Neural Networks\n" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 5, 210 | "id": "1579eddb", 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "from tensorflow.keras.utils import to_categorical\n", 215 | "from tensorflow.keras.layers import Dense\n", 216 | "from tensorflow.keras.models import Sequential\n", 217 | "\n", 218 | "fast_model = Sequential()\n", 219 | "fast_model.add(Dense(units=32,activation=\"relu\",\n", 220 | " input_shape=(22,)))\n", 221 | "fast_model.add(Dense(2, activation='softmax'))\n", 222 | "fast_model.compile(loss='categorical_crossentropy',\n", 223 | " optimizer='adam', \n", 224 | " metrics=['accuracy'])" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 6, 230 | "id": "d4b6a220", 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "strong_model = Sequential()\n", 235 | "strong_model.add(Dense(units=64,activation=\"relu\",\n", 236 | " input_shape=(22,)))\n", 237 | "strong_model.add(Dense(32, activation=\"relu\"))\n", 238 | "strong_model.add(Dense(16, activation=\"relu\"))\n", 239 | "strong_model.add(Dense(2, activation='softmax'))\n", 240 | "strong_model.compile(loss='categorical_crossentropy',\n", 241 | " optimizer='adam', \n", 242 | " metrics=['accuracy'])" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "id": "441286c7", 248 | "metadata": {}, 249 | "source": [ 250 | "## 4.3. Train the Neural Networks\n" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 7, 256 | "id": "96e2ab86", 257 | "metadata": {}, 258 | "outputs": [ 259 | { 260 | "name": "stdout", 261 | "output_type": "stream", 262 | "text": [ 263 | "[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 264 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]\n" 265 | ] 266 | } 267 | ], 268 | "source": [ 269 | "states=[20,1]\n", 270 | "one_hot=to_categorical(states,22)\n", 271 | "print(one_hot)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 8, 277 | "id": "16d6ea3b", 278 | "metadata": {}, 279 | "outputs": [ 280 | { 281 | "name": "stdout", 282 | "output_type": "stream", 283 | "text": [ 284 | "[[1. 0.]\n", 285 | " [0. 1.]]\n" 286 | ] 287 | } 288 | ], 289 | "source": [ 290 | "actions=[1,2]\n", 291 | "# change actions 1 and 2 to 0 and 1.\n", 292 | "actions=np.array(actions)-1\n", 293 | "# change actions to one-hot actions\n", 294 | "one_hot_actions=to_categorical(actions,2)\n", 295 | "print(one_hot_actions)" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 9, 301 | "id": "31cfc31b", 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "with open('files/games_coin.p','rb') as fp:\n", 306 | " games=pickle.load(fp)\n", 307 | "\n", 308 | "states = []\n", 309 | "actions = []\n", 310 | "for x in games:\n", 311 | " state=to_categorical(x[0],22)\n", 312 | " action=to_categorical(x[1]-1,2)\n", 313 | " states.append(state)\n", 314 | " actions.append(action)\n", 315 | "\n", 316 | "X = np.array(states).reshape((-1, 22))\n", 317 | "y = np.array(actions).reshape((-1, 2))" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 10, 323 | "id": "a3fc541a", 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [ 327 | "# Train the models for 25 epochs\n", 328 | "fast_model.fit(X, y, epochs=25, verbose=1)\n", 329 | "fast_model.save('files/fast_coin.h5')" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 11, 335 | "id": "bf390256", 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [ 339 | "strong_model.fit(X, y, epochs=25, verbose=1)\n", 340 | "strong_model.save('files/strong_coin.h5')" 341 | ] 342 | }, 343 | { 344 | "cell_type": "markdown", 345 | "id": "14805aac", 346 | "metadata": {}, 347 | "source": [ 348 | "# 5. MCTS with Policy Rollouts in the Coin Game\n" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "id": "a934805a", 354 | "metadata": {}, 355 | "source": [ 356 | "## 5.1. Policy-Based MCTS in the Coin Game\n", 357 | " " 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "id": "207914a0", 363 | "metadata": {}, 364 | "source": [ 365 | "Go to book's GitHub repository to download the file *ch09util.py* and place it in the folder /Desktop/ags/utils/ on your computer. In the file, we define a *DL_stochastic()* function as follows:" 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "id": "0275ee6b", 371 | "metadata": {}, 372 | "source": [ 373 | "```python\n", 374 | "def onehot_encoder(state):\n", 375 | " onehot=np.zeros((1,22))\n", 376 | " onehot[0,state]=1\n", 377 | " return onehot\n", 378 | "\n", 379 | "def DL_stochastic(env): \n", 380 | " state = env.state\n", 381 | " onehot_state = onehot_encoder(state)\n", 382 | " action_probs = model(onehot_state)\n", 383 | " return np.random.choice([1,2], \n", 384 | " p=np.squeeze(action_probs))\n", 385 | "```" 386 | ] 387 | }, 388 | { 389 | "cell_type": "markdown", 390 | "id": "db74ec75", 391 | "metadata": {}, 392 | "source": [ 393 | "```python\n", 394 | "def policy_simulate(env_copy,done,reward,model):\n", 395 | " # if the game has already ended\n", 396 | " if done==True:\n", 397 | " return reward\n", 398 | " while True:\n", 399 | " move=DL_stochastic(env_copy,model)\n", 400 | " state,reward,done,info=env_copy.step(move)\n", 401 | " if done==True:\n", 402 | " return reward\n", 403 | "```" 404 | ] 405 | }, 406 | { 407 | "cell_type": "markdown", 408 | "id": "0f4b749c", 409 | "metadata": {}, 410 | "source": [ 411 | "```python\n", 412 | "def policy_mcts_coin(env,model,num_rollouts=100,temperature=1.4):\n", 413 | " # if there is only one valid move left, take it\n", 414 | " if len(env.validinputs)==1:\n", 415 | " return env.validinputs[0]\n", 416 | " # create three dictionaries counts, wins, losses\n", 417 | " counts={}\n", 418 | " wins={}\n", 419 | " losses={}\n", 420 | " for move in env.validinputs:\n", 421 | " counts[move]=0\n", 422 | " wins[move]=0\n", 423 | " losses[move]=0\n", 424 | " # roll out games\n", 425 | " for _ in range(num_rollouts):\n", 426 | " # selection\n", 427 | " move=select(env,counts,wins,losses,temperature)\n", 428 | " # expansion\n", 429 | " env_copy, done, reward=expand(env,move)\n", 430 | " # simulation\n", 431 | " reward=policy_simulate(env_copy,done,reward,model)\n", 432 | " # backpropagate\n", 433 | " counts,wins,losses=backpropagate(\\\n", 434 | " env,move,reward,counts,wins,losses)\n", 435 | " # make the move\n", 436 | " return next_move(counts,wins,losses)\n", 437 | "```" 438 | ] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "id": "d4535cf7", 443 | "metadata": {}, 444 | "source": [ 445 | "## 5.2. The Effectiveness of the Policy MCTS Agent\n" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 12, 451 | "id": "f3fa6933", 452 | "metadata": {}, 453 | "outputs": [], 454 | "source": [ 455 | "from utils.ch08util import mcts\n", 456 | "from utils.ch09util import policy_mcts_coin\n", 457 | "\n", 458 | "env=coin_game()\n", 459 | "results=[]\n", 460 | "for i in range(100):\n", 461 | " state=env.reset() \n", 462 | " # Half the time, the UCT MCTS agent moves first\n", 463 | " if i%2==0:\n", 464 | " action=mcts(env,num_rollouts=100)\n", 465 | " state, reward, done, info=env.step(action)\n", 466 | " while True:\n", 467 | " action=policy_mcts_coin(env,model,num_rollouts=100) \n", 468 | " state, reward, done, info=env.step(action)\n", 469 | " if done:\n", 470 | " # result is 1 if the policy MCTS agent wins\n", 471 | " results.append(1) \n", 472 | " break \n", 473 | " action=mcts(env,num_rollouts=100)\n", 474 | " state, reward, done, info=env.step(action)\n", 475 | " if done:\n", 476 | " # result is -1 if the policy MCTS agent loses\n", 477 | " results.append(-1) \n", 478 | " break " 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 13, 484 | "id": "e6e7fd07", 485 | "metadata": {}, 486 | "outputs": [ 487 | { 488 | "name": "stdout", 489 | "output_type": "stream", 490 | "text": [ 491 | "the policy MCTS agent has won 100 games\n", 492 | "the policy MCTS agent has lost 0 games\n" 493 | ] 494 | } 495 | ], 496 | "source": [ 497 | "wins=results.count(1)\n", 498 | "print(f\"the policy MCTS agent has won {wins} games\")\n", 499 | "losses=results.count(-1)\n", 500 | "print(f\"the policy MCTS agent has lost {losses} games\") " 501 | ] 502 | } 503 | ], 504 | "metadata": { 505 | "kernelspec": { 506 | "display_name": "Python 3 (ipykernel)", 507 | "language": "python", 508 | "name": "python3" 509 | }, 510 | "language_info": { 511 | "codemirror_mode": { 512 | "name": "ipython", 513 | "version": 3 514 | }, 515 | "file_extension": ".py", 516 | "mimetype": "text/x-python", 517 | "name": "python", 518 | "nbconvert_exporter": "python", 519 | "pygments_lexer": "ipython3", 520 | "version": "3.9.12" 521 | } 522 | }, 523 | "nbformat": 4, 524 | "nbformat_minor": 5 525 | } 526 | -------------------------------------------------------------------------------- /ch16AlphaGoCoinGame.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "653089c9", 6 | "metadata": {}, 7 | "source": [ 8 | "# Chaper 16: Implement AlphaGo in the Coin Game\n", 9 | "\n", 10 | "\n", 11 | "\n", 12 | "***\n", 13 | "***“I thought AlphaGo was based on probability calculation and that it was merely a machine. But when I saw this move, I changed my mind. Surely, AlphaGo is creative.”***\n", 14 | "\n", 15 | "-- Lee Sedol, winner of 18 world Go titles\n", 16 | "***\n", 17 | "\n", 18 | "\n", 19 | "\n", 20 | "\n", 21 | "The AlphaGo algorithm combines combines deep reinforcement learning (namely, the policy gradient method) with traditional rule-based AI (specifically, Monte Carlo Tree Search or MCTS) to generate intelligent game strategies in Go. Now that you have learned both deep reinforcement learning and MCTS, you’ll learn to combine them together as the DeepMind team did and apply the algorithm\n", 22 | "to the three games in this book: the coin game, Tic Tac Toe, and Connect Four.\n", 23 | "\n", 24 | "In this chapter, you’ll implement AlphaGo in the coin game. First, we’ll go over the AlphaGo algorithm and see how it brings various pieces together to create a powerful AI player. After that, you’ll apply the same logic to the coin game to create your very own AlphaGo agent.\n", 25 | "\n", 26 | "MCTS is the core of AlphaGo’s decision-making process. MCTS is used to find the most promising moves by building a search tree. Each node in the tree represents a game position, and branches represent possible moves. The search through this tree is guided by statistical analysis of the moves. AlphaGo also uses three deep neural networks - a fast policy network, a strengthened policy network through self-play deep reinforcement learning, and a value network. The fast policy network helps in narrowing down the selection of moves to consider in game rollouts. It’s trained on expert games and learns to predict their moves. This network is used to guide the tree search to more promising paths. The strengthened policy network is used to select child nodes in game rollouts so that the most valuable child nodes are selected to roll out games. The value network evaluates board positions and predicts the winner of the game from that position. It’s crucial for looking ahead and evaluating future\n", 27 | "positions without having to play out the entire game.\n", 28 | "\n", 29 | "AlphaGo’s success lies in the effective combination of traditional tree search methods with powerful machine-learning techniques. This allowed it to tackle the immense complexity of Go, a game with more possible positions than atoms in the observable universe.\n", 30 | "\n", 31 | "The three games we consider in this book are much simpler than the game of Go. Nonetheless, we’ll mimic the strategies used by the DeepMind team and implement the AlphaGo algorithm in these games. Along the way, you’ll pick valuable skills in both rule-based AI and cutting-edge developments in deep learning.\n", 32 | "\n", 33 | "To implement AlphaGo in the coin game, we’ll utilize the skills and the trained networks earlier in this book. Specifically, we’ll use the skills you learned about MCTS from Chapter 8. However, instead of rolling out games with random moves, you’ll roll out games by letting both players choose moves based on the fast policy network we trained in Chapter 9. More intelligent moves in rollouts lead to more informative game outcomes. Further, instead of playing out the entire game during rollouts, you’ll use the value network we trained in Chapter 13 to evaluate game states after playing out games after a fixed number of moves (that is, after a certain depth). Finally, to select a child node to roll out games in MCTS, instead of using the UCT formula from Chapter 8, you’ll use a weighted average of the rollout values and the prior probabilities recommended by the strengthened policy network from Chapter 13.\n", 34 | "\n", 35 | "After creating the AlphaGo agent in the coin game, you’ll test it against both random moves and the perfect rule-based AI player from Chapter 1. You’ll see that when moving second, the AlphaGo agent beats the ruled-based AI player in all ten games. When moving first, the AlphaGo agent wins all ten games against random moves. This shows that the AlphaGo algorithm is as strong as any possible game strategy we could have designed for the coin game." 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "id": "282ad876", 41 | "metadata": {}, 42 | "source": [ 43 | "# 1. The AlphaGo Architecture \n" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "id": "39a933d3", 49 | "metadata": {}, 50 | "source": [ 51 | "# 2. AlphaGo in the Coin Game\n", 52 | "\n", 53 | "\n", 54 | "## 2.1. Select the Best Child Node and Expand the Game Tree\n" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "id": "c6d4ea2c", 60 | "metadata": {}, 61 | "source": [ 62 | "```python\n", 63 | "from copy import deepcopy\n", 64 | "import numpy as np\n", 65 | "from tensorflow import keras\n", 66 | "\n", 67 | "# Load the trained fast policy network from Chapter 9\n", 68 | "fast_net=keras.models.load_model(\"files/fast_coin.h5\")\n", 69 | "# Load the strengthend strong net from Chapter 13\n", 70 | "PG_net=keras.models.load_model(\"files/PG_coin.h5\")\n", 71 | "# Load the trained value network from Chapter 13\n", 72 | "value_net=keras.models.load_model(\"files/value_coin.h5\")\n", 73 | "```" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "id": "65d0d349", 79 | "metadata": {}, 80 | "source": [ 81 | "```python\n", 82 | "def onehot_encoder(state):\n", 83 | " onehot=np.zeros((1,22))\n", 84 | " onehot[0,state]=1\n", 85 | " return onehot\n", 86 | "\n", 87 | "def best_move_fast_net(env):\n", 88 | " state = env.state\n", 89 | " onehot_state = onehot_encoder(state)\n", 90 | " action_probs = fast_net(onehot_state)\n", 91 | " return np.random.choice([1,2],\n", 92 | " p=np.squeeze(action_probs))\n", 93 | "```" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "id": "25e0b0e2", 99 | "metadata": {}, 100 | "source": [ 101 | "```python\n", 102 | "def select(priors,env,results,weight): \n", 103 | " # weighted average of priors and rollout_value\n", 104 | " scores={}\n", 105 | " for k,v in results.items():\n", 106 | " # rollout_value for each next move\n", 107 | " if len(v)==0:\n", 108 | " vi=0\n", 109 | " else:\n", 110 | " vi=sum(v)/len(v)\n", 111 | " # scale the prior by (1+N(L))\n", 112 | " prior=priors[0][k-1]/(1+len(v))\n", 113 | " # calculate weighted average\n", 114 | " scores[k]=weight*prior+(1-weight)*vi\n", 115 | " # select child node based on the weighted average \n", 116 | " return max(scores,key=scores.get) \n", 117 | "```" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "id": "74702e59", 123 | "metadata": {}, 124 | "source": [ 125 | "```python\n", 126 | "# expand the game tree by taking a hypothetical move\n", 127 | "def expand(env,move):\n", 128 | " env_copy=deepcopy(env)\n", 129 | " state,reward,done,info=env_copy.step(move)\n", 130 | " return env_copy, done, reward\n", 131 | "```" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "id": "e8c4f9e0", 137 | "metadata": {}, 138 | "source": [ 139 | "## 2.2. Roll Out A Game and Backpropagate" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "id": "c33efb1c", 145 | "metadata": {}, 146 | "source": [ 147 | "```python\n", 148 | "# roll out a game till terminal state or depth reached\n", 149 | "def simulate(env_copy,done,reward,depth):\n", 150 | " # if the game has already ended\n", 151 | " if done==True:\n", 152 | " return reward\n", 153 | " # select moves based on fast policy network\n", 154 | " for _ in range(depth):\n", 155 | " move=best_move_fast_net(env_copy)\n", 156 | " state,reward,done,info=env_copy.step(move)\n", 157 | " # if terminal state is reached, returns outcome\n", 158 | " if done==True:\n", 159 | " return reward\n", 160 | " # depth reached but game not over, evaluate\n", 161 | " onehot_state=onehot_encoder(state)\n", 162 | " # use the trained value network to evaluate\n", 163 | " ps=value_net.predict(onehot_state,verbose=0)\n", 164 | " # output is prob(1 wins)-prob(2 wins)\n", 165 | " reward=ps[0][1]-ps[0][0] \n", 166 | " return reward\n", 167 | "```" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "id": "dff8ab61", 173 | "metadata": {}, 174 | "source": [ 175 | "```python\n", 176 | "def backpropagate(env,move,reward,results):\n", 177 | " # if current player is Player 1, update results\n", 178 | " if env.turn==1:\n", 179 | " results[move].append(reward)\n", 180 | " # if current player is Player 2, multiply outcome with -1\n", 181 | " elif env.turn==2:\n", 182 | " results[move].append(-reward) \n", 183 | " return results\n", 184 | "```" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "id": "88f81d44", 190 | "metadata": {}, 191 | "source": [ 192 | "## 2.3 Create An AlphaGo Agent in the Coin Game\n" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "id": "fb5f027c", 198 | "metadata": {}, 199 | "source": [ 200 | "```python\n", 201 | "def alphago_coin(env,weight,depth,num_rollouts=100):\n", 202 | " # if there is only one valid move left, take it\n", 203 | " if len(env.validinputs)==1:\n", 204 | " return env.validinputs[0]\n", 205 | " # get the prior from the PG policy network\n", 206 | " priors = PG_net(onehot_encoder(env.state)) \n", 207 | " # create a dictionary results\n", 208 | " results={}\n", 209 | " for move in env.validinputs:\n", 210 | " results[move]=[]\n", 211 | " # roll out games\n", 212 | " for _ in range(num_rollouts):\n", 213 | " # select\n", 214 | " move=select(priors,env,results,weight)\n", 215 | " # expand\n", 216 | " env_copy, done, reward=expand(env,move)\n", 217 | " # simulate\n", 218 | " reward=simulate(env_copy,done,reward,depth)\n", 219 | " # backpropagate\n", 220 | " results=backpropagate(env,move,reward,results)\n", 221 | " # select the most visited child node\n", 222 | " visits={k:len(v) for k,v in results.items()}\n", 223 | " return max(visits,key=visits.get)\n", 224 | "```" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "id": "202a5d06", 230 | "metadata": {}, 231 | "source": [ 232 | "# 3. Test the AlphaGo Algorithm in the Coin Game\n" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "id": "b57ebd81", 238 | "metadata": {}, 239 | "source": [ 240 | "## 3.1. When the AlphaGo Agent Moves Second" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 1, 246 | "id": "3ab56e2c", 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "import random\n", 251 | "\n", 252 | "def rule_based_AI(env):\n", 253 | " if env.state%3 != 0:\n", 254 | " move = env.state%3\n", 255 | " else:\n", 256 | " move = random.choice([1,2])\n", 257 | " return move " 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 2, 263 | "id": "f806366f", 264 | "metadata": {}, 265 | "outputs": [ 266 | { 267 | "name": "stdout", 268 | "output_type": "stream", 269 | "text": [ 270 | "WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.\n", 271 | "AlphaGo wins!\n", 272 | "AlphaGo wins!\n", 273 | "AlphaGo wins!\n", 274 | "AlphaGo wins!\n", 275 | "AlphaGo wins!\n", 276 | "AlphaGo wins!\n", 277 | "AlphaGo wins!\n", 278 | "AlphaGo wins!\n", 279 | "AlphaGo wins!\n", 280 | "AlphaGo wins!\n" 281 | ] 282 | } 283 | ], 284 | "source": [ 285 | "from utils.coin_simple_env import coin_game\n", 286 | "from utils.ch16util import alphago_coin\n", 287 | "\n", 288 | "# initiate game environment\n", 289 | "env=coin_game()\n", 290 | "# test ten games\n", 291 | "for i in range(10):\n", 292 | " state=env.reset() \n", 293 | " while True: \n", 294 | " # The rule-based AI player moves first\n", 295 | " action=rule_based_AI(env)\n", 296 | " state,reward,done,_=env.step(action) \n", 297 | " if done:\n", 298 | " # print out the winner\n", 299 | " print(\"Rule-based AI wins!\")\n", 300 | " break \n", 301 | " # The AlphaGo agent moves second\n", 302 | " action=alphago_coin(env,0.9,10,num_rollouts=100) \n", 303 | " state,reward,done,_=env.step(action)\n", 304 | " if done:\n", 305 | " # print out the winner\n", 306 | " print(\"AlphaGo wins!\")\n", 307 | " break" 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "id": "13c2649a", 313 | "metadata": {}, 314 | "source": [ 315 | "## 3.2. Against Random Moves\n" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 3, 321 | "id": "13cfa062", 322 | "metadata": {}, 323 | "outputs": [ 324 | { 325 | "name": "stdout", 326 | "output_type": "stream", 327 | "text": [ 328 | "AlphaGo wins!\n", 329 | "AlphaGo wins!\n", 330 | "AlphaGo wins!\n", 331 | "AlphaGo wins!\n", 332 | "AlphaGo wins!\n", 333 | "AlphaGo wins!\n", 334 | "AlphaGo wins!\n", 335 | "AlphaGo wins!\n", 336 | "AlphaGo wins!\n", 337 | "AlphaGo wins!\n" 338 | ] 339 | } 340 | ], 341 | "source": [ 342 | "# test ten games against random moves\n", 343 | "for i in range(10):\n", 344 | " state=env.reset() \n", 345 | " while True: \n", 346 | " # The AlphaGo agent moves first\n", 347 | " action=alphago_coin(env,0.9,10,num_rollouts=100) \n", 348 | " state,reward,done,_=env.step(action)\n", 349 | " if done:\n", 350 | " print(\"AlphaGo wins!\")\n", 351 | " break\n", 352 | " # The random player moves second\n", 353 | " action=random.choice(env.validinputs)\n", 354 | " state,reward,done,_=env.step(action) \n", 355 | " if done:\n", 356 | " print(\"AlphaGo loses!\")\n", 357 | " break" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "id": "e53dc8fc", 363 | "metadata": {}, 364 | "source": [ 365 | "# 4. Redundancy in the AlphaGo Algorithm\n" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 4, 371 | "id": "91bdca04", 372 | "metadata": {}, 373 | "outputs": [ 374 | { 375 | "name": "stdout", 376 | "output_type": "stream", 377 | "text": [ 378 | "Without value network, AlphaGo wins!\n", 379 | "Without value network, AlphaGo wins!\n", 380 | "Without value network, AlphaGo wins!\n", 381 | "Without value network, AlphaGo wins!\n", 382 | "Without value network, AlphaGo wins!\n", 383 | "Without value network, AlphaGo wins!\n", 384 | "Without value network, AlphaGo wins!\n", 385 | "Without value network, AlphaGo wins!\n", 386 | "Without value network, AlphaGo wins!\n", 387 | "Without value network, AlphaGo wins!\n" 388 | ] 389 | } 390 | ], 391 | "source": [ 392 | "from utils.coin_simple_env import coin_game\n", 393 | "from utils.ch16util import alphago_coin\n", 394 | "\n", 395 | "# initiate game environment\n", 396 | "env=coin_game()\n", 397 | "# test ten games\n", 398 | "for i in range(10):\n", 399 | " state=env.reset() \n", 400 | " while True: \n", 401 | " # The rule-based AI player moves first\n", 402 | " action=rule_based_AI(env)\n", 403 | " state,reward,done,_=env.step(action) \n", 404 | " if done:\n", 405 | " # print out the winner\n", 406 | " print(\"Without value network, rule-based AI wins!\")\n", 407 | " break \n", 408 | " # The AlphaGo agent moves second\n", 409 | " action=alphago_coin(env,0.8,22,num_rollouts=100) \n", 410 | " state,reward,done,_=env.step(action)\n", 411 | " if done:\n", 412 | " # print out the winner\n", 413 | " print(\"Without value network, AlphaGo wins!\")\n", 414 | " break" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 5, 420 | "id": "e5d97e1a", 421 | "metadata": {}, 422 | "outputs": [ 423 | { 424 | "name": "stdout", 425 | "output_type": "stream", 426 | "text": [ 427 | "Without value network, AlphaGo wins!\n", 428 | "Without value network, AlphaGo wins!\n", 429 | "Without value network, AlphaGo wins!\n", 430 | "Without value network, AlphaGo wins!\n", 431 | "Without value network, AlphaGo wins!\n", 432 | "Without value network, AlphaGo wins!\n", 433 | "Without value network, AlphaGo wins!\n", 434 | "Without value network, AlphaGo wins!\n", 435 | "Without value network, AlphaGo wins!\n", 436 | "Without value network, AlphaGo wins!\n" 437 | ] 438 | } 439 | ], 440 | "source": [ 441 | "# test ten games against random moves\n", 442 | "for i in range(10):\n", 443 | " state=env.reset() \n", 444 | " while True: \n", 445 | " # The AlphaGo agent moves first\n", 446 | " action=alphago_coin(env,0.8,22,num_rollouts=100) \n", 447 | " state,reward,done,_=env.step(action)\n", 448 | " if done:\n", 449 | " print(\"Without value network, AlphaGo wins!\")\n", 450 | " break\n", 451 | " # The random player moves second\n", 452 | " action=random.choice(env.validinputs)\n", 453 | " state,reward,done,_=env.step(action) \n", 454 | " if done:\n", 455 | " print(\"Without value network, AlphaGo loses!\")\n", 456 | " break" 457 | ] 458 | } 459 | ], 460 | "metadata": { 461 | "kernelspec": { 462 | "display_name": "Python 3 (ipykernel)", 463 | "language": "python", 464 | "name": "python3" 465 | }, 466 | "language_info": { 467 | "codemirror_mode": { 468 | "name": "ipython", 469 | "version": 3 470 | }, 471 | "file_extension": ".py", 472 | "mimetype": "text/x-python", 473 | "name": "python", 474 | "nbconvert_exporter": "python", 475 | "pygments_lexer": "ipython3", 476 | "version": "3.9.12" 477 | } 478 | }, 479 | "nbformat": 4, 480 | "nbformat_minor": 5 481 | } 482 | -------------------------------------------------------------------------------- /ch17AlphaGoTicTacToe.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "de3a7f85", 6 | "metadata": {}, 7 | "source": [ 8 | "# Chaper 17: AlphaGo in Tic Tac Toe and Connect Four\n", 9 | "\n", 10 | "\n", 11 | "\n", 12 | "***\n", 13 | "***“We also introduce a new search algorithm that combines Monte Carlo simulation\n", 14 | "with value and policy networks. Using this search algorithm, our program AlphaGo\n", 15 | "achieved a 99.8% winning rate against other Go programs.”***\n", 16 | "\n", 17 | "-- Google DeepMind Team, Nature (2016) \n", 18 | "***\n", 19 | "\n", 20 | "\n", 21 | "\n", 22 | "\n", 23 | "In Chapter 16, you have learned the basic architecture of the AlphaGo algorithm, which combines Monte Carlo tree search with deep neural networks, as outlined in the open quote in an article published in the journal Nature in 2016. Specifically, you have implemented a basic version of the AlphaGo algorithm for the coin game, combining deep reinforcement learning with conventional rule-based AI.\n", 24 | "\n", 25 | "In this chapter, you’ll expand this approach, creating an AlphaGo agent adaptable to various games.\n", 26 | "The AlphaGo agent you create features two main enhancements. Firstly, it will be versatile, capable of handling two games, Tic Tac Toe and Connect Four. This flexibility reduces code redundancy and simplifies the application of the AlphaGo algorithm to a broader range of games. Secondly, the agent’s game simulation strategy includes a choice between random moves and those suggested by the fast policy network. This decision involves a trade-off: the network’s moves provide more insight but require\n", 27 | "more processing time due to the complex neural network. In contrast, random moves accelerate game simulations, enabling more game rollouts within a given time frame and potentially smarter move selection in actual gameplay.\n", 28 | "\n", 29 | "Monte Carlo Tree Search (MCTS) remains the core of the agent’s decision process, involving selection, expansion, simulation, and backpropagation, as outlined in Chapter 8. However, now three deep neural networks, introduced in previous chapters, will enhance the tree search. During real games, a large number of simulations will start from the current game state. Each simulation involves choosing a child node for expansion, not based on the upper confidence bounds applied to trees (UCT) formula\n", 30 | "from Chapter 8, but on a weighted average of each node’s rollout value and its prior probability from the trained policy-gradient network. Additionally, players can choose between random moves or those from the fast policy network for rollouts. Instead of playing to a terminal state, the game state will be evaluated at a certain depth using the trained value networks, allowing more simulations within the time limit.\n", 31 | "\n", 32 | "You’ll evaluate the AlphaGo algorithm’s effectiveness in Tic Tac Toe. Against the perfect rule-based AI from Chapter 6, the AlphaGo agent consistently draws, indicating its ability to solve the game.\n", 33 | "Given Tic Tac Toe’s simplicity compared to Chess or Go, we’ll also explore an AlphaGo version without the value network, rolling out games to their end. Another variant will use random moves instead of those from the fast policy network. Both versions will be shown to effectively solve Tic Tac Toe, setting the stage for Chapter 20, where we implement AlphaZero in Tic Tac Toe, omitting the value network and relying solely on one policy network." 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "39a933d3", 39 | "metadata": {}, 40 | "source": [ 41 | "# 1. An AlphaGo Algorithm for Multiple Games\n" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "id": "d44908b0", 47 | "metadata": {}, 48 | "source": [ 49 | "## 1.1. Functions to Select and Expand\n", 50 | "In the local module *ch17util*, we first define a *select()* function to select a child node to expand the game tree, as follows:" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "id": "063509d0", 56 | "metadata": {}, 57 | "source": [ 58 | "```python\n", 59 | "def select(priors,env,results,weight): \n", 60 | " # weighted average of priors and rollout_value\n", 61 | " scores={}\n", 62 | " for k,v in results.items():\n", 63 | " # rollout_value for each next move\n", 64 | " if len(v)==0:\n", 65 | " vi=0\n", 66 | " else:\n", 67 | " vi=sum(v)/len(v)\n", 68 | " # scale the prior by (1+N(L))\n", 69 | " prior=priors[0][k-1]/(1+len(v))\n", 70 | " # calculate weighted average\n", 71 | " scores[k]=weight*prior+(1-weight)*vi\n", 72 | " # select child node based on the weighted average \n", 73 | " return max(scores,key=scores.get) \n", 74 | "```" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "id": "74702e59", 80 | "metadata": {}, 81 | "source": [ 82 | "```python\n", 83 | "# expand the game tree by taking a hypothetical move\n", 84 | "def expand(env,move):\n", 85 | " env_copy=deepcopy(env)\n", 86 | " state,reward,done,info=env_copy.step(move)\n", 87 | " return env_copy, done, reward\n", 88 | "```" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "id": "465fac83", 94 | "metadata": {}, 95 | "source": [ 96 | "## 1.2 Functions to simulate and backpropagate\n" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "id": "598b3ade", 102 | "metadata": {}, 103 | "source": [ 104 | "```python\n", 105 | "def best_move_fast_net(env, fast_net):\n", 106 | " # priors from the policy network\n", 107 | " if env.turn==\"X\":\n", 108 | " state = env.state.reshape(-1,3,3,1)\n", 109 | " action_probs=fast_net(state)\n", 110 | " elif env.turn==\"O\":\n", 111 | " state = env.state.reshape(-1,3,3,1)\n", 112 | " action_probs=fast_net(-state)\n", 113 | " elif env.turn==\"red\":\n", 114 | " state = env.state.reshape(-1,7,6,1)\n", 115 | " action_probs=fast_net(state)\n", 116 | " elif env.turn==\"yellow\":\n", 117 | " state = env.state.reshape(-1,7,6,1)\n", 118 | " action_probs=fast_net(-state) \n", 119 | " ps=[]\n", 120 | " for a in sorted(env.validinputs):\n", 121 | " ps.append(np.squeeze(action_probs)[a-1])\n", 122 | " ps=np.array(ps)\n", 123 | " return np.random.choice(sorted(env.validinputs),\n", 124 | " p=ps/ps.sum())\n", 125 | "```" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "id": "32135136", 131 | "metadata": {}, 132 | "source": [ 133 | "```python\n", 134 | "# roll out a game till terminal state or depth reached\n", 135 | "def simulate(env_copy,done,reward,depth,value_net,\n", 136 | " fast_net, policy_rollout=True):\n", 137 | " # if the game has already ended\n", 138 | " if done==True:\n", 139 | " return reward\n", 140 | " # select moves based on fast policy network\n", 141 | " for _ in range(depth):\n", 142 | " if policy_rollout:\n", 143 | " move=best_move_fast_net(env_copy, fast_net)\n", 144 | " else:\n", 145 | " move=env_copy.sample()\n", 146 | " state,reward,done,info=env_copy.step(move)\n", 147 | " # if terminal state is reached, returns outcome\n", 148 | " if done==True:\n", 149 | " return reward\n", 150 | " # depth reached but game not over, evaluate\n", 151 | " if env_copy.turn==\"X\":\n", 152 | " state=state.reshape(-1,3,3,1)\n", 153 | " ps=value_net.predict(state,verbose=0)\n", 154 | " # reward is prob(X win) - prob(O win)\n", 155 | " reward=ps[0][1]-ps[0][2] \n", 156 | " elif env_copy.turn==\"O\":\n", 157 | " state=state.reshape(-1,3,3,1)\n", 158 | " ps=value_net.predict(-state,verbose=0)\n", 159 | " # reward is prob(X win) - prob(O win)\n", 160 | " reward=ps[0][2]-ps[0][1] \n", 161 | " elif env_copy.turn==\"red\":\n", 162 | " state=state.reshape(-1,7,6,1)\n", 163 | " ps=value_net.predict(state,verbose=0)\n", 164 | " # reward is prob(red win) - prob(yellow win)\n", 165 | " reward=ps[0][1]-ps[0][2] \n", 166 | " elif env_copy.turn==\"yellow\":\n", 167 | " state=state.reshape(-1,7,6,1)\n", 168 | " ps=value_net.predict(-state,verbose=0)\n", 169 | " # reward is prob(red win) - prob(yellow win)\n", 170 | " reward=ps[0][2]-ps[0][1] \n", 171 | " return reward\n", 172 | "```" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "id": "dff8ab61", 178 | "metadata": {}, 179 | "source": [ 180 | "```python\n", 181 | "def backpropagate(env,move,reward,results):\n", 182 | " # update results\n", 183 | " if env.turn==\"X\" or env.turn==\"red\":\n", 184 | " results[move].append(reward)\n", 185 | " # if current player is player 2, multiply outcome with -1\n", 186 | " if env.turn==\"O\" or env.turn==\"yellow\":\n", 187 | " results[move].append(-reward) \n", 188 | " return results\n", 189 | "```" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "id": "88f81d44", 195 | "metadata": {}, 196 | "source": [ 197 | "## 1.3 An AlphaGo Agent for Tic Tac Toe and Connect Four\n" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "id": "fb5f027c", 203 | "metadata": {}, 204 | "source": [ 205 | "```python\n", 206 | "def alphago(env,weight,depth,PG_net,value_net,\n", 207 | " fast_net, policy_rollout=True,num_rollouts=100):\n", 208 | " # if there is only one valid move left, take it\n", 209 | " if len(env.validinputs)==1:\n", 210 | " return env.validinputs[0]\n", 211 | " # get the prior from the PG policy network\n", 212 | " if env.turn==\"X\" or env.turn==\"O\":\n", 213 | " state = env.state.reshape(-1,9)\n", 214 | " conv_state = state.reshape(-1,3,3,1)\n", 215 | " if env.turn==\"X\":\n", 216 | " priors = PG_net([state,conv_state])\n", 217 | " elif env.turn==\"O\":\n", 218 | " priors = PG_net([-state,-conv_state]) \n", 219 | " if env.turn==\"red\" or env.turn==\"yellow\":\n", 220 | " state = env.state.reshape(-1,42)\n", 221 | " conv_state = state.reshape(-1,7,6,1)\n", 222 | " if env.turn==\"red\":\n", 223 | " priors = PG_net([state,conv_state])\n", 224 | " elif env.turn==\"yellow\":\n", 225 | " priors = PG_net([-state,-conv_state]) \n", 226 | " # create a dictionary results\n", 227 | " results={}\n", 228 | " for move in env.validinputs:\n", 229 | " results[move]=[]\n", 230 | " # roll out games\n", 231 | " for _ in range(num_rollouts):\n", 232 | " # select\n", 233 | " move=select(priors,env,results,weight)\n", 234 | " # expand\n", 235 | " env_copy, done, reward=expand(env,move)\n", 236 | " # simulate\n", 237 | " reward=simulate(env_copy,done,reward,depth,value_net,\n", 238 | " fast_net, policy_rollout)\n", 239 | " # backpropagate\n", 240 | " results=backpropagate(env,move,reward,results)\n", 241 | " # select the most visited child node\n", 242 | " visits={k:len(v) for k,v in results.items()}\n", 243 | " return max(visits,key=visits.get)\n", 244 | "```" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "id": "837d70a2", 250 | "metadata": {}, 251 | "source": [ 252 | "# 2. Test the AlphaGo Agent in Tic Tac Toe" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "id": "b57ebd81", 258 | "metadata": {}, 259 | "source": [ 260 | "## 2.1. The Opponent in Tic Tac Toe Games" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 1, 266 | "id": "3ab56e2c", 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "from utils.ch06util import MiniMax_ab\n", 271 | "\n", 272 | "def rule_based_AI(env):\n", 273 | " move = MiniMax_ab(env)\n", 274 | " return move " 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 2, 280 | "id": "39fab8d6", 281 | "metadata": {}, 282 | "outputs": [ 283 | { 284 | "name": "stdout", 285 | "output_type": "stream", 286 | "text": [ 287 | "WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.\n" 288 | ] 289 | } 290 | ], 291 | "source": [ 292 | "from copy import deepcopy\n", 293 | "import numpy as np\n", 294 | "from tensorflow import keras\n", 295 | "\n", 296 | "# Load the trained fast policy network from Chapter 10\n", 297 | "fast_net=keras.models.load_model(\"files/fast_ttt.h5\")\n", 298 | "# Load the policy gradient network from Chapter 14\n", 299 | "PG_net=keras.models.load_model(\"files/pg_ttt.h5\")\n", 300 | "# Load the trained value network from Chapter 14\n", 301 | "value_net=keras.models.load_model(\"files/value_ttt.h5\")" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "id": "5ef56cd1", 307 | "metadata": {}, 308 | "source": [ 309 | "## 2.2. AlphaGo vs Rule-Based AI in Tic Tac Toe \n" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 3, 315 | "id": "f806366f", 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "name": "stdout", 320 | "output_type": "stream", 321 | "text": [ 322 | "The game is tied!\n", 323 | "The game is tied!\n", 324 | "The game is tied!\n", 325 | "The game is tied!\n", 326 | "The game is tied!\n", 327 | "The game is tied!\n", 328 | "The game is tied!\n", 329 | "The game is tied!\n", 330 | "The game is tied!\n", 331 | "The game is tied!\n" 332 | ] 333 | } 334 | ], 335 | "source": [ 336 | "from utils.ttt_simple_env import ttt\n", 337 | "from utils.ch17util import alphago\n", 338 | "\n", 339 | "weight=0.75\n", 340 | "depth=5\n", 341 | "# initiate game environment\n", 342 | "env=ttt()\n", 343 | "# test ten games\n", 344 | "for i in range(10):\n", 345 | " state=env.reset() \n", 346 | " while True: \n", 347 | " # AlphaGo moves first\n", 348 | " action=alphago(env,weight,depth,PG_net,value_net,\n", 349 | " fast_net, policy_rollout=True,num_rollouts=100)\n", 350 | " state,reward,done,_=env.step(action) \n", 351 | " if done:\n", 352 | " if reward==0:\n", 353 | " print(\"The game is tied!\")\n", 354 | " else:\n", 355 | " print(\"AlphaGo wins!\")\n", 356 | " break \n", 357 | " # move recommended by rule-based AI\n", 358 | " action=rule_based_AI(env) \n", 359 | " state,reward,done,_=env.step(action)\n", 360 | " if done:\n", 361 | " print(\"Rule-based AI wins!\")\n", 362 | " break " 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 4, 368 | "id": "13cfa062", 369 | "metadata": {}, 370 | "outputs": [ 371 | { 372 | "name": "stdout", 373 | "output_type": "stream", 374 | "text": [ 375 | "The game is tied!\n", 376 | "The game is tied!\n", 377 | "The game is tied!\n", 378 | "The game is tied!\n", 379 | "The game is tied!\n", 380 | "The game is tied!\n", 381 | "The game is tied!\n", 382 | "The game is tied!\n", 383 | "The game is tied!\n", 384 | "The game is tied!\n" 385 | ] 386 | } 387 | ], 388 | "source": [ 389 | "# test ten games\n", 390 | "for i in range(10):\n", 391 | " state=env.reset() \n", 392 | " while True: \n", 393 | " # Rule-based AI moves first\n", 394 | " action=rule_based_AI(env)\n", 395 | " state,reward,done,_=env.step(action) \n", 396 | " if done:\n", 397 | " if reward==0:\n", 398 | " print(\"The game is tied!\")\n", 399 | " else:\n", 400 | " print(\"Rule-based AI wins!\")\n", 401 | " break \n", 402 | " # AlphaGo moves second\n", 403 | " action=alphago(env,weight,depth,PG_net,value_net,\n", 404 | " fast_net, policy_rollout=True,num_rollouts=100) \n", 405 | " state,reward,done,_=env.step(action)\n", 406 | " if done:\n", 407 | " print(\"AlphaGo wins!\")\n", 408 | " break " 409 | ] 410 | }, 411 | { 412 | "cell_type": "markdown", 413 | "id": "e53dc8fc", 414 | "metadata": {}, 415 | "source": [ 416 | "# 3. Redundancy in AlphaGo\n" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 5, 422 | "id": "91bdca04", 423 | "metadata": {}, 424 | "outputs": [], 425 | "source": [ 426 | "weight=0.8\n", 427 | "depth=10\n", 428 | "# create a list to record game outcome\n", 429 | "results=[]\n", 430 | "# test ten games\n", 431 | "for i in range(20):\n", 432 | " state=env.reset() \n", 433 | " if i%2==0:\n", 434 | " # Ruled-based AI moves\n", 435 | " action=rule_based_AI(env)\n", 436 | " state,reward,done,_=env.step(action) \n", 437 | " while True: \n", 438 | " # AlphaGo moves \n", 439 | " action=alphago(env,weight,depth,PG_net,value_net,\n", 440 | " fast_net, policy_rollout=True,num_rollouts=100) \n", 441 | " state,reward,done,_=env.step(action)\n", 442 | " if done:\n", 443 | " results.append(abs(reward))\n", 444 | " break \n", 445 | " # Rule-based AI moves\n", 446 | " action=rule_based_AI(env)\n", 447 | " state,reward,done,_=env.step(action) \n", 448 | " if done:\n", 449 | " results.append(-abs(reward))\n", 450 | " break " 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": 6, 456 | "id": "8974da15", 457 | "metadata": {}, 458 | "outputs": [ 459 | { 460 | "name": "stdout", 461 | "output_type": "stream", 462 | "text": [ 463 | "AlphaGo won 0 games\n", 464 | "AlphaGo lost 0 games\n", 465 | "the game was tied 20 times\n" 466 | ] 467 | } 468 | ], 469 | "source": [ 470 | "# count how many times AlphaGo won\n", 471 | "wins=results.count(1)\n", 472 | "print(f\"AlphaGo won {wins} games\")\n", 473 | "# count how many times AlphaGo lost\n", 474 | "losses=results.count(-1)\n", 475 | "print(f\"AlphaGo lost {losses} games\")\n", 476 | "# count tie games\n", 477 | "ties=results.count(0)\n", 478 | "print(f\"the game was tied {ties} times\")" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 7, 484 | "id": "e5d97e1a", 485 | "metadata": {}, 486 | "outputs": [], 487 | "source": [ 488 | "weight=0.9\n", 489 | "depth=10\n", 490 | "# create a list to record game outcome\n", 491 | "results=[]\n", 492 | "# test ten games\n", 493 | "for i in range(20):\n", 494 | " state=env.reset() \n", 495 | " if i%2==0:\n", 496 | " # Ruled-based AI moves\n", 497 | " action=rule_based_AI(env)\n", 498 | " state,reward,done,_=env.step(action) \n", 499 | " while True: \n", 500 | " # AlphaGo moves, setting policy_rollout=False \n", 501 | " action=alphago(env,weight,depth,PG_net,value_net,\n", 502 | " fast_net, policy_rollout=False,num_rollouts=100) \n", 503 | " state,reward,done,_=env.step(action)\n", 504 | " if done:\n", 505 | " results.append(abs(reward))\n", 506 | " break \n", 507 | " # Rule-based AI moves\n", 508 | " action=rule_based_AI(env)\n", 509 | " state,reward,done,_=env.step(action) \n", 510 | " if done:\n", 511 | " results.append(-abs(reward))\n", 512 | " break " 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": 8, 518 | "id": "397463d5", 519 | "metadata": {}, 520 | "outputs": [ 521 | { 522 | "name": "stdout", 523 | "output_type": "stream", 524 | "text": [ 525 | "AlphaGo won 0 games\n", 526 | "AlphaGo lost 0 games\n", 527 | "the game was tied 20 times\n" 528 | ] 529 | } 530 | ], 531 | "source": [ 532 | "# count how many times AlphaGo won\n", 533 | "wins=results.count(1)\n", 534 | "print(f\"AlphaGo won {wins} games\")\n", 535 | "# count how many times AlphaGo lost\n", 536 | "losses=results.count(-1)\n", 537 | "print(f\"AlphaGo lost {losses} games\")\n", 538 | "# count tie games\n", 539 | "ties=results.count(0)\n", 540 | "print(f\"the game was tied {ties} times\")" 541 | ] 542 | } 543 | ], 544 | "metadata": { 545 | "kernelspec": { 546 | "display_name": "Python 3 (ipykernel)", 547 | "language": "python", 548 | "name": "python3" 549 | }, 550 | "language_info": { 551 | "codemirror_mode": { 552 | "name": "ipython", 553 | "version": 3 554 | }, 555 | "file_extension": ".py", 556 | "mimetype": "text/x-python", 557 | "name": "python", 558 | "nbconvert_exporter": "python", 559 | "pygments_lexer": "ipython3", 560 | "version": "3.9.12" 561 | } 562 | }, 563 | "nbformat": 4, 564 | "nbformat_minor": 5 565 | } 566 | -------------------------------------------------------------------------------- /files/CONNzero4.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markhliu/AlphaGoSimplified/9dee6f43230a2dcd324fc700a85fd69fcfe76082/files/CONNzero4.h5 -------------------------------------------------------------------------------- /files/PG_coin.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markhliu/AlphaGoSimplified/9dee6f43230a2dcd324fc700a85fd69fcfe76082/files/PG_coin.h5 -------------------------------------------------------------------------------- /files/PG_conn.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markhliu/AlphaGoSimplified/9dee6f43230a2dcd324fc700a85fd69fcfe76082/files/PG_conn.h5 -------------------------------------------------------------------------------- /files/ac_coin.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markhliu/AlphaGoSimplified/9dee6f43230a2dcd324fc700a85fd69fcfe76082/files/ac_coin.h5 -------------------------------------------------------------------------------- /files/coin_Qs.csv: -------------------------------------------------------------------------------- 1 | 0.000000000000000000e+00,0.000000000000000000e+00 2 | 1.000000000000000000e+00,1.000000000000000000e+00 3 | -1.000000000000000000e+00,1.000000000000000000e+00 4 | -1.000000000000000000e+00,-1.000000000000000000e+00 5 | 9.499999999999999556e-01,-1.000000000000000000e+00 6 | -8.803999999999999604e-01,9.499999999999999556e-01 7 | -9.499999999999999556e-01,-9.499999999999999556e-01 8 | 9.024999999999999689e-01,-9.210000000000000409e-01 9 | -8.678000000000000158e-01,9.024999999999999689e-01 10 | -9.024999999999999689e-01,-9.024999999999999689e-01 11 | 8.572999999999999510e-01,-8.944999999999999618e-01 12 | -8.367999999999999883e-01,8.572999999999999510e-01 13 | -8.574000000000000510e-01,-8.574000000000000510e-01 14 | 8.144000000000000128e-01,-8.540999999999999703e-01 15 | -8.104000000000000092e-01,8.144000000000000128e-01 16 | -8.145000000000000018e-01,-8.145000000000000018e-01 17 | 7.737000000000000544e-01,-8.142000000000000348e-01 18 | -7.728000000000000425e-01,7.737000000000000544e-01 19 | -7.700000000000000178e-01,-7.701000000000000068e-01 20 | 7.349999999999999867e-01,-7.737000000000000544e-01 21 | -7.278000000000000025e-01,7.349999999999999867e-01 22 | 0.000000000000000000e+00,0.000000000000000000e+00 23 | -------------------------------------------------------------------------------- /files/fast_coin.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markhliu/AlphaGoSimplified/9dee6f43230a2dcd324fc700a85fd69fcfe76082/files/fast_coin.h5 -------------------------------------------------------------------------------- /files/fast_ttt.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markhliu/AlphaGoSimplified/9dee6f43230a2dcd324fc700a85fd69fcfe76082/files/fast_ttt.h5 -------------------------------------------------------------------------------- /files/pg_ttt.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markhliu/AlphaGoSimplified/9dee6f43230a2dcd324fc700a85fd69fcfe76082/files/pg_ttt.h5 -------------------------------------------------------------------------------- /files/policy_conn.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markhliu/AlphaGoSimplified/9dee6f43230a2dcd324fc700a85fd69fcfe76082/files/policy_conn.h5 -------------------------------------------------------------------------------- /files/strong_coin.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markhliu/AlphaGoSimplified/9dee6f43230a2dcd324fc700a85fd69fcfe76082/files/strong_coin.h5 -------------------------------------------------------------------------------- /files/strong_ttt.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markhliu/AlphaGoSimplified/9dee6f43230a2dcd324fc700a85fd69fcfe76082/files/strong_ttt.h5 -------------------------------------------------------------------------------- /files/value_coin.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markhliu/AlphaGoSimplified/9dee6f43230a2dcd324fc700a85fd69fcfe76082/files/value_coin.h5 -------------------------------------------------------------------------------- /files/value_conn.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markhliu/AlphaGoSimplified/9dee6f43230a2dcd324fc700a85fd69fcfe76082/files/value_conn.h5 -------------------------------------------------------------------------------- /files/value_ttt.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markhliu/AlphaGoSimplified/9dee6f43230a2dcd324fc700a85fd69fcfe76082/files/value_ttt.h5 -------------------------------------------------------------------------------- /files/zero_ttt4.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markhliu/AlphaGoSimplified/9dee6f43230a2dcd324fc700a85fd69fcfe76082/files/zero_ttt4.h5 -------------------------------------------------------------------------------- /tmp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "71d53d56", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [] 10 | } 11 | ], 12 | "metadata": { 13 | "kernelspec": { 14 | "display_name": "Python 3 (ipykernel)", 15 | "language": "python", 16 | "name": "python3" 17 | }, 18 | "language_info": { 19 | "codemirror_mode": { 20 | "name": "ipython", 21 | "version": 3 22 | }, 23 | "file_extension": ".py", 24 | "mimetype": "text/x-python", 25 | "name": "python", 26 | "nbconvert_exporter": "python", 27 | "pygments_lexer": "ipython3", 28 | "version": "3.11.3" 29 | } 30 | }, 31 | "nbformat": 4, 32 | "nbformat_minor": 5 33 | } 34 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markhliu/AlphaGoSimplified/9dee6f43230a2dcd324fc700a85fd69fcfe76082/utils/__init__.py -------------------------------------------------------------------------------- /utils/cash.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markhliu/AlphaGoSimplified/9dee6f43230a2dcd324fc700a85fd69fcfe76082/utils/cash.png -------------------------------------------------------------------------------- /utils/ch01util.py: -------------------------------------------------------------------------------- 1 | import random 2 | from utils.coin_env import coin_game 3 | 4 | 5 | def rule_based_AI(env): 6 | if env.state%3 != 0: 7 | move = env.state%3 8 | else: 9 | move = env.sample() 10 | return move 11 | 12 | 13 | def random_player(env): 14 | move = env.sample() 15 | return move 16 | 17 | 18 | # Define the one_coin_game() function 19 | def one_coin_game(player1, player2): 20 | env = coin_game() 21 | env.reset() 22 | while True: 23 | action = player1(env) 24 | new_state, reward, done, info = env.step(action) 25 | if done: 26 | break 27 | action = player2(env) 28 | new_state, reward, done, info = env.step(action) 29 | if done: 30 | break 31 | return reward -------------------------------------------------------------------------------- /utils/ch02util.py: -------------------------------------------------------------------------------- 1 | from utils.ttt_env import ttt 2 | from copy import deepcopy 3 | 4 | def ttt_think1(env): 5 | # iterate through all possible next moves 6 | for m in env.validinputs: 7 | # make a hypothetical move 8 | env_copy=deepcopy(env) 9 | state,reward,done,_=env_copy.step(m) 10 | # if reward is 1 or -1, current player wins 11 | if done and abs(reward)==1: 12 | # take the winning move 13 | return m 14 | # otherwise, randomly select a move 15 | return env.sample() 16 | 17 | def ttt_random(env): 18 | move = env.sample() 19 | return move 20 | 21 | def ttt_manual(env): 22 | print(f"game state is \n{env.state.reshape(3,3)[::-1]}") 23 | while True: 24 | move = input(f"Player {env.turn}, what's your move?") 25 | try: 26 | move = int(move) 27 | except: 28 | print("the move must be a number") 29 | if move in env.validinputs: 30 | return move 31 | else: 32 | print("please enter a valid move") 33 | 34 | # Define the one_ttt_game() function 35 | def one_ttt_game(player1, player2): 36 | env = ttt() 37 | env.reset() 38 | while True: 39 | action = player1(env) 40 | state, reward, done, _ = env.step(action) 41 | if done: 42 | break 43 | action = player2(env) 44 | state, reward, done, _ = env.step(action) 45 | if done: 46 | break 47 | return reward 48 | 49 | 50 | def ttt_think2(env): 51 | # iterate through all possible next moves 52 | for m in env.validinputs: 53 | # make a hypothetical move 54 | env_copy=deepcopy(env) 55 | state,reward,done,_=env_copy.step(m) 56 | # if reward is 1 or -1, current player wins 57 | if done and abs(reward)==1: 58 | # take the winning move 59 | return m 60 | # otherwise, look two moves ahead 61 | for m1 in env.validinputs: 62 | for m2 in env.validinputs: 63 | if m1!=m2: 64 | env_copy=deepcopy(env) 65 | s,r,done,_=env_copy.step(m1) 66 | s,r,done,_=env_copy.step(m2) 67 | # block opponent's winning move 68 | if done and r!=0: 69 | return m2 70 | # otherwise, return a random move 71 | return env.sample() 72 | 73 | 74 | def ttt_think3(env): 75 | # iterate through all possible next moves 76 | for m in env.validinputs: 77 | # make a hypothetical move 78 | env_copy=deepcopy(env) 79 | state,reward,done,_=env_copy.step(m) 80 | # if reward is 1 or -1, current player wins 81 | if done and abs(reward)==1: 82 | # take the winning move 83 | return m 84 | # otherwise, look two moves ahead 85 | for m1 in env.validinputs: 86 | for m2 in env.validinputs: 87 | if m1!=m2: 88 | env_copy=deepcopy(env) 89 | s,r,done,_=env_copy.step(m1) 90 | s,r,done,_=env_copy.step(m2) 91 | # block opponent's winning move 92 | if done and r!=0: 93 | return m2 94 | # look three steps ahead 95 | w3=[] 96 | for m1 in env.validinputs: 97 | for m2 in env.validinputs: 98 | for m3 in env.validinputs: 99 | if m1!=m2 and m1!=m3 and m2!=m3: 100 | env_copy=deepcopy(env) 101 | s,r,done,_=env_copy.step(m1) 102 | s,r,done,_=env_copy.step(m2) 103 | s,r,done,_=env_copy.step(m3) 104 | if done and r!=0: 105 | w3.append(m1) 106 | # Choose the most frequent winner 107 | if len(w3)>0: 108 | return max(w3,key=w3.count) 109 | # Return random move otherwise 110 | return env.sample() 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /utils/ch03util.py: -------------------------------------------------------------------------------- 1 | from utils.conn_env import conn 2 | from copy import deepcopy 3 | import random 4 | 5 | def conn_think1(env): 6 | # iterate through all possible next moves 7 | for m in env.validinputs: 8 | # make a hypothetical move 9 | env_copy=deepcopy(env) 10 | state,reward,done,_=env_copy.step(m) 11 | # take the winning move 12 | if done and reward!=0: 13 | return m 14 | return env.sample() 15 | 16 | def conn_random(env): 17 | move = env.sample() 18 | return move 19 | 20 | def conn_manual(env): 21 | print(f"game state is \n{env.state.T[::-1]}") 22 | while True: 23 | move=input(f"{env.turn} player, enter your move:") 24 | try: 25 | move=int(move) 26 | except: 27 | print("the move must be a number") 28 | if move in env.validinputs: 29 | return move 30 | else: 31 | print("please enter a valid move") 32 | 33 | 34 | # Define the one_conn_game() function 35 | def one_conn_game(player1, player2): 36 | env = conn() 37 | env.reset() 38 | while True: 39 | action = player1(env) 40 | state, reward, done, _ = env.step(action) 41 | if done: 42 | break 43 | action = player2(env) 44 | state, reward, done, _ = env.step(action) 45 | if done: 46 | break 47 | return reward 48 | 49 | 50 | def to_avoid(env): 51 | toavoid=[] 52 | # look for ones you should avoid 53 | for m in env.validinputs: 54 | if len(env.occupied[m-1])<=4: 55 | env_copy=deepcopy(env) 56 | s,r,done,_=env_copy.step(m) 57 | s,r,done,_=env_copy.step(m) 58 | if done and r==-1: 59 | toavoid.append(m) 60 | return toavoid 61 | 62 | 63 | 64 | def conn_think2(env): 65 | # iterate through all possible next moves 66 | for m in env.validinputs: 67 | # make a hypothetical move 68 | env_copy=deepcopy(env) 69 | state,reward,done,_=env_copy.step(m) 70 | # take the winning move 71 | if done and reward!=0: 72 | return m 73 | # otherwise, look two moves ahead 74 | # look for ones you should block 75 | for m1 in env.validinputs: 76 | for m2 in env.validinputs: 77 | if m1!=m2: 78 | env_copy=deepcopy(env) 79 | s,r,done,_=env_copy.step(m1) 80 | s,r,done,_=env_copy.step(m2) 81 | # block your opponent's winning move 82 | if done and r!=0: 83 | return m2 84 | # look for ones you should avoid 85 | toavoid=to_avoid(env) 86 | if len(toavoid)>0: 87 | leftovers=[i for i in env.validinputs if i not in toavoid] 88 | if len(leftovers)>0: 89 | return random.choice(leftovers) 90 | # otherwise, return a random move 91 | return env.sample() 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | def conn_think3(env): 103 | # if there is only one valid move left 104 | if len(env.validinputs)==1: 105 | return env.validinputs[0] 106 | # take column 4 if it's empty 107 | if len(env.occupied[3])==0: 108 | return 4 109 | # iterate through all possible next moves 110 | for m in env.validinputs: 111 | # make a hypothetical move 112 | env_copy=deepcopy(env) 113 | state,reward,done,_=env_copy.step(m) 114 | # take the winning move 115 | if done and reward!=0: 116 | return m 117 | # otherwise, look two moves ahead 118 | # look for ones you should block 119 | for m1 in env.validinputs: 120 | for m2 in env.validinputs: 121 | if m1!=m2: 122 | env_copy=deepcopy(env) 123 | s,r,done,_=env_copy.step(m1) 124 | s,r,done,_=env_copy.step(m2) 125 | # block your opponent's winning move 126 | if done and r!=0: 127 | return m2 128 | # look for ones you should avoid 129 | toavoid=to_avoid(env) 130 | if len(toavoid)>0: 131 | leftovers=[i for i in env.validinputs if i not in toavoid] 132 | if len(leftovers)>0: 133 | return random.choice(leftovers) 134 | # look three steps ahead 135 | w3=[] 136 | for m1 in env.validinputs: 137 | for m2 in env.validinputs: 138 | for m3 in env.validinputs: 139 | try: 140 | env_copy=deepcopy(env) 141 | s,r,done,_=env_copy.step(m1) 142 | s,r,done,_=env_copy.step(m2) 143 | s,r,done,_=env_copy.step(m3) 144 | if done and r!=0: 145 | w3.append(m1) 146 | except: 147 | pass 148 | # Choose the most frequent winner 149 | if len(w3)>0: 150 | return max(w3,key=w3.count) 151 | # otherwise, return a random move 152 | return env.sample() 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | -------------------------------------------------------------------------------- /utils/ch04util.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from random import choice 3 | 4 | def MiniMax(env): 5 | # create a list to store winning moves 6 | wins=[] 7 | # iterate through all possible next moves 8 | for m in env.validinputs: 9 | # make a hypothetical move and see what happens 10 | env_copy=deepcopy(env) 11 | new_state, reward, done, info = env_copy.step(m) 12 | # if move m lead to a win now, take it 13 | if done and reward==1: 14 | return m 15 | # see what's the best response from the opponent 16 | opponent_payoff=maximized_payoff(env_copy,reward,done) 17 | # Opponent's payoff is the opposite of your payoff 18 | my_payoff=-opponent_payoff 19 | if my_payoff==1: 20 | wins.append(m) 21 | # pick winning moves if there is any 22 | if len(wins)>0: 23 | return choice(wins) 24 | # otherwise randomly pick 25 | return env.sample() 26 | 27 | 28 | 29 | def maximized_payoff(env, reward, done): 30 | # if the game has ended after the previous player's move 31 | if done: 32 | return -1 33 | # otherwise, search for action to maximize payoff 34 | best_payoff=-2 35 | # iterate through all possible next moves 36 | for m in env.validinputs: 37 | env_copy=deepcopy(env) 38 | new_state,reward,done,info=env_copy.step(m) 39 | # what's the opponent's response 40 | opponent_payoff=maximized_payoff(env_copy, reward, done) 41 | # opponent's payoff is the opposite of your payoff 42 | my_payoff=-opponent_payoff 43 | # update your best payoff 44 | if my_payoff>best_payoff: 45 | best_payoff=my_payoff 46 | return best_payoff 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /utils/ch05util.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from random import choice 3 | 4 | 5 | def MiniMax_X(env): 6 | wins=[] 7 | ties=[] 8 | losses=[] 9 | # iterate through all possible next moves 10 | for m in env.validinputs: 11 | # make a hypothetical move and see what happens 12 | env_copy=deepcopy(env) 13 | state,reward,done,info=env_copy.step(m) 14 | # If player X wins right away with move m, take it. 15 | if done and reward==1: 16 | return m 17 | # See what's the best response from the opponent 18 | opponent_payoff=maximized_payoff(env_copy,reward,done) 19 | # Opponent's payoff is the opposite of your payoff 20 | my_payoff=-opponent_payoff 21 | if my_payoff==1: 22 | wins.append(m) 23 | elif my_payoff==0: 24 | ties.append(m) 25 | else: 26 | losses.append(m) 27 | # pick winning moves if there is any 28 | if len(wins)>0: 29 | return choice(wins) 30 | # otherwise pick tying moves 31 | elif len(ties)>0: 32 | return choice(ties) 33 | return env.sample() 34 | 35 | 36 | def maximized_payoff(env,reward,done): 37 | # if the game has ended after the previous player's move 38 | if done: 39 | # if it's not a tie 40 | if reward!=0: 41 | return -1 42 | else: 43 | return 0 44 | # Otherwise, search for action to maximize payoff 45 | best_payoff=-2 46 | # iterate through all possible moves 47 | for m in env.validinputs: 48 | env_copy=deepcopy(env) 49 | state,reward,done,info=env_copy.step(m) 50 | # If I make this move, what's the opponent's response? 51 | opponent_payoff=maximized_payoff(env_copy,reward,done) 52 | # Opponent's payoff is the opposite of your payoff 53 | my_payoff=-opponent_payoff 54 | # update your best payoff 55 | if my_payoff>best_payoff: 56 | best_payoff=my_payoff 57 | return best_payoff 58 | 59 | def MiniMax_O(env): 60 | wins=[] 61 | ties=[] 62 | losses=[] 63 | # iterate through all possible next moves 64 | for m in env.validinputs: 65 | # make a hypothetical move and see what happens 66 | env_copy=deepcopy(env) 67 | state,reward,done,info=env_copy.step(m) 68 | # If player O wins right away with move m, take it. 69 | if done and reward==-1: 70 | return m 71 | # See what's the best response from the opponent 72 | opponent_payoff=maximized_payoff(env_copy,reward,done) 73 | # Opponent's payoff is the opposite of your payoff 74 | my_payoff=-opponent_payoff 75 | if my_payoff==1: 76 | wins.append(m) 77 | elif my_payoff==0: 78 | ties.append(m) 79 | else: 80 | losses.append(m) 81 | # pick winning moves if there is any 82 | if len(wins)>0: 83 | return choice(wins) 84 | # otherwise pick tying moves 85 | elif len(ties)>0: 86 | return choice(ties) 87 | return env.sample() 88 | 89 | 90 | def max_payoff(env,reward,done,depth): 91 | # if the game has ended after the previous player's move 92 | if done: 93 | # if it's not a tie 94 | if reward!=0: 95 | return -1 96 | else: 97 | return 0 98 | # If the maximum depth is reached, assume tie game 99 | if depth==0: 100 | return 0 101 | # Otherwise, search for action to maximize payoff 102 | best_payoff=-2 103 | # iterate through all possible moves 104 | for m in env.validinputs: 105 | env_copy=deepcopy(env) 106 | state,reward,done,info=env_copy.step(m) 107 | # If I make this move, what's the opponent's response? 108 | opponent_payoff=max_payoff(env_copy,reward,done,depth-1) 109 | # Opponent's payoff is the opposite of your payoff 110 | my_payoff=-opponent_payoff 111 | # update your best payoff 112 | if my_payoff>best_payoff: 113 | best_payoff=my_payoff 114 | return best_payoff 115 | 116 | def MiniMax_depth(env,depth=3): 117 | wins=[] 118 | ties=[] 119 | losses=[] 120 | # iterate through all possible next moves 121 | for m in env.validinputs: 122 | # make a hypothetical move and see what happens 123 | env_copy=deepcopy(env) 124 | state,reward,done,info=env_copy.step(m) 125 | if done and reward!=0: 126 | return m 127 | # See what's the best response from the opponent 128 | opponent_payoff=max_payoff(env_copy,reward,done,depth) 129 | # Opponent's payoff is the opposite of your payoff 130 | my_payoff=-opponent_payoff 131 | if my_payoff==1: 132 | wins.append(m) 133 | elif my_payoff==0: 134 | ties.append(m) 135 | else: 136 | losses.append(m) 137 | # pick winning moves if there is any 138 | if len(wins)>0: 139 | return choice(wins) 140 | # otherwise pick tying moves 141 | elif len(ties)>0: 142 | return choice(ties) 143 | return env.sample() 144 | 145 | -------------------------------------------------------------------------------- /utils/ch06util.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from random import choice 3 | 4 | 5 | def maximized_payoff_ttt(env,reward,done,alpha,beta): 6 | # if the game has ended after the previous player's move 7 | if done: 8 | # if it's not a tie 9 | if reward!=0: 10 | return -1 11 | else: 12 | return 0 13 | if alpha==None: 14 | alpha=-2 15 | if beta==None: 16 | beta=-2 17 | if env.turn=="X": 18 | best_payoff = alpha 19 | if env.turn=="O": 20 | best_payoff = beta 21 | # iterate through all possible moves 22 | for m in env.validinputs: 23 | env_copy=deepcopy(env) 24 | state,reward,done,info=env_copy.step(m) 25 | # If I make this move, what's the opponent's response? 26 | opponent_payoff=maximized_payoff_ttt(env_copy,\ 27 | reward,done,alpha,beta) 28 | # Opponent's payoff is the opposite of your payoff 29 | my_payoff=-opponent_payoff 30 | if my_payoff > best_payoff: 31 | best_payoff = my_payoff 32 | if env.turn=="X": 33 | alpha=best_payoff 34 | if env.turn=="O": 35 | beta=best_payoff 36 | # skip the rest of the branch 37 | if alpha>=-beta: 38 | break 39 | return best_payoff 40 | 41 | 42 | def MiniMax_ab(env): 43 | wins=[] 44 | ties=[] 45 | losses=[] 46 | # iterate through all possible next moves 47 | for m in env.validinputs: 48 | # make a hypothetical move and see what happens 49 | env_copy=deepcopy(env) 50 | state,reward,done,info=env_copy.step(m) 51 | # If player X wins right away with move m, take it. 52 | if done and reward!=0: 53 | return m 54 | # See what's the best response from the opponent 55 | opponent_payoff=maximized_payoff_ttt(env_copy,\ 56 | reward,done,-2,-2) 57 | # Opponent's payoff is the opposite of your payoff 58 | my_payoff=-opponent_payoff 59 | if my_payoff==1: 60 | wins.append(m) 61 | elif my_payoff==0: 62 | ties.append(m) 63 | else: 64 | losses.append(m) 65 | # pick winning moves if there is any 66 | if len(wins)>0: 67 | return choice(wins) 68 | # otherwise pick tying moves 69 | elif len(ties)>0: 70 | return choice(ties) 71 | return env.sample() 72 | 73 | 74 | def max_payoff_conn(env,reward,done,depth,alpha,beta): 75 | # if the game has ended after the previous player's move 76 | if done: 77 | # if it's not a tie 78 | if reward!=0: 79 | return -1 80 | else: 81 | return 0 82 | # If the maximum depth is reached, assume tie game 83 | if depth==0: 84 | return 0 85 | if alpha==None: 86 | alpha=-2 87 | if beta==None: 88 | beta=-2 89 | if env.turn=="red": 90 | best_payoff = alpha 91 | if env.turn=="yellow": 92 | best_payoff = beta 93 | # iterate through all possible moves 94 | for m in env.validinputs: 95 | env_copy=deepcopy(env) 96 | state,reward,done,info=env_copy.step(m) 97 | # If I make this move, what's the opponent's response? 98 | opponent_payoff=max_payoff_conn(env_copy,\ 99 | reward,done,depth-1,alpha,beta) 100 | # Opponent's payoff is the opposite of your payoff 101 | my_payoff=-opponent_payoff 102 | if my_payoff > best_payoff: 103 | best_payoff = my_payoff 104 | if env.turn=="red": 105 | alpha=best_payoff 106 | if env.turn=="yellow": 107 | beta=best_payoff 108 | # Skip the rest of the branch 109 | if alpha>=-beta: 110 | break 111 | return best_payoff 112 | 113 | def MiniMax_conn(env,depth=3): 114 | wins=[] 115 | ties=[] 116 | losses=[] 117 | # iterate through all possible next moves 118 | for m in env.validinputs: 119 | # make a hypothetical move and see what happens 120 | env_copy=deepcopy(env) 121 | state,reward,done,info=env_copy.step(m) 122 | # If player X wins right away with move m, take it. 123 | if done and reward!=0: 124 | return m 125 | # See what's the best response from the opponent 126 | opponent_payoff=max_payoff_conn(env_copy,\ 127 | reward,done,depth,-2,-2) 128 | # Opponent's payoff is the opposite of your payoff 129 | my_payoff=-opponent_payoff 130 | if my_payoff==1: 131 | wins.append(m) 132 | elif my_payoff==0: 133 | ties.append(m) 134 | else: 135 | losses.append(m) 136 | # pick winning moves if there is any 137 | if len(wins)>0: 138 | return choice(wins) 139 | # otherwise pick tying moves 140 | elif len(ties)>0: 141 | return choice(ties) 142 | return env.sample() 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /utils/ch07util.py: -------------------------------------------------------------------------------- 1 | from random import choice 2 | from copy import deepcopy 3 | 4 | 5 | def position_eval(env,model): 6 | # obtain the current state, reshape it 7 | state=env.state.reshape(-1,7,6,1) 8 | pred=model.predict(state,verbose=0) 9 | # prob(win)-prob(opponent wins) 10 | evaluation=pred[0][1]-pred[0][2] 11 | return evaluation 12 | 13 | 14 | def eval_payoff_conn(env,model,reward,done,depth,alpha,beta): 15 | # if the game has ended after the previous player's move 16 | if done: 17 | # if it's not a tie 18 | if reward!=0: 19 | return -1 20 | else: 21 | return 0 22 | # If the maximum depth is reached, assume tie game 23 | if depth==0: 24 | if env.turn=="red": 25 | return position_eval(env,model) 26 | else: 27 | return -position_eval(env,model) 28 | if alpha==None: 29 | alpha=-2 30 | if beta==None: 31 | beta=-2 32 | if env.turn=="red": 33 | best_payoff = alpha 34 | if env.turn=="yellow": 35 | best_payoff = beta 36 | # iterate through all possible moves 37 | for m in env.validinputs: 38 | env_copy=deepcopy(env) 39 | state,reward,done,info=env_copy.step(m) 40 | # If I make this move, what's the opponent's response? 41 | opponent_payoff=eval_payoff_conn(env_copy,model,\ 42 | reward,done,depth-1,alpha,beta) 43 | # Opponent's payoff is the opposite of your payoff 44 | my_payoff=-opponent_payoff 45 | if my_payoff > best_payoff: 46 | best_payoff = my_payoff 47 | if env.turn=="red": 48 | alpha=best_payoff 49 | if env.turn=="yellow": 50 | beta=best_payoff 51 | if alpha>=-beta: 52 | break 53 | return best_payoff 54 | 55 | 56 | 57 | def MiniMax_conn_eval(env,model,depth=3): 58 | values={} 59 | # iterate through all possible next moves 60 | for m in env.validinputs: 61 | # make a hypothetical move and see what happens 62 | env_copy=deepcopy(env) 63 | state,reward,done,info=env_copy.step(m) 64 | # If current player wins with m, take it. 65 | if done and reward!=0: 66 | return m 67 | # See what's the best response from the opponent 68 | opponent_payoff=eval_payoff_conn(env_copy,\ 69 | model,reward,done,depth,-2,-2) 70 | # Opponent's payoff is the opposite of your payoff 71 | my_payoff=-opponent_payoff 72 | values[m]=my_payoff 73 | # pick the move with the highest value 74 | best_move=max(values,key=values.get) 75 | return best_move 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /utils/ch08util.py: -------------------------------------------------------------------------------- 1 | import random 2 | from copy import deepcopy 3 | 4 | def simulate_a_game(env,counts,wins,losses): 5 | env_copy=deepcopy(env) 6 | actions=[] 7 | # play a full game 8 | while True: 9 | #randomly select a next move 10 | move=random.choice(env_copy.validinputs) 11 | actions.append(deepcopy(move)) 12 | state,reward,done,info=env_copy.step(move) 13 | if done: 14 | counts[actions[0]] += 1 15 | if (reward==1 and env.turn==1) or \ 16 | (reward==-1 and env.turn==2): 17 | wins[actions[0]] += 1 18 | if (reward==-1 and env.turn==1) or \ 19 | (reward==1 and env.turn==2): 20 | losses[actions[0]] += 1 21 | break 22 | return counts, wins, losses 23 | 24 | 25 | def best_move(counts,wins,losses): 26 | # See which action is most promising 27 | scores={} 28 | for k,v in counts.items(): 29 | if v==0: 30 | scores[k]=0 31 | else: 32 | scores[k]=(wins.get(k,0)-losses.get(k,0))/v 33 | return max(scores,key=scores.get) 34 | 35 | 36 | def naive_mcts(env, num_rollouts=10000): 37 | if len(env.validinputs)==1: 38 | return env.validinputs[0] 39 | counts={} 40 | wins={} 41 | losses={} 42 | for move in env.validinputs: 43 | counts[move]=0 44 | wins[move]=0 45 | losses[move]=0 46 | # roll out games 47 | for _ in range(num_rollouts): 48 | counts,wins,losses=simulate_a_game(env,counts,\ 49 | wins, losses) 50 | return best_move(counts,wins,losses) 51 | 52 | 53 | 54 | 55 | from math import sqrt, log 56 | 57 | def select(env,counts,wins,losses,temperature): 58 | # calculate the uct score for all next moves 59 | scores={} 60 | # the ones not visited get the priority 61 | for k in env.validinputs: 62 | if counts[k]==0: 63 | return k 64 | # total number of simulations conducted 65 | N=sum([v for k,v in counts.items()]) 66 | for k,v in counts.items(): 67 | if v==0: 68 | scores[k]=0 69 | else: 70 | # vi for each next move 71 | vi=(wins.get(k,0)-losses.get(k,0))/v 72 | exploration=temperature*sqrt(log(N)/counts[k]) 73 | scores[k]=vi+exploration 74 | # Select the next move with the highest UCT score 75 | return max(scores,key=scores.get) 76 | 77 | 78 | 79 | def expand(env,move): 80 | env_copy=deepcopy(env) 81 | state,reward,done,info=env_copy.step(move) 82 | return env_copy, done, reward 83 | 84 | 85 | 86 | def simulate(env_copy,done,reward): 87 | # if the game has already ended 88 | if done==True: 89 | return reward 90 | while True: 91 | move=env_copy.sample() 92 | state,reward,done,info=env_copy.step(move) 93 | if done==True: 94 | return reward 95 | 96 | 97 | 98 | def backpropagate(env,move,reward,counts,wins,losses): 99 | # add 1 to the total game counts 100 | counts[move]=counts.get(move,0)+1 101 | # if the current player wins 102 | if reward==1 and (env.turn==1 or \ 103 | env.turn=="X" or env.turn=="red"): 104 | wins[move]=wins.get(move,0)+1 105 | if reward==-1 and (env.turn==2 or \ 106 | env.turn=="O" or env.turn=="yellow"): 107 | wins[move]=wins.get(move,0)+1 108 | if reward==-1 and (env.turn==1 or \ 109 | env.turn=="X" or env.turn=="red"): 110 | losses[move]=losses.get(move,0)+1 111 | if reward==1 and (env.turn==2 or \ 112 | env.turn=="O" or env.turn=="yellow"): 113 | losses[move]=losses.get(move,0)+1 114 | return counts,wins,losses 115 | 116 | def next_move(counts,wins,losses): 117 | # See which action is most promising 118 | scores={} 119 | for k,v in counts.items(): 120 | if v==0: 121 | scores[k]=0 122 | else: 123 | scores[k]=(wins.get(k,0)-losses.get(k,0))/v 124 | return max(scores,key=scores.get) 125 | 126 | 127 | def mcts(env,num_rollouts=100,temperature=1.4): 128 | # if there is only one valid move left, take it 129 | if len(env.validinputs)==1: 130 | return env.validinputs[0] 131 | # create three dictionaries counts, wins, losses 132 | counts={} 133 | wins={} 134 | losses={} 135 | for move in env.validinputs: 136 | counts[move]=0 137 | wins[move]=0 138 | losses[move]=0 139 | # roll out games 140 | for _ in range(num_rollouts): 141 | # selection 142 | move=select(env,counts,wins,losses,temperature) 143 | # expansion 144 | env_copy, done, reward=expand(env,move) 145 | # simulation 146 | reward=simulate(env_copy,done,reward) 147 | # backpropagate 148 | counts,wins,losses=backpropagate(\ 149 | env,move,reward,counts,wins,losses) 150 | # make the move 151 | return next_move(counts,wins,losses) 152 | 153 | 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /utils/ch16util.py: -------------------------------------------------------------------------------- 1 | 2 | from copy import deepcopy 3 | import numpy as np 4 | from tensorflow import keras 5 | 6 | 7 | # Load the trained fast policy network from Chapter 9 8 | fast_net=keras.models.load_model("files/fast_coin.h5") 9 | # Load the strengthend strong net from Chapter 13 10 | PG_net=keras.models.load_model("files/PG_coin.h5") 11 | # Load the trained value network from Chapter 13 12 | value_net=keras.models.load_model("files/value_coin.h5") 13 | 14 | 15 | 16 | 17 | def onehot_encoder(state): 18 | onehot=np.zeros((1,22)) 19 | onehot[0,state]=1 20 | return onehot 21 | 22 | def best_move_fast_net(env): 23 | state = env.state 24 | onehot_state = onehot_encoder(state) 25 | action_probs = fast_net(onehot_state) 26 | return np.random.choice([1,2], 27 | p=np.squeeze(action_probs)) 28 | 29 | 30 | def select(priors,env,results,weight): 31 | # weighted average of priors and rollout_value 32 | scores={} 33 | for k,v in results.items(): 34 | # rollout_value for each next move 35 | if len(v)==0: 36 | vi=0 37 | else: 38 | vi=sum(v)/len(v) 39 | # scale the prior by (1+N(L)) 40 | prior=priors[0][k-1]/(1+len(v)) 41 | # calculate weighted average 42 | scores[k]=weight*prior+(1-weight)*vi 43 | # select child node based on the weighted average 44 | return max(scores,key=scores.get) 45 | 46 | # expand game tree by selecting child node 47 | def expand(env,move): 48 | env_copy=deepcopy(env) 49 | state,reward,done,info=env_copy.step(move) 50 | return env_copy, done, reward 51 | 52 | # roll out a game till terminal state or depth reached 53 | def simulate(env_copy,done,reward,depth): 54 | # if the game has already ended 55 | if done==True: 56 | return reward 57 | # select moves based on fast policy network 58 | for _ in range(depth): 59 | move=best_move_fast_net(env_copy) 60 | state,reward,done,info=env_copy.step(move) 61 | # if terminal state is reached, returns outcome 62 | if done==True: 63 | return reward 64 | # depth reached but game not over, evaluate 65 | onehot_state=onehot_encoder(state) 66 | # use the trained value network to evaluate 67 | ps=value_net.predict(onehot_state,verbose=0) 68 | # output is prob(1 wins)-prob(2 wins) 69 | reward=ps[0][1]-ps[0][0] 70 | return reward 71 | 72 | def backpropagate(env,move,reward,results): 73 | # if the current player is Player 1, update results 74 | if env.turn==1: 75 | results[move].append(reward) 76 | # if the current player is Player 2, multiply outcome with -1 77 | elif env.turn==2: 78 | results[move].append(-reward) 79 | return results 80 | 81 | def alphago_coin(env,weight,depth,num_rollouts=100): 82 | # if there is only one valid move left, take it 83 | if len(env.validinputs)==1: 84 | return env.validinputs[0] 85 | # get the prior from the PG policy network 86 | priors = PG_net(onehot_encoder(env.state)) 87 | # create a dictionary results 88 | results={} 89 | for move in env.validinputs: 90 | results[move]=[] 91 | # roll out games 92 | for _ in range(num_rollouts): 93 | # select 94 | move=select(priors,env,results,weight) 95 | # expand 96 | env_copy, done, reward=expand(env,move) 97 | # simulate 98 | reward=simulate(env_copy,done,reward,depth) 99 | # backpropagate 100 | results=backpropagate(env,move,reward,results) 101 | # select the most visited child node 102 | visits={k:len(v) for k,v in results.items()} 103 | return max(visits,key=visits.get) 104 | 105 | -------------------------------------------------------------------------------- /utils/ch17util.py: -------------------------------------------------------------------------------- 1 | 2 | from copy import deepcopy 3 | import numpy as np 4 | 5 | 6 | 7 | 8 | def best_move_fast_net(env, fast_net): 9 | # priors from the policy network 10 | if env.turn=="X": 11 | state = env.state.reshape(-1,3,3,1) 12 | action_probs=fast_net(state) 13 | elif env.turn=="O": 14 | state = env.state.reshape(-1,3,3,1) 15 | action_probs=fast_net(-state) 16 | elif env.turn=="red": 17 | state = env.state.reshape(-1,7,6,1) 18 | action_probs=fast_net(state) 19 | elif env.turn=="yellow": 20 | state = env.state.reshape(-1,7,6,1) 21 | action_probs=fast_net(-state) 22 | ps=[] 23 | for a in sorted(env.validinputs): 24 | ps.append(np.squeeze(action_probs)[a-1]) 25 | ps=np.array(ps) 26 | return np.random.choice(sorted(env.validinputs), 27 | p=ps/ps.sum()) 28 | 29 | 30 | def select(priors,env,results,weight): 31 | # weighted average of priors and rollout_value 32 | scores={} 33 | for k,v in results.items(): 34 | # rollout_value for each next move 35 | if len(v)==0: 36 | vi=0 37 | else: 38 | vi=sum(v)/len(v) 39 | # scale the prior by (1+N(L)) 40 | prior=priors[0][k-1]/(1+len(v)) 41 | # calculate weighted average 42 | scores[k]=weight*prior+(1-weight)*vi 43 | # select child node based on the weighted average 44 | return max(scores,key=scores.get) 45 | 46 | # expand game tree by selecting child node 47 | def expand(env,move): 48 | env_copy=deepcopy(env) 49 | state,reward,done,info=env_copy.step(move) 50 | return env_copy, done, reward 51 | 52 | # roll out a game till terminal state or depth reached 53 | def simulate(env_copy,done,reward,depth,value_net, 54 | fast_net, policy_rollout=True): 55 | # if the game has already ended 56 | if done==True: 57 | return reward 58 | # select moves based on fast policy network 59 | for _ in range(depth): 60 | if policy_rollout: 61 | try: 62 | move=best_move_fast_net(env_copy, fast_net) 63 | except:move=env_copy.sample() 64 | else: 65 | move=env_copy.sample() 66 | state,reward,done,info=env_copy.step(move) 67 | # if terminal state is reached, returns outcome 68 | if done==True: 69 | return reward 70 | # depth reached but game not over, evaluate 71 | if env_copy.turn=="X": 72 | state=state.reshape(-1,3,3,1) 73 | ps=value_net.predict(state,verbose=0) 74 | # reward is prob(X win) - prob(O win) 75 | reward=ps[0][1]-ps[0][2] 76 | elif env_copy.turn=="O": 77 | state=state.reshape(-1,3,3,1) 78 | ps=value_net.predict(-state,verbose=0) 79 | # reward is prob(X win) - prob(O win) 80 | reward=ps[0][2]-ps[0][1] 81 | elif env_copy.turn=="red": 82 | state=state.reshape(-1,7,6,1) 83 | ps=value_net.predict(state,verbose=0) 84 | # reward is prob(red win) - prob(yellow win) 85 | reward=ps[0][1]-ps[0][2] 86 | elif env_copy.turn=="yellow": 87 | state=state.reshape(-1,7,6,1) 88 | ps=value_net.predict(-state,verbose=0) 89 | # reward is prob(red win) - prob(yellow win) 90 | reward=ps[0][2]-ps[0][1] 91 | return reward 92 | 93 | def backpropagate(env,move,reward,results): 94 | # update results 95 | if env.turn=="X" or env.turn=="red": 96 | results[move].append(reward) 97 | # if current player is player 2, multiply outcome with -1 98 | if env.turn=="O" or env.turn=="yellow": 99 | results[move].append(-reward) 100 | return results 101 | 102 | def alphago(env,weight,depth,PG_net,value_net, 103 | fast_net, policy_rollout=True,num_rollouts=100): 104 | # if there is only one valid move left, take it 105 | if len(env.validinputs)==1: 106 | return env.validinputs[0] 107 | # get the prior from the PG policy network 108 | if env.turn=="X" or env.turn=="O": 109 | state = env.state.reshape(-1,9) 110 | conv_state = state.reshape(-1,3,3,1) 111 | if env.turn=="X": 112 | priors = PG_net([state,conv_state]) 113 | elif env.turn=="O": 114 | priors = PG_net([-state,-conv_state]) 115 | if env.turn=="red" or env.turn=="yellow": 116 | state = env.state.reshape(-1,42) 117 | conv_state = state.reshape(-1,7,6,1) 118 | if env.turn=="red": 119 | priors = PG_net([state,conv_state]) 120 | elif env.turn=="yellow": 121 | priors = PG_net([-state,-conv_state]) 122 | # create a dictionary results 123 | results={} 124 | for move in env.validinputs: 125 | results[move]=[] 126 | # roll out games 127 | for _ in range(num_rollouts): 128 | # select 129 | move=select(priors,env,results,weight) 130 | # expand 131 | env_copy, done, reward=expand(env,move) 132 | # simulate 133 | reward=simulate(env_copy,done,reward,depth,value_net, 134 | fast_net, policy_rollout) 135 | # backpropagate 136 | results=backpropagate(env,move,reward,results) 137 | # select the most visited child node 138 | visits={k:len(v) for k,v in results.items()} 139 | return max(visits,key=visits.get) 140 | 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /utils/ch19util.py: -------------------------------------------------------------------------------- 1 | 2 | from copy import deepcopy 3 | import numpy as np 4 | from tensorflow import keras 5 | 6 | # Load the trained actor critic model 7 | model=keras.models.load_model("files/ac_coin.h5") 8 | 9 | def onehot_encoder(state): 10 | onehot=np.zeros((1,22)) 11 | onehot[0,state]=1 12 | return onehot 13 | 14 | 15 | def select(priors,env,results,weight): 16 | # weighted average of priors and rollout_value 17 | scores={} 18 | for k,v in results.items(): 19 | # rollout_value for each next move 20 | if len(v)==0: 21 | vi=0 22 | else: 23 | vi=sum(v)/len(v) 24 | # scale the prior by (1+N(L)) 25 | prior=priors[0][k-1]/(1+len(v)) 26 | # calculate weighted average 27 | scores[k]=weight*prior+(1-weight)*vi 28 | # select child node based on the weighted average 29 | return max(scores,key=scores.get) 30 | 31 | # expand game tree by selecting child node 32 | def expand(env,move): 33 | env_copy=deepcopy(env) 34 | state,reward,done,info=env_copy.step(move) 35 | return env_copy, done, reward 36 | 37 | # roll out a game till terminal state or depth reached 38 | def simulate(env_copy,done,reward,depth): 39 | # if the game has already ended 40 | if done==True: 41 | return reward 42 | for _ in range(depth): 43 | move=env_copy.sample() 44 | state,reward,done,info=env_copy.step(move) 45 | # if terminal state is reached, returns outcome 46 | if done==True: 47 | return reward 48 | # depth reached but game not over, evaluate 49 | onehot_state=onehot_encoder(state) 50 | # use the trained actor critic model to evaluate 51 | action_probs, critic_value = model(onehot_state) 52 | # output is predicted game outcome 53 | return critic_value[0,0] 54 | 55 | def backpropagate(env,move,reward,results): 56 | # if the current player is Player 1, update results 57 | if env.turn==1: 58 | results[move].append(reward) 59 | # if the current player is Player 2, multiply outcome with -1 60 | elif env.turn==2: 61 | results[move].append(-reward) 62 | return results 63 | 64 | def alphazero_coin(env,weight,depth,num_rollouts=100): 65 | # if there is only one valid move left, take it 66 | if len(env.validinputs)==1: 67 | return env.validinputs[0] 68 | # get the prior from the AC policy network 69 | onehot_state = onehot_encoder(env.state) 70 | priors, _ = model(onehot_state) 71 | # create a dictionary results 72 | results={} 73 | for move in env.validinputs: 74 | results[move]=[] 75 | # roll out games 76 | for _ in range(num_rollouts): 77 | # select 78 | move=select(priors,env,results,weight) 79 | # expand 80 | env_copy, done, reward=expand(env,move) 81 | # simulate 82 | reward=simulate(env_copy,done,reward,depth) 83 | # backpropagate 84 | results=backpropagate(env,move,reward,results) 85 | # select the most visited child node 86 | visits={k:len(v) for k,v in results.items()} 87 | return max(visits,key=visits.get) 88 | 89 | -------------------------------------------------------------------------------- /utils/ch20util.py: -------------------------------------------------------------------------------- 1 | 2 | from copy import deepcopy 3 | import numpy as np 4 | from tensorflow import keras 5 | from math import sqrt, log, inf 6 | 7 | def select(priors,env,results,weight): 8 | # weighted average of priors and rollout_value 9 | scores={} 10 | for k,v in results.items(): 11 | # rollout_value for each next move 12 | if len(v)==0: 13 | vi=0 14 | else: 15 | vi=sum(v)/len(v); 16 | # scale the prior by (1+N(L)) 17 | prior=priors[0][k-1]/(1+len(v)) 18 | # calculate weighted average 19 | scores[k]=weight*prior+(1-weight)*vi 20 | # select child node based on the weighted average 21 | return max(scores,key=scores.get) 22 | 23 | # expand game tree by selecting child node 24 | def expand(env,move): 25 | env_copy=deepcopy(env) 26 | state,reward,done,info=env_copy.step(move) 27 | return env_copy, done, reward 28 | 29 | # roll out a game till terminal state or depth reached 30 | def simulate(env_copy,done,reward): 31 | # if the game has already ended 32 | if done==True: 33 | return reward 34 | # select moves based on fast policy network 35 | while True: 36 | move=env_copy.sample() 37 | state,reward,done,info=env_copy.step(move) 38 | # if terminal state is reached, returns outcome 39 | if done==True: 40 | return reward 41 | 42 | def backpropagate(env,move,reward,results): 43 | # update results 44 | if env.turn=="X" or env.turn=="red": 45 | results[move].append(reward) 46 | # if current player is player 2, 47 | # multiply outcome with -1 48 | if env.turn=="O" or env.turn=="yellow": 49 | results[move].append(-reward) 50 | return results 51 | 52 | def alphazero(env,weight,PG_net,num_rollouts=100): 53 | # if there is only one valid move left, take it 54 | if len(env.validinputs)==1: 55 | return env.validinputs[0] 56 | # get the prior from the PG policy network 57 | if env.turn=="X" or env.turn=="O": 58 | state = env.state.reshape(-1,9) 59 | conv_state = state.reshape(-1,3,3,1) 60 | if env.turn=="X": 61 | priors = PG_net([state,conv_state]) 62 | elif env.turn=="O": 63 | priors = PG_net([-state,-conv_state]) 64 | if env.turn=="red" or env.turn=="yellow": 65 | state = env.state.reshape(-1,42) 66 | conv_state = state.reshape(-1,7,6,1) 67 | if env.turn=="red": 68 | priors = PG_net([state,conv_state]) 69 | elif env.turn=="yellow": 70 | priors = PG_net([-state,-conv_state]) 71 | # create a dictionary results 72 | results={} 73 | for move in env.validinputs: 74 | results[move]=[] 75 | # roll out games 76 | for _ in range(num_rollouts): 77 | # select 78 | move=select(priors,env,results,weight) 79 | # expand 80 | env_copy, done, reward=expand(env,move) 81 | # simulate 82 | reward=simulate(env_copy,done,reward) 83 | # backpropagate 84 | results=backpropagate(env,move,reward,results) 85 | # select the most visited child node 86 | visits={k:len(v) for k,v in results.items()} 87 | return max(visits,key=visits.get) 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /utils/coin_env.py: -------------------------------------------------------------------------------- 1 | import turtle as t 2 | from random import choice 3 | from tkinter import messagebox 4 | from tkinter import PhotoImage 5 | import time 6 | import numpy as np 7 | import random 8 | 9 | # Define an action_space helper class 10 | class action_space: 11 | def __init__(self, n): 12 | self.n = n 13 | 14 | # Define an obervation_space helper class 15 | class observation_space: 16 | def __init__(self, n): 17 | self.shape = (n,) 18 | 19 | class coin_game(): 20 | def __init__(self): 21 | # use the helper action_space class 22 | self.action_space=action_space(2) 23 | # use the helper observation_space class 24 | self.observation_space=observation_space(1) 25 | self.info="" 26 | self.showboard=False 27 | # set the game to the initial state 28 | self.reset() 29 | 30 | def reset(self): 31 | # The player 1 moves first 32 | self.turn = 1 33 | # Count how many coins left 34 | self.state = 21 35 | # Create a list of valid moves 36 | self.validinputs = [1, 2] if self.state>2 else [1] 37 | # Whether the game is over 38 | self.done=False 39 | self.reward=0 40 | self.move=0 41 | return self.state 42 | 43 | # sample() function: returns a random move 44 | def sample(self): 45 | return random.choice(self.validinputs) 46 | 47 | # step() function: make a move and update state 48 | def step(self, inp): 49 | # choice of the move must be 1 or 2 50 | assert int(inp)==1 or int(inp)==2 51 | if self.state==1: 52 | self.move=1 53 | else: 54 | self.move=int(inp) 55 | # update the state 56 | self.state -= self.move 57 | # check if the player has won the game 58 | if self.state == 0: 59 | self.done=True 60 | # reward is 1 if player 1 won; -1 otherwise 61 | self.reward=2*(self.turn==1)-1 62 | else: 63 | # Give the turn to the other player 64 | if self.turn == 1: 65 | self.turn = 2 66 | else: 67 | self.turn = 1 68 | return self.state, self.reward, self.done, self.info 69 | 70 | def display_board(self): 71 | # Set up the screen 72 | try: 73 | t.setup(600,500,10,70) 74 | except t.Terminator: 75 | t.setup(600,500,10,70) 76 | t.tracer(False) 77 | t.hideturtle() 78 | t.bgcolor("lavender") 79 | t.title("Last Coin Standing") 80 | # Create a second turtle to show coins left 81 | self.left = t.Turtle() 82 | self.left.up() 83 | self.left.hideturtle() 84 | self.left.goto(-120,-200) 85 | self.left.write(f"Coins left: {self.state}", font = ('Arial',30,'normal')) 86 | self.left.up() 87 | self.left.goto(-200,150) 88 | self.left.write(f"Player {self.turn}'s turn to move",font = ('Arial',30,'normal')) 89 | # Load a picture of the coin 90 | coin = PhotoImage(file = "utils/cash.png").subsample(10,10) 91 | t.addshape("coin", t.Shape("image", coin)) 92 | # Create 21 coin on screen 93 | self.coins = [0]*21 94 | for i in range(21): 95 | self.coins[i] = t.Turtle('coin') 96 | self.coins[i].up() 97 | self.coins[i].goto(-150+50*(i//3),50-50*(i%3)) 98 | t.update() 99 | 100 | def render(self): 101 | if self.showboard==False: 102 | self.display_board() 103 | self.showboard=True 104 | # Update the number of coins left 105 | self.left.clear() 106 | self.left.goto(-120,-200) 107 | self.left.write(f"Coins left: {self.state}", font = ('Arial',30,'normal')) 108 | self.left.goto(-200,150) 109 | self.left.write(f"Player {self.turn}'s turn to move",font = ('Arial',30,'normal')) 110 | 111 | # Remove a coin 112 | if self.move==1: 113 | self.coins[self.state].hideturtle() 114 | if self.move==2: 115 | self.coins[self.state+1].hideturtle() 116 | self.coins[self.state].hideturtle() 117 | t.update() 118 | 119 | 120 | def close(self): 121 | time.sleep(1) 122 | try: 123 | t.bye() 124 | except t.Terminator: 125 | print('exit turtle') 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /utils/coin_simple_env.py: -------------------------------------------------------------------------------- 1 | import turtle as t 2 | from random import choice 3 | from tkinter import messagebox 4 | from tkinter import PhotoImage 5 | import time 6 | import numpy as np 7 | import random 8 | 9 | # Define an action_space helper class 10 | class action_space: 11 | def __init__(self, n): 12 | self.n = n 13 | 14 | # Define an obervation_space helper class 15 | class observation_space: 16 | def __init__(self, n): 17 | self.shape = (n,) 18 | 19 | class coin_game(): 20 | def __init__(self): 21 | # use the helper action_space class 22 | self.action_space=action_space(2) 23 | # use the helper observation_space class 24 | self.observation_space=observation_space(1) 25 | self.info="" 26 | # set the game to the initial state 27 | self.reset() 28 | 29 | def reset(self): 30 | # The player 1 moves first 31 | self.turn = 1 32 | # Count how many coins left 33 | self.state = 21 34 | # Create a list of valid moves 35 | self.validinputs = [1, 2] if self.state>2 else [1] 36 | # Whether the game is over 37 | self.done=False 38 | self.reward=0 39 | self.move=0 40 | return self.state 41 | 42 | # sample() function: returns a random move 43 | def sample(self): 44 | return random.choice(self.validinputs) 45 | 46 | # step() function: make a move and update state 47 | def step(self, inp): 48 | # choice of the move must be 1 or 2 49 | assert int(inp)==1 or int(inp)==2 50 | if self.state==1: 51 | self.move=1 52 | else: 53 | self.move=int(inp) 54 | # update the state 55 | self.state -= self.move 56 | # check if the player has won the game 57 | if self.state == 0: 58 | self.done=True 59 | # reward is 1 if player 1 won; -1 otherwise 60 | self.reward=2*(self.turn==1)-1 61 | else: 62 | # Give the turn to the other player 63 | if self.turn == 1: 64 | self.turn = 2 65 | else: 66 | self.turn = 1 67 | return self.state, self.reward, self.done, self.info 68 | 69 | 70 | -------------------------------------------------------------------------------- /utils/conn_env.py: -------------------------------------------------------------------------------- 1 | import turtle as t 2 | from random import choice 3 | import numpy as np 4 | import time 5 | 6 | # Define an action_space helper class 7 | class action_space: 8 | def __init__(self, n): 9 | self.n = n 10 | 11 | # Define an obervation_space helper class 12 | class observation_space: 13 | def __init__(self, row, col): 14 | self.shape = (row, col) 15 | 16 | class conn(): 17 | def __init__(self): 18 | # use the helper action_space class 19 | self.action_space=action_space(7) 20 | # use the helper observation_space class 21 | self.observation_space=observation_space(7,6) 22 | self.info="" 23 | # The x-coordinates of the center of the 7 columns 24 | self.xs = [-300,-200,-100,0,100,200,300] 25 | # The y-coordinates of the center of the 6 rows 26 | self.ys = [-250,-150,-50,50,150,250] 27 | self.showboard=False 28 | self.game_piece=None 29 | self.reset() 30 | 31 | # sample() function: returns a random move 32 | def sample(self): 33 | return choice(self.validinputs) 34 | 35 | def reset(self): 36 | # The X player moves first 37 | self.turn = "red" 38 | # Create a list of valid moves 39 | self.validinputs = [1,2,3,4,5,6,7] 40 | # Create a list of lists to track game pieces 41 | self.occupied = [list(),list(),list(),list(),list(),list(),list()] 42 | # Tracking the state 43 | self.state=np.array([[0,0,0,0,0,0], 44 | [0,0,0,0,0,0], 45 | [0,0,0,0,0,0], 46 | [0,0,0,0,0,0], 47 | [0,0,0,0,0,0], 48 | [0,0,0,0,0,0], 49 | [0,0,0,0,0,0]]) 50 | self.done=False 51 | self.reward=0 52 | return self.state 53 | 54 | # step() function: place piece on board and update state 55 | def step(self, inp): 56 | # Remember the current game piece 57 | self.game_piece=[inp-1, len(self.occupied[inp-1]), self.turn] 58 | # update the state: red is 1 and yellow is -1 59 | self.state[inp-1][len(self.occupied[inp-1])]=2*(self.turn=="red")-1 60 | # Add the move to the occupied list 61 | self.occupied[inp-1].append(self.turn) 62 | 63 | # Update the list of valid moves 64 | if len(self.occupied[inp-1]) == 6 and inp in self.validinputs: 65 | self.validinputs.remove(inp) 66 | # check if the player has won the game 67 | if self.win_game(inp) == True: 68 | self.done=True 69 | # reward is 1 if red won; -1 if yellow won 70 | self.reward=2*(self.turn=="red")-1 71 | self.validinputs=[] 72 | # If all cellls are occupied and no winner, it's a tie 73 | elif len(self.validinputs) == 0: 74 | self.done=True 75 | self.reward=0 76 | else: 77 | # Give the turn to the other player 78 | if self.turn == "red": 79 | self.turn = "yellow" 80 | else: 81 | self.turn = "red" 82 | return self.state, self.reward, self.done, self.info 83 | 84 | # Determine if a player has won the game 85 | # Define a horizontal4() function to check connecting 4 horizontally 86 | def horizontal4(self, x, y): 87 | win = False 88 | for dif in (-3, -2, -1, 0): 89 | try: 90 | if self.occupied[x+dif][y] == self.turn\ 91 | and self.occupied[x+dif+1][y] == self.turn\ 92 | and self.occupied[x+dif+2][y] == self.turn\ 93 | and self.occupied[x+dif+3][y] == self.turn\ 94 | and x+dif >= 0: 95 | win = True 96 | except IndexError: 97 | pass 98 | return win 99 | # Define a vertical4() function to check connecting 4 vertically 100 | def vertical4(self, x, y): 101 | win = False 102 | try: 103 | if self.occupied[x][y] == self.turn\ 104 | and self.occupied[x][y-1] == self.turn\ 105 | and self.occupied[x][y-2] == self.turn\ 106 | and self.occupied[x][y-3] == self.turn\ 107 | and y-3 >= 0: 108 | win = True 109 | except IndexError: 110 | pass 111 | return win 112 | # Define a forward4() function to check connecting 4 diagonally in / shape 113 | def forward4(self, x, y): 114 | win = False 115 | for dif in (-3, -2, -1, 0): 116 | try: 117 | if self.occupied[x+dif][y+dif] == self.turn\ 118 | and self.occupied[x+dif+1][y+dif+1] == self.turn\ 119 | and self.occupied[x+dif+2][y+dif+2] == self.turn\ 120 | and self.occupied[x+dif+3][y+dif+3] == self.turn\ 121 | and x+dif >= 0 and y+dif >= 0: 122 | win = True 123 | except IndexError: 124 | pass 125 | return win 126 | # Define a back4() function to check connecting 4 diagonally in \ shape 127 | def back4(self, x, y): 128 | win = False 129 | for dif in (-3, -2, -1, 0): 130 | try: 131 | if self.occupied[x+dif][y-dif] == self.turn\ 132 | and self.occupied[x+dif+1][y-dif-1] == self.turn\ 133 | and self.occupied[x+dif+2][y-dif-2] == self.turn\ 134 | and self.occupied[x+dif+3][y-dif-3] == self.turn\ 135 | and x+dif >= 0 and y-dif-3 >= 0: 136 | win = True 137 | except IndexError: 138 | pass 139 | return win 140 | 141 | # Define a win_game() function to check if someone wins the game 142 | def win_game(self, inp): 143 | win = False 144 | x = inp-1 145 | y = len(self.occupied[inp-1])-1 146 | # Check all winning possibilities 147 | if self.vertical4(x,y)==True: 148 | win = True 149 | if self.horizontal4(x,y)==True: 150 | win = True 151 | if self.forward4(x,y)==True: 152 | win = True 153 | if self.back4(x,y)==True: 154 | win = True 155 | return win 156 | 157 | def display_board(self): 158 | # Set up the screen 159 | try: 160 | t.setup(730,680, 10, 70) 161 | except: 162 | t.setup(730,680, 10, 70) 163 | t.hideturtle() 164 | t.tracer(False) 165 | t.title("Connect Four in Turtle Graphics") 166 | # Draw frame 167 | t.pensize(5) 168 | t.up() 169 | t.goto(-350,-300) 170 | t.down() 171 | t.begin_fill() 172 | t.color("black", "blue") 173 | t.forward(700) 174 | t.left(90) 175 | t.forward(600) 176 | t.left(90) 177 | t.forward(700) 178 | t.left(90) 179 | t.forward(600) 180 | t.left(90) 181 | t.end_fill() 182 | t.up() 183 | # Write column numbers on the board 184 | colnum = 1 185 | for x in range(-300, 350, 100): 186 | t.goto(x,300) 187 | t.write(colnum,font = ('Arial',20,'normal')) 188 | t.goto(x,-330) 189 | t.write(colnum,font = ('Arial',20,'normal')) 190 | colnum += 1 191 | # Show white cells 192 | for col in range(7): 193 | for row in range(6): 194 | t.up() 195 | t.goto(self.xs[col],self.ys[row]) 196 | t.dot(80,"white") 197 | t.update() 198 | # Create a second turtle to show disc falling 199 | self.fall = t.Turtle() 200 | self.fall.up() 201 | self.fall.hideturtle() 202 | 203 | def render(self): 204 | if self.showboard==False: 205 | self.display_board() 206 | self.showboard=True 207 | 208 | if self.game_piece is not None: 209 | # Show the disc fall from the top 210 | col, row, c = self.game_piece 211 | if row<7: 212 | for i in range(6,row+1,-1): 213 | self.fall.goto(self.xs[col],self.ys[i-1]) 214 | self.fall.dot(80,c) 215 | t.update() 216 | time.sleep(0.05) 217 | self.fall.clear() 218 | # Go to the cell and place a dot of the player's color 219 | t.up() 220 | t.goto(self.xs[col],self.ys[row]) 221 | t.dot(80,c) 222 | t.update() 223 | def close(self): 224 | time.sleep(1) 225 | try: 226 | t.bye() 227 | except t.Terminator: 228 | print('exit turtle') 229 | 230 | 231 | 232 | 233 | 234 | 235 | -------------------------------------------------------------------------------- /utils/conn_simple_env.py: -------------------------------------------------------------------------------- 1 | import turtle as t 2 | from random import choice 3 | import numpy as np 4 | import time 5 | 6 | # Define an action_space helper class 7 | class action_space: 8 | def __init__(self, n): 9 | self.n = n 10 | 11 | # Define an obervation_space helper class 12 | class observation_space: 13 | def __init__(self, row, col): 14 | self.shape = (row, col) 15 | 16 | class conn(): 17 | def __init__(self): 18 | # use the helper action_space class 19 | self.action_space=action_space(7) 20 | # use the helper observation_space class 21 | self.observation_space=observation_space(7,6) 22 | self.info="" 23 | # The x-coordinates of the center of the 7 columns 24 | self.xs = [-300,-200,-100,0,100,200,300] 25 | # The y-coordinates of the center of the 6 rows 26 | self.ys = [-250,-150,-50,50,150,250] 27 | self.game_piece=None 28 | self.reset() 29 | 30 | # sample() function: returns a random move 31 | def sample(self): 32 | return choice(self.validinputs) 33 | 34 | def reset(self): 35 | # The X player moves first 36 | self.turn = "red" 37 | # Create a list of valid moves 38 | self.validinputs = [1,2,3,4,5,6,7] 39 | # Create a list of lists to track game pieces 40 | self.occupied = [list(),list(),list(),list(),list(),list(),list()] 41 | # Tracking the state 42 | self.state=np.array([[0,0,0,0,0,0], 43 | [0,0,0,0,0,0], 44 | [0,0,0,0,0,0], 45 | [0,0,0,0,0,0], 46 | [0,0,0,0,0,0], 47 | [0,0,0,0,0,0], 48 | [0,0,0,0,0,0]]) 49 | self.done=False 50 | self.reward=0 51 | return self.state 52 | 53 | # step() function: place piece on board and update state 54 | def step(self, inp): 55 | # Remember the current game piece 56 | self.game_piece=[inp-1, len(self.occupied[inp-1]), self.turn] 57 | # update the state: red is 1 and yellow is -1 58 | self.state[inp-1][len(self.occupied[inp-1])]=2*(self.turn=="red")-1 59 | # Add the move to the occupied list 60 | self.occupied[inp-1].append(self.turn) 61 | 62 | # Update the list of valid moves 63 | if len(self.occupied[inp-1]) == 6 and inp in self.validinputs: 64 | self.validinputs.remove(inp) 65 | # check if the player has won the game 66 | if self.win_game(inp) == True: 67 | self.done=True 68 | # reward is 1 if red won; -1 if yellow won 69 | self.reward=2*(self.turn=="red")-1 70 | self.validinputs=[] 71 | # If all cellls are occupied and no winner, it's a tie 72 | elif len(self.validinputs) == 0: 73 | self.done=True 74 | self.reward=0 75 | else: 76 | # Give the turn to the other player 77 | if self.turn == "red": 78 | self.turn = "yellow" 79 | else: 80 | self.turn = "red" 81 | return self.state, self.reward, self.done, self.info 82 | 83 | # Determine if a player has won the game 84 | # Define a horizontal4() function to check connecting 4 horizontally 85 | def horizontal4(self, x, y): 86 | win = False 87 | for dif in (-3, -2, -1, 0): 88 | try: 89 | if self.occupied[x+dif][y] == self.turn\ 90 | and self.occupied[x+dif+1][y] == self.turn\ 91 | and self.occupied[x+dif+2][y] == self.turn\ 92 | and self.occupied[x+dif+3][y] == self.turn\ 93 | and x+dif >= 0: 94 | win = True 95 | except IndexError: 96 | pass 97 | return win 98 | # Define a vertical4() function to check connecting 4 vertically 99 | def vertical4(self, x, y): 100 | win = False 101 | try: 102 | if self.occupied[x][y] == self.turn\ 103 | and self.occupied[x][y-1] == self.turn\ 104 | and self.occupied[x][y-2] == self.turn\ 105 | and self.occupied[x][y-3] == self.turn\ 106 | and y-3 >= 0: 107 | win = True 108 | except IndexError: 109 | pass 110 | return win 111 | # Define a forward4() function to check connecting 4 diagonally in / shape 112 | def forward4(self, x, y): 113 | win = False 114 | for dif in (-3, -2, -1, 0): 115 | try: 116 | if self.occupied[x+dif][y+dif] == self.turn\ 117 | and self.occupied[x+dif+1][y+dif+1] == self.turn\ 118 | and self.occupied[x+dif+2][y+dif+2] == self.turn\ 119 | and self.occupied[x+dif+3][y+dif+3] == self.turn\ 120 | and x+dif >= 0 and y+dif >= 0: 121 | win = True 122 | except IndexError: 123 | pass 124 | return win 125 | # Define a back4() function to check connecting 4 diagonally in \ shape 126 | def back4(self, x, y): 127 | win = False 128 | for dif in (-3, -2, -1, 0): 129 | try: 130 | if self.occupied[x+dif][y-dif] == self.turn\ 131 | and self.occupied[x+dif+1][y-dif-1] == self.turn\ 132 | and self.occupied[x+dif+2][y-dif-2] == self.turn\ 133 | and self.occupied[x+dif+3][y-dif-3] == self.turn\ 134 | and x+dif >= 0 and y-dif-3 >= 0: 135 | win = True 136 | except IndexError: 137 | pass 138 | return win 139 | 140 | # Define a win_game() function to check if someone wins the game 141 | def win_game(self, inp): 142 | win = False 143 | x = inp-1 144 | y = len(self.occupied[inp-1])-1 145 | # Check all winning possibilities 146 | if self.vertical4(x,y)==True: 147 | win = True 148 | if self.horizontal4(x,y)==True: 149 | win = True 150 | if self.forward4(x,y)==True: 151 | win = True 152 | if self.back4(x,y)==True: 153 | win = True 154 | return win 155 | 156 | -------------------------------------------------------------------------------- /utils/ttt_env.py: -------------------------------------------------------------------------------- 1 | import turtle as t 2 | import numpy as np 3 | import time 4 | import random 5 | 6 | # Define an action_space helper class 7 | class action_space: 8 | def __init__(self, n): 9 | self.n = n 10 | 11 | # Define an obervation_space helper class 12 | class observation_space: 13 | def __init__(self, n): 14 | self.shape = (n,) 15 | 16 | class ttt(): 17 | def __init__(self): 18 | # use the helper action_space class 19 | self.action_space=action_space(9) 20 | # use the helper observation_space class 21 | self.observation_space=observation_space(9) 22 | self.info="" 23 | self.showboard=False 24 | # Create a dictionary to map cell number to coordinates 25 | self.cellcenter = {1:(-200,-200), 2:(0,-200), 3:(200,-200), 26 | 4:(-200,0), 5:(0,0), 6:(200,0), 27 | 7:(-200,200), 8:(0,200), 9:(200,200)} 28 | # set the game to the initial state 29 | self.reset() 30 | 31 | # sample() function: returns a random move 32 | def sample(self): 33 | return random.choice(self.validinputs) 34 | 35 | def reset(self): 36 | # The X player moves first 37 | self.turn = "X" 38 | # Count how many rounds played 39 | self.rounds = 1 40 | # Create a list of valid moves 41 | self.validinputs = list(self.cellcenter.keys()) 42 | # Create a dictionary of moves made by each player 43 | self.occupied = {"X":[],"O":[]} 44 | # Tracking the state 45 | self.state=np.array([0,0,0,0,0,0,0,0,0]) 46 | self.done=False 47 | self.reward=0 48 | return self.state 49 | 50 | # step() function: place piece on board and update state 51 | def step(self, inp): 52 | # Add the move to the occupied list 53 | self.occupied[self.turn].append(inp) 54 | # update the state: X is 1 and O is -1 55 | self.state[int(inp)-1]=2*(self.turn=="X")-1 56 | # Disallow the move in future rounds 57 | self.validinputs.remove(inp) 58 | # check if the player has won the game 59 | if self.win_game() == True: 60 | self.done=True 61 | # reward is 1 if X won; -1 if O won 62 | self.reward=2*(self.turn=="X")-1 63 | self.validinputs=[] 64 | # If all cellls are occupied and no winner, it's a tie 65 | elif self.rounds == 9: 66 | self.done=True 67 | self.reward=0 68 | self.validinputs=[] 69 | else: 70 | # Counting rounds 71 | self.rounds += 1 72 | # Give the turn to the other player 73 | if self.turn == "X": 74 | self.turn = "O" 75 | else: 76 | self.turn = "X" 77 | return self.state, self.reward, self.done, self.info 78 | 79 | # Determine if a player has won the game 80 | def win_game(self): 81 | lst = self.occupied[self.turn] 82 | if 1 in lst and 2 in lst and 3 in lst: 83 | return True 84 | elif 4 in lst and 5 in lst and 6 in lst: 85 | return True 86 | elif 7 in lst and 8 in lst and 9 in lst: 87 | return True 88 | elif 1 in lst and 4 in lst and 7 in lst: 89 | return True 90 | elif 2 in lst and 5 in lst and 8 in lst: 91 | return True 92 | elif 3 in lst and 6 in lst and 9 in lst: 93 | return True 94 | elif 1 in lst and 5 in lst and 9 in lst: 95 | return True 96 | elif 3 in lst and 5 in lst and 7 in lst: 97 | return True 98 | else: 99 | return False 100 | 101 | 102 | def display_board(self): 103 | # Set up the screen 104 | try: 105 | t.setup(630,630,10,70) 106 | except t.Terminator: 107 | t.setup(630,630,10,70) 108 | t.tracer(False) 109 | t.hideturtle() 110 | t.bgcolor("azure") 111 | t.title("Tic-Tac-Toe in Turtle Graphics") 112 | # Draw horizontal lines and vertical lines 113 | t.pensize(5) 114 | t.color('blue') 115 | for i in (-300,-100,100,300): 116 | t.up() 117 | t.goto(i,-300) 118 | t.down() 119 | t.goto(i,300) 120 | t.up() 121 | t.goto(-300,i) 122 | t.down() 123 | t.goto(300,i) 124 | t.up() 125 | # Write down the cell number 126 | t.color('red') 127 | for cell, center in list(self.cellcenter.items()): 128 | t.goto(center[0]-80,center[1]-80) 129 | t.write(cell,font = ('Arial',20,'normal')) 130 | 131 | def render(self): 132 | if self.showboard==False: 133 | self.display_board() 134 | self.showboard=True 135 | # Place X or O in occupied cells 136 | t.color('light gray') 137 | if len(self.occupied["X"])>0: 138 | for x in self.occupied["X"]: 139 | xy=self.cellcenter[x] 140 | t.up() 141 | t.goto(xy[0]-60,xy[1]-60) 142 | t.down() 143 | t.goto(xy[0]+60,xy[1]+60) 144 | t.up() 145 | t.goto(xy[0]-60,xy[1]+60) 146 | t.down() 147 | t.goto(xy[0]+60,xy[1]-60) 148 | t.up() 149 | t.update() 150 | if len(self.occupied["O"])>0: 151 | for o in self.occupied["O"]: 152 | t.up() 153 | t.goto(self.cellcenter[o]) 154 | t.dot(160,"light gray") 155 | t.update() 156 | 157 | def close(self): 158 | time.sleep(1) 159 | try: 160 | t.bye() 161 | except t.Terminator: 162 | print('exit turtle') 163 | 164 | 165 | 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /utils/ttt_new_env.py: -------------------------------------------------------------------------------- 1 | import turtle as t 2 | from random import choice 3 | import numpy as np 4 | import time 5 | 6 | # Define an action_space helper class 7 | class action_space: 8 | def __init__(self, n): 9 | self.n = n 10 | def sample(self): 11 | num = np.random.choice(range(self.n)) 12 | # covert to 1 to 9 13 | action = 1+num 14 | return action 15 | 16 | # Define an obervation_space helper class 17 | class observation_space: 18 | def __init__(self, n): 19 | self.shape = (n,) 20 | 21 | class ttt(): 22 | def __init__(self): 23 | # use the helper action_space class 24 | self.action_space=action_space(9) 25 | # use the helper observation_space class 26 | self.observation_space=observation_space(9) 27 | self.info="" 28 | # Create a dictionary to map cell number to coordinates 29 | self.cellcenter = {1:(-200,-200), 2:(0,-200), 3:(200,-200), 30 | 4:(-200,0), 5:(0,0), 6:(200,0), 31 | 7:(-200,200), 8:(0,200), 9:(200,200)} 32 | 33 | 34 | def reset(self): 35 | # The X player moves first 36 | self.turn = "X" 37 | # Count how many rounds played 38 | self.rounds = 1 39 | # Create a list of valid moves 40 | self.validinputs = list(self.cellcenter.keys()) 41 | # Create a dictionary of moves made by each player 42 | self.occupied = {"X":[],"O":[]} 43 | # Tracking the state 44 | self.state=np.array([0,0,0,0,0,0,0,0,0]) 45 | self.done=False 46 | self.reward=0 47 | return self.state 48 | 49 | # step() function: place piece on board and update state 50 | def step(self, inp): 51 | # Add the move to the occupied list 52 | self.occupied[self.turn].append(inp) 53 | # update the state: X is 1 and O is -1 54 | self.state[int(inp)-1]=2*(self.turn=="X")-1 55 | # Disallow the move in future rounds 56 | self.validinputs.remove(inp) 57 | # check if the player has won the game 58 | if self.win_game() == True: 59 | self.done=True 60 | # reward is 1 if X won; -1 if O won 61 | self.reward=2*(self.turn=="X")-1 62 | self.validinputs=[] 63 | # If all cellls are occupied and no winner, it's a tie 64 | elif self.rounds == 9: 65 | self.done=True 66 | self.reward=0 67 | self.validinputs=[] 68 | # Counting rounds 69 | self.rounds += 1 70 | # Give the turn to the other player 71 | if self.turn == "X": 72 | self.turn = "O" 73 | else: 74 | self.turn = "X" 75 | return self.state, self.reward, self.done, self.info 76 | 77 | # Determine if a player has won the game 78 | def win_game(self): 79 | lst = self.occupied[self.turn] 80 | if 1 in lst and 2 in lst and 3 in lst: 81 | return True 82 | elif 4 in lst and 5 in lst and 6 in lst: 83 | return True 84 | elif 7 in lst and 8 in lst and 9 in lst: 85 | return True 86 | elif 1 in lst and 4 in lst and 7 in lst: 87 | return True 88 | elif 2 in lst and 5 in lst and 8 in lst: 89 | return True 90 | elif 3 in lst and 6 in lst and 9 in lst: 91 | return True 92 | elif 1 in lst and 5 in lst and 9 in lst: 93 | return True 94 | elif 3 in lst and 5 in lst and 7 in lst: 95 | return True 96 | else: 97 | return False 98 | 99 | 100 | -------------------------------------------------------------------------------- /utils/ttt_simple_env.py: -------------------------------------------------------------------------------- 1 | import turtle as t 2 | import numpy as np 3 | import time 4 | import random 5 | 6 | # Define an action_space helper class 7 | class action_space: 8 | def __init__(self, n): 9 | self.n = n 10 | 11 | # Define an obervation_space helper class 12 | class observation_space: 13 | def __init__(self, n): 14 | self.shape = (n,) 15 | 16 | class ttt(): 17 | def __init__(self): 18 | # use the helper action_space class 19 | self.action_space=action_space(9) 20 | # use the helper observation_space class 21 | self.observation_space=observation_space(9) 22 | self.info="" 23 | # Create a dictionary to map cell number to coordinates 24 | self.cellcenter = {1:(-200,-200), 2:(0,-200), 3:(200,-200), 25 | 4:(-200,0), 5:(0,0), 6:(200,0), 26 | 7:(-200,200), 8:(0,200), 9:(200,200)} 27 | # set the game to the initial state 28 | self.reset() 29 | 30 | # sample() function: returns a random move 31 | def sample(self): 32 | return random.choice(self.validinputs) 33 | 34 | def reset(self): 35 | # The X player moves first 36 | self.turn = "X" 37 | # Count how many rounds played 38 | self.rounds = 1 39 | # Create a list of valid moves 40 | self.validinputs = list(self.cellcenter.keys()) 41 | # Create a dictionary of moves made by each player 42 | self.occupied = {"X":[],"O":[]} 43 | # Tracking the state 44 | self.state=np.array([0,0,0,0,0,0,0,0,0]) 45 | self.done=False 46 | self.reward=0 47 | return self.state 48 | 49 | # step() function: place piece on board and update state 50 | def step(self, inp): 51 | # Add the move to the occupied list 52 | self.occupied[self.turn].append(inp) 53 | # update the state: X is 1 and O is -1 54 | self.state[int(inp)-1]=2*(self.turn=="X")-1 55 | # Disallow the move in future rounds 56 | self.validinputs.remove(inp) 57 | # check if the player has won the game 58 | if self.win_game() == True: 59 | self.done=True 60 | # reward is 1 if X won; -1 if O won 61 | self.reward=2*(self.turn=="X")-1 62 | self.validinputs=[] 63 | # If all cellls are occupied and no winner, it's a tie 64 | elif self.rounds == 9: 65 | self.done=True 66 | self.reward=0 67 | self.validinputs=[] 68 | else: 69 | # Counting rounds 70 | self.rounds += 1 71 | # Give the turn to the other player 72 | if self.turn == "X": 73 | self.turn = "O" 74 | else: 75 | self.turn = "X" 76 | return self.state, self.reward, self.done, self.info 77 | 78 | # Determine if a player has won the game 79 | def win_game(self): 80 | lst = self.occupied[self.turn] 81 | if 1 in lst and 2 in lst and 3 in lst: 82 | return True 83 | elif 4 in lst and 5 in lst and 6 in lst: 84 | return True 85 | elif 7 in lst and 8 in lst and 9 in lst: 86 | return True 87 | elif 1 in lst and 4 in lst and 7 in lst: 88 | return True 89 | elif 2 in lst and 5 in lst and 8 in lst: 90 | return True 91 | elif 3 in lst and 6 in lst and 9 in lst: 92 | return True 93 | elif 1 in lst and 5 in lst and 9 in lst: 94 | return True 95 | elif 3 in lst and 5 in lst and 7 in lst: 96 | return True 97 | else: 98 | return False 99 | 100 | 101 | -------------------------------------------------------------------------------- /utils/ttt_think3.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | from copy import deepcopy 4 | 5 | 6 | # moves that leads to X winning now 7 | def check_Xwin1(env): 8 | for m in env.validinputs: 9 | env_copy=deepcopy(env) 10 | state, reward, done, info = env_copy.step(m) 11 | # Winning move for X 12 | if done and reward==1: 13 | return m 14 | return None 15 | # moves leading to O winning next 16 | def check_Owin2(env): 17 | for m1 in env.validinputs: 18 | for m2 in env.validinputs: 19 | if m1!=m2: 20 | env_copy=deepcopy(env) 21 | s,r,done,_=env_copy.step(m1) 22 | s,r,done,_=env_copy.step(m2) 23 | if done and r==-1: 24 | return m2 25 | return None 26 | # look 3 steps ahead 27 | def X_think3(env): 28 | # If there is only one move left, take it 29 | if len(env.validinputs) == 1: 30 | return env.validinputs[0] 31 | # Otherwise, see if there is a winning move 32 | winner=check_Xwin1(env) 33 | if winner is not None: 34 | return winner 35 | # Otherwise, see if there is a losing move 36 | loser=check_Owin2(env) 37 | if loser is not None: 38 | return loser 39 | # If only two moves left, randomly choose 40 | if len(env.validinputs)<=2: 41 | return random.choice(env.validinputs) 42 | # Otherwise, look ahead 43 | w3=[] 44 | for m1 in env.validinputs: 45 | for m2 in env.validinputs: 46 | for m3 in env.validinputs: 47 | if m1!=m2 and m1!=m3 and m2!=m3: 48 | env_copy=deepcopy(env) 49 | s,r,done,_=env_copy.step(m1) 50 | s,r,done,_=env_copy.step(m2) 51 | s,r,done,_=env_copy.step(m3) 52 | if done and r==1: 53 | w3.append(m1) 54 | # Choose the most frequent winner 55 | if len(w3)>0: 56 | return max(set(w3),key=w3.count) 57 | # Take random move otherwise 58 | return random.choice(env.validinputs) 59 | 60 | 61 | 62 | 63 | 64 | def check_Owin1(env): 65 | for m in env.validinputs: 66 | env_copy=deepcopy(env) 67 | state, reward, done, info = env_copy.step(m) 68 | if done and reward==-1: 69 | return m 70 | return None 71 | # moves leading to O winning next 72 | def check_Xwin2(env): 73 | for m1 in env.validinputs: 74 | for m2 in env.validinputs: 75 | if m1!=m2: 76 | env_copy=deepcopy(env) 77 | s,r,done,_=env_copy.step(m1) 78 | s,r,done,_=env_copy.step(m2) 79 | if done and r==1: 80 | return m2 81 | return None 82 | # look 3 steps ahead 83 | def O_think3(env): 84 | if len(env.validinputs) == 1: 85 | return env.validinputs[0] 86 | winner=check_Owin1(env) 87 | if winner is not None: 88 | return winner 89 | # Otherwise, see if there is a losing move 90 | loser=check_Xwin2(env) 91 | if loser is not None: 92 | return loser 93 | # If only two moves left, randomly choose 94 | if len(env.validinputs)<=2: 95 | return random.choice(env.validinputs) 96 | # Otherwise, look ahead 97 | w3=[] 98 | for m1 in env.validinputs: 99 | for m2 in env.validinputs: 100 | for m3 in env.validinputs: 101 | if m1!=m2 and m1!=m3 and m2!=m3: 102 | env_copy=deepcopy(env) 103 | s,r,done,_=env_copy.step(m1) 104 | s,r,done,_=env_copy.step(m2) 105 | s,r,done,_=env_copy.step(m3) 106 | if done and r==-1: 107 | w3.append(m1) 108 | # Choose the most frequent winner 109 | if len(w3)>0: 110 | return max(set(w3),key=w3.count) 111 | # Take random move otherwise 112 | return random.choice(env.validinputs) 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | --------------------------------------------------------------------------------