├── .gitignore ├── screenshots └── human_ui.png ├── gym_go ├── envs │ ├── __init__.py │ ├── go_extrahard_env.py │ └── go_env.py ├── govars.py ├── __init__.py ├── tests │ ├── test_batch_fns.py │ ├── efficiency.py │ ├── test_invalid_moves.py │ ├── test_valid_moves.py │ └── test_basics.py ├── rendering.py ├── state_utils.py └── gogame.py ├── setup.py ├── demo.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.egg-info/ 3 | __pycache__/ 4 | .ipynb_checkpoints/ 5 | .vscode/ 6 | -------------------------------------------------------------------------------- /screenshots/human_ui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangeddie/GymGo/HEAD/screenshots/human_ui.png -------------------------------------------------------------------------------- /gym_go/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from gym_go.envs.go_env import GoEnv 2 | from gym_go.envs.go_extrahard_env import GoExtraHardEnv 3 | -------------------------------------------------------------------------------- /gym_go/envs/go_extrahard_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | 4 | class GoExtraHardEnv(gym.Env): 5 | metadata = {'render.modes': ['human', 'terminal']} 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='gym_go', 5 | version='0.0.1', 6 | install_requires=['gym'] # and other dependencies 7 | ) 8 | -------------------------------------------------------------------------------- /gym_go/govars.py: -------------------------------------------------------------------------------- 1 | ANYONE = None 2 | NOONE = -1 3 | 4 | BLACK = 0 5 | WHITE = 1 6 | TURN_CHNL = 2 7 | INVD_CHNL = 3 8 | PASS_CHNL = 4 9 | DONE_CHNL = 5 10 | 11 | NUM_CHNLS = 6 12 | -------------------------------------------------------------------------------- /gym_go/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='go-v0', 5 | entry_point='gym_go.envs:GoEnv', 6 | ) 7 | register( 8 | id='go-extrahard-v0', 9 | entry_point='gym_go.envs:GoExtraHardEnv', 10 | ) 11 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | 5 | # Arguments 6 | parser = argparse.ArgumentParser(description='Demo Go Environment') 7 | parser.add_argument('--boardsize', type=int, default=7) 8 | parser.add_argument('--komi', type=float, default=0) 9 | args = parser.parse_args() 10 | 11 | # Initialize environment 12 | go_env = gym.make('gym_go:go-v0', size=args.boardsize, komi=args.komi) 13 | 14 | # Game loop 15 | done = False 16 | while not done: 17 | action = go_env.render(mode="human") 18 | state, reward, done, info = go_env.step(action) 19 | 20 | if go_env.game_ended(): 21 | break 22 | action = go_env.uniform_random_action() 23 | state, reward, done, info = go_env.step(action) 24 | go_env.render(mode="human") 25 | -------------------------------------------------------------------------------- /gym_go/tests/test_batch_fns.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from gym_go import gogame, govars 4 | 5 | 6 | class TestBatchFns(unittest.TestCase): 7 | 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | 11 | def setUp(self): 12 | pass 13 | 14 | def test_batch_canonical_form(self): 15 | states = gogame.batch_init_state(2, 7) 16 | states[0] = gogame.next_state(states[0], 0) 17 | 18 | self.assertEqual(states[0, govars.BLACK].sum(), 1) 19 | self.assertEqual(states[0, govars.WHITE].sum(), 0) 20 | 21 | states = gogame.batch_canonical_form(states) 22 | 23 | self.assertEqual(states[0, govars.BLACK].sum(), 0) 24 | self.assertEqual(states[0, govars.WHITE].sum(), 1) 25 | 26 | self.assertEqual(states[1, govars.BLACK].sum(), 0) 27 | self.assertEqual(states[1, govars.WHITE].sum(), 0) 28 | 29 | for i in range(2): 30 | self.assertEqual(gogame.turn(states[i]), govars.BLACK) 31 | 32 | canon_again = gogame.batch_canonical_form(states) 33 | 34 | self.assertTrue((canon_again == states).all()) 35 | 36 | 37 | if __name__ == '__main__': 38 | unittest.main() 39 | -------------------------------------------------------------------------------- /gym_go/tests/efficiency.py: -------------------------------------------------------------------------------- 1 | import time 2 | import unittest 3 | 4 | import gym 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | 9 | class Efficiency(unittest.TestCase): 10 | boardsize = 9 11 | iterations = 64 12 | 13 | def setUp(self) -> None: 14 | self.env = gym.make('gym_go:go-v0', size=self.boardsize, reward_method='real') 15 | 16 | def testOrderedTrajs(self): 17 | durs = [] 18 | for _ in tqdm(range(self.iterations)): 19 | start = time.time() 20 | self.env.reset() 21 | for a in range(self.boardsize ** 2 - 2): 22 | self.env.step(a) 23 | end = time.time() 24 | 25 | dur = end - start 26 | durs.append(dur) 27 | 28 | avg_time = np.mean(durs) 29 | std_time = np.std(durs) 30 | print(f"Ordered Trajs: {avg_time:.3f} AVG, {std_time:.3f} STD", flush=True) 31 | 32 | def testLowerBound(self): 33 | durs = [] 34 | for _ in tqdm(range(self.iterations)): 35 | start = time.time() 36 | state = self.env.reset() 37 | 38 | max_steps = self.boardsize ** 2 39 | for s in range(max_steps): 40 | for _ in range(max_steps - s): 41 | np.copy(state) 42 | 43 | pi = np.ones(self.boardsize ** 2 + 1) / (self.boardsize ** 2 + 1) 44 | a = np.random.choice(np.arange(self.boardsize ** 2 + 1), p=pi) 45 | np.copy(state) 46 | 47 | end = time.time() 48 | 49 | dur = end - start 50 | durs.append(dur) 51 | 52 | avg_time = np.mean(durs) 53 | std_time = np.std(durs) 54 | print(f"Lower bound: {avg_time:.3f} AVG, {std_time:.3f} STD", flush=True) 55 | 56 | def testRandTrajsWithChildren(self): 57 | durs = [] 58 | num_steps = [] 59 | for _ in tqdm(range(self.iterations)): 60 | start = time.time() 61 | self.env.reset() 62 | 63 | max_steps = 2 * self.boardsize ** 2 64 | s = 0 65 | for s in range(max_steps): 66 | valid_moves = self.env.valid_moves() 67 | self.env.children(canonical=True) 68 | # Do not pass if possible 69 | if np.sum(valid_moves) > 1: 70 | valid_moves[-1] = 0 71 | probs = valid_moves / np.sum(valid_moves) 72 | a = np.random.choice(np.arange(self.boardsize ** 2 + 1), p=probs) 73 | state, _, done, _ = self.env.step(a) 74 | if done: 75 | break 76 | num_steps.append(s) 77 | 78 | end = time.time() 79 | 80 | dur = end - start 81 | durs.append(dur) 82 | 83 | avg_time = np.mean(durs) 84 | std_time = np.std(durs) 85 | avg_steps = np.mean(num_steps) 86 | print(f"Rand Trajs w/ Children: {avg_time:.3f} AVG SEC, {std_time:.3f} STD SEC, {avg_steps:.1f} AVG STEPS", 87 | flush=True) 88 | 89 | 90 | if __name__ == '__main__': 91 | unittest.main() 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # About 2 | An environment for the board game Go. It is implemented using OpenAI's Gym API. 3 | It is also optimized to be as efficient as possible in order to efficiently train ML models. 4 | 5 | # Installation 6 | ```bash 7 | # In the root directory 8 | pip install -e . 9 | ``` 10 | 11 | # API 12 | 13 | ### Coding example 14 | ```python 15 | import gym 16 | 17 | go_env = gym.make('gym_go:go-v0', size=7, komi=0, reward_method='real') 18 | 19 | first_action = (2,5) 20 | second_action = (5,2) 21 | state, reward, done, info = go_env.step(first_action) 22 | go_env.render('terminal') 23 | ``` 24 | 25 | ``` 26 | 0 1 2 3 4 5 6 27 | 0 ╔═╤═╤═╤═╤═╤═╗ 28 | 1 ╟─┼─┼─┼─┼─┼─╢ 29 | 2 ╟─┼─┼─┼─┼─○─╢ 30 | 3 ╟─┼─┼─┼─┼─┼─╢ 31 | 4 ╟─┼─┼─┼─┼─┼─╢ 32 | 5 ╟─┼─┼─┼─┼─┼─╢ 33 | 6 ╚═╧═╧═╧═╧═╧═╝ 34 | Turn: WHITE, Game State (ONGOING|PASSED|END): ONGOING 35 | Black Area: 49, White Area: 0 36 | ``` 37 | 38 | ```python 39 | state, reward, done, info = go_env.step(second_action) 40 | go_env.render('terminal') 41 | ``` 42 | 43 | ``` 44 | 0 1 2 3 4 5 6 45 | 0 ╔═╤═╤═╤═╤═╤═╗ 46 | 1 ╟─┼─┼─┼─┼─┼─╢ 47 | 2 ╟─┼─┼─┼─┼─○─╢ 48 | 3 ╟─┼─┼─┼─┼─┼─╢ 49 | 4 ╟─┼─┼─┼─┼─┼─╢ 50 | 5 ╟─┼─●─┼─┼─┼─╢ 51 | 6 ╚═╧═╧═╧═╧═╧═╝ 52 | Turn: BLACK, Game State (ONGOING|PASSED|END): ONGOING 53 | Black Area: 1, White Area: 1 54 | ``` 55 | 56 | ### UI example 57 | ```bash 58 | # In the root directory. 59 | # Defaults to a uniform random AI opponent. 60 | python3 demo.py 61 | ``` 62 | ![alt text](screenshots/human_ui.png) 63 | 64 | ### High level API 65 | [GoEnv](gym_go/envs/go_env.py) defines the Gym environment for Go. 66 | It contains the highest level API for basic Go usage. 67 | 68 | ### Low level API 69 | [GoGame](gym_go/gogame.py) is the set of low-level functions that defines all the game logic of Go. 70 | `GoEnv`'s high level API is built on `GoGame`. 71 | These sets of functions are intended for a more detailed and finetuned 72 | usage of Go. 73 | 74 | # Scoring 75 | We use Trump Taylor scoring, a simple area scoring, to determine the winner. A player's _area_ is defined as the number of empty points a 76 | player's pieces surround plus the number of player's pieces on the board. The _winner_ is the player with the larger 77 | area (a game is tied if both players have an equal amount of area on the board). 78 | 79 | There is also support for `komi`, a bias score constant to balance the advantage of black going first. 80 | By default `komi` is set to 0. 81 | 82 | # Game ending 83 | A game ends when both players pass consecutively 84 | 85 | # Reward methods 86 | Reward methods are in _black_'s perspective 87 | * **Real**: 88 | * If game ended: 89 | * `-1` - White won 90 | * `0` - Game is tied 91 | * `1` - Black won 92 | * `0` - Otherwise 93 | * **Heuristic**: If the game is ongoing, the reward is `black area - white area`. 94 | If black won, the reward is `BOARD_SIZE**2`. 95 | If white won, the reward is `-BOARD_SIZE**2`. 96 | If tied, the reward is `0`. 97 | 98 | # State 99 | The `state` object that is returned by the `reset` and `step` functions of the environment is a 100 | `6 x BOARD_SIZE x BOARD_SIZE` numpy array. All values in the array are either `0` or `1` 101 | * **First and second channel:** represent the black and white pieces respectively. 102 | * **Third channel:** Indicator layer for whose turn it is 103 | * **Fourth channel:** Invalid moves (including ko-protection) for the next action 104 | * **Fifth channel:** Indicator layer for whether the previous move was a pass 105 | * **Sixth channel:** Indicator layer for whether the game is over 106 | 107 | # Action 108 | The `step` function takes in the action to execute and can be in the following forms: 109 | * a tuple/list of 2 integers representing the row and column or `None` for passing 110 | * a single integer representing the action in 1d space (i.e 9 would be (1,2) and 49 would be a pass for a 7x7 board) 111 | -------------------------------------------------------------------------------- /gym_go/rendering.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pyglet 3 | 4 | from gym_go import govars, gogame 5 | 6 | 7 | def draw_circle(x, y, color, radius): 8 | num_sides = 50 9 | verts = [x, y] 10 | colors = list(color) 11 | for i in range(num_sides + 1): 12 | verts.append(x + radius * np.cos(i * np.pi * 2 / num_sides)) 13 | verts.append(y + radius * np.sin(i * np.pi * 2 / num_sides)) 14 | colors.extend(color) 15 | pyglet.graphics.draw(len(verts) // 2, pyglet.gl.GL_TRIANGLE_FAN, 16 | ('v2f', verts), ('c3f', colors)) 17 | 18 | 19 | def draw_command_labels(batch, window_width, window_height): 20 | pyglet.text.Label('Pass (p) | Reset (r) | Exit (e)', 21 | font_name='Helvetica', 22 | font_size=11, 23 | x=20, y=window_height - 20, anchor_y='top', batch=batch, multiline=True, width=window_width) 24 | 25 | 26 | def draw_info(batch, window_width, window_height, upper_grid_coord, state): 27 | turn = gogame.turn(state) 28 | turn_str = 'B' if turn == govars.BLACK else 'W' 29 | prev_player_passed = gogame.prev_player_passed(state) 30 | game_ended = gogame.game_ended(state) 31 | info_label = "Turn: {}\nPassed: {}\nGame: {}".format(turn_str, prev_player_passed, 32 | "OVER" if game_ended else "ONGOING") 33 | 34 | pyglet.text.Label(info_label, font_name='Helvetica', font_size=11, x=window_width - 20, y=window_height - 20, 35 | anchor_x='right', anchor_y='top', color=(0, 0, 0, 192), batch=batch, width=window_width / 2, 36 | align='right', multiline=True) 37 | 38 | # Areas 39 | black_area, white_area = gogame.areas(state) 40 | pyglet.text.Label("{}B | {}W".format(black_area, white_area), font_name='Helvetica', font_size=16, 41 | x=window_width / 2, y=upper_grid_coord + 80, anchor_x='center', color=(0, 0, 0, 192), batch=batch, 42 | width=window_width, align='center') 43 | 44 | 45 | def draw_title(batch, window_width, window_height): 46 | pyglet.text.Label("Go", font_name='Helvetica', font_size=20, bold=True, x=window_width / 2, y=window_height - 20, 47 | anchor_x='center', anchor_y='top', color=(0, 0, 0, 255), batch=batch, width=window_width / 2, 48 | align='center') 49 | 50 | 51 | def draw_grid(batch, delta, board_size, lower_grid_coord, upper_grid_coord): 52 | label_offset = 20 53 | left_coord = lower_grid_coord 54 | right_coord = lower_grid_coord 55 | ver_list = [] 56 | color_list = [] 57 | num_vert = 0 58 | for i in range(board_size): 59 | # horizontal 60 | ver_list.extend((lower_grid_coord, left_coord, 61 | upper_grid_coord, right_coord)) 62 | # vertical 63 | ver_list.extend((left_coord, lower_grid_coord, 64 | right_coord, upper_grid_coord)) 65 | color_list.extend([0.3, 0.3, 0.3] * 4) # black 66 | # label on the left 67 | pyglet.text.Label(str(i), 68 | font_name='Courier', font_size=11, 69 | x=lower_grid_coord - label_offset, y=left_coord, 70 | anchor_x='center', anchor_y='center', 71 | color=(0, 0, 0, 255), batch=batch) 72 | # label on the bottom 73 | pyglet.text.Label(str(i), 74 | font_name='Courier', font_size=11, 75 | x=left_coord, y=lower_grid_coord - label_offset, 76 | anchor_x='center', anchor_y='center', 77 | color=(0, 0, 0, 255), batch=batch) 78 | left_coord += delta 79 | right_coord += delta 80 | num_vert += 4 81 | batch.add(num_vert, pyglet.gl.GL_LINES, None, 82 | ('v2f/static', ver_list), ('c3f/static', color_list)) 83 | 84 | 85 | def draw_pieces(batch, lower_grid_coord, delta, piece_r, size, state): 86 | for i in range(size): 87 | for j in range(size): 88 | # black piece 89 | if state[0, i, j] == 1: 90 | draw_circle(lower_grid_coord + i * delta, lower_grid_coord + j * delta, 91 | [0.05882352963, 0.180392161, 0.2470588237], 92 | piece_r) # 0 for black 93 | 94 | # white piece 95 | if state[1, i, j] == 1: 96 | draw_circle(lower_grid_coord + i * delta, lower_grid_coord + j * delta, 97 | [0.9754120272] * 3, piece_r) # 255 for white 98 | -------------------------------------------------------------------------------- /gym_go/tests/test_invalid_moves.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | import gym 5 | import numpy as np 6 | 7 | from gym_go import govars 8 | 9 | 10 | class TestGoEnvInvalidMoves(unittest.TestCase): 11 | 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | self.env = gym.make('gym_go:go-v0', size=7, reward_method='real') 15 | 16 | def setUp(self): 17 | self.env.reset() 18 | 19 | def test_out_of_bounds_action(self): 20 | with self.assertRaises(Exception): 21 | self.env.step((-1, 0)) 22 | 23 | with self.assertRaises(Exception): 24 | self.env.step((0, 100)) 25 | 26 | def test_invalid_occupied_moves(self): 27 | # Test this 8 times at random 28 | for _ in range(8): 29 | self.env.reset() 30 | row = random.randint(0, 6) 31 | col = random.randint(0, 6) 32 | 33 | state, reward, done, info = self.env.step((row, col)) 34 | 35 | # Assert that the invalid layer is correct 36 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL]), 1) 37 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL] == 1), 1) 38 | self.assertEqual(state[govars.INVD_CHNL, row, col], 1) 39 | 40 | with self.assertRaises(Exception): 41 | self.env.step((row, col)) 42 | 43 | def test_invalid_ko_protection_moves(self): 44 | """ 45 | _, 1, 2, _, _, _, _, 46 | 47 | 3, 8, 7/9, 4, _, _, _, 48 | 49 | _, 5, 6, _, _, _, _, 50 | 51 | _, _, _, _, _, _, _, 52 | 53 | _, _, _, _, _, _, _, 54 | 55 | _, _, _, _, _, _, _, 56 | 57 | _, _, _, _, _, _, _, 58 | 59 | :return: 60 | """ 61 | 62 | for move in [(0, 1), (0, 2), (1, 0), (1, 3), (2, 1), (2, 2), (1, 2), (1, 1)]: 63 | state, reward, done, info = self.env.step(move) 64 | 65 | # Test invalid channel 66 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL]), 8, state[govars.INVD_CHNL]) 67 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL] == 1), 8) 68 | self.assertEqual(state[govars.INVD_CHNL, 1, 2], 1) 69 | 70 | # Assert pieces channel is empty at ko-protection coordinate 71 | self.assertEqual(state[govars.BLACK, 1, 2], 0) 72 | self.assertEqual(state[govars.WHITE, 1, 2], 0) 73 | 74 | final_move = (1, 2) 75 | with self.assertRaises(Exception): 76 | self.env.step(final_move) 77 | 78 | # Assert ko-protection goes off 79 | state, reward, done, info = self.env.step((6, 6)) 80 | state, reward, done, info = self.env.step(None) 81 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL]), 8) 82 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL] == 1), 8) 83 | self.assertEqual(state[govars.INVD_CHNL, 1, 2], 0) 84 | 85 | def test_invalid_ko_wall_protection_moves(self): 86 | """ 87 | 2/8, 7, 6, _, _, _, _, 88 | 89 | 1, 4, _, _, _, _, _, 90 | 91 | _, _, _, _, _, _, _, 92 | 93 | _, _, _, _, _, _, _, 94 | 95 | _, _, _, _, _, _, _, 96 | 97 | _, _, _, _, _, _, _, 98 | 99 | _, _, _, _, _, _, _, 100 | 101 | :return: 102 | """ 103 | 104 | for move in [(1, 0), (0, 0), None, (1, 1), None, (0, 2), (0, 1)]: 105 | state, reward, done, info = self.env.step(move) 106 | 107 | # Test invalid channel 108 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL]), 5, state[govars.INVD_CHNL]) 109 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL] == 1), 5) 110 | self.assertEqual(state[govars.INVD_CHNL, 0, 0], 1) 111 | 112 | # Assert pieces channel is empty at ko-protection coordinate 113 | self.assertEqual(state[govars.BLACK, 0, 0], 0) 114 | self.assertEqual(state[govars.WHITE, 0, 0], 0) 115 | 116 | final_move = (0, 0) 117 | with self.assertRaises(Exception): 118 | self.env.step(final_move) 119 | 120 | # Assert ko-protection goes off 121 | state, reward, done, info = self.env.step((6, 6)) 122 | state, reward, done, info = self.env.step(None) 123 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL]), 5) 124 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL] == 1), 5) 125 | self.assertEqual(state[govars.INVD_CHNL, 0, 0], 0) 126 | 127 | def test_invalid_no_liberty_move(self): 128 | """ 129 | _, 1, 2, _, _, _, _, 130 | 131 | 3, 8, 7, _, 4, _, _, 132 | 133 | _, 5, 6, _, _, _, _, 134 | 135 | _, _, _, _, _, _, _, 136 | 137 | _, _, _, _, _, _, _, 138 | 139 | _, _, _, _, _, _, _, 140 | 141 | _, _, _, _, _, _, _, 142 | 143 | :return: 144 | """ 145 | for move in [(0, 1), (0, 2), (1, 0), (1, 4), (2, 1), (2, 2), (1, 2)]: 146 | state, reward, done, info = self.env.step(move) 147 | 148 | # Test invalid channel 149 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL]), 9, state[govars.INVD_CHNL]) 150 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL] == 1), 9) 151 | self.assertEqual(state[govars.INVD_CHNL, 1, 1], 1) 152 | self.assertEqual(state[govars.INVD_CHNL, 0, 0], 1) 153 | # Assert empty space in pieces channels 154 | self.assertEqual(state[govars.BLACK, 1, 1], 0) 155 | self.assertEqual(state[govars.WHITE, 1, 1], 0) 156 | self.assertEqual(state[govars.BLACK, 0, 0], 0) 157 | self.assertEqual(state[govars.WHITE, 0, 0], 0) 158 | 159 | final_move = (1, 1) 160 | with self.assertRaises(Exception): 161 | self.env.step(final_move) 162 | 163 | def test_invalid_game_already_over_move(self): 164 | self.env.step(None) 165 | self.env.step(None) 166 | 167 | with self.assertRaises(Exception): 168 | self.env.step(None) 169 | 170 | self.env.reset() 171 | 172 | self.env.step(None) 173 | self.env.step(None) 174 | 175 | with self.assertRaises(Exception): 176 | self.env.step((0, 0)) 177 | 178 | def test_small_suicide(self): 179 | """ 180 | 7, 8, 0, 181 | 182 | 0, 5, 4, 183 | 184 | 1, 2, 3/6, 185 | :return: 186 | """ 187 | 188 | self.env = gym.make('gym_go:go-v0', size=3, reward_method='real') 189 | for move in [6, 7, 8, 5, 4, 8, 0, 1]: 190 | state, reward, done, info = self.env.step(move) 191 | 192 | with self.assertRaises(Exception): 193 | self.env.step(3) 194 | 195 | def test_invalid_after_capture(self): 196 | """ 197 | 1, 5, 6, 198 | 199 | 7, 4, _, 200 | 201 | 3, 8, 2, 202 | :return: 203 | """ 204 | 205 | self.env = gym.make('gym_go:go-v0', size=3, reward_method='real') 206 | for move in [0, 8, 6, 4, 1, 2, 3, 7]: 207 | state, reward, done, info = self.env.step(move) 208 | 209 | with self.assertRaises(Exception): 210 | self.env.step(5) 211 | 212 | def test_cannot_capture_groups_with_multiple_holes(self): 213 | """ 214 | _, 2, 4, 6, 8, 10, _, 215 | 216 | 32, 1, 3, 5, 7, 9, 12, 217 | 218 | 30, 25, 34, 19, _, 11, 14, 219 | 220 | 28, 23, 21, 17, 15, 13, 16, 221 | 222 | _, 26, 24, 22, 20, 18, _, 223 | 224 | _, _, _, _, _, _, _, 225 | 226 | _, _, _, _, _, _, _, 227 | 228 | :return: 229 | """ 230 | for move in [(1, 1), (0, 1), (1, 2), (0, 2), (1, 3), (0, 3), (1, 4), (0, 4), (1, 5), (0, 5), (2, 5), (1, 6), 231 | (3, 5), (2, 6), (3, 4), (3, 6), 232 | (3, 3), (4, 5), (2, 3), (4, 4), (3, 2), (4, 3), (3, 1), (4, 2), (2, 1), (4, 1), None, (3, 0), None, 233 | (2, 0), None, (1, 0)]: 234 | state, reward, done, info = self.env.step(move) 235 | 236 | self.env.step(None) 237 | final_move = (2, 2) 238 | with self.assertRaises(Exception): 239 | self.env.step(final_move) 240 | 241 | 242 | if __name__ == '__main__': 243 | unittest.main() 244 | -------------------------------------------------------------------------------- /gym_go/tests/test_valid_moves.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from gym_go import govars 7 | 8 | 9 | class TestGoEnvValidMoves(unittest.TestCase): 10 | 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.env = gym.make('gym_go:go-v0', size=7, reward_method='real') 14 | 15 | def setUp(self): 16 | self.env.reset() 17 | 18 | def test_simple_valid_moves(self): 19 | for i in range(7): 20 | state, reward, done, info = self.env.step((0, i)) 21 | self.assertEqual(done, False) 22 | 23 | self.env.reset() 24 | 25 | for i in range(7): 26 | state, reward, done, info = self.env.step((i, i)) 27 | self.assertEqual(done, False) 28 | 29 | self.env.reset() 30 | 31 | for i in range(7): 32 | state, reward, done, info = self.env.step((i, 0)) 33 | self.assertEqual(done, False) 34 | 35 | def test_valid_no_liberty_move(self): 36 | """ 37 | _, 1, 2, _, _, _, _, 38 | 39 | 3, 8, 7, 4, _, _, _, 40 | 41 | _, 5, 6, _, _, _, _, 42 | 43 | _, _, _, _, _, _, _, 44 | 45 | _, _, _, _, _, _, _, 46 | 47 | _, _, _, _, _, _, _, 48 | 49 | _, _, _, _, _, _, _, 50 | 51 | 52 | :return: 53 | """ 54 | for move in [(0, 1), (0, 2), (1, 0), (1, 3), (2, 1), (2, 2), (1, 2), (1, 1)]: 55 | state, reward, done, info = self.env.step(move) 56 | 57 | # Black should have 3 pieces 58 | self.assertEqual(np.count_nonzero(state[govars.BLACK]), 3) 59 | 60 | # White should have 4 pieces 61 | self.assertEqual(np.count_nonzero(state[govars.WHITE]), 4) 62 | # Assert values are ones 63 | self.assertEqual(np.count_nonzero(state[govars.WHITE] == 1), 4) 64 | 65 | def test_valid_no_liberty_capture(self): 66 | """ 67 | 1, 7, 2, 3, _, _, _, 68 | 69 | 6, 4, 5, _, _, _, _, 70 | 71 | _, _, _, _, _, _, _, 72 | 73 | _, _, _, _, _, _, _, 74 | 75 | _, _, _, _, _, _, _, 76 | 77 | _, _, _, _, _, _, _, 78 | 79 | _, _, _, _, _, _, _, 80 | 81 | :return: 82 | """ 83 | for move in [(0, 0), (0, 2), (0, 3), (1, 1), (1, 2), (1, 0)]: 84 | state, reward, done, info = self.env.step(move) 85 | 86 | # Test invalid channel 87 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL]), 6, state[govars.INVD_CHNL]) 88 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL] == 1), 6) 89 | self.assertEqual(state[govars.INVD_CHNL, 0, 1], 0, state[govars.INVD_CHNL]) 90 | # Assert empty space in pieces channels 91 | self.assertEqual(state[govars.BLACK, 0, 1], 0) 92 | self.assertEqual(state[govars.WHITE, 0, 1], 0) 93 | 94 | final_move = (0, 1) 95 | state, reward, done, info = self.env.step(final_move) 96 | 97 | # White should only have 2 pieces 98 | self.assertEqual(np.count_nonzero(state[govars.WHITE]), 2, state[govars.WHITE]) 99 | self.assertEqual(np.count_nonzero(state[govars.WHITE] == 1), 2) 100 | # Black should have 4 pieces 101 | self.assertEqual(np.count_nonzero(state[govars.BLACK]), 4, state[govars.BLACK]) 102 | self.assertEqual(np.count_nonzero(state[govars.BLACK] == 1), 4) 103 | 104 | def test_simple_capture(self): 105 | """ 106 | _, 1, _, _, _, _, _, 107 | 108 | 3, 2, 5, _, _, _, _, 109 | 110 | _, 7, _, _, _, _, _, 111 | 112 | _, _, _, _, _, _, _, 113 | 114 | _, _, _, _, _, _, _, 115 | 116 | _, _, _, _, _, _, _, 117 | 118 | _, _, _, _, _, _, _, 119 | 120 | :return: 121 | """ 122 | 123 | for move in [(0, 1), (1, 1), (1, 0), None, (1, 2), None, (2, 1)]: 124 | state, reward, done, info = self.env.step(move) 125 | 126 | # White should have no pieces 127 | self.assertEqual(np.count_nonzero(state[govars.WHITE]), 0) 128 | 129 | # Black should have 4 pieces 130 | self.assertEqual(np.count_nonzero(state[govars.BLACK]), 4) 131 | # Assert values are ones 132 | self.assertEqual(np.count_nonzero(state[govars.BLACK] == 1), 4) 133 | 134 | def test_large_group_capture(self): 135 | """ 136 | _, _, _, _, _, _, _, 137 | 138 | _, _, 2, 4, 6, _, _, 139 | 140 | _, 20, 1, 3, 5, 8, _, 141 | 142 | _, 18, 11, 9, 7, 10, _, 143 | 144 | _, _, 16, 14, 12, _, _, 145 | 146 | _, _, _, _, _, _, _, 147 | 148 | _, _, _, _, _, _, _, 149 | 150 | :return: 151 | """ 152 | for move in [(2, 2), (1, 2), (2, 3), (1, 3), (2, 4), (1, 4), (3, 4), (2, 5), (3, 3), (3, 5), (3, 2), (4, 4), 153 | None, (4, 3), None, (4, 2), None, 154 | (3, 1), None, (2, 1)]: 155 | state, reward, done, info = self.env.step(move) 156 | 157 | # Black should have no pieces 158 | self.assertEqual(np.count_nonzero(state[govars.BLACK]), 0) 159 | 160 | # White should have 10 pieces 161 | self.assertEqual(np.count_nonzero(state[govars.WHITE]), 10) 162 | # Assert they are ones 163 | self.assertEqual(np.count_nonzero(state[govars.WHITE] == 1), 10) 164 | 165 | def test_large_group_suicide(self): 166 | """ 167 | _, _, _, _, _, _, _, 168 | 169 | _, _, _, _, _, _, _, 170 | 171 | _, _, _, _, _, _, _, 172 | 173 | _, _, _, _, _, _, _, 174 | 175 | 1, 3, _, _, _, _, _, 176 | 177 | 4, 6, 5, _, _, _, _, 178 | 179 | 2, 8, 7, _, _, _, _, 180 | 181 | :return: 182 | """ 183 | for move in [(4, 0), (6, 0), (4, 1), (5, 0), (5, 2), (5, 1), (6, 2)]: 184 | state, reward, done, info = self.env.step(move) 185 | 186 | # Test invalid channel 187 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL]), 8, state[govars.INVD_CHNL]) 188 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL] == 1), 8) 189 | # Assert empty space in pieces channels 190 | self.assertEqual(state[govars.BLACK, 6, 1], 0) 191 | self.assertEqual(state[govars.WHITE, 6, 1], 0) 192 | 193 | final_move = (6, 1) 194 | with self.assertRaises(Exception): 195 | self.env.step(final_move) 196 | 197 | def test_group_edge_capture(self): 198 | """ 199 | 1, 3, 2, _, _, _, _, 200 | 201 | 7, 5, 4, _, _, _, _, 202 | 203 | 8, 6, _, _, _, _, _, 204 | 205 | _, _, _, _, _, _, _, 206 | 207 | _, _, _, _, _, _, _, 208 | 209 | _, _, _, _, _, _, _, 210 | 211 | _, _, _, _, _, _, _, 212 | 213 | :return: 214 | """ 215 | 216 | for move in [(0, 0), (0, 2), (0, 1), (1, 2), (1, 1), (2, 1), (1, 0), (2, 0)]: 217 | state, reward, done, info = self.env.step(move) 218 | 219 | # Black should have no pieces 220 | self.assertEqual(np.count_nonzero(state[govars.BLACK]), 0) 221 | 222 | # White should have 4 pieces 223 | self.assertEqual(np.count_nonzero(state[govars.WHITE]), 4) 224 | # Assert they are ones 225 | self.assertEqual(np.count_nonzero(state[govars.WHITE] == 1), 4) 226 | 227 | def test_group_kill_no_ko_protection(self): 228 | """ 229 | Thanks to DeepGeGe for finding this bug. 230 | 231 | _, _, _, _, 2, 1, 13, 232 | 233 | _, _, _, _, 4, 3, 12/14, 234 | 235 | _, _, _, _, 6, 5, 7, 236 | 237 | _, _, _, _, _, 8, 10, 238 | 239 | _, _, _, _, _, _, _, 240 | 241 | _, _, _, _, _, _, _, 242 | 243 | _, _, _, _, _, _, _, 244 | 245 | :return: 246 | """ 247 | 248 | for move in [(0, 5), (0, 4), (1, 5), (1, 4), (2, 5), (2, 4), (2, 6), (3, 5), None, (3, 6), None, (1, 6), 249 | (0, 6)]: 250 | state, reward, done, info = self.env.step(move) 251 | 252 | # Test final kill move (1, 6) is valid 253 | final_move = (1, 6) 254 | self.assertEqual(state[govars.INVD_CHNL, 1, 6], 0) 255 | state, _, _, _ = self.env.step(final_move) 256 | 257 | # Assert black is removed 258 | self.assertEqual(state[govars.BLACK].sum(), 0) 259 | 260 | # Assert 6 white pieces still on the board 261 | self.assertEqual(state[govars.WHITE].sum(), 6) 262 | 263 | 264 | if __name__ == '__main__': 265 | unittest.main() 266 | -------------------------------------------------------------------------------- /gym_go/envs/go_env.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from gym_go import govars, rendering, gogame 7 | 8 | 9 | class RewardMethod(Enum): 10 | """ 11 | REAL: 0 = game is ongoing, 1 = black won, -1 = game tied or white won 12 | HEURISTIC: If game is ongoing, the reward is the area difference between black and white. 13 | Otherwise the game has ended, and if black has more area, the reward is BOARD_SIZE**2, otherwise it's -BOARD_SIZE**2 14 | """ 15 | REAL = 'real' 16 | HEURISTIC = 'heuristic' 17 | 18 | 19 | class GoEnv(gym.Env): 20 | metadata = {'render.modes': ['terminal', 'human']} 21 | govars = govars 22 | gogame = gogame 23 | 24 | def __init__(self, size, komi=0, reward_method='real'): 25 | ''' 26 | @param reward_method: either 'heuristic' or 'real' 27 | heuristic: gives # black pieces - # white pieces. 28 | real: gives 0 for in-game move, 1 for winning, -1 for losing, 29 | 0 for draw, all from black player's perspective 30 | ''' 31 | self.size = size 32 | self.komi = komi 33 | self.state_ = gogame.init_state(size) 34 | self.reward_method = RewardMethod(reward_method) 35 | self.observation_space = gym.spaces.Box(np.float32(0), np.float32(govars.NUM_CHNLS), 36 | shape=(govars.NUM_CHNLS, size, size)) 37 | self.action_space = gym.spaces.Discrete(gogame.action_size(self.state_)) 38 | self.done = False 39 | 40 | def reset(self): 41 | ''' 42 | Reset state, go_board, curr_player, prev_player_passed, 43 | done, return state 44 | ''' 45 | self.state_ = gogame.init_state(self.size) 46 | self.done = False 47 | return np.copy(self.state_) 48 | 49 | def step(self, action): 50 | ''' 51 | Assumes the correct player is making a move. Black goes first. 52 | return observation, reward, done, info 53 | ''' 54 | assert not self.done 55 | if isinstance(action, tuple) or isinstance(action, list) or isinstance(action, np.ndarray): 56 | assert 0 <= action[0] < self.size 57 | assert 0 <= action[1] < self.size 58 | action = self.size * action[0] + action[1] 59 | elif action is None: 60 | action = self.size ** 2 61 | 62 | self.state_ = gogame.next_state(self.state_, action, canonical=False) 63 | self.done = gogame.game_ended(self.state_) 64 | return np.copy(self.state_), self.reward(), self.done, self.info() 65 | 66 | def game_ended(self): 67 | return self.done 68 | 69 | def turn(self): 70 | return gogame.turn(self.state_) 71 | 72 | def prev_player_passed(self): 73 | return gogame.prev_player_passed(self.state_) 74 | 75 | def valid_moves(self): 76 | return gogame.valid_moves(self.state_) 77 | 78 | def uniform_random_action(self): 79 | valid_moves = self.valid_moves() 80 | valid_move_idcs = np.argwhere(valid_moves).flatten() 81 | return np.random.choice(valid_move_idcs) 82 | 83 | def info(self): 84 | """ 85 | :return: Debugging info for the state 86 | """ 87 | return { 88 | 'turn': gogame.turn(self.state_), 89 | 'invalid_moves': gogame.invalid_moves(self.state_), 90 | 'prev_player_passed': gogame.prev_player_passed(self.state_), 91 | } 92 | 93 | def state(self): 94 | """ 95 | :return: copy of state 96 | """ 97 | return np.copy(self.state_) 98 | 99 | def canonical_state(self): 100 | """ 101 | :return: canonical shallow copy of state 102 | """ 103 | return gogame.canonical_form(self.state_) 104 | 105 | def children(self, canonical=False, padded=True): 106 | """ 107 | :return: Same as get_children, but in canonical form 108 | """ 109 | return gogame.children(self.state_, canonical, padded) 110 | 111 | def winning(self): 112 | """ 113 | :return: Who's currently winning in BLACK's perspective, regardless if the game is over 114 | """ 115 | return gogame.winning(self.state_, self.komi) 116 | 117 | def winner(self): 118 | """ 119 | Get's the winner in BLACK's perspective 120 | :return: 121 | """ 122 | 123 | if self.game_ended(): 124 | return self.winning() 125 | else: 126 | return 0 127 | 128 | def reward(self): 129 | ''' 130 | Return reward based on reward_method. 131 | heuristic: black total area - white total area 132 | real: 0 for in-game move, 1 for winning, 0 for losing, 133 | 0.5 for draw, from black player's perspective. 134 | Winning and losing based on the Area rule 135 | Also known as Trump Taylor Scoring 136 | Area rule definition: https://en.wikipedia.org/wiki/Rules_of_Go#End 137 | ''' 138 | if self.reward_method == RewardMethod.REAL: 139 | return self.winner() 140 | 141 | elif self.reward_method == RewardMethod.HEURISTIC: 142 | black_area, white_area = gogame.areas(self.state_) 143 | area_difference = black_area - white_area 144 | komi_correction = area_difference - self.komi 145 | if self.game_ended(): 146 | return (1 if komi_correction > 0 else -1) * self.size ** 2 147 | return komi_correction 148 | else: 149 | raise Exception("Unknown Reward Method") 150 | 151 | def __str__(self): 152 | return gogame.str(self.state_) 153 | 154 | def close(self): 155 | if hasattr(self, 'window'): 156 | assert hasattr(self, 'pyglet') 157 | self.window.close() 158 | self.pyglet.app.exit() 159 | 160 | def render(self, mode='terminal'): 161 | if mode == 'terminal': 162 | print(self.__str__()) 163 | elif mode == 'human': 164 | import pyglet 165 | from pyglet.window import mouse 166 | from pyglet.window import key 167 | 168 | screen = pyglet.canvas.get_display().get_default_screen() 169 | window_width = int(min(screen.width, screen.height) * 2 / 3) 170 | window_height = int(window_width * 1.2) 171 | window = pyglet.window.Window(window_width, window_height) 172 | 173 | self.window = window 174 | self.pyglet = pyglet 175 | self.user_action = None 176 | 177 | # Set Cursor 178 | cursor = window.get_system_mouse_cursor(window.CURSOR_CROSSHAIR) 179 | window.set_mouse_cursor(cursor) 180 | 181 | # Outlines 182 | lower_grid_coord = window_width * 0.075 183 | board_size = window_width * 0.85 184 | upper_grid_coord = board_size + lower_grid_coord 185 | delta = board_size / (self.size - 1) 186 | piece_r = delta / 3.3 # radius 187 | 188 | @window.event 189 | def on_draw(): 190 | pyglet.gl.glClearColor(0.7, 0.5, 0.3, 1) 191 | window.clear() 192 | 193 | pyglet.gl.glLineWidth(3) 194 | batch = pyglet.graphics.Batch() 195 | 196 | # draw the grid and labels 197 | rendering.draw_grid(batch, delta, self.size, lower_grid_coord, upper_grid_coord) 198 | 199 | # info on top of the board 200 | rendering.draw_info(batch, window_width, window_height, upper_grid_coord, self.state_) 201 | 202 | # Inform user what they can do 203 | rendering.draw_command_labels(batch, window_width, window_height) 204 | 205 | rendering.draw_title(batch, window_width, window_height) 206 | 207 | batch.draw() 208 | 209 | # draw the pieces 210 | rendering.draw_pieces(batch, lower_grid_coord, delta, piece_r, self.size, self.state_) 211 | 212 | @window.event 213 | def on_mouse_press(x, y, button, modifiers): 214 | if button == mouse.LEFT: 215 | grid_x = (x - lower_grid_coord) 216 | grid_y = (y - lower_grid_coord) 217 | x_coord = round(grid_x / delta) 218 | y_coord = round(grid_y / delta) 219 | try: 220 | self.window.close() 221 | pyglet.app.exit() 222 | self.user_action = (x_coord, y_coord) 223 | except: 224 | pass 225 | 226 | @window.event 227 | def on_key_press(symbol, modifiers): 228 | if symbol == key.P: 229 | self.window.close() 230 | pyglet.app.exit() 231 | self.user_action = None 232 | elif symbol == key.R: 233 | self.reset() 234 | self.window.close() 235 | pyglet.app.exit() 236 | elif symbol == key.E: 237 | self.window.close() 238 | pyglet.app.exit() 239 | self.user_action = -1 240 | 241 | pyglet.app.run() 242 | 243 | return self.user_action 244 | -------------------------------------------------------------------------------- /gym_go/tests/test_basics.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from gym_go import govars 7 | 8 | 9 | class TestGoEnvBasics(unittest.TestCase): 10 | 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.env = gym.make('gym_go:go-v0', size=7, reward_method='real') 14 | 15 | def setUp(self): 16 | self.env.reset() 17 | 18 | def test_state(self): 19 | env = gym.make('gym_go:go-v0', size=7) 20 | state = env.reset() 21 | self.assertIsInstance(state, np.ndarray) 22 | self.assertEqual(state.shape[0], govars.NUM_CHNLS) 23 | 24 | env.close() 25 | 26 | def test_board_sizes(self): 27 | expected_sizes = [7, 13, 19] 28 | 29 | for expec_size in expected_sizes: 30 | env = gym.make('gym_go:go-v0', size=expec_size) 31 | state = env.reset() 32 | self.assertEqual(state.shape[1], expec_size) 33 | self.assertEqual(state.shape[2], expec_size) 34 | 35 | env.close() 36 | 37 | def test_empty_board(self): 38 | state = self.env.reset() 39 | self.assertEqual(np.count_nonzero(state), 0) 40 | 41 | def test_reset(self): 42 | state, reward, done, info = self.env.step((0, 0)) 43 | self.assertEqual(np.count_nonzero(state[[govars.BLACK, govars.WHITE, govars.INVD_CHNL]]), 2) 44 | self.assertEqual(np.count_nonzero(state), 51) 45 | state = self.env.reset() 46 | self.assertEqual(np.count_nonzero(state), 0) 47 | 48 | def test_preserve_original_state(self): 49 | state = self.env.reset() 50 | original_state = np.copy(state) 51 | self.env.gogame.next_state(state, 0) 52 | assert (original_state == state).all() 53 | 54 | def test_black_moves_first(self): 55 | """ 56 | Make a move at 0,0 and assert that a black piece was placed 57 | :return: 58 | """ 59 | next_state, reward, done, info = self.env.step((0, 0)) 60 | self.assertEqual(next_state[govars.BLACK, 0, 0], 1) 61 | self.assertEqual(next_state[govars.WHITE, 0, 0], 0) 62 | 63 | def test_turns(self): 64 | for i in range(7): 65 | # For the first move at i == 0, black went so now it should be white's turn 66 | state, reward, done, info = self.env.step((i, 0)) 67 | self.assertIn('turn', info) 68 | self.assertEqual(info['turn'], 1 if i % 2 == 0 else 0) 69 | 70 | def test_multiple_action_formats(self): 71 | for _ in range(10): 72 | action_1d = np.random.randint(50) 73 | action_2d = None if action_1d == 49 else (action_1d // 7, action_1d % 7) 74 | 75 | self.env.reset() 76 | state_from_1d, _, _, _ = self.env.step(action_1d) 77 | 78 | self.env.reset() 79 | state_from_2d, _, _, _ = self.env.step(action_2d) 80 | 81 | self.assertTrue((state_from_1d == state_from_2d).all()) 82 | 83 | def test_passing(self): 84 | """ 85 | None indicates pass 86 | :return: 87 | """ 88 | 89 | # Pass on first move 90 | state, reward, done, info = self.env.step(None) 91 | # Expect empty board still 92 | self.assertEqual(np.count_nonzero(state[[govars.BLACK, govars.WHITE]]), 0) 93 | # Expect passing layer and turn layer channels to be all ones 94 | self.assertEqual(np.count_nonzero(state), 98, state) 95 | self.assertEqual(np.count_nonzero(state[govars.PASS_CHNL]), 49) 96 | self.assertEqual(np.count_nonzero(state[govars.PASS_CHNL] == 1), 49) 97 | 98 | self.assertIn('turn', info) 99 | self.assertEqual(info['turn'], 1) 100 | 101 | # Make a move 102 | state, reward, done, info = self.env.step((0, 0)) 103 | 104 | # Expect the passing layer channel to be empty 105 | self.assertEqual(np.count_nonzero(state), 2) 106 | self.assertEqual(np.count_nonzero(state[govars.WHITE]), 1) 107 | self.assertEqual(np.count_nonzero(state[govars.WHITE] == 1), 1) 108 | self.assertEqual(np.count_nonzero(state[govars.PASS_CHNL]), 0) 109 | 110 | # Pass on second move 111 | self.env.reset() 112 | state, reward, done, info = self.env.step((0, 0)) 113 | # Expect two pieces (one in the invalid channel) 114 | # Plus turn layer is all ones 115 | self.assertEqual(np.count_nonzero(state), 51, state) 116 | self.assertEqual(np.count_nonzero(state[[govars.BLACK, govars.WHITE, govars.INVD_CHNL]]), 2, state) 117 | 118 | self.assertIn('turn', info) 119 | self.assertEqual(info['turn'], 1) 120 | 121 | # Pass 122 | state, reward, done, info = self.env.step(None) 123 | # Expect two pieces (one in the invalid channel) 124 | self.assertEqual(np.count_nonzero(state[[govars.BLACK, govars.WHITE, govars.INVD_CHNL]]), 2, 125 | state[[govars.BLACK, govars.WHITE, govars.INVD_CHNL]]) 126 | self.assertIn('turn', info) 127 | self.assertEqual(info['turn'], 0) 128 | 129 | def test_game_ends(self): 130 | state, reward, done, info = self.env.step(None) 131 | self.assertFalse(done) 132 | state, reward, done, info = self.env.step(None) 133 | self.assertTrue(done) 134 | 135 | self.env.reset() 136 | 137 | state, reward, done, info = self.env.step((0, 0)) 138 | self.assertFalse(done) 139 | state, reward, done, info = self.env.step(None) 140 | self.assertFalse(done) 141 | state, reward, done, info = self.env.step(None) 142 | self.assertTrue(done) 143 | 144 | def test_game_does_not_end_with_disjoint_passes(self): 145 | state, reward, done, info = self.env.step(None) 146 | self.assertFalse(done) 147 | state, reward, done, info = self.env.step((0, 0)) 148 | self.assertFalse(done) 149 | state, reward, done, info = self.env.step(None) 150 | self.assertFalse(done) 151 | 152 | def test_num_liberties(self): 153 | env = gym.make('gym_go:go-v0', size=7) 154 | 155 | steps = [(0, 0), (0, 1)] 156 | libs = [(2, 0), (1, 2)] 157 | 158 | env.reset() 159 | for step, libs in zip(steps, libs): 160 | state, _, _, _ = env.step(step) 161 | blacklibs, whitelibs = env.gogame.num_liberties(state) 162 | self.assertEqual(blacklibs, libs[0], state) 163 | self.assertEqual(whitelibs, libs[1], state) 164 | 165 | steps = [(2, 1), None, (1, 2), None, (2, 3), None, (3, 2), None] 166 | libs = [(4, 0), (4, 0), (6, 0), (6, 0), (8, 0), (8, 0), (9, 0), (9, 0)] 167 | 168 | env.reset() 169 | for step, libs in zip(steps, libs): 170 | state, _, _, _ = env.step(step) 171 | blacklibs, whitelibs = env.gogame.num_liberties(state) 172 | self.assertEqual(blacklibs, libs[0], state) 173 | self.assertEqual(whitelibs, libs[1], state) 174 | 175 | def test_komi(self): 176 | env = gym.make('gym_go:go-v0', size=7, komi=2.5, reward_method='real') 177 | 178 | # White win 179 | _ = env.step(None) 180 | state, reward, done, info = env.step(None) 181 | self.assertEqual(-1, reward) 182 | 183 | # White still win 184 | env.reset() 185 | _ = env.step(0) 186 | _ = env.step(2) 187 | 188 | _ = env.step(1) 189 | _ = env.step(None) 190 | 191 | state, reward, done, info = env.step(None) 192 | self.assertEqual(-1, reward) 193 | 194 | # Black win 195 | env.reset() 196 | _ = env.step(0) 197 | _ = env.step(None) 198 | 199 | _ = env.step(1) 200 | _ = env.step(None) 201 | 202 | _ = env.step(2) 203 | _ = env.step(None) 204 | state, reward, done, info = env.step(None) 205 | self.assertEqual(1, reward) 206 | 207 | env.close() 208 | 209 | def test_children(self): 210 | for canonical in [False, True]: 211 | for _ in range(20): 212 | action = self.env.uniform_random_action() 213 | self.env.step(action) 214 | state = self.env.state() 215 | children = self.env.children(canonical, padded=True) 216 | valid_moves = self.env.valid_moves() 217 | for a in range(len(valid_moves)): 218 | if valid_moves[a]: 219 | child = self.env.gogame.next_state(state, a, canonical) 220 | equal = children[a] == child 221 | self.assertTrue(equal.all(), (canonical, np.argwhere(~equal))) 222 | else: 223 | self.assertTrue((children[a] == 0).all()) 224 | 225 | def test_real_reward(self): 226 | env = gym.make('gym_go:go-v0', size=7, reward_method='real') 227 | 228 | # In game 229 | state, reward, done, info = env.step((0, 0)) 230 | self.assertEqual(reward, 0) 231 | state, reward, done, info = env.step(None) 232 | self.assertEqual(reward, 0) 233 | 234 | # Win 235 | state, reward, done, info = env.step(None) 236 | self.assertEqual(reward, 1) 237 | 238 | # Lose 239 | env.reset() 240 | 241 | state, reward, done, info = env.step(None) 242 | self.assertEqual(reward, 0) 243 | state, reward, done, info = env.step((0, 0)) 244 | self.assertEqual(reward, 0) 245 | state, reward, done, info = env.step(None) 246 | self.assertEqual(reward, 0) 247 | state, reward, done, info = env.step(None) 248 | self.assertEqual(reward, -1) 249 | 250 | # Tie 251 | env.reset() 252 | 253 | state, reward, done, info = env.step(None) 254 | self.assertEqual(reward, 0) 255 | state, reward, done, info = env.step(None) 256 | self.assertEqual(reward, 0) 257 | 258 | env.close() 259 | 260 | def test_heuristic_reward(self): 261 | env = gym.make('gym_go:go-v0', size=7, reward_method='heuristic') 262 | 263 | # In game 264 | state, reward, done, info = env.step((0, 0)) 265 | self.assertEqual(reward, 49) 266 | state, reward, done, info = env.step((0, 1)) 267 | self.assertEqual(reward, 0) 268 | state, reward, done, info = env.step(None) 269 | self.assertEqual(reward, 0) 270 | state, reward, done, info = env.step((1, 0)) 271 | self.assertEqual(reward, -49) 272 | 273 | # Lose 274 | state, reward, done, info = env.step(None) 275 | self.assertEqual(reward, -49) 276 | state, reward, done, info = env.step(None) 277 | self.assertEqual(reward, -49) 278 | 279 | # Win 280 | env.reset() 281 | 282 | state, reward, done, info = env.step((0, 0)) 283 | self.assertEqual(reward, 49) 284 | state, reward, done, info = env.step(None) 285 | self.assertEqual(reward, 49) 286 | state, reward, done, info = env.step(None) 287 | self.assertEqual(reward, 49) 288 | 289 | env.close() 290 | 291 | 292 | if __name__ == '__main__': 293 | unittest.main() 294 | -------------------------------------------------------------------------------- /gym_go/state_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import ndimage 3 | from scipy.ndimage import measurements 4 | 5 | from gym_go import govars 6 | 7 | group_struct = np.array([[[0, 0, 0], 8 | [0, 0, 0], 9 | [0, 0, 0]], 10 | [[0, 1, 0], 11 | [1, 1, 1], 12 | [0, 1, 0]], 13 | [[0, 0, 0], 14 | [0, 0, 0], 15 | [0, 0, 0]]]) 16 | 17 | surround_struct = np.array([[0, 1, 0], 18 | [1, 0, 1], 19 | [0, 1, 0]]) 20 | 21 | neighbor_deltas = np.array([[-1, 0], [1, 0], [0, -1], [0, 1]]) 22 | 23 | 24 | def compute_invalid_moves(state, player, ko_protect=None): 25 | """ 26 | Updates invalid moves in the OPPONENT's perspective 27 | 1.) Opponent cannot move at a location 28 | i.) If it's occupied 29 | i.) If it's protected by ko 30 | 2.) Opponent can move at a location 31 | i.) If it can kill 32 | 3.) Opponent cannot move at a location 33 | i.) If it's adjacent to one of their groups with only one liberty and 34 | not adjacent to other groups with more than one liberty and is completely surrounded 35 | ii.) If it's surrounded by our pieces and all of those corresponding groups 36 | move more than one liberty 37 | """ 38 | 39 | # All pieces and empty spaces 40 | all_pieces = np.sum(state[[govars.BLACK, govars.WHITE]], axis=0) 41 | empties = 1 - all_pieces 42 | 43 | # Setup invalid and valid arrays 44 | possible_invalid_array = np.zeros(state.shape[1:]) 45 | definite_valids_array = np.zeros(state.shape[1:]) 46 | 47 | # Get all groups 48 | all_own_groups, num_own_groups = measurements.label(state[player]) 49 | all_opp_groups, num_opp_groups = measurements.label(state[1 - player]) 50 | expanded_own_groups = np.zeros((num_own_groups, *state.shape[1:])) 51 | expanded_opp_groups = np.zeros((num_opp_groups, *state.shape[1:])) 52 | 53 | # Expand the groups such that each group is in its own channel 54 | for i in range(num_own_groups): 55 | expanded_own_groups[i] = all_own_groups == (i + 1) 56 | 57 | for i in range(num_opp_groups): 58 | expanded_opp_groups[i] = all_opp_groups == (i + 1) 59 | 60 | # Get all liberties in the expanded form 61 | all_own_liberties = empties[np.newaxis] * ndimage.binary_dilation(expanded_own_groups, surround_struct[np.newaxis]) 62 | all_opp_liberties = empties[np.newaxis] * ndimage.binary_dilation(expanded_opp_groups, surround_struct[np.newaxis]) 63 | 64 | own_liberty_counts = np.sum(all_own_liberties, axis=(1, 2)) 65 | opp_liberty_counts = np.sum(all_opp_liberties, axis=(1, 2)) 66 | 67 | # Possible invalids are on single liberties of opponent groups and on multi-liberties of own groups 68 | # Definite valids are on single liberties of own groups, multi-liberties of opponent groups 69 | # or you are not surrounded 70 | possible_invalid_array += np.sum(all_own_liberties[own_liberty_counts > 1], axis=0) 71 | possible_invalid_array += np.sum(all_opp_liberties[opp_liberty_counts == 1], axis=0) 72 | 73 | definite_valids_array += np.sum(all_own_liberties[own_liberty_counts == 1], axis=0) 74 | definite_valids_array += np.sum(all_opp_liberties[opp_liberty_counts > 1], axis=0) 75 | 76 | # All invalid moves are occupied spaces + (possible invalids minus the definite valids and it's surrounded) 77 | surrounded = ndimage.convolve(all_pieces, surround_struct, mode='constant', cval=1) == 4 78 | invalid_moves = all_pieces + possible_invalid_array * (definite_valids_array == 0) * surrounded 79 | 80 | # Ko-protection 81 | if ko_protect is not None: 82 | invalid_moves[ko_protect[0], ko_protect[1]] = 1 83 | return invalid_moves > 0 84 | 85 | 86 | def batch_compute_invalid_moves(batch_state, batch_player, batch_ko_protect): 87 | """ 88 | Updates invalid moves in the OPPONENT's perspective 89 | 1.) Opponent cannot move at a location 90 | i.) If it's occupied 91 | i.) If it's protected by ko 92 | 2.) Opponent can move at a location 93 | i.) If it can kill 94 | 3.) Opponent cannot move at a location 95 | i.) If it's adjacent to one of their groups with only one liberty and 96 | not adjacent to other groups with more than one liberty and is completely surrounded 97 | ii.) If it's surrounded by our pieces and all of those corresponding groups 98 | move more than one liberty 99 | """ 100 | batch_idcs = np.arange(len(batch_state)) 101 | 102 | # All pieces and empty spaces 103 | batch_all_pieces = np.sum(batch_state[:, [govars.BLACK, govars.WHITE]], axis=1) 104 | batch_empties = 1 - batch_all_pieces 105 | 106 | # Setup invalid and valid arrays 107 | batch_possible_invalid_array = np.zeros(batch_state.shape[:1] + batch_state.shape[2:]) 108 | batch_definite_valids_array = np.zeros(batch_state.shape[:1] + batch_state.shape[2:]) 109 | 110 | # Get all groups 111 | batch_all_own_groups, _ = measurements.label(batch_state[batch_idcs, batch_player], group_struct) 112 | batch_all_opp_groups, _ = measurements.label(batch_state[batch_idcs, 1 - batch_player], group_struct) 113 | 114 | batch_data = enumerate(zip(batch_all_own_groups, batch_all_opp_groups, batch_empties)) 115 | for i, (all_own_groups, all_opp_groups, empties) in batch_data: 116 | own_labels = np.unique(all_own_groups) 117 | opp_labels = np.unique(all_opp_groups) 118 | own_labels = own_labels[np.nonzero(own_labels)] 119 | opp_labels = opp_labels[np.nonzero(opp_labels)] 120 | expanded_own_groups = np.zeros((len(own_labels), *all_own_groups.shape)) 121 | expanded_opp_groups = np.zeros((len(opp_labels), *all_opp_groups.shape)) 122 | 123 | # Expand the groups such that each group is in its own channel 124 | for j, label in enumerate(own_labels): 125 | expanded_own_groups[j] = all_own_groups == label 126 | 127 | for j, label in enumerate(opp_labels): 128 | expanded_opp_groups[j] = all_opp_groups == label 129 | 130 | # Get all liberties in the expanded form 131 | all_own_liberties = empties[np.newaxis] * ndimage.binary_dilation(expanded_own_groups, 132 | surround_struct[np.newaxis]) 133 | all_opp_liberties = empties[np.newaxis] * ndimage.binary_dilation(expanded_opp_groups, 134 | surround_struct[np.newaxis]) 135 | 136 | own_liberty_counts = np.sum(all_own_liberties, axis=(1, 2)) 137 | opp_liberty_counts = np.sum(all_opp_liberties, axis=(1, 2)) 138 | 139 | # Possible invalids are on single liberties of opponent groups and on multi-liberties of own groups 140 | # Definite valids are on single liberties of own groups, multi-liberties of opponent groups 141 | # or you are not surrounded 142 | batch_possible_invalid_array[i] += np.sum(all_own_liberties[own_liberty_counts > 1], axis=0) 143 | batch_possible_invalid_array[i] += np.sum(all_opp_liberties[opp_liberty_counts == 1], axis=0) 144 | 145 | batch_definite_valids_array[i] += np.sum(all_own_liberties[own_liberty_counts == 1], axis=0) 146 | batch_definite_valids_array[i] += np.sum(all_opp_liberties[opp_liberty_counts > 1], axis=0) 147 | 148 | # All invalid moves are occupied spaces + (possible invalids minus the definite valids and it's surrounded) 149 | surrounded = ndimage.convolve(batch_all_pieces, surround_struct[np.newaxis], mode='constant', cval=1) == 4 150 | invalid_moves = batch_all_pieces + batch_possible_invalid_array * (batch_definite_valids_array == 0) * surrounded 151 | 152 | # Ko-protection 153 | for i, ko_protect in enumerate(batch_ko_protect): 154 | if ko_protect is not None: 155 | invalid_moves[i, ko_protect[0], ko_protect[1]] = 1 156 | return invalid_moves > 0 157 | 158 | 159 | def update_pieces(state, adj_locs, player): 160 | opponent = 1 - player 161 | killed_groups = [] 162 | 163 | all_pieces = np.sum(state[[govars.BLACK, govars.WHITE]], axis=0) 164 | empties = 1 - all_pieces 165 | 166 | all_opp_groups, _ = ndimage.measurements.label(state[opponent]) 167 | 168 | # Go through opponent groups 169 | all_adj_labels = all_opp_groups[adj_locs[:, 0], adj_locs[:, 1]] 170 | all_adj_labels = np.unique(all_adj_labels) 171 | for opp_group_idx in all_adj_labels[np.nonzero(all_adj_labels)]: 172 | opp_group = all_opp_groups == opp_group_idx 173 | liberties = empties * ndimage.binary_dilation(opp_group) 174 | if np.sum(liberties) <= 0: 175 | # Killed group 176 | opp_group_locs = np.argwhere(opp_group) 177 | state[opponent, opp_group_locs[:, 0], opp_group_locs[:, 1]] = 0 178 | killed_groups.append(opp_group_locs) 179 | 180 | return killed_groups 181 | 182 | 183 | def batch_update_pieces(batch_non_pass, batch_state, batch_adj_locs, batch_player): 184 | batch_opponent = 1 - batch_player 185 | batch_killed_groups = [] 186 | 187 | batch_all_pieces = np.sum(batch_state[:, [govars.BLACK, govars.WHITE]], axis=1) 188 | batch_empties = 1 - batch_all_pieces 189 | 190 | batch_all_opp_groups, _ = ndimage.measurements.label(batch_state[batch_non_pass, batch_opponent], 191 | group_struct) 192 | 193 | batch_data = enumerate(zip(batch_all_opp_groups, batch_all_pieces, batch_empties, batch_adj_locs, batch_opponent)) 194 | for i, (all_opp_groups, all_pieces, empties, adj_locs, opponent) in batch_data: 195 | killed_groups = [] 196 | 197 | # Go through opponent groups 198 | all_adj_labels = all_opp_groups[adj_locs[:, 0], adj_locs[:, 1]] 199 | all_adj_labels = np.unique(all_adj_labels) 200 | for opp_group_idx in all_adj_labels[np.nonzero(all_adj_labels)]: 201 | opp_group = all_opp_groups == opp_group_idx 202 | liberties = empties * ndimage.binary_dilation(opp_group) 203 | if np.sum(liberties) <= 0: 204 | # Killed group 205 | opp_group_locs = np.argwhere(opp_group) 206 | batch_state[batch_non_pass[i], opponent, opp_group_locs[:, 0], opp_group_locs[:, 1]] = 0 207 | killed_groups.append(opp_group_locs) 208 | 209 | batch_killed_groups.append(killed_groups) 210 | 211 | return batch_killed_groups 212 | 213 | 214 | def adj_data(state, action2d, player): 215 | neighbors = neighbor_deltas + action2d 216 | valid = (neighbors >= 0) & (neighbors < state.shape[1]) 217 | valid = np.prod(valid, axis=1) 218 | neighbors = neighbors[np.nonzero(valid)] 219 | 220 | opp_pieces = state[1 - player] 221 | surrounded = (opp_pieces[neighbors[:, 0], neighbors[:, 1]] > 0).all() 222 | 223 | return neighbors, surrounded 224 | 225 | 226 | def batch_adj_data(batch_state, batch_action2d, batch_player): 227 | batch_neighbors, batch_surrounded = [], [] 228 | for state, action2d, player in zip(batch_state, batch_action2d, batch_player): 229 | neighbors, surrounded = adj_data(state, action2d, player) 230 | batch_neighbors.append(neighbors) 231 | batch_surrounded.append(surrounded) 232 | return batch_neighbors, batch_surrounded 233 | 234 | 235 | def set_turn(state): 236 | """ 237 | Swaps turn 238 | :param state: 239 | :return: 240 | """ 241 | state[govars.TURN_CHNL] = 1 - state[govars.TURN_CHNL] 242 | 243 | 244 | def batch_set_turn(batch_state): 245 | """ 246 | Swaps turn 247 | :param state: 248 | :return: 249 | """ 250 | batch_state[:, govars.TURN_CHNL] = 1 - batch_state[:, govars.TURN_CHNL] 251 | -------------------------------------------------------------------------------- /gym_go/gogame.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import ndimage 3 | from sklearn import preprocessing 4 | 5 | from gym_go import state_utils, govars 6 | 7 | """ 8 | The state of the game is a numpy array 9 | * Are values are either 0 or 1 10 | 11 | * Shape [NUM_CHNLS, SIZE, SIZE] 12 | 13 | 0 - Black pieces 14 | 1 - White pieces 15 | 2 - Turn (0 - black, 1 - white) 16 | 3 - Invalid moves (including ko-protection) 17 | 4 - Previous move was a pass 18 | 5 - Game over 19 | """ 20 | 21 | 22 | def init_state(size): 23 | # return initial board (numpy board) 24 | state = np.zeros((govars.NUM_CHNLS, size, size)) 25 | return state 26 | 27 | 28 | def batch_init_state(batch_size, board_size): 29 | # return initial board (numpy board) 30 | batch_state = np.zeros((batch_size, govars.NUM_CHNLS, board_size, board_size)) 31 | return batch_state 32 | 33 | 34 | def next_state(state, action1d, canonical=False): 35 | # Deep copy the state to modify 36 | state = np.copy(state) 37 | 38 | # Initialize basic variables 39 | board_shape = state.shape[1:] 40 | pass_idx = np.prod(board_shape) 41 | passed = action1d == pass_idx 42 | action2d = action1d // board_shape[0], action1d % board_shape[1] 43 | 44 | player = turn(state) 45 | previously_passed = prev_player_passed(state) 46 | ko_protect = None 47 | 48 | if passed: 49 | # We passed 50 | state[govars.PASS_CHNL] = 1 51 | if previously_passed: 52 | # Game ended 53 | state[govars.DONE_CHNL] = 1 54 | else: 55 | # Move was not pass 56 | state[govars.PASS_CHNL] = 0 57 | 58 | # Assert move is valid 59 | assert state[govars.INVD_CHNL, action2d[0], action2d[1]] == 0, ("Invalid move", action2d) 60 | 61 | # Add piece 62 | state[player, action2d[0], action2d[1]] = 1 63 | 64 | # Get adjacent location and check whether the piece will be surrounded by opponent's piece 65 | adj_locs, surrounded = state_utils.adj_data(state, action2d, player) 66 | 67 | # Update pieces 68 | killed_groups = state_utils.update_pieces(state, adj_locs, player) 69 | 70 | # If only killed one group, and that one group was one piece, and piece set is surrounded, 71 | # activate ko protection 72 | if len(killed_groups) == 1 and surrounded: 73 | killed_group = killed_groups[0] 74 | if len(killed_group) == 1: 75 | ko_protect = killed_group[0] 76 | 77 | # Update invalid moves 78 | state[govars.INVD_CHNL] = state_utils.compute_invalid_moves(state, player, ko_protect) 79 | 80 | # Switch turn 81 | state_utils.set_turn(state) 82 | 83 | if canonical: 84 | # Set canonical form 85 | state = canonical_form(state) 86 | 87 | return state 88 | 89 | 90 | def batch_next_states(batch_states, batch_action1d, canonical=False): 91 | # Deep copy the state to modify 92 | batch_states = np.copy(batch_states) 93 | 94 | # Initialize basic variables 95 | board_shape = batch_states.shape[2:] 96 | pass_idx = np.prod(board_shape) 97 | batch_pass = np.nonzero(batch_action1d == pass_idx) 98 | batch_non_pass = np.nonzero(batch_action1d != pass_idx)[0] 99 | batch_prev_passed = batch_prev_player_passed(batch_states) 100 | batch_game_ended = np.nonzero(batch_prev_passed & (batch_action1d == pass_idx)) 101 | batch_action2d = np.array([batch_action1d[batch_non_pass] // board_shape[0], 102 | batch_action1d[batch_non_pass] % board_shape[1]]).T 103 | 104 | batch_players = batch_turn(batch_states) 105 | batch_non_pass_players = batch_players[batch_non_pass] 106 | batch_ko_protect = np.empty(len(batch_states), dtype=object) 107 | 108 | # Pass moves 109 | batch_states[batch_pass, govars.PASS_CHNL] = 1 110 | # Game ended 111 | batch_states[batch_game_ended, govars.DONE_CHNL] = 1 112 | 113 | # Non-pass moves 114 | batch_states[batch_non_pass, govars.PASS_CHNL] = 0 115 | 116 | # Assert all non-pass moves are valid 117 | assert (batch_states[batch_non_pass, govars.INVD_CHNL, batch_action2d[:, 0], batch_action2d[:, 1]] == 0).all() 118 | 119 | # Add piece 120 | batch_states[batch_non_pass, batch_non_pass_players, batch_action2d[:, 0], batch_action2d[:, 1]] = 1 121 | 122 | # Get adjacent location and check whether the piece will be surrounded by opponent's piece 123 | batch_adj_locs, batch_surrounded = state_utils.batch_adj_data(batch_states[batch_non_pass], batch_action2d, 124 | batch_non_pass_players) 125 | 126 | # Update pieces 127 | batch_killed_groups = state_utils.batch_update_pieces(batch_non_pass, batch_states, batch_adj_locs, 128 | batch_non_pass_players) 129 | 130 | # Ko-protection 131 | for i, (killed_groups, surrounded) in enumerate(zip(batch_killed_groups, batch_surrounded)): 132 | # If only killed one group, and that one group was one piece, and piece set is surrounded, 133 | # activate ko protection 134 | if len(killed_groups) == 1 and surrounded: 135 | killed_group = killed_groups[0] 136 | if len(killed_group) == 1: 137 | batch_ko_protect[batch_non_pass[i]] = killed_group[0] 138 | 139 | # Update invalid moves 140 | batch_states[:, govars.INVD_CHNL] = state_utils.batch_compute_invalid_moves(batch_states, batch_players, 141 | batch_ko_protect) 142 | 143 | # Switch turn 144 | state_utils.batch_set_turn(batch_states) 145 | 146 | if canonical: 147 | # Set canonical form 148 | batch_states = batch_canonical_form(batch_states) 149 | 150 | return batch_states 151 | 152 | 153 | def invalid_moves(state): 154 | # return a fixed size binary vector 155 | if game_ended(state): 156 | return np.zeros(action_size(state)) 157 | return np.append(state[govars.INVD_CHNL].flatten(), 0) 158 | 159 | 160 | def valid_moves(state): 161 | return 1 - invalid_moves(state) 162 | 163 | 164 | def batch_invalid_moves(batch_state): 165 | n = len(batch_state) 166 | batch_invalid_moves_bool = batch_state[:, govars.INVD_CHNL].reshape(n, -1) 167 | batch_invalid_moves_bool = np.append(batch_invalid_moves_bool, np.zeros((n, 1)), axis=1) 168 | return batch_invalid_moves_bool 169 | 170 | 171 | def batch_valid_moves(batch_state): 172 | return 1 - batch_invalid_moves(batch_state) 173 | 174 | 175 | def children(state, canonical=False, padded=True): 176 | valid_moves_bool = valid_moves(state) 177 | n = len(valid_moves_bool) 178 | valid_move_idcs = np.argwhere(valid_moves_bool).flatten() 179 | batch_states = np.tile(state[np.newaxis], (len(valid_move_idcs), 1, 1, 1)) 180 | children = batch_next_states(batch_states, valid_move_idcs, canonical) 181 | 182 | if padded: 183 | padded_children = np.zeros((n, *state.shape)) 184 | padded_children[valid_move_idcs] = children 185 | children = padded_children 186 | return children 187 | 188 | 189 | def action_size(state=None, board_size: int = None): 190 | # return number of actions 191 | if state is not None: 192 | m, n = state.shape[1:] 193 | elif board_size is not None: 194 | m, n = board_size, board_size 195 | else: 196 | raise RuntimeError('No argument passed') 197 | return m * n + 1 198 | 199 | 200 | def prev_player_passed(state): 201 | return np.max(state[govars.PASS_CHNL] == 1) == 1 202 | 203 | 204 | def batch_prev_player_passed(batch_state): 205 | return np.max(batch_state[:, govars.PASS_CHNL], axis=(1, 2)) == 1 206 | 207 | 208 | def game_ended(state): 209 | """ 210 | :param state: 211 | :return: 0/1 = game not ended / game ended respectively 212 | """ 213 | m, n = state.shape[1:] 214 | return int(np.count_nonzero(state[govars.DONE_CHNL] == 1) == m * n) 215 | 216 | 217 | def batch_game_ended(batch_state): 218 | """ 219 | :param batch_state: 220 | :return: 0/1 = game not ended / game ended respectively 221 | """ 222 | return np.max(batch_state[:, govars.DONE_CHNL], axis=(1, 2)) 223 | 224 | 225 | def winning(state, komi=0): 226 | black_area, white_area = areas(state) 227 | area_difference = black_area - white_area 228 | komi_correction = area_difference - komi 229 | 230 | return np.sign(komi_correction) 231 | 232 | 233 | def batch_winning(state, komi=0): 234 | batch_black_area, batch_white_area = batch_areas(state) 235 | batch_area_difference = batch_black_area - batch_white_area 236 | batch_komi_correction = batch_area_difference - komi 237 | 238 | return np.sign(batch_komi_correction) 239 | 240 | 241 | def turn(state): 242 | """ 243 | :param state: 244 | :return: Who's turn it is (govars.BLACK/govars.WHITE) 245 | """ 246 | return int(np.max(state[govars.TURN_CHNL])) 247 | 248 | 249 | def batch_turn(batch_state): 250 | return np.max(batch_state[:, govars.TURN_CHNL], axis=(1, 2)).astype(np.int) 251 | 252 | 253 | def liberties(state: np.ndarray): 254 | blacks = state[govars.BLACK] 255 | whites = state[govars.WHITE] 256 | all_pieces = np.sum(state[[govars.BLACK, govars.WHITE]], axis=0) 257 | 258 | liberty_list = [] 259 | for player_pieces in [blacks, whites]: 260 | liberties = ndimage.binary_dilation(player_pieces, state_utils.surround_struct) 261 | liberties *= (1 - all_pieces).astype(np.bool) 262 | liberty_list.append(liberties) 263 | 264 | return liberty_list[0], liberty_list[1] 265 | 266 | 267 | def num_liberties(state: np.ndarray): 268 | black_liberties, white_liberties = liberties(state) 269 | black_liberties = np.count_nonzero(black_liberties) 270 | white_liberties = np.count_nonzero(white_liberties) 271 | 272 | return black_liberties, white_liberties 273 | 274 | 275 | def areas(state): 276 | ''' 277 | Return black area, white area 278 | ''' 279 | 280 | all_pieces = np.sum(state[[govars.BLACK, govars.WHITE]], axis=0) 281 | empties = 1 - all_pieces 282 | 283 | empty_labels, num_empty_areas = ndimage.measurements.label(empties) 284 | 285 | black_area, white_area = np.sum(state[govars.BLACK]), np.sum(state[govars.WHITE]) 286 | for label in range(1, num_empty_areas + 1): 287 | empty_area = empty_labels == label 288 | neighbors = ndimage.binary_dilation(empty_area) 289 | black_claim = False 290 | white_claim = False 291 | if (state[govars.BLACK] * neighbors > 0).any(): 292 | black_claim = True 293 | if (state[govars.WHITE] * neighbors > 0).any(): 294 | white_claim = True 295 | if black_claim and not white_claim: 296 | black_area += np.sum(empty_area) 297 | elif white_claim and not black_claim: 298 | white_area += np.sum(empty_area) 299 | 300 | return black_area, white_area 301 | 302 | 303 | def batch_areas(batch_state): 304 | black_areas, white_areas = [], [] 305 | 306 | for state in batch_state: 307 | ba, wa = areas(state) 308 | black_areas.append(ba) 309 | white_areas.append(wa) 310 | return np.array(black_areas), np.array(white_areas) 311 | 312 | 313 | def canonical_form(state): 314 | state = np.copy(state) 315 | if turn(state) == govars.WHITE: 316 | channels = np.arange(govars.NUM_CHNLS) 317 | channels[govars.BLACK] = govars.WHITE 318 | channels[govars.WHITE] = govars.BLACK 319 | state = state[channels] 320 | state_utils.set_turn(state) 321 | return state 322 | 323 | 324 | def batch_canonical_form(batch_state): 325 | batch_state = np.copy(batch_state) 326 | batch_player = batch_turn(batch_state) 327 | white_players_idcs = np.nonzero(batch_player == govars.WHITE)[0] 328 | 329 | channels = np.arange(govars.NUM_CHNLS) 330 | channels[govars.BLACK] = govars.WHITE 331 | channels[govars.WHITE] = govars.BLACK 332 | 333 | for i in white_players_idcs: 334 | batch_state[i] = batch_state[i, channels] 335 | batch_state[i, govars.TURN_CHNL] = 1 - batch_player[i] 336 | 337 | return batch_state 338 | 339 | 340 | def random_symmetry(image): 341 | """ 342 | Returns a random symmetry of the image 343 | :param image: A (C, BOARD_SIZE, BOARD_SIZE) numpy array, where C is any number 344 | :return: 345 | """ 346 | orientation = np.random.randint(0, 8) 347 | 348 | if (orientation >> 0) % 2: 349 | # Horizontal flip 350 | image = np.flip(image, 2) 351 | if (orientation >> 1) % 2: 352 | # Vertical flip 353 | image = np.flip(image, 1) 354 | if (orientation >> 2) % 2: 355 | # Rotate 90 degrees 356 | image = np.rot90(image, axes=(1, 2)) 357 | 358 | return image 359 | 360 | 361 | def all_symmetries(image): 362 | """ 363 | :param image: A (C, BOARD_SIZE, BOARD_SIZE) numpy array, where C is any number 364 | :return: All 8 orientations that are symmetrical in a Go game over the 2nd and 3rd axes 365 | (i.e. rotations, flipping and combos of them) 366 | """ 367 | symmetries = [] 368 | 369 | for i in range(8): 370 | x = image 371 | if (i >> 0) % 2: 372 | # Horizontal flip 373 | x = np.flip(x, 2) 374 | if (i >> 1) % 2: 375 | # Vertical flip 376 | x = np.flip(x, 1) 377 | if (i >> 2) % 2: 378 | # Rotation 90 degrees 379 | x = np.rot90(x, axes=(1, 2)) 380 | symmetries.append(x) 381 | 382 | return symmetries 383 | 384 | 385 | def random_weighted_action(move_weights): 386 | """ 387 | Assumes all invalid moves have weight 0 388 | Action is 1D 389 | Expected shape is (NUM OF MOVES, ) 390 | """ 391 | move_weights = preprocessing.normalize(move_weights[np.newaxis], norm='l1') 392 | return np.random.choice(np.arange(len(move_weights[0])), p=move_weights[0]) 393 | 394 | 395 | def random_action(state): 396 | """ 397 | Assumed to be (NUM_CHNLS, BOARD_SIZE, BOARD_SIZE) 398 | Action is 1D 399 | """ 400 | invalid_moves = state[govars.INVD_CHNL].flatten() 401 | invalid_moves = np.append(invalid_moves, 0) 402 | move_weights = 1 - invalid_moves 403 | 404 | return random_weighted_action(move_weights) 405 | 406 | 407 | def str(state): 408 | board_str = '' 409 | 410 | size = state.shape[1] 411 | board_str += '\t' 412 | for i in range(size): 413 | board_str += '{}'.format(i).ljust(2, ' ') 414 | board_str += '\n' 415 | for i in range(size): 416 | board_str += '{}\t'.format(i) 417 | for j in range(size): 418 | if state[0, i, j] == 1: 419 | board_str += '○' 420 | if j != size - 1: 421 | if i == 0 or i == size - 1: 422 | board_str += '═' 423 | else: 424 | board_str += '─' 425 | elif state[1, i, j] == 1: 426 | board_str += '●' 427 | if j != size - 1: 428 | if i == 0 or i == size - 1: 429 | board_str += '═' 430 | else: 431 | board_str += '─' 432 | else: 433 | if i == 0: 434 | if j == 0: 435 | board_str += '╔═' 436 | elif j == size - 1: 437 | board_str += '╗' 438 | else: 439 | board_str += '╤═' 440 | elif i == size - 1: 441 | if j == 0: 442 | board_str += '╚═' 443 | elif j == size - 1: 444 | board_str += '╝' 445 | else: 446 | board_str += '╧═' 447 | else: 448 | if j == 0: 449 | board_str += '╟─' 450 | elif j == size - 1: 451 | board_str += '╢' 452 | else: 453 | board_str += '┼─' 454 | board_str += '\n' 455 | 456 | black_area, white_area = areas(state) 457 | done = game_ended(state) 458 | ppp = prev_player_passed(state) 459 | t = turn(state) 460 | if done: 461 | game_state = 'END' 462 | elif ppp: 463 | game_state = 'PASSED' 464 | else: 465 | game_state = 'ONGOING' 466 | board_str += '\tTurn: {}, Game State (ONGOING|PASSED|END): {}\n'.format('BLACK' if t == 0 else 'WHITE', game_state) 467 | board_str += '\tBlack Area: {}, White Area: {}\n'.format(int(black_area), int(white_area)) 468 | return board_str 469 | --------------------------------------------------------------------------------