├── .gitignore ├── model_5x5_6.h5 ├── requirements.txt ├── LICENSE ├── README.md ├── test_problem.py ├── policyvalue.py ├── resnet.py ├── scoring.py ├── gostate.py ├── training.py ├── gostate_pachi.py ├── agz.py └── goboard.py /.gitignore: -------------------------------------------------------------------------------- 1 | .agz 2 | *.prof 3 | *.pyc 4 | *.h5 5 | *.hdf5 6 | -------------------------------------------------------------------------------- /model_5x5_6.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntonOsika/agz/HEAD/model_5x5_6.h5 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | backports.weakref==1.0.post1 2 | bleach==1.5.0 3 | enum34==1.1.6 4 | funcsigs==1.0.2 5 | futures==3.1.1 6 | html5lib==0.9999999 7 | Keras==2.0.9 8 | Markdown==2.6.9 9 | mock==2.0.0 10 | numpy==1.13.3 11 | pbr==3.1.1 12 | protobuf==3.4.0 13 | PyYAML==3.12 14 | scipy==1.0.0 15 | six==1.11.0 16 | tensorflow==1.4.0 17 | tensorflow-tensorboard==0.4.0rc2 18 | Werkzeug==0.12.2 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2017 Anton Osika 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AlphaGo Zero based RL agent 2 | Made during 'AI Weekend' in Stockholm. 3 | 4 | ## Structure 5 | ```python 6 | ├── README.md 7 | ├── agz.py # MCTS logic. File can also visualise etc 8 | ├── goboard.py # Go implementation code 9 | ├── scoring.py # Go implementation code 10 | ├── gostate.py # Go environment wrapping goboard, scoring 11 | ├── gostate_pachi.py # Go environment wrapping the fast pachi implementation 12 | ├── resnet.py # Neural network for evaluating board positions 13 | ├── policyvalue.py # Predictor class wrapping the resnet CNN 14 | └── training.py # Training loop performing self play 15 | ``` 16 | 17 | ## Installation 18 | 19 | Requires [pachi-py](https://github.com/openai/pachi-py). 20 | ``` 21 | pip install numpy 22 | pip install keras 23 | pip install tensorflow 24 | 25 | python agz.py 26 | ``` 27 | 28 | ## Todo 29 | - [ ] Cleanup code structure with folders etc 30 | - [ ] Implement random reflections of board 31 | - [ ] Tune how much time is spend exploring / training (c.f. AGZ paper) 32 | - [ ] Parallelize training and simulation. 33 | - [ ] Use code from `agz.play_game` to create `MCTSAgent` class 34 | - [ ] Use same model on other environments 35 | - [ ] Learn the transition dynamics of step(state, action) 36 | - [ ] Refactor `MCTSAgent` to implement `.update_state` and `.decision` methods 37 | -------------------------------------------------------------------------------- /test_problem.py: -------------------------------------------------------------------------------- 1 | from agz import * 2 | 3 | class AddDivState(): 4 | """ 5 | Wrapper of betago class. 6 | TODO: Possibly replace the state with numpy arrays for less memory consumption 7 | """ 8 | 9 | def __init__(self, target = 7.4): 10 | print("init") 11 | self.game_over = False 12 | self.winner = None 13 | self.current_player = 'b' # TODO represent this with (1, -1) is faster 14 | self.action_space = 2 15 | self.valid_actions = self._valid_actions() 16 | 17 | self.state = 3.14 18 | self.target = 5.31 19 | self.player_transition = {'b': 'b', 'b': 'b'} 20 | 21 | def step(self, choice): 22 | 23 | 24 | # If illegal move: Will pass 25 | logger.log(5, "Did action {} in:\n{}".format(choice, self)) 26 | 27 | if choice > 0.5: 28 | self.state = self.state*0.75 29 | else: 30 | self.state = self.state+1.0 31 | 32 | print("choice", choice, "state", self.state) 33 | """ 34 | self.current_player = self.player_transition[self.current_player] 35 | 36 | self.last_action_2 = self.last_action 37 | self.last_action = pos 38 | 39 | """ 40 | self._new_state_checks() # Updates self.game_over and self.winner 41 | 42 | def _action_pos(self, action): 43 | if action == self.action_space - 1: # pass turn 44 | return None 45 | else: 46 | return (action // self.board_size, action % self.board_size) 47 | 48 | def _new_state_checks(self): 49 | 50 | self.game_over = self.state > self.target 51 | if self.game_over: 52 | self.winner = self._compute_winner() 53 | 54 | 55 | def _compute_winner(self): 56 | 57 | return 1/(10*abs(self.state - self.target) + 1) 58 | 59 | 60 | def _valid_actions(self): 61 | actions = [] 62 | for action in range(self.action_space): 63 | actions.append(action) 64 | 65 | return actions 66 | 67 | def observed_state(self): 68 | 69 | return self.state 70 | 71 | 72 | class test_value_policy: 73 | 74 | def policy(self, state): 75 | """Returns distribution over all allowed actions""" 76 | # uniform placeholder: 77 | return np.zeros([state.action_space]) + 1.0/state.action_space 78 | 79 | def value(self, state): 80 | return 1/(10*abs((state.state - state.target%1.0)) + 1) 81 | 82 | 83 | 84 | def predict(self, state): 85 | return self.policy(state), self.value(state) 86 | 87 | if __name__ == "__main__": 88 | 89 | start_state = AddDivState() 90 | tree_root = TreeStructure(start_state) 91 | 92 | hist, rew = play_game(start_state,policy_value=test_value_policy()) 93 | print(hist, rew) -------------------------------------------------------------------------------- /policyvalue.py: -------------------------------------------------------------------------------- 1 | import random 2 | import copy 3 | import logging 4 | import time 5 | 6 | import numpy as np 7 | 8 | from resnet import ResNet 9 | 10 | 11 | logger = logging.getLogger("__main__") 12 | 13 | """ 14 | Model evaluating prior-policy and value for MCTS. 15 | Classes here are Go specific so far. 16 | """ 17 | 18 | class SimpleCNN(ResNet): 19 | """ 20 | Uses the keras resnet implementation. 21 | It reverses order of input and their shape! 22 | """ 23 | def __init__(self, input_shape): 24 | super(SimpleCNN, self).__init__(input_shape=input_shape, 25 | n_filter=256, 26 | n_blocks=20) 27 | 28 | self.compile() 29 | 30 | def predict(self, state): 31 | x = state.observed_state() 32 | x = x[None, ...] # batch_size = 1 here (using queue in paper) 33 | p, v = self.model.predict(x) 34 | return p.flatten(), v.flatten()[0] 35 | 36 | def train_on_batch(self, x, y): 37 | self.model.train_on_batch(x, y) 38 | 39 | def load(self, number): 40 | fn = "model_{}x{}_{}.h5".format(self.input_shape[0], self.input_shape[1], number) 41 | try: 42 | self.model.load_weights(fn) 43 | except: 44 | print("Couldnt load model weights {}".format(fn)) 45 | 46 | 47 | class NaivePolicyValue(object): 48 | def __init__(self): 49 | pass 50 | 51 | def value_network_counter(self, state): 52 | """Some logistic regression thing on sum of stones""" 53 | black_stones = 0 54 | white_stones = 0 55 | for x in state.board.values(): 56 | if x == 'b': 57 | black_stones += 1 58 | if x == 'w': 59 | white_stones += 1 60 | value = np.tanh((black_stones - white_stones)/3.0) 61 | return value 62 | 63 | def value_network_rollout(self, state): 64 | """Returns value of position for player 1.""" 65 | # simple rollout placeholder: 66 | t0 = time.time() 67 | state = copy.deepcopy(state) 68 | t1 = time.time() 69 | counter = 0 70 | while not state.game_over: 71 | # choice = sample(policy_network(state)[state.allowed_actions]) 72 | choice = random.randint(0, len(state.valid_actions) - 1) 73 | state.step(choice) 74 | counter += 1 75 | logger.debug("took {} + {} to copy + roll out for {}:".format( 76 | t1 - t0, time.time() - t1, counter)) 77 | return state.winner 78 | 79 | def policy(self, state): 80 | """Returns distribution over all allowed actions""" 81 | # uniform placeholder: 82 | return np.zeros([state.action_space]) + 1.0/state.action_space 83 | 84 | def value(self, state): 85 | return self.value_network_rollout(state) 86 | 87 | def predict(self, state): 88 | return self.policy(state), self.value(state) 89 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import keras 3 | from keras.layers import Input 4 | from keras import layers 5 | from keras.layers import Dense 6 | from keras.layers import Activation 7 | from keras.layers import Flatten 8 | from keras.layers import Conv2D 9 | from keras.layers import BatchNormalization 10 | from keras.models import Model 11 | 12 | class ResNet(object): 13 | ''' 14 | Creates a residual neural network as described in the AlphaGo Zero paper. 15 | 16 | Output: a tuple [p, v] of value and prior prob. over the action space. 17 | ''' 18 | def __init__(self, input_shape, n_filter=256, kernel_size=(3, 3), n_blocks=20, bn_axis=3): 19 | self.input_shape = input_shape 20 | self.n_filter = n_filter #number of filters in convolutional layers 21 | self.kernel_size = kernel_size #kernel size of convolutional layers 22 | self.n_blocks = n_blocks #number of residual blocks 23 | self.bn_axis = bn_axis #batch normalization axis 24 | self.n_actions = self.input_shape[0]**2 + 1 25 | self.model = self.build_model() 26 | 27 | def build_model(self): 28 | input_ = Input(shape=self.input_shape) 29 | 30 | #Input layers 31 | x = Conv2D(self.n_filter, self.kernel_size, strides=(1, 1), padding='same')(input_) 32 | x = BatchNormalization(axis=self.bn_axis)(x) 33 | x = Activation('relu')(x) 34 | 35 | #residual tower 36 | for _ in range(self.n_blocks): 37 | cnn_1 = Conv2D(self.n_filter, self.kernel_size, padding='same') 38 | cnn_2 = Conv2D(self.n_filter, self.kernel_size, padding='same') 39 | bn_1 = BatchNormalization(axis=self.bn_axis) 40 | bn_2 = BatchNormalization(axis=self.bn_axis) 41 | relu = Activation('relu') 42 | y = bn_2(cnn_2(relu(bn_1(cnn_1(x))))) 43 | x = relu(layers.add([x, y])) 44 | 45 | #Policy part 46 | p = Activation('relu')(BatchNormalization(axis=self.bn_axis)(Conv2D(2, (1, 1), padding='same')(x))) 47 | p = Flatten()(p) 48 | p = Dense(self.n_actions, activation="softmax", name="policy_output", kernel_initializer='random_uniform', 49 | bias_initializer='ones')(p) 50 | 51 | #Value part 52 | v = Activation('relu')(BatchNormalization(axis=self.bn_axis)(Conv2D(1, (1, 1), padding='same')(x))) 53 | v = Flatten()(v) 54 | v = Dense(self.n_filter, activation='relu', kernel_initializer='random_uniform', 55 | bias_initializer='ones')(v) 56 | v = Dense(1, kernel_initializer='random_uniform', 57 | bias_initializer='ones')(v) 58 | v = Activation('tanh', name="value_output")(v) 59 | 60 | model = Model(input_, [p, v]) 61 | return model 62 | 63 | def compile(self): 64 | self.model.compile(loss={'policy_output': 'categorical_crossentropy', 'value_output': 'mse'}, 65 | loss_weights={'policy_output': 1., 'value_output': 1.}, optimizer='adam') 66 | 67 | 68 | if __name__ == '__main__': 69 | resnet = ResNet((19, 19, 17), n_blocks=20) 70 | resnet.model.summary() 71 | resnet.compile() 72 | -------------------------------------------------------------------------------- /scoring.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import itertools 3 | from six.moves import range 4 | 5 | """Original copyright Max Pumperla written for betago.""" 6 | 7 | class Territory(object): 8 | def __init__(self, territory_map): 9 | self.num_black_territory = 0 10 | self.num_white_territory = 0 11 | self.num_black_stones = 0 12 | self.num_white_stones = 0 13 | self.num_dame = 0 14 | self.dame_points = [] 15 | for point, status in territory_map.items(): 16 | if status == 'b': 17 | self.num_black_stones += 1 18 | elif status == 'w': 19 | self.num_white_stones += 1 20 | elif status == 'territory_b': 21 | self.num_black_territory += 1 22 | elif status == 'territory_w': 23 | self.num_white_territory += 1 24 | elif status == 'dame': 25 | self.num_dame += 1 26 | self.dame_points.append(point) 27 | 28 | 29 | def evaluate_territory(board): 30 | """Map a board into territory and dame. 31 | 32 | Any points that are completely surrounded by a single color are 33 | counted as territory; it makes no attempt to identify even 34 | trivially dead groups. 35 | """ 36 | status = {} 37 | for r, c in itertools.product(list(range(board.board_size)), list(range(board.board_size))): 38 | if (r, c) in status: 39 | # Already visited this as part of a different group. 40 | continue 41 | if (r, c) in board.board: 42 | # It's a stone. 43 | status[r, c] = board.board[r, c] 44 | else: 45 | group, neighbors = _collect_region((r, c), board) 46 | if len(neighbors) == 1: 47 | # Completely surrounded by black or white. 48 | fill_with = 'territory_' + neighbors.pop() 49 | else: 50 | # Dame. 51 | fill_with = 'dame' 52 | for pos in group: 53 | status[pos] = fill_with 54 | return Territory(status) 55 | 56 | 57 | def _collect_region(start_pos, board, visited=None): 58 | """Find the contiguous section of a board containing a point. Also 59 | identify all the boundary points. 60 | """ 61 | if visited is None: 62 | visited = {} 63 | if start_pos in visited: 64 | return [], set() 65 | all_points = [start_pos] 66 | all_borders = set() 67 | visited[start_pos] = True 68 | here = board.board.get(start_pos) 69 | r, c = start_pos 70 | deltas = [(-1, 0), (1, 0), (0, -1), (0, 1)] 71 | for delta_r, delta_c in deltas: 72 | next_r, next_c = r + delta_r, c + delta_c 73 | if next_r < 0 or next_r >= board.board_size: 74 | continue 75 | if next_c < 0 or next_c >= board.board_size: 76 | continue 77 | neighbor = board.board.get((next_r, next_c)) 78 | if neighbor == here: 79 | points, borders = _collect_region((next_r, next_c), board, visited) 80 | all_points += points 81 | all_borders |= borders 82 | else: 83 | all_borders.add(neighbor) 84 | return all_points, all_borders 85 | -------------------------------------------------------------------------------- /gostate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | 4 | from goboard import GoBoard 5 | from scoring import evaluate_territory 6 | 7 | BOARD_SIZE = 5 8 | 9 | logger = logging.getLogger("__main__") 10 | 11 | 12 | class GoState(GoBoard): 13 | """ 14 | OpenAI-gym env for go board. 15 | Has .valid_actions to sample from. If step receives an invalid actions -> pass turn is played. 16 | Can generate the numeric observation with .observed_state. 17 | 18 | properties: 19 | 20 | .winner 21 | .game_over 22 | .current_player 23 | .action_space 24 | .valid_actions 25 | 26 | 27 | TODO: Replace go engine code so that checking valid states does not require a deepcopy. 28 | 29 | """ 30 | 31 | def __init__(self, board_size=BOARD_SIZE): 32 | super(GoState, self).__init__(board_size) 33 | 34 | self.game_over = False 35 | self.winner = None 36 | self.current_player = 'b' 37 | self.action_space = board_size**2 + 1 38 | self.valid_actions = self._valid_actions() 39 | 40 | self.last_action = -1 41 | self.last_action_2 = -1 42 | 43 | self.player_transition = {'b': 'w', 'w': 'b'} 44 | 45 | def step(self, choice): 46 | 47 | action = self.valid_actions[choice] 48 | pos = self._action_pos(action) 49 | 50 | # If illegal move: Will pass 51 | logger.log(5, "Did action {} in:\n{}".format(pos, self)) 52 | 53 | if pos and not self.is_move_legal(self.current_player, pos): 54 | pos = None 55 | logger.log(5, "Which was not allowed") 56 | 57 | if pos: 58 | super(GoState, self).apply_move(self.current_player, pos) 59 | 60 | self.current_player = self.player_transition[self.current_player] 61 | 62 | self.last_action_2 = self.last_action 63 | self.last_action = pos 64 | 65 | self._new_state_checks() # Updates self.game_over and self.winner 66 | 67 | def _action_pos(self, action): 68 | if action == self.action_space - 1: # pass turn 69 | return None 70 | else: 71 | return (action // self.board_size, action % self.board_size) 72 | 73 | def _new_state_checks(self): 74 | """Checks if game is over and who won""" 75 | board_is_full = len(self.board) == self.board_size**2 76 | double_pass = (self.last_action is None) and \ 77 | (self.last_action_2 is None) 78 | self.game_over = board_is_full or double_pass 79 | 80 | if self.game_over: 81 | self.winner = self._compute_winner() 82 | 83 | self.valid_actions = self._valid_actions() 84 | 85 | def _compute_winner(self): 86 | counts = evaluate_territory(self) 87 | black_won = counts.num_black_stones + counts.num_black_territory > counts.num_white_stones + counts.num_white_territory 88 | white_won = counts.num_black_stones + counts.num_black_territory < counts.num_white_stones + counts.num_white_territory 89 | # Make sure tie -> 0 90 | return black_won - white_won 91 | 92 | def _valid_actions(self): 93 | actions = [] 94 | for action in range(self.action_space): 95 | if self._action_pos(action) not in self.board: 96 | actions.append(action) 97 | 98 | return actions 99 | 100 | def observed_state(self): 101 | board = np.zeros([self.board_size, self.board_size, 2]) 102 | for key, val in self.board.items(): 103 | if val == 'b': 104 | board[key, 0] = 1.0 105 | if val == 'w': 106 | board[key, 1] = 1.0 107 | 108 | return board 109 | 110 | def step(state, choice): 111 | """Functional stateless version of env.step() """ 112 | t0 = time.time() 113 | new_state = copy.deepcopy(state) 114 | logger.log(6, "took {} to deepcopy \n{}".format(time.time()-t0, state) ) 115 | new_state.step(choice) 116 | return new_state 117 | 118 | 119 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | from __future__ import absolute_import 4 | 5 | import copy 6 | import random 7 | import itertools 8 | 9 | import numpy as np 10 | from six.moves import input 11 | 12 | import agz 13 | import policyvalue 14 | 15 | # from gostate_pachi import GoState 16 | from gostate_pachi import GoState 17 | 18 | N_SIMULATIONS = 50 19 | 20 | def training_loop(policy_value_class=policyvalue.SimpleCNN, 21 | board_size=5, 22 | n_simulations=N_SIMULATIONS, 23 | games_per_iteration=2, 24 | train_per_iteration=1000, 25 | eval_games=10, 26 | batch_size=32, 27 | visualise_freq=10): 28 | 29 | # obs_shape = GoState(board_size=board_size).observed_state().shape 30 | # memory = np.array([memory_size] + obs_shape) 31 | # memory_idx = 0 32 | # memory_used = 0 33 | 34 | max_game_length = 2*board_size**2 35 | 36 | input_shape = [board_size, board_size, 2] 37 | 38 | memory = [] 39 | 40 | improvements = 0.0 41 | duels = 0.0 42 | 43 | model = policy_value_class(input_shape) 44 | best_model = model 45 | 46 | print("Training... Abort with Ctrl-C.") 47 | for i in itertools.count(): 48 | print("Iteration", i) 49 | try: 50 | for j in range(games_per_iteration): 51 | history, winner = agz.play_game(start_state=GoState(board_size), 52 | policy_value=best_model, 53 | n_simulations=n_simulations, 54 | max_game_length=max_game_length) 55 | 56 | 57 | for state, obs, pi in history: 58 | memory.append([obs, pi, winner]) 59 | 60 | if j % visualise_freq == 0: 61 | print("Visualising one game:") 62 | for state, board, choice in history: 63 | print(state) 64 | # input("") 65 | if winner == 1: 66 | print("Black won") 67 | elif winner == 0: 68 | print("Tie") 69 | else: 70 | print("White won") 71 | 72 | for j in range(int(train_per_iteration/batch_size)): 73 | samples = [random.choice(memory) for _ in range(batch_size)] 74 | obs, pi, z = [np.stack(x) for x in zip(*samples)] 75 | 76 | model.train_on_batch(obs, [pi, z]) 77 | 78 | score = 0 79 | for i in range(eval_games): 80 | start_state = GoState(board_size) 81 | old_agent = agz.MCTSAgent(best_model, GoState(board_size), n_simulations=n_simulations, self_play=True) # FIXME: adding noise through setting self_play 82 | new_agent = agz.MCTSAgent(model, GoState(board_size), n_simulations=n_simulations, self_play=True) # FIXME: adding noise through setting self_play 83 | 84 | if i % 2 == 0: # Play equal amounts of games as black/white 85 | history, winner = agz.duel(start_state, new_agent, old_agent) 86 | score += winner 87 | else: 88 | history, winner = agz.duel(start_state, old_agent, new_agent) 89 | score -= winner 90 | 91 | # Store history: 92 | for state, obs, pi in history: 93 | memory.append([obs, pi, winner]) 94 | 95 | print("New model won {} more games than old.".format(score)) 96 | if score > eval_games*0.05: 97 | best_model = model 98 | improvements += 1 99 | duels += 1 100 | print("{:2f} % of games were improvements".format(100.0*improvements/duels)) 101 | 102 | 103 | except KeyboardInterrupt: 104 | print("Stopped training with Ctrl-C.") 105 | break 106 | 107 | best_model.model.save('model_{}x{}_{}.h5'.format(board_size, 108 | board_size, 109 | i)) 110 | return best_model 111 | 112 | def main(n_simulations=N_SIMULATIONS): 113 | board_size = 9 114 | input_shape = [board_size, board_size, 2] 115 | dumb_model = policyvalue.SimpleCNN(input_shape=input_shape) 116 | smart_model = training_loop(board_size=board_size) 117 | 118 | print("First playing against initial version:") 119 | agz.main(policy_value=dumb_model, n_simulations=n_simulations) 120 | print("Now playing against trained version:") 121 | agz.main(policy_value=smart_model, n_simulations=n_simulations) 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /gostate_pachi.py: -------------------------------------------------------------------------------- 1 | from gym import error 2 | try: 3 | import pachi_py 4 | except ImportError as e: 5 | # The dependency group [pachi] should match the name is setup.py. 6 | raise error.DependencyNotInstalled( 7 | '{}. (HINT: you may need to install the Go dependencies via "pip install gym[pachi]".)' 8 | .format(e)) 9 | 10 | import numpy as np 11 | import gym 12 | from gym import spaces 13 | from gym.utils import seeding 14 | from six import StringIO 15 | import sys 16 | import six 17 | 18 | BLACK = pachi_py.BLACK 19 | WHITE = pachi_py.WHITE 20 | BOARD_SIZE = 5 21 | 22 | # The coordinate representation of Pachi (and pachi_py) is defined on a board 23 | # with extra rows and columns on the margin of the board, so positions on the board 24 | # are not numbers in [0, board_size**2) as one would expect. For this Go env, we instead 25 | # use an action representation that does fall in this more natural range. 26 | 27 | 28 | def _pass_action(board_size): 29 | return board_size**2 30 | 31 | 32 | def _resign_action(board_size): 33 | return board_size**2 + 1 34 | 35 | 36 | def _coord_to_action(board, c): 37 | '''Converts Pachi coordinates to actions''' 38 | if c == pachi_py.PASS_COORD: 39 | return _pass_action(board.size) 40 | if c == pachi_py.RESIGN_COORD: 41 | return _resign_action(board.size) 42 | i, j = board.coord_to_ij(c) 43 | return i * board.size + j 44 | 45 | 46 | def _action_to_coord(board, a): 47 | '''Converts actions to Pachi coordinates''' 48 | if a == _pass_action(board.size): 49 | return pachi_py.PASS_COORD 50 | if a == _resign_action(board.size): 51 | return pachi_py.RESIGN_COORD 52 | return board.ij_to_coord(a // board.size, a % board.size) 53 | 54 | 55 | def str_to_action(board, s): 56 | return _coord_to_action(board, board.str_to_coord(s.encode())) 57 | 58 | 59 | class GoState(object): 60 | ''' 61 | Go game state. Consists of a current player and a board. 62 | Actions are exposed as integers in [0, num_actions), which is different 63 | from Pachi's internal "coord_t" encoding. 64 | ''' 65 | 66 | def __init__(self, board_size=BOARD_SIZE, color=BLACK, board=None): 67 | ''' 68 | Args: 69 | board_size: size of board 70 | color: color of current player 71 | board: current board 72 | ''' 73 | assert color in [BLACK, WHITE], 'Invalid player color' 74 | 75 | if board: 76 | self.board = board 77 | else: 78 | self.board = pachi_py.CreateBoard(board_size) 79 | 80 | self.last_action = -1 81 | self.last_action_2 = -1 82 | 83 | self.color = color 84 | self.board_size = board_size 85 | 86 | self.game_over = False 87 | self.winner = None 88 | self.action_space = spaces.Discrete(board_size**2 + 1) 89 | 90 | self._new_state_checks() 91 | 92 | if color == BLACK: 93 | self.current_player = 1 94 | else: 95 | self.current_player = -1 96 | 97 | def act(self, action): 98 | ''' 99 | Executes an action for the current player 100 | ''' 101 | 102 | try: 103 | self.board = self.board.play( 104 | _action_to_coord(self.board, action), self.color) 105 | except pachi_py.IllegalMove: 106 | # Will do pass turn on disallowed move 107 | action = _pass_action(self.board_size) 108 | self.board = self.board.play( 109 | _action_to_coord(self.board, action), self.color) 110 | 111 | self.color = pachi_py.stone_other(self.color) 112 | self.current_player = -self.current_player 113 | 114 | self.last_action_2 = self.last_action 115 | self.last_action = action 116 | 117 | self._new_state_checks() # Updates self.game_over and self.winner 118 | 119 | def stateless_act(self, action): 120 | ''' 121 | Executes an action for the current player on a copy of the state 122 | Returns: 123 | a new GoState with the new board and the player switched 124 | ''' 125 | try: 126 | new_board = self.board.play( 127 | _action_to_coord(self.board, action), self.color) 128 | except pachi_py.IllegalMove: 129 | # Will do pass turn on invalid move 130 | action = _pass_action(self.board_size) 131 | new_board = self.board.play( 132 | _action_to_coord(self.board, action), self.color) 133 | 134 | new_state = GoState( 135 | board_size=self.board_size, 136 | color=pachi_py.stone_other(self.color), 137 | board=new_board) 138 | 139 | new_state.last_action_2 = new_state.last_action 140 | new_state.last_action = action 141 | 142 | new_state._new_state_checks() 143 | 144 | return new_state 145 | 146 | def step(self, choice): 147 | """Makes an action from choice of valid_actions""" 148 | action = self.valid_actions[choice] 149 | self.act(action) 150 | 151 | def observed_state(self): 152 | return self._observed_state 153 | 154 | def _new_state_checks(self): 155 | """Checks if game is over, who won and updates valid states""" 156 | double_pass = (self.last_action is _pass_action(self.board_size)) and \ 157 | (self.last_action_2 is _pass_action(self.board_size)) 158 | 159 | self.game_over = self.board.is_terminal or double_pass 160 | 161 | if self.game_over: 162 | self.winner = self._compute_winner() 163 | 164 | encoded_board = self.board.encode() 165 | 166 | self._observed_state = encoded_board[:2].transpose() 167 | self.valid_actions = self._valid_actions(encoded_board[2]) 168 | 169 | def _valid_actions(self, empty_positions): 170 | actions = [] 171 | for action in range(self.board_size**2): 172 | # coord = board.ij_to_coord(action // board.size, action % board.size) 173 | if empty_positions[action // self.board.size, action % 174 | self.board.size] == 1: 175 | actions.append(action) 176 | 177 | return actions + [_pass_action(self.board_size)] 178 | 179 | def _compute_winner(self): 180 | """Returns winner as -1/0/1 for white/tie/black""" 181 | white_won = self.board.official_score > 0 182 | black_won = self.board.official_score < 0 183 | return black_won - white_won 184 | 185 | def __repr__(self): 186 | return 'To play: {}\n{}'.format( 187 | six.u(pachi_py.color_to_str(self.color)), 188 | self.board.__repr__().decode()) 189 | 190 | 191 | def act(state, action): 192 | """Functional version of act""" 193 | return state.stateless_act(action) 194 | 195 | 196 | def step(state, choice): 197 | """Functional version of step""" 198 | return act(state, state.valid_actions[choice]) 199 | 200 | 201 | ### Adversary policies ### 202 | def make_random_policy(np_random): 203 | def random_policy(curr_state, prev_state, prev_action): 204 | b = curr_state.board 205 | legal_coords = b.get_legal_coords(curr_state.color) 206 | return _coord_to_action(b, np_random.choice(legal_coords)) 207 | 208 | return random_policy 209 | 210 | 211 | def make_pachi_policy(board, engine_type='uct', threads=1, pachi_timestr=''): 212 | engine = pachi_py.PyPachiEngine(board, engine_type, 213 | six.b('threads=%d' % threads)) 214 | 215 | def pachi_policy(curr_state, prev_state, prev_action): 216 | if prev_state is not None: 217 | assert engine.curr_board == prev_state.board, 'Engine internal board is inconsistent with provided board. The Pachi engine must be called consistently as the game progresses.' 218 | prev_coord = _action_to_coord(prev_state.board, prev_action) 219 | engine.notify(prev_coord, prev_state.color) 220 | engine.curr_board.play_inplace(prev_coord, prev_state.color) 221 | out_coord = engine.genmove(curr_state.color, pachi_timestr) 222 | out_action = _coord_to_action(curr_state.board, out_coord) 223 | engine.curr_board.play_inplace(out_coord, curr_state.color) 224 | return out_action 225 | 226 | return pachi_policy 227 | -------------------------------------------------------------------------------- /agz.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import copy 4 | import random 5 | import time 6 | import os 7 | import itertools 8 | 9 | import numpy as np 10 | 11 | from six.moves import input 12 | 13 | """ Slower but more hackable go implementation: 14 | from gostate import GoState 15 | """ 16 | from gostate_pachi import GoState 17 | from gostate_pachi import step 18 | 19 | from policyvalue import NaivePolicyValue 20 | from policyvalue import SimpleCNN 21 | 22 | # import tqdm 23 | 24 | BOARD_SIZE = 5 25 | C_PUCT = 1.0 26 | N_SIMULATIONS = 160 27 | 28 | """ 29 | MCTS logic and go playing / visualisation. 30 | 31 | TODO: 32 | - Decide on CLI arguments and use argparse 33 | 34 | """ 35 | 36 | # '-d level' argument for printing specific debug level: 37 | if "-d" in sys.argv: 38 | level_idx = sys.argv.index("-d") + 1 39 | level = int(sys.argv[level_idx]) if level_idx < len(sys.argv) else 10 40 | logging.basicConfig(level=level) 41 | else: 42 | logging.basicConfig(level=logging.INFO) 43 | 44 | logger = logging.getLogger(__name__) 45 | np.set_printoptions(3) 46 | 47 | 48 | class TreeStructure(object): 49 | """ Node in the MCTS tree structure """ 50 | def __init__(self, state, parent=None, choice_that_led_here=None): 51 | 52 | self.children = {} # map from choice to node 53 | 54 | self.parent = parent 55 | self.state = state 56 | 57 | self.w = np.zeros(len(state.valid_actions)) 58 | self.n = np.zeros(len(state.valid_actions)) 59 | self.n += (1.0 + np.random.rand(len(self.n)))*1e-10 60 | self.prior_policy = 1.0/len(self.state.valid_actions) 61 | 62 | self.sum_n = 1 63 | self.choice_that_led_here = choice_that_led_here 64 | 65 | self.move_number = 0 66 | 67 | if parent: 68 | self.move_number = parent.move_number + 1 69 | 70 | 71 | def history_sample(self): 72 | """Returns a representation of state to be stored for training""" 73 | pi = np.zeros(self.state.action_space.n) 74 | pi[self.state.valid_actions] = self.n/self.n.sum() 75 | return [self.state, self.state.observed_state(), pi] 76 | 77 | def add_noise_to_prior(self, noise_frac=0.25, dirichlet_alpha=0.03): 78 | noise = np.random.dirichlet(dirichlet_alpha*np.ones(len(self.state.valid_actions))) 79 | self.prior_policy = (1-noise_frac)*self.prior_policy + noise_frac*noise 80 | 81 | def sample(probs): 82 | """Sample from unnormalized probabilities""" 83 | 84 | probs = probs / probs.sum() 85 | return np.random.choice(np.arange(len(probs)), p=probs.flatten()) 86 | 87 | def puct_distribution(node): 88 | 89 | """Puct equation""" 90 | # this should never be a distribution but always maximised over 91 | 92 | return node.w/node.n + C_PUCT*node.prior_policy*np.sqrt(node.sum_n)/(1 + node.n) 93 | 94 | def puct_choice(node): 95 | """Selects the next move.""" 96 | return np.argmax(puct_distribution(node)) 97 | 98 | 99 | def choice_to_play(node, opponent=None): 100 | """Samples a move if self play training.""" 101 | logger.debug("Selecting move # {}".format(node.move_number)) 102 | logger.debug(node.w) 103 | logger.debug(node.n) 104 | logger.debug(node.prior_policy) 105 | 106 | if node.move_number < 30 and opponent is None: 107 | return sample(node.n) 108 | else: 109 | return np.argmax(node.n) 110 | 111 | def backpropagate(node, value): 112 | """MCTS backpropagation""" 113 | 114 | def _increment(node, choice, value): 115 | # Mirror value for odd states: 116 | value *= 1 - 2*(node.move_number % 2) # TODO: use node.state.current_player after changing it to (+1, -1) 117 | node.w[choice] += value 118 | node.n[choice] += 1 119 | node.sum_n += 1 120 | 121 | while node.parent: 122 | _increment(node.parent, node.choice_that_led_here, value) 123 | node = node.parent 124 | 125 | 126 | def mcts(tree_root, policy_value, n_simulations): 127 | # for i in tqdm.tqdm(range(n_simulations)): 128 | for i in range(n_simulations): 129 | node = tree_root 130 | # Select from "PUCT/UCB1 equation" in paper. 131 | choice = puct_choice(node) 132 | while choice in node.children.keys(): 133 | node = node.children[choice] 134 | choice = puct_choice(node) 135 | 136 | if node.state.game_over: 137 | # This happens the 'second time' we go to a winning state. 138 | # Logic for visiting "winning nodes" multiple times is probably correct? 139 | value = node.state.winner 140 | backpropagate(node, value) 141 | continue 142 | 143 | # Expand tree: 144 | new_state = step(node.state, choice) 145 | node.children[choice] = TreeStructure(new_state, node, choice) 146 | node = node.children[choice] 147 | 148 | if new_state.game_over: 149 | value = new_state.winner 150 | else: 151 | policy, value = policy_value.predict(node.state) 152 | node.prior_policy = policy[node.state.valid_actions] 153 | 154 | backpropagate(node, value) 155 | 156 | 157 | def print_tree(tree_root, level): 158 | print(" "*level, tree_root.choice_that_led_here, tree_root.state.board, tree_root.n, tree_root.w) 159 | # [print_tree(tree_root.children[i], level + 1) for i in tree_root.children] 160 | 161 | class MCTSAgent(object): 162 | """Object that keeps track of MCTS tree and can perform actions""" 163 | 164 | def __init__(self, policy_value, state, n_simulations=N_SIMULATIONS): 165 | self.policy_value = policy_value 166 | self.game_history = list() 167 | self.tree_root = TreeStructure(state) 168 | self.n_simulations = n_simulations 169 | 170 | policy, value = self.policy_value.predict(self.tree_root.state) 171 | self.tree_root.prior_policy = policy[self.tree_root.state.valid_actions] 172 | assert type(self.tree_root.prior_policy) != float, "Prior_policy was not np array" 173 | self.tree_root.add_noise_to_prior() 174 | 175 | def update_state(self, choice): 176 | self.game_history.append(self.tree_root.history_sample()) 177 | 178 | if choice in self.tree_root.children: 179 | self.tree_root = self.tree_root.children[choice] 180 | else: 181 | new_state = step(self.tree_root.state, choice) 182 | self.tree_root = TreeStructure(new_state) 183 | policy, value = self.policy_value.predict(self.tree_root.state) 184 | self.tree_root.prior_policy = policy[self.tree_root.state.valid_actions] 185 | self.tree_root.add_noise_to_prior() 186 | self.tree_root.parent = None 187 | 188 | def perform_simulations(self, n_simulations=None): 189 | n_simulations = n_simulations or self.n_simulations 190 | mcts(self.tree_root, self.policy_value, n_simulations) 191 | 192 | def decision(self, self_play=False): 193 | return choice_to_play(self.tree_root, not self_play) 194 | 195 | 196 | def duel(state, agent_1, agent_2, max_game_length=1e99): 197 | """Plays two agants against each other""" 198 | history = [] 199 | 200 | agents = itertools.cycle([agent_1, agent_2]) 201 | 202 | move_number = 0 203 | while not state.game_over and move_number < max_game_length: 204 | actor = next(agents) 205 | actor.perform_simulations() 206 | choice = actor.decision() 207 | 208 | history.append(actor.tree_root.history_sample()) 209 | 210 | state.step(choice) 211 | agent_1.update_state(choice) 212 | agent_2.update_state(choice) 213 | 214 | move_number += 1 215 | 216 | if move_number >= max_game_length: 217 | state.winner = state._compute_winner() 218 | 219 | return history, state.winner 220 | 221 | 222 | # TODO: Create agent class from this that can be queried 223 | def play_game(start_state=GoState(), 224 | policy_value=NaivePolicyValue(), 225 | max_game_length=1e99, 226 | opponent=None, 227 | n_simulations=N_SIMULATIONS): 228 | """ 229 | Plays a game against itself or specified opponent. 230 | 231 | The state should be prepared so that it is the agents turn, 232 | and so that `self.winner == 1` when the agent won. 233 | """ 234 | 235 | # TODO: This will set .move_number = 0, should maybe track whose turn it is instead: 236 | tree_root = TreeStructure(start_state) 237 | policy, value = policy_value.predict(tree_root.state) 238 | tree_root.prior_policy = policy[tree_root.state.valid_actions] 239 | tree_root.add_noise_to_prior() 240 | game_history = [] 241 | 242 | while not tree_root.state.game_over and tree_root.move_number < max_game_length: 243 | 244 | mcts(tree_root, policy_value, n_simulations) 245 | 246 | # print_tree(tree_root,0) 247 | # Store the state and distribution before we prune the tree: 248 | # TODO: Refactor this 249 | 250 | game_history.append(tree_root.history_sample()) 251 | 252 | choice = choice_to_play(tree_root, bool(opponent)) 253 | tree_root = tree_root.children[choice] 254 | tree_root.parent = None 255 | tree_root.add_noise_to_prior() 256 | 257 | if opponent: 258 | game_history.append(tree_root.history_sample()) 259 | choice = opponent(tree_root.state) 260 | if choice in tree_root.children: 261 | tree_root = tree_root.children[choice] 262 | else: 263 | new_state = step(tree_root.state, choice) 264 | tree_root = TreeStructure(new_state) 265 | policy, value = policy_value.predict(tree_root.state) 266 | tree_root.prior_policy = policy[tree_root.state.valid_actions] 267 | tree_root.parent = None 268 | 269 | if tree_root.move_number >= max_game_length: 270 | tree_root.state.winner = tree_root.state._compute_winner() 271 | 272 | return game_history, tree_root.state.winner 273 | 274 | 275 | # UI code below: 276 | def human_opponent(state): 277 | """Queries human for move when called.""" 278 | print(state) 279 | while True: 280 | inp = input("What is your move? \n") 281 | if inp == 'pass': 282 | return len(state.valid_actions) - 1 283 | if inp == 'random': 284 | return random.randint(0, len(state.valid_actions) - 1) 285 | 286 | try: 287 | pos = [int(x) for x in inp.split()] 288 | action = pos[0]*state.board_size + pos[1] 289 | choice = state.valid_actions.index(action) 290 | return choice 291 | except: 292 | print("Invalid move {} try again.".format(inp)) 293 | 294 | 295 | def self_play_visualisation(board_size=BOARD_SIZE): 296 | """Visualises one game of self_play""" 297 | policy_value = SimpleCNN([board_size, board_size, 2]) 298 | history, winner = play_game(policy_value=policy_value) 299 | print("Watching game replay\nPress Return to advance board") 300 | for state, board, hoice in history: 301 | print(state) 302 | input("") 303 | 304 | if winner == 1: 305 | print("Black won") 306 | else: 307 | print("White won") 308 | 309 | def duel_players(player_1, player_2): 310 | return winner(player_1, player_2) 311 | 312 | def main(policy_value=NaivePolicyValue(), board_size=BOARD_SIZE, n_simulations=N_SIMULATIONS): 313 | 314 | if "-selfplay" in sys.argv: 315 | self_play_visualisation() 316 | return 317 | 318 | if "-40" in sys.argv: 319 | n_simulations = 40 320 | print("Letting MCTS search for {} moves!".format(n_simulations)) 321 | 322 | if '-nogpu' in sys.argv: 323 | # The following is neecessary if GPU memory is full (training) 324 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 325 | 326 | # Loads weights that trained for 60 iterations 327 | policy_value = SimpleCNN([board_size, board_size, 2]) 328 | policy_value.load(6) 329 | 330 | print("") 331 | print("Welcome!") 332 | print("Format moves like: y x") 333 | print("(or pass/random)") 334 | print("") 335 | try: 336 | history, winner = play_game(start_state=GoState(board_size), 337 | policy_value=policy_value, 338 | opponent=human_opponent, 339 | n_simulations=n_simulations) 340 | except KeyboardInterrupt: 341 | print("Game aborted.") 342 | return 343 | 344 | if winner == 1: 345 | print("AI won") 346 | else: 347 | print("Human won") 348 | 349 | if __name__ == "__main__": 350 | -------------------------------------------------------------------------------- /goboard.py: -------------------------------------------------------------------------------- 1 | # This Source Code Form is subject to the terms of the Mozilla Public License, 2 | # v. 2.0. If a copy of the MPL was not distributed with this file, You can 3 | # obtain one at http://mozilla.org/MPL/2.0/. 4 | 5 | from __future__ import absolute_import 6 | import copy 7 | from six.moves import range 8 | 9 | """Original copyright Max Pumperla written for betago.""" 10 | 11 | class GoBoard(object): 12 | ''' 13 | Representation of a go board. It contains "GoStrings" to represent stones and liberties. Moreover, 14 | the board can account for ko and handle captured stones. 15 | ''' 16 | def __init__(self, board_size=19): 17 | ''' 18 | Parameters 19 | ---------- 20 | ko_last_move_num_captured: How many stones have been captured last move. If this is not 1, it can't be ko. 21 | ko_last_move: board position of the ko. 22 | board_size: Side length of the board, defaulting to 19. 23 | go_strings: Dictionary of go_string objects representing stones and liberties. 24 | ''' 25 | self.ko_last_move_num_captured = 0 26 | self.ko_last_move = -3 27 | self.board_size = board_size 28 | self.board = {} 29 | self.go_strings = {} 30 | 31 | def fold_go_strings(self, target, source, join_position): 32 | ''' Merge two go strings by joining their common moves''' 33 | if target == source: 34 | return 35 | for stone_position in source.stones.stones: 36 | self.go_strings[stone_position] = target 37 | target.insert_stone(stone_position) 38 | target.copy_liberties_from(source) 39 | target.remove_liberty(join_position) 40 | 41 | def add_adjacent_liberty(self, pos, go_string): 42 | ''' 43 | Append new liberties to provided GoString for the current move 44 | ''' 45 | row, col = pos 46 | if row < 0 or col < 0 or row > self.board_size - 1 or col > self.board_size - 1: 47 | return 48 | if pos not in self.board: 49 | go_string.insert_liberty(pos) 50 | 51 | def is_move_on_board(self, move): 52 | return move in self.board 53 | 54 | def is_move_suicide(self, color, pos): 55 | '''Check if a proposed move would be suicide.''' 56 | # Make a copy of ourself to apply the move. 57 | temp_board = copy.deepcopy(self) 58 | temp_board.apply_move(color, pos) 59 | new_string = temp_board.go_strings[pos] 60 | return new_string.get_num_liberties() == 0 61 | 62 | def is_move_legal(self, color, pos): 63 | '''Check if a proposed moved is legal.''' 64 | return (not self.is_move_on_board(pos)) and \ 65 | (not self.is_move_suicide(color, pos)) and \ 66 | (not self.is_simple_ko(color, pos)) 67 | 68 | def create_go_string(self, color, pos): 69 | ''' Create GoString from current Board and move ''' 70 | go_string = GoString(self.board_size, color) 71 | go_string.insert_stone(pos) 72 | self.go_strings[pos] = go_string 73 | self.board[pos] = color 74 | 75 | row, col = pos 76 | for adjpos in [(row - 1, col), (row + 1, col), (row, col - 1), (row, col + 1)]: 77 | self.add_adjacent_liberty(adjpos, go_string) 78 | return go_string 79 | 80 | def other_color(self, color): 81 | ''' 82 | Color of other player 83 | ''' 84 | if color == 'b': 85 | return 'w' 86 | if color == 'w': 87 | return 'b' 88 | 89 | def is_simple_ko(self, play_color, pos): 90 | ''' 91 | Determine ko from board position and player. 92 | 93 | Parameters: 94 | ----------- 95 | play_color: Color of the player to make the next move. 96 | pos: Current move as (row, col) 97 | ''' 98 | enemy_color = self.other_color(play_color) 99 | row, col = pos 100 | if self.ko_last_move_num_captured == 1: 101 | last_move_row, last_move_col = self.ko_last_move 102 | manhattan_distance_last_move = abs(last_move_row - row) + abs(last_move_col - col) 103 | if manhattan_distance_last_move == 1: 104 | last_go_string = self.go_strings.get((last_move_row, last_move_col)) 105 | if last_go_string is not None and last_go_string.get_num_liberties() == 1: 106 | if last_go_string.get_num_stones() == 1: 107 | num_adjacent_enemy_liberties = 0 108 | for adjpos in [(row - 1, col), (row + 1, col), (row, col - 1), (row, col + 1)]: 109 | if (self.board.get(adjpos) == enemy_color and 110 | self.go_strings[adjpos].get_num_liberties() == 1): 111 | num_adjacent_enemy_liberties = num_adjacent_enemy_liberties + 1 112 | if num_adjacent_enemy_liberties == 1: 113 | return True 114 | return False 115 | 116 | def check_enemy_liberty(self, play_color, enemy_pos, our_pos): 117 | ''' 118 | Update surrounding liberties on board after a move has been played. 119 | 120 | Parameters: 121 | ----------- 122 | play_color: Color of player about to move 123 | enemy_pos: latest enemy move 124 | our_pos: our latest move 125 | ''' 126 | enemy_row, enemy_col = enemy_pos 127 | our_row, our_col = our_pos 128 | 129 | # Sanity checks 130 | if enemy_row < 0 or enemy_row >= self.board_size or enemy_col < 0 or enemy_col >= self.board_size: 131 | return 132 | enemy_color = self.other_color(play_color) 133 | if self.board.get(enemy_pos) != enemy_color: 134 | return 135 | enemy_string = self.go_strings[enemy_pos] 136 | if enemy_string is None: 137 | raise ValueError('Inconsistency between board and go_strings at %r' % enemy_pos) 138 | 139 | # Update adjacent liberties on board 140 | enemy_string.remove_liberty(our_pos) 141 | if enemy_string.get_num_liberties() == 0: 142 | for enemy_pos in enemy_string.stones.stones: 143 | string_row, string_col = enemy_pos 144 | del self.board[enemy_pos] 145 | del self.go_strings[enemy_pos] 146 | self.ko_last_move_num_captured = self.ko_last_move_num_captured + 1 147 | for adjstring in [(string_row - 1, string_col), (string_row + 1, string_col), 148 | (string_row, string_col - 1), (string_row, string_col + 1)]: 149 | self.add_liberty_to_adjacent_string(adjstring, enemy_pos, play_color) 150 | 151 | def apply_move(self, play_color, pos): 152 | ''' 153 | Execute move for given color, i.e. play current stone on this board 154 | Parameters: 155 | ----------- 156 | play_color: Color of player about to move 157 | pos: Current move as (row, col) 158 | ''' 159 | if pos in self.board: 160 | raise ValueError('Move ' + str(pos) + 'is already on board.') 161 | 162 | self.ko_last_move_num_captured = 0 163 | row, col = pos 164 | 165 | # Remove any enemy stones that no longer have a liberty 166 | self.check_enemy_liberty(play_color, (row - 1, col), pos) 167 | self.check_enemy_liberty(play_color, (row + 1, col), pos) 168 | self.check_enemy_liberty(play_color, (row, col - 1), pos) 169 | self.check_enemy_liberty(play_color, (row, col + 1), pos) 170 | 171 | # Create a GoString for our new stone, and merge with any adjacent strings 172 | play_string = self.create_go_string(play_color, pos) 173 | play_string = self.fold_our_moves(play_string, play_color, (row - 1, col), pos) 174 | play_string = self.fold_our_moves(play_string, play_color, (row + 1, col), pos) 175 | play_string = self.fold_our_moves(play_string, play_color, (row, col - 1), pos) 176 | play_string = self.fold_our_moves(play_string, play_color, (row, col + 1), pos) 177 | 178 | # Store last move for ko 179 | self.ko_last_move = pos 180 | 181 | def add_liberty_to_adjacent_string(self, string_pos, liberty_pos, color): 182 | ''' Insert liberty into corresponding GoString ''' 183 | if self.board.get(string_pos) != color: 184 | return 185 | go_string = self.go_strings[string_pos] 186 | go_string.insert_liberty(liberty_pos) 187 | 188 | def fold_our_moves(self, first_string, color, pos, join_position): 189 | ''' Fold current board situation with a new move played by us''' 190 | row, col = pos 191 | if row < 0 or row >= self.board_size or col < 0 or col >= self.board_size: 192 | return first_string 193 | if self.board.get(pos) != color: 194 | return first_string 195 | string_to_fold = self.go_strings[pos] 196 | self.fold_go_strings(string_to_fold, first_string, join_position) 197 | return string_to_fold 198 | 199 | def __str__(self): 200 | result = 'GoBoard\n' 201 | for i in range(self.board_size - 1, -1, -1): 202 | line = '' 203 | for j in range(0, self.board_size): 204 | thispiece = self.board.get((i, j)) 205 | if thispiece is None: 206 | line = line + '.' 207 | if thispiece == 'b': 208 | line = line + '*' 209 | if thispiece == 'w': 210 | line = line + 'O' 211 | result = result + line + '\n' 212 | return result 213 | 214 | 215 | class BoardSequence(object): 216 | ''' 217 | Store a sequence of locations on a board, which could either represent stones or liberties. 218 | ''' 219 | def __init__(self, board_size=19): 220 | self.board_size = board_size 221 | self.stones = [] 222 | self.board = {} 223 | 224 | def insert(self, combo): 225 | row, col = combo 226 | if combo in self.board: 227 | return 228 | self.stones.append(combo) 229 | self.board[combo] = len(self.stones) - 1 230 | 231 | def erase(self, combo): 232 | if combo not in self.board: 233 | return 234 | iid = self.board[combo] 235 | if iid == len(self.stones) - 1: 236 | del self.stones[iid] 237 | del self.board[combo] 238 | return 239 | self.stones[iid] = self.stones[len(self.stones) - 1] 240 | del self.stones[len(self.stones) - 1] 241 | movedcombo = self.stones[iid] 242 | self.board[movedcombo] = iid 243 | del self.board[combo] 244 | 245 | def exists(self, combo): 246 | return combo in self.board 247 | 248 | def size(self): 249 | return len(self.stones) 250 | 251 | def __getitem__(self, iid): 252 | return self.stones[iid] 253 | 254 | def __str__(self): 255 | result = 'BoardSequence\n' 256 | for row in range(self.board_size - 1, -1, -1): 257 | thisline = "" 258 | for col in range(0, self.board_size): 259 | if self.exists((row, col)): 260 | thisline = thisline + "*" 261 | else: 262 | thisline = thisline + "." 263 | result = result + thisline + "\n" 264 | return result 265 | 266 | 267 | class GoString(object): 268 | ''' 269 | Represents a string of contiguous stones of one color on the board, including a list of all its liberties. 270 | ''' 271 | def __init__(self, board_size, color): 272 | self.board_size = board_size 273 | self.color = color 274 | self.liberties = BoardSequence(board_size) 275 | self.stones = BoardSequence(board_size) 276 | 277 | def get_stone(self, index): 278 | return self.stones[index] 279 | 280 | def get_liberty(self, index): 281 | return self.liberties[index] 282 | 283 | def insert_stone(self, combo): 284 | self.stones.insert(combo) 285 | 286 | def get_num_stones(self): 287 | return self.stones.size() 288 | 289 | def remove_liberty(self, combo): 290 | self.liberties.erase(combo) 291 | 292 | def get_num_liberties(self): 293 | return self.liberties.size() 294 | 295 | def insert_liberty(self, combo): 296 | self.liberties.insert(combo) 297 | 298 | def copy_liberties_from(self, source): 299 | for libertyPos in source.liberties.stones: 300 | self.liberties.insert(libertyPos) 301 | 302 | def __str__(self): 303 | result = "go_string[ stones=" + str(self.stones) + " liberties=" + str(self.liberties) + " ]" 304 | return result 305 | 306 | 307 | def from_string(board_string): 308 | """Build a board from an ascii-art representation. 309 | 310 | 'b' for black stones 311 | 'w' for white stones 312 | '.' for empty 313 | 314 | The bottom row is row 0, and the top row is row boardsize - 1. This 315 | matches the normal way you'd use board coordinates, with A1 in the 316 | bottom-left. 317 | 318 | Rows are separated by newlines. Extra whitespace is ignored. 319 | """ 320 | rows = [line.strip() for line in board_string.strip().split("\n")] 321 | boardsize = len(rows) 322 | if any(len(row) != boardsize for row in rows): 323 | raise ValueError('Board must be square') 324 | 325 | board = GoBoard(boardsize) 326 | rows.reverse() 327 | for r, row_string in enumerate(rows): 328 | for c, point in enumerate(row_string): 329 | if point in ('b', 'w'): 330 | board.apply_move(point, (r, c)) 331 | return board 332 | 333 | 334 | def to_string(board): 335 | """Make an ascii-art representation of a board.""" 336 | rows = [] 337 | for r in range(board.board_size): 338 | row = '' 339 | for c in range(board.board_size): 340 | row += board.board.get((r, c), '.') 341 | rows.append(row) 342 | rows.reverse() 343 | return '\n'.join(rows) 344 | --------------------------------------------------------------------------------