├── .gitignore ├── README.md ├── docs └── tictactoe_ddqn_prb.png ├── env ├── __init__.py ├── base.py ├── human_player.py ├── punish_illegal_moves.py └── random_player.py ├── play ├── agent_vs_human.py ├── agent_vs_random.py └── game.py ├── policies ├── a2c-cnn ├── dqn-cnn └── dqn-cnn-v1 ├── requirements.txt └── train ├── a2c.py └── dqn.py /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | logs/ 3 | runs/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | venv/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # pytype static type analyzer 137 | .pytype/ 138 | 139 | # Cython debug symbols 140 | cython_debug/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tic-tac-toe RL Agent 2 | 3 | This is a Tic-tac-toe game player trained with Reinforcement Learning. 4 | 5 | The player depends on a simple [backend][1] and a reinforcement learning [library][2] that I made. It starts by knowing nothing about the game and gains knowledge by playing against an automated player that plays randomly. 6 | 7 | Player X is Player 1 and Player O is Player 2, I use them interchangeably. Player X is the agent and Player O is whatever you want it to be. 8 | 9 | My aim in creating this agent is to both understand Reinforcement Learning in terms of whats going on under the hood as well as train an agent that plays optimally. 10 | 11 | Considering the following aspects of Tic-tac-toe: 12 | - It lives in a relatively small space (3x3) as opposed to Chess (8x8) or Go (19x19). 13 | - It is a solved game, meaning that optimal strategies exist and can be coded with hard rules. 14 | 15 | Therefore, it is reasonable to expect Reinforcement Learning to be able to learn an optimal strategy by getting better and better as it observes and learn from experience. 16 | 17 | ## Design 18 | 19 | I have used a [DQN algorithm][4] with the following enhancements: 20 | - Double DQN - calculate state values according to actions selected by the current behavior policy rather than the target network. 21 | - Prioritized Replay Experience. [[paper]][5][[blog post]][6] 22 | 23 | The prioritized replay experience in particular boosted learning in a big way and accelerated learning. 24 | 25 | For the neural network used to predict Q-values, I have used a Convolutional Neural Network since there are spatial relationships between data points (two X's next to eachother matter as opposed to two X's far from each other). 26 | 27 | The environment is based on OpenAI gym and has a reward structure as follows: 28 | - Win: +1 29 | - Lose: -1 30 | - Draw: 0.5 31 | - Step: -0.1 32 | 33 | ## Training and Performance 34 | 35 | One of the metrics that measures the agent's winning stability is the standard deviation of evaluation episodes. For example, when I evaluate over 1000 episodes, I could get an average reward of 0.75 (meaning on average the agent is doing great and winning mostly), however outliers are not detected if the average is observed independently. Hence, the standard deviation is observed for the evaluation list indicating how far off values are from the mean of the list. During training, the standard deviation was getting minimized over time which means there are less outliers and more consistent wins. 36 | 37 | Epsilon was decayed over time from an initial value near 1 (meaning explore almost all the time) to an epsilon limit of 0.1 (meaning act optimally 90% of the time and explore 10% of the time). This encourages exploration at the beginning to gather experience and slowly starts favoring optimal actions over random exploratory actions. 38 | 39 | ![Training](/docs/tictactoe_ddqn_prb.png) 40 | *An agent that was trained for 34 minutes on an experience of around 100k games.* 41 | 42 | The agent's performance currently is far from optimal. It does prioritize blocking the opponent from winning in certain states, however in analogous but different states it prioritizes winning rather than blocking, not realizing that the opponent is one step away from winning the game. Since the reward for intermediate steps is equal, this is almost definitely a problem with the state value. 43 | 44 | There are a few improvements that I can make that should bring it closer to optimality. All improvements but the last one are from the [Rainbow DQN paper][3], namely: 45 | - [ ] Adding noise to the neural network to further increase exploration. 46 | - [ ] Dueling DQN. 47 | - [ ] Periodically save the trained agent and use it *as an opponent* instead of playing against a random player all the time. 48 | 49 | > Any suggestions or improvements are welcome. Please create a github issue or a pull request if you have something to contribute. 50 | 51 | ## Defining Player 2 policy 52 | 53 | The environment `TicTacToeEnv` defined in `env/base.py` allows you to subclass it and define the behavior of player 2: 54 | 55 | ```python 56 | @abc.abstractmethod 57 | def player2_policy(self): 58 | """ 59 | Define player 2 (non-agent) policy 60 | """ 61 | ``` 62 | 63 | For example, this is an environment that makes player 2 play randomly: 64 | 65 | ```python 66 | class Env(TicTacToeEnv): 67 | def player2_policy(self): 68 | random_action = random.choice(self._board.empty_cells) 69 | self._player2.mark(*random_action) 70 | ``` 71 | 72 | ## Play 73 | 74 | To play the game, you can run the following in `play/`: 75 | 76 | ### `agent_vs_random.py` 77 | 78 | Plays a trained agent vs a random player for `num_games` times, default is 1000. Player X is the trained agent and Player O is the random player: 79 | 80 | ```console 81 | $ python -m play.agent_vs_random -algorithm dqn -net_type cnn -policy dqn-cnn -num_games 1000 82 | Game #1000 - Iteration duration: 0.0016622543334960938 83 | 84 | X Won: 750 85 | O Won: 172 86 | Draw: 78 87 | Win percentage: 75.0 88 | Win+Draw percentage: 82.8 89 | Loss percentage: 17.2 90 | ``` 91 | 92 | ### `agent_vs_human.py` 93 | 94 | Allows you to play against your trained agent. Player X (1) is the trained agent and Player O (2) is you. You can pass specify the first player in `fp` flag (example: `-fp=2`; default is player 1). 95 | 96 | If you specify `-debug` then the neural network final layer values are printed, showing you what the agent thinks of each action on the board. 97 | 98 | Cell coordinates start from (0,0) to (2,2): 99 | 100 | ``` 101 | $ python -m play.agent_vs_human -algorithm dqn -net_type cnn -policy dqn-cnn -fp 2 -debug 102 | Enter cell coordinates (e.g. 1,2): 1,1 103 | Player 2: Marking 1 1 104 | | | 105 | ------------ 106 | | O | 107 | ------------ 108 | | | 109 | ------------ 110 | 111 | action distribution: 112 | tensor([[0.7591, 0.7115, 0.8669], 113 | [0.4905, -inf, 0.5601], 114 | [0.7362, 0.5985, 0.7299]], grad_fn=) 115 | action, max(action_dist): 2, 0.866879940032959 116 | 117 | Player 1: Marking 0 2 118 | | | X 119 | ------------ 120 | | O | 121 | ------------ 122 | | | 123 | ------------ 124 | 125 | Enter cell coordinates (e.g. 1,2): 0,0 126 | Player 2: Marking 0 0 127 | O | | X 128 | ------------ 129 | | O | 130 | ------------ 131 | | | 132 | ------------ 133 | 134 | action distribution: 135 | tensor([[ -inf, 0.6825, -inf], 136 | [0.5657, -inf, 0.6621], 137 | [0.4554, 0.5861, 0.8073]], grad_fn=) 138 | action, max(action_dist): 8, 0.8072749376296997 139 | 140 | Player 1: Marking 2 2 141 | O | | X 142 | ------------ 143 | | O | 144 | ------------ 145 | | | X 146 | ------------ 147 | Enter cell coordinates (e.g. 1,2): 148 | ``` 149 | 150 | and so on. 151 | 152 | 153 | [1]: https://github.com/abstractpaper/tictactoe 154 | [2]: https://github.com/abstractpaper/prop 155 | [3]: https://arxiv.org/abs/1710.02298 156 | [4]: https://en.wikipedia.org/wiki/Q-learning#Deep_Q-learning 157 | [5]: https://arxiv.org/abs/1511.05952 158 | [6]: https://danieltakeshi.github.io/2019/07/14/per/ -------------------------------------------------------------------------------- /docs/tictactoe_ddqn_prb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alfoudari/tictactoe-pytorch/db3a8ad6eb70aaa024ad141fe9e0ed430f22507d/docs/tictactoe_ddqn_prb.png -------------------------------------------------------------------------------- /env/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import registry, register, make, spec 2 | 3 | # register all environments here so gym.make() can get them. 4 | 5 | register( 6 | id='TicTacToeBase-v0', 7 | entry_point='env.base:TicTacToeEnv', 8 | reward_threshold=0.7, 9 | ) 10 | 11 | register( 12 | id='TicTacToeHumanPlayer-v0', 13 | entry_point='env.human_player:TicTacToeEnv', 14 | reward_threshold=0.7, 15 | ) 16 | 17 | register( 18 | id='TicTacToeRandomPlayer-v0', 19 | entry_point='env.random_player:TicTacToeEnv', 20 | reward_threshold=0.7, 21 | ) 22 | 23 | register( 24 | id='TicTacToePunishIllegalMoves-v0', 25 | entry_point='env.punish_illegal_moves:TicTacToeEnv', 26 | reward_threshold=0.7, 27 | ) -------------------------------------------------------------------------------- /env/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import random 3 | import gym 4 | import numpy as np 5 | import warnings 6 | 7 | from itertools import chain 8 | from enum import Enum 9 | from gym import spaces 10 | from gym.utils import seeding 11 | from tictactoe import engine as tictactoe 12 | 13 | # https://stackoverflow.com/questions/40659212/futurewarning-elementwise-comparison-failed-returning-scalar-but-in-the-futur 14 | warnings.simplefilter(action='ignore', category=FutureWarning) 15 | 16 | class Player(Enum): 17 | X = 1 18 | O = 2 19 | 20 | class TicTacToeEnv(gym.Env): 21 | environment_name = "TicTacToe Environment" 22 | 23 | def __init__(self, 24 | player1_verbose=False, 25 | player2_verbose=False, 26 | board_verbose=False, 27 | first_player=None, 28 | rewards=dict( 29 | pos_ep=1, 30 | neg_ep=-1, 31 | draw=0.5, 32 | step=-0.1, 33 | ), 34 | thresholds=dict( 35 | win_rate=0.9, 36 | draw_rate=0.1, 37 | )): 38 | # spaces 39 | self.action_space = spaces.Discrete(9) 40 | self.observation_space = spaces.Tuple( 41 | (spaces.Discrete(3),spaces.Discrete(3)) 42 | ) 43 | 44 | # game 45 | self._board = tictactoe.Board(verbose=board_verbose) 46 | self._player1 = tictactoe.Player(board=self._board, side=1, verbose=player1_verbose) # X 47 | self._player2 = tictactoe.Player(board=self._board, side=2, verbose=player2_verbose) # O 48 | self._first_player = first_player 49 | self._done = False 50 | 51 | # rewards 52 | self.rewards = rewards 53 | 54 | # thresholds 55 | self.thresholds = thresholds 56 | s = sum([n for k,n in self.thresholds.items()]) 57 | if s != 1: 58 | raise Exception(f"thresholds must equal to 1; got {s} instead.") 59 | 60 | # stats 61 | self.stats = dict( 62 | games_played=0 63 | ) 64 | 65 | # do stuff 66 | self.seed() 67 | 68 | def step(self, action): 69 | if self._done: 70 | # The last action ended the episode. Start a new episode. 71 | return self.reset() 72 | 73 | # e.g. 74 | # 0 -> (0,0) 75 | # 2 -> (0,2) 76 | # 4 -> (1,1) 77 | coordinates = (int(action/3), int(action % 3)) 78 | 79 | # Player 1 80 | self._player1.mark(*coordinates) 81 | 82 | if self._board.player_won or not self._board.empty_cells: 83 | self._done = True 84 | 85 | # Player 2 86 | if not self._done and self._board.empty_cells: 87 | self.player2_policy() 88 | if self._board.player_won or not self._board.empty_cells: 89 | self._done = True 90 | 91 | if self._done: 92 | if self._board.player_won == self._player1: 93 | reward = self.rewards['pos_ep'] 94 | elif self._board.player_won == self._player2: 95 | reward = self.rewards['neg_ep'] 96 | else: 97 | # draw 98 | reward = self.rewards['draw'] 99 | 100 | # adjust stats 101 | self.stats['games_played'] += 1 102 | else: 103 | reward = self.rewards['step'] 104 | 105 | return (self.state, reward, self._done, dict()) 106 | 107 | def reset(self): 108 | # reset everything 109 | self._board.reset() 110 | self._first_move() 111 | self._done = False 112 | # return an initial observation 113 | return self.state 114 | 115 | def render(self): 116 | pass 117 | 118 | def close(self): 119 | pass 120 | 121 | def seed(self, seed=1): 122 | random.seed(seed) # fixed seed 123 | 124 | @abc.abstractmethod 125 | def player2_policy(self): 126 | """ 127 | Define player 2 (non-agent) policy 128 | """ 129 | 130 | @property 131 | def legal_actions(self): 132 | cells = self._board.empty_cells 133 | actions = [x*3+y for (x,y) in cells] 134 | return actions 135 | 136 | @property 137 | def state(self): 138 | # flatten state 139 | board = self._board.board 140 | return list(chain.from_iterable(board)) 141 | 142 | @property 143 | def observation_space_n(self): 144 | return self.observation_space[0].n * self.observation_space[1].n 145 | 146 | @property 147 | def action_space_n(self): 148 | return self.action_space.n 149 | 150 | @property 151 | def performance_threshold(self): 152 | # ideal moves accumulated rewards 153 | win = 2 * self.rewards['step'] + self.rewards['pos_ep'] 154 | # play till the end and draw 155 | first_move_draw = 4 * self.rewards['step'] + self.rewards['draw'] # 4 moves + final move 156 | second_move_draw = 3 * self.rewards['step'] + self.rewards['draw'] # 3 moves + final move 157 | 158 | # win some % of the time, draw some % of the time 159 | weighed_rewards = win * self.thresholds['win_rate'] + first_move_draw * self.thresholds['draw_rate']/2 + second_move_draw * self.thresholds['draw_rate']/2 160 | 161 | return weighed_rewards 162 | 163 | def _first_move(self): 164 | if self._first_player is None: 165 | # random 166 | self._first_player = random.choice(list(Player)) 167 | 168 | if self._first_player == Player.O: 169 | # Let O (player 2) play first 170 | self.player2_policy() -------------------------------------------------------------------------------- /env/human_player.py: -------------------------------------------------------------------------------- 1 | from .base import TicTacToeEnv as BaseEnv 2 | 3 | class TicTacToeEnv(BaseEnv): 4 | def player2_policy(self): 5 | while True: 6 | action = input("Enter cell coordinates (e.g. 1,2): ") 7 | try: 8 | action = tuple(map(int, action.split(","))) 9 | except: 10 | print("Invalid input") 11 | continue 12 | 13 | if action in self._board.empty_cells: 14 | break 15 | else: 16 | print("Illegal move") 17 | 18 | self._player2.mark(*action) -------------------------------------------------------------------------------- /env/punish_illegal_moves.py: -------------------------------------------------------------------------------- 1 | from .base import TicTacToeEnv as BaseEnv 2 | 3 | class TicTacToeEnv(BaseEnv): 4 | def step(self, action): 5 | if self._done: 6 | # The last action ended the episode. Start a new episode. 7 | return self.reset() 8 | 9 | # e.g. 10 | # 0 -> (0,0) 11 | # 2 -> (0,2) 12 | # 4 -> (1,1) 13 | coordinates = (int(action/3), int(action % 3)) 14 | 15 | # illegal move 16 | illegal_move = coordinates not in self._board.empty_cells 17 | 18 | if not illegal_move: 19 | # Player 1 20 | self._player1.mark(*coordinates) 21 | 22 | if self._board.player_won or not self._board.empty_cells: 23 | self._done = True 24 | 25 | # Player 2 26 | if not self._done and self._board.empty_cells: 27 | self.player2_policy() 28 | if self._board.player_won or not self._board.empty_cells: 29 | self._done = True 30 | 31 | self._board_state_to_int() 32 | 33 | if self._done: 34 | # Win: reward is 1 + # of empty cells remaining, this is to 35 | # encourage efficient strategies (the ones that win faster). 36 | # Loss: rewards -1 - # of empty cells; the agent chances to 37 | # win decrease the longer it plays so early losses are punished. 38 | # Draw: 0 39 | if self._board.player_won == self._player1: 40 | reward = 1+len(self._board.empty_cells) 41 | elif self._board.player_won == self._player2: 42 | reward = -1-len(self._board.empty_cells) 43 | else: 44 | # draw 45 | reward = 1 46 | 47 | return (self._state, reward, True, dict()) 48 | else: 49 | if illegal_move: 50 | # end episode 51 | return (self._state, -5, True, dict()) 52 | else: 53 | return (self._state, 1, False, dict()) -------------------------------------------------------------------------------- /env/random_player.py: -------------------------------------------------------------------------------- 1 | from .base import TicTacToeEnv as BaseEnv 2 | import random 3 | 4 | class TicTacToeEnv(BaseEnv): 5 | def player2_policy(self): 6 | random_action = random.choice(self._board.empty_cells) 7 | self._player2.mark(*random_action) -------------------------------------------------------------------------------- /play/agent_vs_human.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from .game import Game 4 | 5 | if __name__ == '__main__': 6 | # flags 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('-fp', action="store", dest="fp", type=int, default=1, choices=[1,2]) 9 | parser.add_argument('-algorithm', action="store", dest="algorithm", type=str, required=True, choices=["dqn","a2c"]) 10 | parser.add_argument('-net_type', action="store", dest="net_type", type=str, required=True, choices=["fc","cnn"]) 11 | parser.add_argument('-policy', action="store", dest="policy", type=str, required=True) 12 | parser.add_argument('-debug', action="store_true", dest="debug") 13 | args = parser.parse_args() 14 | 15 | game = Game(verbose=True, debug=args.debug, first_player=args.fp) 16 | game.load_algorithm(args.algorithm) 17 | game.load_net(args.net_type) 18 | game.load_env("TicTacToeHumanPlayer-v0") 19 | game.load_model(args.policy) 20 | 21 | # start playing 22 | game.play() 23 | 24 | board = game.env._board 25 | if board.player_won: 26 | if board.player_won.side == "X": 27 | print("AI won!") 28 | elif board.player_won.side == "O": 29 | print("You won!") 30 | else: 31 | print("Draw.") -------------------------------------------------------------------------------- /play/agent_vs_random.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import argparse 4 | from env.base import Player 5 | from .game import Game 6 | 7 | def readable_board(board): 8 | matrix = np.asarray(board) 9 | # map numbers to letters 10 | v = np.vectorize(lambda x: dict(zip([0, 1, 2], [" ", "X", "O"]))[x]) 11 | matrix = v(matrix) 12 | 13 | readable = [] 14 | for row in matrix: 15 | row_string = f" {row[0]} | {row[1]} | {row[2]} " 16 | readable.append(row_string) 17 | readable.append("-" * len(row_string)) 18 | return readable 19 | 20 | if __name__ == '__main__': 21 | # flags 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('-algorithm', action="store", dest="algorithm", type=str, required=True, choices=["dqn","a2c"]) 24 | parser.add_argument('-net_type', action="store", dest="net_type", type=str, required=True, choices=["fc","cnn"]) 25 | parser.add_argument('-policy', action="store", dest="policy", type=str, default="") 26 | parser.add_argument('-num_games', action="store", dest="num_games", type=int, default=1000) 27 | args = parser.parse_args() 28 | 29 | game = Game() 30 | game.load_algorithm(args.algorithm) 31 | game.load_net(args.net_type) 32 | game.load_env("TicTacToeRandomPlayer-v0") 33 | game.load_model(args.policy) 34 | 35 | x_won = o_won = draw = 0 36 | won_games = [] 37 | lost_games = [] 38 | for n in range(args.num_games): 39 | start = time.time() 40 | 41 | game.env.seed(start) 42 | game.play() 43 | board = game.env._board 44 | 45 | if board.player_won: 46 | if board.player_won.side == Player.X.value: 47 | x_won += 1 48 | won_games.append((n,readable_board(board.board))) 49 | elif board.player_won.side == Player.O.value: 50 | o_won += 1 51 | if board.board not in [b for _, b in lost_games]: 52 | lost_games.append((n,board.board)) 53 | else: 54 | draw += 1 55 | 56 | end = time.time() 57 | 58 | print(f"Game #{n+1} - Iteration duration: {end - start}", end="\r", flush=True) 59 | 60 | print("\n") 61 | print(f"X Won: {x_won}") 62 | print(f"O Won: {o_won}") 63 | print(f"Draw: {draw}") 64 | print(f"Win percentage: {x_won / args.num_games * 100}") 65 | print(f"Win+Draw percentage: {(x_won + draw) / args.num_games * 100}") 66 | print(f"Loss percentage: {o_won / args.num_games * 100}") 67 | 68 | # print("Lost games:") 69 | # for idx, (n, board) in enumerate(lost_games): 70 | # print(f"#{n}") 71 | # for row in readable_board(board): 72 | # print(row) 73 | 74 | # if idx > 10: 75 | # break 76 | 77 | # print("Won game: ") 78 | # for row in won_games[0][1]: 79 | # print(row) 80 | -------------------------------------------------------------------------------- /play/game.py: -------------------------------------------------------------------------------- 1 | 2 | import importlib 3 | import time 4 | import sys 5 | import torch 6 | import gym 7 | from env.base import Player 8 | 9 | # mapping between algorithms and their module paths 10 | algorithm_module_paths = dict( 11 | dqn = "train.dqn", 12 | a2c = "train.a2c", 13 | ) 14 | 15 | # mapping between nets and their classes 16 | net_type_classes = dict( 17 | fc = "FCNet", 18 | cnn = "CNNNet", 19 | ) 20 | 21 | class Game: 22 | def __init__(self, verbose=False, debug=False, first_player=None): 23 | self.algorithm = None 24 | self.algorithm_module = None 25 | self.net = None 26 | self.env = None 27 | self.model = None 28 | self.verbose = verbose 29 | self.debug = debug 30 | self.first_player = first_player 31 | 32 | def load_algorithm(self, algorithm, module_paths=algorithm_module_paths): 33 | try: 34 | self.algorithm = algorithm 35 | self.algorithm_module = importlib.import_module(module_paths[algorithm]) 36 | except Exception as e: 37 | sys.exit(e) 38 | 39 | def load_net(self, net_type): 40 | try: 41 | self.net = getattr(self.algorithm_module, net_type_classes[net_type]) 42 | except Exception as e: 43 | sys.exit(e) 44 | 45 | def load_env(self, env): 46 | kwargs = dict() 47 | 48 | kwargs["first_player"] = Player(self.first_player) if self.first_player else None 49 | if self.verbose: 50 | kwargs["player1_verbose"] = True 51 | kwargs["player2_verbose"] = True 52 | kwargs["board_verbose"] = True 53 | 54 | self.env = gym.make(env, **kwargs) 55 | 56 | # seed 57 | self.env.seed(time.time()) 58 | 59 | def load_model(self, policy): 60 | self.model = self.net(obs_size=9, n_actions=9) 61 | self.model.load_state_dict(torch.load(f"policies/{policy}")) 62 | self.model.eval() 63 | 64 | def play(self): 65 | # environment 66 | obs = self.env.reset() 67 | 68 | while True: 69 | x = torch.Tensor(obs).reshape(self.env.observation_space_n) 70 | mask = torch.zeros(self.env.action_space_n).index_fill(0, torch.LongTensor(self.env.legal_actions), 1) 71 | 72 | if self.algorithm == "dqn": 73 | y = self.model(x, mask) 74 | elif self.algorithm == "a2c": 75 | y, _ = self.model(x, mask) 76 | 77 | action = torch.argmax(y).item() 78 | 79 | if self.debug: 80 | print(f"action distribution:\n{y.view(3,3)}") 81 | print(f"action, max(action_dist): {action}, {torch.max(y)}\n") 82 | 83 | obs, reward, done, _ = self.env.step(action) 84 | 85 | if done: 86 | break -------------------------------------------------------------------------------- /policies/a2c-cnn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alfoudari/tictactoe-pytorch/db3a8ad6eb70aaa024ad141fe9e0ed430f22507d/policies/a2c-cnn -------------------------------------------------------------------------------- /policies/dqn-cnn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alfoudari/tictactoe-pytorch/db3a8ad6eb70aaa024ad141fe9e0ed430f22507d/policies/dqn-cnn -------------------------------------------------------------------------------- /policies/dqn-cnn-v1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alfoudari/tictactoe-pytorch/db3a8ad6eb70aaa024ad141fe9e0ed430f22507d/policies/dqn-cnn-v1 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | gym 4 | tensorboard 5 | tensorboardX 6 | xando 7 | pytest 8 | scipy 9 | rlprop -------------------------------------------------------------------------------- /train/a2c.py: -------------------------------------------------------------------------------- 1 | # Advantage Actor Critic (A2C) 2 | 3 | import random 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | import argparse 8 | import sys 9 | import torch 10 | import gym 11 | import env # module init needs to run 12 | from prop.algorithms.a2c import Agent 13 | from prop.net.feed_forward import FeedForward 14 | 15 | class CNNNet(FeedForward): 16 | """ compute action probability distribution and state value """ 17 | def __init__(self, obs_size, n_actions): 18 | # model is initiated in parent class, set params early. 19 | self.obs_size = obs_size 20 | self.n_actions = n_actions 21 | super(CNNNet, self).__init__() 22 | 23 | def model(self): 24 | common = nn.Sequential( 25 | nn.Conv2d(in_channels=3, out_channels=256, kernel_size=3, stride=1, padding=1), # 3 channels -> `out_channels` different kernels/feature maps 26 | nn.ReLU(), # negative numbers -> 0 27 | nn.MaxPool2d(kernel_size=2, stride=1), # deformation invariance; subtle changes are captured 28 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=2, stride=1, padding=1), 29 | nn.ReLU(), 30 | nn.MaxPool2d(kernel_size=2, stride=1), 31 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=2, stride=1, padding=1), 32 | nn.ReLU(), 33 | nn.MaxPool2d(kernel_size=2, stride=1), 34 | nn.Flatten(), 35 | nn.Linear(256 * 4, 64), 36 | nn.ReLU(), 37 | ) 38 | self.actor_head = nn.Linear(64, self.n_actions) 39 | self.critic_head = nn.Linear(64, 1) 40 | return common 41 | 42 | def forward(self, x, mask): 43 | # transfrom from one tensor of shape (9) into 3 tensors of shape (3,3) each 44 | empty = torch.zeros(x.size()).masked_scatter_((x == 0), torch.ones(x.size())).view(-1, 3, 3) 45 | player1 = torch.zeros(x.size()).masked_scatter_((x == 1), torch.ones(x.size())).view(-1, 3, 3) 46 | player2 = torch.zeros(x.size()).masked_scatter_((x == 2), torch.ones(x.size())).view(-1, 3, 3) 47 | cnn_input = torch.stack((empty, player1, player2), dim=1) 48 | 49 | # shared layers among actor and critic 50 | common = self.net(cnn_input) 51 | 52 | # actor layer 53 | actions = self.actor_head(common) 54 | if mask is not None: 55 | actions = self.mask_actions(actions, mask) 56 | action_dist = F.softmax(actions, dim=1) 57 | 58 | # critic layer 59 | value = self.critic_head(common) 60 | 61 | return action_dist, value 62 | 63 | if __name__ == "__main__": 64 | # flags 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('-device', action="store", dest="device", type=str) 67 | args = parser.parse_args() 68 | 69 | # device 70 | if args.device == None: 71 | device = "cuda" if torch.cuda.is_available() else "cpu" 72 | else: 73 | if args.device not in ["cpu", "cuda"]: 74 | sys.exit('device must be "cpu" or "cuda"') 75 | device = args.device 76 | 77 | device = torch.device(device) 78 | print(f"using: {device}") 79 | 80 | # setup env and agent and start training 81 | env = gym.make('TicTacToeRandomPlayer-v0') 82 | agent = Agent( 83 | env=env, 84 | net=CNNNet, 85 | name="a2c-cnn", 86 | learning_rate=3e-5, 87 | optimizer=optim.Adam, 88 | discount=1, 89 | dev=device) 90 | agent.train() 91 | -------------------------------------------------------------------------------- /train/dqn.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import argparse 5 | import sys 6 | import torch 7 | import gym 8 | import env # module init needs to run 9 | from prop.algorithms.dqn import Agent 10 | from prop.net.feed_forward import FeedForward 11 | 12 | class FCNet(FeedForward): 13 | def __init__(self, obs_size, n_actions): 14 | # model is initiated in parent class, set params early. 15 | self.obs_size = obs_size 16 | self.n_actions = n_actions 17 | super(FCNet, self).__init__() 18 | 19 | def model(self): 20 | # observations -> hidden layer with relu activation -> actions 21 | return nn.Sequential( 22 | nn.Linear(self.obs_size, 64), 23 | nn.ReLU(), 24 | nn.Linear(64, 64), 25 | nn.ReLU(), 26 | nn.Linear(64, 64), 27 | nn.ReLU(), 28 | nn.Linear(64, 64), 29 | nn.ReLU(), 30 | nn.Linear(64, self.n_actions) 31 | ) 32 | 33 | class CNNNet(FeedForward): 34 | def __init__(self, obs_size, n_actions): 35 | # model is initiated in parent class, set params early. 36 | self.obs_size = obs_size 37 | self.n_actions = n_actions 38 | super(CNNNet, self).__init__() 39 | 40 | def model(self): 41 | return nn.Sequential( 42 | # convolution 1 43 | nn.Conv2d(in_channels=3, out_channels=512, kernel_size=3, stride=1, padding=2), # 3 channels -> `out_channels` different kernels/feature maps 44 | nn.ReLU(), # negative numbers -> 0 45 | nn.MaxPool2d(kernel_size=5, stride=1), # deformation invariance; subtle changes are captured 46 | # flatten 47 | nn.Flatten(), 48 | nn.Linear(512, 256), 49 | nn.ReLU(), 50 | nn.Linear(256, self.n_actions) 51 | ) 52 | 53 | def forward(self, x, mask=[]): 54 | # transfrom from one tensor of shape (9) into 3 tensors of shape (3,3) each 55 | empty = torch.zeros(x.size()).masked_scatter_((x == 0), torch.ones(x.size())).view(-1, 3, 3) 56 | player1 = torch.zeros(x.size()).masked_scatter_((x == 1), torch.ones(x.size())).view(-1, 3, 3) 57 | player2 = torch.zeros(x.size()).masked_scatter_((x == 2), torch.ones(x.size())).view(-1, 3, 3) 58 | cnn_input = torch.stack((empty, player1, player2), dim=1) 59 | return super(CNNNet, self).forward(cnn_input, mask) 60 | 61 | if __name__ == "__main__": 62 | # flags 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument('-device', action="store", dest="device", type=str, default="cpu", choices=["cpu","cuda"]) 65 | args = parser.parse_args() 66 | 67 | if args.device == "cuda" and not torch.cuda.is_available(): 68 | print("cuda is not available") 69 | args.device = "cpu" 70 | 71 | device = torch.device(args.device) 72 | print(f"using device: {device}") 73 | 74 | env = gym.make('TicTacToeRandomPlayer-v0', 75 | thresholds=dict( 76 | win_rate=0.92, 77 | draw_rate=0.08 78 | )) 79 | env.spec.reward_threshold = env.performance_threshold 80 | print(f"performance threshold: {env.performance_threshold}") 81 | 82 | agent = Agent( 83 | env=env, 84 | net=CNNNet, 85 | name="dqn-cnn", 86 | learning_rate=1e-5, 87 | batch_size=128, 88 | optimizer=optim.Adam, 89 | loss_cutoff=0.02, 90 | max_std_dev=0.09, 91 | epsilon_decay=3000, 92 | double=True, 93 | target_net_update=500, 94 | eval_every=500, 95 | dev=device) 96 | agent.train() 97 | 98 | print(f"#### stats ####") 99 | print(f"games played: {env.stats['games_played']}") --------------------------------------------------------------------------------