├── logo.png ├── gui_example.png ├── requirements.txt ├── .gitignore ├── LICENSE ├── td_gammon ├── gnubg │ ├── bridge.py │ └── gnubg_backgammon.py ├── main.py ├── agents.py ├── model.py ├── utils.py └── web_gui │ ├── gui.py │ └── index.html └── README.md /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dellalibera/td-gammon/HEAD/logo.png -------------------------------------------------------------------------------- /gui_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dellalibera/td-gammon/HEAD/gui_example.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym 2 | torch>=1.2.0 3 | torchvision>=0.4.0 4 | tb-nightly 5 | requests -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # https://github.com/github/gitignore/blob/master/Python.gitignore 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | 32 | # Jupyter Notebook 33 | .ipynb_checkpoints 34 | 35 | # IPython 36 | profile_default/ 37 | ipython_config.py 38 | 39 | # pyenv 40 | .python-version 41 | 42 | # Environments 43 | .env 44 | .venv 45 | env/ 46 | venv/ 47 | ENV/ 48 | env.bak/ 49 | venv.bak/ 50 | 51 | .idea/ 52 | 53 | # https://github.com/github/gitignore/blob/master/Global/macOS.gitignore 54 | # General 55 | .DS_Store 56 | .AppleDouble 57 | .LSOverride 58 | 59 | # Added 60 | *.tar 61 | saved_models/ 62 | runs/ 63 | _*/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Alessio Della Libera 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /td_gammon/gnubg/bridge.py: -------------------------------------------------------------------------------- 1 | # THIS FILE SHOULD BE RUN ON THE SAME MACHINE WHERE gnubg IS INSTALLED. 2 | # IT USES PYTHON 2.7 3 | 4 | import gnubg 5 | from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer 6 | # from http.server import BaseHTTPRequestHandler, HTTPServer 7 | import json 8 | 9 | try: 10 | from urllib.parse import urlparse, parse_qs 11 | except ImportError: 12 | from urlparse import urlparse, parse_qs 13 | 14 | 15 | class Handler(BaseHTTPRequestHandler): 16 | 17 | def _set_headers(self, response=200): 18 | self.send_response(response) 19 | self.send_header('Content-type', 'text/html') 20 | self.end_headers() 21 | 22 | def do_POST(self): 23 | response = {'board': [], 'last_move': [], 'info': []} 24 | post_data = self.rfile.read(int(self.headers['Content-Length'])).decode('utf-8') 25 | data = parse_qs(post_data) 26 | 27 | command = data['command'][0] 28 | print(command) 29 | 30 | prev_game = gnubg.match(0)['games'][-1]['game'] if gnubg.match(0) else [] 31 | 32 | gnubg.command(command) 33 | 34 | # check if the game is started/exists (handle the case the command executed is set at the beginning) 35 | if gnubg.match(0): 36 | # get the board after the execution of a move 37 | response['board'] = gnubg.board() 38 | 39 | # get the last games 40 | games = gnubg.match(0)['games'][-1] 41 | 42 | # get the last game 43 | game = games['game'][-1] 44 | 45 | # save the state of the game before and after having executed a command 46 | response['last_move'] = [prev_game, game] 47 | 48 | # save the info al all games played so far 49 | for idx, g in enumerate(gnubg.match(0)['games']): 50 | info = g['info'] 51 | 52 | response['info'].append( 53 | { 54 | 'winner': info['winner'], 55 | 'n_moves': len(g['game']), 56 | 'resigned': info['resigned'] if 'resigned' in info else None 57 | } 58 | ) 59 | 60 | self._set_headers() 61 | self.wfile.write(json.dumps(response)) 62 | 63 | def do_GET(self): 64 | parsed = urlparse(self.path) 65 | path = parsed.path 66 | 67 | if self.path: 68 | self._set_headers() 69 | self.wfile.write(bytes("Hello! Welcome to Backgammon WebGUI")) 70 | 71 | 72 | def run(host, server_class=HTTPServer, handler_class=Handler, port=8001): 73 | server_address = (host, port) 74 | httpd = server_class(server_address, handler_class) 75 | print('Starting httpd ({}:{})...'.format(host, port)) 76 | httpd.serve_forever() 77 | 78 | 79 | if __name__ == "__main__": 80 | HOST = 'localhost' # <-- YOUR HOST HERE 81 | PORT = 8001 # <-- YOUR PORT HERE 82 | run(host=HOST, port=PORT) 83 | -------------------------------------------------------------------------------- /td_gammon/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import utils 3 | 4 | 5 | def formatter(prog): 6 | return argparse.ArgumentDefaultsHelpFormatter(prog, max_help_position=100, width=180) 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description='TD-Gammon', formatter_class=lambda prog: formatter(prog)) 11 | subparsers = parser.add_subparsers(help='Train TD-Network | Evaluate Agent(s) | Web GUI | Plot Wins') 12 | 13 | parser_train = subparsers.add_parser('train', help='Train TD-Network', formatter_class=lambda prog: formatter(prog)) 14 | parser_train.add_argument('--save_path', help='Save directory location', type=str, default=None) 15 | parser_train.add_argument('--save_step', help='Save the model every n episodes/games', type=int, default=0) 16 | parser_train.add_argument('--episodes', help='Number of episodes/games', type=int, default=200000) 17 | parser_train.add_argument('--init_weights', help='Init Weights with zeros', action='store_true') 18 | parser_train.add_argument('--lr', help='Learning rate', type=float, default=1e-4) 19 | parser_train.add_argument('--hidden_units', help='Hidden units', type=int, default=40) 20 | parser_train.add_argument('--lamda', help='Credit assignment parameter', type=float, default=0.7) 21 | parser_train.add_argument('--model', help='Directory location to the model to be restored', type=str, default=None) 22 | parser_train.add_argument('--name', help='Name of the experiment', type=str, default='exp1') 23 | parser_train.add_argument('--type', help='Model type', choices=['cnn', 'nn'], type=str, default='nn') 24 | parser_train.add_argument('--seed', help='Seed used to reproduce results', type=int, default=123) 25 | 26 | parser_train.set_defaults(func=utils.args_train) 27 | 28 | parser_evaluate = subparsers.add_parser('evaluate', help='Evaluate Agent(s)', formatter_class=lambda prog: formatter(prog)) 29 | parser_evaluate.add_argument('--model_agent0', help='Saved model used by the agent0 (WHITE)', required=True, type=str) 30 | parser_evaluate.add_argument('--model_agent1', help='Saved model used by the agent1 (BLACK)', required=False, type=str) 31 | parser_evaluate.add_argument('--type', help='Model type used by the agents', choices=['cnn', 'nn'], type=str, default='nn') 32 | parser_evaluate.add_argument('--hidden_units_agent0', help='Hidden Units of the model used by the agent0 (WHITE)', required=False, type=int, default=40) 33 | parser_evaluate.add_argument('--hidden_units_agent1', help='Hidden Units of the model used by the agent1 (BLACK)', required=False, type=int, default=40) 34 | parser_evaluate.add_argument('--episodes', help='Number of episodes/games', default=20, required=False, type=int) 35 | 36 | subparsers_gnubg = parser_evaluate.add_subparsers(help='Parameters for gnubg interface') 37 | parser_gnubg = subparsers_gnubg.add_parser('vs_gnubg', help='Evaluate agent0 against gnubg', formatter_class=lambda prog: formatter(prog)) 38 | parser_gnubg.add_argument('--host', help='Host running gnubg', type=str, required=True) 39 | parser_gnubg.add_argument('--port', help='Port listening for gnubg commands', type=int, required=True) 40 | parser_gnubg.add_argument('--difficulty', help='Difficulty level', choices=['beginner', 'intermediate', 'advanced', 'world_class'], type=str, required=False, default='beginner') 41 | 42 | parser_gnubg.set_defaults(func=utils.args_gnubg) 43 | parser_evaluate.set_defaults(func=utils.args_evaluate) 44 | 45 | parser_gui = subparsers.add_parser('gui', help='Start Web GUI', formatter_class=lambda prog: formatter(prog)) 46 | parser_gui.add_argument('--host', help='Host running the web gui', default='localhost') 47 | parser_gui.add_argument('--port', help='Port listening for command', default=8002, type=int) 48 | parser_gui.add_argument('--model', help='Model used by the AI opponent', required=True, type=str) 49 | parser_gui.add_argument('--hidden_units', help='Hidden units of the model loaded', required=False, type=int, default=40) 50 | parser_gui.add_argument('--type', help='Model type', choices=['cnn', 'nn'], type=str, default='nn') 51 | 52 | parser_gui.set_defaults(func=utils.args_gui) 53 | 54 | parser_plot = subparsers.add_parser('plot', help='Plot the performance (wins)', formatter_class=lambda prog: formatter(prog)) 55 | parser_plot.add_argument('--save_path', help='Directory where the model are saved', type=str, required=True) 56 | parser_plot.add_argument('--hidden_units', help='Hidden units of the model(s) loaded', type=int, default=40) 57 | parser_plot.add_argument('--episodes', help='Number of episodes/games against a single opponent', default=20, type=int) 58 | parser_plot.add_argument('--opponent', help='Opponent(s) agent(s) (delimited by comma) - "random" and/or "gnubg"', default='random', type=str) 59 | parser_plot.add_argument('--host', help='Host running gnubg (if gnubg in --opponent)', type=str) 60 | parser_plot.add_argument('--port', help='Port listening for gnubg commands (if gnubg in --opponent)', type=int) 61 | parser_plot.add_argument('--difficulty', help='Difficulty level(s) (delimited by comma)', type=str, default="beginner,intermediate,advanced,world_class") 62 | parser_plot.add_argument('--dst', help='Save directory location', type=str, default='myexp') 63 | parser_plot.add_argument('--type', help='Model type', choices=['cnn', 'nn'], type=str, default='nn') 64 | 65 | parser_plot.set_defaults(func=lambda args: utils.args_plot(args, parser)) 66 | 67 | args = parser.parse_args() 68 | args.func(args) 69 | -------------------------------------------------------------------------------- /td_gammon/agents.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from itertools import count 4 | from random import randint, choice 5 | 6 | import numpy as np 7 | from gym_backgammon.envs.backgammon import WHITE, BLACK, COLORS 8 | 9 | random.seed(0) 10 | 11 | 12 | # AGENT ============================================================================================ 13 | 14 | 15 | class Agent: 16 | def __init__(self, color): 17 | self.color = color 18 | self.name = 'Agent({})'.format(COLORS[color]) 19 | 20 | def roll_dice(self): 21 | return (-randint(1, 6), -randint(1, 6)) if self.color == WHITE else (randint(1, 6), randint(1, 6)) 22 | 23 | def choose_best_action(self, actions, env): 24 | raise NotImplementedError 25 | 26 | 27 | # RANDOM AGENT ======================================================================================= 28 | 29 | 30 | class RandomAgent(Agent): 31 | def __init__(self, color): 32 | super().__init__(color) 33 | self.name = 'RandomAgent({})'.format(COLORS[color]) 34 | 35 | def choose_best_action(self, actions, env): 36 | return choice(list(actions)) if actions else None 37 | 38 | 39 | # HUMAN AGENT ======================================================================================= 40 | 41 | 42 | class HumanAgent(Agent): 43 | def __init__(self, color): 44 | super().__init__(color) 45 | self.name = 'HumanAgent({})'.format(COLORS[color]) 46 | 47 | def choose_best_action(self, actions=None, env=None): 48 | pass 49 | 50 | 51 | # TD-GAMMON AGENT ===================================================================================== 52 | 53 | 54 | class TDAgent(Agent): 55 | def __init__(self, color, net): 56 | super().__init__(color) 57 | self.net = net 58 | self.name = 'TDAgent({})'.format(COLORS[color]) 59 | 60 | def choose_best_action(self, actions, env): 61 | best_action = None 62 | 63 | if actions: 64 | values = [0.0] * len(actions) 65 | tmp_counter = env.counter 66 | env.counter = 0 67 | state = env.game.save_state() 68 | 69 | # Iterate over all the legal moves and pick the best action 70 | for i, action in enumerate(actions): 71 | observation, reward, done, info = env.step(action) 72 | values[i] = self.net(observation) 73 | 74 | # restore the board and other variables (undo the action) 75 | env.game.restore_state(state) 76 | 77 | # practical-issues-in-temporal-difference-learning, pag.3 78 | # ... the network's output P_t is an estimate of White's probability of winning from board position x_t. 79 | # ... the move which is selected at each time step is the move which maximizes P_t when White is to play and minimizes P_t when Black is to play. 80 | best_action_index = int(np.argmax(values)) if self.color == WHITE else int(np.argmin(values)) 81 | best_action = list(actions)[best_action_index] 82 | env.counter = tmp_counter 83 | 84 | return best_action 85 | 86 | 87 | # TD-GAMMON AGENT (play against gnubg) ================================================================ 88 | 89 | 90 | class TDAgentGNU(TDAgent): 91 | 92 | def __init__(self, color, net, gnubg_interface): 93 | super().__init__(color, net) 94 | self.gnubg_interface = gnubg_interface 95 | 96 | def roll_dice(self): 97 | gnubg = self.gnubg_interface.send_command("roll") 98 | return self.handle_opponent_move(gnubg) 99 | 100 | def choose_best_action(self, actions, env): 101 | best_action = None 102 | 103 | if actions: 104 | game = env.game 105 | values = [0.0] * len(actions) 106 | state = game.save_state() 107 | 108 | for i, action in enumerate(actions): 109 | game.execute_play(self.color, action) 110 | opponent = game.get_opponent(self.color) 111 | observation = game.get_board_features(opponent) if env.model_type == 'nn' else env.render(mode='state_pixels') 112 | values[i] = self.net(observation) 113 | game.restore_state(state) 114 | 115 | best_action_index = int(np.argmax(values)) if self.color == WHITE else int(np.argmin(values)) 116 | best_action = list(actions)[best_action_index] 117 | 118 | return best_action 119 | 120 | def handle_opponent_move(self, gnubg): 121 | # Once I roll the dice, 2 possible situations can happen: 122 | # 1) I can move (the value gnubg.roll is not None) 123 | # 2) I cannot move, so my opponent rolls the dice and makes its move, and eventually ask for doubling, so I have to roll the dice again 124 | 125 | # One way to distinguish between the above cases, is to check the color of the player that performs the last move in gnubg: 126 | # - if the player's color is the same as the TD Agent, it means I can send the 'move' command (no other moves have been performed after the 'roll' command) - case 1); 127 | # - if the player's color is not the same as the TD Agent, this means that the last move performed after the 'roll' is not of the TD agent - case 2) 128 | previous_agent = gnubg.agent 129 | if previous_agent == self.color: # case 1) 130 | return gnubg 131 | else: # case 2) 132 | while previous_agent != self.color and gnubg.winner is None: 133 | # check if my opponent asks for doubling 134 | if gnubg.double: 135 | # default action if the opponent asks for doubling is 'take' 136 | gnubg = self.gnubg_interface.send_command("take") 137 | else: 138 | gnubg = self.gnubg_interface.send_command("roll") 139 | previous_agent = gnubg.agent 140 | return gnubg 141 | 142 | 143 | def evaluate_agents(agents, env, n_episodes): 144 | wins = {WHITE: 0, BLACK: 0} 145 | 146 | for episode in range(n_episodes): 147 | 148 | agent_color, first_roll, observation = env.reset() 149 | agent = agents[agent_color] 150 | 151 | t = time.time() 152 | 153 | for i in count(): 154 | 155 | if first_roll: 156 | roll = first_roll 157 | first_roll = None 158 | else: 159 | roll = agent.roll_dice() 160 | 161 | actions = env.get_valid_actions(roll) 162 | action = agent.choose_best_action(actions, env) 163 | observation_next, reward, done, winner = env.step(action) 164 | 165 | if done: 166 | if winner is not None: 167 | wins[agent.color] += 1 168 | tot = wins[WHITE] + wins[BLACK] 169 | tot = tot if tot > 0 else 1 170 | 171 | print("EVAL => Game={:<6d} | Winner={} | after {:<4} plays || Wins: {}={:<6}({:<5.1f}%) | {}={:<6}({:<5.1f}%) | Duration={:<.3f} sec".format(episode + 1, winner, i, 172 | agents[WHITE].name, wins[WHITE], (wins[WHITE] / tot) * 100, 173 | agents[BLACK].name, wins[BLACK], (wins[BLACK] / tot) * 100, time.time() - t)) 174 | break 175 | 176 | agent_color = env.get_opponent_agent() 177 | agent = agents[agent_color] 178 | 179 | observation = observation_next 180 | return wins 181 | -------------------------------------------------------------------------------- /td_gammon/model.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import random 3 | import time 4 | from itertools import count 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | from agents import TDAgent, RandomAgent, evaluate_agents 11 | from gym_backgammon.envs.backgammon import WHITE, BLACK 12 | 13 | torch.set_default_tensor_type('torch.DoubleTensor') 14 | 15 | 16 | class BaseModel(nn.Module): 17 | def __init__(self, lr, lamda, seed=123): 18 | super(BaseModel, self).__init__() 19 | self.lr = lr 20 | self.lamda = lamda # trace-decay parameter 21 | self.start_episode = 0 22 | 23 | self.eligibility_traces = None 24 | self.optimizer = None 25 | 26 | torch.manual_seed(seed) 27 | random.seed(seed) 28 | 29 | def update_weights(self, p, p_next): 30 | raise NotImplementedError 31 | 32 | def forward(self, x): 33 | raise NotImplementedError 34 | 35 | def init_weights(self): 36 | raise NotImplementedError 37 | 38 | def init_eligibility_traces(self): 39 | self.eligibility_traces = [torch.zeros(weights.shape, requires_grad=False) for weights in list(self.parameters())] 40 | 41 | def checkpoint(self, checkpoint_path, step, name_experiment): 42 | path = checkpoint_path + "/{}_{}_{}.tar".format(name_experiment, datetime.datetime.now().strftime('%Y%m%d_%H%M_%S_%f'), step + 1) 43 | torch.save({'step': step + 1, 'model_state_dict': self.state_dict(), 'eligibility': self.eligibility_traces if self.eligibility_traces else []}, path) 44 | print("\nCheckpoint saved: {}".format(path)) 45 | 46 | def load(self, checkpoint_path, optimizer=None, eligibility_traces=None): 47 | checkpoint = torch.load(checkpoint_path) 48 | self.start_episode = checkpoint['step'] 49 | 50 | self.load_state_dict(checkpoint['model_state_dict']) 51 | 52 | if eligibility_traces is not None: 53 | self.eligibility_traces = checkpoint['eligibility'] 54 | 55 | if optimizer is not None: 56 | self.optimizer.load_state_dict(checkpoint['optimizer']) 57 | 58 | def train_agent(self, env, n_episodes, save_path=None, eligibility=False, save_step=0, name_experiment=''): 59 | start_episode = self.start_episode 60 | n_episodes += start_episode 61 | 62 | wins = {WHITE: 0, BLACK: 0} 63 | network = self 64 | 65 | agents = {WHITE: TDAgent(WHITE, net=network), BLACK: TDAgent(BLACK, net=network)} 66 | 67 | durations = [] 68 | steps = 0 69 | start_training = time.time() 70 | 71 | for episode in range(start_episode, n_episodes): 72 | 73 | if eligibility: 74 | self.init_eligibility_traces() 75 | 76 | agent_color, first_roll, observation = env.reset() 77 | agent = agents[agent_color] 78 | 79 | t = time.time() 80 | 81 | for i in count(): 82 | if first_roll: 83 | roll = first_roll 84 | first_roll = None 85 | else: 86 | roll = agent.roll_dice() 87 | 88 | p = self(observation) 89 | 90 | actions = env.get_valid_actions(roll) 91 | action = agent.choose_best_action(actions, env) 92 | observation_next, reward, done, winner = env.step(action) 93 | p_next = self(observation_next) 94 | 95 | if done: 96 | if winner is not None: 97 | loss = self.update_weights(p, reward) 98 | 99 | wins[agent.color] += 1 100 | 101 | tot = sum(wins.values()) 102 | tot = tot if tot > 0 else 1 103 | 104 | print("Game={:<6d} | Winner={} | after {:<4} plays || Wins: {}={:<6}({:<5.1f}%) | {}={:<6}({:<5.1f}%) | Duration={:<.3f} sec".format(episode + 1, winner, i, 105 | agents[WHITE].name, wins[WHITE], (wins[WHITE] / tot) * 100, 106 | agents[BLACK].name, wins[BLACK], (wins[BLACK] / tot) * 100, time.time() - t)) 107 | 108 | durations.append(time.time() - t) 109 | steps += i 110 | break 111 | else: 112 | loss = self.update_weights(p, p_next) 113 | 114 | agent_color = env.get_opponent_agent() 115 | agent = agents[agent_color] 116 | 117 | observation = observation_next 118 | 119 | if save_path and save_step > 0 and episode > 0 and (episode + 1) % save_step == 0: 120 | self.checkpoint(checkpoint_path=save_path, step=episode, name_experiment=name_experiment) 121 | agents_to_evaluate = {WHITE: TDAgent(WHITE, net=network), BLACK: RandomAgent(BLACK)} 122 | evaluate_agents(agents_to_evaluate, env, n_episodes=20) 123 | print() 124 | 125 | print("\nAverage duration per game: {} seconds".format(round(sum(durations) / n_episodes, 3))) 126 | print("Average game length: {} plays | Total Duration: {}".format(round(steps / n_episodes, 2), datetime.timedelta(seconds=int(time.time() - start_training)))) 127 | 128 | if save_path: 129 | self.checkpoint(checkpoint_path=save_path, step=n_episodes - 1, name_experiment=name_experiment) 130 | 131 | with open('{}/comments.txt'.format(save_path), 'a') as file: 132 | file.write("Average duration per game: {} seconds".format(round(sum(durations) / n_episodes, 3))) 133 | file.write("\nAverage game length: {} plays | Total Duration: {}".format(round(steps / n_episodes, 2), datetime.timedelta(seconds=int(time.time() - start_training)))) 134 | 135 | env.close() 136 | 137 | 138 | class TDGammonCNN(BaseModel): 139 | def __init__(self, lr, seed=123, output_units=1): 140 | super(TDGammonCNN, self).__init__(lr, seed=seed, lamda=0.7) 141 | 142 | self.loss_fn = torch.nn.MSELoss(reduction='sum') 143 | 144 | self.conv1 = nn.Sequential( 145 | nn.Conv2d(in_channels=1, out_channels=32, kernel_size=8, stride=4), # CHANNEL it was 3 146 | nn.BatchNorm2d(32), 147 | nn.ReLU() 148 | ) 149 | 150 | self.conv2 = nn.Sequential( 151 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), 152 | nn.BatchNorm2d(64), 153 | nn.ReLU() 154 | ) 155 | 156 | self.conv3 = nn.Sequential( 157 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), 158 | nn.BatchNorm2d(64), 159 | nn.ReLU() 160 | ) 161 | 162 | self.hidden = nn.Sequential( 163 | nn.Linear(64 * 8 * 8, 80), 164 | nn.Sigmoid() 165 | ) 166 | 167 | self.output = nn.Sequential( 168 | nn.Linear(80, output_units), 169 | nn.Sigmoid() 170 | ) 171 | 172 | self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 173 | 174 | def init_weights(self): 175 | pass 176 | 177 | def forward(self, x): 178 | # https://stackoverflow.com/questions/12201577/how-can-i-convert-an-rgb-image-into-grayscale-in-python 179 | x = np.dot(x[..., :3], [0.2989, 0.5870, 0.1140]) 180 | x = x[np.newaxis, :] 181 | x = torch.from_numpy(np.array(x)) 182 | x = x.unsqueeze(0) 183 | x = x.type(torch.DoubleTensor) 184 | 185 | x = self.conv1(x) 186 | x = self.conv2(x) 187 | x = self.conv3(x) 188 | x = x.view(-1, 64 * 8 * 8) 189 | x = x.reshape(-1) 190 | x = self.hidden(x) 191 | x = self.output(x) 192 | return x 193 | 194 | def update_weights(self, p, p_next): 195 | 196 | if isinstance(p_next, int): 197 | p_next = torch.tensor([p_next], dtype=torch.float64) 198 | 199 | loss = self.loss_fn(p_next, p) 200 | self.optimizer.zero_grad() 201 | loss.backward() 202 | self.optimizer.step() 203 | return loss 204 | 205 | 206 | class TDGammon(BaseModel): 207 | def __init__(self, hidden_units, lr, lamda, init_weights, seed=123, input_units=198, output_units=1): 208 | super(TDGammon, self).__init__(lr, lamda, seed=seed) 209 | 210 | self.hidden = nn.Sequential( 211 | nn.Linear(input_units, hidden_units), 212 | nn.Sigmoid() 213 | ) 214 | 215 | # self.hidden2 = nn.Sequential( 216 | # nn.Linear(hidden_units, hidden_units), 217 | # nn.Sigmoid() 218 | # ) 219 | 220 | # self.hidden3 = nn.Sequential( 221 | # nn.Linear(hidden_units, hidden_units), 222 | # nn.Sigmoid() 223 | # ) 224 | 225 | self.output = nn.Sequential( 226 | nn.Linear(hidden_units, output_units), 227 | nn.Sigmoid() 228 | ) 229 | 230 | if init_weights: 231 | self.init_weights() 232 | 233 | def init_weights(self): 234 | for p in self.parameters(): 235 | nn.init.zeros_(p) 236 | 237 | def forward(self, x): 238 | x = torch.from_numpy(np.array(x)) 239 | x = self.hidden(x) 240 | # x = self.hidden2(x) 241 | # x = self.hidden3(x) 242 | x = self.output(x) 243 | return x 244 | 245 | def update_weights(self, p, p_next): 246 | # reset the gradients 247 | self.zero_grad() 248 | 249 | # compute the derivative of p w.r.t. the parameters 250 | p.backward() 251 | 252 | with torch.no_grad(): 253 | 254 | td_error = p_next - p 255 | 256 | # get the parameters of the model 257 | parameters = list(self.parameters()) 258 | 259 | for i, weights in enumerate(parameters): 260 | 261 | # z <- gamma * lambda * z + (grad w w.r.t P_t) 262 | self.eligibility_traces[i] = self.lamda * self.eligibility_traces[i] + weights.grad 263 | 264 | # w <- w + alpha * td_error * z 265 | new_weights = weights + self.lr * td_error * self.eligibility_traces[i] 266 | weights.copy_(new_weights) 267 | 268 | return td_error 269 | -------------------------------------------------------------------------------- /td_gammon/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gym 3 | import sys 4 | from agents import TDAgent, HumanAgent, TDAgentGNU, RandomAgent, evaluate_agents 5 | from gnubg.gnubg_backgammon import GnubgInterface, GnubgEnv, evaluate_vs_gnubg 6 | from gym_backgammon.envs.backgammon import WHITE, BLACK 7 | from model import TDGammon, TDGammonCNN 8 | from web_gui.gui import GUI 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | # tensorboard --logdir=runs/ --host localhost --port 8001 12 | 13 | 14 | def write_file(path, **kwargs): 15 | with open('{}/parameters.txt'.format(path), 'w+') as file: 16 | print("Parameters:") 17 | for key, value in kwargs.items(): 18 | file.write("{}={}\n".format(key, value)) 19 | print("{}={}".format(key, value)) 20 | print() 21 | 22 | 23 | def path_exists(path): 24 | if os.path.exists(path): 25 | return True 26 | else: 27 | print("The path {} doesn't exists".format(path)) 28 | sys.exit() 29 | 30 | 31 | 32 | # ==================================== TRAINING PARAMETERS =================================== 33 | def args_train(args): 34 | save_step = args.save_step 35 | save_path = None 36 | n_episodes = args.episodes 37 | init_weights = args.init_weights 38 | lr = args.lr 39 | hidden_units = args.hidden_units 40 | lamda = args.lamda 41 | name = args.name 42 | model_type = args.type 43 | seed = args.seed 44 | 45 | eligibility = False 46 | optimizer = None 47 | 48 | if model_type == 'nn': 49 | net = TDGammon(hidden_units=hidden_units, lr=lr, lamda=lamda, init_weights=init_weights, seed=seed) 50 | eligibility = True 51 | env = gym.make('gym_backgammon:backgammon-v0') 52 | 53 | else: 54 | net = TDGammonCNN(lr=lr, seed=seed) 55 | optimizer = True 56 | env = gym.make('gym_backgammon:backgammon-pixel-v0') 57 | 58 | if args.model and path_exists(args.model): 59 | # assert os.path.exists(args.model), print("The path {} doesn't exists".format(args.model)) 60 | net.load(checkpoint_path=args.model, optimizer=optimizer, eligibility_traces=eligibility) 61 | 62 | if args.save_path and path_exists(args.save_path): 63 | # assert os.path.exists(args.save_path), print("The path {} doesn't exists".format(args.save_path)) 64 | save_path = args.save_path 65 | 66 | write_file( 67 | save_path, save_path=args.save_path, command_line_args=args, type=model_type, hidden_units=hidden_units, init_weights=init_weights, alpha=net.lr, lamda=net.lamda, 68 | n_episodes=n_episodes, save_step=save_step, start_episode=net.start_episode, name_experiment=name, env=env.spec.id, restored_model=args.model, seed=seed, 69 | eligibility=eligibility, optimizer=optimizer, modules=[module for module in net.modules()] 70 | ) 71 | 72 | net.train_agent(env=env, n_episodes=n_episodes, save_path=save_path, save_step=save_step, eligibility=eligibility, name_experiment=name) 73 | 74 | 75 | # ==================================== WEB GUI PARAMETERS ==================================== 76 | def args_gui(args): 77 | if path_exists(args.model): 78 | # assert os.path.exists(args.model), print("The path {} doesn't exists".format(args.model)) 79 | 80 | if args.type == 'nn': 81 | net = TDGammon(hidden_units=args.hidden_units, lr=0.1, lamda=None, init_weights=False) 82 | env = gym.make('gym_backgammon:backgammon-v0') 83 | else: 84 | net = TDGammonCNN(lr=0.0001) 85 | env = gym.make('gym_backgammon:backgammon-pixel-v0') 86 | 87 | net.load(checkpoint_path=args.model, optimizer=None, eligibility_traces=False) 88 | 89 | agents = {BLACK: TDAgent(BLACK, net=net), WHITE: HumanAgent(WHITE)} 90 | gui = GUI(env=env, host=args.host, port=args.port, agents=agents) 91 | gui.run() 92 | 93 | 94 | # =================================== EVALUATE PARAMETERS ==================================== 95 | def args_evaluate(args): 96 | model_agent0 = args.model_agent0 97 | model_agent1 = args.model_agent1 98 | model_type = args.type 99 | hidden_units_agent0 = args.hidden_units_agent0 100 | hidden_units_agent1 = args.hidden_units_agent1 101 | n_episodes = args.episodes 102 | 103 | if path_exists(model_agent0) and path_exists(model_agent1): 104 | # assert os.path.exists(model_agent0), print("The path {} doesn't exists".format(model_agent0)) 105 | # assert os.path.exists(model_agent1), print("The path {} doesn't exists".format(model_agent1)) 106 | 107 | if model_type == 'nn': 108 | net0 = TDGammon(hidden_units=hidden_units_agent0, lr=0.1, lamda=None, init_weights=False) 109 | net1 = TDGammon(hidden_units=hidden_units_agent1, lr=0.1, lamda=None, init_weights=False) 110 | env = gym.make('gym_backgammon:backgammon-v0') 111 | else: 112 | net0 = TDGammonCNN(lr=0.0001) 113 | net1 = TDGammonCNN(lr=0.0001) 114 | env = gym.make('gym_backgammon:backgammon-pixel-v0') 115 | 116 | net0.load(checkpoint_path=model_agent0, optimizer=None, eligibility_traces=False) 117 | net1.load(checkpoint_path=model_agent1, optimizer=None, eligibility_traces=False) 118 | 119 | agents = {WHITE: TDAgent(WHITE, net=net1), BLACK: TDAgent(BLACK, net=net0)} 120 | 121 | evaluate_agents(agents, env, n_episodes) 122 | 123 | 124 | # ===================================== GNUBG PARAMETERS ===================================== 125 | def args_gnubg(args): 126 | model_agent0 = args.model_agent0 127 | model_type = args.type 128 | hidden_units_agent0 = args.hidden_units_agent0 129 | n_episodes = args.episodes 130 | host = args.host 131 | port = args.port 132 | difficulty = args.difficulty 133 | 134 | if path_exists(model_agent0): 135 | # assert os.path.exists(model_agent0), print("The path {} doesn't exists".format(model_agent0)) 136 | if model_type == 'nn': 137 | net0 = TDGammon(hidden_units=hidden_units_agent0, lr=0.1, lamda=None, init_weights=False) 138 | else: 139 | net0 = TDGammonCNN(lr=0.0001) 140 | 141 | net0.load(checkpoint_path=model_agent0, optimizer=None, eligibility_traces=False) 142 | 143 | gnubg_interface = GnubgInterface(host=host, port=port) 144 | gnubg_env = GnubgEnv(gnubg_interface, difficulty=difficulty, model_type=model_type) 145 | evaluate_vs_gnubg(agent=TDAgentGNU(WHITE, net=net0, gnubg_interface=gnubg_interface), env=gnubg_env, n_episodes=n_episodes) 146 | 147 | 148 | # ===================================== PLOT PARAMETERS ====================================== 149 | def args_plot(args, parser): 150 | ''' 151 | This method is used to plot the number of time an agent wins when it plays against an opponent. 152 | Instead of evaluating the agent during training (it can require some time and slow down the training), I decided to plot the wins separately, loading the different 153 | model saved during training. 154 | For example, suppose I run the training for 100 games and save my model every 10 games. 155 | Later I will load these 10 models, and for each of them, I will compute how many times the agent would win against an opponent. 156 | :return: None 157 | ''' 158 | 159 | src = args.save_path 160 | hidden_units = args.hidden_units 161 | n_episodes = args.episodes 162 | opponents = args.opponent.split(',') 163 | host = args.host 164 | port = args.port 165 | difficulties = args.difficulty.split(',') 166 | model_type = args.type 167 | 168 | if path_exists(src): 169 | # assert os.path.exists(src), print("The path {} doesn't exists".format(src)) 170 | 171 | for d in difficulties: 172 | if d not in ['beginner', 'intermediate', 'advanced', 'world_class']: 173 | parser.error("--difficulty should be (one or more of) 'beginner','intermediate', 'advanced' ,'world_class'") 174 | 175 | dst = args.dst 176 | 177 | if 'gnubg' in opponents and (not host or not port): 178 | parser.error("--host and --port are required when 'gnubg' is specified in --opponent") 179 | 180 | for root, dirs, files in os.walk(src): 181 | global_step = 0 182 | files = sorted(files) 183 | 184 | writer = SummaryWriter(dst) 185 | 186 | for file in files: 187 | if ".tar" in file: 188 | print("\nLoad {}".format(os.path.join(root, file))) 189 | 190 | if model_type == 'nn': 191 | net = TDGammon(hidden_units=hidden_units, lr=0.1, lamda=None, init_weights=False) 192 | env = gym.make('gym_backgammon:backgammon-v0') 193 | else: 194 | net = TDGammonCNN(lr=0.0001) 195 | env = gym.make('gym_backgammon:backgammon-pixel-v0') 196 | 197 | net.load(checkpoint_path=os.path.join(root, file), optimizer=None, eligibility_traces=False) 198 | 199 | if 'gnubg' in opponents: 200 | tag_scalar_dict = {} 201 | 202 | gnubg_interface = GnubgInterface(host=host, port=port) 203 | 204 | for difficulty in difficulties: 205 | gnubg_env = GnubgEnv(gnubg_interface, difficulty=difficulty, model_type=model_type) 206 | wins = evaluate_vs_gnubg(agent=TDAgentGNU(WHITE, net=net, gnubg_interface=gnubg_interface), env=gnubg_env, n_episodes=n_episodes) 207 | tag_scalar_dict[difficulty] = wins[WHITE] 208 | 209 | writer.add_scalars('wins_vs_gnubg/', tag_scalar_dict, global_step) 210 | 211 | with open(root + '/results.txt', 'a') as f: 212 | print("{};".format(file) + str(tag_scalar_dict), file=f) 213 | 214 | if 'random' in opponents: 215 | tag_scalar_dict = {} 216 | agents = {WHITE: TDAgent(WHITE, net=net), BLACK: RandomAgent(BLACK)} 217 | wins = evaluate_agents(agents, env, n_episodes) 218 | 219 | tag_scalar_dict['random'] = wins[WHITE] 220 | 221 | writer.add_scalars('wins_vs_random/', tag_scalar_dict, global_step) 222 | 223 | global_step += 1 224 | 225 | writer.close() 226 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
4 |
5 |
140 |
141 |
Player: WHITE
66 | 67 | 68 | 71 | 72 |