├── logo.png ├── gui_example.png ├── requirements.txt ├── .gitignore ├── LICENSE ├── td_gammon ├── gnubg │ ├── bridge.py │ └── gnubg_backgammon.py ├── main.py ├── agents.py ├── model.py ├── utils.py └── web_gui │ ├── gui.py │ └── index.html └── README.md /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dellalibera/td-gammon/HEAD/logo.png -------------------------------------------------------------------------------- /gui_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dellalibera/td-gammon/HEAD/gui_example.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym 2 | torch>=1.2.0 3 | torchvision>=0.4.0 4 | tb-nightly 5 | requests -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # https://github.com/github/gitignore/blob/master/Python.gitignore 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | 32 | # Jupyter Notebook 33 | .ipynb_checkpoints 34 | 35 | # IPython 36 | profile_default/ 37 | ipython_config.py 38 | 39 | # pyenv 40 | .python-version 41 | 42 | # Environments 43 | .env 44 | .venv 45 | env/ 46 | venv/ 47 | ENV/ 48 | env.bak/ 49 | venv.bak/ 50 | 51 | .idea/ 52 | 53 | # https://github.com/github/gitignore/blob/master/Global/macOS.gitignore 54 | # General 55 | .DS_Store 56 | .AppleDouble 57 | .LSOverride 58 | 59 | # Added 60 | *.tar 61 | saved_models/ 62 | runs/ 63 | _*/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Alessio Della Libera 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /td_gammon/gnubg/bridge.py: -------------------------------------------------------------------------------- 1 | # THIS FILE SHOULD BE RUN ON THE SAME MACHINE WHERE gnubg IS INSTALLED. 2 | # IT USES PYTHON 2.7 3 | 4 | import gnubg 5 | from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer 6 | # from http.server import BaseHTTPRequestHandler, HTTPServer 7 | import json 8 | 9 | try: 10 | from urllib.parse import urlparse, parse_qs 11 | except ImportError: 12 | from urlparse import urlparse, parse_qs 13 | 14 | 15 | class Handler(BaseHTTPRequestHandler): 16 | 17 | def _set_headers(self, response=200): 18 | self.send_response(response) 19 | self.send_header('Content-type', 'text/html') 20 | self.end_headers() 21 | 22 | def do_POST(self): 23 | response = {'board': [], 'last_move': [], 'info': []} 24 | post_data = self.rfile.read(int(self.headers['Content-Length'])).decode('utf-8') 25 | data = parse_qs(post_data) 26 | 27 | command = data['command'][0] 28 | print(command) 29 | 30 | prev_game = gnubg.match(0)['games'][-1]['game'] if gnubg.match(0) else [] 31 | 32 | gnubg.command(command) 33 | 34 | # check if the game is started/exists (handle the case the command executed is set at the beginning) 35 | if gnubg.match(0): 36 | # get the board after the execution of a move 37 | response['board'] = gnubg.board() 38 | 39 | # get the last games 40 | games = gnubg.match(0)['games'][-1] 41 | 42 | # get the last game 43 | game = games['game'][-1] 44 | 45 | # save the state of the game before and after having executed a command 46 | response['last_move'] = [prev_game, game] 47 | 48 | # save the info al all games played so far 49 | for idx, g in enumerate(gnubg.match(0)['games']): 50 | info = g['info'] 51 | 52 | response['info'].append( 53 | { 54 | 'winner': info['winner'], 55 | 'n_moves': len(g['game']), 56 | 'resigned': info['resigned'] if 'resigned' in info else None 57 | } 58 | ) 59 | 60 | self._set_headers() 61 | self.wfile.write(json.dumps(response)) 62 | 63 | def do_GET(self): 64 | parsed = urlparse(self.path) 65 | path = parsed.path 66 | 67 | if self.path: 68 | self._set_headers() 69 | self.wfile.write(bytes("Hello! Welcome to Backgammon WebGUI")) 70 | 71 | 72 | def run(host, server_class=HTTPServer, handler_class=Handler, port=8001): 73 | server_address = (host, port) 74 | httpd = server_class(server_address, handler_class) 75 | print('Starting httpd ({}:{})...'.format(host, port)) 76 | httpd.serve_forever() 77 | 78 | 79 | if __name__ == "__main__": 80 | HOST = 'localhost' # <-- YOUR HOST HERE 81 | PORT = 8001 # <-- YOUR PORT HERE 82 | run(host=HOST, port=PORT) 83 | -------------------------------------------------------------------------------- /td_gammon/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import utils 3 | 4 | 5 | def formatter(prog): 6 | return argparse.ArgumentDefaultsHelpFormatter(prog, max_help_position=100, width=180) 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description='TD-Gammon', formatter_class=lambda prog: formatter(prog)) 11 | subparsers = parser.add_subparsers(help='Train TD-Network | Evaluate Agent(s) | Web GUI | Plot Wins') 12 | 13 | parser_train = subparsers.add_parser('train', help='Train TD-Network', formatter_class=lambda prog: formatter(prog)) 14 | parser_train.add_argument('--save_path', help='Save directory location', type=str, default=None) 15 | parser_train.add_argument('--save_step', help='Save the model every n episodes/games', type=int, default=0) 16 | parser_train.add_argument('--episodes', help='Number of episodes/games', type=int, default=200000) 17 | parser_train.add_argument('--init_weights', help='Init Weights with zeros', action='store_true') 18 | parser_train.add_argument('--lr', help='Learning rate', type=float, default=1e-4) 19 | parser_train.add_argument('--hidden_units', help='Hidden units', type=int, default=40) 20 | parser_train.add_argument('--lamda', help='Credit assignment parameter', type=float, default=0.7) 21 | parser_train.add_argument('--model', help='Directory location to the model to be restored', type=str, default=None) 22 | parser_train.add_argument('--name', help='Name of the experiment', type=str, default='exp1') 23 | parser_train.add_argument('--type', help='Model type', choices=['cnn', 'nn'], type=str, default='nn') 24 | parser_train.add_argument('--seed', help='Seed used to reproduce results', type=int, default=123) 25 | 26 | parser_train.set_defaults(func=utils.args_train) 27 | 28 | parser_evaluate = subparsers.add_parser('evaluate', help='Evaluate Agent(s)', formatter_class=lambda prog: formatter(prog)) 29 | parser_evaluate.add_argument('--model_agent0', help='Saved model used by the agent0 (WHITE)', required=True, type=str) 30 | parser_evaluate.add_argument('--model_agent1', help='Saved model used by the agent1 (BLACK)', required=False, type=str) 31 | parser_evaluate.add_argument('--type', help='Model type used by the agents', choices=['cnn', 'nn'], type=str, default='nn') 32 | parser_evaluate.add_argument('--hidden_units_agent0', help='Hidden Units of the model used by the agent0 (WHITE)', required=False, type=int, default=40) 33 | parser_evaluate.add_argument('--hidden_units_agent1', help='Hidden Units of the model used by the agent1 (BLACK)', required=False, type=int, default=40) 34 | parser_evaluate.add_argument('--episodes', help='Number of episodes/games', default=20, required=False, type=int) 35 | 36 | subparsers_gnubg = parser_evaluate.add_subparsers(help='Parameters for gnubg interface') 37 | parser_gnubg = subparsers_gnubg.add_parser('vs_gnubg', help='Evaluate agent0 against gnubg', formatter_class=lambda prog: formatter(prog)) 38 | parser_gnubg.add_argument('--host', help='Host running gnubg', type=str, required=True) 39 | parser_gnubg.add_argument('--port', help='Port listening for gnubg commands', type=int, required=True) 40 | parser_gnubg.add_argument('--difficulty', help='Difficulty level', choices=['beginner', 'intermediate', 'advanced', 'world_class'], type=str, required=False, default='beginner') 41 | 42 | parser_gnubg.set_defaults(func=utils.args_gnubg) 43 | parser_evaluate.set_defaults(func=utils.args_evaluate) 44 | 45 | parser_gui = subparsers.add_parser('gui', help='Start Web GUI', formatter_class=lambda prog: formatter(prog)) 46 | parser_gui.add_argument('--host', help='Host running the web gui', default='localhost') 47 | parser_gui.add_argument('--port', help='Port listening for command', default=8002, type=int) 48 | parser_gui.add_argument('--model', help='Model used by the AI opponent', required=True, type=str) 49 | parser_gui.add_argument('--hidden_units', help='Hidden units of the model loaded', required=False, type=int, default=40) 50 | parser_gui.add_argument('--type', help='Model type', choices=['cnn', 'nn'], type=str, default='nn') 51 | 52 | parser_gui.set_defaults(func=utils.args_gui) 53 | 54 | parser_plot = subparsers.add_parser('plot', help='Plot the performance (wins)', formatter_class=lambda prog: formatter(prog)) 55 | parser_plot.add_argument('--save_path', help='Directory where the model are saved', type=str, required=True) 56 | parser_plot.add_argument('--hidden_units', help='Hidden units of the model(s) loaded', type=int, default=40) 57 | parser_plot.add_argument('--episodes', help='Number of episodes/games against a single opponent', default=20, type=int) 58 | parser_plot.add_argument('--opponent', help='Opponent(s) agent(s) (delimited by comma) - "random" and/or "gnubg"', default='random', type=str) 59 | parser_plot.add_argument('--host', help='Host running gnubg (if gnubg in --opponent)', type=str) 60 | parser_plot.add_argument('--port', help='Port listening for gnubg commands (if gnubg in --opponent)', type=int) 61 | parser_plot.add_argument('--difficulty', help='Difficulty level(s) (delimited by comma)', type=str, default="beginner,intermediate,advanced,world_class") 62 | parser_plot.add_argument('--dst', help='Save directory location', type=str, default='myexp') 63 | parser_plot.add_argument('--type', help='Model type', choices=['cnn', 'nn'], type=str, default='nn') 64 | 65 | parser_plot.set_defaults(func=lambda args: utils.args_plot(args, parser)) 66 | 67 | args = parser.parse_args() 68 | args.func(args) 69 | -------------------------------------------------------------------------------- /td_gammon/agents.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from itertools import count 4 | from random import randint, choice 5 | 6 | import numpy as np 7 | from gym_backgammon.envs.backgammon import WHITE, BLACK, COLORS 8 | 9 | random.seed(0) 10 | 11 | 12 | # AGENT ============================================================================================ 13 | 14 | 15 | class Agent: 16 | def __init__(self, color): 17 | self.color = color 18 | self.name = 'Agent({})'.format(COLORS[color]) 19 | 20 | def roll_dice(self): 21 | return (-randint(1, 6), -randint(1, 6)) if self.color == WHITE else (randint(1, 6), randint(1, 6)) 22 | 23 | def choose_best_action(self, actions, env): 24 | raise NotImplementedError 25 | 26 | 27 | # RANDOM AGENT ======================================================================================= 28 | 29 | 30 | class RandomAgent(Agent): 31 | def __init__(self, color): 32 | super().__init__(color) 33 | self.name = 'RandomAgent({})'.format(COLORS[color]) 34 | 35 | def choose_best_action(self, actions, env): 36 | return choice(list(actions)) if actions else None 37 | 38 | 39 | # HUMAN AGENT ======================================================================================= 40 | 41 | 42 | class HumanAgent(Agent): 43 | def __init__(self, color): 44 | super().__init__(color) 45 | self.name = 'HumanAgent({})'.format(COLORS[color]) 46 | 47 | def choose_best_action(self, actions=None, env=None): 48 | pass 49 | 50 | 51 | # TD-GAMMON AGENT ===================================================================================== 52 | 53 | 54 | class TDAgent(Agent): 55 | def __init__(self, color, net): 56 | super().__init__(color) 57 | self.net = net 58 | self.name = 'TDAgent({})'.format(COLORS[color]) 59 | 60 | def choose_best_action(self, actions, env): 61 | best_action = None 62 | 63 | if actions: 64 | values = [0.0] * len(actions) 65 | tmp_counter = env.counter 66 | env.counter = 0 67 | state = env.game.save_state() 68 | 69 | # Iterate over all the legal moves and pick the best action 70 | for i, action in enumerate(actions): 71 | observation, reward, done, info = env.step(action) 72 | values[i] = self.net(observation) 73 | 74 | # restore the board and other variables (undo the action) 75 | env.game.restore_state(state) 76 | 77 | # practical-issues-in-temporal-difference-learning, pag.3 78 | # ... the network's output P_t is an estimate of White's probability of winning from board position x_t. 79 | # ... the move which is selected at each time step is the move which maximizes P_t when White is to play and minimizes P_t when Black is to play. 80 | best_action_index = int(np.argmax(values)) if self.color == WHITE else int(np.argmin(values)) 81 | best_action = list(actions)[best_action_index] 82 | env.counter = tmp_counter 83 | 84 | return best_action 85 | 86 | 87 | # TD-GAMMON AGENT (play against gnubg) ================================================================ 88 | 89 | 90 | class TDAgentGNU(TDAgent): 91 | 92 | def __init__(self, color, net, gnubg_interface): 93 | super().__init__(color, net) 94 | self.gnubg_interface = gnubg_interface 95 | 96 | def roll_dice(self): 97 | gnubg = self.gnubg_interface.send_command("roll") 98 | return self.handle_opponent_move(gnubg) 99 | 100 | def choose_best_action(self, actions, env): 101 | best_action = None 102 | 103 | if actions: 104 | game = env.game 105 | values = [0.0] * len(actions) 106 | state = game.save_state() 107 | 108 | for i, action in enumerate(actions): 109 | game.execute_play(self.color, action) 110 | opponent = game.get_opponent(self.color) 111 | observation = game.get_board_features(opponent) if env.model_type == 'nn' else env.render(mode='state_pixels') 112 | values[i] = self.net(observation) 113 | game.restore_state(state) 114 | 115 | best_action_index = int(np.argmax(values)) if self.color == WHITE else int(np.argmin(values)) 116 | best_action = list(actions)[best_action_index] 117 | 118 | return best_action 119 | 120 | def handle_opponent_move(self, gnubg): 121 | # Once I roll the dice, 2 possible situations can happen: 122 | # 1) I can move (the value gnubg.roll is not None) 123 | # 2) I cannot move, so my opponent rolls the dice and makes its move, and eventually ask for doubling, so I have to roll the dice again 124 | 125 | # One way to distinguish between the above cases, is to check the color of the player that performs the last move in gnubg: 126 | # - if the player's color is the same as the TD Agent, it means I can send the 'move' command (no other moves have been performed after the 'roll' command) - case 1); 127 | # - if the player's color is not the same as the TD Agent, this means that the last move performed after the 'roll' is not of the TD agent - case 2) 128 | previous_agent = gnubg.agent 129 | if previous_agent == self.color: # case 1) 130 | return gnubg 131 | else: # case 2) 132 | while previous_agent != self.color and gnubg.winner is None: 133 | # check if my opponent asks for doubling 134 | if gnubg.double: 135 | # default action if the opponent asks for doubling is 'take' 136 | gnubg = self.gnubg_interface.send_command("take") 137 | else: 138 | gnubg = self.gnubg_interface.send_command("roll") 139 | previous_agent = gnubg.agent 140 | return gnubg 141 | 142 | 143 | def evaluate_agents(agents, env, n_episodes): 144 | wins = {WHITE: 0, BLACK: 0} 145 | 146 | for episode in range(n_episodes): 147 | 148 | agent_color, first_roll, observation = env.reset() 149 | agent = agents[agent_color] 150 | 151 | t = time.time() 152 | 153 | for i in count(): 154 | 155 | if first_roll: 156 | roll = first_roll 157 | first_roll = None 158 | else: 159 | roll = agent.roll_dice() 160 | 161 | actions = env.get_valid_actions(roll) 162 | action = agent.choose_best_action(actions, env) 163 | observation_next, reward, done, winner = env.step(action) 164 | 165 | if done: 166 | if winner is not None: 167 | wins[agent.color] += 1 168 | tot = wins[WHITE] + wins[BLACK] 169 | tot = tot if tot > 0 else 1 170 | 171 | print("EVAL => Game={:<6d} | Winner={} | after {:<4} plays || Wins: {}={:<6}({:<5.1f}%) | {}={:<6}({:<5.1f}%) | Duration={:<.3f} sec".format(episode + 1, winner, i, 172 | agents[WHITE].name, wins[WHITE], (wins[WHITE] / tot) * 100, 173 | agents[BLACK].name, wins[BLACK], (wins[BLACK] / tot) * 100, time.time() - t)) 174 | break 175 | 176 | agent_color = env.get_opponent_agent() 177 | agent = agents[agent_color] 178 | 179 | observation = observation_next 180 | return wins 181 | -------------------------------------------------------------------------------- /td_gammon/model.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import random 3 | import time 4 | from itertools import count 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | from agents import TDAgent, RandomAgent, evaluate_agents 11 | from gym_backgammon.envs.backgammon import WHITE, BLACK 12 | 13 | torch.set_default_tensor_type('torch.DoubleTensor') 14 | 15 | 16 | class BaseModel(nn.Module): 17 | def __init__(self, lr, lamda, seed=123): 18 | super(BaseModel, self).__init__() 19 | self.lr = lr 20 | self.lamda = lamda # trace-decay parameter 21 | self.start_episode = 0 22 | 23 | self.eligibility_traces = None 24 | self.optimizer = None 25 | 26 | torch.manual_seed(seed) 27 | random.seed(seed) 28 | 29 | def update_weights(self, p, p_next): 30 | raise NotImplementedError 31 | 32 | def forward(self, x): 33 | raise NotImplementedError 34 | 35 | def init_weights(self): 36 | raise NotImplementedError 37 | 38 | def init_eligibility_traces(self): 39 | self.eligibility_traces = [torch.zeros(weights.shape, requires_grad=False) for weights in list(self.parameters())] 40 | 41 | def checkpoint(self, checkpoint_path, step, name_experiment): 42 | path = checkpoint_path + "/{}_{}_{}.tar".format(name_experiment, datetime.datetime.now().strftime('%Y%m%d_%H%M_%S_%f'), step + 1) 43 | torch.save({'step': step + 1, 'model_state_dict': self.state_dict(), 'eligibility': self.eligibility_traces if self.eligibility_traces else []}, path) 44 | print("\nCheckpoint saved: {}".format(path)) 45 | 46 | def load(self, checkpoint_path, optimizer=None, eligibility_traces=None): 47 | checkpoint = torch.load(checkpoint_path) 48 | self.start_episode = checkpoint['step'] 49 | 50 | self.load_state_dict(checkpoint['model_state_dict']) 51 | 52 | if eligibility_traces is not None: 53 | self.eligibility_traces = checkpoint['eligibility'] 54 | 55 | if optimizer is not None: 56 | self.optimizer.load_state_dict(checkpoint['optimizer']) 57 | 58 | def train_agent(self, env, n_episodes, save_path=None, eligibility=False, save_step=0, name_experiment=''): 59 | start_episode = self.start_episode 60 | n_episodes += start_episode 61 | 62 | wins = {WHITE: 0, BLACK: 0} 63 | network = self 64 | 65 | agents = {WHITE: TDAgent(WHITE, net=network), BLACK: TDAgent(BLACK, net=network)} 66 | 67 | durations = [] 68 | steps = 0 69 | start_training = time.time() 70 | 71 | for episode in range(start_episode, n_episodes): 72 | 73 | if eligibility: 74 | self.init_eligibility_traces() 75 | 76 | agent_color, first_roll, observation = env.reset() 77 | agent = agents[agent_color] 78 | 79 | t = time.time() 80 | 81 | for i in count(): 82 | if first_roll: 83 | roll = first_roll 84 | first_roll = None 85 | else: 86 | roll = agent.roll_dice() 87 | 88 | p = self(observation) 89 | 90 | actions = env.get_valid_actions(roll) 91 | action = agent.choose_best_action(actions, env) 92 | observation_next, reward, done, winner = env.step(action) 93 | p_next = self(observation_next) 94 | 95 | if done: 96 | if winner is not None: 97 | loss = self.update_weights(p, reward) 98 | 99 | wins[agent.color] += 1 100 | 101 | tot = sum(wins.values()) 102 | tot = tot if tot > 0 else 1 103 | 104 | print("Game={:<6d} | Winner={} | after {:<4} plays || Wins: {}={:<6}({:<5.1f}%) | {}={:<6}({:<5.1f}%) | Duration={:<.3f} sec".format(episode + 1, winner, i, 105 | agents[WHITE].name, wins[WHITE], (wins[WHITE] / tot) * 100, 106 | agents[BLACK].name, wins[BLACK], (wins[BLACK] / tot) * 100, time.time() - t)) 107 | 108 | durations.append(time.time() - t) 109 | steps += i 110 | break 111 | else: 112 | loss = self.update_weights(p, p_next) 113 | 114 | agent_color = env.get_opponent_agent() 115 | agent = agents[agent_color] 116 | 117 | observation = observation_next 118 | 119 | if save_path and save_step > 0 and episode > 0 and (episode + 1) % save_step == 0: 120 | self.checkpoint(checkpoint_path=save_path, step=episode, name_experiment=name_experiment) 121 | agents_to_evaluate = {WHITE: TDAgent(WHITE, net=network), BLACK: RandomAgent(BLACK)} 122 | evaluate_agents(agents_to_evaluate, env, n_episodes=20) 123 | print() 124 | 125 | print("\nAverage duration per game: {} seconds".format(round(sum(durations) / n_episodes, 3))) 126 | print("Average game length: {} plays | Total Duration: {}".format(round(steps / n_episodes, 2), datetime.timedelta(seconds=int(time.time() - start_training)))) 127 | 128 | if save_path: 129 | self.checkpoint(checkpoint_path=save_path, step=n_episodes - 1, name_experiment=name_experiment) 130 | 131 | with open('{}/comments.txt'.format(save_path), 'a') as file: 132 | file.write("Average duration per game: {} seconds".format(round(sum(durations) / n_episodes, 3))) 133 | file.write("\nAverage game length: {} plays | Total Duration: {}".format(round(steps / n_episodes, 2), datetime.timedelta(seconds=int(time.time() - start_training)))) 134 | 135 | env.close() 136 | 137 | 138 | class TDGammonCNN(BaseModel): 139 | def __init__(self, lr, seed=123, output_units=1): 140 | super(TDGammonCNN, self).__init__(lr, seed=seed, lamda=0.7) 141 | 142 | self.loss_fn = torch.nn.MSELoss(reduction='sum') 143 | 144 | self.conv1 = nn.Sequential( 145 | nn.Conv2d(in_channels=1, out_channels=32, kernel_size=8, stride=4), # CHANNEL it was 3 146 | nn.BatchNorm2d(32), 147 | nn.ReLU() 148 | ) 149 | 150 | self.conv2 = nn.Sequential( 151 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), 152 | nn.BatchNorm2d(64), 153 | nn.ReLU() 154 | ) 155 | 156 | self.conv3 = nn.Sequential( 157 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), 158 | nn.BatchNorm2d(64), 159 | nn.ReLU() 160 | ) 161 | 162 | self.hidden = nn.Sequential( 163 | nn.Linear(64 * 8 * 8, 80), 164 | nn.Sigmoid() 165 | ) 166 | 167 | self.output = nn.Sequential( 168 | nn.Linear(80, output_units), 169 | nn.Sigmoid() 170 | ) 171 | 172 | self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 173 | 174 | def init_weights(self): 175 | pass 176 | 177 | def forward(self, x): 178 | # https://stackoverflow.com/questions/12201577/how-can-i-convert-an-rgb-image-into-grayscale-in-python 179 | x = np.dot(x[..., :3], [0.2989, 0.5870, 0.1140]) 180 | x = x[np.newaxis, :] 181 | x = torch.from_numpy(np.array(x)) 182 | x = x.unsqueeze(0) 183 | x = x.type(torch.DoubleTensor) 184 | 185 | x = self.conv1(x) 186 | x = self.conv2(x) 187 | x = self.conv3(x) 188 | x = x.view(-1, 64 * 8 * 8) 189 | x = x.reshape(-1) 190 | x = self.hidden(x) 191 | x = self.output(x) 192 | return x 193 | 194 | def update_weights(self, p, p_next): 195 | 196 | if isinstance(p_next, int): 197 | p_next = torch.tensor([p_next], dtype=torch.float64) 198 | 199 | loss = self.loss_fn(p_next, p) 200 | self.optimizer.zero_grad() 201 | loss.backward() 202 | self.optimizer.step() 203 | return loss 204 | 205 | 206 | class TDGammon(BaseModel): 207 | def __init__(self, hidden_units, lr, lamda, init_weights, seed=123, input_units=198, output_units=1): 208 | super(TDGammon, self).__init__(lr, lamda, seed=seed) 209 | 210 | self.hidden = nn.Sequential( 211 | nn.Linear(input_units, hidden_units), 212 | nn.Sigmoid() 213 | ) 214 | 215 | # self.hidden2 = nn.Sequential( 216 | # nn.Linear(hidden_units, hidden_units), 217 | # nn.Sigmoid() 218 | # ) 219 | 220 | # self.hidden3 = nn.Sequential( 221 | # nn.Linear(hidden_units, hidden_units), 222 | # nn.Sigmoid() 223 | # ) 224 | 225 | self.output = nn.Sequential( 226 | nn.Linear(hidden_units, output_units), 227 | nn.Sigmoid() 228 | ) 229 | 230 | if init_weights: 231 | self.init_weights() 232 | 233 | def init_weights(self): 234 | for p in self.parameters(): 235 | nn.init.zeros_(p) 236 | 237 | def forward(self, x): 238 | x = torch.from_numpy(np.array(x)) 239 | x = self.hidden(x) 240 | # x = self.hidden2(x) 241 | # x = self.hidden3(x) 242 | x = self.output(x) 243 | return x 244 | 245 | def update_weights(self, p, p_next): 246 | # reset the gradients 247 | self.zero_grad() 248 | 249 | # compute the derivative of p w.r.t. the parameters 250 | p.backward() 251 | 252 | with torch.no_grad(): 253 | 254 | td_error = p_next - p 255 | 256 | # get the parameters of the model 257 | parameters = list(self.parameters()) 258 | 259 | for i, weights in enumerate(parameters): 260 | 261 | # z <- gamma * lambda * z + (grad w w.r.t P_t) 262 | self.eligibility_traces[i] = self.lamda * self.eligibility_traces[i] + weights.grad 263 | 264 | # w <- w + alpha * td_error * z 265 | new_weights = weights + self.lr * td_error * self.eligibility_traces[i] 266 | weights.copy_(new_weights) 267 | 268 | return td_error 269 | -------------------------------------------------------------------------------- /td_gammon/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gym 3 | import sys 4 | from agents import TDAgent, HumanAgent, TDAgentGNU, RandomAgent, evaluate_agents 5 | from gnubg.gnubg_backgammon import GnubgInterface, GnubgEnv, evaluate_vs_gnubg 6 | from gym_backgammon.envs.backgammon import WHITE, BLACK 7 | from model import TDGammon, TDGammonCNN 8 | from web_gui.gui import GUI 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | # tensorboard --logdir=runs/ --host localhost --port 8001 12 | 13 | 14 | def write_file(path, **kwargs): 15 | with open('{}/parameters.txt'.format(path), 'w+') as file: 16 | print("Parameters:") 17 | for key, value in kwargs.items(): 18 | file.write("{}={}\n".format(key, value)) 19 | print("{}={}".format(key, value)) 20 | print() 21 | 22 | 23 | def path_exists(path): 24 | if os.path.exists(path): 25 | return True 26 | else: 27 | print("The path {} doesn't exists".format(path)) 28 | sys.exit() 29 | 30 | 31 | 32 | # ==================================== TRAINING PARAMETERS =================================== 33 | def args_train(args): 34 | save_step = args.save_step 35 | save_path = None 36 | n_episodes = args.episodes 37 | init_weights = args.init_weights 38 | lr = args.lr 39 | hidden_units = args.hidden_units 40 | lamda = args.lamda 41 | name = args.name 42 | model_type = args.type 43 | seed = args.seed 44 | 45 | eligibility = False 46 | optimizer = None 47 | 48 | if model_type == 'nn': 49 | net = TDGammon(hidden_units=hidden_units, lr=lr, lamda=lamda, init_weights=init_weights, seed=seed) 50 | eligibility = True 51 | env = gym.make('gym_backgammon:backgammon-v0') 52 | 53 | else: 54 | net = TDGammonCNN(lr=lr, seed=seed) 55 | optimizer = True 56 | env = gym.make('gym_backgammon:backgammon-pixel-v0') 57 | 58 | if args.model and path_exists(args.model): 59 | # assert os.path.exists(args.model), print("The path {} doesn't exists".format(args.model)) 60 | net.load(checkpoint_path=args.model, optimizer=optimizer, eligibility_traces=eligibility) 61 | 62 | if args.save_path and path_exists(args.save_path): 63 | # assert os.path.exists(args.save_path), print("The path {} doesn't exists".format(args.save_path)) 64 | save_path = args.save_path 65 | 66 | write_file( 67 | save_path, save_path=args.save_path, command_line_args=args, type=model_type, hidden_units=hidden_units, init_weights=init_weights, alpha=net.lr, lamda=net.lamda, 68 | n_episodes=n_episodes, save_step=save_step, start_episode=net.start_episode, name_experiment=name, env=env.spec.id, restored_model=args.model, seed=seed, 69 | eligibility=eligibility, optimizer=optimizer, modules=[module for module in net.modules()] 70 | ) 71 | 72 | net.train_agent(env=env, n_episodes=n_episodes, save_path=save_path, save_step=save_step, eligibility=eligibility, name_experiment=name) 73 | 74 | 75 | # ==================================== WEB GUI PARAMETERS ==================================== 76 | def args_gui(args): 77 | if path_exists(args.model): 78 | # assert os.path.exists(args.model), print("The path {} doesn't exists".format(args.model)) 79 | 80 | if args.type == 'nn': 81 | net = TDGammon(hidden_units=args.hidden_units, lr=0.1, lamda=None, init_weights=False) 82 | env = gym.make('gym_backgammon:backgammon-v0') 83 | else: 84 | net = TDGammonCNN(lr=0.0001) 85 | env = gym.make('gym_backgammon:backgammon-pixel-v0') 86 | 87 | net.load(checkpoint_path=args.model, optimizer=None, eligibility_traces=False) 88 | 89 | agents = {BLACK: TDAgent(BLACK, net=net), WHITE: HumanAgent(WHITE)} 90 | gui = GUI(env=env, host=args.host, port=args.port, agents=agents) 91 | gui.run() 92 | 93 | 94 | # =================================== EVALUATE PARAMETERS ==================================== 95 | def args_evaluate(args): 96 | model_agent0 = args.model_agent0 97 | model_agent1 = args.model_agent1 98 | model_type = args.type 99 | hidden_units_agent0 = args.hidden_units_agent0 100 | hidden_units_agent1 = args.hidden_units_agent1 101 | n_episodes = args.episodes 102 | 103 | if path_exists(model_agent0) and path_exists(model_agent1): 104 | # assert os.path.exists(model_agent0), print("The path {} doesn't exists".format(model_agent0)) 105 | # assert os.path.exists(model_agent1), print("The path {} doesn't exists".format(model_agent1)) 106 | 107 | if model_type == 'nn': 108 | net0 = TDGammon(hidden_units=hidden_units_agent0, lr=0.1, lamda=None, init_weights=False) 109 | net1 = TDGammon(hidden_units=hidden_units_agent1, lr=0.1, lamda=None, init_weights=False) 110 | env = gym.make('gym_backgammon:backgammon-v0') 111 | else: 112 | net0 = TDGammonCNN(lr=0.0001) 113 | net1 = TDGammonCNN(lr=0.0001) 114 | env = gym.make('gym_backgammon:backgammon-pixel-v0') 115 | 116 | net0.load(checkpoint_path=model_agent0, optimizer=None, eligibility_traces=False) 117 | net1.load(checkpoint_path=model_agent1, optimizer=None, eligibility_traces=False) 118 | 119 | agents = {WHITE: TDAgent(WHITE, net=net1), BLACK: TDAgent(BLACK, net=net0)} 120 | 121 | evaluate_agents(agents, env, n_episodes) 122 | 123 | 124 | # ===================================== GNUBG PARAMETERS ===================================== 125 | def args_gnubg(args): 126 | model_agent0 = args.model_agent0 127 | model_type = args.type 128 | hidden_units_agent0 = args.hidden_units_agent0 129 | n_episodes = args.episodes 130 | host = args.host 131 | port = args.port 132 | difficulty = args.difficulty 133 | 134 | if path_exists(model_agent0): 135 | # assert os.path.exists(model_agent0), print("The path {} doesn't exists".format(model_agent0)) 136 | if model_type == 'nn': 137 | net0 = TDGammon(hidden_units=hidden_units_agent0, lr=0.1, lamda=None, init_weights=False) 138 | else: 139 | net0 = TDGammonCNN(lr=0.0001) 140 | 141 | net0.load(checkpoint_path=model_agent0, optimizer=None, eligibility_traces=False) 142 | 143 | gnubg_interface = GnubgInterface(host=host, port=port) 144 | gnubg_env = GnubgEnv(gnubg_interface, difficulty=difficulty, model_type=model_type) 145 | evaluate_vs_gnubg(agent=TDAgentGNU(WHITE, net=net0, gnubg_interface=gnubg_interface), env=gnubg_env, n_episodes=n_episodes) 146 | 147 | 148 | # ===================================== PLOT PARAMETERS ====================================== 149 | def args_plot(args, parser): 150 | ''' 151 | This method is used to plot the number of time an agent wins when it plays against an opponent. 152 | Instead of evaluating the agent during training (it can require some time and slow down the training), I decided to plot the wins separately, loading the different 153 | model saved during training. 154 | For example, suppose I run the training for 100 games and save my model every 10 games. 155 | Later I will load these 10 models, and for each of them, I will compute how many times the agent would win against an opponent. 156 | :return: None 157 | ''' 158 | 159 | src = args.save_path 160 | hidden_units = args.hidden_units 161 | n_episodes = args.episodes 162 | opponents = args.opponent.split(',') 163 | host = args.host 164 | port = args.port 165 | difficulties = args.difficulty.split(',') 166 | model_type = args.type 167 | 168 | if path_exists(src): 169 | # assert os.path.exists(src), print("The path {} doesn't exists".format(src)) 170 | 171 | for d in difficulties: 172 | if d not in ['beginner', 'intermediate', 'advanced', 'world_class']: 173 | parser.error("--difficulty should be (one or more of) 'beginner','intermediate', 'advanced' ,'world_class'") 174 | 175 | dst = args.dst 176 | 177 | if 'gnubg' in opponents and (not host or not port): 178 | parser.error("--host and --port are required when 'gnubg' is specified in --opponent") 179 | 180 | for root, dirs, files in os.walk(src): 181 | global_step = 0 182 | files = sorted(files) 183 | 184 | writer = SummaryWriter(dst) 185 | 186 | for file in files: 187 | if ".tar" in file: 188 | print("\nLoad {}".format(os.path.join(root, file))) 189 | 190 | if model_type == 'nn': 191 | net = TDGammon(hidden_units=hidden_units, lr=0.1, lamda=None, init_weights=False) 192 | env = gym.make('gym_backgammon:backgammon-v0') 193 | else: 194 | net = TDGammonCNN(lr=0.0001) 195 | env = gym.make('gym_backgammon:backgammon-pixel-v0') 196 | 197 | net.load(checkpoint_path=os.path.join(root, file), optimizer=None, eligibility_traces=False) 198 | 199 | if 'gnubg' in opponents: 200 | tag_scalar_dict = {} 201 | 202 | gnubg_interface = GnubgInterface(host=host, port=port) 203 | 204 | for difficulty in difficulties: 205 | gnubg_env = GnubgEnv(gnubg_interface, difficulty=difficulty, model_type=model_type) 206 | wins = evaluate_vs_gnubg(agent=TDAgentGNU(WHITE, net=net, gnubg_interface=gnubg_interface), env=gnubg_env, n_episodes=n_episodes) 207 | tag_scalar_dict[difficulty] = wins[WHITE] 208 | 209 | writer.add_scalars('wins_vs_gnubg/', tag_scalar_dict, global_step) 210 | 211 | with open(root + '/results.txt', 'a') as f: 212 | print("{};".format(file) + str(tag_scalar_dict), file=f) 213 | 214 | if 'random' in opponents: 215 | tag_scalar_dict = {} 216 | agents = {WHITE: TDAgent(WHITE, net=net), BLACK: RandomAgent(BLACK)} 217 | wins = evaluate_agents(agents, env, n_episodes) 218 | 219 | tag_scalar_dict['random'] = wins[WHITE] 220 | 221 | writer.add_scalars('wins_vs_random/', tag_scalar_dict, global_step) 222 | 223 | global_step += 1 224 | 225 | writer.close() 226 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

TD-Gammon


3 |

4 | Backgammon 5 |

6 | 7 | --- 8 | # Table of Contents 9 | - [Features](#features) 10 | - [Installation](#installation) 11 | - [How to interact with GNU Backgammon using Python Script?](#howto) 12 | - [Usage](#usage) 13 | - [Train TD-Network](#train) 14 | - [Evaluate Agent(s)](#evaluate) 15 | - [Web Interface](#web_interface) 16 | - [Plot Wins](#plot) 17 | - [Backgammon OpenAI Gym Environment](#env) 18 | - [Bibliography, sources of inspiration, related works](#biblio) 19 | - [License](#license) 20 | --- 21 | ## Features 22 | - PyTorch implementation of TD-Gammon [1]. 23 | - Test the trained agents against an open source implementation of the Backgammon game, [GNU Backgammon](https://www.gnu.org/software/gnubg/). 24 | - Play against a trained agent via web gui 25 | 26 | --- 27 | ## Installation 28 | 29 | I used [`Anaconda3`](https://www.anaconda.com/distribution/), with `Python 3.6.8` (I tested only with the following configurations). 30 | 31 | Create the conda environment: 32 | ``` 33 | $ conda create --name tdgammon python=3.6 34 | $ source activate tdgammon 35 | (tdgammon) $ git clone https://github.com/dellalibera/td-gammon.git 36 | ``` 37 | Install the environment [`gym-backgammon`](#https://github.com/dellalibera/gym-backgammon): 38 | ``` 39 | (tdgammon) $ git clone https://github.com/dellalibera/gym-backgammon.git 40 | (tdgammon) $ cd gym-backgammon 41 | (tdgammon) $ pip install -e . 42 | ``` 43 | 44 | Install the dependencies [`pytorch v1.2`](https://pytorch.org/get-started/locally/): 45 | ``` 46 | (tdgammon) $ pip install torch torchvision 47 | (tdgammon) $ pip install tb-nightly 48 | ``` 49 | or 50 | ``` 51 | (tdgammon) $ cd td-gammon/ 52 | (tdgammon) $ pip install -r requirements.txt 53 | ``` 54 | 55 | #### Without Anaconda Environment 56 | If you don't use Anaconda environment, run the following commands: 57 | ``` 58 | git clone https://github.com/dellalibera/td-gammon.git 59 | pip3 install -r td-gammon/requirements.txt 60 | git clone https://github.com/dellalibera/gym-backgammon.git 61 | cd gym-backgammon/ 62 | pip3 install -e . 63 | ``` 64 | If you don't use Anaconda environment, in the commands below replace `python` with `python3`. 65 | 66 | 67 | ### GNU Backgammon 68 | To play against `gnubg`, you have to install [`gnubg`](https://www.gnu.org/software/gnubg/). 69 | **NOTE**: I installed `gnubg` on `Ubuntu 18.04` (running on a Virtual Machine), with `Python 2.7` (see next section to see how to interact with GNU Backgammon). 70 | #### On Ubuntu: 71 | ``` 72 | sudo apt-get install gnubg 73 | ``` 74 | --- 75 | ## How to interact with GNU Backgammon using Python Script? 76 | I used an `http server` that runs on the Guest machine (Ubuntu), to receive commands and interact with the `gnubg` program. 77 | In this way, it's possible to send commands from the Host machine (in my case `MacOS`). 78 |
79 | The file `bridge.py` should be executed on the Guest Machine (the machine where `gnubg` is installed). 80 | #### On Ubuntu: 81 | ``` 82 | gnubg -t -p /path/to/bridge.py 83 | ``` 84 | It runs the `gnubg` with the command-line instead of using the graphical interface (`-t`) and evaluates a Python code file and exits (`-p`). 85 | For a list of parameters of `gnubg`, run `gnubg --help`. 86 |
87 | The python script `bridge.py` creates an `http server`, running on `localhost:8001`. 88 | If you want to modify the host and the port, change the following line in `bridge.py`: 89 | ```python 90 | if __name__ == "__main__": 91 | HOST = 'localhost' # <-- YOUR HOST HERE 92 | PORT = 8001 # <-- YOUR PORT HERE 93 | run(host=HOST, port=PORT) 94 | ``` 95 | The file `td_gammon/gnubg/gnubg_backgammon.py` sends messages/commands to `gnubg` and parses the response. 96 | 97 | --- 98 | ## Usage 99 | Run `python /path/to/main.py --help` for a list of parameters. 100 | 101 | ### Train TD-Network 102 | To train a neural network with a single layer with `40` hidden units, for `100000` games/episodes and save the model every `10000`, run the following command: 103 | ``` 104 | (tdgammon) $ python /path/to/main.py train --save_path ./saved_models/exp1 --save_step 10000 --episodes 100000 --name exp1 --type nn --lr 0.1 --hidden_units 40 105 | ``` 106 | Run `python /path/to/main.py train --help` for a list of parameters available for training. 107 | 108 | --- 109 | ### Evaluate Agent(s) 110 | To evaluate an already trained models, you have to options: evaluate models to play against each other or evaluate one model against `gnubg`. 111 | Run `python /path/to/main.py evaluate --help` for a list of parameters available for evaluation. 112 | 113 | 114 | ### Agent vs Agent 115 | To evaluate two model to play against each other you have to specify the path where the models are saved with the corresponding number of hidden units. 116 | ``` 117 | (tdgammon) $ python /path/to/main.py evaluate --episodes 50 --hidden_units_agent0 40 --hidden_units_agent1 40 --type nn --model_agent0 path/to/saved_models/agent0.tar --model_agent1 path/to/saved_models/agent1.tar 118 | ``` 119 | 120 | ### Agent vs gnubg 121 | To evaluate one model to play against `gnubg`, first you have to run `gnubg` with the script `bridge` as input. 122 | On Ubuntu (or where `gnubg` is installed) 123 | ``` 124 | gnubg -t -p /path/to/bridge.py 125 | ``` 126 | Then run (to play vs `gnubg` at intermediate level for 100 games): 127 | ``` 128 | (tdgammon) $ python /path/to/main.py evaluate --episodes 50 --hidden_units_agent0 40 --type nn --model_agent0 path/to/saved_models/agent0.tar vs_gnubg --difficulty beginner --host GNUBG_HOST --port GNUBG_PORT 129 | ``` 130 | The hidden units (`--hidden_units_agent0`) of the model must be same of the loaded model (`--model_agent0`). 131 | 132 | --- 133 | ### Web Interface 134 | You can play against a trained agent via a web gui: 135 | ``` 136 | (tdgammon) $ python /path/to/main.py gui --host localhost --port 8002 --model path/to/saved_models/agent0.tar --hidden_units 40 --type nn 137 | ``` 138 | Then navigate to `http://localhost:8002` in your browser: 139 |

140 | Web Interface 141 |

142 | 143 | Run `python /path/to/main.py gui --help` for a list of parameters available about the web gui. 144 | 145 | --- 146 | ### Plot Wins 147 | Instead of evaluating the agent during training (it can require some time especially if you evaluate against `gnubg` - difficulty `world_class`), you can load all the saved models in a folder, and evaluate each model (saved at different time during training) against one or more opponents. 148 | The models in the directory should be of the same type (i.e the structure of the network should be the same for all the models in the same folder). 149 | 150 | To plot the wins against `gnubg`, run on Ubuntu (or where `gnubg` is installed): 151 | ``` 152 | gnubg -t -p /path/to/bridge.py 153 | ``` 154 | In the example below the trained model is going to be evaluated against `gnubg` on two different difficulties levels - `beginner` and `advanced`:` 155 | ``` 156 | (tdgammon) $ python /path/to/main.py plot --save_path /path/to/saved_models/myexp --hidden_units 40 --episodes 10 --opponent random,gnubg --dst /path/to/experiments --type nn --difficulty beginner,advanced --host GNUBG_HOST --port GNUBG_PORT 157 | ``` 158 | To visualize the plots: 159 | ``` 160 | (tdgammon) $ tensorboard --logdir=runs/path/to/experiment/ --host localhost --port 8001 161 | ``` 162 | Run `python /path/to/main.py plot --help` for a list of parameters available about plotting. 163 | 164 | ## Backgammon OpenAI Gym Environment 165 | For a detailed description of the environment: [`gym-backgammon`](https://github.com/dellalibera/gym-backgammon). 166 | 167 | --- 168 | ## Bibliography, sources of inspiration, related works 169 | - TD-Gammon and Temporal Difference Learning: 170 | - [1] [Practical Issues in Temporal Difference Learning](https://papers.nips.cc/paper/465-practical-issues-in-temporal-difference-learning.pdf) 171 | - [Temporal Difference Learning and TD-Gammon](https://researcher.watson.ibm.com/researcher/view_page.php?id=7021) 172 | - [Programming backgammon using self-teaching neural nets](www.bkgm.com/articles/tesauro/ProgrammingBackgammon.pdf) 173 | - [Implementaion Details TD-Gammon](http://www.scholarpedia.org/article/User:Gerald_Tesauro/Proposed/Td-gammon) 174 | - [Chapter 9 Temporal-Difference Learning](https://web.stanford.edu/group/pdplab/pdphandbook/handbookch10.html) 175 | - [Implementation Details of the TD(λ) Procedure for the Case of Vector Predictions and Backpropagation](https://www.ece.uvic.ca/~bctill/papers/learning/Sutton_1987.pdf) 176 | - [Learning to Predict by the Methods of Temporal Differences](http://incompleteideas.net/papers/sutton-88-with-erratum.pdf) 177 |

178 | - GNU Backgammon: https://www.gnu.org/software/gnubg/ 179 |

180 | - Rules of Backgammon: 181 | - www.bkgm.com/rules.html 182 | - https://en.wikipedia.org/wiki/Backgammon 183 | - Starting Position: http://www.bkgm.com/gloss/lookup.cgi?starting+position 184 | - https://bkgm.com/faq/ 185 |

186 | - Install GNU Backgammon on Ubuntu: 187 | - https://ubuntuforums.org/showthread.php?t=2217668 188 | - https://ubuntuforums.org/showthread.php?t=1506341 189 | - https://www.reddit.com/r/backgammon/comments/5gpkov/installing_gnu_or_xg_on_linux/ 190 |

191 | - How to use python to interact with `gnubg`: [\[Bug-gnubg\] Documentation: Looking for documentation on python scripting](https://www.mail-archive.com/bug-gnubg@gnu.org/msg06794.html) 192 |

193 | - Other Implementation of the Backgammon OpenAI Gym Environment: 194 | - https://github.com/edusta/gym-backgammon 195 |

196 | - Other Implementation of TD-Gammon: 197 | - https://github.com/TobiasVogt/TD-Gammon 198 | - https://github.com/millerm/TD-Gammon 199 | - https://github.com/fomorians/td-gammon 200 |

201 | - How to setup your VMWare Fusion images to use static IP addresses on Mac OS X 202 | - https://gist.github.com/pjkelly/1068716/6d19faa0122c0e1efe350e818bb8f4e8687ea1ab 203 |

204 | - PyTorch Tensorboard: https://pytorch.org/docs/stable/tensorboard.html 205 | 206 | --- 207 | ## License 208 | [MIT](https://github.com/dellalibera/td-gammon/blob/master/LICENSE) -------------------------------------------------------------------------------- /td_gammon/web_gui/gui.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | 5 | from http.server import BaseHTTPRequestHandler, HTTPServer 6 | from urllib.parse import urlparse 7 | from gym_backgammon.envs.backgammon import WHITE, BLACK 8 | 9 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 10 | 11 | COLORS = {WHITE: "White", BLACK: 'Black'} 12 | 13 | 14 | class Handler(BaseHTTPRequestHandler): 15 | 16 | def parse_data(self, data): 17 | command = data['command'].lower() 18 | 19 | if command in ['start', 'new game']: 20 | response = self.server.dispatcher['start']() 21 | 22 | elif command in ['roll']: 23 | response = self.server.dispatcher['roll']() 24 | 25 | elif 'move' in command: 26 | response = self.server.dispatcher['move'](command) 27 | 28 | else: 29 | message = 'Invalid command\n' 30 | response = {'message': message, 'state': self.server.env.game.state, 'actions': self.server.last_commands} 31 | 32 | self.server.last_commands = response['actions'] 33 | response['message'] = response['message'][:-1] # remove the new line character of the last line 34 | 35 | return response 36 | 37 | def _set_headers(self, response=200): 38 | self.send_response(response) 39 | self.send_header('Content-type', 'text/html') 40 | self.end_headers() 41 | 42 | def do_POST(self): 43 | post_data = self.rfile.read(int(self.headers['Content-Length'])).decode('utf-8') 44 | data = json.loads(post_data) 45 | # print(data) 46 | response = self.parse_data(data) 47 | self._set_headers() 48 | self.wfile.write(bytes(json.dumps(response), encoding='utf-8')) 49 | 50 | def do_GET(self): 51 | parsed = urlparse(self.path) 52 | path = parsed.path 53 | if self.path == '/': 54 | self._set_headers() 55 | # f = open("../td_gammon/web_gui/index.html").read() 56 | f = open(os.path.dirname(__file__) + "/../web_gui/index.html").read() 57 | 58 | # RESET THE STATE EVERY TIME A NEW PAGE IS REFRESHED 59 | self.server.reset() 60 | 61 | self.wfile.write(bytes(f, encoding='utf-8')) 62 | 63 | 64 | class GUI: 65 | def __init__(self, env, host="localhost", port=8002, agents=None): 66 | self.host = host 67 | self.port = port 68 | self.server = HTTPServer((host, port), Handler) 69 | 70 | # GAME VARIABLES 71 | self.server.agents = agents 72 | self.server.env = env 73 | 74 | self.server.reset = self.reset 75 | self.server.reset() 76 | 77 | self.server.dispatcher = { 78 | 'start': self.handle_start, 79 | 'roll': self.handle_roll, 80 | 'move': self.handle_move 81 | } 82 | 83 | def reset(self): 84 | self.server.agent = None 85 | self.server.first_roll = None 86 | self.server.wins = {WHITE: 0, BLACK: 0} 87 | self.server.roll = None 88 | self.server.game_started = False 89 | self.server.game_finished = False 90 | self.server.last_commands = [] 91 | 92 | def handle_start(self): 93 | self.server.reset() 94 | # response = {'message': '', 'state': [], 'actions': []} 95 | message = '\nNew game started\n' 96 | commands = [] 97 | 98 | if self.server.game_started: 99 | message = "The game is already started. To start a new game, type 'new game'\n" 100 | commands.append('new game') 101 | 102 | else: 103 | self.server.game_finished = False 104 | self.server.game_started = True 105 | 106 | agent_color, self.server.first_roll, observation = self.server.env.reset() 107 | self.server.agent = self.server.agents[agent_color] 108 | 109 | if agent_color == WHITE: 110 | message += "{} Starts first | Roll={} | Run 'move (src/target)'\n".format(COLORS[self.server.agent.color], (abs(self.server.first_roll[0]), abs(self.server.first_roll[1]))) 111 | commands.extend(self.server.env.get_valid_actions(self.server.first_roll)) 112 | self.server.roll = self.server.first_roll 113 | 114 | else: 115 | opponent = self.server.agents[agent_color] 116 | message += "{} Starts first | Roll={}\n".format(COLORS[opponent.color], (abs(self.server.first_roll[0]), abs(self.server.first_roll[1]))) 117 | 118 | if self.server.first_roll: 119 | roll = self.server.first_roll 120 | self.server.first_roll = None 121 | else: 122 | roll = opponent.roll_dice() 123 | 124 | actions = self.server.env.get_valid_actions(roll) 125 | action = opponent.choose_best_action(actions, self.server.env) 126 | message += "{} | Roll={} | Action={} | Run 'roll'\n".format(COLORS[opponent.color], roll, action) 127 | commands.extend(['roll', 'new game']) 128 | observation_next, reward, done, info = self.server.env.step(action) 129 | 130 | agent_color = self.server.env.get_opponent_agent() 131 | self.server.agent = self.server.agents[agent_color] 132 | 133 | return {'message': message, 'state': self.server.env.game.state, 'actions': list(commands)} 134 | 135 | def handle_roll(self): 136 | message = '' 137 | commands = [] 138 | 139 | if self.server.roll is not None: 140 | message += "You have already rolled the dice {}. Run 'move (src/target)'\n".format((abs(self.server.roll[0]), abs(self.server.roll[1]))) 141 | actions = self.server.env.get_valid_actions(self.server.roll) 142 | if len(actions) == 0: 143 | commands.append('start') 144 | else: 145 | commands.extend(list(actions)) 146 | 147 | elif self.server.game_finished: 148 | message += "The game is finished. Type 'Start' to start a new game\n".format((abs(self.server.roll[0]), abs(self.server.roll[1]))) 149 | commands.append('start') 150 | 151 | elif not self.server.game_started: 152 | message += "The game is not started. Type 'start' to start a new game\n" 153 | commands.append('start') 154 | 155 | else: 156 | self.server.roll = self.server.agent.roll_dice() 157 | message += "{} | Roll={} | Run 'move (src/target)'\n".format(COLORS[self.server.agent.color], (abs(self.server.roll[0]), abs(self.server.roll[1]))) 158 | actions = self.server.env.get_valid_actions(self.server.roll) 159 | commands.extend(list(actions)) 160 | 161 | if len(actions) == 0: 162 | message += "You cannot move\n" 163 | 164 | agent_color = self.server.env.get_opponent_agent() 165 | opponent = self.server.agents[agent_color] 166 | 167 | roll = opponent.roll_dice() 168 | 169 | actions = self.server.env.get_valid_actions(roll) 170 | action = opponent.choose_best_action(actions, self.server.env) 171 | message += "{} | Roll={} | Action={}\n".format(COLORS[opponent.color], roll, action) 172 | observation_next, reward, done, info = self.server.env.step(action) 173 | 174 | if done: 175 | winner = self.server.env.game.get_winner() 176 | message += "Game Finished!!! {} wins \n".format(COLORS[winner]) 177 | commands.append('new game') 178 | self.server.game_finished = True 179 | else: 180 | agent_color = self.server.env.get_opponent_agent() 181 | self.server.agent = self.server.agents[agent_color] 182 | self.server.roll = None 183 | commands.extend(['roll', 'new game']) 184 | 185 | return {'message': message, 'state': self.server.env.game.state, 'actions': list(commands)} 186 | 187 | def handle_move(self, command): 188 | message = '' 189 | commands = [] 190 | 191 | if self.server.roll is None: 192 | message += "You must roll the dice first\n" 193 | commands = self.server.last_commands 194 | 195 | elif self.server.game_finished: 196 | message += "The game is finished. Type 'new game' to start a new game\n".format((abs(self.server.roll[0]), abs(self.server.roll[1]))) 197 | commands.append('new game') 198 | 199 | else: 200 | try: 201 | action = command.split()[1] 202 | action = action.split(',') 203 | play = [] 204 | is_bar = False 205 | 206 | for move in action: 207 | move = move.replace('(', '') 208 | move = move.replace(')', '') 209 | s, t = move.split('/') 210 | 211 | if s == 'BAR' or s == 'bar': 212 | play.append(('bar', int(t))) 213 | is_bar = True 214 | else: 215 | play.append((int(s), int(t))) 216 | 217 | if is_bar: 218 | action = tuple(play) 219 | else: 220 | action = tuple(sorted(play, reverse=True)) 221 | 222 | except Exception as e: 223 | message += "Error during parsing move\n" 224 | commands = self.server.last_commands 225 | 226 | else: 227 | actions = self.server.env.get_valid_actions(self.server.roll) 228 | 229 | if action not in actions: 230 | message += "Illegal move | Roll={}\n".format((abs(self.server.roll[0]), abs(self.server.roll[1]))) 231 | else: 232 | message += "{} | Roll={} | Action={}\n".format(COLORS[self.server.agent.color], (abs(self.server.roll[0]), abs(self.server.roll[1])), action) 233 | observation_next, reward, done, info = self.server.env.step(action) 234 | 235 | if done: 236 | winner = self.server.env.game.get_winner() 237 | message += "Game Finished!!! {} wins\n".format(COLORS[winner]) 238 | commands.append('new game') 239 | self.server.game_finished = True 240 | 241 | else: 242 | agent_color = self.server.env.get_opponent_agent() 243 | opponent = self.server.agents[agent_color] 244 | 245 | roll = opponent.roll_dice() 246 | actions = self.server.env.get_valid_actions(roll) 247 | action = opponent.choose_best_action(actions, self.server.env) 248 | 249 | message += "{} | Roll={} | Action={}\n".format(COLORS[opponent.color], roll, action) 250 | observation_next, reward, done, info = self.server.env.step(action) 251 | 252 | if done: 253 | winner = self.server.env.game.get_winner() 254 | message += "Game Finished!!! {} wins\n".format(COLORS[winner]) 255 | commands.append('new game') 256 | self.server.game_finished = True 257 | 258 | else: 259 | commands.extend(['roll', 'new game']) 260 | agent_color = self.server.env.get_opponent_agent() 261 | self.server.agent = self.server.agents[agent_color] 262 | self.server.roll = None 263 | 264 | return {'message': message, 'state': self.server.env.game.state, 'actions': list(commands)} 265 | 266 | def run(self): 267 | print('Starting httpd (http://{}:{})...'.format(self.host, self.port)) 268 | self.server.serve_forever() 269 | -------------------------------------------------------------------------------- /td_gammon/gnubg/gnubg_backgammon.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | from collections import namedtuple 5 | from itertools import count 6 | import requests 7 | from gym_backgammon.envs.backgammon import Backgammon as Game, WHITE, BLACK, NUM_POINTS, COLORS, assert_board 8 | from gym_backgammon.envs.backgammon_env import STATE_W, STATE_H, SCREEN_W, SCREEN_H 9 | from gym_backgammon.envs.rendering import Viewer 10 | 11 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 12 | 13 | gnubgState = namedtuple('GNUState', ['agent', 'roll', 'move', 'board', 'double', 'winner', 'n_moves', 'action', 'resigned', 'history']) 14 | 15 | 16 | class GnubgInterface: 17 | def __init__(self, host, port): 18 | self.url = "http://{}:{}".format(host, port) 19 | # Mapping from gnu board representation to representation used by the environment 20 | self.gnu_to_idx = {23 - k: k for k in range(NUM_POINTS)} 21 | # In GNU Backgammon, position 25 (here 24 because I start from 0-24) represents the 'bar' move 22 | self.gnu_to_idx[24] = 'bar' 23 | self.gnu_to_idx[-1] = -1 24 | 25 | def send_command(self, command): 26 | try: 27 | resp = requests.post(url=self.url, data={"command": command}) 28 | return self.parse_response(resp.json()) 29 | except Exception as e: 30 | print("Error during connection to {}: {} (Remember to run gnubg -t -p bridge.py)".format(self.url, e)) 31 | 32 | def parse_response(self, response): 33 | gnubg_board = response["board"] 34 | action = response["last_move"][-1] if response["last_move"] else None 35 | 36 | info = response["info"][-1] if response["info"] else None 37 | 38 | winner = None 39 | n_moves = 0 40 | resigned = False 41 | double = False 42 | move = () 43 | roll = () 44 | agent = None 45 | 46 | if info: 47 | winner = info['winner'] 48 | n_moves = info['n_moves'] 49 | resigned = info['resigned'] 50 | 51 | if action: 52 | 53 | agent = WHITE if action['player'] == 'O' else BLACK 54 | 55 | if action['action'] == "double": 56 | double = True 57 | elif 'dice' in action: 58 | roll = tuple(action['dice']) 59 | roll = (-roll[0], -roll[1]) if agent == WHITE else (roll[0], roll[1]) 60 | 61 | if action['action'] == 'move': 62 | move = tuple(tuple([self.gnu_to_idx[a - 1], self.gnu_to_idx[b - 1]]) for (a, b) in action['move']) 63 | 64 | return gnubgState(agent=agent, roll=roll, move=move, board=gnubg_board[:], double=double, winner=winner, n_moves=n_moves, action=action, resigned=resigned, history=response["info"]) 65 | 66 | def parse_action(self, action): 67 | result = "" 68 | if action: 69 | for move in action: 70 | src, target = move 71 | if src == 'bar': 72 | result += "bar/{},".format(target + 1) 73 | elif target == -1: 74 | result += "{}/off,".format(src + 1) 75 | else: 76 | result += "{}/{},".format(src + 1, target + 1) 77 | 78 | return result[:-1] # remove the last semicolon 79 | 80 | 81 | class GnubgEnv: 82 | DIFFICULTIES = ['beginner', 'intermediate', 'advanced', 'world_class'] 83 | 84 | def __init__(self, gnubg_interface, difficulty='beginner', model_type='nn'): 85 | self.game = Game() 86 | self.current_agent = WHITE 87 | self.gnubg_interface = gnubg_interface 88 | self.gnubg = None 89 | self.difficulty = difficulty 90 | self.is_difficulty_set = False 91 | self.model_type = model_type 92 | self.viewer = None 93 | 94 | def step(self, action): 95 | reward = 0 96 | done = False 97 | 98 | if action and self.gnubg.winner is None: 99 | action = self.gnubg_interface.parse_action(action) 100 | self.gnubg = self.gnubg_interface.send_command(action) 101 | 102 | if self.gnubg.double and self.gnubg.winner is None: 103 | self.gnubg = self.gnubg_interface.send_command("take") 104 | 105 | if self.gnubg.agent == WHITE and self.gnubg.action['action'] == 'move' and self.gnubg.winner is None: 106 | if self.gnubg.winner != 'O': 107 | self.gnubg = self.gnubg_interface.send_command("accept") 108 | assert self.gnubg.winner == 'O', print(self.gnubg) 109 | assert self.gnubg.action['action'] == 'resign' and self.gnubg.agent == 1 and self.gnubg.action['player'] == 'X' 110 | assert self.gnubg.resigned 111 | 112 | self.update_game_board(self.gnubg.board) 113 | 114 | observation = self.game.get_board_features(self.current_agent) if self.model_type == 'nn' else self.render(mode='state_pixels') 115 | 116 | winner = self.gnubg.winner 117 | if winner is not None: 118 | winner = WHITE if winner == 'O' else BLACK 119 | 120 | if winner == WHITE: 121 | reward = 1 122 | done = True 123 | 124 | return observation, reward, done, winner 125 | 126 | def reset(self): 127 | # Start a new session in gnubg simulator 128 | self.gnubg = self.gnubg_interface.send_command("new session") 129 | 130 | if not self.is_difficulty_set: 131 | self.set_difficulty() 132 | 133 | roll = None if self.gnubg.agent == BLACK else self.gnubg.roll 134 | 135 | self.current_agent = WHITE 136 | self.game = Game() 137 | self.update_game_board(self.gnubg.board) 138 | 139 | observation = self.game.get_board_features(self.current_agent) if self.model_type == 'nn' else self.render(mode='state_pixels') 140 | return observation, roll 141 | 142 | def update_game_board(self, gnu_board): 143 | # Update the internal board representation with the representation of the gnubg program 144 | # The gnubg board is represented with two list of 25 elements each, one for each player 145 | gnu_positions_black = gnu_board[0] 146 | gnu_positions_white = gnu_board[1] 147 | board = [(0, None)] * NUM_POINTS 148 | 149 | for src, checkers in enumerate(gnu_positions_white[:-1]): 150 | if checkers > 0: 151 | board[src] = (checkers, WHITE) 152 | 153 | for src, checkers in enumerate(reversed(gnu_positions_black[:-1])): 154 | if checkers > 0: 155 | board[src] = (checkers, BLACK) 156 | 157 | self.game.board = board 158 | # the last element represent the checkers on the bar 159 | self.game.bar = [gnu_positions_white[-1], gnu_positions_black[-1]] 160 | # update the players position 161 | self.game.players_positions = self.game.get_players_positions() 162 | # off bar 163 | self.game.off = [15 - sum(gnu_positions_white), 15 - sum(gnu_positions_black)] 164 | # Just for debugging 165 | # self.render() 166 | assert_board(None, self.game.board, self.game.bar, self.game.off) 167 | 168 | def get_valid_actions(self, roll): 169 | return self.game.get_valid_plays(self.current_agent, roll) 170 | 171 | def set_difficulty(self): 172 | self.is_difficulty_set = True 173 | 174 | self.gnubg_interface.send_command('set automatic roll off') 175 | self.gnubg_interface.send_command('set automatic game off') 176 | 177 | if self.difficulty == 'beginner': 178 | self.gnubg_interface.send_command('set player gnubg chequer evaluation plies 0') 179 | self.gnubg_interface.send_command('set player gnubg chequer evaluation prune off') 180 | self.gnubg_interface.send_command('set player gnubg chequer evaluation noise 0.060') 181 | self.gnubg_interface.send_command('set player gnubg cube evaluation plies 0') 182 | self.gnubg_interface.send_command('set player gnubg cube evaluation prune off') 183 | self.gnubg_interface.send_command('set player gnubg cube evaluation noise 0.060') 184 | 185 | elif self.difficulty == 'intermediate': 186 | self.gnubg_interface.send_command('set player gnubg chequer evaluation noise 0.040') 187 | self.gnubg_interface.send_command('set player gnubg cube evaluation noise 0.040') 188 | 189 | elif self.difficulty == 'advanced': 190 | self.gnubg_interface.send_command('set player gnubg chequer evaluation plies 0') 191 | self.gnubg_interface.send_command('set player gnubg chequer evaluation prune off') 192 | self.gnubg_interface.send_command('set player gnubg chequer evaluation noise 0.015') 193 | self.gnubg_interface.send_command('set player gnubg cube evaluation plies 0') 194 | self.gnubg_interface.send_command('set player gnubg cube evaluation prune off') 195 | self.gnubg_interface.send_command('set player gnubg cube evaluation noise 0.015') 196 | 197 | elif self.difficulty == 'world_class': 198 | self.gnubg_interface.send_command('set player gnubg chequer evaluation plies 2') 199 | self.gnubg_interface.send_command('set player gnubg chequer evaluation prune on') 200 | self.gnubg_interface.send_command('set player gnubg chequer evaluation noise 0.000') 201 | 202 | self.gnubg_interface.send_command('set player gnubg movefilter 1 0 0 8 0.160') 203 | self.gnubg_interface.send_command('set player gnubg movefilter 2 0 0 8 0.160') 204 | self.gnubg_interface.send_command('set player gnubg movefilter 3 0 0 8 0.160') 205 | self.gnubg_interface.send_command('set player gnubg movefilter 3 2 0 2 0.040') 206 | self.gnubg_interface.send_command('set player gnubg movefilter 4 0 0 8 0.160') 207 | self.gnubg_interface.send_command('set player gnubg movefilter 4 2 0 2 0.040') 208 | 209 | self.gnubg_interface.send_command('set player gnubg cube evaluation plies 2') 210 | self.gnubg_interface.send_command('set player gnubg cube evaluation prune on') 211 | self.gnubg_interface.send_command('set player gnubg cube evaluation noise 0.000') 212 | 213 | self.gnubg_interface.send_command('save setting') 214 | 215 | def render(self, mode='human'): 216 | assert mode in ['human', 'rgb_array', 'state_pixels'], print(mode) 217 | 218 | if mode == 'human': 219 | self.game.render() 220 | return True 221 | else: 222 | if self.viewer is None: 223 | self.viewer = Viewer(SCREEN_W, SCREEN_H) 224 | 225 | if mode == 'rgb_array': 226 | width = SCREEN_W 227 | height = SCREEN_H 228 | 229 | else: 230 | assert mode == 'state_pixels', print(mode) 231 | width = STATE_W 232 | height = STATE_H 233 | 234 | return self.viewer.render(board=self.game.board, bar=self.game.bar, off=self.game.off, state_w=width, state_h=height) 235 | 236 | 237 | def evaluate_vs_gnubg(agent, env, n_episodes): 238 | wins = {WHITE: 0, BLACK: 0} 239 | 240 | for episode in range(n_episodes): 241 | observation, first_roll = env.reset() 242 | t = time.time() 243 | for i in count(): 244 | if first_roll: 245 | roll = first_roll 246 | first_roll = None 247 | else: 248 | env.gnubg = agent.roll_dice() 249 | env.update_game_board(env.gnubg.board) 250 | roll = env.gnubg.roll 251 | 252 | actions = env.get_valid_actions(roll) 253 | action = agent.choose_best_action(actions, env) 254 | 255 | observation_next, reward, done, info = env.step(action) 256 | # env.render(mode='rgb_array') 257 | 258 | if done: 259 | winner = WHITE if env.gnubg.winner == 'O' else BLACK 260 | wins[winner] += 1 261 | tot = wins[WHITE] + wins[BLACK] 262 | 263 | print("EVAL => Game={:<6} {:>15} | Winner={} | after {:<4} plays || Wins: {}={:<6}({:<5.1f}%) | gnubg={:<6}({:<5.1f}%) | Duration={:<.3f} sec".format( 264 | episode + 1, '('+env.difficulty+')', info, env.gnubg.n_moves, agent.name, wins[WHITE], (wins[WHITE] / tot) * 100, wins[BLACK], (wins[BLACK] / tot) * 100, time.time() - t)) 265 | break 266 | observation = observation_next 267 | 268 | env.gnubg_interface.send_command("new session") 269 | return wins 270 | -------------------------------------------------------------------------------- /td_gammon/web_gui/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | 8 | Backgammon GUI 9 | 53 | 54 | 55 | 56 |

BACKGAMMON

57 | 58 | 59 |
60 |
61 | 62 |
63 | 64 | 65 |

Player: WHITE

66 | 67 | 68 | 69 | 71 | 72 |
73 | 74 | 75 | 506 | 507 | 508 | --------------------------------------------------------------------------------