├── LICENSE ├── README.md ├── lengths.py ├── play.py ├── train.py ├── uno_ai ├── __init__.py ├── actions.py ├── agent.py ├── cards.py ├── game.py ├── pool.py ├── ppo.py ├── rollouts.py └── test_game.py └── vs_baseline.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018-2019, Alexander Nichol. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 17 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 18 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 19 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 20 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 22 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # uno-ai 2 | 3 | This is my first foray into multi-agent RL. The idea is to train an AI to play the game Uno, by continually running it against earlier versions of itself. 4 | 5 | So far I have implemented the game dynamics and a multi-agent training setup. I'm running training now, and will report on my results. Chances are there's at least one bug lurking around somewhere. 6 | -------------------------------------------------------------------------------- /lengths.py: -------------------------------------------------------------------------------- 1 | """ 2 | Measure the lengths of random games. 3 | """ 4 | 5 | import random 6 | 7 | from uno_ai.game import Game 8 | 9 | 10 | def main(): 11 | while True: 12 | g = Game(4) 13 | num_moves = 0 14 | while g.winner() is None: 15 | action = random.choice(g.options()) 16 | g.act(action) 17 | num_moves += 1 18 | print(num_moves) 19 | 20 | 21 | if __name__ == '__main__': 22 | main() 23 | -------------------------------------------------------------------------------- /play.py: -------------------------------------------------------------------------------- 1 | """ 2 | Play against an Uno agent. 3 | """ 4 | 5 | import argparse 6 | import os 7 | 8 | import torch 9 | 10 | from uno_ai.actions import NopAction 11 | from uno_ai.agent import Agent 12 | from uno_ai.game import Game 13 | from uno_ai.rollouts import Rollout 14 | 15 | 16 | def main(): 17 | args = arg_parser().parse_args() 18 | 19 | agent = Agent() 20 | if os.path.exists(args.path): 21 | state_dict = torch.load(args.path, map_location='cpu') 22 | agent.load_state_dict(state_dict) 23 | 24 | game = Game(args.players) 25 | agents = [HumanAgent()] + [agent] * (args.players - 1) 26 | Rollout.rollout(game, agents) 27 | 28 | 29 | def arg_parser(): 30 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 31 | parser.add_argument('--path', help='path to latest model', default='model.pt') 32 | parser.add_argument('--players', help='number of players', default=4, type=int) 33 | return parser 34 | 35 | 36 | class HumanAgent: 37 | def step(self, game, player, state): 38 | print('----------------------------') 39 | print('Player:', game.turn()) 40 | print('Discard:', game.discard()) 41 | print('') 42 | print('Cards:') 43 | for i, card in enumerate(game.hands()[player]): 44 | print('%d. %s' % (i, card)) 45 | print('') 46 | if game.turn() != player: 47 | input('Hit enter to continue.') 48 | return { 49 | 'action': NopAction(), 50 | 'state': None 51 | } 52 | else: 53 | print('Options:') 54 | for i, option in enumerate(game.options()): 55 | print('%d. %s' % (i, option)) 56 | idx = input('Choose option: ') 57 | if idx == '': 58 | idx = '0' 59 | return { 60 | 'action': game.options()[int(idx)], 61 | 'state': None, 62 | } 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train an Uno agent. 3 | """ 4 | 5 | import argparse 6 | import itertools 7 | import os 8 | import random 9 | 10 | import torch 11 | 12 | from uno_ai.agent import Agent, BaselineAgent 13 | from uno_ai.game import Game 14 | from uno_ai.pool import Pool 15 | from uno_ai.ppo import PPO 16 | from uno_ai.rollouts import Rollout, RolloutBatch 17 | 18 | 19 | def main(): 20 | args = arg_parser().parse_args() 21 | device = torch.device(args.device) 22 | agent = Agent() 23 | agent.to(device) 24 | 25 | if os.path.exists(args.path): 26 | state_dict = torch.load(args.path, map_location=device) 27 | agent.load_state_dict(state_dict) 28 | 29 | pool = Pool(args.pool) 30 | if pool.empty(): 31 | pool.add(agent) 32 | 33 | ppo = PPO(agent, epsilon=args.epsilon, lr=args.lr, ent_reg=args.entropy) 34 | for i in itertools.count(): 35 | rollouts, mean_rew, mean_len = gather_rollouts(args, agent, pool) 36 | step_res = ppo.loop(rollouts, iters=args.iters) 37 | print('reward=%f steps=%f explained=%f entropy_init=%f entropy_final=%f clipped=%f' % 38 | (mean_rew, mean_len, step_res[0]['explained'], step_res[0]['entropy'], 39 | step_res[-1]['entropy'], step_res[-1]['clipped'])) 40 | if not i % args.save_interval: 41 | pool.add(agent) 42 | torch.save(agent.state_dict(), args.path) 43 | 44 | 45 | def gather_rollouts(args, agent, pool): 46 | rollouts = [] 47 | for _ in range(args.batch): 48 | if args.baseline: 49 | agents = [agent] + [BaselineAgent() for _ in range(args.players - 1)] 50 | else: 51 | agents = [agent] + [pool.sample(agent.device()) for _ in range(args.players - 1)] 52 | random.shuffle(agents) 53 | rs = Rollout.rollout(Game(args.players), agents) 54 | rollouts.append(rs[agents.index(agent)]) 55 | mean_rew = sum(r.reward for r in rollouts) / len(rollouts) 56 | mean_len = sum(r.num_steps for r in rollouts) / len(rollouts) 57 | return (RolloutBatch(rollouts, agent.device(), gamma=args.gamma, lam=args.lam), 58 | mean_rew, mean_len) 59 | 60 | 61 | def arg_parser(): 62 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 63 | parser.add_argument('--path', help='path to latest model', default='model.pt') 64 | parser.add_argument('--pool', help='path to pool directory', default='agents') 65 | parser.add_argument('--lr', help='PPO learning rate', default=0.0001, type=float) 66 | parser.add_argument('--entropy', help='entropy bonus', default=0.01, type=float) 67 | parser.add_argument('--epsilon', help='PPO epsilon', default=0.1, type=float) 68 | parser.add_argument('--iters', help='PPO iterations', default=64, type=int) 69 | parser.add_argument('--players', help='number of players', default=4, type=int) 70 | parser.add_argument('--batch', help='rollouts per batch', default=256, type=int) 71 | parser.add_argument('--gamma', help='GAE gamma', default=1.0, type=float) 72 | parser.add_argument('--lam', help='GAE lambda', default=1.0, type=float) 73 | parser.add_argument('--device', help='torch device to use', default='cpu') 74 | parser.add_argument('--baseline', help='train against a baseline', action='store_true') 75 | parser.add_argument('--save-interval', help='iterations per agent checkpoint', 76 | default=10, type=int) 77 | return parser 78 | 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /uno_ai/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/uno-ai/3124afc8fa6b0cbcced95ef03ed9672cdb4f35a7/uno_ai/__init__.py -------------------------------------------------------------------------------- /uno_ai/actions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Action representations for Uno. 3 | """ 4 | 5 | from abc import ABC, abstractmethod 6 | 7 | ACTION_VECTOR_SIZE = 115 8 | 9 | 10 | class Action(ABC): 11 | """ 12 | An abstract action in a game of Uno. 13 | """ 14 | @abstractmethod 15 | def index(self): 16 | """ 17 | Get the index of the action in the action vector. 18 | """ 19 | pass 20 | 21 | 22 | class NopAction(Action): 23 | def index(self): 24 | return 0 25 | 26 | def __eq__(self, other): 27 | return type(other) == NopAction 28 | 29 | def __str__(self): 30 | return 'do nothing' 31 | 32 | 33 | class ChallengeAction(Action): 34 | def index(self): 35 | return 1 36 | 37 | def __eq__(self, other): 38 | return type(other) == ChallengeAction 39 | 40 | def __str__(self): 41 | return 'challenge' 42 | 43 | 44 | class DrawAction(Action): 45 | def index(self): 46 | return 2 47 | 48 | def __eq__(self, other): 49 | return type(other) == DrawAction 50 | 51 | def __str__(self): 52 | return 'draw a card' 53 | 54 | 55 | class PickColorAction(Action): 56 | def __init__(self, color): 57 | self.color = color 58 | 59 | def index(self): 60 | return self.color.value + 3 61 | 62 | def __eq__(self, other): 63 | return type(other) == PickColorAction and other.color == self.color 64 | 65 | def __str__(self): 66 | return 'pick color %s' % self.color 67 | 68 | 69 | class PlayCardAction(Action): 70 | def __init__(self, index): 71 | self.raw_index = index 72 | 73 | def index(self): 74 | return self.raw_index + 7 75 | 76 | def __eq__(self, other): 77 | return type(other) == PlayCardAction and other.raw_index == self.raw_index 78 | 79 | def __str__(self): 80 | return 'play card %d' % self.raw_index 81 | -------------------------------------------------------------------------------- /uno_ai/agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reinforcement Learning agents. 3 | """ 4 | 5 | import math 6 | import random 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | 12 | from .actions import ACTION_VECTOR_SIZE, DrawAction, NopAction 13 | from .game import OBS_VECTOR_SIZE 14 | 15 | 16 | class Agent(nn.Module): 17 | """ 18 | A stochastic policy plus a value function. 19 | """ 20 | 21 | def __init__(self): 22 | super().__init__() 23 | self.input_proc = nn.Sequential( 24 | nn.Linear(OBS_VECTOR_SIZE, 128), 25 | nn.Tanh(), 26 | ) 27 | self.norm = nn.LayerNorm(128) 28 | self.rnn = nn.LSTM(128, 128) 29 | self.policy = nn.Linear(256, ACTION_VECTOR_SIZE) 30 | self.value = nn.Linear(256, 1) 31 | for param in list(self.policy.parameters()) + list(self.value.parameters()): 32 | param.data.fill_(0.0) 33 | 34 | def device(self): 35 | return next(self.parameters()).device 36 | 37 | def forward(self, inputs, states=None): 38 | """ 39 | Apply the agent to a batch of sequences. 40 | 41 | Args: 42 | inputs: a (seq_len, batch, OBS_VECTOR_SIZE) 43 | Tensor of observations. 44 | states: a tuple (h_0, c_0) of states. 45 | 46 | Returns: 47 | A tuple (logits, states): 48 | logits: A (seq_len, batch, ACTION_VECTOR_SIZE) 49 | Tensor of logits. 50 | values: A (seq_len, batch) Tensor of values. 51 | states: a new (h_0, c_0) tuple. 52 | """ 53 | seq_len, batch = inputs.shape[0], inputs.shape[1] 54 | flat_in = inputs.view(-1, OBS_VECTOR_SIZE) 55 | features = self.norm(self.input_proc(flat_in)) 56 | feature_seq = features.view(seq_len, batch, -1) 57 | if states is None: 58 | outputs, (h_n, c_n) = self.rnn(feature_seq) 59 | else: 60 | outputs, (h_n, c_n) = self.rnn(feature_seq, states) 61 | flat_out = outputs.view(-1, outputs.shape[-1]) 62 | flat_out = torch.cat([flat_out, features], dim=-1) 63 | flat_logits = self.policy(flat_out) 64 | flat_values = self.value(flat_out) 65 | logits = flat_logits.view(seq_len, batch, ACTION_VECTOR_SIZE) 66 | values = flat_values.view(seq_len, batch) 67 | 68 | # Negative bias towards draw, to make initial 69 | # episodes much shorter. 70 | bias_vec = [0.0] * ACTION_VECTOR_SIZE 71 | bias_vec[DrawAction().index()] = -1 72 | logits += torch.from_numpy(np.array(bias_vec, dtype=np.float32)).to(logits.device) 73 | 74 | return logits, values, (h_n, c_n) 75 | 76 | def step(self, game, player, state): 77 | """ 78 | Pick an action in the game. 79 | 80 | Args: 81 | game: the Game we are playing. 82 | player: the index of this agent. 83 | state: the previous RNN state (or None). 84 | 85 | Returns: 86 | A dict containing the following keys: 87 | options: the options chosen from. 88 | action: the sampled action. 89 | log_prob: the log probability of the action. 90 | value: the value function output. 91 | state: the new RNN state. 92 | """ 93 | obs = torch.from_numpy(np.array(game.obs(player), dtype=np.float32)).to(self.device()) 94 | obs = obs.view(1, 1, -1) 95 | options = [NopAction()] 96 | if player == game.turn(): 97 | options = game.options() 98 | vec, values, new_state = self(obs, states=state) 99 | np_vec = vec.view(-1).detach().cpu().numpy() 100 | logits = np.array([np_vec[act.index()] for act in options]) 101 | idx, prob = sample_softmax(logits) 102 | return { 103 | 'options': options, 104 | 'action': options[idx], 105 | 'log_prob': math.log(prob), 106 | 'value': values.item(), 107 | 'state': new_state, 108 | } 109 | 110 | 111 | class RandomAgent: 112 | """ 113 | An agent that takes random actions. 114 | """ 115 | 116 | def step(self, game, player, state): 117 | options = [NopAction()] 118 | if player == game.turn(): 119 | options = game.options() 120 | return { 121 | 'options': options, 122 | 'action': random.choice(options), 123 | 'state': None, 124 | } 125 | 126 | 127 | class BaselineAgent: 128 | """ 129 | An agent that plays a card if possible, and otherwise 130 | takes a random action. 131 | """ 132 | 133 | def step(self, game, player, state): 134 | options = [NopAction()] 135 | if player == game.turn(): 136 | options = game.options() 137 | if len(options) > 2: 138 | action = random.choice(options[2:]) 139 | else: 140 | action = random.choice(options) 141 | return { 142 | 'options': options, 143 | 'action': action, 144 | 'state': None, 145 | } 146 | 147 | 148 | def sample_softmax(logits): 149 | max_value = np.max(logits) 150 | probs = np.exp(logits - max_value) 151 | probs /= np.sum(probs) 152 | x = random.random() 153 | for i, y in enumerate(probs): 154 | x -= y 155 | if x <= 0: 156 | return i, probs[i] 157 | return len(logits) - 1, probs[-1] 158 | -------------------------------------------------------------------------------- /uno_ai/cards.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | CARD_VEC_SIZE = 10 + 6 + 4 5 | 6 | 7 | def full_deck(): 8 | """ 9 | Create a complete Uno deck. 10 | """ 11 | deck = [] 12 | for color in [Color.RED, Color.ORANGE, Color.GREEN, Color.BLUE]: 13 | for _ in range(2): 14 | for number in range(1, 10): 15 | deck.append(Card(CardType.NUMERAL, color=color, number=number)) 16 | for card_type in [CardType.SKIP, CardType.REVERSE, CardType.DRAW_TWO]: 17 | deck.append(Card(card_type, color=color)) 18 | 19 | deck.append(Card(CardType.NUMERAL, color=color, number=0)) 20 | deck.append(Card(CardType.WILD)) 21 | deck.append(Card(CardType.WILD_DRAW)) 22 | return deck 23 | 24 | 25 | class CardType(Enum): 26 | """ 27 | The type of a card. 28 | """ 29 | NUMERAL = 0 30 | SKIP = 1 31 | REVERSE = 2 32 | DRAW_TWO = 3 33 | WILD = 4 34 | WILD_DRAW = 5 35 | 36 | 37 | class Color(Enum): 38 | """ 39 | The color of a card. 40 | """ 41 | RED = 0 42 | ORANGE = 1 43 | GREEN = 2 44 | BLUE = 3 45 | 46 | def __str__(self): 47 | return self.name.lower() 48 | 49 | 50 | class Card: 51 | """ 52 | A card in the deck. 53 | """ 54 | 55 | def __init__(self, card_type, color=None, number=None): 56 | self.card_type = card_type 57 | self.color = color 58 | self.number = number 59 | 60 | def vector(self): 61 | """ 62 | Convert the card into a vector. 63 | """ 64 | vec = [0.0] * CARD_VEC_SIZE 65 | if self.number is not None: 66 | vec[self.number] = 1.0 67 | if self.color is not None: 68 | vec[10 + self.color.value] = 1.0 69 | vec[14 + self.card_type.value] = 1.0 70 | return vec 71 | 72 | def __str__(self): 73 | if self.card_type == CardType.NUMERAL: 74 | return '%s %d' % (self.color, self.number) 75 | elif self.card_type == CardType.SKIP: 76 | return '%s skip' % self.color 77 | elif self.card_type == CardType.REVERSE: 78 | return '%s reverse' % self.color 79 | elif self.card_type == CardType.DRAW_TWO: 80 | return '%s draw two' % self.color 81 | elif self.card_type == CardType.WILD: 82 | if self.color is None: 83 | return 'wild card' 84 | return 'wild card (%s)' % self.color 85 | elif self.card_type == CardType.WILD_DRAW: 86 | if self.color is None: 87 | return 'wild +4' 88 | return 'wild +4 (%s)' % self.color 89 | raise RuntimeError('unknown type') 90 | -------------------------------------------------------------------------------- /uno_ai/game.py: -------------------------------------------------------------------------------- 1 | """ 2 | Core Uno game logic. 3 | """ 4 | 5 | from enum import Enum 6 | import random 7 | 8 | from .actions import NopAction, ChallengeAction, DrawAction, PickColorAction, PlayCardAction 9 | from .cards import CardType, Color, full_deck 10 | 11 | MAX_PLAYERS = 8 12 | OBS_VECTOR_SIZE = 109 * 20 + MAX_PLAYERS * 3 13 | 14 | 15 | class GameState(Enum): 16 | PLAY_OR_DRAW = 0 17 | PLAY = 1 18 | PLAY_DRAWN = 2 19 | PICK_COLOR = 3 20 | PICK_COLOR_INIT = 4 21 | CHALLENGE_VALID = 5 22 | CHALLENGE_INVALID = 6 23 | 24 | 25 | class Game: 26 | def __init__(self, num_players): 27 | assert num_players <= MAX_PLAYERS 28 | self._num_players = num_players 29 | self._deck = full_deck() 30 | random.shuffle(self._deck) 31 | self._discard = [] 32 | self._hands = [] 33 | for _ in range(num_players): 34 | hand = [] 35 | for _ in range(7): 36 | hand.append(self._deck.pop()) 37 | self._hands.append(hand) 38 | while self._deck[-1].card_type == CardType.WILD_DRAW: 39 | random.shuffle(self._deck) 40 | self._discard.append(self._deck.pop()) 41 | 42 | self._direction = 1 43 | self._turn = 0 44 | self._state = GameState.PLAY_OR_DRAW 45 | 46 | self._update_init_state() 47 | 48 | def hands(self): 49 | return self._hands 50 | 51 | def discard(self): 52 | return self._discard[-1] 53 | 54 | def winner(self): 55 | """ 56 | If the game has ended, get the winner. 57 | """ 58 | for i, hand in enumerate(self._hands): 59 | if not len(hand): 60 | return i 61 | return None 62 | 63 | def turn(self): 64 | """ 65 | Get the current player. 66 | """ 67 | return self._turn 68 | 69 | def obs(self, player): 70 | """ 71 | Generate an observation vector for a player. 72 | """ 73 | vec = [] 74 | hand = self._hands[player] 75 | for card in hand: 76 | vec += card.vector() 77 | vec += [0.0] * (20 * (108 - len(hand))) 78 | vec += self._discard[-1].vector() 79 | vec += _player_vec(player) 80 | vec += _player_vec(self.turn()) 81 | for hand in self._hands: 82 | vec += [float(len(hand))] 83 | vec += [0.0] * (MAX_PLAYERS - self._num_players) 84 | return vec 85 | 86 | def options(self): 87 | """ 88 | Get the valid actions for the current player. 89 | """ 90 | if self._state == GameState.PLAY_OR_DRAW: 91 | return [NopAction(), DrawAction()] + self._play_options() 92 | elif self._state == GameState.PLAY: 93 | return [NopAction()] + self._play_options() 94 | elif self._state == GameState.PLAY_DRAWN: 95 | res = [NopAction()] 96 | if self._can_play(self._current_hand()[-1]): 97 | res += [PlayCardAction(len(self._current_hand()) - 1)] 98 | return res 99 | elif self._state == GameState.PICK_COLOR or self._state == GameState.PICK_COLOR_INIT: 100 | return [PickColorAction(c) for c in [Color.RED, Color.ORANGE, Color.GREEN, Color.BLUE]] 101 | elif self._state == GameState.CHALLENGE_VALID or self._state == GameState.CHALLENGE_INVALID: 102 | return [NopAction(), ChallengeAction()] 103 | raise RuntimeError('invalid state') 104 | 105 | def act(self, action): 106 | """ 107 | Take a turn by selecting the action for the 108 | current player. 109 | """ 110 | assert action in self.options() 111 | if self._state == GameState.PLAY_OR_DRAW: 112 | if isinstance(action, NopAction): 113 | self._advance_turn() 114 | elif isinstance(action, DrawAction): 115 | self._state = GameState.PLAY_DRAWN 116 | self._current_hand().append(self._draw()) 117 | else: 118 | self._play_card(action) 119 | elif self._state == GameState.PLAY: 120 | if isinstance(action, NopAction): 121 | self._state = GameState.PLAY_OR_DRAW 122 | self._advance_turn() 123 | else: 124 | self._play_card(action) 125 | elif self._state == GameState.PLAY_DRAWN: 126 | if isinstance(action, NopAction): 127 | self._state = GameState.PLAY_OR_DRAW 128 | self._advance_turn() 129 | else: 130 | self._play_card(action) 131 | elif self._state == GameState.PICK_COLOR or self._state == GameState.PICK_COLOR_INIT: 132 | disc = self._discard[-1] 133 | disc.color = action.color 134 | if self._state == GameState.PICK_COLOR: 135 | last_disc = self._discard[-2] 136 | if disc.card_type == CardType.WILD: 137 | self._state = GameState.PLAY_OR_DRAW 138 | elif any(x.color == last_disc.color for x in self._current_hand()): 139 | self._state = GameState.CHALLENGE_INVALID 140 | else: 141 | self._state = GameState.CHALLENGE_VALID 142 | self._advance_turn() 143 | else: 144 | self._state = GameState.PLAY 145 | elif self._state == GameState.CHALLENGE_VALID or self._state == GameState.CHALLENGE_INVALID: 146 | if isinstance(action, NopAction): 147 | for _ in range(4): 148 | self._current_hand().append(self._draw()) 149 | self._advance_turn() 150 | elif self._state == GameState.CHALLENGE_INVALID: 151 | self._advance_turn(by=-1) 152 | for _ in range(4): 153 | self._current_hand().append(self._draw()) 154 | self._advance_turn(by=2) 155 | else: 156 | for _ in range(6): 157 | self._current_hand().append(self._draw()) 158 | self._advance_turn() 159 | self._state = GameState.PLAY_OR_DRAW 160 | 161 | def _advance_turn(self, by=1): 162 | self._turn += by * self._direction 163 | while self._turn < 0: 164 | self._turn += self._num_players 165 | while self._turn >= self._num_players: 166 | self._turn -= self._num_players 167 | 168 | def _play_card(self, action): 169 | card = self._current_hand()[action.raw_index] 170 | self._current_hand().remove(card) 171 | self._discard.append(card) 172 | if card.card_type == CardType.NUMERAL: 173 | self._advance_turn() 174 | elif card.card_type == CardType.SKIP: 175 | self._advance_turn(by=2) 176 | elif card.card_type == CardType.REVERSE: 177 | self._direction *= -1 178 | self._advance_turn() 179 | elif card.card_type == CardType.DRAW_TWO: 180 | self._advance_turn() 181 | for _ in range(2): 182 | self._current_hand().append(self._draw()) 183 | self._advance_turn() 184 | elif card.card_type == CardType.WILD or card.card_type == CardType.WILD_DRAW: 185 | self._state = GameState.PICK_COLOR 186 | return 187 | self._state = GameState.PLAY_OR_DRAW 188 | 189 | def _draw(self): 190 | if len(self._deck): 191 | return self._deck.pop() 192 | self._deck = self._discard[:-1] 193 | self._discard = [self._discard[-1]] 194 | random.shuffle(self._deck) 195 | for card in self._deck: 196 | if card.card_type == CardType.WILD or card.card_type == CardType.WILD_DRAW: 197 | card.color = None 198 | return self._deck.pop() 199 | 200 | def _update_init_state(self): 201 | first_card = self._discard[0] 202 | if first_card.card_type == CardType.SKIP: 203 | self._turn += 1 204 | elif first_card.card_type == CardType.REVERSE: 205 | self._direction = -1 206 | self._turn = self._num_players - 1 207 | elif first_card.card_type == CardType.DRAW_TWO: 208 | for _ in range(2): 209 | self._hands[0].append(self._draw()) 210 | self._turn = 1 211 | elif first_card.card_type == CardType.WILD: 212 | self._state = GameState.PICK_COLOR_INIT 213 | 214 | def _can_play(self, card): 215 | if card.card_type == CardType.WILD or card.card_type == CardType.WILD_DRAW: 216 | return True 217 | disc = self._discard[-1] 218 | if card.color == disc.color: 219 | return True 220 | if card.card_type == CardType.NUMERAL and disc.card_type == CardType.NUMERAL: 221 | return card.number == disc.number 222 | return card.card_type == disc.card_type 223 | 224 | def _play_options(self): 225 | res = [] 226 | for i, card in enumerate(self._current_hand()): 227 | if self._can_play(card): 228 | res.append(PlayCardAction(i)) 229 | return res 230 | 231 | def _current_hand(self): 232 | return self._hands[self._turn] 233 | 234 | 235 | def _player_vec(idx): 236 | res = [0.0] * MAX_PLAYERS 237 | res[idx] = 1.0 238 | return res 239 | -------------------------------------------------------------------------------- /uno_ai/pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pools of agents. 3 | """ 4 | 5 | import os 6 | import random 7 | 8 | import torch 9 | 10 | from .agent import Agent 11 | 12 | 13 | class Pool: 14 | def __init__(self, dir_path): 15 | self.dir_path = dir_path 16 | if not os.path.exists(dir_path): 17 | os.mkdir(dir_path) 18 | self.agent_names = [x for x in os.listdir(dir_path) if x.endswith('.pt')] 19 | 20 | def empty(self): 21 | return len(self.agent_names) == 0 22 | 23 | def add(self, agent): 24 | """ 25 | Add an agent to the pool. 26 | """ 27 | idx = 0 28 | while ('%d.pt' % idx) in self.agent_names: 29 | idx += 1 30 | name = '%d.pt' % idx 31 | torch.save(agent.state_dict(), os.path.join(self.dir_path, name)) 32 | self.agent_names.append(name) 33 | 34 | def sample(self, device): 35 | """ 36 | Sample an agent from the pool. 37 | """ 38 | name = random.choice(self.agent_names) 39 | state_dict = torch.load(os.path.join(self.dir_path, name), map_location=device) 40 | res = Agent() 41 | res.to(device) 42 | res.load_state_dict(state_dict) 43 | return res 44 | -------------------------------------------------------------------------------- /uno_ai/ppo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Proximal Policy Optimization. 3 | """ 4 | 5 | import torch 6 | import torch.optim as optim 7 | 8 | 9 | class PPO: 10 | def __init__(self, agent, epsilon=0.2, lr=1e-4, ent_reg=1e-2): 11 | self.agent = agent 12 | self.epsilon = epsilon 13 | self.optim = optim.Adam(agent.parameters(), lr=lr) 14 | self.ent_reg = ent_reg 15 | 16 | def loop(self, batch, iters=8): 17 | steps = [] 18 | for _ in range(iters): 19 | steps.append(self.step(batch)) 20 | return steps 21 | 22 | def step(self, batch): 23 | logits, values, _ = self.agent(batch.observations) 24 | masked_logits = logits - (1 - batch.masks) * 10000 25 | all_probs = torch.log_softmax(masked_logits, dim=-1) 26 | log_probs = torch.sum(all_probs * batch.actions, dim=-1) 27 | 28 | vf_loss = batch.masked_mean(torch.pow(values - batch.targets, 2)) 29 | variance = batch.masked_var(batch.targets) 30 | explained = 1 - vf_loss / variance 31 | 32 | ratio = torch.exp(log_probs - batch.log_probs) 33 | clip_ratio = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) 34 | pi_loss = -batch.masked_mean(torch.min(ratio * batch.advs, clip_ratio * batch.advs)) 35 | clip_frac = batch.masked_mean(torch.gt(ratio * batch.advs, clip_ratio * batch.advs).float()) 36 | 37 | neg_entropy = batch.masked_mean(torch.sum(torch.exp(all_probs) * all_probs, dim=-1)) 38 | ent_loss = self.ent_reg * neg_entropy 39 | 40 | loss = vf_loss + pi_loss + ent_loss 41 | self.optim.zero_grad() 42 | loss.backward() 43 | self.optim.step() 44 | 45 | return { 46 | 'vf_loss': vf_loss.item(), 47 | 'pi_loss': pi_loss.item(), 48 | 'ent_loss': ent_loss.item(), 49 | 'entropy': -neg_entropy.item(), 50 | 'clipped': clip_frac.item(), 51 | 'explained': explained.item(), 52 | } 53 | -------------------------------------------------------------------------------- /uno_ai/rollouts.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gather trajectories from agents on a game. 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from .actions import ACTION_VECTOR_SIZE 9 | from .game import OBS_VECTOR_SIZE 10 | 11 | 12 | class Rollout: 13 | """ 14 | A single episode from the perspective of a single 15 | agent. 16 | 17 | Stores the observation for every timestamp, the output 18 | of the agent at every timestamp, and the reward of the 19 | final episode from the agent's perspective. 20 | """ 21 | 22 | def __init__(self, observations, outputs, reward): 23 | self.observations = observations 24 | self.outputs = outputs 25 | self.reward = reward 26 | 27 | @property 28 | def num_steps(self): 29 | """ 30 | Get the number of timesteps. 31 | """ 32 | return len(self.observations) 33 | 34 | def advantages(self, gamma, lam): 35 | """ 36 | Compute the advantages using Generalized 37 | Advantage Estimation. 38 | """ 39 | res = [] 40 | adv = 0 41 | for i in range(self.num_steps)[::-1]: 42 | if i == self.num_steps - 1: 43 | delta = self.reward - self.outputs[-1]['value'] 44 | else: 45 | delta = gamma * self.outputs[i + 1]['value'] - self.outputs[i]['value'] 46 | adv *= lam * gamma 47 | adv += delta 48 | res.append(adv) 49 | return res[::-1] 50 | 51 | @classmethod 52 | def empty(cls): 53 | return cls([], [], 0.0) 54 | 55 | @classmethod 56 | def rollout(cls, game, agents): 57 | rollouts = [cls.empty() for _ in agents] 58 | states = [None] * len(agents) 59 | while game.winner() is None: 60 | for i, agent in enumerate(agents): 61 | res = agent.step(game, i, states[i]) 62 | states[i] = res['state'] 63 | r = rollouts[i] 64 | r.observations.append(game.obs(i)) 65 | r.outputs.append(res) 66 | game.act(rollouts[game.turn()].outputs[-1]['action']) 67 | for i, r in enumerate(rollouts): 68 | if i == game.winner(): 69 | r.reward = 1.0 70 | return rollouts 71 | 72 | 73 | class RolloutBatch: 74 | """ 75 | A packed batch of rollouts. 76 | 77 | Specifically, stores the following: 78 | observations: a (seq_len, batch, OBS_VECTOR_SIZE) 79 | Tensor of observations. 80 | actions: a (seq_len, batch, ACTION_VECTOR_SIZE) 81 | Tensor of one-hot vectors indicating which action 82 | was taken at every timestep. 83 | log_probs: a (seq_len, batch) Tensor storing the 84 | initial log probabilities of the actions. 85 | masks: a (seq_len, batch, ACTION_VECTOR_SIZE) Tensor 86 | of values where a 1 indicates that an action was 87 | allowed, and a 0 indicates otherwise. 88 | advs: a (seq_len, batch) Tensor of advantages. 89 | targets: a (seq_len, batch) Tensor of target values. 90 | seq_mask: a (seq_len, batch) Tensor of values where 91 | a 1 indicates that the element is valid. 92 | """ 93 | 94 | def __init__(self, rollouts, device, gamma, lam): 95 | seq_len = max(r.num_steps for r in rollouts) 96 | batch = len(rollouts) 97 | observations = np.zeros([seq_len, batch, OBS_VECTOR_SIZE], dtype=np.float32) 98 | actions = np.zeros([seq_len, batch, ACTION_VECTOR_SIZE], dtype=np.float32) 99 | log_probs = np.zeros([seq_len, batch], dtype=np.float32) 100 | masks = np.zeros([seq_len, batch, ACTION_VECTOR_SIZE], dtype=np.float32) 101 | advs = np.zeros([seq_len, batch], dtype=np.float32) 102 | targets = np.zeros([seq_len, batch], dtype=np.float32) 103 | seq_mask = np.zeros([seq_len, batch], dtype=np.float32) 104 | for i, r in enumerate(rollouts): 105 | observations[:r.num_steps, i, :] = r.observations 106 | actions[:r.num_steps, i, :] = [_one_hot_action(o['action']) for o in r.outputs] 107 | actions[r.num_steps:, i, 0] = 1 108 | log_probs[:r.num_steps, i] = [o['log_prob'] for o in r.outputs] 109 | masks[:r.num_steps, i, :] = [_option_mask(o['options']) for o in r.outputs] 110 | masks[r.num_steps:, i, 0] = 1 111 | our_advs = r.advantages(gamma=gamma, lam=lam) 112 | advs[:r.num_steps, i] = our_advs 113 | targets[:r.num_steps, i] = [adv + o['value'] for adv, o in zip(our_advs, r.outputs)] 114 | seq_mask[:r.num_steps, i] = 1 115 | 116 | def proc_list(l): 117 | return torch.from_numpy(l).to(device) 118 | 119 | self.observations = proc_list(observations) 120 | self.actions = proc_list(actions) 121 | self.log_probs = proc_list(log_probs) 122 | self.masks = proc_list(masks) 123 | self.advs = proc_list(advs) 124 | self.targets = proc_list(targets) 125 | self.seq_mask = proc_list(seq_mask) 126 | 127 | def masked_mean(self, seqs): 128 | """ 129 | Compute a mean using the sequence mask. 130 | 131 | Args: 132 | seqs: a (seq_len, batch) Tensor. 133 | 134 | Returns: 135 | A masked mean. 136 | """ 137 | return torch.sum(seqs * self.seq_mask) / torch.sum(self.seq_mask) 138 | 139 | def masked_var(self, seqs): 140 | """ 141 | Compute a variance using the sequence mask. 142 | """ 143 | num_entries = torch.sum(self.seq_mask) 144 | ex2 = torch.sum(torch.pow(seqs, 2) * self.seq_mask) / num_entries 145 | ex = torch.sum(seqs * self.seq_mask) / num_entries 146 | return ex2 - torch.pow(ex, 2) 147 | 148 | 149 | def _one_hot_action(action): 150 | res = [0.0] * ACTION_VECTOR_SIZE 151 | res[action.index()] = 1.0 152 | return res 153 | 154 | 155 | def _option_mask(options): 156 | res = [0.0] * ACTION_VECTOR_SIZE 157 | for a in options: 158 | res[a.index()] = 1.0 159 | return res 160 | -------------------------------------------------------------------------------- /uno_ai/test_game.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from .game import Game 4 | 5 | 6 | def test_game(): 7 | """ 8 | Run through a few games and make sure there's no 9 | exceptions. 10 | """ 11 | for i in range(10): 12 | for n in range(2, 5): 13 | g = Game(n) 14 | while not g.winner(): 15 | g.act(random.choice(g.options())) 16 | -------------------------------------------------------------------------------- /vs_baseline.py: -------------------------------------------------------------------------------- 1 | """ 2 | Play an agent against random agents. 3 | """ 4 | 5 | import argparse 6 | import os 7 | import random 8 | 9 | import torch 10 | 11 | from uno_ai.agent import Agent, BaselineAgent 12 | from uno_ai.game import Game 13 | from uno_ai.rollouts import Rollout 14 | 15 | 16 | def main(): 17 | args = arg_parser().parse_args() 18 | 19 | agent = Agent() 20 | if os.path.exists(args.path): 21 | state_dict = torch.load(args.path, map_location='cpu') 22 | agent.load_state_dict(state_dict) 23 | 24 | baseline = BaselineAgent() 25 | agents = [baseline] * (args.players - 1) + [agent] 26 | 27 | rewards = [] 28 | while True: 29 | game = Game(args.players) 30 | random.shuffle(agents) 31 | rs = Rollout.rollout(game, agents) 32 | rewards.append(rs[agents.index(agent)].reward) 33 | print('mean=%f' % (sum(rewards) / len(rewards))) 34 | 35 | 36 | def arg_parser(): 37 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 38 | parser.add_argument('--path', help='path to latest model', default='model.pt') 39 | parser.add_argument('--players', help='number of players', default=4, type=int) 40 | return parser 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | --------------------------------------------------------------------------------