├── .gitignore ├── GymGo ├── .gitignore ├── README.md ├── demo.py ├── gym_go │ ├── __init__.py │ ├── envs │ │ ├── __init__.py │ │ ├── go_env.py │ │ └── go_extrahard_env.py │ ├── gogame.py │ ├── govars.py │ ├── rendering.py │ ├── state_utils.py │ └── tests │ │ ├── efficiency.py │ │ ├── test_basics.py │ │ ├── test_batch_fns.py │ │ ├── test_invalid_moves.py │ │ └── test_valid_moves.py ├── screenshots │ └── human_ui.png └── setup.py ├── LICENSE ├── README.md ├── assets ├── audios │ ├── Button.wav │ └── Stone.wav ├── fonts │ ├── msyh.ttc │ ├── msyhbd.ttc │ └── msyhl.ttc └── pictures │ ├── B-13.png │ ├── B-19.png │ ├── B-9.png │ ├── B.png │ ├── W-13.png │ ├── W-19.png │ ├── W-9.png │ └── W.png ├── docs ├── 围棋基本知识.md ├── 围棋程序逻辑.md ├── 机巧围棋(CleverGo)开发计划文档.md ├── 机巧围棋(CleverGo)技术原理文档.md ├── 机巧围棋(CleverGo)项目总览及介绍.md ├── 深度学习框架(PaddlePaddle)使用教程.md ├── 深度强化学习基础.md ├── 游戏开发引擎(Pygame)核心方法.md ├── 蒙特卡洛树搜索(MCTS).md ├── 训练策略网络和价值网络.md └── 阿尔法狗与机巧围棋的网络结构.md ├── game_engine.py ├── go_engine.py ├── mcts.py ├── models └── alpha_go.pdparams ├── pgutils ├── manager.py ├── pgcontrols │ ├── __init__.py │ ├── button.py │ └── ctbase.py ├── pgtools │ ├── __init__.py │ ├── information_display.py │ └── toolbase.py ├── position.py └── text.py ├── pictures ├── 1_1.png ├── 2_1.png ├── 2_10.png ├── 2_2.png ├── 2_3.png ├── 2_4.png ├── 2_5.png ├── 2_6.png ├── 2_7.png ├── 2_8.gif ├── 2_9.png ├── 3_1.png ├── 4_1.png ├── 4_2.png ├── 4_3.png ├── 6_1.png ├── 6_2.png ├── 6_3.png ├── 7_1.png ├── 7_2.png ├── 7_3.png ├── 7_4.png ├── 8_1.png ├── 8_2.png ├── 8_3.png ├── 9_1.png ├── 启动界面.png ├── 对弈.png ├── 训练初始界面.png └── 训练过程.png ├── play_game.py ├── player.py ├── policy_value_net.py ├── requirements.txt ├── test.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .idea/ 3 | 4 | # music files 5 | assets/musics/* -------------------------------------------------------------------------------- /GymGo/.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.egg-info/ 3 | __pycache__/ 4 | .ipynb_checkpoints/ 5 | .vscode/ 6 | -------------------------------------------------------------------------------- /GymGo/README.md: -------------------------------------------------------------------------------- 1 | # About 2 | An environment for the board game Go. It is implemented using OpenAI's Gym API. 3 | It is also optimized to be as efficient as possible in order to efficiently train ML models. 4 | 5 | # Installation 6 | ```bash 7 | # In the root directory 8 | pip install -e . 9 | ``` 10 | 11 | # API 12 | ### Basic example 13 | ```bash 14 | # In the root directory 15 | python3 demo.py 16 | ``` 17 | ![alt text](screenshots/human_ui.png) 18 | 19 | ### Coding example 20 | ```python 21 | import gym 22 | 23 | go_env = gym.make('gym_go:go-v0', size=7, komi=0, reward_method='real') 24 | 25 | first_action = (2,5) 26 | second_action = (5,2) 27 | state, reward, done, info = go_env.step(first_action) 28 | go_env.render('terminal') 29 | ``` 30 | 31 | ``` 32 | 0 1 2 3 4 5 6 33 | ----------------------------- 34 | 0 | | | | | | | | 35 | ----------------------------- 36 | 1 | | | | | | | | 37 | ----------------------------- 38 | 2 | | | | | | B | | 39 | ----------------------------- 40 | 3 | | | | | | | | 41 | ----------------------------- 42 | 4 | | | | | | | | 43 | ----------------------------- 44 | 5 | | | | | | | | 45 | ----------------------------- 46 | 6 | | | | | | | | 47 | ----------------------------- 48 | Turn: WHITE, Last Turn Passed: False, Game Over: False 49 | Black Area: 49, White Area: 0, Reward: 0 50 | ``` 51 | 52 | ```python 53 | state, reward, done, info = go_env.step(second_action) 54 | go_env.render('terminal') 55 | ``` 56 | 57 | ``` 58 | 0 1 2 3 4 5 6 59 | ----------------------------- 60 | 0 | | | | | | | | 61 | ----------------------------- 62 | 1 | | | | | | | | 63 | ----------------------------- 64 | 2 | | | | | | B | | 65 | ----------------------------- 66 | 3 | | | | | | | | 67 | ----------------------------- 68 | 4 | | | | | | | | 69 | ----------------------------- 70 | 5 | | | W | | | | | 71 | ----------------------------- 72 | 6 | | | | | | | | 73 | ----------------------------- 74 | Turn: BLACK, Last Turn Passed: False, Game Over: False 75 | Black Area: 1, White Area: 1, Reward: 0 76 | ``` 77 | 78 | ### High level API 79 | [GoEnv](gym_go/envs/go_env.py) defines the Gym environment for Go. 80 | It contains the highest level API for basic Go usage. 81 | 82 | ### Low level API 83 | [GoGame](gym_go/gogame.py) is the set of low-level functions that defines all the game logic of Go. 84 | `GoEnv`'s high level API is built on `GoGame`. 85 | These sets of functions are intended for a more detailed and finetuned 86 | usage of Go. 87 | 88 | # Scoring 89 | We use Trump Taylor scoring, a simple area scoring, to determine the winner. A player's _area_ is defined as the number of empty points a 90 | player's pieces surround plus the number of player's pieces on the board. The _winner_ is the player with the larger 91 | area (a game is tied if both players have an equal amount of area on the board). 92 | 93 | There is also support for `komi`, a bias score constant to balance the advantage of black going first. 94 | By default `komi` is set to 0. 95 | 96 | # Game ending 97 | A game ends when both players pass consecutively 98 | 99 | # Reward methods 100 | Reward methods are in _black_'s perspective 101 | * **Real**: 102 | * If game ended: 103 | * `-1` - White won 104 | * `0` - Game is tied 105 | * `1` - Black won 106 | * `0` - Otherwise 107 | * **Heuristic**: If the game is ongoing, the reward is `black area - white area`. 108 | If black won, the reward is `BOARD_SIZE**2`. 109 | If white won, the reward is `-BOARD_SIZE**2`. 110 | If tied, the reward is `0`. 111 | 112 | # State 113 | The `state` object that is returned by the `reset` and `step` functions of the environment is a 114 | `6 x BOARD_SIZE x BOARD_SIZE` numpy array. All values in the array are either `0` or `1` 115 | * **First and second channel:** represent the black and white pieces respectively. 116 | * **Third channel:** Indicator layer for whose turn it is 117 | * **Fourth channel:** Invalid moves (including ko-protection) for the next action 118 | * **Fifth channel:** Indicator layer for whether the previous move was a pass 119 | * **Sixth channel:** Indicator layer for whether the game is over 120 | 121 | # Action 122 | The `step` function takes in the action to execute and can be in the following forms: 123 | * a tuple/list of 2 integers representing the row and column or `None` for passing 124 | * a single integer representing the action in 1d space (i.e 9 would be (1,2) and 49 would be a pass for a 7x7 board) 125 | -------------------------------------------------------------------------------- /GymGo/demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | 5 | # Arguments 6 | parser = argparse.ArgumentParser(description='Demo Go Environment') 7 | parser.add_argument('--boardsize', type=int, default=7) 8 | parser.add_argument('--komi', type=float, default=0) 9 | args = parser.parse_args() 10 | 11 | # Initialize environment 12 | go_env = gym.make('gym_go:go-v0', size=3, komi=args.komi) 13 | 14 | # Game loop 15 | done = False 16 | while not done: 17 | action = go_env.render(mode="human") 18 | state, reward, done, info = go_env.step(action) 19 | 20 | if go_env.game_ended(): 21 | break 22 | action = go_env.uniform_random_action() 23 | state, reward, done, info = go_env.step(action) 24 | go_env.render(mode="human") 25 | -------------------------------------------------------------------------------- /GymGo/gym_go/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='go-v0', 5 | entry_point='gym_go.envs:GoEnv', 6 | ) 7 | register( 8 | id='go-extrahard-v0', 9 | entry_point='gym_go.envs:GoExtraHardEnv', 10 | ) 11 | -------------------------------------------------------------------------------- /GymGo/gym_go/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from gym_go.envs.go_env import GoEnv 2 | from gym_go.envs.go_extrahard_env import GoExtraHardEnv 3 | -------------------------------------------------------------------------------- /GymGo/gym_go/envs/go_env.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from GymGo.gym_go import govars, rendering, gogame 7 | 8 | 9 | class RewardMethod(Enum): 10 | """ 11 | REAL: 0 = game is ongoing, 1 = black won, -1 = game tied or white won 12 | HEURISTIC: If game is ongoing, the reward is the area difference between black and white. 13 | Otherwise the game has ended, and if black has more area, the reward is BOARD_SIZE**2, otherwise it's -BOARD_SIZE**2 14 | """ 15 | REAL = 'real' 16 | HEURISTIC = 'heuristic' 17 | 18 | 19 | class GoEnv(gym.Env): 20 | metadata = {'render.modes': ['terminal', 'human']} 21 | govars = govars 22 | gogame = gogame 23 | 24 | def __init__(self, size, komi=0, reward_method='real'): 25 | ''' 26 | @param reward_method: either 'heuristic' or 'real' 27 | heuristic: gives # black pieces - # white pieces. 28 | real: gives 0 for in-game move, 1 for winning, -1 for losing, 29 | 0 for draw, all from black player's perspective 30 | ''' 31 | self.size = size 32 | self.komi = komi 33 | self.state_ = gogame.init_state(size) 34 | self.reward_method = RewardMethod(reward_method) 35 | self.observation_space = gym.spaces.Box(np.float32(0), np.float32(govars.NUM_CHNLS), 36 | shape=(govars.NUM_CHNLS, size, size)) 37 | self.action_space = gym.spaces.Discrete(gogame.action_size(self.state_)) 38 | self.done = False 39 | 40 | def reset(self): 41 | ''' 42 | Reset state, go_board, curr_player, prev_player_passed, 43 | done, return state 44 | ''' 45 | self.state_ = gogame.init_state(self.size) 46 | self.done = False 47 | return np.copy(self.state_) 48 | 49 | def step(self, action): 50 | ''' 51 | Assumes the correct player is making a move. Black goes first. 52 | return observation, reward, done, info 53 | ''' 54 | assert not self.done 55 | if isinstance(action, tuple) or isinstance(action, list) or isinstance(action, np.ndarray): 56 | assert 0 <= action[0] < self.size 57 | assert 0 <= action[1] < self.size 58 | action = self.size * action[0] + action[1] 59 | elif action is None: 60 | action = self.size ** 2 61 | 62 | self.state_ = gogame.next_state(self.state_, action, canonical=False) 63 | self.done = gogame.game_ended(self.state_) 64 | return np.copy(self.state_), self.reward(), self.done, self.info() 65 | 66 | def game_ended(self): 67 | return self.done 68 | 69 | def turn(self): 70 | return gogame.turn(self.state_) 71 | 72 | def prev_player_passed(self): 73 | return gogame.prev_player_passed(self.state_) 74 | 75 | def valid_moves(self): 76 | return gogame.valid_moves(self.state_) 77 | 78 | def uniform_random_action(self): 79 | valid_moves = self.valid_moves() 80 | valid_move_idcs = np.argwhere(valid_moves).flatten() 81 | return np.random.choice(valid_move_idcs) 82 | 83 | def info(self): 84 | """ 85 | :return: Debugging info for the state 86 | """ 87 | return { 88 | 'turn': gogame.turn(self.state_), 89 | 'invalid_moves': gogame.invalid_moves(self.state_), 90 | 'prev_player_passed': gogame.prev_player_passed(self.state_), 91 | } 92 | 93 | def state(self): 94 | """ 95 | :return: copy of state 96 | """ 97 | return np.copy(self.state_) 98 | 99 | def canonical_state(self): 100 | """ 101 | :return: canonical shallow copy of state 102 | """ 103 | return gogame.canonical_form(self.state_) 104 | 105 | def children(self, canonical=False, padded=True): 106 | """ 107 | :return: Same as get_children, but in canonical form 108 | """ 109 | return gogame.children(self.state_, canonical, padded) 110 | 111 | def winning(self): 112 | """ 113 | :return: Who's currently winning in BLACK's perspective, regardless if the game is over 114 | """ 115 | return gogame.winning(self.state_, self.komi) 116 | 117 | def winner(self): 118 | """ 119 | Get's the winner in BLACK's perspective 120 | :return: 121 | """ 122 | 123 | if self.game_ended(): 124 | return self.winning() 125 | else: 126 | return 0 127 | 128 | def reward(self): 129 | ''' 130 | Return reward based on reward_method. 131 | heuristic: black total area - white total area 132 | real: 0 for in-game move, 1 for winning, 0 for losing, 133 | 0.5 for draw, from black player's perspective. 134 | Winning and losing based on the Area rule 135 | Also known as Trump Taylor Scoring 136 | Area rule definition: https://en.wikipedia.org/wiki/Rules_of_Go#End 137 | ''' 138 | if self.reward_method == RewardMethod.REAL: 139 | return self.winner() 140 | 141 | elif self.reward_method == RewardMethod.HEURISTIC: 142 | black_area, white_area = gogame.areas(self.state_) 143 | area_difference = black_area - white_area 144 | komi_correction = area_difference - self.komi 145 | if self.game_ended(): 146 | return (1 if komi_correction > 0 else -1) * self.size ** 2 147 | return komi_correction 148 | else: 149 | raise Exception("Unknown Reward Method") 150 | 151 | def __str__(self): 152 | return gogame.str(self.state_) 153 | 154 | def close(self): 155 | if hasattr(self, 'window'): 156 | assert hasattr(self, 'pyglet') 157 | self.window.close() 158 | self.pyglet.app.exit() 159 | 160 | def render(self, mode='terminal'): 161 | if mode == 'terminal': 162 | print(self.__str__()) 163 | elif mode == 'human': 164 | import pyglet 165 | from pyglet.window import mouse 166 | from pyglet.window import key 167 | 168 | screen = pyglet.canvas.get_display().get_default_screen() 169 | window_width = int(min(screen.width, screen.height) * 2 / 3) 170 | window_height = int(window_width * 1.2) 171 | window = pyglet.window.Window(window_width, window_height) 172 | 173 | self.window = window 174 | self.pyglet = pyglet 175 | self.user_action = None 176 | 177 | # Set Cursor 178 | cursor = window.get_system_mouse_cursor(window.CURSOR_CROSSHAIR) 179 | window.set_mouse_cursor(cursor) 180 | 181 | # Outlines 182 | lower_grid_coord = window_width * 0.075 183 | board_size = window_width * 0.85 184 | upper_grid_coord = board_size + lower_grid_coord 185 | delta = board_size / (self.size - 1) 186 | piece_r = delta / 3.3 # radius 187 | 188 | @window.event 189 | def on_draw(): 190 | pyglet.gl.glClearColor(0.7, 0.5, 0.3, 1) 191 | window.clear() 192 | 193 | pyglet.gl.glLineWidth(3) 194 | batch = pyglet.graphics.Batch() 195 | 196 | # draw the grid and labels 197 | rendering.draw_grid(batch, delta, self.size, lower_grid_coord, upper_grid_coord) 198 | 199 | # info on top of the board 200 | rendering.draw_info(batch, window_width, window_height, upper_grid_coord, self.state_) 201 | 202 | # Inform user what they can do 203 | rendering.draw_command_labels(batch, window_width, window_height) 204 | 205 | rendering.draw_title(batch, window_width, window_height) 206 | 207 | batch.draw() 208 | 209 | # draw the pieces 210 | rendering.draw_pieces(batch, lower_grid_coord, delta, piece_r, self.size, self.state_) 211 | 212 | @window.event 213 | def on_mouse_press(x, y, button, modifiers): 214 | if button == mouse.LEFT: 215 | grid_x = (x - lower_grid_coord) 216 | grid_y = (y - lower_grid_coord) 217 | x_coord = round(grid_x / delta) 218 | y_coord = round(grid_y / delta) 219 | try: 220 | self.window.close() 221 | pyglet.app.exit() 222 | self.user_action = (x_coord, y_coord) 223 | except: 224 | pass 225 | 226 | @window.event 227 | def on_key_press(symbol, modifiers): 228 | if symbol == key.P: 229 | self.window.close() 230 | pyglet.app.exit() 231 | self.user_action = None 232 | elif symbol == key.R: 233 | self.reset() 234 | self.window.close() 235 | pyglet.app.exit() 236 | elif symbol == key.E: 237 | self.window.close() 238 | pyglet.app.exit() 239 | self.user_action = -1 240 | 241 | pyglet.app.run() 242 | 243 | return self.user_action 244 | -------------------------------------------------------------------------------- /GymGo/gym_go/envs/go_extrahard_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | 4 | class GoExtraHardEnv(gym.Env): 5 | metadata = {'render.modes': ['human', 'terminal']} 6 | -------------------------------------------------------------------------------- /GymGo/gym_go/gogame.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import ndimage 3 | from sklearn import preprocessing 4 | 5 | from GymGo.gym_go import state_utils, govars 6 | 7 | """ 8 | The state of the game is a numpy array 9 | * Are values are either 0 or 1 10 | 11 | * Shape [NUM_CHNLS, SIZE, SIZE] 12 | 13 | 0 - Black pieces 14 | CHANNEL[0]: 黑棋棋子分布。有黑棋棋子位置为1,否则为0; 15 | 1 - White pieces 16 | CHANNEL[1]: 白棋棋子分布。有白棋棋子位置为1,否则为0; 17 | 2 - Turn (0 - black, 1 - white) 18 | CHANNEL[2]: 下一步落子方,一个全0或全1的矩阵。0:黑方,1:白方; 19 | 3 - Invalid moves (including ko-protection) 20 | CHANNEL[3]: 下一步的落子无效位置。无效位置为1,其余为0; 21 | 4 - Previous move was a pass 22 | CHANNEL[4]: 上一步是否为PASS,一个全0或全1的矩阵。0:不是PASS,1:是PASS; 23 | 5 - Game over 24 | CHANNEL[5]: 上一步落子后,游戏是否结束,一个全0或全1的矩阵。0:未结束,1:已结束。 25 | """ 26 | 27 | 28 | def init_state(size): 29 | # return initial board (numpy board) 30 | state = np.zeros((govars.NUM_CHNLS, size, size)) 31 | return state 32 | 33 | 34 | def batch_init_state(batch_size, board_size): 35 | # return initial board (numpy board) 36 | batch_state = np.zeros((batch_size, govars.NUM_CHNLS, board_size, board_size)) 37 | return batch_state 38 | 39 | 40 | def next_state(state, action1d, canonical=False): 41 | # Deep copy the state to modify 42 | state = np.copy(state) 43 | 44 | # Initialize basic variables 45 | board_shape = state.shape[1:] # state.shape为(通道数, 棋盘高度, 棋盘宽度) 46 | pass_idx = np.prod(board_shape) # np.prod()将参数内所有元素连乘,pass_idx:"pass"对应的id 47 | passed = action1d == pass_idx # 如果action id等于pass_idx,则passed为True 48 | action2d = action1d // board_shape[0], action1d % board_shape[1] # 将action1d转换成action2d 49 | 50 | player = turn(state) # 获取下一步落子方 51 | previously_passed = prev_player_passed(state) # 获取上一步是否为pass 52 | ko_protect = None 53 | 54 | if passed: 55 | # We passed 56 | # 如果下一步为pass,则将next_state中PASS_CHNL矩阵置为全1矩阵 57 | state[govars.PASS_CHNL] = 1 58 | if previously_passed: 59 | # Game ended 60 | # 如果上一步也为pass,则游戏结束【双方连续各pass,则游戏结束】 61 | # 将next_state中DONE_CHNL矩阵置为全1矩阵 62 | state[govars.DONE_CHNL] = 1 63 | else: 64 | # Move was not pass 65 | state[govars.PASS_CHNL] = 0 66 | 67 | # Assert move is valid 检查落子是否有效【state中INVD_CHNL对应位置为0】 68 | assert state[govars.INVD_CHNL, action2d[0], action2d[1]] == 0, ("Invalid move", action2d) 69 | 70 | # Add piece 71 | state[player, action2d[0], action2d[1]] = 1 72 | 73 | # Get adjacent location and check whether the piece will be surrounded by opponent's piece 74 | # 获取下一步落子位置的相邻位置(仅在棋盘内)、下一步落子位置是否被下一步落子方对手的棋子包围 75 | adj_locs, surrounded = state_utils.adj_data(state, action2d, player) 76 | 77 | # Update pieces 78 | # 更新棋盘黑白棋子分布矩阵,并返回各组被杀死的棋子列表 79 | killed_groups = state_utils.update_pieces(state, adj_locs, player) 80 | 81 | # If only killed one group, and that one group was one piece, and piece set is surrounded, 82 | # activate ko protection 83 | if len(killed_groups) == 1 and surrounded: 84 | killed_group = killed_groups[0] 85 | if len(killed_group) == 1: 86 | ko_protect = killed_group[0] 87 | 88 | # Update invalid moves 89 | state[govars.INVD_CHNL] = state_utils.compute_invalid_moves(state, player, ko_protect) 90 | 91 | # Switch turn 92 | # 设置下一步落子方 93 | state_utils.set_turn(state) 94 | 95 | # 该标记是选择是否始终以黑棋视角看待当前游戏局面 96 | if canonical: 97 | # Set canonical form 98 | # 该函数将黑白棋子分布对换,并更改下一手落子方为黑棋 99 | state = canonical_form(state) 100 | 101 | return state 102 | 103 | 104 | def batch_next_states(batch_states, batch_action1d, canonical=False): 105 | # Deep copy the state to modify 106 | batch_states = np.copy(batch_states) 107 | 108 | # Initialize basic variables 109 | board_shape = batch_states.shape[2:] 110 | pass_idx = np.prod(board_shape) 111 | batch_pass = np.nonzero(batch_action1d == pass_idx) 112 | batch_non_pass = np.nonzero(batch_action1d != pass_idx)[0] 113 | batch_prev_passed = batch_prev_player_passed(batch_states) 114 | batch_game_ended = np.nonzero(batch_prev_passed & (batch_action1d == pass_idx)) 115 | batch_action2d = np.array([batch_action1d[batch_non_pass] // board_shape[0], 116 | batch_action1d[batch_non_pass] % board_shape[1]]).T 117 | # batch_action2d在.T之前是一个shape为(2, batch_size)的二维数组,第一行代表所有非pass动作的坐标行号, 118 | # 第二行代表所有非pass动作列号。因此需要.T进行转置 119 | 120 | batch_players = batch_turn(batch_states) 121 | batch_non_pass_players = batch_players[batch_non_pass] 122 | batch_ko_protect = np.empty(len(batch_states), dtype=object) 123 | 124 | # Pass moves 125 | batch_states[batch_pass, govars.PASS_CHNL] = 1 126 | # Game ended 127 | batch_states[batch_game_ended, govars.DONE_CHNL] = 1 128 | 129 | # Non-pass moves 130 | batch_states[batch_non_pass, govars.PASS_CHNL] = 0 131 | 132 | # Assert all non-pass moves are valid 133 | assert (batch_states[batch_non_pass, govars.INVD_CHNL, batch_action2d[:, 0], batch_action2d[:, 1]] == 0).all() 134 | 135 | # Add piece 136 | batch_states[batch_non_pass, batch_non_pass_players, batch_action2d[:, 0], batch_action2d[:, 1]] = 1 137 | 138 | # Get adjacent location and check whether the piece will be surrounded by opponent's piece 139 | batch_adj_locs, batch_surrounded = state_utils.batch_adj_data(batch_states[batch_non_pass], batch_action2d, 140 | batch_non_pass_players) 141 | 142 | # Update pieces 143 | batch_killed_groups = state_utils.batch_update_pieces(batch_non_pass, batch_states, batch_adj_locs, 144 | batch_non_pass_players) 145 | 146 | # Ko-protection 147 | for i, (killed_groups, surrounded) in enumerate(zip(batch_killed_groups, batch_surrounded)): 148 | # If only killed one group, and that one group was one piece, and piece set is surrounded, 149 | # activate ko protection 150 | if len(killed_groups) == 1 and surrounded: 151 | killed_group = killed_groups[0] 152 | if len(killed_group) == 1: 153 | batch_ko_protect[batch_non_pass[i]] = killed_group[0] 154 | 155 | # Update invalid moves 156 | batch_states[:, govars.INVD_CHNL] = state_utils.batch_compute_invalid_moves(batch_states, batch_players, 157 | batch_ko_protect) 158 | 159 | # Switch turn 160 | state_utils.batch_set_turn(batch_states) 161 | 162 | if canonical: 163 | # Set canonical form 164 | batch_states = batch_canonical_form(batch_states) 165 | 166 | return batch_states 167 | 168 | 169 | def invalid_moves(state): 170 | # return a fixed size binary vector 171 | if game_ended(state): 172 | return np.zeros(action_size(state)) 173 | return np.append(state[govars.INVD_CHNL].flatten(), 0) 174 | 175 | 176 | def valid_moves(state): 177 | return 1 - invalid_moves(state) 178 | 179 | 180 | def batch_invalid_moves(batch_state): 181 | n = len(batch_state) 182 | batch_invalid_moves_bool = batch_state[:, govars.INVD_CHNL].reshape(n, -1) 183 | batch_invalid_moves_bool = np.append(batch_invalid_moves_bool, np.zeros((n, 1)), axis=1) 184 | return batch_invalid_moves_bool 185 | 186 | 187 | def batch_valid_moves(batch_state): 188 | return 1 - batch_invalid_moves(batch_state) 189 | 190 | 191 | def children(state, canonical=False, padded=True): 192 | valid_moves_bool = valid_moves(state) 193 | n = len(valid_moves_bool) 194 | valid_move_idcs = np.argwhere(valid_moves_bool).flatten() 195 | batch_states = np.tile(state[np.newaxis], (len(valid_move_idcs), 1, 1, 1)) 196 | children = batch_next_states(batch_states, valid_move_idcs, canonical) 197 | 198 | if padded: 199 | padded_children = np.zeros((n, *state.shape)) 200 | padded_children[valid_move_idcs] = children 201 | children = padded_children 202 | return children 203 | 204 | 205 | def action_size(state=None, board_size: int = None): 206 | # return number of actions 207 | if state is not None: 208 | m, n = state.shape[1:] 209 | elif board_size is not None: 210 | m, n = board_size, board_size 211 | else: 212 | raise RuntimeError('No argument passed') 213 | return m * n + 1 214 | 215 | 216 | def prev_player_passed(state): 217 | return np.max(state[govars.PASS_CHNL] == 1) == 1 218 | 219 | 220 | def batch_prev_player_passed(batch_state): 221 | return np.max(batch_state[:, govars.PASS_CHNL], axis=(1, 2)) == 1 222 | 223 | 224 | def game_ended(state): 225 | """ 226 | :param state: 227 | :return: 0/1 = game not ended / game ended respectively 228 | """ 229 | m, n = state.shape[1:] 230 | return int(np.count_nonzero(state[govars.DONE_CHNL] == 1) == m * n) 231 | 232 | 233 | def batch_game_ended(batch_state): 234 | """ 235 | :param batch_state: 236 | :return: 0/1 = game not ended / game ended respectively 237 | """ 238 | return np.max(batch_state[:, govars.DONE_CHNL], axis=(1, 2)) 239 | 240 | 241 | def winning(state, komi=0): 242 | black_area, white_area = areas(state) 243 | area_difference = black_area - white_area 244 | komi_correction = area_difference - komi 245 | 246 | return np.sign(komi_correction) 247 | 248 | 249 | def batch_winning(state, komi=0): 250 | batch_black_area, batch_white_area = batch_areas(state) 251 | batch_area_difference = batch_black_area - batch_white_area 252 | batch_komi_correction = batch_area_difference - komi 253 | 254 | return np.sign(batch_komi_correction) 255 | 256 | 257 | def turn(state): 258 | """ 259 | :param state: 260 | :return: Who's turn it is (govars.BLACK/govars.WHITE) 261 | """ 262 | return int(np.max(state[govars.TURN_CHNL])) 263 | 264 | 265 | def batch_turn(batch_state): 266 | return np.max(batch_state[:, govars.TURN_CHNL], axis=(1, 2)).astype(np.int) 267 | 268 | 269 | def liberties(state: np.ndarray): 270 | blacks = state[govars.BLACK] 271 | whites = state[govars.WHITE] 272 | all_pieces = np.sum(state[[govars.BLACK, govars.WHITE]], axis=0) 273 | 274 | liberty_list = [] 275 | for player_pieces in [blacks, whites]: 276 | liberties = ndimage.binary_dilation(player_pieces, state_utils.surround_struct) 277 | liberties *= (1 - all_pieces).astype(np.bool) 278 | liberty_list.append(liberties) 279 | 280 | return liberty_list[0], liberty_list[1] 281 | 282 | 283 | def num_liberties(state: np.ndarray): 284 | black_liberties, white_liberties = liberties(state) 285 | black_liberties = np.count_nonzero(black_liberties) 286 | white_liberties = np.count_nonzero(white_liberties) 287 | 288 | return black_liberties, white_liberties 289 | 290 | 291 | def areas(state): 292 | ''' 293 | Return black area, white area 294 | ''' 295 | 296 | all_pieces = np.sum(state[[govars.BLACK, govars.WHITE]], axis=0) 297 | empties = 1 - all_pieces 298 | 299 | empty_labels, num_empty_areas = ndimage.measurements.label(empties) 300 | 301 | black_area, white_area = np.sum(state[govars.BLACK]), np.sum(state[govars.WHITE]) 302 | for label in range(1, num_empty_areas + 1): 303 | empty_area = empty_labels == label 304 | neighbors = ndimage.binary_dilation(empty_area) 305 | black_claim = False 306 | white_claim = False 307 | if (state[govars.BLACK] * neighbors > 0).any(): 308 | black_claim = True 309 | if (state[govars.WHITE] * neighbors > 0).any(): 310 | white_claim = True 311 | if black_claim and not white_claim: 312 | black_area += np.sum(empty_area) 313 | elif white_claim and not black_claim: 314 | white_area += np.sum(empty_area) 315 | 316 | return black_area, white_area 317 | 318 | 319 | def batch_areas(batch_state): 320 | black_areas, white_areas = [], [] 321 | 322 | for state in batch_state: 323 | ba, wa = areas(state) 324 | black_areas.append(ba) 325 | white_areas.append(wa) 326 | return np.array(black_areas), np.array(white_areas) 327 | 328 | 329 | def canonical_form(state): 330 | state = np.copy(state) 331 | if turn(state) == govars.WHITE: 332 | channels = np.arange(govars.NUM_CHNLS) 333 | channels[govars.BLACK] = govars.WHITE 334 | channels[govars.WHITE] = govars.BLACK 335 | state = state[channels] 336 | state_utils.set_turn(state) 337 | return state 338 | 339 | 340 | def batch_canonical_form(batch_state): 341 | batch_state = np.copy(batch_state) 342 | batch_player = batch_turn(batch_state) 343 | white_players_idcs = np.nonzero(batch_player == govars.WHITE)[0] 344 | 345 | channels = np.arange(govars.NUM_CHNLS) 346 | channels[govars.BLACK] = govars.WHITE 347 | channels[govars.WHITE] = govars.BLACK 348 | 349 | for i in white_players_idcs: 350 | batch_state[i] = batch_state[i, channels] 351 | batch_state[i, govars.TURN_CHNL] = 1 - batch_player[i] 352 | 353 | return batch_state 354 | 355 | 356 | def random_symmetry(image): 357 | """ 358 | Returns a random symmetry of the image 359 | :param image: A (C, BOARD_SIZE, BOARD_SIZE) numpy array, where C is any number 360 | :return: 361 | """ 362 | orientation = np.random.randint(0, 8) 363 | 364 | if (orientation >> 0) % 2: 365 | # Horizontal flip 366 | image = np.flip(image, 2) 367 | if (orientation >> 1) % 2: 368 | # Vertical flip 369 | image = np.flip(image, 1) 370 | if (orientation >> 2) % 2: 371 | # Rotate 90 degrees 372 | image = np.rot90(image, axes=(1, 2)) 373 | 374 | return image 375 | 376 | 377 | def all_symmetries(image): 378 | """ 379 | :param image: A (C, BOARD_SIZE, BOARD_SIZE) numpy array, where C is any number 380 | :return: All 8 orientations that are symmetrical in a Go game over the 2nd and 3rd axes 381 | (i.e. rotations, flipping and combos of them) 382 | """ 383 | symmetries = [] 384 | 385 | for i in range(8): 386 | x = image 387 | if (i >> 0) % 2: 388 | # Horizontal flip 389 | x = np.flip(x, 2) 390 | if (i >> 1) % 2: 391 | # Vertical flip 392 | x = np.flip(x, 1) 393 | if (i >> 2) % 2: 394 | # Rotation 90 degrees 395 | x = np.rot90(x, axes=(1, 2)) 396 | symmetries.append(x) 397 | 398 | return symmetries 399 | 400 | 401 | def random_weighted_action(move_weights): 402 | """ 403 | Assumes all invalid moves have weight 0 404 | Action is 1D 405 | Expected shape is (NUM OF MOVES, ) 406 | """ 407 | move_weights = preprocessing.normalize(move_weights[np.newaxis], norm='l1') 408 | return np.random.choice(np.arange(len(move_weights[0])), p=move_weights[0]) 409 | 410 | 411 | def random_action(state): 412 | """ 413 | Assumed to be (NUM_CHNLS, BOARD_SIZE, BOARD_SIZE) 414 | Action is 1D 415 | """ 416 | invalid_moves = state[govars.INVD_CHNL].flatten() 417 | invalid_moves = np.append(invalid_moves, 0) 418 | move_weights = 1 - invalid_moves 419 | 420 | return random_weighted_action(move_weights) 421 | 422 | 423 | def str(state): 424 | board_str = ' ' 425 | 426 | size = state.shape[1] 427 | for i in range(size): 428 | board_str += ' {}'.format(i) 429 | board_str += '\n ' 430 | board_str += '----' * size + '-' 431 | board_str += '\n' 432 | for i in range(size): 433 | board_str += '{} |'.format(i) 434 | for j in range(size): 435 | if state[0, i, j] == 1: 436 | board_str += ' B' 437 | elif state[1, i, j] == 1: 438 | board_str += ' W' 439 | elif state[2, i, j] == 1: 440 | board_str += ' .' 441 | else: 442 | board_str += ' ' 443 | 444 | board_str += ' |' 445 | 446 | board_str += '\n ' 447 | board_str += '----' * size + '-' 448 | board_str += '\n' 449 | 450 | black_area, white_area = areas(state) 451 | done = game_ended(state) 452 | ppp = prev_player_passed(state) 453 | t = turn(state) 454 | board_str += '\tTurn: {}, Last Turn Passed: {}, Game Over: {}\n'.format('B' if t == 0 else 'W', ppp, done) 455 | board_str += '\tBlack Area: {}, White Area: {}\n'.format(black_area, white_area) 456 | return board_str 457 | -------------------------------------------------------------------------------- /GymGo/gym_go/govars.py: -------------------------------------------------------------------------------- 1 | ANYONE = None 2 | NOONE = -1 3 | 4 | BLACK = 0 5 | WHITE = 1 6 | TURN_CHNL = 2 7 | INVD_CHNL = 3 8 | PASS_CHNL = 4 9 | DONE_CHNL = 5 10 | 11 | NUM_CHNLS = 6 12 | -------------------------------------------------------------------------------- /GymGo/gym_go/rendering.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pyglet 3 | 4 | from gym_go import govars, gogame 5 | 6 | 7 | def draw_circle(x, y, color, radius): 8 | num_sides = 50 9 | verts = [x, y] 10 | colors = list(color) 11 | for i in range(num_sides + 1): 12 | verts.append(x + radius * np.cos(i * np.pi * 2 / num_sides)) 13 | verts.append(y + radius * np.sin(i * np.pi * 2 / num_sides)) 14 | colors.extend(color) 15 | pyglet.graphics.draw(len(verts) // 2, pyglet.gl.GL_TRIANGLE_FAN, 16 | ('v2f', verts), ('c3f', colors)) 17 | 18 | 19 | def draw_command_labels(batch, window_width, window_height): 20 | pyglet.text.Label('Pass (p) | Reset (r) | Exit (e)', 21 | font_name='Helvetica', 22 | font_size=11, 23 | x=20, y=window_height - 20, anchor_y='top', batch=batch, multiline=True, width=window_width) 24 | 25 | 26 | def draw_info(batch, window_width, window_height, upper_grid_coord, state): 27 | turn = gogame.turn(state) 28 | turn_str = 'B' if turn == govars.BLACK else 'W' 29 | prev_player_passed = gogame.prev_player_passed(state) 30 | game_ended = gogame.game_ended(state) 31 | info_label = "Turn: {}\nPassed: {}\nGame: {}".format(turn_str, prev_player_passed, 32 | "OVER" if game_ended else "ONGOING") 33 | 34 | pyglet.text.Label(info_label, font_name='Helvetica', font_size=11, x=window_width - 20, y=window_height - 20, 35 | anchor_x='right', anchor_y='top', color=(0, 0, 0, 192), batch=batch, width=window_width / 2, 36 | align='right', multiline=True) 37 | 38 | # Areas 39 | black_area, white_area = gogame.areas(state) 40 | pyglet.text.Label("{}B | {}W".format(black_area, white_area), font_name='Helvetica', font_size=16, 41 | x=window_width / 2, y=upper_grid_coord + 80, anchor_x='center', color=(0, 0, 0, 192), batch=batch, 42 | width=window_width, align='center') 43 | 44 | 45 | def draw_title(batch, window_width, window_height): 46 | pyglet.text.Label("Go", font_name='Helvetica', font_size=20, bold=True, x=window_width / 2, y=window_height - 20, 47 | anchor_x='center', anchor_y='top', color=(0, 0, 0, 255), batch=batch, width=window_width / 2, 48 | align='center') 49 | 50 | 51 | def draw_grid(batch, delta, board_size, lower_grid_coord, upper_grid_coord): 52 | label_offset = 20 53 | left_coord = lower_grid_coord 54 | right_coord = lower_grid_coord 55 | ver_list = [] 56 | color_list = [] 57 | num_vert = 0 58 | for i in range(board_size): 59 | # horizontal 60 | ver_list.extend((lower_grid_coord, left_coord, 61 | upper_grid_coord, right_coord)) 62 | # vertical 63 | ver_list.extend((left_coord, lower_grid_coord, 64 | right_coord, upper_grid_coord)) 65 | color_list.extend([0.3, 0.3, 0.3] * 4) # black 66 | # label on the left 67 | pyglet.text.Label(str(i), 68 | font_name='Courier', font_size=11, 69 | x=lower_grid_coord - label_offset, y=left_coord, 70 | anchor_x='center', anchor_y='center', 71 | color=(0, 0, 0, 255), batch=batch) 72 | # label on the bottom 73 | pyglet.text.Label(str(i), 74 | font_name='Courier', font_size=11, 75 | x=left_coord, y=lower_grid_coord - label_offset, 76 | anchor_x='center', anchor_y='center', 77 | color=(0, 0, 0, 255), batch=batch) 78 | left_coord += delta 79 | right_coord += delta 80 | num_vert += 4 81 | batch.add(num_vert, pyglet.gl.GL_LINES, None, 82 | ('v2f/static', ver_list), ('c3f/static', color_list)) 83 | 84 | 85 | def draw_pieces(batch, lower_grid_coord, delta, piece_r, size, state): 86 | for i in range(size): 87 | for j in range(size): 88 | # black piece 89 | if state[0, i, j] == 1: 90 | draw_circle(lower_grid_coord + i * delta, lower_grid_coord + j * delta, 91 | [0.05882352963, 0.180392161, 0.2470588237], 92 | piece_r) # 0 for black 93 | 94 | # white piece 95 | if state[1, i, j] == 1: 96 | draw_circle(lower_grid_coord + i * delta, lower_grid_coord + j * delta, 97 | [0.9754120272] * 3, piece_r) # 255 for white 98 | -------------------------------------------------------------------------------- /GymGo/gym_go/state_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import ndimage 3 | from scipy.ndimage import measurements 4 | 5 | from GymGo.gym_go import govars 6 | 7 | group_struct = np.array([[[0, 0, 0], 8 | [0, 0, 0], 9 | [0, 0, 0]], 10 | [[0, 1, 0], 11 | [1, 1, 1], 12 | [0, 1, 0]], 13 | [[0, 0, 0], 14 | [0, 0, 0], 15 | [0, 0, 0]]]) 16 | 17 | surround_struct = np.array([[0, 1, 0], 18 | [1, 0, 1], 19 | [0, 1, 0]]) 20 | 21 | neighbor_deltas = np.array([[-1, 0], [1, 0], [0, -1], [0, 1]]) 22 | 23 | 24 | def compute_invalid_moves(state, player, ko_protect=None): 25 | """ 26 | Updates invalid moves in the OPPONENT's perspective 27 | 1.) Opponent cannot move at a location 28 | i.) If it's occupied(被占领的) 29 | i.) If it's protected by ko 30 | 2.) Opponent can move at a location 31 | i.) If it can kill 32 | 3.) Opponent cannot move at a location 33 | i.) If it's adjacent to one of their groups with only one liberty and 34 | not adjacent to other groups with more than one liberty and is completely surrounded 35 | ii.) If it's surrounded by our pieces and all of those corresponding groups 36 | move more than one liberty 37 | """ 38 | 39 | # All pieces and empty spaces 40 | # 棋盘所有有棋子的分布矩阵,有棋子的位置为1 41 | all_pieces = np.sum(state[[govars.BLACK, govars.WHITE]], axis=0) 42 | # 棋盘上所有空交叉点的分布矩阵,空交叉点位置为1 43 | empties = 1 - all_pieces 44 | 45 | # Setup invalid and valid arrays 46 | possible_invalid_array = np.zeros(state.shape[1:]) 47 | definite_valids_array = np.zeros(state.shape[1:]) 48 | 49 | # Get all groups 50 | # 上一步落子方各块棋子分布矩阵,及棋子块数 51 | all_own_groups, num_own_groups = measurements.label(state[player]) 52 | # 下一步落子方各块棋子分布矩阵,及棋子块数 53 | all_opp_groups, num_opp_groups = measurements.label(state[1 - player]) 54 | expanded_own_groups = np.zeros((num_own_groups, *state.shape[1:])) 55 | expanded_opp_groups = np.zeros((num_opp_groups, *state.shape[1:])) 56 | 57 | # Expand the groups such that each group is in its own channel 58 | for i in range(num_own_groups): 59 | expanded_own_groups[i] = all_own_groups == (i + 1) 60 | 61 | for i in range(num_opp_groups): 62 | expanded_opp_groups[i] = all_opp_groups == (i + 1) 63 | 64 | # Get all liberties in the expanded form 65 | # 计算每一块棋子的气分布矩阵 66 | # 其中np.newaxis == None,matrix[None]意思是在第0维增加一个维度 67 | # all_own_liberties和all_opp_liberties均是三维矩阵,代表每块棋子的气的分布 68 | all_own_liberties = empties[np.newaxis] * ndimage.binary_dilation(expanded_own_groups, surround_struct[np.newaxis]) 69 | all_opp_liberties = empties[np.newaxis] * ndimage.binary_dilation(expanded_opp_groups, surround_struct[np.newaxis]) 70 | 71 | # all_own_liberties和all_opp_liberties均是三维矩阵, np.sum( , axis=(1,2))针对每块棋子计算其气数 72 | # own_liberty_counts和opp_liberty_counts均是一维数组,每个元素代表每块棋的气 73 | own_liberty_counts = np.sum(all_own_liberties, axis=(1, 2)) 74 | opp_liberty_counts = np.sum(all_opp_liberties, axis=(1, 2)) 75 | 76 | # Possible invalids are on single liberties of opponent groups and on multi-liberties of own groups 77 | # Definite valids are on single liberties of own groups, multi-liberties of opponent groups 78 | # or you are not surrounded 79 | possible_invalid_array += np.sum(all_own_liberties[own_liberty_counts > 1], axis=0) 80 | possible_invalid_array += np.sum(all_opp_liberties[opp_liberty_counts == 1], axis=0) 81 | 82 | definite_valids_array += np.sum(all_own_liberties[own_liberty_counts == 1], axis=0) 83 | definite_valids_array += np.sum(all_opp_liberties[opp_liberty_counts > 1], axis=0) 84 | 85 | # All invalid moves are occupied spaces + (possible invalids minus the definite valids and it's surrounded) 86 | surrounded = ndimage.convolve(all_pieces, surround_struct, mode='constant', cval=1) == 4 87 | invalid_moves = all_pieces + possible_invalid_array * (definite_valids_array == 0) * surrounded 88 | 89 | # Ko-protection 90 | if ko_protect is not None: 91 | invalid_moves[ko_protect[0], ko_protect[1]] = 1 92 | return invalid_moves > 0 93 | 94 | 95 | def batch_compute_invalid_moves(batch_state, batch_player, batch_ko_protect): 96 | """ 97 | Updates invalid moves in the OPPONENT's perspective 98 | 1.) Opponent cannot move at a location 99 | i.) If it's occupied 100 | i.) If it's protected by ko 101 | 2.) Opponent can move at a location 102 | i.) If it can kill 103 | 3.) Opponent cannot move at a location 104 | i.) If it's adjacent to one of their groups with only one liberty and 105 | not adjacent to other groups with more than one liberty and is completely surrounded 106 | ii.) If it's surrounded by our pieces and all of those corresponding groups 107 | move more than one liberty 108 | """ 109 | batch_idcs = np.arange(len(batch_state)) 110 | 111 | # All pieces and empty spaces 112 | batch_all_pieces = np.sum(batch_state[:, [govars.BLACK, govars.WHITE]], axis=1) 113 | batch_empties = 1 - batch_all_pieces 114 | 115 | # Setup invalid and valid arrays 116 | batch_possible_invalid_array = np.zeros(batch_state.shape[:1] + batch_state.shape[2:]) 117 | batch_definite_valids_array = np.zeros(batch_state.shape[:1] + batch_state.shape[2:]) 118 | 119 | # Get all groups 120 | batch_all_own_groups, _ = measurements.label(batch_state[batch_idcs, batch_player], group_struct) 121 | batch_all_opp_groups, _ = measurements.label(batch_state[batch_idcs, 1 - batch_player], group_struct) 122 | 123 | batch_data = enumerate(zip(batch_all_own_groups, batch_all_opp_groups, batch_empties)) 124 | for i, (all_own_groups, all_opp_groups, empties) in batch_data: 125 | own_labels = np.unique(all_own_groups) 126 | opp_labels = np.unique(all_opp_groups) 127 | own_labels = own_labels[np.nonzero(own_labels)] 128 | opp_labels = opp_labels[np.nonzero(opp_labels)] 129 | expanded_own_groups = np.zeros((len(own_labels), *all_own_groups.shape)) 130 | expanded_opp_groups = np.zeros((len(opp_labels), *all_opp_groups.shape)) 131 | 132 | # Expand the groups such that each group is in its own channel 133 | for j, label in enumerate(own_labels): 134 | expanded_own_groups[j] = all_own_groups == label 135 | 136 | for j, label in enumerate(opp_labels): 137 | expanded_opp_groups[j] = all_opp_groups == label 138 | 139 | # Get all liberties in the expanded form 140 | all_own_liberties = empties[np.newaxis] * ndimage.binary_dilation(expanded_own_groups, 141 | surround_struct[np.newaxis]) 142 | all_opp_liberties = empties[np.newaxis] * ndimage.binary_dilation(expanded_opp_groups, 143 | surround_struct[np.newaxis]) 144 | 145 | own_liberty_counts = np.sum(all_own_liberties, axis=(1, 2)) 146 | opp_liberty_counts = np.sum(all_opp_liberties, axis=(1, 2)) 147 | 148 | # Possible invalids are on single liberties of opponent groups and on multi-liberties of own groups 149 | # Definite valids are on single liberties of own groups, multi-liberties of opponent groups 150 | # or you are not surrounded 151 | batch_possible_invalid_array[i] += np.sum(all_own_liberties[own_liberty_counts > 1], axis=0) 152 | batch_possible_invalid_array[i] += np.sum(all_opp_liberties[opp_liberty_counts == 1], axis=0) 153 | 154 | batch_definite_valids_array[i] += np.sum(all_own_liberties[own_liberty_counts == 1], axis=0) 155 | batch_definite_valids_array[i] += np.sum(all_opp_liberties[opp_liberty_counts > 1], axis=0) 156 | 157 | # All invalid moves are occupied spaces + (possible invalids minus the definite valids and it's surrounded) 158 | surrounded = ndimage.convolve(batch_all_pieces, surround_struct[np.newaxis], mode='constant', cval=1) == 4 159 | invalid_moves = batch_all_pieces + batch_possible_invalid_array * (batch_definite_valids_array == 0) * surrounded 160 | 161 | # Ko-protection 162 | for i, ko_protect in enumerate(batch_ko_protect): 163 | if ko_protect is not None: 164 | invalid_moves[i, ko_protect[0], ko_protect[1]] = 1 165 | return invalid_moves > 0 166 | 167 | 168 | def update_pieces(state, adj_locs, player): 169 | opponent = 1 - player 170 | killed_groups = [] 171 | 172 | all_pieces = np.sum(state[[govars.BLACK, govars.WHITE]], axis=0) 173 | empties = 1 - all_pieces 174 | 175 | all_opp_groups, _ = ndimage.measurements.label(state[opponent]) 176 | 177 | # Go through opponent groups 178 | all_adj_labels = all_opp_groups[adj_locs[:, 0], adj_locs[:, 1]] 179 | all_adj_labels = np.unique(all_adj_labels) 180 | for opp_group_idx in all_adj_labels[np.nonzero(all_adj_labels)]: 181 | opp_group = all_opp_groups == opp_group_idx 182 | liberties = empties * ndimage.binary_dilation(opp_group) 183 | if np.sum(liberties) <= 0: 184 | # Killed group 185 | opp_group_locs = np.argwhere(opp_group) 186 | state[opponent, opp_group_locs[:, 0], opp_group_locs[:, 1]] = 0 187 | killed_groups.append(opp_group_locs) 188 | 189 | return killed_groups 190 | 191 | 192 | def batch_update_pieces(batch_non_pass, batch_state, batch_adj_locs, batch_player): 193 | batch_opponent = 1 - batch_player 194 | batch_killed_groups = [] 195 | 196 | batch_all_pieces = np.sum(batch_state[:, [govars.BLACK, govars.WHITE]], axis=1) 197 | batch_empties = 1 - batch_all_pieces 198 | 199 | batch_all_opp_groups, _ = ndimage.measurements.label(batch_state[batch_non_pass, batch_opponent], 200 | group_struct) 201 | 202 | batch_data = enumerate(zip(batch_all_opp_groups, batch_all_pieces, batch_empties, batch_adj_locs, batch_opponent)) 203 | for i, (all_opp_groups, all_pieces, empties, adj_locs, opponent) in batch_data: 204 | killed_groups = [] 205 | 206 | # Go through opponent groups 207 | all_adj_labels = all_opp_groups[adj_locs[:, 0], adj_locs[:, 1]] 208 | all_adj_labels = np.unique(all_adj_labels) 209 | for opp_group_idx in all_adj_labels[np.nonzero(all_adj_labels)]: 210 | opp_group = all_opp_groups == opp_group_idx 211 | liberties = empties * ndimage.binary_dilation(opp_group) 212 | if np.sum(liberties) <= 0: 213 | # Killed group 214 | opp_group_locs = np.argwhere(opp_group) 215 | batch_state[batch_non_pass[i], opponent, opp_group_locs[:, 0], opp_group_locs[:, 1]] = 0 216 | killed_groups.append(opp_group_locs) 217 | 218 | batch_killed_groups.append(killed_groups) 219 | 220 | return batch_killed_groups 221 | 222 | 223 | def adj_data(state, action2d, player): 224 | neighbors = neighbor_deltas + action2d 225 | valid = (neighbors >= 0) & (neighbors < state.shape[1]) 226 | valid = np.prod(valid, axis=1) 227 | neighbors = neighbors[np.nonzero(valid)] 228 | 229 | opp_pieces = state[1 - player] 230 | surrounded = (opp_pieces[neighbors[:, 0], neighbors[:, 1]] > 0).all() 231 | 232 | return neighbors, surrounded 233 | 234 | 235 | def batch_adj_data(batch_state, batch_action2d, batch_player): 236 | batch_neighbors, batch_surrounded = [], [] 237 | for state, action2d, player in zip(batch_state, batch_action2d, batch_player): 238 | neighbors, surrounded = adj_data(state, action2d, player) 239 | batch_neighbors.append(neighbors) 240 | batch_surrounded.append(surrounded) 241 | return batch_neighbors, batch_surrounded 242 | 243 | 244 | def set_turn(state): 245 | """ 246 | Swaps turn 247 | :param state: 248 | :return: 249 | """ 250 | state[govars.TURN_CHNL] = 1 - state[govars.TURN_CHNL] 251 | 252 | 253 | def batch_set_turn(batch_state): 254 | """ 255 | Swaps turn 256 | :param batch_state: 257 | :return: 258 | """ 259 | batch_state[:, govars.TURN_CHNL] = 1 - batch_state[:, govars.TURN_CHNL] 260 | -------------------------------------------------------------------------------- /GymGo/gym_go/tests/efficiency.py: -------------------------------------------------------------------------------- 1 | import time 2 | import unittest 3 | 4 | import gym 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | 9 | class Efficiency(unittest.TestCase): 10 | boardsize = 9 11 | iterations = 64 12 | 13 | def setUp(self) -> None: 14 | self.env = gym.make('gym_go:go-v0', size=self.boardsize, reward_method='real') 15 | 16 | def testOrderedTrajs(self): 17 | durs = [] 18 | for _ in tqdm(range(self.iterations)): 19 | start = time.time() 20 | self.env.reset() 21 | for a in range(self.boardsize ** 2 - 2): 22 | self.env.step(a) 23 | end = time.time() 24 | 25 | dur = end - start 26 | durs.append(dur) 27 | 28 | avg_time = np.mean(durs) 29 | std_time = np.std(durs) 30 | print(f"Ordered Trajs: {avg_time:.3f} AVG, {std_time:.3f} STD", flush=True) 31 | 32 | def testLowerBound(self): 33 | durs = [] 34 | for _ in tqdm(range(self.iterations)): 35 | start = time.time() 36 | state = self.env.reset() 37 | 38 | max_steps = self.boardsize ** 2 39 | for s in range(max_steps): 40 | for _ in range(max_steps - s): 41 | np.copy(state) 42 | 43 | pi = np.ones(self.boardsize ** 2 + 1) / (self.boardsize ** 2 + 1) 44 | a = np.random.choice(np.arange(self.boardsize ** 2 + 1), p=pi) 45 | np.copy(state) 46 | 47 | end = time.time() 48 | 49 | dur = end - start 50 | durs.append(dur) 51 | 52 | avg_time = np.mean(durs) 53 | std_time = np.std(durs) 54 | print(f"Lower bound: {avg_time:.3f} AVG, {std_time:.3f} STD", flush=True) 55 | 56 | def testRandTrajsWithChildren(self): 57 | durs = [] 58 | num_steps = [] 59 | for _ in tqdm(range(self.iterations)): 60 | start = time.time() 61 | self.env.reset() 62 | 63 | max_steps = 2 * self.boardsize ** 2 64 | s = 0 65 | for s in range(max_steps): 66 | valid_moves = self.env.valid_moves() 67 | self.env.children(canonical=True) 68 | # Do not pass if possible 69 | if np.sum(valid_moves) > 1: 70 | valid_moves[-1] = 0 71 | probs = valid_moves / np.sum(valid_moves) 72 | a = np.random.choice(np.arange(self.boardsize ** 2 + 1), p=probs) 73 | state, _, done, _ = self.env.step(a) 74 | if done: 75 | break 76 | num_steps.append(s) 77 | 78 | end = time.time() 79 | 80 | dur = end - start 81 | durs.append(dur) 82 | 83 | avg_time = np.mean(durs) 84 | std_time = np.std(durs) 85 | avg_steps = np.mean(num_steps) 86 | print(f"Rand Trajs w/ Children: {avg_time:.3f} AVG SEC, {std_time:.3f} STD SEC, {avg_steps:.1f} AVG STEPS", 87 | flush=True) 88 | 89 | 90 | if __name__ == '__main__': 91 | unittest.main() 92 | -------------------------------------------------------------------------------- /GymGo/gym_go/tests/test_basics.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from gym_go import govars 7 | 8 | 9 | class TestGoEnvBasics(unittest.TestCase): 10 | 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.env = gym.make('gym_go:go-v0', size=7, reward_method='real') 14 | 15 | def setUp(self): 16 | self.env.reset() 17 | 18 | def test_state(self): 19 | env = gym.make('gym_go:go-v0', size=7) 20 | state = env.reset() 21 | self.assertIsInstance(state, np.ndarray) 22 | self.assertEqual(state.shape[0], govars.NUM_CHNLS) 23 | 24 | env.close() 25 | 26 | def test_board_sizes(self): 27 | expected_sizes = [7, 13, 19] 28 | 29 | for expec_size in expected_sizes: 30 | env = gym.make('gym_go:go-v0', size=expec_size) 31 | state = env.reset() 32 | self.assertEqual(state.shape[1], expec_size) 33 | self.assertEqual(state.shape[2], expec_size) 34 | 35 | env.close() 36 | 37 | def test_empty_board(self): 38 | state = self.env.reset() 39 | self.assertEqual(np.count_nonzero(state), 0) 40 | 41 | def test_reset(self): 42 | state, reward, done, info = self.env.step((0, 0)) 43 | self.assertEqual(np.count_nonzero(state[[govars.BLACK, govars.WHITE, govars.INVD_CHNL]]), 2) 44 | self.assertEqual(np.count_nonzero(state), 51) 45 | state = self.env.reset() 46 | self.assertEqual(np.count_nonzero(state), 0) 47 | 48 | def test_preserve_original_state(self): 49 | state = self.env.reset() 50 | original_state = np.copy(state) 51 | self.env.gogame.next_state(state, 0) 52 | assert (original_state == state).all() 53 | 54 | def test_black_moves_first(self): 55 | """ 56 | Make a move at 0,0 and assert that a black piece was placed 57 | :return: 58 | """ 59 | next_state, reward, done, info = self.env.step((0, 0)) 60 | self.assertEqual(next_state[govars.BLACK, 0, 0], 1) 61 | self.assertEqual(next_state[govars.WHITE, 0, 0], 0) 62 | 63 | def test_turns(self): 64 | for i in range(7): 65 | # For the first move at i == 0, black went so now it should be white's turn 66 | state, reward, done, info = self.env.step((i, 0)) 67 | self.assertIn('turn', info) 68 | self.assertEqual(info['turn'], 1 if i % 2 == 0 else 0) 69 | 70 | def test_multiple_action_formats(self): 71 | for _ in range(10): 72 | action_1d = np.random.randint(50) 73 | action_2d = None if action_1d == 49 else (action_1d // 7, action_1d % 7) 74 | 75 | self.env.reset() 76 | state_from_1d, _, _, _ = self.env.step(action_1d) 77 | 78 | self.env.reset() 79 | state_from_2d, _, _, _ = self.env.step(action_2d) 80 | 81 | self.assertTrue((state_from_1d == state_from_2d).all()) 82 | 83 | def test_passing(self): 84 | """ 85 | None indicates pass 86 | :return: 87 | """ 88 | 89 | # Pass on first move 90 | state, reward, done, info = self.env.step(None) 91 | # Expect empty board still 92 | self.assertEqual(np.count_nonzero(state[[govars.BLACK, govars.WHITE]]), 0) 93 | # Expect passing layer and turn layer channels to be all ones 94 | self.assertEqual(np.count_nonzero(state), 98, state) 95 | self.assertEqual(np.count_nonzero(state[govars.PASS_CHNL]), 49) 96 | self.assertEqual(np.count_nonzero(state[govars.PASS_CHNL] == 1), 49) 97 | 98 | self.assertIn('turn', info) 99 | self.assertEqual(info['turn'], 1) 100 | 101 | # Make a move 102 | state, reward, done, info = self.env.step((0, 0)) 103 | 104 | # Expect the passing layer channel to be empty 105 | self.assertEqual(np.count_nonzero(state), 2) 106 | self.assertEqual(np.count_nonzero(state[govars.WHITE]), 1) 107 | self.assertEqual(np.count_nonzero(state[govars.WHITE] == 1), 1) 108 | self.assertEqual(np.count_nonzero(state[govars.PASS_CHNL]), 0) 109 | 110 | # Pass on second move 111 | self.env.reset() 112 | state, reward, done, info = self.env.step((0, 0)) 113 | # Expect two pieces (one in the invalid channel) 114 | # Plus turn layer is all ones 115 | self.assertEqual(np.count_nonzero(state), 51, state) 116 | self.assertEqual(np.count_nonzero(state[[govars.BLACK, govars.WHITE, govars.INVD_CHNL]]), 2, state) 117 | 118 | self.assertIn('turn', info) 119 | self.assertEqual(info['turn'], 1) 120 | 121 | # Pass 122 | state, reward, done, info = self.env.step(None) 123 | # Expect two pieces (one in the invalid channel) 124 | self.assertEqual(np.count_nonzero(state[[govars.BLACK, govars.WHITE, govars.INVD_CHNL]]), 2, 125 | state[[govars.BLACK, govars.WHITE, govars.INVD_CHNL]]) 126 | self.assertIn('turn', info) 127 | self.assertEqual(info['turn'], 0) 128 | 129 | def test_game_ends(self): 130 | state, reward, done, info = self.env.step(None) 131 | self.assertFalse(done) 132 | state, reward, done, info = self.env.step(None) 133 | self.assertTrue(done) 134 | 135 | self.env.reset() 136 | 137 | state, reward, done, info = self.env.step((0, 0)) 138 | self.assertFalse(done) 139 | state, reward, done, info = self.env.step(None) 140 | self.assertFalse(done) 141 | state, reward, done, info = self.env.step(None) 142 | self.assertTrue(done) 143 | 144 | def test_game_does_not_end_with_disjoint_passes(self): 145 | state, reward, done, info = self.env.step(None) 146 | self.assertFalse(done) 147 | state, reward, done, info = self.env.step((0, 0)) 148 | self.assertFalse(done) 149 | state, reward, done, info = self.env.step(None) 150 | self.assertFalse(done) 151 | 152 | def test_num_liberties(self): 153 | env = gym.make('gym_go:go-v0', size=7) 154 | 155 | steps = [(0, 0), (0, 1)] 156 | libs = [(2, 0), (1, 2)] 157 | 158 | env.reset() 159 | for step, libs in zip(steps, libs): 160 | state, _, _, _ = env.step(step) 161 | blacklibs, whitelibs = env.gogame.num_liberties(state) 162 | self.assertEqual(blacklibs, libs[0], state) 163 | self.assertEqual(whitelibs, libs[1], state) 164 | 165 | steps = [(2, 1), None, (1, 2), None, (2, 3), None, (3, 2), None] 166 | libs = [(4, 0), (4, 0), (6, 0), (6, 0), (8, 0), (8, 0), (9, 0), (9, 0)] 167 | 168 | env.reset() 169 | for step, libs in zip(steps, libs): 170 | state, _, _, _ = env.step(step) 171 | blacklibs, whitelibs = env.gogame.num_liberties(state) 172 | self.assertEqual(blacklibs, libs[0], state) 173 | self.assertEqual(whitelibs, libs[1], state) 174 | 175 | def test_komi(self): 176 | env = gym.make('gym_go:go-v0', size=7, komi=2.5, reward_method='real') 177 | 178 | # White win 179 | _ = env.step(None) 180 | state, reward, done, info = env.step(None) 181 | self.assertEqual(-1, reward) 182 | 183 | # White still win 184 | env.reset() 185 | _ = env.step(0) 186 | _ = env.step(2) 187 | 188 | _ = env.step(1) 189 | _ = env.step(None) 190 | 191 | state, reward, done, info = env.step(None) 192 | self.assertEqual(-1, reward) 193 | 194 | # Black win 195 | env.reset() 196 | _ = env.step(0) 197 | _ = env.step(None) 198 | 199 | _ = env.step(1) 200 | _ = env.step(None) 201 | 202 | _ = env.step(2) 203 | _ = env.step(None) 204 | state, reward, done, info = env.step(None) 205 | self.assertEqual(1, reward) 206 | 207 | env.close() 208 | 209 | def test_children(self): 210 | for canonical in [False, True]: 211 | for _ in range(20): 212 | action = self.env.uniform_random_action() 213 | self.env.step(action) 214 | state = self.env.state() 215 | children = self.env.children(canonical, padded=True) 216 | valid_moves = self.env.valid_moves() 217 | for a in range(len(valid_moves)): 218 | if valid_moves[a]: 219 | child = self.env.gogame.next_state(state, a, canonical) 220 | equal = children[a] == child 221 | self.assertTrue(equal.all(), (canonical, np.argwhere(~equal))) 222 | else: 223 | self.assertTrue((children[a] == 0).all()) 224 | 225 | def test_real_reward(self): 226 | env = gym.make('gym_go:go-v0', size=7, reward_method='real') 227 | 228 | # In game 229 | state, reward, done, info = env.step((0, 0)) 230 | self.assertEqual(reward, 0) 231 | state, reward, done, info = env.step(None) 232 | self.assertEqual(reward, 0) 233 | 234 | # Win 235 | state, reward, done, info = env.step(None) 236 | self.assertEqual(reward, 1) 237 | 238 | # Lose 239 | env.reset() 240 | 241 | state, reward, done, info = env.step(None) 242 | self.assertEqual(reward, 0) 243 | state, reward, done, info = env.step((0, 0)) 244 | self.assertEqual(reward, 0) 245 | state, reward, done, info = env.step(None) 246 | self.assertEqual(reward, 0) 247 | state, reward, done, info = env.step(None) 248 | self.assertEqual(reward, -1) 249 | 250 | # Tie 251 | env.reset() 252 | 253 | state, reward, done, info = env.step(None) 254 | self.assertEqual(reward, 0) 255 | state, reward, done, info = env.step(None) 256 | self.assertEqual(reward, 0) 257 | 258 | env.close() 259 | 260 | def test_heuristic_reward(self): 261 | env = gym.make('gym_go:go-v0', size=7, reward_method='heuristic') 262 | 263 | # In game 264 | state, reward, done, info = env.step((0, 0)) 265 | self.assertEqual(reward, 49) 266 | state, reward, done, info = env.step((0, 1)) 267 | self.assertEqual(reward, 0) 268 | state, reward, done, info = env.step(None) 269 | self.assertEqual(reward, 0) 270 | state, reward, done, info = env.step((1, 0)) 271 | self.assertEqual(reward, -49) 272 | 273 | # Lose 274 | state, reward, done, info = env.step(None) 275 | self.assertEqual(reward, -49) 276 | state, reward, done, info = env.step(None) 277 | self.assertEqual(reward, -49) 278 | 279 | # Win 280 | env.reset() 281 | 282 | state, reward, done, info = env.step((0, 0)) 283 | self.assertEqual(reward, 49) 284 | state, reward, done, info = env.step(None) 285 | self.assertEqual(reward, 49) 286 | state, reward, done, info = env.step(None) 287 | self.assertEqual(reward, 49) 288 | 289 | env.close() 290 | 291 | 292 | if __name__ == '__main__': 293 | unittest.main() 294 | -------------------------------------------------------------------------------- /GymGo/gym_go/tests/test_batch_fns.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from gym_go import gogame, govars 4 | 5 | 6 | class TestBatchFns(unittest.TestCase): 7 | 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | 11 | def setUp(self): 12 | pass 13 | 14 | def test_batch_canonical_form(self): 15 | states = gogame.batch_init_state(2, 7) 16 | states[0] = gogame.next_state(states[0], 0) 17 | 18 | self.assertEqual(states[0, govars.BLACK].sum(), 1) 19 | self.assertEqual(states[0, govars.WHITE].sum(), 0) 20 | 21 | states = gogame.batch_canonical_form(states) 22 | 23 | self.assertEqual(states[0, govars.BLACK].sum(), 0) 24 | self.assertEqual(states[0, govars.WHITE].sum(), 1) 25 | 26 | self.assertEqual(states[1, govars.BLACK].sum(), 0) 27 | self.assertEqual(states[1, govars.WHITE].sum(), 0) 28 | 29 | for i in range(2): 30 | self.assertEqual(gogame.turn(states[i]), govars.BLACK) 31 | 32 | canon_again = gogame.batch_canonical_form(states) 33 | 34 | self.assertTrue((canon_again == states).all()) 35 | 36 | 37 | if __name__ == '__main__': 38 | unittest.main() 39 | -------------------------------------------------------------------------------- /GymGo/gym_go/tests/test_invalid_moves.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | import gym 5 | import numpy as np 6 | 7 | from gym_go import govars 8 | 9 | 10 | class TestGoEnvInvalidMoves(unittest.TestCase): 11 | 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | self.env = gym.make('gym_go:go-v0', size=7, reward_method='real') 15 | 16 | def setUp(self): 17 | self.env.reset() 18 | 19 | def test_out_of_bounds_action(self): 20 | with self.assertRaises(Exception): 21 | self.env.step((-1, 0)) 22 | 23 | with self.assertRaises(Exception): 24 | self.env.step((0, 100)) 25 | 26 | def test_invalid_occupied_moves(self): 27 | # Test this 8 times at random 28 | for _ in range(8): 29 | self.env.reset() 30 | row = random.randint(0, 6) 31 | col = random.randint(0, 6) 32 | 33 | state, reward, done, info = self.env.step((row, col)) 34 | 35 | # Assert that the invalid layer is correct 36 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL]), 1) 37 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL] == 1), 1) 38 | self.assertEqual(state[govars.INVD_CHNL, row, col], 1) 39 | 40 | with self.assertRaises(Exception): 41 | self.env.step((row, col)) 42 | 43 | def test_invalid_ko_protection_moves(self): 44 | """ 45 | _, 1, 2, _, _, _, _, 46 | 47 | 3, 8, 7/9, 4, _, _, _, 48 | 49 | _, 5, 6, _, _, _, _, 50 | 51 | _, _, _, _, _, _, _, 52 | 53 | _, _, _, _, _, _, _, 54 | 55 | _, _, _, _, _, _, _, 56 | 57 | _, _, _, _, _, _, _, 58 | 59 | :return: 60 | """ 61 | 62 | for move in [(0, 1), (0, 2), (1, 0), (1, 3), (2, 1), (2, 2), (1, 2), (1, 1)]: 63 | state, reward, done, info = self.env.step(move) 64 | 65 | # Test invalid channel 66 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL]), 8, state[govars.INVD_CHNL]) 67 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL] == 1), 8) 68 | self.assertEqual(state[govars.INVD_CHNL, 1, 2], 1) 69 | 70 | # Assert pieces channel is empty at ko-protection coordinate 71 | self.assertEqual(state[govars.BLACK, 1, 2], 0) 72 | self.assertEqual(state[govars.WHITE, 1, 2], 0) 73 | 74 | final_move = (1, 2) 75 | with self.assertRaises(Exception): 76 | self.env.step(final_move) 77 | 78 | # Assert ko-protection goes off 79 | state, reward, done, info = self.env.step((6, 6)) 80 | state, reward, done, info = self.env.step(None) 81 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL]), 8) 82 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL] == 1), 8) 83 | self.assertEqual(state[govars.INVD_CHNL, 1, 2], 0) 84 | 85 | def test_invalid_ko_wall_protection_moves(self): 86 | """ 87 | 2/8, 7, 6, _, _, _, _, 88 | 89 | 1, 4, _, _, _, _, _, 90 | 91 | _, _, _, _, _, _, _, 92 | 93 | _, _, _, _, _, _, _, 94 | 95 | _, _, _, _, _, _, _, 96 | 97 | _, _, _, _, _, _, _, 98 | 99 | _, _, _, _, _, _, _, 100 | 101 | :return: 102 | """ 103 | 104 | for move in [(1, 0), (0, 0), None, (1, 1), None, (0, 2), (0, 1)]: 105 | state, reward, done, info = self.env.step(move) 106 | 107 | # Test invalid channel 108 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL]), 5, state[govars.INVD_CHNL]) 109 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL] == 1), 5) 110 | self.assertEqual(state[govars.INVD_CHNL, 0, 0], 1) 111 | 112 | # Assert pieces channel is empty at ko-protection coordinate 113 | self.assertEqual(state[govars.BLACK, 0, 0], 0) 114 | self.assertEqual(state[govars.WHITE, 0, 0], 0) 115 | 116 | final_move = (0, 0) 117 | with self.assertRaises(Exception): 118 | self.env.step(final_move) 119 | 120 | # Assert ko-protection goes off 121 | state, reward, done, info = self.env.step((6, 6)) 122 | state, reward, done, info = self.env.step(None) 123 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL]), 5) 124 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL] == 1), 5) 125 | self.assertEqual(state[govars.INVD_CHNL, 0, 0], 0) 126 | 127 | def test_invalid_no_liberty_move(self): 128 | """ 129 | _, 1, 2, _, _, _, _, 130 | 131 | 3, 8, 7, _, 4, _, _, 132 | 133 | _, 5, 6, _, _, _, _, 134 | 135 | _, _, _, _, _, _, _, 136 | 137 | _, _, _, _, _, _, _, 138 | 139 | _, _, _, _, _, _, _, 140 | 141 | _, _, _, _, _, _, _, 142 | 143 | :return: 144 | """ 145 | for move in [(0, 1), (0, 2), (1, 0), (1, 4), (2, 1), (2, 2), (1, 2)]: 146 | state, reward, done, info = self.env.step(move) 147 | 148 | # Test invalid channel 149 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL]), 9, state[govars.INVD_CHNL]) 150 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL] == 1), 9) 151 | self.assertEqual(state[govars.INVD_CHNL, 1, 1], 1) 152 | self.assertEqual(state[govars.INVD_CHNL, 0, 0], 1) 153 | # Assert empty space in pieces channels 154 | self.assertEqual(state[govars.BLACK, 1, 1], 0) 155 | self.assertEqual(state[govars.WHITE, 1, 1], 0) 156 | self.assertEqual(state[govars.BLACK, 0, 0], 0) 157 | self.assertEqual(state[govars.WHITE, 0, 0], 0) 158 | 159 | final_move = (1, 1) 160 | with self.assertRaises(Exception): 161 | self.env.step(final_move) 162 | 163 | def test_invalid_game_already_over_move(self): 164 | self.env.step(None) 165 | self.env.step(None) 166 | 167 | with self.assertRaises(Exception): 168 | self.env.step(None) 169 | 170 | self.env.reset() 171 | 172 | self.env.step(None) 173 | self.env.step(None) 174 | 175 | with self.assertRaises(Exception): 176 | self.env.step((0, 0)) 177 | 178 | def test_small_suicide(self): 179 | """ 180 | 7, 8, 0, 181 | 182 | 0, 5, 4, 183 | 184 | 1, 2, 3/6, 185 | :return: 186 | """ 187 | 188 | self.env = gym.make('gym_go:go-v0', size=3, reward_method='real') 189 | for move in [6, 7, 8, 5, 4, 8, 0, 1]: 190 | state, reward, done, info = self.env.step(move) 191 | 192 | with self.assertRaises(Exception): 193 | self.env.step(3) 194 | 195 | def test_invalid_after_capture(self): 196 | """ 197 | 1, 5, 6, 198 | 199 | 7, 4, _, 200 | 201 | 3, 8, 2, 202 | :return: 203 | """ 204 | 205 | self.env = gym.make('gym_go:go-v0', size=3, reward_method='real') 206 | for move in [0, 8, 6, 4, 1, 2, 3, 7]: 207 | state, reward, done, info = self.env.step(move) 208 | 209 | with self.assertRaises(Exception): 210 | self.env.step(5) 211 | 212 | def test_cannot_capture_groups_with_multiple_holes(self): 213 | """ 214 | _, 2, 4, 6, 8, 10, _, 215 | 216 | 32, 1, 3, 5, 7, 9, 12, 217 | 218 | 30, 25, 34, 19, _, 11, 14, 219 | 220 | 28, 23, 21, 17, 15, 13, 16, 221 | 222 | _, 26, 24, 22, 20, 18, _, 223 | 224 | _, _, _, _, _, _, _, 225 | 226 | _, _, _, _, _, _, _, 227 | 228 | :return: 229 | """ 230 | for move in [(1, 1), (0, 1), (1, 2), (0, 2), (1, 3), (0, 3), (1, 4), (0, 4), (1, 5), (0, 5), (2, 5), (1, 6), 231 | (3, 5), (2, 6), (3, 4), (3, 6), 232 | (3, 3), (4, 5), (2, 3), (4, 4), (3, 2), (4, 3), (3, 1), (4, 2), (2, 1), (4, 1), None, (3, 0), None, 233 | (2, 0), None, (1, 0)]: 234 | state, reward, done, info = self.env.step(move) 235 | 236 | self.env.step(None) 237 | final_move = (2, 2) 238 | with self.assertRaises(Exception): 239 | self.env.step(final_move) 240 | 241 | 242 | if __name__ == '__main__': 243 | unittest.main() 244 | -------------------------------------------------------------------------------- /GymGo/gym_go/tests/test_valid_moves.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from gym_go import govars 7 | 8 | 9 | class TestGoEnvValidMoves(unittest.TestCase): 10 | 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.env = gym.make('gym_go:go-v0', size=7, reward_method='real') 14 | 15 | def setUp(self): 16 | self.env.reset() 17 | 18 | def test_simple_valid_moves(self): 19 | for i in range(7): 20 | state, reward, done, info = self.env.step((0, i)) 21 | self.assertEqual(done, False) 22 | 23 | self.env.reset() 24 | 25 | for i in range(7): 26 | state, reward, done, info = self.env.step((i, i)) 27 | self.assertEqual(done, False) 28 | 29 | self.env.reset() 30 | 31 | for i in range(7): 32 | state, reward, done, info = self.env.step((i, 0)) 33 | self.assertEqual(done, False) 34 | 35 | def test_valid_no_liberty_move(self): 36 | """ 37 | _, 1, 2, _, _, _, _, 38 | 39 | 3, 8, 7, 4, _, _, _, 40 | 41 | _, 5, 6, _, _, _, _, 42 | 43 | _, _, _, _, _, _, _, 44 | 45 | _, _, _, _, _, _, _, 46 | 47 | _, _, _, _, _, _, _, 48 | 49 | _, _, _, _, _, _, _, 50 | 51 | 52 | :return: 53 | """ 54 | for move in [(0, 1), (0, 2), (1, 0), (1, 3), (2, 1), (2, 2), (1, 2), (1, 1)]: 55 | state, reward, done, info = self.env.step(move) 56 | 57 | # Black should have 3 pieces 58 | self.assertEqual(np.count_nonzero(state[govars.BLACK]), 3) 59 | 60 | # White should have 4 pieces 61 | self.assertEqual(np.count_nonzero(state[govars.WHITE]), 4) 62 | # Assert values are ones 63 | self.assertEqual(np.count_nonzero(state[govars.WHITE] == 1), 4) 64 | 65 | def test_valid_no_liberty_capture(self): 66 | """ 67 | 1, 7, 2, 3, _, _, _, 68 | 69 | 6, 4, 5, _, _, _, _, 70 | 71 | _, _, _, _, _, _, _, 72 | 73 | _, _, _, _, _, _, _, 74 | 75 | _, _, _, _, _, _, _, 76 | 77 | _, _, _, _, _, _, _, 78 | 79 | _, _, _, _, _, _, _, 80 | 81 | :return: 82 | """ 83 | for move in [(0, 0), (0, 2), (0, 3), (1, 1), (1, 2), (1, 0)]: 84 | state, reward, done, info = self.env.step(move) 85 | 86 | # Test invalid channel 87 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL]), 6, state[govars.INVD_CHNL]) 88 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL] == 1), 6) 89 | self.assertEqual(state[govars.INVD_CHNL, 0, 1], 0, state[govars.INVD_CHNL]) 90 | # Assert empty space in pieces channels 91 | self.assertEqual(state[govars.BLACK, 0, 1], 0) 92 | self.assertEqual(state[govars.WHITE, 0, 1], 0) 93 | 94 | final_move = (0, 1) 95 | state, reward, done, info = self.env.step(final_move) 96 | 97 | # White should only have 2 pieces 98 | self.assertEqual(np.count_nonzero(state[govars.WHITE]), 2, state[govars.WHITE]) 99 | self.assertEqual(np.count_nonzero(state[govars.WHITE] == 1), 2) 100 | # Black should have 4 pieces 101 | self.assertEqual(np.count_nonzero(state[govars.BLACK]), 4, state[govars.BLACK]) 102 | self.assertEqual(np.count_nonzero(state[govars.BLACK] == 1), 4) 103 | 104 | def test_simple_capture(self): 105 | """ 106 | _, 1, _, _, _, _, _, 107 | 108 | 3, 2, 5, _, _, _, _, 109 | 110 | _, 7, _, _, _, _, _, 111 | 112 | _, _, _, _, _, _, _, 113 | 114 | _, _, _, _, _, _, _, 115 | 116 | _, _, _, _, _, _, _, 117 | 118 | _, _, _, _, _, _, _, 119 | 120 | :return: 121 | """ 122 | 123 | for move in [(0, 1), (1, 1), (1, 0), None, (1, 2), None, (2, 1)]: 124 | state, reward, done, info = self.env.step(move) 125 | 126 | # White should have no pieces 127 | self.assertEqual(np.count_nonzero(state[govars.WHITE]), 0) 128 | 129 | # Black should have 4 pieces 130 | self.assertEqual(np.count_nonzero(state[govars.BLACK]), 4) 131 | # Assert values are ones 132 | self.assertEqual(np.count_nonzero(state[govars.BLACK] == 1), 4) 133 | 134 | def test_large_group_capture(self): 135 | """ 136 | _, _, _, _, _, _, _, 137 | 138 | _, _, 2, 4, 6, _, _, 139 | 140 | _, 20, 1, 3, 5, 8, _, 141 | 142 | _, 18, 11, 9, 7, 10, _, 143 | 144 | _, _, 16, 14, 12, _, _, 145 | 146 | _, _, _, _, _, _, _, 147 | 148 | _, _, _, _, _, _, _, 149 | 150 | :return: 151 | """ 152 | for move in [(2, 2), (1, 2), (2, 3), (1, 3), (2, 4), (1, 4), (3, 4), (2, 5), (3, 3), (3, 5), (3, 2), (4, 4), 153 | None, (4, 3), None, (4, 2), None, 154 | (3, 1), None, (2, 1)]: 155 | state, reward, done, info = self.env.step(move) 156 | 157 | # Black should have no pieces 158 | self.assertEqual(np.count_nonzero(state[govars.BLACK]), 0) 159 | 160 | # White should have 10 pieces 161 | self.assertEqual(np.count_nonzero(state[govars.WHITE]), 10) 162 | # Assert they are ones 163 | self.assertEqual(np.count_nonzero(state[govars.WHITE] == 1), 10) 164 | 165 | def test_large_group_suicide(self): 166 | """ 167 | _, _, _, _, _, _, _, 168 | 169 | _, _, _, _, _, _, _, 170 | 171 | _, _, _, _, _, _, _, 172 | 173 | _, _, _, _, _, _, _, 174 | 175 | 1, 3, _, _, _, _, _, 176 | 177 | 4, 6, 5, _, _, _, _, 178 | 179 | 2, 8, 7, _, _, _, _, 180 | 181 | :return: 182 | """ 183 | for move in [(4, 0), (6, 0), (4, 1), (5, 0), (5, 2), (5, 1), (6, 2)]: 184 | state, reward, done, info = self.env.step(move) 185 | 186 | # Test invalid channel 187 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL]), 8, state[govars.INVD_CHNL]) 188 | self.assertEqual(np.count_nonzero(state[govars.INVD_CHNL] == 1), 8) 189 | # Assert empty space in pieces channels 190 | self.assertEqual(state[govars.BLACK, 6, 1], 0) 191 | self.assertEqual(state[govars.WHITE, 6, 1], 0) 192 | 193 | final_move = (6, 1) 194 | with self.assertRaises(Exception): 195 | self.env.step(final_move) 196 | 197 | def test_group_edge_capture(self): 198 | """ 199 | 1, 3, 2, _, _, _, _, 200 | 201 | 7, 5, 4, _, _, _, _, 202 | 203 | 8, 6, _, _, _, _, _, 204 | 205 | _, _, _, _, _, _, _, 206 | 207 | _, _, _, _, _, _, _, 208 | 209 | _, _, _, _, _, _, _, 210 | 211 | _, _, _, _, _, _, _, 212 | 213 | :return: 214 | """ 215 | 216 | for move in [(0, 0), (0, 2), (0, 1), (1, 2), (1, 1), (2, 1), (1, 0), (2, 0)]: 217 | state, reward, done, info = self.env.step(move) 218 | 219 | # Black should have no pieces 220 | self.assertEqual(np.count_nonzero(state[govars.BLACK]), 0) 221 | 222 | # White should have 4 pieces 223 | self.assertEqual(np.count_nonzero(state[govars.WHITE]), 4) 224 | # Assert they are ones 225 | self.assertEqual(np.count_nonzero(state[govars.WHITE] == 1), 4) 226 | 227 | def test_group_kill_no_ko_protection(self): 228 | """ 229 | Thanks to DeepGeGe for finding this bug. 230 | 231 | _, _, _, _, 2, 1, 13, 232 | 233 | _, _, _, _, 4, 3, 12/14, 234 | 235 | _, _, _, _, 6, 5, 7, 236 | 237 | _, _, _, _, _, 8, 10, 238 | 239 | _, _, _, _, _, _, _, 240 | 241 | _, _, _, _, _, _, _, 242 | 243 | _, _, _, _, _, _, _, 244 | 245 | :return: 246 | """ 247 | 248 | for move in [(0, 5), (0, 4), (1, 5), (1, 4), (2, 5), (2, 4), (2, 6), (3, 5), None, (3, 6), None, (1, 6), 249 | (0, 6)]: 250 | state, reward, done, info = self.env.step(move) 251 | 252 | # Test final kill move (1, 6) is valid 253 | final_move = (1, 6) 254 | self.assertEqual(state[govars.INVD_CHNL, 1, 6], 0) 255 | state, _, _, _ = self.env.step(final_move) 256 | 257 | # Assert black is removed 258 | self.assertEqual(state[govars.BLACK].sum(), 0) 259 | 260 | # Assert 6 white pieces still on the board 261 | self.assertEqual(state[govars.WHITE].sum(), 6) 262 | 263 | 264 | if __name__ == '__main__': 265 | unittest.main() 266 | -------------------------------------------------------------------------------- /GymGo/screenshots/human_ui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/GymGo/screenshots/human_ui.png -------------------------------------------------------------------------------- /GymGo/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='gym_go', 5 | version='0.0.1', 6 | install_requires=['gym'] # and other dependencies 7 | ) 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 QPT Family(Org) and DeepGeGe(Owner) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 机巧围棋(CleverGo) 2 | ![GitHub Repo stars](https://img.shields.io/github/stars/QPT-Family/QPT-CleverGo) 3 | ![GitHub forks](https://img.shields.io/github/forks/QPT-Family/QPT-CleverGo) 4 | ![GitHub](https://img.shields.io/github/license/QPT-Family/QPT-CleverGo) 5 | [![QQGroup](https://img.shields.io/badge/QQ群-935098082-9cf?logo=tencent-qq&logoColor=000&logoWidth=15)](https://jq.qq.com/?_wv=1027&k=qFlk0VWG) 6 | 7 | 8 | [GitHub主页](https://github.com/QPT-Family/QPT-CleverGo) 9 | 10 | 机巧围棋(CleverGo)是基于Python+Pygame+PaddlePaddle打造的一款点击按钮就能可视化地训练围棋人工智能的程序。 11 | 12 | 机巧围棋通过模块化设计,搭建了一整套简单易用的围棋AI学习、开发、训练及效果可视化验证框架。 13 | 14 | 期望大家能够*Star*支持机巧围棋鸭~! 15 | 16 | ## 版本说明 17 | > 当前版本为尝鲜版,可能会有未测试出的Bug。如发现问题,强烈建议加QQ群935098082与我们进行交流,我们仍在更新~ 18 | 19 | ## 安装说明 20 | Python版本: `3.7` 21 | 22 | 下载链接: `https://github.com/QPT-Family/QPT-CleverGo/archive/refs/heads/main.zip` 23 | 24 | 依赖安装: `pip install -r requirements.txt` 25 | 26 | 音乐资源: 在项目`assets/`文件夹下创建`musics`文件夹,并将任意`.mp3`格式音频文件放入该文件夹下即可。 27 | 可选音乐资源包:[下载链接](https://pan.baidu.com/s/1XPWUcVkfy3NLGLKb3VkLRA) ,提取码`tixk`。 28 | 29 | ## 功能说明 30 | - 程序启动入口:`play_game.py` 31 | ![启动界面](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/%E5%90%AF%E5%8A%A8%E7%95%8C%E9%9D%A2.png) 32 | 33 | - 点击训练幼生阿尔法狗: 34 | ![训练初始界面](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/%E8%AE%AD%E7%BB%83%E5%88%9D%E5%A7%8B%E7%95%8C%E9%9D%A2.png) 35 | 36 | - 点击开始训练: 37 | ![训练过程](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/%E8%AE%AD%E7%BB%83%E8%BF%87%E7%A8%8B.png) 38 | 39 | - 对弈: 40 | ![对弈](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/%E5%AF%B9%E5%BC%88.png) 41 | 42 | 43 | ## 其它说明 44 | 目前正在编写项目技术原理文档,技术原理文档将从以下8个方面详细介绍项目核心技术原理: 45 | 46 | 1. 围棋基本知识 47 | 2. 围棋模拟器(GymGo)核心方法原理 48 | 3. 游戏引擎(Pygame)核心方法 49 | 4. 深度学习框架(PaddlePaddle)核心方法 50 | 5. 深度强化学习基本原理方法 51 | 6. AlphaGo基本原理 52 | 7. CleverGo项目程序设计方法原理 53 | 8. CleverGo未来规划 54 | 55 | 具体规划请参见:[机巧围棋(CleverGo)技术原理文档](https://github.com/QPT-Family/QPT-CleverGo/blob/main/docs/%E6%9C%BA%E5%B7%A7%E5%9B%B4%E6%A3%8B(CleverGo)%E6%8A%80%E6%9C%AF%E5%8E%9F%E7%90%86%E6%96%87%E6%A1%A3.md) 56 | -------------------------------------------------------------------------------- /assets/audios/Button.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/assets/audios/Button.wav -------------------------------------------------------------------------------- /assets/audios/Stone.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/assets/audios/Stone.wav -------------------------------------------------------------------------------- /assets/fonts/msyh.ttc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/assets/fonts/msyh.ttc -------------------------------------------------------------------------------- /assets/fonts/msyhbd.ttc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/assets/fonts/msyhbd.ttc -------------------------------------------------------------------------------- /assets/fonts/msyhl.ttc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/assets/fonts/msyhl.ttc -------------------------------------------------------------------------------- /assets/pictures/B-13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/assets/pictures/B-13.png -------------------------------------------------------------------------------- /assets/pictures/B-19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/assets/pictures/B-19.png -------------------------------------------------------------------------------- /assets/pictures/B-9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/assets/pictures/B-9.png -------------------------------------------------------------------------------- /assets/pictures/B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/assets/pictures/B.png -------------------------------------------------------------------------------- /assets/pictures/W-13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/assets/pictures/W-13.png -------------------------------------------------------------------------------- /assets/pictures/W-19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/assets/pictures/W-19.png -------------------------------------------------------------------------------- /assets/pictures/W-9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/assets/pictures/W-9.png -------------------------------------------------------------------------------- /assets/pictures/W.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/assets/pictures/W.png -------------------------------------------------------------------------------- /docs/围棋基本知识.md: -------------------------------------------------------------------------------- 1 | # 围棋基本知识 2 | 3 | 围棋使用格状棋盘及黑白二色圆形棋子进行对弈,棋盘上有纵横各19条线段将棋盘分成361个交叉点。在野狐、奕城等各大围棋平台,除19路围棋外,还存在9路和13路围棋用于练习和娱乐。9路围棋即棋盘由纵横各9条线段分成81个交叉点,13路围棋即棋盘由纵横各13条线段分成169个交叉点。围棋以围地多者为胜,被认为是世界上最复杂的棋盘游戏。 4 | 5 | 由于13路和19路围棋复杂度较高,训练相应的围棋AI所需计算资源非常多,所需训练时间非常长,因此机巧围棋主要提供9路围棋一键式训练功能。 6 | 7 | 本文介绍围棋规则及基本行棋知识。 8 | 9 | 10 | 11 | ## 1. 围棋规则 12 | 13 | ### 1.1 围棋棋盘 14 | 15 | 如图一所示,围棋棋盘上有纵横19条线段,共构成了361个交叉点,其中包括9个标记的交叉点。中间的被标记的交叉点叫做“天元”,其余被标记的交叉点叫做“星位”。星位和天元的作用是帮助对局者更方便地定位棋盘上交叉点的位置,无其余作用。 16 | 17 | ![2_1](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/2_1.png) 18 | 19 | 20 | 21 | ### 1.2 基本下法 22 | 23 | 1. 对局双方各执一色棋子,黑先白后,交替下子,每次只能下一子; 24 | 25 | 2. 棋子下在棋盘上的交叉点上; 26 | 27 | 3. 棋子下定后,不得再向其他位置移动; 28 | 29 | 4. 轮流下子是双方的权利,但允许任何一方放弃下子权而使用虚着。 30 | 31 | 32 | 33 | ### 1.3 基本规则 34 | 35 | 1. 无气自提; 36 | 2. 禁止全局同形再现; 37 | 3. 地大者胜。 38 | 39 | > 全局同形再现是妨碍终局的唯一技术性原因,原则上必须禁止。禁止全局同型再现包括以下内容: 40 | > 41 | > 1. 禁止单劫立即回提; 42 | > 2. 禁止假生类多劫循环; 43 | > 3. 原则上禁止三劫循环、四劫循环、长生、双提两子等全局同形再现的罕见特例。根据不同比赛,也可制定相应的补充规定,如无胜负、和棋、加赛等。 44 | > 45 | > 在一般围棋比赛中,只会严格禁止单劫立即回提和禁止假生类多劫循环,三劫循环等情况一般作无胜负处理。由于假生类多劫循环、三劫循环等情况非常罕见,文本讲解围棋基本知识默认禁止全局同形再现等同于禁止单劫立即回提,不考虑其余两种情况。 46 | > 47 | > 上述规则的具体含义将在下面基本行棋知识部分讲解。 48 | 49 | 50 | 51 | ## 2. 基本行棋知识 52 | 53 | ### 2.1 棋子的气 54 | 55 | 当一个棋子被放置到棋盘的某个交叉点上,其直线相邻的交叉点就是该棋子的气,气的单位为“口”。如图二所示,棋子A共有4个**X**所在位置的4口气。 56 | 57 | 可将不同颜色的棋子放置在一个棋子直接相邻的交叉点,以堵住或填掉相应棋子的气。图二中黑子B的4个直接相邻交叉点中有两个位置存在白子C和D,则黑子B只有2口气。同样的,白子C和D分别存在3口气。 58 | 59 | 若棋子位于棋盘边线上,显然只存在3个与其直线相邻的交叉点,即只有3口气。若位于棋盘的角上,则只存在2个与其直接相邻的交叉点,即只有2口气。图二中黑子G只有3个**X**所在位置的3口气,黑子H只有2个**X**所在位置的2口气。 60 | 61 | 若多个相同颜色的棋子连在一起,则这些棋子必须被视为一个整体,称为一块棋。一块棋的气是共用的,而且同生共死。图二中△标识的5个黑子连在了一起,因此为一块棋。△标识的5个黑子所组成的这一块棋共有11个**X**所在位置的11口气。黑子E和F没有连在一起,因此他们不是一块棋。 62 | 63 | ![2_2](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/2_2.png) 64 | 65 | 棋子的气即为棋子的生命,若一块棋(单独的一颗棋子也可以视为一块棋)的所有气均被堵住,则该块棋子被杀死,必须从棋盘上全部提掉(拿走),即**无气自提**。 66 | 67 | 68 | 69 | ### 2.2 禁着点 70 | 71 | 在围棋棋盘上,并不是所有空的交叉点均能够放置棋子(已经存在棋子的位置当然不能够放置棋子啦),这些位置被称为禁着点。分为如下两种情况: 72 | 73 | 1. 己方棋子放入后呈无气状态且不能吃掉任意周围的对手棋子; 74 | 2. 己方棋子放入后使得对手面临其上一手落子时相同的局面。 75 | 76 | 如图三所示,空交叉点A位置和B位置均为白方的禁着点。A位置白方棋子放入后呈无气状态,B位置白方棋子放入后,放入的棋子与▲标识的两颗白子一起所组成的一块棋呈无气状态,而且不能吃掉周围的对手的棋子。 77 | 78 | C位置不是白方的禁着点,因为虽然白方棋子放入后呈无气状态,但是在C位置放入白子可填掉周围黑子的最后一口气,即可提掉周围无气的黑子。在围棋规则中,一方落子后,判定对手棋子的状态优先于对己方棋子状态的判定,白方在C位置放入一颗白子,吃掉周围黑子,会使得自己刚放入的棋子是有气的。白方在C位置放入一颗白子,左框局面会变成右框状态。 79 | 80 | ![2_3](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/2_3.png) 81 | 82 | > 第2种情况在围棋中有更为专业的表示,即“劫争”。本节不对第2种情况进行展开,请参见第**2.6**部分。 83 | 84 | 85 | 86 | ### 2.3 死棋与活棋 87 | 88 | 思考下图左上角方框中所示情况,其中A位置和B位置均为白棋的禁着点,如果黑棋不主动在A位置和B位置放置棋子,白棋能够把黑棋吃掉吗? 89 | 90 | ![2_4](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/2_4.png) 91 | 92 | 答案是否定的。只要黑棋不在A位置和B位置放置棋子,则任何时候白棋均不能在A位置或B位置放置棋子,因此黑棋不论如何均不会被吃掉。A位置和B位置被称为黑棋的“眼”。一块棋子如果有不少于两只眼,则该块棋子则是活棋,否则为死棋。 93 | 94 | > 注意,一块棋子有不少于两只眼则为活棋,但是一块棋只有在没有气的情况下才会被提掉。 95 | 96 | 再思考图四右上角方框中所示情况,方框中的黑棋是活棋吗? 97 | 98 | 答案也是否定的,虽然黑棋目前有两口气,但是其只有一只眼。白方要杀黑棋,只需要在C位置或者D位置放置一颗棋子。如果白方在C位置放置一颗棋子,黑棋不论是否在D位置放置棋子,这块黑棋均会只有一口气。黑棋如果不在D位置放置棋子,则白方只需要进一步在D位置放置一颗棋子,即可吃掉这块黑棋。黑棋如果在D位置放置棋子,吃掉白方在C位置放置的白子,这块黑棋仍然会只有一口气(因为D位置多了一颗黑子),白方只需要在C位置放入一颗棋子,即可吃掉这块黑棋。 99 | 100 | 再思考图四中间方框中所示情况,如果此时轮到黑棋下,黑棋该如何是的自己的这块棋变成活棋呢? 101 | 102 | 答案是下在F位置。因为如果黑棋在F位置放入一颗棋子,则E位置和G位置均会变成白棋的禁着点,这块黑棋有了2只眼,因此成为活棋,永远也不会被白棋吃掉。同理如果此时轮到白棋下,白棋只需要在F位置放入一颗棋子,则黑棋无法做出2只眼,无法做活,免不了被白棋吃掉的结局。 103 | 104 | > 注意,禁着点并不与眼等同。 105 | 106 | 107 | 108 | ### 2.4 真眼与假眼 109 | 110 | 一个空交叉点周围4个直接相邻的交叉点上均有黑子(在边线上即3个直接相邻的交叉点上均有黑子,在角落上即2个直接相邻的交叉点上均有黑子),则该空交叉点为黑棋的眼。对于白棋亦然。 111 | 112 | 如图五所示,三个▲标识的位置均为黑棋的眼。眼的4个角落是该眼的眼角,即图五中**X**位置分别是对应眼的眼角。若一个眼是真眼,当该眼位于中央,则必须至少占领3个眼角。当该眼位于边线上,则必须占领全部两个眼角。当该眼位于角落,则必须保证占领唯一的一个眼角。否则眼即为假眼。即在图五中,眼A、B、C均为黑棋的真眼,眼D、E、F均为黑棋的假眼。 113 | 114 | ![2_5](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/2_5.png) 115 | 116 | 为什么空交叉点D、E、F均是黑棋的假眼? 117 | 118 | 对于空交叉点D来说,如果白棋依次落子GHIJD,则阴影标识的三个黑子将被吃掉,相应的眼D将不复存在。对于空交叉点E和F也是同理。即假眼是能够被对手破坏的眼,真眼是对手无法破坏的眼。一块棋只有存在大于或等于两只真眼才是活棋。 119 | 120 | 121 | 122 | ### 2.5 胜负的计算 123 | 124 | 在围棋游戏中胜负分为两种情况: 125 | 126 | 1. 中盘胜/中盘负:如果在对局过程中,白方认输,则黑中盘胜,或白中盘负; 127 | 2. 围棋游戏进行至某一局面,黑白双方协商终局,通过围地大小判定胜负。 128 | 129 | 当黑白双方协商终局,按照如下3个步骤计算胜负: 130 | 131 | 1. 清理死子: 132 | 133 | 在图六中,▲标识的黑子和白子均为死子,因为在一人一手轮流落子的情况下,这些棋子无法做成两只眼,免不了被吃掉的结局。 134 | 135 | ![2_6](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/2_6.png) 136 | 137 | 2. 分别计算黑棋和白棋所为围住的地盘大小(围住的空交叉点数+相应颜色的棋子数): 138 | 139 | 如图七所示,黑棋共围住了**X**所示的18个交叉点,再加上棋盘上共有29颗黑子,因此黑棋地盘为18+29=47。白棋共围住了▲所示的7个交叉点,再加上棋盘上共有27颗白子,因此白棋地盘为7+27=34。 140 | 141 | ![2_7](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/2_7.png) 142 | 143 | 3. 计算胜负: 144 | 145 | 在图七所示的棋盘上,总共有9*9=81个交叉点,则整个围棋棋盘中一半的地盘为40.5。因为黑棋先行,具有先招优势,因此在计算胜负时,黑棋需要贴三又3/4子,即黑棋所围住的地盘,必须超过棋盘中一半的地盘还要多三又3/4(即44.25),才能算胜利。因为黑棋地盘为47,47-44.25=2.75,因此对局结果为黑胜2又3/4子。 146 | 147 | > 协商终局:协商终局是指对局双方均认可地盘划分现状。因为在一人一手交替落子的情况下,双方均认为不可能更改目前的地盘占领情况。 148 | > 149 | > 黑贴3又3/4子:这是中国围棋规则所规定的,是中国围棋协会根据职业棋手对局情况商讨定下来的一个比较合理的值。 150 | 151 | 152 | 153 | ### 2.6 劫 154 | 155 | 在下图中,1位置黑棋只有1口气。此时白棋如落子A位置,则可以吃掉1位置的黑子。当白棋吃掉1位置黑子,规则规定黑棋不能马上将2位置白子吃掉。如果黑棋马上将2位置白子吃掉,则会导致白棋面临上一手落子时相同的局面,即此时被吃掉的1位置黑子处的空交叉点为黑棋的禁着点。这种黑白双方在同一个位置反复吃掉对方一颗棋子的情况称为“劫”,黑白双方争夺“劫”所在位置的空交叉点的情况称为“劫争”或“打劫”。 156 | 157 | ![2_8](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/2_8.gif) 158 | 159 | > 为什么规则要这么规定? 160 | > 161 | > 其实很好理解。如果允许白棋吃掉1位置黑子之后,黑棋能够马上重新落子1位置吃掉2位置白子,显然这会导致双方提来提去,一局棋永远下不完。 162 | 163 | #### 2.6.1 劫争的意义 164 | 165 | 如图九所示,白棋落子1位置,吃掉▲处原本存在的一颗黑子,在▲处形成了一个劫。因为A位置是黑棋的假眼,这一整块黑棋只有B位置一个真眼。因此黑棋如果想要做出两只真眼,形成活棋,则必须赢下这个劫争,即占领▲位置。 166 | 167 | 由于此时黑棋不能马上落子▲位置,吃掉1位置白子,因此黑棋可以在棋盘上其它关键位置落子,使得白方来不及落子▲位置,则再次轮到黑棋落子时,黑棋可落子▲位置,使得这块黑棋能够做出两只眼,成为活棋。 168 | 169 | ![2_9](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/2_9.png) 170 | 171 | > 所谓的关键位置即称为“劫材”,黑棋落子后,如果白棋不理,则黑子再下一子可给白方造成很严重的损失。 172 | > 173 | > 劫争是围棋中的一种高级技巧和艺术。在很多棋局中,一个劫争有可能决定多达数十颗棋子的生死存亡,甚至会直接决定一局棋的胜负。 174 | 175 | 176 | 177 | ## 3. 作者的话 178 | 179 | 围棋与中华文化一脉相承,息息相通,下围棋的原则与中华传统文化倡导的为人处世原则是相互统一的。历史上曾记载了关于围棋的十句口诀: 180 | 181 | 1. 不得贪胜 182 | 2. 入界宜缓 183 | 3. 攻彼顾我 184 | 4. 弃子争先 185 | 5. 舍小就大 186 | 6. 逢危须弃 187 | 7. 慎勿轻速 188 | 8. 动须相应 189 | 9. 彼强自保 190 | 10. 势孤取和 191 | 192 | 机巧围棋作者学习围棋没有上过任何学习班,完全是自学,从在野狐围棋平台18级一直打到了3段。围棋是一个很有意思的游戏,在作者学会围棋并对弈升至野狐3段的这一过程中,总结出了围棋的四境界: 193 | 194 | 1. 君子不立于危墙之下,善战者不败; 195 | 2. 重剑无锋,大巧不工,善弈者通盘无妙手; 196 | 3. 你走你的,我走我的; 197 | 4. 旁观者,局外人。 198 | 199 | 如果你也喜欢围棋,或者看了本文有兴趣学习围棋,欢迎下载[野狐围棋](https://www.foxwq.com/),然后加我为好友(我的昵称是:秋临铜雀台),和我一起对弈吧~ 200 | 201 | ![2_10](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/2_10.png) 202 | 203 | 204 | 205 | ## 4. 围棋学习资料 206 | 207 | 1. 学习围棋基本规则及常见技巧:[20天从零学会下围棋](https://www.bilibili.com/video/BV1uW411W7xK?p=1) 208 | 2. 感受围棋的竞技魅力:[2014 古李十番棋合集](https://www.bilibili.com/video/BV1Rs411B791?p=6) 209 | 3. 聆听围棋人工智能与顶尖棋手的对话:[【完整合集】【中英解说】【赛后分析】AlphaGo VS 柯洁 2017中国乌镇·围棋峰会](https://www.bilibili.com/video/BV1ix411Y7M9?spm_id_from=333.999.0.0) 210 | -------------------------------------------------------------------------------- /docs/机巧围棋(CleverGo)开发计划文档.md: -------------------------------------------------------------------------------- 1 | # 机巧围棋(CleverGo)开发计划文档 2 | 开发计划文档主要分阶段地罗列各版本开发功能,具体主线如下 3 | 1. 基于目前项目状态,将基本功能完善 4 | 2. 项目重构 5 | 3. 项目各功能模块优化 6 | 4. 功能追加 7 | 5. 项目维护 8 | 9 | ## 1. 项目基础开发:当前项目-->V1.0 10 | 完善CleverGo项目基本核心功能 11 | - Pygame信息展示控件封装【取消该控件实现】 12 | - ~~Tensorflow框架转PaddlePaddle框架~~ 13 | - ~~策略网络和价值网络分离(网络单独定义,与训练等操作分离)~~ 14 | - ~~训练方法集成,实现点击按钮训练阿尔法狗~~ 15 | - ~~实现策略网络、价值网络,训练好的阿尔法狗添加至可选玩家~~ 16 | 17 | ## 2. 项目重构:V1.0-->V2.0 18 | 重构整个项目,对不合理的程序设计及代码实现进行重构。 19 | - 前后端完全分离,引用项目之间完全分离 20 | 在原始GymGo开源项目上封装一层,定义为goengine.py,该文件与GymGo项目直接进行交互,并依据自己项目情况,定义相关后端方法,项目前端只与该文件中相关方法进行交互。 21 | ![项目架构图](https://img-blog.csdnimg.cn/e87f822779b04a2c9cc3f13e2b226767.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBARGVlcEdlR2U=,size_20,color_FFFFFF,t_70,g_se,x_16#pic_center) 22 | - Pygame控件实现逻辑优化 23 | 按钮控件及显示屏控件继承自一个控件基类,相关控件定义时候直接注册到控件状态更新器,可以在程序主循环中通过一行代码,如`controls.update(event)`实现所有控件的状态更新。 24 | ![控件系统架构图](https://img-blog.csdnimg.cn/c2e4b12120bc4d5d86ba29d8dc08a1c4.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBARGVlcEdlR2U=,size_19,color_FFFFFF,t_70,g_se,x_16#pic_center) 25 | 26 | - 程序多线程处理逻辑优化 27 | 由于不同Player计算下一步落子需要一定时间,程序必须满足在计算下一步落子期间能够正确执行玩家其余指令(如切换游戏状态,切换音乐,更改Player等等),因此必须使用多线程方法。目前多线程处理逻辑存在一定问题,具体优化方法待讨论后决定。 28 | - 程序各函数方法接口规范化 29 | 参数命名规范化; 30 | 方法说明及相关注释规范化; 31 | 配置规范化; 32 | Log规范化。 33 | - 其余重构任务 34 | …… 35 | 36 | ## 3. 项目各功能模块优化 37 | 1. 音乐播放系统优化:无需显式地定义出各待播放音乐,只需将相关音乐资源放置到musics文件夹,即可加载文件 38 | (针对音乐名称长度进行判定,长度超过按钮长度的加载,不能简单地判定字符数,而需判定Pygame.fonts对象size) 39 | 2. 音乐播放、Player选择、音乐模式选择等相关按钮功能优化:点击按钮左半边向前翻滚,右半边向后翻滚 40 | 3. 其它重要紧急优化 41 | 42 | ## 4. 功能追加 43 | **第一阶段** 44 | - 给阿尔法狗命名(不能与系统内置确定名称,如人类玩家,策略网络等冲突) 45 | - 领养多条阿尔法狗 46 | 47 | **第二阶段** 48 | - 阿尔法狗可自定义,点击按钮进入阿尔法狗属性配置界面,用户可一定程度上自定义网络结构(考虑交互方式,如按钮等等),可自己设置相关训练超参数等等 49 | 50 | **第三阶段** 51 | 考虑向13路棋拓展 52 | 53 | *** 54 | **远景计划一** 55 | 多人PK平台,用户可与他人联机PK(【优先考虑局域网】在这个平台上,各玩家可以与其他玩家发布的AI进行对弈,也可以玩家之间对弈,或者让自己的AI与其它玩家的AI进行对弈) 56 | 57 | **远景计划二** 58 | 网络个人AI对弈平台:由局域网拓展到互联网。 59 | 60 | **远景计划三** 61 | 考虑向19路棋拓展 62 | 63 | ## 5. 项目维护 64 | 修改BUG,及各种Tricks更新。 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /docs/机巧围棋(CleverGo)技术原理文档.md: -------------------------------------------------------------------------------- 1 | # 机巧围棋(CleverGo)技术原理文档 2 | 技术原理文档主要详解以下8个方面核心知识、技术、方法原理: 3 | 1. 围棋基本知识 4 | 2. 围棋模拟器(GymGo)核心方法原理 5 | 3. 游戏引擎(Pygame)核心方法 6 | 4. 深度学习框架(PaddlePaddle)核心方法 7 | 5. 深度强化学习基本原理方法 8 | 6. AlphaGo基本原理 9 | 7. CleverGo项目程序设计方法原理 10 | 8. CleverGo未来规划 11 | 12 | 原理撰写遵循上述主线,并针对每个方面进行展开,对每个方面用一篇或几篇文章,详细介绍知识原理及程序实现逻辑,具体规划如下。 13 | 14 | ## 1. 机巧围棋(CleverGo)项目总览及介绍 15 | 项目总体介绍、我们说几句话、感谢一些人及相关项目 16 | - 项目简介 17 | - 作者介绍 18 | - 效果展示 19 | - 技术原理文档目录 20 | - 作者的话 21 | - 致谢 22 | - 参考资料链接 23 | 24 | ## 2. 围棋基本知识 25 | 介绍围棋基本规则、围棋学习基本视频资料、介绍围棋对弈平台及作者围棋ID,欢迎大家切磋交流 26 | - 围棋基本说明及基本介绍 27 | - 棋子的气 28 | - 禁入点 29 | - 棋子的眼(真眼与假眼) 30 | - 死棋与活棋 31 | - 棋子的分块 32 | - 双打吃 33 | - 征吃 34 | - 枷吃 35 | - 倒扑 36 | - 扑与接不归 37 | - 胜负的计算 38 | - 目与单官 39 | - 劫 40 | - 说一说关于围棋的理解(善战者不败;善弈者通盘无妙手;你下你的,我下我的;旁观者,局外人) 41 | - 围棋学习资料链接 42 | - 围棋对弈平台介绍 43 | 44 | ## 3. 围棋模拟器(GymGo)核心方法原理 45 | 介绍围棋程序核心方法原理及实现逻辑 46 | - 棋盘状态定义 47 | - 计算无效落子位置方法原理及程序实现逻辑 48 | - 计算棋子击杀状态方法原理及程序实现逻辑 49 | - 棋盘状态更新方法原理及程序实现逻辑 50 | - 落子合法性检测方法原理及程序实现逻辑 51 | - 棋局结束判定方法即程序实现逻辑 52 | - 棋子的气计算方法原理及程序实现逻辑 53 | - 黑白占领实地计算方法原理及程序实现逻辑 54 | - 胜负统计方法即程序实现逻辑 55 | 56 | ## 4. 游戏引擎(Pygame)核心方法 57 | 介绍Pygame中项目相关的方法 58 | - Pygame游戏开发基本框架 59 | - 棋盘、棋子、落子标记、落子进度、落子提示区等绘制方法 60 | - Pygame中的音乐播放器 61 | - 棋盘文字绘制方法 62 | - 按钮控件程序实现方法 63 | - 信息显示屏控件程序实现方法 64 | 65 | ## 5. 深度学习框架(PaddlePaddle)核心方法 66 | 简单介绍一下项目用到的PP框架中网络定义,训练等方法 67 | - 深度神经网络定义方法 68 | - 深度神经网络训练方法 69 | - 深度神经网络预测方法 70 | 71 | ## 6. 深度强化学习基本知识介绍 72 | 介绍与CleverGo相关的强化学习基本知识 73 | - 状态 74 | - 状态空间 75 | - 动作 76 | - 动作空间 77 | - 智能体 78 | - 策略函数 79 | - 状态转移 80 | - 价值函数 81 | - …… 82 | 83 | ## 7. AlphaGo基本原理 84 | 介绍阿尔法狗的基本知识、技术、方法原理\ 85 | 推演:蒙特卡洛树搜索\ 86 | 棋感:策略网络\ 87 | 局面评估:价值网络 88 | - 阿尔法狗中的动作、状态、策略网络和价值网络 89 | - 蒙特卡洛树搜索\ 90 | 蒙特卡洛方法与蒙特卡洛数搜索 91 | - 训练策略网络和价值网络 92 | 93 | ## 8. CleverGo项目程序设计方法原理 94 | 拆解CleverGo项目设计思想及代码结构 95 | - 程序设计框架 96 | - 项目各文件夹及文件内容介绍 97 | - 各模块设计思想及实现方法 98 | 99 | ## 9. CleverGo未来规划 100 | 说一下未来的计划,和对整个项目的期待 101 | - 未来维护开发计划 102 | 103 | ## 10. 写在最后 104 | 项目总结,心得,历程回顾,期望。总的来说,我们再过来说一些话。 105 | - 项目总结 106 | - 历程回顾 107 | - 作者的话 108 | -------------------------------------------------------------------------------- /docs/机巧围棋(CleverGo)项目总览及介绍.md: -------------------------------------------------------------------------------- 1 | # 机巧围棋(CleverGo)项目总览及介绍 2 | 3 | ## 1. 项目简介 4 | 5 | 2016年3月,阿尔法狗以4:1战胜围棋世界冠军李世石。自此开始,深度强化学习受到空前的关注并成为AI领域的研究热点,彻底引爆了以深度学习为核心技术的第三次人工智能热潮。 6 | 7 | 机巧围棋利用Python+Pygame+PaddlePaddle基于AlphaGo Zero算法打造了一款点击按钮就能可视化的训练9路围棋人工智能的程序,并搭建了一整套简单易用的围棋AI学习、开发、训练及效果可视化验证框架。 8 | 9 | 机巧围棋项目源码及技术原理文档全部免费开源,真诚期望您能够在GitHub上点个**Star**支持机巧围棋鸭~ 10 | 11 | **项目GitHub仓库地址**:[https://github.com/QPT-Family/QPT-CleverGo](https://github.com/QPT-Family/QPT-CleverGo) 12 | 13 | **QQ群**:935098082 14 | 15 | 16 | 17 | ## 2. 开发者简介 18 | 19 | 机巧围棋项目归属于GitHub组织QPT软件包家族(QPT Family),由DeepGeGe和GT-Zhang共同开发并维护。 20 | 21 | - QPT Family主页:[https://github.com/QPT-Family](https://github.com/QPT-Family) 22 | - QPT Family官方交流群:935098082 23 | 24 | DeepGeGe:QPT Family成员 25 | 26 | - CSDN博客:[https://blog.csdn.net/qq_24178985](https://blog.csdn.net/qq_24178985) 27 | - GitHub主页:[https://github.com/DeepGeGe](https://github.com/DeepGeGe) 28 | 29 | GT-Zhang:QPT Family创始人 30 | 31 | - GitHub主页:[https://github.com/GT-ZhangAcer](https://github.com/GT-ZhangAcer) 32 | 33 | 34 | 35 | ## 3. 效果展示 36 | 37 | 机巧围棋程序界面: 38 | 39 | ![1_1](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/1_1.png) 40 | 41 | 效果展示视频链接:[https://www.bilibili.com/video/BV1N3411C742?spm_id_from=333.999.0.0](https://www.bilibili.com/video/BV1N3411C742?spm_id_from=333.999.0.0) 42 | 43 | 44 | 45 | ## 4. 技术原理文档目录 46 | 47 | 机巧围棋技术原理文档主要讲解项目相关知识、算法原理及程序逻辑,具体目录如下: 48 | 49 | 1. 机巧围棋(CleverGo)项目总览及介绍 50 | 2. 围棋基本知识 51 | 3. 围棋程序逻辑 52 | 4. 游戏开发引擎(Pygame)核心方法 53 | 5. 深度学习框架(PaddlePaddle)使用教程 54 | 6. 深度强化学习基本知识 55 | 7. 阿尔法狗(AlphaGo)算法原理 56 | 57 | 8. 机巧围棋(CleverGo)程序设计 58 | 9. 机巧围棋(CleverGo)远景规划 59 | 10. 机巧围棋(CleverGo)项目总结 60 | 61 | 62 | 63 | ## 5. 开发者说 64 | 65 | **DeepGeGe**:2016年,阿尔法狗横空出世,使我深深地感受到了围棋和人工智能的魅力。自此开始,我自学围棋和人工智能,成为了AI算法工程师。围棋人工智能前有阿尔法狗,后有大名鼎鼎的绝艺,但是他们都离我们非常遥远,就像天上的星星。当了解阿尔法狗算法原理之后,我就在想是不是能够训练出一个属于我自己的阿尔法狗,或者说能不能做出一款不需要任何人工智能领域专业知识,只需点击一个按钮就能训练一个阿尔法狗? 66 | 67 | 机巧围棋从2021年3月6日开始开发,直至4月底,各大功能模块基本完成。5-8月项目搁置。9月初,我找到GT-Zhang大佬商定一起合作开发并维护机巧围棋项目。从9月底到10月中旬,重构了整个项目,并完成了和优化全部核心功能。 68 | 69 | 机巧围棋不需要任何专业背景知识,只需要点击按钮就能够体验训练属于自己的围棋人工智能阿尔法狗,核心理念是:Easy AI for Everyone! 70 | 71 | 期望大家能够去GitHub上给机巧围棋点个**Star**鸭~ 72 | 73 | 74 | 75 | ## 6. 致谢 76 | 77 | 1. 机巧围棋界面设计参考了HapHac作者的weiqi项目,采用了该项目的部分素材,但是围棋程序逻辑及游戏引擎与该项目不同。 78 | 79 | 参考项目地址:[https://github.com/HapHac/weiqi](https://github.com/HapHac/weiqi) 80 | 81 | 2. 机巧围棋中围棋程序内核采用了aigagror作者的GymGo项目,机巧围棋模拟器环境在该项目的基础上进行了自定义封装。此外,机巧围棋项目开发者DeepGeGe(鸽鸽)也是GymGo项目的Contributor~ 82 | 83 | 参考项目地址:[https://github.com/aigagror/GymGo](https://github.com/aigagror/GymGo) 84 | 85 | 3. 机巧围棋技术原理文档中,深度强化学习基本知识及阿尔法狗算法原理部分参考了wangshusen作者的DRL项目。综合该项目中相关深度强化学习知识,讲解了零狗算法在机巧围棋中的应用。 86 | 87 | 参考项目地址:[https://github.com/wangshusen/DRL](https://github.com/wangshusen/DRL) 88 | 89 | 4. 机巧围棋中,有关蒙特卡洛树搜索等部分参考了junxiaosong作者的AlphaZero_Gomoku项目。其中蒙特卡洛树搜索的实现方式基本与该项目中实现保持一致,并结合机巧围棋中相关功能需求进行了部分更改。 90 | 91 | 参考项目地址:[https://github.com/junxiaosong/AlphaZero_Gomoku](https://github.com/junxiaosong/AlphaZero_Gomoku) 92 | -------------------------------------------------------------------------------- /docs/深度学习框架(PaddlePaddle)使用教程.md: -------------------------------------------------------------------------------- 1 | # 深度学习框架(PaddlePaddle)使用教程 2 | 3 | 机巧围棋使用[飞桨(PaddlePaddle)](https://www.paddlepaddle.org.cn/) 框架构建并训练阿尔法狗落子策略中的深度神经网络——策略网络和价值网络。 4 | 5 | 本文讲解飞桨框架的基本使用方法:第1部分介绍神经网络的构建;第2部分讲解神经网络训练过程;第3部分介绍模型权重的保存和加载。 6 | 7 | 8 | 9 | ## 1. 神经网络构建 10 | 11 | 使用飞桨框架构建神经网络的流程如下: 12 | 13 | 1. 导入`paddle`库; 14 | 2. 定义继承自`paddle.nn.Layer`的类,在`__init__()`方法中初始化神经网络的子层(或参数); 15 | 3. 重写`forward()`方法,在该方法中实现神经网络计算流程。 16 | 17 | > 在`__init__()`方法中初始化神经网络子层的本质是初始化神经网络的参数,不同的子层实质上是不同的部分参数初始化及前向计算流程的封装。下面两种网络构建方式是等同的: 18 | > 19 | > ```python 20 | > import paddle 21 | > 22 | > 23 | > # 构建方法一:使用飞桨框架内置封装好的子层 24 | > class LinearNet1(paddle.nn.Layer): 25 | > def __init__(self): 26 | > super(LinearNet1, self).__init__() 27 | > 28 | > # 使用飞桨框架封装好的Linear层 29 | > self.linear = paddle.nn.Linear(in_features=3, out_features=2) 30 | > 31 | > def forward(self, x): 32 | > return self.linear(x) 33 | > 34 | > 35 | > # 构建方法二:自定义神经网络参数 36 | > class LinearNet2(paddle.nn.Layer): 37 | > def __init__(self): 38 | > super(LinearNet2, self).__init__() 39 | > 40 | > # 自定义神经网络参数 41 | > w = self.create_parameter(shape=[3, 2]) 42 | > b = self.create_parameter(shape=[2], is_bias=True) 43 | > self.add_parameter('w', w) 44 | > self.add_parameter('b', b) 45 | > 46 | > def forward(self, x): 47 | > x = paddle.matmul(x, self.w) 48 | > x = x + self.b 49 | > return x 50 | > 51 | > 52 | > if __name__ == "__main__": 53 | > model1 = LinearNet1() 54 | > model2 = LinearNet2() 55 | > 56 | > print('LinearNet1模型结构信息:') 57 | > paddle.summary(model1, input_size=(None, 3)) 58 | > print('LinearNet2模型结构信息:') 59 | > paddle.summary(model2, input_size=(None, 3)) 60 | > ``` 61 | > 62 | > 输出两种模型结构信息如下: 63 | > 64 | > ``` 65 | > LinearNet1模型结构信息: 66 | > --------------------------------------------------------------------------- 67 | > Layer (type) Input Shape Output Shape Param # 68 | > =========================================================================== 69 | > Linear-1 [[1, 3]] [1, 2] 8 70 | > =========================================================================== 71 | > Total params: 8 72 | > Trainable params: 8 73 | > Non-trainable params: 0 74 | > --------------------------------------------------------------------------- 75 | > Input size (MB): 0.00 76 | > Forward/backward pass size (MB): 0.00 77 | > Params size (MB): 0.00 78 | > Estimated Total Size (MB): 0.00 79 | > --------------------------------------------------------------------------- 80 | > 81 | > LinearNet2模型结构信息: 82 | > --------------------------------------------------------------------------- 83 | > Layer (type) Input Shape Output Shape Param # 84 | > =========================================================================== 85 | > LinearNet2-1 [[1, 3]] [1, 2] 8 86 | > =========================================================================== 87 | > Total params: 8 88 | > Trainable params: 8 89 | > Non-trainable params: 0 90 | > --------------------------------------------------------------------------- 91 | > Input size (MB): 0.00 92 | > Forward/backward pass size (MB): 0.00 93 | > Params size (MB): 0.00 94 | > Estimated Total Size (MB): 0.00 95 | > --------------------------------------------------------------------------- 96 | > ``` 97 | 98 | 99 | 100 | ## 2. 神经网络训练 101 | 102 | 使用飞桨框架训练神经网络的流程如下: 103 | 104 | 1. 实例化模型对象`model`; 105 | 2. 使用`model.eval()`,将模型更改为`eval`模式; 106 | 3. 定义优化器`opt`,指定优化参数; 107 | 4. 在循环中输入数据,执行模型前向计算流程,得到前向输出结果; 108 | 5. 计算前向输出结果和数据的标签的损失loss; 109 | 6. 使用`loss.backward()`进行后向传播,计算`loss`关于模型参数的梯度; 110 | 7. 使用`opt.step()`更新一次模型参数; 111 | 8. 使用`opt.clear_grad()`清除模型参数梯度; 112 | 9. 回到4,继续优化模型参数。 113 | 114 | 示例代码如下: 115 | 116 | ```python 117 | def train(epochs: int = 5): 118 | """ 119 | 训练过程示例 120 | 121 | :param epochs: 对数据集的遍历次数 122 | :return: 123 | """ 124 | # 实例化模型对象 125 | model = LinearNet1() 126 | # 更改为eval模式 127 | model.eval() 128 | 129 | # 定义优化器 130 | opt = paddle.optimizer.SGD(learning_rate=1e-2, parameters=model.parameters()) 131 | 132 | for epoch in range(epochs): 133 | # 生成随机的输入和标签 134 | fake_inputs = paddle.randn(shape=(10, 3), dtype='float32') 135 | fake_labels = paddle.randn(shape=(10, 2), dtype='float32') 136 | 137 | # 前向计算 138 | output = model(fake_inputs) 139 | # 计算损失 140 | loss = paddle.nn.functional.mse_loss(output, fake_labels) 141 | 142 | print(f'Epoch:{epoch}, Loss:{loss.numpy()}') 143 | 144 | # 后向传播 145 | loss.backward() 146 | # 参数更新 147 | opt.step() 148 | # 清除梯度 149 | opt.clear_grad() 150 | ``` 151 | 152 | 打印输出如下: 153 | 154 | ``` 155 | Epoch:0, Loss:[1.5520184] 156 | Epoch:1, Loss:[1.6992496] 157 | Epoch:2, Loss:[1.9622276] 158 | Epoch:3, Loss:[2.1343968] 159 | Epoch:4, Loss:[1.221286] 160 | ``` 161 | 162 | 163 | 164 | ## 3. 模型权重的保存和加载 165 | 166 | 飞桨框架提供了非常简单易用的API实现在模型训练和应用时保存或加载模型参数。具体如下: 167 | 168 | 保存模型参数: 169 | 170 | - `paddle.save(model.state_dict(), 'save_path/model.pdparams')` 171 | 172 | 加载模型参数: 173 | 174 | - `state_dict = paddle.load('save_path/model.pdparams')` 175 | - `model.set_state_dict(state_dict)` 176 | 177 | 178 | 179 | ## 4. 结束语 180 | 181 | 飞桨(PaddlePaddle)是中国首个自主研发、功能完备、 开源开放的产业级深度学习平台。其不仅提供了简单易用的深度学习模型构建及训练API,同时还集深度学习核心训练和推理框架、基础模型库、端到端开发套件和丰富的工具组件于一体。 182 | 183 | 机巧围棋使用飞桨框架构建并训练阿尔法狗落子策略中的价值网络和策略网络。本文详细讲解了飞桨框架最核心最基础的使用方法,相信读者在阅读完本文后,可以清晰地了解到机巧围棋中阿尔法狗的神经网络模型训练机制。 184 | 185 | 最后,期待您能够给本文点个赞,同时去GitHub上给机巧围棋项目点个Star呀~ 186 | 187 | 机巧围棋项目链接:[https://github.com/QPT-Family/QPT-CleverGo](https://github.com/QPT-Family/QPT-CleverGo) -------------------------------------------------------------------------------- /docs/深度强化学习基础.md: -------------------------------------------------------------------------------- 1 | # 深度强化学习基础 2 | 3 | 深度强化学习(Deep Reinforcement Learning)是值得深入学习研究且非常有意思的领域,但是其数学原理复杂,远胜于深度学习,且脉络复杂,概念繁杂。强化学习是一个序贯决策过程,它通过智能体(Agent)与环境进行交互收集信息,并试图找到一系列决策规则(即策略)使得系统获得最大的累积奖励,即获得最大价值。环境(Environment)是与智能体交互的对象,可以抽象地理解为交互过程中的规则或机理,在围棋游戏中,游戏规则就是环境。强化学习的数学基础和建模工具是马尔可夫决策过程(Markov Decision Process,MDP)。一个MDP通常由状态空间、动作空间、状态转移函数、奖励函数等组成。 4 | 5 | 本文介绍与机巧围棋相关的深度强化学习基础知识,辅助理解描述阿尔法狗算法原理的强化学习语言。 6 | 7 | 8 | 9 | ## 1. 基本概念 10 | 11 | - **状态(state)**:是对当前时刻环境的概括,可以将状态理解成做决策的唯一依据。在围棋游戏中,棋盘上所有棋子的分布情况就是状态。 12 | 13 | - **状态空间(state space)**:是指所有可能存在状态的集合,一般记作花体字母![1](https://latex.codecogs.com/png.latex?\mathcal{S}) 。状态空间可以是离散的,也可以是连续的。可以是有限集合,也可以是无限可数集合。在围棋游戏中,状态空间是离散有限集合,可以枚举出所有可能存在的状态(也就是棋盘上可能出现的格局)。 14 | 15 | - **动作(action)**:是智能体基于当前的状态做出的决策。在围棋游戏中,棋盘上有361个位置,而且可以选择PASS(放弃一次落子权利),于是有362种动作。动作的选取可以是确定的,也可以依照某个概率分布随机选取一个动作。 16 | 17 | - **动作空间(action space)**:是指所有可能动作的集合,一般记作花体字母![2](https://latex.codecogs.com/png.latex?\mathcal{A}) 。在围棋例子中,动作空间是![3](https://latex.codecogs.com/png.latex?\mathcal{A}=\{0,1,2,\cdots,361\}) ,其中第![4](https://latex.codecogs.com/png.latex?i) 种动作是指把棋子放到第![5](https://latex.codecogs.com/png.latex?i) 个位置上(从0开始),第361种动作是指PASS。 18 | 19 | - **奖励(reward)**:是指智能体执行一个动作之后,环境返回给智能体的一个数值。奖励往往由我们自己来定义,奖励定义得好坏非常影响强化学习的结果。一般来说,奖励是状态和动作的函数。 20 | 21 | - **状态转移(state transition)**:是指从当前$t$时刻的状态![6](https://latex.codecogs.com/png.latex?s) 转移到下一个时刻状态![7](https://latex.codecogs.com/png.latex?s^\prime) 的过程。在围棋的例子中,基于当前状态(棋盘上的格局),黑方或白方落下一子,那么环境(即游戏规则)就会生成新的状态(棋盘上新的格局)。 22 | 23 | 状态转移可以是确定的,也可以是随机的。在强化学习中,一般假设状态转移是随机的,随机性来自于环境。比如贪吃蛇游戏中,贪吃蛇吃掉苹果,新苹果出现的位置是随机的。 24 | 25 | - **策略(policy)**:的意思是根据观测到的状态,如何做出决策,即从动作空间中选取一个动作的方法。策略可以是确定性的,也可以是随机性的。强化学习中无模型方法(model -free)可以大致分为策略学习和价值学习,策略学习的目标就是得到一个**策略函数**,在每个时刻根据观测到的状态,用策略函数做出决策。 26 | 27 | 将状态记作![8](https://latex.codecogs.com/png.latex?S) 或![9](https://latex.codecogs.com/png.latex?s) ,动作记作![10](https://latex.codecogs.com/png.latex?A) 或![11](https://latex.codecogs.com/png.latex?a) ,随机策略函数![12](https://latex.codecogs.com/png.latex?\pi:\mathcal{S}\times\mathcal{A}\mapsto[0,1]) 是一个概率密度函数,记作![13](https://latex.codecogs.com/png.latex?\pi(a|s)=\mathbb{P}(A=a|S=s)) 。策略函数的输入是状态![14](https://latex.codecogs.com/png.latex?s) 和动作![15](https://latex.codecogs.com/png.latex?a) ,输出是一个0到1之间的概率值。将当前状态和动作空间中所有动作输入策略函数,得到每个动作的概率值,根据动作的概率值抽样,即可选取一个动作。 28 | 29 | 确定策略是随机策略![16](https://latex.codecogs.com/png.latex?\mu:\mathcal{S}\mapsto\mathcal{A}) 的一个特例,它根据输入状态![17](https://latex.codecogs.com/png.latex?s) ,直接输出动作![18](https://latex.codecogs.com/png.latex?a=\mu(s)) ,而不是输出概率值。对于给定的状态![19](https://latex.codecogs.com/png.latex?s) ,做出的决策![20](https://latex.codecogs.com/png.latex?a) 是确定的,没有随机性。 30 | 31 | - **状态转移函数(state transition function)**:是指环境用于生成新的状态![21](https://latex.codecogs.com/png.latex?s^\prime) 时用到的函数。由于状态转移一般是随机的,因此在强化学习中用**状态转移概率函数(state transition probability function)** 来描述状态转移。状态转移概率函数是一个条件概率密度函数,记作![22](https://latex.codecogs.com/png.latex?p(s^\prime|s,a)=\mathbb{P}(S^\prime=s^\prime|S=s,A=a)) ,表示观测到当前状态为![23](https://latex.codecogs.com/png.latex?s) ,智能体执的行动作为![24](https://latex.codecogs.com/png.latex?a) ,环境状态变成![25](https://latex.codecogs.com/png.latex?s^\prime) 的概率。 32 | 33 | 确定状态转移是随机状态转移的一个特例,即概率全部集中在一个状态![26](https://latex.codecogs.com/png.latex?s^\prime) 上。 34 | 35 | - **智能体与环境交互(agent environment interaction)**:是指智能体观测到环境的状态![27](https://latex.codecogs.com/png.latex?s) ,做出动作![28](https://latex.codecogs.com/png.latex?a) ,动作会改变环境的状态,环境反馈给智能体奖励![29](https://latex.codecogs.com/png.latex?r) 以及新的状态![30](https://latex.codecogs.com/png.latex?s^\prime) 。 36 | 37 | - **回合(episodes)**:“回合”的概念来自游戏,是指智能体从游戏开始到通关或者游戏结束的过程。 38 | 39 | - **轨迹(trajectory)**:是指一回合游戏中,智能体观测到的所有的状态、动作、奖励:![31](https://latex.codecogs.com/png.latex?s_1,a_1,r_1,s_2,a_2,r_2,s_3,a_3,r_3,\cdots) 。 40 | 41 | - **马尔可夫性质(Markov property)**:是指下一时刻状态![32](https://latex.codecogs.com/png.latex?S_{t+1}) 近依赖于当前状态![33](https://latex.codecogs.com/png.latex?S_t) 和动作![34](https://latex.codecogs.com/png.latex?A_t) ,而不依赖于过去的状态和动作。如果状态转移具有马尔可夫性质,则![35](https://latex.codecogs.com/png.latex?\mathbb{P}(S_{t+1}|S_t,A_t)=\mathbb{P}(S_{t+1}|S_1,A_1,S_2,A_2,\cdots,S_t,A_t)) 。 42 | 43 | 44 | 45 | ## 2. 回报与折扣回报 46 | 47 | ### 2.1 回报(return) 48 | 49 | 回报是从当前时刻开始到本回合结束的所有奖励的总和,所以回报也叫做**累计奖励(cumulative future reward)**。由于奖励是状态和动作的函数,因此回报具有随机性。将![36](https://latex.codecogs.com/png.latex?t) 时刻的回报记作随机变量![37](https://latex.codecogs.com/png.latex?U_t) ,假设本回合在时刻![38](https://latex.codecogs.com/png.latex?n) 结束,则![39](https://latex.codecogs.com/png.latex?U_t=R_t+R_{t+1}+R_{t+2}+R_{t+3}+\cdots+R_n) 。 50 | 51 | 回报是未来获得的奖励总和,强化学习的一种方法就是寻找一个策略,使得回报的期望最大化。这个策略称为**最优策略(optimum policy)**。这种以最大化回报的期望为目标,去寻找最优策略的强化学习方法就是策略学习。 52 | 53 | 54 | 55 | ### 2.2 折扣回报(discount return) 56 | 57 | 假如我给你两个选项:第一,现在我立刻给你100元钱;第二,等一年后我给你100元钱。你选哪一个?相信理性人都会选择现在拿到100元钱,因为未来具有不确定性,未来的收益是会具有折扣的。即在强化学习中,奖励![40](https://latex.codecogs.com/png.latex?r_t) 和![41](https://latex.codecogs.com/png.latex?r_{t+1}) 的重要性并不等同。 58 | 59 | 在MDP中,通常会给未来的奖励做折扣,基于折扣的奖励的回报即为折扣回报,折扣回报的定义为![42](https://latex.codecogs.com/png.latex?U_t=R_t+\gamma{R_{t+1}}+\gamma^2R_{t+2}+\gamma^3R_{t+3}+\cdots) 。其中![43](https://latex.codecogs.com/png.latex?\gamma\in[0,1]) 为折扣率,对待越久远的未来,给奖励打的折扣越大。 60 | 61 | > 由于回报是折扣率等于1的特殊折扣回报,下文中将“回报”和“折扣回报”统称为“回报”,不再对二者进行区分。 62 | 63 | 64 | 65 | ## 3. 价值函数(value function) 66 | 67 | ### 3.1 动作价值函数(action-value function) 68 | 69 | 回报![44](https://latex.codecogs.com/png.latex?U_t) 是![45](https://latex.codecogs.com/png.latex?t) 时刻及未来所有时刻奖励的加权和。在![46](https://latex.codecogs.com/png.latex?t) 时刻,如果知道![47](https://latex.codecogs.com/png.latex?U_t) 的值,我们就可以知道局势的好坏。![48](https://latex.codecogs.com/png.latex?U_t) 是一个随机变量,假设在![49](https://latex.codecogs.com/png.latex?t) 时刻我们已经观测到状态为![50](https://latex.codecogs.com/png.latex?s_t) ,基于状态![51](https://latex.codecogs.com/png.latex?s_t) ,已经做完决策并选择了动作![52](https://latex.codecogs.com/png.latex?a_t) ,则随机变量![53](https://latex.codecogs.com/png.latex?U_t) 的随机性来自于![54](https://latex.codecogs.com/png.latex?t+1) 时刻起的所有的状态和动作:![55](https://latex.codecogs.com/png.latex?S_{t+1},A_{t+1},S_{t+2},A_{t+2},\cdots,S_n,A_n) 。 70 | 71 | 在![56](https://latex.codecogs.com/png.latex?t) 时刻,我们并不知道![57](https://latex.codecogs.com/png.latex?U_t) 的值,但是我们又想估计![58](https://latex.codecogs.com/png.latex?U_t) 的值,解决方案就是对![59](https://latex.codecogs.com/png.latex?U_t) 求期望,消除掉其中的随机性。 72 | 73 | 对![60](https://latex.codecogs.com/png.latex?U_t) 关于变量![61](https://latex.codecogs.com/png.latex?S_{t+1},A_{t+1},S_{t+2},A_{t+2},\cdots,S_n,A_n) 求条件期望,得到: 74 | 75 | ![62](https://latex.codecogs.com/png.latex?Q_\pi(s_t,a_t)=\mathbb{E}_{S_{t+1},A_{t+1},\cdots,S_n,A_n}[U_t|S_t=s_t,A_t=a_t]~~~~~~~~~~~~~~~~~~~~~~~~~~~~~(1)) 76 | 77 | 期望中的![63](https://latex.codecogs.com/png.latex?S_t=s_t) 和![64](https://latex.codecogs.com/png.latex?A_t=a_t) 是条件,意思是已经观测到![65](https://latex.codecogs.com/png.latex?S_t) 和![66](https://latex.codecogs.com/png.latex?A_t) 的值。条件期望的结果![67](https://latex.codecogs.com/png.latex?Q_\pi(s_t,a_t)) 被称为**动作价值函数**。 78 | 79 | 期望消除了随机变量![68](https://latex.codecogs.com/png.latex?S_{t+1},A_{t+1},S_{t+2},A_{t+2},\cdots,S_n,A_n) ,因此动作价值函数![69](https://latex.codecogs.com/png.latex?Q_\pi(s_t,a_t)) 依赖于![70](https://latex.codecogs.com/png.latex?s_t) 和![71](https://latex.codecogs.com/png.latex?a_t) ,而不依赖于![72](https://latex.codecogs.com/png.latex?t+1) 时刻及其之后的状态和动作。由于动作![73](https://latex.codecogs.com/png.latex?A_{t+1},A_{t+2},\cdots,A_n) 的概率质量函数都是![74](https://latex.codecogs.com/png.latex?\pi) ,因此使用不同的![75](https://latex.codecogs.com/png.latex?\pi) ,求期望得到的结果就会有所不同,因此![76](https://latex.codecogs.com/png.latex?Q_\pi(s_t,a_t)) 还依赖于策略函数![77](https://latex.codecogs.com/png.latex?\pi) 。 80 | 81 | 综上所述,![78](https://latex.codecogs.com/png.latex?t) 时刻的动作价值函数![79](https://latex.codecogs.com/png.latex?Q_\pi(s_t,a_t)) 依赖于以下三个因素: 82 | 83 | 1. 当前状态![80](https://latex.codecogs.com/png.latex?s_t) 。当前状态越好,则![81](https://latex.codecogs.com/png.latex?Q_\pi(s_t,a_t)) 的值越大,也就是说回报的期望越大; 84 | 2. 当前动作![82](https://latex.codecogs.com/png.latex?a_t) 。智能体执行的动作越好,则![83](https://latex.codecogs.com/png.latex?Q_\pi(s_t,a_t)) 的值越大; 85 | 3. 策略函数![84](https://latex.codecogs.com/png.latex?\pi) 。策略决定未来的动作![85](https://latex.codecogs.com/png.latex?A_{t+1},A_{t+2},\cdots,A_n) 的好坏,策略越好,则![86](https://latex.codecogs.com/png.latex?Q_\pi(s_t,a_t)) 的值越大。比如同样一局棋,柯洁(好的策略)来下,肯定会比我(差的策略)来下,得到的回报的期望会更大。 86 | 87 | > 更准确地说,![87](https://latex.codecogs.com/png.latex?Q_\pi(s_t,a_t)) 应该叫做“动作状态价值函数”,但是一般习惯性地称之为“动作价值函数”。 88 | 89 | 90 | 91 | ### 3.2 状态价值函数(state-value function) 92 | 93 | 当阿尔法狗下棋时,它想知道当前状态![88](https://latex.codecogs.com/png.latex?s_t) (即棋盘上的格局)是否对自己有利,以及自己和对手的胜算各有多大。这种用来量化双方胜算的函数就是**状态价值函数**。 94 | 95 | 将动作价值函数![89](https://latex.codecogs.com/png.latex?Q_\pi(s_t,a_t)) 中动作作为随机变量![90](https://latex.codecogs.com/png.latex?A_t) ,然后关于![91](https://latex.codecogs.com/png.latex?A_t) 求期望,把![92](https://latex.codecogs.com/png.latex?A_t) 消掉,即得到状态价值函数: 96 | 97 | ![93](https://latex.codecogs.com/png.latex?V_\pi(s_t)=\mathbb{E}_{A_t\sim\pi(\cdot|s_t)}[Q_\pi(s_t|,A_t)]=\displaystyle\sum_{a_\in\mathcal{A}}\pi(a|s_t)\cdot{Q_\pi(s_t,a)}~~~~~~~~~~~~~~~~~~~~~~~~~~~~(2)) 98 | 99 | 状态价值函数![94](https://latex.codecogs.com/png.latex?V_\pi(s_t)) 只依赖于策略![95](https://latex.codecogs.com/png.latex?\pi) 和当前状态![96](https://latex.codecogs.com/png.latex?s_t) ,不依赖于动作。状态价值函数![97](https://latex.codecogs.com/png.latex?V_\pi(s_t)) 也是回报![98](https://latex.codecogs.com/png.latex?U_t) 的期望,![99](https://latex.codecogs.com/png.latex?V_\pi(s_t)=\mathbb{E}_{A_t,S_{t+1},A_{t+1},\cdots,S_n,A_n}[U_t|S_t=s_t]) 。期望消掉了回报![100](https://latex.codecogs.com/png.latex?U_t) 依赖的随机变量![101](https://latex.codecogs.com/png.latex?A_t,S_{t+1},A_{t+1},\cdots,S_n,A_n) ,状态价值越大,则意味着回报的期望越大。用状态价值函数可以衡量策略![102](https://latex.codecogs.com/png.latex?\pi) 与当前状态![103](https://latex.codecogs.com/png.latex?s_t) 的好坏。 100 | 101 | 102 | 103 | ## 4. 策略网络与价值网络 104 | 105 | ### 4.1 策略网络(policy network) 106 | 107 | 在围棋游戏中,动作空间![104](https://latex.codecogs.com/png.latex?\mathcal{A}=\{0,1,2,\cdots,360,361\}) 。策略函数![105](https://latex.codecogs.com/png.latex?\pi) 是个条件概率质量函数: 108 | 109 | ![106](https://latex.codecogs.com/png.latex?\pi(a|s)\overset{\triangle}{=}\mathbb{P}(A=a|S=s)~~~~~~~~~~~~~~~~~~~~~~~~(3)) 110 | 111 | 策略函数![107](https://latex.codecogs.com/png.latex?\pi) 的输入是状态![108](https://latex.codecogs.com/png.latex?s) 和动作![109](https://latex.codecogs.com/png.latex?a) ,输出是一个0到1之间的概率值,表示在状态![110](https://latex.codecogs.com/png.latex?s) 的情况下,做出决策,从动作空间中选取动作![111](https://latex.codecogs.com/png.latex?a) 的概率。 112 | 113 | 策略网络是用神经网络![112](https://latex.codecogs.com/png.latex?\pi(a|s;\theta)) 近似策略函数![113](https://latex.codecogs.com/png.latex?\pi(a|s)) ,其中![114](https://latex.codecogs.com/png.latex?\theta) 表示神经网络的参数。一开始随机初始化![115](https://latex.codecogs.com/png.latex?\theta) ,然后用收集到的状态、动作、奖励去更新![116](https://latex.codecogs.com/png.latex?\theta) 。 114 | 115 | 策略网络的结构如图一所示。策略网络的输入是状态![117](https://latex.codecogs.com/png.latex?s) ,在围棋游戏中,由于状态是张量,一般会使用卷积网络处理输入,生成特征向量。策略网络的输出层的激活函数是Softmax,因此输出的向量(记作![118](https://latex.codecogs.com/png.latex?f) )所有元素都是正数,而且相加等于1。向量![119](https://latex.codecogs.com/png.latex?f) 的维度与动作空间![120](https://latex.codecogs.com/png.latex?\mathcal{A}) 的大小相同,在围棋游戏中,动作空间![121](https://latex.codecogs.com/png.latex?\mathcal{A}) 大小为362,因此向量![122](https://latex.codecogs.com/png.latex?f) 就是一个362维的向量。 116 | 117 | ![6_1](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/6_1.png) 118 | 119 | 120 | 121 | ### 4.2 价值网络(value network) 122 | 123 | 价值网络是用神经网络![123](https://latex.codecogs.com/png.latex?q(s,a;\omega)) 来近似动作价值函数![124](https://latex.codecogs.com/png.latex?Q_\pi(s,a)) 或用![125](https://latex.codecogs.com/png.latex?v(s;\theta)) 来近似状态价值函数![126](https://latex.codecogs.com/png.latex?V_\pi(s)) ,其中![127](https://latex.codecogs.com/png.latex?\omega) 表示神经网络的参数。神经网络的结构是人为预先设定的,参数![128](https://latex.codecogs.com/png.latex?\omega) 一开始随机初始化,并通过智能体与环境的交互来学习。 124 | 125 | 价值网络的结构如图二和图三所示。价值网络的输入是状态![129](https://latex.codecogs.com/png.latex?s) ,在围棋游戏中,由于状态是一个张量,因此会使用卷积网络处理![130](https://latex.codecogs.com/png.latex?s) ,生成特征向量。对于动作价值函数,价值网络输出每个动作的价值,动作空间![131](https://latex.codecogs.com/png.latex?\mathcal{A}) 中有多少种动作,则价值网络的输出就是多少维的向量。对于状态价值函数,价值网络的输出是一个实数,表示状态的价值。 126 | 127 | ![6_2](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/6_2.png) 128 | 129 | ![6_3](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/6_3.png) 130 | 131 | 132 | 133 | ## 5. 蒙特卡洛(Monte Carlo) 134 | 135 | 蒙特卡洛是一大类通过随机样本估算真实值的随机算法(Randomized Algorithms)的总称,如通过实际观测值估算期望值、通过随机梯度近似目标函数关于神经网络参数的梯度。 136 | 137 | 价值网络的输出是回报![132](https://latex.codecogs.com/png.latex?U_t) 的期望。在强化学习中,可以将一局游戏进行到底,观测到所有的奖励![133](https://latex.codecogs.com/png.latex?r_1,r_2,\cdots,r_n) ,然后计算出回报![134](https://latex.codecogs.com/png.latex?u_t=\sum_{i=0}^{n-t}\gamma^ir_{t+i}) 。训练价值网络的时候以![135](https://latex.codecogs.com/png.latex?u_t) 作为目标,这种方式被称作“**蒙特卡洛**”。 138 | 139 | 原因非常显然:以动作价值函数为例,动作价值函数可以写作![136](https://latex.codecogs.com/png.latex?Q_\pi(s_t,a_t)=\mathbb{E}[U_t|S_t=s_t,A_t=a_t]) ,而我们用实际观测![137](https://latex.codecogs.com/png.latex?u_t) 去近似期望,这就是典型的蒙特卡洛近似。 140 | 141 | 蒙特卡洛的好处是**无偏性**:![138](https://latex.codecogs.com/png.latex?u_t) 是![139](https://latex.codecogs.com/png.latex?Q_\pi(s_t,a_t)) 的无偏估计。由于![140](https://latex.codecogs.com/png.latex?u_t) 的无偏性,拿![141](https://latex.codecogs.com/png.latex?u_t) 作为目标训练价值网络,得到的价值网络也是无偏的。 142 | 143 | 蒙特卡洛的坏处是**方差大**:随机变量![142](https://latex.codecogs.com/png.latex?U_t) 依赖于![143](https://latex.codecogs.com/png.latex?S_{t+1},A_{t+1},\cdots,S_n,A_n) 这些随机变量,其中不确定性很大。观测值![144](https://latex.codecogs.com/png.latex?u_t) 虽然是![145](https://latex.codecogs.com/png.latex?U_t) 的无偏估计,但可能实际上离![146](https://latex.codecogs.com/png.latex?\mathbb{E}[U_t]) 很远。因此拿![147](https://latex.codecogs.com/png.latex?u_t) 作为目标训练价值网络,收敛会非常慢。 144 | 145 | 阿尔法狗的训练基于蒙特卡洛树搜索(后续文章会详细介绍),由于蒙特卡洛树搜索方差大的缺点,训练阿尔法狗的过程非常慢,据说DeepMind公司训练阿尔法狗用了5000块TPU?! 146 | 147 | 148 | 149 | ## 6. 结束语 150 | 151 | 机巧围棋核心AlphaGo Zero算法是一种深度强化学习算法,本文介绍了深度强化学习基础,相信通过本文能够让大家更好地理解后续文章中介绍的阿尔法狗算法原理。 152 | 153 | 最后,期待您能够给本文点个赞,同时去GitHub上给机巧围棋项目点个Star呀~ 154 | 155 | 机巧围棋项目链接:[https://github.com/QPT-Family/QPT-CleverGo](https://github.com/QPT-Family/QPT-CleverGo) 156 | 157 | -------------------------------------------------------------------------------- /docs/蒙特卡洛树搜索(MCTS).md: -------------------------------------------------------------------------------- 1 | # 蒙特卡洛树搜索(MCTS) 2 | 3 | 阿尔法狗下棋的时候,做决策的不是策略网络和价值网络,而是蒙特卡洛树搜索(Monte Carlo Tree Search, MCTS)。 4 | 5 | 训练好的策略网络和价值网络均能单独地直接做决策。MCTS不需要训练,也可以单独地直接做决策。在阿尔法狗中,训练策略网络和价值网络的目的是辅助MCTS,降低MCTS的深度和宽度。 6 | 7 | 在机巧围棋中,除阿尔法狗之外,还分别集成了策略网络、价值网络和蒙特卡洛落子策略,可以任意更改黑白双方的落子策略,查看不同落子策略之间的效果。 8 | 9 | 10 | 11 | ## 1. MCTS的基本思想 12 | 13 | 人类玩家下围棋时,通常会往前看几步,越是高手,看的越远。与此同时,人类玩家不会分析棋盘上所有不违反规则的走法,而只会针对性地分析几个貌似可能的走法。 14 | 15 | 假如现在该我放置棋子了,我会这样思考:现在貌似有几个可行的走法,如果我的动作是![1](https://latex.codecogs.com/png.latex?a_t=234) ,对手会怎么走呢?假如对手接下来将棋子放在![2](https://latex.codecogs.com/png.latex?a^\prime=30) 的位置上,那我下一步动作![3](https://latex.codecogs.com/png.latex?a_{t+1}) 应该是什么呢? 16 | 17 | 人类玩家在做决策之前,会在大脑里面进行推演,确保几步以后很可能会占优势。同样的道理,AI下棋时候,也应该枚举未来可能发生的情况,从而判断当前执行什么动作的胜算更大。这样做远好于使用策略网络直接算出一个动作。 18 | 19 | MCTS的基本原理就是向前看,模拟未来可能发生的情况,从而找出当前最优的动作。这种向前看不是遍历所有可能的情况,而是与人类玩家类似,只遍历几种貌似可能的走法,而哪些动作是貌似可行的动作以及几步之后的局面优劣情况是由神经网络所决定的。阿尔法狗每走一步棋,都要用MCTS做成千上万次模拟,从而判断出哪个动作的胜算更大,并执行胜算最大的动作。 20 | 21 | 22 | 23 | ## 2. 阿尔法狗2016版中的MCTS(MCTS in AlphaGo) 24 | 25 | 在阿尔法狗2016版本中,MCTS的每一次模拟分为4个步骤:选择(selection)、扩展(expansion)、求值(evaluation)和回溯(backup)。 26 | 27 | 28 | 29 | ### 2.1 选择(Selection) 30 | 31 | 给定棋盘上的当前格局,可以确定所有符合围棋规则的可落子位置,每个位置对应一个可行的动作。在围棋中,每一步很有可能存在几十甚至上百个可行的动作,挨个搜索和评估所有可行的动作,计算量会大到无法承受。人类玩家做决策前,在大脑里面推演的时候不会考虑所有可行的动作,只会考虑少数几个认为胜算较高的动作。 32 | 33 | MCTS第一步【选择】的目的就是找出胜算较高的动作,只搜索这些好的动作,忽略掉其它的动作。 34 | 35 | 判断动作![4](https://latex.codecogs.com/png.latex?a) 的好坏有两个指标:第一,动作![5](https://latex.codecogs.com/png.latex?a) 的胜率;第二,策略网络给动作![6](https://latex.codecogs.com/png.latex?a) 的评分(概率值)。结合这两个指标,用下面的公式评价动作![7](https://latex.codecogs.com/png.latex?a) 的好坏: 36 | 37 | ![8](https://latex.codecogs.com/png.latex?score(a)\triangleq{Q(a)}+\frac{\eta}{1+N(a)}\cdot\pi(a|s;\theta)~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~(1)) 38 | 39 | 其中: 40 | 41 | - ![9](https://latex.codecogs.com/png.latex?N(a)) 是动作![10](https://latex.codecogs.com/png.latex?a) 已经被访问过的次数。初始时,对于所有![11](https://latex.codecogs.com/png.latex?a) ,令![12](https://latex.codecogs.com/png.latex?N(a)\gets{0}) 。动作![13](https://latex.codecogs.com/png.latex?a) 每被选中一次,就把![14](https://latex.codecogs.com/png.latex?N(a)) 的值加1:![15](https://latex.codecogs.com/png.latex?N(a)\gets{N(a)+1}) ; 42 | 43 | - ![16](https://latex.codecogs.com/png.latex?Q(a)) 是之前![17](https://latex.codecogs.com/png.latex?N(a)) 次模拟算出来的动作价值,主要由胜率和价值函数决定。![18](https://latex.codecogs.com/png.latex?Q(a)) 的初始值是0,动作![19](https://latex.codecogs.com/png.latex?a) 每被选中一次,就会更新一次![20](https://latex.codecogs.com/png.latex?Q(a)) ; 44 | - ![21](https://latex.codecogs.com/png.latex?\eta) 是一个超参数,需要手动调整; 45 | - ![22](https://latex.codecogs.com/png.latex?\pi(a|s;\theta)) 是策略网络对动作![23](https://latex.codecogs.com/png.latex?a) 的评分。 46 | 47 | 可以这样理解上述公式(1): 48 | 49 | - 如果动作![24](https://latex.codecogs.com/png.latex?a) 还没有被选中过,则![25](https://latex.codecogs.com/png.latex?Q(a)) 和![26](https://latex.codecogs.com/png.latex?N(a)) 均等于0,因此![27](https://latex.codecogs.com/png.latex?score(a)\propto{\pi(a|s;\theta)}) ,即完全由策略网络评价动作![28](https://latex.codecogs.com/png.latex?a) 的好坏; 50 | - 如果动作![29](https://latex.codecogs.com/png.latex?a) 已经被选中过很多次,则![30](https://latex.codecogs.com/png.latex?N(a)) 会很大,因此策略网络给动作![31](https://latex.codecogs.com/png.latex?a) 的评分在![32](https://latex.codecogs.com/png.latex?score(a)) 中的权重会降低。当![33](https://latex.codecogs.com/png.latex?N(a)) 很大的时候,有![34](https://latex.codecogs.com/png.latex?score(a)\approx{Q(a)}) ,即此时主要基于![35](https://latex.codecogs.com/png.latex?Q(a)) 判断![36](https://latex.codecogs.com/png.latex?a) 的好坏,策略网络给动作![37](https://latex.codecogs.com/png.latex?a) 的评分已经无关紧要了; 51 | - 系数![38](https://latex.codecogs.com/png.latex?\frac{\eta}{1+N(a)}) 的另一个作用是鼓励探索。如果两个动作有相近的![39](https://latex.codecogs.com/png.latex?Q) 分数和![40](https://latex.codecogs.com/png.latex?\pi) 分数,那么被选中次数少的动作的![41](https://latex.codecogs.com/png.latex?score) 会更高,也就是让被选中次数少的动作有更多的机会被选中。 52 | 53 | 给定某个状态![42](https://latex.codecogs.com/png.latex?s) ,对于所有可行的动作,MCTS会使用公式(1)算出所有动作的分数![43](https://latex.codecogs.com/png.latex?score(a)) ,找到分数最高的动作,并在这轮模拟中,执行这个动作(选择出的动作只是在模拟器中执行,类似于人类玩家在大脑中推演,并不是阿尔法狗真正走的一步棋)。 54 | 55 | 56 | 57 | ### 2.2 扩展(Expansion) 58 | 59 | 将第一步选中的动作记为![44](https://latex.codecogs.com/png.latex?a_t) ,在模拟器中执行动作![45](https://latex.codecogs.com/png.latex?a_t) ,环境应该根据状态转移函数![46](https://latex.codecogs.com/png.latex?p(s_{k+1}|s_k,a_k)) 返回给阿尔法狗一个新的状态![47](https://latex.codecogs.com/png.latex?s_{t+1}) 。 60 | 61 | 假如阿尔法狗执行动作![48](https://latex.codecogs.com/png.latex?a_t) ,对手并不会告诉阿尔法狗他会执行什么动作,因此阿尔法狗只能自己猜测对手的动作,从而确定新的状态![49](https://latex.codecogs.com/png.latex?s_{t+1}) 。和人类玩家一样,阿尔法狗可以推己及人:如果阿尔法狗认为几个动作很好,那么就假设对手也怎么认为。所以阿尔法狗用策略网络模拟对手,根据策略网络随机抽样一个动作: 62 | 63 | ![50](https://latex.codecogs.com/png.latex?a_t^\prime\sim\pi(\cdot|s_t^\prime;\theta)~~~~~~~~~~~~~~~~~~~~~~~~~~~~~(2)) 64 | 65 | 此处的状态![51](https://latex.codecogs.com/png.latex?s_t^\prime) 是站在对手的角度观测到的棋盘上的格局,动作![53](https://latex.codecogs.com/png.latex?a_t^\prime) 是假想的对手选择的动作。 66 | 67 | 进行MCTS需要模拟阿尔法狗和对手对局,阿尔法狗每执行一个动作![53](https://latex.codecogs.com/png.latex?a_k) ,环境应该返回一个新的状态![54](https://latex.codecogs.com/png.latex?s_{k+1}) 。围棋游戏具有对称性,阿尔法狗的策略,在对手看来是状态转移函数;对手的策略,在阿尔法狗看来是状态转移函数。最理想情况下,模拟器的状态转移函数是对手的真实策略,然而阿尔法狗并不知道对手的真实策略,因此阿尔法狗退而求其次,用自己训练出的策略网络![55](https://latex.codecogs.com/png.latex?\pi) 代替对手的策略,作为模拟器的状态转移函数。 68 | 69 | 70 | 71 | ### 2.3 求值(Evaluation) 72 | 73 | 从状态![56](https://latex.codecogs.com/png.latex?s_{t+1}) 开始,双方都用策略网络![57](https://latex.codecogs.com/png.latex?\pi) 做决策,在模拟器中交替落子,直至分出胜负(见图一)。阿尔法狗基于状态![58](https://latex.codecogs.com/png.latex?s_k) ,根据策略网络抽样得到动作: 74 | 75 | ![59](https://latex.codecogs.com/png.latex?a_k\sim\pi(\cdot|s_k;\theta)~~~~~~~~~~~~~~~~~~~~~~~~~~~~~(3)) 76 | 77 | 对手基于状态![60](https://latex.codecogs.com/png.latex?s_k^\prime) (从对手角度观测到的棋盘上的格局),根据策略网络抽样得到动作: 78 | 79 | ![61](https://latex.codecogs.com/png.latex?a_k^\prime\sim\pi(\cdot|s_k^\prime;\theta)~~~~~~~~~~~~~~~~~~~~~~~~~~~~~(4)) 80 | 81 | 模拟对局直至分出胜负,可以观测到奖励![62](https://latex.codecogs.com/png.latex?r) 。如果阿尔法狗胜利,则![63](https://latex.codecogs.com/png.latex?r=+1) ,否则![64](https://latex.codecogs.com/png.latex?r=-1) 。 82 | 83 | ![8_1](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/8_1.png) 84 | 85 | 综上所述,棋盘上真实的状态是![65](https://latex.codecogs.com/png.latex?s_t) ,阿尔法狗在模拟器中执行动作![66](https://latex.codecogs.com/png.latex?a_t) ,然后模拟器中的对手执行动作![67](https://latex.codecogs.com/png.latex?a_t^\prime) ,带来新的状态![68](https://latex.codecogs.com/png.latex?s_{t+1}) 。对于阿尔法狗来说,如果状态![69](https://latex.codecogs.com/png.latex?s_{t+1}) 越好,则这局游戏胜算越大,因此: 86 | 87 | - 如果阿尔法狗赢得了这局模拟(![70](https://latex.codecogs.com/png.latex?r=+1) ),则说明![71](https://latex.codecogs.com/png.latex?s_{t+1}) 可能很好;如果输了(![72](https://latex.codecogs.com/png.latex?r=-1) ),则说明可能不好。因此,奖励![73](https://latex.codecogs.com/png.latex?r) 可以反映出![74](https://latex.codecogs.com/png.latex?s_{t+1}) 的好坏。 88 | - 此外,还可以使用价值网络评价状态![75](https://latex.codecogs.com/png.latex?s_{t+1}) 的好坏。价值![76](https://latex.codecogs.com/png.latex?v(s_{t+1};\omega)) 越大,则说明状态![77](https://latex.codecogs.com/png.latex?s_{t+1}) 越好。 89 | 90 | 奖励![78](https://latex.codecogs.com/png.latex?r) 是模拟对局获得的胜负,是对![79](https://latex.codecogs.com/png.latex?s_{t+1}) 很可靠的评价,但是随机性很大。价值网络的评估![80](https://latex.codecogs.com/png.latex?v(s_{t+1};\omega)) 没有![81](https://latex.codecogs.com/png.latex?r) 可靠,但是价值网络更稳定,随机性小。阿尔法狗将奖励![82](https://latex.codecogs.com/png.latex?r) 与价值网络的输出![83](https://latex.codecogs.com/png.latex?v(s_{t+1};\omega)) 取平均,作为对状态![84](https://latex.codecogs.com/png.latex?s_{t+1}) 的评价,记作:![85](https://latex.codecogs.com/png.latex?V(s_{t+1})\triangleq\frac{r+v(s_{t+1;w})}{2}) 。 91 | 92 | 使用策略网络交替落子,直至分出胜负,通常要走一两百步。在实际实现时候,阿尔法狗训练了一个更小的神经网络(称为快速走子网络)来代替大的策略网络,以加速MCTS。 93 | 94 | 95 | 96 | ### 2.4 回溯(Backup) 97 | 98 | 第三步【求值】计算出了![86](https://latex.codecogs.com/png.latex?t+1) 步某一个状态的价值,记作![87](https://latex.codecogs.com/png.latex?V(s_{t+1})) 。每一次模拟都会得出这样一个价值,并且记录下来。模拟会重复很多次,于是第![88](https://latex.codecogs.com/png.latex?t+1) 步每一种状态下面可以有多条记录,如图二所示。 99 | 100 | ![8_2](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/8_2.png) 101 | 102 | 第![89](https://latex.codecogs.com/png.latex?t) 步的动作![90](https://latex.codecogs.com/png.latex?a_t) 下面有多个可能的状态(子节点),每个状态下面有若干条记录。把![91](https://latex.codecogs.com/png.latex?a_t) 下面所有的记录取平均,记为价值![92](https://latex.codecogs.com/png.latex?Q(a_t)) ,它可以反映出动作![93](https://latex.codecogs.com/png.latex?a_t) 的好坏。 103 | 104 | 给定棋盘上真实的状态![94](https://latex.codecogs.com/png.latex?s_t) ,有多个可行的动作![95](https://latex.codecogs.com/png.latex?a) 可供选择。对于所有的![96](https://latex.codecogs.com/png.latex?a) ,价值![97](https://latex.codecogs.com/png.latex?Q(a)) 的初始值为0。动作![98](https://latex.codecogs.com/png.latex?a) 每被选中一次(成为![99](https://latex.codecogs.com/png.latex?a_t) ),它下面就会多一条记录,我们就对![100](https://latex.codecogs.com/png.latex?Q(a)) 做一次更新。 105 | 106 | 107 | 108 | ### 2.5 MCTS的决策 109 | 110 | 上述4个步骤为一次MCTS的流程,MCTS想要真正做出一个决策(即往真正的棋盘上落一个棋子),需要做成千上万次模拟。在无数次模拟之后,MCTS做出真正的决策: 111 | 112 | ![101](https://latex.codecogs.com/png.latex?a_t=\overset{argmax}{_a}~N(a)~~~~~~~~~~~~~~~~~~~~~~~~~~~~~(5)) 113 | 114 | 此时阿尔法狗才会真正往棋盘上放一个棋子。 115 | 116 | > 为什么要依据![102](https://latex.codecogs.com/png.latex?N(a)) 来做决策呢? 117 | > 118 | > 在每一次模拟中,MCTS找出所有可行的动作![103](https://latex.codecogs.com/png.latex?\{a\}) ,计算他们的分数![104](https://latex.codecogs.com/png.latex?score(a)) ,然后选择其中分数最高的动作,并在模拟器中执行。如果某个动作![105](https://latex.codecogs.com/png.latex?a) 在模拟时胜率很大,那么它的价值![106](https://latex.codecogs.com/png.latex?Q(a)) 就会很大,它的分数![107](https://latex.codecogs.com/png.latex?score(a)) 会很高,于是它被选中的几率就很大。也就是说如果某个动作![108](https://latex.codecogs.com/png.latex?a) 很好,他被选中的次数![109](https://latex.codecogs.com/png.latex?N(a)) 就会大。 119 | 120 | 观测到棋盘上当前状态![110](https://latex.codecogs.com/png.latex?s_t) ,MCTS做成千上万次模拟,记录每个动作![111](https://latex.codecogs.com/png.latex?a) 被选中的次数![112](https://latex.codecogs.com/png.latex?N(a)) ,最终做出决策![113](https://latex.codecogs.com/png.latex?a_t=\overset{argmax}{_a}~N(a)) 。到了下一时刻,状态变成了![114](https://latex.codecogs.com/png.latex?s_{t+1}) ,MCTS会把所有动作![115](https://latex.codecogs.com/png.latex?a) 的![116](https://latex.codecogs.com/png.latex?Q(a)) 、![117](https://latex.codecogs.com/png.latex?N(a)) 全都初始化为0,然后从头开始做模拟,而不能利用上一次的结果。 121 | 122 | 123 | 124 | ## 3. 零狗中的MCTS(MCTS in AlphaGo Zero) 125 | 126 | 零狗中对MCTS进行了简化,放弃了快速走子网络,合并了【扩展】和【求值】,并且更改了【选择】和【决策】逻辑。零狗中维护了一个蒙特卡洛搜索树,搜索树的每一个节点保存了![118](https://latex.codecogs.com/png.latex?N(s,a)) (节点访问次数)、![119](https://latex.codecogs.com/png.latex?W(s,a)) (合计动作价值)、![120](https://latex.codecogs.com/png.latex?Q(s,a)) (平均动作价值)和![121](https://latex.codecogs.com/png.latex?P(s,a)) (选择该节点的先验概率)。每一次模拟会遍历一条从搜索树根结点到叶节点的路径。 127 | 128 | ![8_3](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/8_3.png) 129 | 130 | 如图三所示,零狗中每一次MCTS共有三个流程: 131 | 132 | - **选择(Select):** 133 | 134 | 在选择阶段,从搜索树的根节点开始,不断选择![122](https://latex.codecogs.com/png.latex?a_c=\overset{argmax}{_a}[Q(s,a)+U(s,a)]) ,其中![123](https://latex.codecogs.com/png.latex?U(s,a)=c_{puct}P(s,a)\frac{\sqrt{\sum_bN(s,b)}}{1+N(s,a)}) ,直至搜索树的叶节点终止。 135 | 136 | > ![124](https://latex.codecogs.com/png.latex?s) :为搜索树的一个节点代表的棋局状态; 137 | > 138 | > ![125](https://latex.codecogs.com/png.latex?a) :表示某一个可行的动作; 139 | > 140 | > ![126](https://latex.codecogs.com/png.latex?N(s,a)) :表示状态![127](https://latex.codecogs.com/png.latex?s) 下可行动作![128](https://latex.codecogs.com/png.latex?a) 被选中的次数; 141 | > 142 | > ![129](https://latex.codecogs.com/png.latex?P(s,a)) :为状态![130](https://latex.codecogs.com/png.latex?s) 下的可行动作![131](https://latex.codecogs.com/png.latex?a) 的先验概率; 143 | > 144 | > ![132](https://latex.codecogs.com/png.latex?Q(s,a)) :表示状态![133](https://latex.codecogs.com/png.latex?s) 下可行动作![134](https://latex.codecogs.com/png.latex?a) 的动作价值; 145 | > 146 | > ![135](https://latex.codecogs.com/png.latex?c_{puct}) :为一个决定探索程度超参数。 147 | 148 | - **拓展和求值(Expand and evaluate):** 149 | 150 | 选择阶段,在搜索树中不断选择![136](https://latex.codecogs.com/png.latex?Q+U) 最大的动作,直至游戏结束或某一个不是终局的叶结点。如果到了不是终局的叶结点![137](https://latex.codecogs.com/png.latex?l) ,对于![138](https://latex.codecogs.com/png.latex?l) 对应的棋局状态![139](https://latex.codecogs.com/png.latex?s) ,使用策略网络和价值网络对状态![140](https://latex.codecogs.com/png.latex?s) 进行评估,得到![141](https://latex.codecogs.com/png.latex?l) 对应棋局状态![142](https://latex.codecogs.com/png.latex?s) 下一步各个可能动作的概率![143](https://latex.codecogs.com/png.latex?p) 和![144](https://latex.codecogs.com/png.latex?l) 的价值![145](https://latex.codecogs.com/png.latex?v) 。为所有可能动作对应的棋局状态分别创建一个节点,将这些节点的先验概率设置为策略网络的输出概率值。 151 | 152 | - **回溯(Backup):** 153 | 154 | 进过上述扩展之后,之前的叶子节点![146](https://latex.codecogs.com/png.latex?l) ,现在变成了内部节点。做完了扩展和求值后,从节点![147](https://latex.codecogs.com/png.latex?l) 开始,逐层向搜索树根节点回溯,并依次更新搜索树当次被遍历的路径上各层节点的信息: 155 | 156 | ![148](https://latex.codecogs.com/png.latex?N(s_n,a_n)=N(s_n,a_n)+1\\\\W(s_n,a_n)=W(s_n,a_n)+v_n\\\\Q(s_n,a_n)=\frac{W(s_n,a_n)}{N(s_n,a_n)}) 157 | 158 | > ![149](https://latex.codecogs.com/png.latex?s_n) :表示搜索树中当次被遍历路径上节点对应的棋局状态; 159 | > 160 | > ![150](https://latex.codecogs.com/png.latex?a_n) :表示搜索树中当次被遍历路径上节点对应棋局状态下选择的动作; 161 | > 162 | > ![151](https://latex.codecogs.com/png.latex?v_n) :表示搜索树中当次被遍历路径上节点的价值,由于搜索树中相邻两层的落子方是不同的,因此相邻两层的节点价值互为相反数。 163 | 164 | 上述三个流程为零狗中的一次MCTS模拟,在零狗往真正的棋盘上落一个棋子之前,会进行1600次模拟。在上千次MCTS完成之后,MCTS基于下述公式做出真正的决策: 165 | 166 | ![152](https://latex.codecogs.com/png.latex?\pi(a|s)=\frac{N(s,a)^{1/\tau}}{\sum_bN(s,b)^{1/\tau}}~~~~~~~~~~~~~~~~~~~~~~~~~~~~~(6)) 167 | 168 | > ![153](https://latex.codecogs.com/png.latex?\tau) 为温度参数,控制探索的程度。 ![154](https://latex.codecogs.com/png.latex?\tau) 越大,不同走法间差异变小,探索比例增大。反之,则更多选择当前最优操作。在零狗中,每一次自我对弈的前30步,参数![155](https://latex.codecogs.com/png.latex?\tau=1) ,即早期鼓励探索。游戏剩下的步数,该参数将逐渐降低至0。如果是比赛,则直接为0。 169 | 170 | 171 | 172 | ## 4. MCTS的程序实现 173 | 174 | 机巧围棋是基于AlphaGo Zero算法的一款点击按钮就能可视化训练围棋人工智能的程序,机巧围棋中的MCTS与零狗中的MCTS一致,不过不支持多线程搜索,具体代码如下: 175 | 176 | ```python 177 | class TreeNode: 178 | """蒙特卡洛树节点""" 179 | def __init__(self, parent, prior_p): 180 | self.parent = parent # 节点的父节点 181 | self.children = {} # 一个字典,用来存节点的子节点 182 | self.n_visits = 0 # 节点被访问的次数 183 | self.Q = 0 # 节点的平均行动价值 184 | self.U = 0 # MCTS选择Q+U最大的节点,公式里的U 185 | self.P = prior_p # 节点被选择的概率 186 | 187 | def select(self, c_puct): 188 | """ 189 | 蒙特卡洛树搜索的第一步:选择 190 | 蒙特卡洛树搜索通过不断选择 最大上置信限Q+U 的子节点,直至一个树的叶结点 191 | 该函数为进行一步选择函数 192 | 193 | :param c_puct: 为计算U值公式中的c_puct,是一个决定探索水平的常数 194 | :return: 返回一个元组(action, next_node) 195 | """ 196 | return max(self.children.items(), 197 | key=lambda act_node: act_node[1].get_value(c_puct)) 198 | 199 | def expand(self, action_priors): 200 | """ 201 | 当select搜索到一个叶结点,且该叶节点代表的局面游戏没有结束, 202 | 需要expand树,创建一系列可能得节点,即对应节点所有可能选择的动作对应的子节点 203 | 204 | :param action_priors: 为一个列表,列表中的每一个元素为一个 特定动作及其先验概率 的元组 205 | :return: 206 | """ 207 | for action, prob in action_priors: 208 | if action not in self.children: 209 | self.children[action] = TreeNode(self, prob) 210 | 211 | def update(self, leaf_value): 212 | """ 213 | 根据子树的价值更新当前节点的价值 214 | 215 | :param leaf_value: 以当前玩家的视角看待得到的子树的价值 216 | :return: 217 | """ 218 | self.n_visits += 1 # 当前节点的访问次数+1 219 | # 更新当前节点的Q值,下述公式可由Q = W / N 推导得到 220 | # Q_old = W_old / N_old 221 | # Q = (W_old + v) / (N_old + 1) = (Q_old * N_old + v) / (N_old + 1) 222 | self.Q += 1.0 * (leaf_value - self.Q) / self.n_visits 223 | 224 | def update_recursive(self, leaf_value): 225 | """ 226 | 跟心所有祖先的Q值及访问次数 227 | 228 | :param leaf_value: 229 | :return: 230 | """ 231 | if self.parent: # 如果有父节点,证明还没到根节点 232 | self.parent.update_recursive(-leaf_value) # -leaf_value是因为每向上一层,以当前玩家视角,价值反转 233 | self.update(leaf_value) 234 | 235 | def get_value(self, c_puct): 236 | """ 237 | 计算并返回一个节点的 上置信限 评价,即Q+U值 238 | 239 | :param c_puct: 为计算U值公式中的c_puct,是一个决定探索水平的常数 240 | :return: 241 | """ 242 | self.U = c_puct * self.P * np.sqrt(self.parent.n_visits) / (1 + self.n_visits) 243 | return self.Q + self.U 244 | 245 | def is_leaf(self): 246 | """ 247 | 判断当前节点是否为叶结点 248 | 249 | :return: 250 | """ 251 | return self.children == {} 252 | 253 | def is_root(self): 254 | """ 255 | 判断当前节点是否为根节点 256 | 257 | :return: 258 | """ 259 | return self.parent is None 260 | 261 | 262 | class MCTS: 263 | """蒙特卡洛树搜索主体""" 264 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): 265 | self.root = TreeNode(None, 1.0) # 整个蒙特卡洛搜索树的根节点 266 | # policy_value_fn是一个函数,该函数的输入为game_state, 267 | # 输出为一个列表,列表中的每一个元素为(action, probability)形式的元组 268 | self.policy = policy_value_fn 269 | # c_puct为一个正数,用于控制多块收敛到策略的最大值。这个数越大,意味着越依赖前面的结果。 270 | self.c_puct = c_puct 271 | self.n_playout = n_playout 272 | 273 | def playout(self, simulate_game_state): 274 | """ 275 | 从根节点不断选择直到叶结点,并获取叶结点的值,反向传播到叶结点的祖先节点 276 | 277 | :param simulate_game_state: 模拟游戏对象 278 | :return: 279 | """ 280 | node = self.root 281 | while True: # 从根节点一直定位到叶结点 282 | if node.is_leaf(): 283 | break 284 | # 贪婪地选择下一步动作 285 | action, node = node.select(self.c_puct) 286 | simulate_game_state.step(action) 287 | # 使用网络来评估叶结点,产生一个每一个元素均为(action, probability)元组的列表,以及 288 | # 一个以当前玩家视角看待的在[-1, 1]之间的v值 289 | action_probs, leaf_value = self.policy(simulate_game_state) 290 | # 检查模拟游戏是否结束 291 | end, winner = simulate_game_state.game_ended(), simulate_game_state.winner() 292 | if not end: # 没结束则扩展 293 | node.expand(action_probs) 294 | else: 295 | if winner == -1: # 和棋 296 | leaf_value = 0.0 297 | else: 298 | leaf_value = ( 299 | 1.0 if winner == simulate_game_state.turn() else -1.0 300 | ) 301 | # 更新此条遍历路径上的节点的访问次数和value 302 | # 这里的值要符号反转,因为这个值是根据根节点的player视角来得到的 303 | # 但是做出下一步落子的是根节点对应player的对手 304 | node.update_recursive(-leaf_value) 305 | 306 | def get_move_probs(self, game, temp=1e-3, player=None): 307 | """ 308 | 执行n_playout次模拟,并根据子节点的访问次数,获得每个动作对应的概率 309 | 310 | :param game: 游戏模拟器 311 | :param temp: 制探索水平的温度参数 312 | :param player: 调用该函数的player,用于进行进度绘制 313 | :return: 314 | """ 315 | for i in range(self.n_playout): 316 | if not player.valid: 317 | return -1, -1 318 | if player is not None: 319 | player.speed = (i + 1, self.n_playout) 320 | simulate_game_state = game.game_state_simulator(player.is_selfplay) 321 | self.playout(simulate_game_state) 322 | # 基于节点访问次数,计算每个动作对应的概率 323 | act_visits = [(act, node.n_visits) 324 | for act, node in self.root.children.items()] 325 | acts, visits = zip(*act_visits) 326 | act_probs = softmax(1.0 / temp * np.log(np.array(visits) + 1e-10)) 327 | return acts, act_probs 328 | 329 | def get_move(self, game, player=None): 330 | """ 331 | 执行n_playout次模拟,返回访问次数最多的动作 332 | 333 | :param game: 游戏模拟器 334 | :param player: 调用该函数的player,用于进行进度绘制 335 | :return: 返回访问次数最多的动作 336 | """ 337 | for i in range(self.n_playout): 338 | if not player.valid: 339 | return -1 340 | if player is not None: 341 | player.speed = (i + 1, self.n_playout) 342 | game_state = game.game_state_simulator() 343 | self.playout(game_state) 344 | return max(self.root.children.items(), key=lambda act_node: act_node[1].n_visits)[0] 345 | 346 | def update_with_move(self, last_move): 347 | """ 348 | 蒙特卡洛搜索树向深层前进一步,并且保存对应子树的全部信息 349 | 350 | :param last_move: 上一步选择的动作 351 | :return: 352 | """ 353 | if last_move in self.root.children: 354 | self.root = self.root.children[last_move] 355 | self.root.parent = None 356 | else: 357 | self.root = TreeNode(None, 1.0) 358 | ``` 359 | 360 | 361 | 362 | ## 5. 结束语 363 | 364 | 本文介绍了阿尔法狗2016版本和零狗中的蒙特卡洛树搜索及其实现,在机巧围棋中也集成了纯蒙特卡洛落子策略(所有可行动作的概率值是随机的,节点的状态价值通过随机落子到游戏终局,根据胜负确定),大家可以在GitHub上clone机巧围棋的代码,体验纯蒙特卡洛落子策略的效果。 365 | 366 | 最后,期待您能够给本文点个赞,同时去GitHub上给机巧围棋项目点个Star呀~ 367 | 368 | 机巧围棋项目链接:[https://github.com/QPT-Family/QPT-CleverGo](https://github.com/QPT-Family/QPT-CleverGo) 369 | 370 | -------------------------------------------------------------------------------- /docs/训练策略网络和价值网络.md: -------------------------------------------------------------------------------- 1 | # 训练策略网络和价值网络 2 | 3 | 阿尔法狗2016版本使用人类高手棋谱数据初步训练策略网络,并使用深度强化学习中的REINFORCE算法进一步训练策略网络。策略网络训练好之后,使用策略网络辅助训练价值网络。零狗(AlphaGo Zero)使用MCTS控制两个玩家对弈,用自对弈生成的棋谱数据和胜负关系同时训练策略网络和价值网络。 4 | 5 | 在机巧围棋中,训练策略网络和价值网络的方法原理与零狗基本相同。 6 | 7 | 本文将详细讲解阿尔法狗2016版本和零狗中两个神经网络的训练方法。 8 | 9 | 10 | 11 | ## 1. 阿尔法狗2016版本的训练方法 12 | 13 | 2016年3月,阿尔法狗以4 : 1战胜了世界冠军李世石九段。赛前(2016年1月27日)DeepMind公司在nature上发表论文[Mastering the game of Go with deep neural networks and tree search](https://www.cs.princeton.edu/courses/archive/spring16/cos598F/Google-go-nature16.pdf)详细介绍了阿尔法狗的算法原理。 14 | 15 | 阿尔法狗的训练分为三步: 16 | 17 | 1. 随机初始化策略网络![1](https://latex.codecogs.com/png.latex?\pi(a|s;\theta)) 的参数之后,使用行为克隆(Behavior Cloning)从人类高手棋谱中学习策略网络; 18 | 2. 让两个策略网络自我博弈,使用REINFORCE算法改进策略网络; 19 | 3. 使用两个已经训练好的策略网络自我博弈,根据胜负关系数据训练价值网络![2](https://latex.codecogs.com/png.latex?v(s;\omega)) 。 20 | 21 | 22 | 23 | ### 1.1 行为克隆 24 | 25 | REINFORCE算法会让两个策略网络博弈直至游戏结束,使用游戏结束后实际观测到的回报![3](https://latex.codecogs.com/png.latex?u) 对策略梯度中的动作价值函数![4](https://latex.codecogs.com/png.latex?Q_{\pi}) 做蒙特卡洛近似,从而计算出策略梯度的无偏估计值,并做随机梯度上升更新策略网络参数。 26 | 27 | 一开始的时候,策略网络的参数都是随机初始化的。假如直接使用REINFORCE算法学习策略网络,会让两个随机初始化的策略网络博弈。由于策略网络的参数是随机初始化的,它们会做出随机的动作,要经过一个很久的随机摸索过程才能做出合理的动作。因此,阿尔法狗2016版本使用人类专家知识,通过行为克隆初步训练一个策略网络。 28 | 29 | 行为克隆是一种最简单的模仿学习,目的是模仿人的动作,学出一个随机策略网络![5](https://latex.codecogs.com/png.latex?\pi(a|s;\theta)) 。行为克隆的本质是监督学习(分类或回归),其利用事先准备好的数据集,用人类的动作指导策略网络做改进,目的是让策略网络的决策更像人类的决策。 30 | 31 | 在[这个网站](https://u-go.net/gamerecords/)上可以下载到K Go Server(KGS,原名Kiseido Go Server)上大量6段以上高手玩家的对局数据,每一局有很多步,每一步棋盘上的格局作为一个状态![6](https://latex.codecogs.com/png.latex?s_k) ,下一个棋子的位置作为动作![7](https://latex.codecogs.com/png.latex?a_k) ,这样得到数据集![8](https://latex.codecogs.com/png.latex?\{(s_k,a_k)\}) 。 32 | 33 | 设362维向量![9](https://latex.codecogs.com/png.latex?f_k=\pi(\cdot|s_k;\theta)=[\pi(0|s_k;\theta),\pi(1|s_k;\theta),\cdots,\pi(361|s_k;\theta)]) 是策略网络的输出,![10](https://latex.codecogs.com/png.latex?\bar{a}_k) 是对动作![11](https://latex.codecogs.com/png.latex?a_k) 的独热编码(one-hot)。可以使用![12](https://latex.codecogs.com/png.latex?\bar{a}_k) 和![13](https://latex.codecogs.com/png.latex?f_k) 的交叉熵![14](https://latex.codecogs.com/png.latex?H(\bar{a}_k,f_k)) 作为损失函数,计算损失函数关于策略网络参数的梯度,使用随机梯度下降更新策略网络参数,最小化损失函数的值,使策略网络的决策更接近人类高手的动作。 34 | 35 | 行为克隆得到的策略网络模仿人类高手的动作,可以做出比较合理的决策。根据阿尔法狗的论文,它在实战中可以打败业余玩家,但是打不过职业玩家。由于人类高手在实际对局中很少探索奇怪的状态和动作,因此训练数据集上的状态和动作缺乏多样性。在数据集![15](https://latex.codecogs.com/png.latex?\{(s_k,a_k)\}) 上做完行为克隆之后,策略网络在真正对局时,可能会见到陌生的状态,此时做出的决策可能会很糟糕。如果策略网络做出的动作![16](https://latex.codecogs.com/png.latex?a_t) 不够好,那么下一时刻的状态![17](https://latex.codecogs.com/png.latex?s_{t+1}) 可能会比较罕见,于是做出的下一个动作![18](https://latex.codecogs.com/png.latex?a_{t+1}) 会很差;这又导致状态![19](https://latex.codecogs.com/png.latex?s_{t+2}) 非常奇怪,使得动作![20](https://latex.codecogs.com/png.latex?a_{t+2}) 更糟糕。如此“错误累加”,进入这种恶性循环。 36 | 37 | 为了克服上述行为克隆的缺陷,还需要用强化学习训练策略网络。在行为克隆之后再做强化学习改进策略网络,可以击败只用行为克隆的策略网络,胜算是80%。 38 | 39 | > 为什么可以使用策略网络输出和人类高手动作独热编码的交叉熵作为损失函数,可以参见博客:[为什么交叉熵常被用作分类问题的损失函数](https://blog.csdn.net/qq_24178985/article/details/122682830)。 40 | 41 | 42 | 43 | ### 1.2 使用REINFORCE算法改进策略网络 44 | 45 | REINFORCE是一种策略梯度方法,其使用实际观测到的回报![21](https://latex.codecogs.com/png.latex?u) 对策略梯度的无偏估计![22](https://latex.codecogs.com/png.latex?g(s,a;\theta)\triangleq{Q_\pi(s,a)}\cdot\nabla_\theta{\ln\pi(a|s;\theta)}) 中![23](https://latex.codecogs.com/png.latex?Q_\pi) 做蒙特卡洛近似,并通过下面的公式做随机梯度上升,更新策略网络: 46 | 47 | ![24](https://latex.codecogs.com/png.latex?\theta_{new}\leftarrow\theta_{now}+\beta\cdot\sum_{t=1}^nu_t\cdot\nabla\ln\pi(a_t|s_t;\theta_{now})~~~~~~~~~~~~~~~~~~~~~~~(1)) 48 | 49 | 其中,![25](https://latex.codecogs.com/png.latex?\beta) 是学习率,![26](https://latex.codecogs.com/png.latex?n) 是游戏至终局共进行的步数,![27](https://latex.codecogs.com/png.latex?\pi(a_t|s_t;\theta_{now})) 是策略网络的输出,![28](https://latex.codecogs.com/png.latex?\ln\pi(a_t|s_t;\theta_{now})) 是策略网络输出值的对数,![29](https://latex.codecogs.com/png.latex?\nabla\ln\pi(a_t|s_t;\theta_{now})) 是策略网络输出值的对数对策略网络参数求的导数。 50 | 51 | 如图一所示,阿尔法狗使用两个策略网络进行博弈,将胜负作为奖励,计算回报![30](https://latex.codecogs.com/png.latex?u_t) 的值。参与博弈的一个策略网络叫做“玩家”,用最新的参数![31](https://latex.codecogs.com/png.latex?\theta_{now}) ;另一个叫做“对手”,它的参数是从过时的参数中随机选出来的,记作![32](https://latex.codecogs.com/png.latex?\theta_{old}) 。“对手”的作用相当于模拟器(环境)的状态转移函数,在训练过程中,只更新“玩家”的参数,不更新“对手”的参数。 52 | 53 | ![9_1](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/9_1.png) 54 | 55 | 让“玩家”和“对手”博弈,将一局游戏进行到底,根据博弈的胜负关系可以确定公式(1)中回报![33](https://latex.codecogs.com/png.latex?u_t) 的值。假设一局游戏“玩家”共走了![34](https://latex.codecogs.com/png.latex?n) 步,设游戏未结束时,奖励![35](https://latex.codecogs.com/png.latex?r_1=r_2=\cdots=r_{n-1}=0) 。游戏结束时,如果“玩家”赢了,则奖励![36](https://latex.codecogs.com/png.latex?r_n=+1) 。“玩家”输了,则奖励![37](https://latex.codecogs.com/png.latex?r_n=-1) 。设定折扣率![38](https://latex.codecogs.com/png.latex?\gamma=1) ,当“玩家”赢了,所有的回报![39](https://latex.codecogs.com/png.latex?u_1=u_2=\cdots=u_n=+1) ,当“玩家”输了,所有的回报![40](https://latex.codecogs.com/png.latex?u_1=u_2=\cdots=u_n=-1) 。 56 | 57 | > 回报的定义是![41](https://latex.codecogs.com/png.latex?u_t=r_t+\gamma{r_{t+1}}+\gamma^2r_{t+2}\cdots+\gamma^nr_n) ,在阿尔法狗中,设定折扣率![42](https://latex.codecogs.com/png.latex?\gamma=1) 。 58 | > 59 | > REINFORCE是算法深度强化学习领域的一种策略梯度方法,策略梯度的推导比较麻烦,因此本文不会讲解公式(1)的由来。或许以后有时间会系统地讲解深度强化学习相关算法,可以关注我的博客[DeepGeGe的博客主页](https://blog.csdn.net/qq_24178985)哟~ 60 | 61 | 62 | 63 | ### 1.3 训练价值网络 64 | 65 | 在阿尔法狗中,价值网络![43](https://latex.codecogs.com/png.latex?v(s;\omega)) 是对状态价值函数![44](https://latex.codecogs.com/png.latex?V_\pi(s)) 的近似,用于评估状态![45](https://latex.codecogs.com/png.latex?s) 的好坏。状态价值函数![46](https://latex.codecogs.com/png.latex?V_\pi(s)) 依赖于状态![47](https://latex.codecogs.com/png.latex?s) ,状态![48](https://latex.codecogs.com/png.latex?s) 越好,那么价值![49](https://latex.codecogs.com/png.latex?V_\pi(s)) 就越大;![50](https://latex.codecogs.com/png.latex?V_\pi(s)) 还依赖于策略函数![51](https://latex.codecogs.com/png.latex?\pi) ,策略![52](https://latex.codecogs.com/png.latex?\pi) 越好,同样价值![53](https://latex.codecogs.com/png.latex?V_\pi(s)) 也就越大。如果策略![54](https://latex.codecogs.com/png.latex?\pi) 是固定的,则可以用状态价值函数![55](https://latex.codecogs.com/png.latex?V_\pi(s)) 评估状态![56](https://latex.codecogs.com/png.latex?s) 的好坏。因此,阿尔法狗在完成第二步——训练策略网络![57](https://latex.codecogs.com/png.latex?\pi) 之后,用![58](https://latex.codecogs.com/png.latex?\pi) 辅助训练![59](https://latex.codecogs.com/png.latex?v) 。 66 | 67 | 让训练好的策略网络做自我博弈,每对弈完一局,可以记录(状态—回报)二元组![60](https://latex.codecogs.com/png.latex?(s_k,u_k)) 。自我博弈需要重复非常多次,将最终得到的数据集记作![61](https://latex.codecogs.com/png.latex?\{(s_k,u_k)\}_{k=1}^m) 。根据定义,状态价值![62](https://latex.codecogs.com/png.latex?V_\pi(s_k)) 是回报![63](https://latex.codecogs.com/png.latex?U_k) 的期望:![64](https://latex.codecogs.com/png.latex?V_\pi(s_k)=\mathbb{E}[U_k|S_k=s_k]) 。训练价值网络![65](https://latex.codecogs.com/png.latex?v(s;w)) 的目标是使其接近![66](https://latex.codecogs.com/png.latex?V_\pi) ,即让![67](https://latex.codecogs.com/png.latex?v(s;\omega)) 拟合回报![68](https://latex.codecogs.com/png.latex?u_k) 。 68 | 69 | 定义回归问题: 70 | 71 | ![69](https://latex.codecogs.com/png.latex?\overset{min}{_\omega}\frac{1}{2m}\sum_{k=1}^m[v(s_k;\omega)-u_k]^2) 72 | 73 | 用均方误差(MSE)作为损失函数,训练价值网络![70](https://latex.codecogs.com/png.latex?v(s;\omega)) ,求解这个回归问题。 74 | 75 | 76 | 77 | ## 2. 零狗的训练方法 78 | 79 | 根据论文[Mastering the game of Go without human knowledge](https://www.tensorflownews.com/wp-content/uploads/2017/10/nature24270.pdf)可知,零狗和阿尔法狗2016版本的最大区别在于训练策略网络![71](https://latex.codecogs.com/png.latex?\pi(a|s;\theta)) 的方式。训练![72](https://latex.codecogs.com/png.latex?\pi) 的时候,不再向人类高手学习,也不用REINFORCE方法,而是向MCTS学习。可以把零狗训练![73](https://latex.codecogs.com/png.latex?\pi) 的方法看做是模仿学习,被模仿的对象不是人类高手,而是MCTS。 80 | 81 | 82 | 83 | ### 2.1 自我博弈 84 | 85 | 用MCTS控制两个玩家对弈,每走一步棋,需要进行成千上万次模拟,并记录下每个动作被选中的次数![74](https://latex.codecogs.com/png.latex?N(a),\forall{a\in\{0,1,2,\cdots,361\}}) 。设当前时刻为![75](https://latex.codecogs.com/png.latex?t) ,棋盘上状态为![76](https://latex.codecogs.com/png.latex?s_t) ,执行MCTS得到362个动作被选中的次数: 86 | 87 | ![77](https://latex.codecogs.com/png.latex?N(0),N(1),\cdots,N(361)) 88 | 89 | 对这些动作被选中的次数做归一化,得到362个和为1的正数,将这362个数记作362维向量![78](https://latex.codecogs.com/png.latex?p_t=normalize\Big([N(0),N(1),\cdots,N(361)]^T\Big)) 。 90 | 91 | 设这一局游戏走了![79](https://latex.codecogs.com/png.latex?n) 步之后分出胜负,奖励![80](https://latex.codecogs.com/png.latex?r_n) 要么等于![81](https://latex.codecogs.com/png.latex?+1) ,要么等于![82](https://latex.codecogs.com/png.latex?-1) ,取决于游戏的胜负。游戏结束之后,可以得到回报![83](https://latex.codecogs.com/png.latex?u_1=u_2=\cdots=u_n=r_n) 。 92 | 93 | 每自对弈一局可以得到数据:![84](https://latex.codecogs.com/png.latex?(s_1,p_1,u_1),(s_2,p_2,u_2),\cdots,(s_n,p_n,u_n)) 。使用这些数训练策略网络![85](https://latex.codecogs.com/png.latex?\pi) 和价值网络![86](https://latex.codecogs.com/png.latex?v) 。对![87](https://latex.codecogs.com/png.latex?\pi) 和![88](https://latex.codecogs.com/png.latex?v) 的训练同时进行。 94 | 95 | 96 | 97 | ### 2.2 训练策略网络和价值网络 98 | 99 | 根据技术原理文档[蒙特卡洛树搜索(MCTS)](https://github.com/QPT-Family/QPT-CleverGo/blob/main/docs/%E8%92%99%E7%89%B9%E5%8D%A1%E6%B4%9B%E6%A0%91%E6%90%9C%E7%B4%A2(MCTS).md)可知,MCTS做出的决策优于策略网络![89](https://latex.codecogs.com/png.latex?\pi) 的决策(这也是阿尔法狗使用MCTS做决策,而![90](https://latex.codecogs.com/png.latex?\pi) 只是用来辅助MCTS的原因)。既然MCTS做出的决策比![91](https://latex.codecogs.com/png.latex?\pi) 更好,那么可以把MCTS的决策作为目标,让![92](https://latex.codecogs.com/png.latex?\pi) 去模仿。与1.1节所述行为克隆一致,只不过被模仿的对象不是人类高手,而是MCTS,即训练策略网络的数据不是收集到的人类高手对局数据,而是2.1节所述MCTS控制两个玩家对弈生成的对局数据。 100 | 101 | 训练价值网络的目标与阿尔法狗2016版本一致,都是让![93](https://latex.codecogs.com/png.latex?v(s_t;\omega)) 拟合回报![94](https://latex.codecogs.com/png.latex?u_t) 。其中回报![95](https://latex.codecogs.com/png.latex?u_t) 不是通过策略网络做自我博弈胜负得到,而是2.1节所述方法生成。 102 | 103 | 在零狗中,对策略网络和价值网络的训练是同时进行的。将策略网络的损失与价值网络的损失相加,作为训练时优化的目标函数: 104 | 105 | ![96](https://latex.codecogs.com/png.latex?l=(u-v)^2-\pi^Tlog~p+c\big(||\theta||^2+||\omega||^2\big)) 106 | 107 | 其中,![97](https://latex.codecogs.com/png.latex?u) 是2.1所述通过MCTS自我博弈收集到的回报数据,![98](https://latex.codecogs.com/png.latex?v) 是价值网络输出,![99](https://latex.codecogs.com/png.latex?(u-v)^2) 即为价值网络的损失(均方损失);![100](https://latex.codecogs.com/png.latex?\pi) 是策略网络输出,![101](https://latex.codecogs.com/png.latex?p) 是2.1所述通过MCTS自我博弈收集到的被归一化了的每个动作被选中次数数据,![102](https://latex.codecogs.com/png.latex?-\pi^Tlog~p) 即为策略网络的损失(交叉熵损失);![103](https://latex.codecogs.com/png.latex?\theta) 和![104](https://latex.codecogs.com/png.latex?\omega) 分别是策略网络参数和价值网络参数,![105](https://latex.codecogs.com/png.latex?\big(||\theta||^2+||\omega||^2\big)) 即为防止过拟合的正则项(L2正则);![106](https://latex.codecogs.com/png.latex?c) 是一个超参数,用于控制L2正则化权重。 108 | 109 | > 零狗论文中所述神经网络![107](https://latex.codecogs.com/png.latex?f_\theta(s)=\big(P(s,\cdot),V(s)\big)) 由策略网络和价值网络构成,因此论文中神经网络参数![108](https://latex.codecogs.com/png.latex?\theta) ,等同于本文中的策略网络和价值网络参数![109](https://latex.codecogs.com/png.latex?\theta+\omega) 。如果读者留意到零狗论文中目标函数![110](https://latex.codecogs.com/png.latex?l=(u-v)^2-\pi^Tlog~p+c||\theta||^2) 与文本所述存在一定差别,不必感到疑惑,也不必质疑本文的正确性。 110 | > 111 | > 本文对论文中相关原理以更容易理解的方式来表述,但是相关方法在本质上是相同的。 112 | 113 | 114 | 115 | ### 2.3 训练流程 116 | 117 | 随机初始化策略网络参数![111](https://latex.codecogs.com/png.latex?\theta) 和价值网络参数![112](https://latex.codecogs.com/png.latex?\omega) ,然后让MCTS自我博弈,玩很多局游戏。每完成一局游戏,更新一次![113](https://latex.codecogs.com/png.latex?\theta) 和![114](https://latex.codecogs.com/png.latex?\omega) 。具体训练流程如下,训练会重复如下步骤直到收敛: 118 | 119 | 1. 让MCTS自我博弈,完成一局游戏,收集到![115](https://latex.codecogs.com/png.latex?n) 个三元组:![116](https://latex.codecogs.com/png.latex?(s_1,p_1,u_1),(s_2,p_2,u_2),\cdots,(s_n,p_n,u_n)) ; 120 | 2. 做梯度下降,同时更新策略网络参数![117](https://latex.codecogs.com/png.latex?\theta) 和价值网络参数![118](https://latex.codecogs.com/png.latex?\omega) 。 121 | 122 | 123 | 124 | ## 3. 结束语 125 | 126 | 本文介绍了阿尔法狗2016版本和零狗中训练策略网络和价值网络的方法,机巧围棋中训练方法与零狗基本一致。大家可以在GitHub上clone机巧围棋的代码,结合本文理解和学习零狗的训练方法。 127 | 128 | 最后,期待您能够给本文点个赞,同时去GitHub上给机巧围棋项目点个Star呀~ 129 | 130 | 机巧围棋项目链接:[https://github.com/QPT-Family/QPT-CleverGo](https://github.com/QPT-Family/QPT-CleverGo) 131 | 132 | -------------------------------------------------------------------------------- /docs/阿尔法狗与机巧围棋的网络结构.md: -------------------------------------------------------------------------------- 1 | # 阿尔法狗与机巧围棋的网络结构 2 | 3 | 阿尔法狗(AlphaGo)的意思是“围棋王”,俗称“阿尔法狗”,它是世界上第一个打败人类围棋冠军的AI。2015年10月,阿尔法狗以5 : 0战胜了欧洲围棋冠军樊麾二段,在2016年3月,阿尔法狗以4 : 1战胜了世界冠军李世石。2017年,新版不依赖人类经验完全从零开始自学的零狗(AlphaGo Zero)以100 : 0战胜阿尔法狗。 4 | 5 | 阿尔法狗使用策略网络和价值网络辅助蒙特卡洛树搜索,以降低搜索的深度和宽度。机巧围棋落子策略完全基于零狗算法,本文将用强化学习的语言描述围棋游戏的状态和动作,并介绍阿尔法狗和机巧围棋中构造的策略网络和价值网络。 6 | 7 | 8 | 9 | ## 1. 动作和状态 10 | 11 | 围棋的棋盘是19 X 19的网格,黑白双方轮流在两条线的交叉点处放置棋子。一共有19 X 19 = 361个可以放置棋子的位置,同时可以选择PASS(放弃一次当前落子的权利),因此动作空间是![1](https://latex.codecogs.com/png.latex?\mathcal{A}=\{0,1,2,\cdots,361\}) ,其中第![2](https://latex.codecogs.com/png.latex?i) 种动作表示在第![3](https://latex.codecogs.com/png.latex?i) 个位置(从0开始)放置棋子,第361种动作表示PASS。 12 | 13 | 机巧围棋是基于9路围棋的人工智能程序,即棋盘是9 X 9的网格。相应地动作空间![4](https://latex.codecogs.com/png.latex?\mathcal{A}=\{0,1,2,\cdots,81\}) 。 14 | 15 | ![7_1](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/7_1.png) 16 | 17 | 阿尔法狗2016版本使用19 X 19 X 48的张量(tensor)表示一个状态,零狗使用19 X 19 X 17的张量表示一个状态。如图一所示,零狗中使用的状态张量的意义如下: 18 | 19 | - 状态张量的每个切片(slice)是19 X 19的矩阵,对应19 X 19的棋盘。一个19 X 19的矩阵可以表示棋盘上所有黑色棋子的位置,如果一个位置上有黑色棋子,则矩阵对应位置的元素为1,否则为0。同样的道理,可以用一个19 X 19的矩阵表示棋盘上所有白色棋子的位置。 20 | - 在零狗的状态张量中,一共存在17个矩阵。其中8个矩阵记录最近8步棋盘上黑子的位置,8个矩阵记录最近8步白子的位置。还有一个矩阵表示下一步落子方,如果接下来由黑方落子,则该矩阵元素全部等于1,如果接下来由白方落子,则该矩阵的元素全部都等于0。 21 | 22 | 为了减少计算量,机巧围棋对状态张量做了一定的简化。在机巧围棋中,使用9 X 9 X 10的张量表示一个状态,其中4个9 X 9的矩阵记录最近4步棋盘上黑子的位置,4个矩阵记录白子的位置。一个矩阵表示下一步落子方,如果接下来由黑方落子,则该矩阵元素全部等于0,由白方落子则等于1。还有最后一个矩阵表示上一步落子位置,即上一步落子位置元素为1,其余位置元素为0,若上一步为PASS,则该矩阵元素全部为0。 23 | 24 | > 阿尔法狗2016版本的状态张量意义比较复杂,本文不详细展开,具体可参加下图: 25 | > 26 | > ![7_2](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/7_2.png) 27 | 28 | 29 | 30 | ## 2. 策略网络 31 | 32 | 策略网络![5](https://latex.codecogs.com/png.latex?\pi(a|s;\theta)) 的结构如图三所示。零狗策略网络的输入是19 X 19 X 17的状态![6](https://latex.codecogs.com/png.latex?s) ,输出是362维的向量![7](https://latex.codecogs.com/png.latex?f) ,它的每个元素对应动作空间中的一个动作。策略网络的输出层激活函数为Softmax,因此向量![8](https://latex.codecogs.com/png.latex?f) 所有元素均是正数,而且相加等于1。 33 | 34 | ![7_3](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/7_3.png) 35 | 36 | 37 | 38 | ## 3. 价值网络 39 | 40 | 在阿尔法中还有一个价值网络![9](https://latex.codecogs.com/png.latex?v_\pi(s;\omega)) ,它是对状态价值函数![10](https://latex.codecogs.com/png.latex?V_\pi(s)) 的近似,价值网络的结构如图四所示。价值网络的输入是19 X 19 X 17的状态![11](https://latex.codecogs.com/png.latex?s) ,输出是一个![12](https://latex.codecogs.com/png.latex?[-1,+1]) 的实数,它的大小评价当前状态![13](https://latex.codecogs.com/png.latex?s) 的好坏。 41 | 42 | ![7_4](https://github.com/QPT-Family/QPT-CleverGo/blob/main/pictures/7_4.png) 43 | 44 | 策略网络和价值网络的输入相同,都是状态![14](https://latex.codecogs.com/png.latex?s) 。而且它们都用卷积层将![15](https://latex.codecogs.com/png.latex?s) 映射到特征向量,因此零狗中让策略网络和价值网络共用卷积层。 45 | 46 | > 零狗中策略网络和价值网络共用卷积层,但是在阿尔法狗2016版本中没有共用。因为零狗中策略网络和价值网络是一起训练的,而阿尔法狗2016版本中是先训练策略网络,然后用策略网络来训练价值网络,二者不是同时训练的,因此不能共用卷积层。后续会详细介绍阿尔法狗中神经网络训练方法。 47 | 48 | 49 | 50 | ## 4. 机巧围棋网络结构 51 | 52 | 零狗训练用了5000块TPU,在机巧围棋中为了减少计算量,大幅简化了策略网络和价值网络。在机巧围棋中,使用了3个卷积层从状态![16](https://latex.codecogs.com/png.latex?s) 中提取特征,分别是: 53 | 54 | - 3 X 3步长1的32通道卷积; 55 | - 3 X 3步长1的64通道卷积; 56 | - 3 X 3步长1的128通道卷积。 57 | 58 | 在策略网络部分,首先使用1 X 1的8通道卷积对信息进行夸通道整合,再接一个全连接层将特征向量维度压缩成256,最后再接入输出层;在价值网络部分,首先使用1 X 1的4通道卷积对信息进行夸通道整合,再接入两个全连接层,最后接入输出层。具体代码如下: 59 | 60 | ```python 61 | # -*- coding: utf-8 -*- 62 | # @Time : 2021/3/29 21:01 63 | # @Author : He Ruizhi 64 | # @File : policy_value_net.py 65 | # @Software: PyCharm 66 | 67 | import paddle 68 | 69 | 70 | class PolicyValueNet(paddle.nn.Layer): 71 | def __init__(self, input_channels: int = 10, 72 | board_size: int = 9): 73 | """ 74 | 75 | :param input_channels: 输入的通道数,默认为10。双方最近4步,再加一个表示当前落子方的平面,再加上一个最近一手位置的平面 76 | :param board_size: 棋盘大小 77 | """ 78 | super(PolicyValueNet, self).__init__() 79 | 80 | # AlphaGo Zero网络架构:一个身子,两个头 81 | # 特征提取网络部分 82 | self.conv_layer = paddle.nn.Sequential( 83 | paddle.nn.Conv2D(in_channels=input_channels, out_channels=32, kernel_size=3, padding=1), 84 | paddle.nn.ReLU(), 85 | paddle.nn.Conv2D(in_channels=32, out_channels=64, kernel_size=3, padding=1), 86 | paddle.nn.ReLU(), 87 | paddle.nn.Conv2D(in_channels=64, out_channels=128, kernel_size=3, padding=1), 88 | paddle.nn.ReLU() 89 | ) 90 | 91 | # 策略网络部分 92 | self.policy_layer = paddle.nn.Sequential( 93 | paddle.nn.Conv2D(in_channels=128, out_channels=8, kernel_size=1), 94 | paddle.nn.ReLU(), 95 | paddle.nn.Flatten(), 96 | paddle.nn.Linear(in_features=9*9*8, out_features=256), 97 | paddle.nn.ReLU(), 98 | paddle.nn.Linear(in_features=256, out_features=board_size*board_size+1), 99 | paddle.nn.Softmax() 100 | ) 101 | 102 | # 价值网络部分 103 | self.value_layer = paddle.nn.Sequential( 104 | paddle.nn.Conv2D(in_channels=128, out_channels=4, kernel_size=1), 105 | paddle.nn.ReLU(), 106 | paddle.nn.Flatten(), 107 | paddle.nn.Linear(in_features=9*9*4, out_features=128), 108 | paddle.nn.ReLU(), 109 | paddle.nn.Linear(in_features=128, out_features=64), 110 | paddle.nn.ReLU(), 111 | paddle.nn.Linear(in_features=64, out_features=1), 112 | paddle.nn.Tanh() 113 | ) 114 | 115 | def forward(self, x): 116 | x = self.conv_layer(x) 117 | policy = self.policy_layer(x) 118 | value = self.value_layer(x) 119 | return policy, value 120 | ``` 121 | 122 | 123 | 124 | ## 5. 结束语 125 | 126 | 本文介绍了阿尔法狗中的两个深度神经网络——策略网络和价值网络,并讲解了机巧围棋中的网络实现。在阿尔法狗或者机巧围棋中,神经网络结构并不是一层不变的,可以依据个人经验或喜好随意调整。总的来说,浅的网络能够减少计算量,加快训练和落子过程,深的网络可能更有希望训练出更高水平狗。 127 | 128 | 最后,期待您能够给本文点个赞,同时去GitHub上给机巧围棋项目点个Star呀~ 129 | 130 | 机巧围棋项目链接:[https://github.com/QPT-Family/QPT-CleverGo](https://github.com/QPT-Family/QPT-CleverGo) 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /go_engine.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/9/30 14:32 3 | # @Author : He Ruizhi 4 | # @File : go_engine.py 5 | # @Software: PyCharm 6 | 7 | from GymGo.gym_go import govars, gogame 8 | from typing import Union, List, Tuple 9 | import numpy as np 10 | from scipy import ndimage 11 | 12 | surround_struct = np.array([[0, 1, 0], 13 | [1, 0, 1], 14 | [0, 1, 0]]) 15 | 16 | eye_struct = np.array([[1, 1, 1], 17 | [1, 0, 1], 18 | [1, 1, 1]]) 19 | 20 | corner_struct = np.array([[1, 0, 1], 21 | [0, 0, 0], 22 | [1, 0, 1]]) 23 | BLACK = govars.BLACK 24 | WHITE = govars.WHITE 25 | 26 | 27 | class GoEngine: 28 | def __init__(self, board_size: int = 9, 29 | komi=7.5, 30 | record_step: int = 4, 31 | state_format: str = "separated", 32 | record_last: bool = True): 33 | """ 34 | 围棋引擎初始化 35 | 36 | :param board_size: 棋盘大小,默认为9 37 | :param komi: 黑棋贴目数,默认黑贴7.5目(3又3/4子) 38 | :param record_step: 记录棋盘历史状态步数,默认为4 39 | :param state_format: 记录棋盘历史状态格式 40 | 【separated:黑白棋子分别记录在不同的矩阵中,[黑棋,白棋,下一步落子方,上一步落子位置(可选)]】 41 | 【merged:黑白棋子记录在同一个矩阵中,[棋盘棋子分布(黑1白-1),下一步落子方,上一步落子位置(可选)]】 42 | :param record_last: 是否记录上一步落子位置 43 | """ 44 | assert state_format in ["separated", "merged"],\ 45 | "state_format can only be 'separated' or 'merged', but received: {}".format(state_format) 46 | 47 | self.board_size = board_size 48 | self.komi = komi 49 | self.record_step = record_step 50 | self.state_format = state_format 51 | self.record_last = record_last 52 | self.current_state = gogame.init_state(board_size) 53 | # 保存棋盘状态,用于悔棋 54 | self.board_state_history = [] 55 | # 保存历史动作,用于悔棋 56 | self.action_history = [] 57 | 58 | if state_format == "separated": 59 | record_step *= 2 60 | self.state_channels = record_step + 2 if record_last else record_step + 1 61 | self.board_state = np.zeros((self.state_channels, board_size, board_size)) 62 | self.done = False 63 | 64 | def reset(self) -> np.ndarray: 65 | """重置current_state, board_state, board_state_history, action_history""" 66 | self.current_state = gogame.init_state(self.board_size) 67 | self.board_state = np.zeros((self.state_channels, self.board_size, self.board_size)) 68 | self.board_state_history = [] 69 | self.action_history = [] 70 | self.done = False 71 | return np.copy(self.current_state) 72 | 73 | def step(self, action: Union[List[int], Tuple[int], int, None]) -> np.ndarray: 74 | """ 75 | 围棋落子 76 | 77 | :param action: 下一步落子位置 78 | :return: 79 | """ 80 | assert not self.done 81 | if isinstance(action, tuple) or isinstance(action, list) or isinstance(action, np.ndarray): 82 | assert 0 <= action[0] < self.board_size 83 | assert 0 <= action[1] < self.board_size 84 | action = self.board_size * action[0] + action[1] 85 | elif isinstance(action, int): 86 | assert 0 <= action <= self.board_size ** 2 87 | elif action is None: 88 | action = self.board_size ** 2 89 | 90 | self.current_state = gogame.next_state(self.current_state, action, canonical=False) 91 | # 更新self.board_state 92 | self.board_state = self._update_state_step(action) 93 | # 存储历史状态 94 | self.board_state_history.append(np.copy(self.current_state)) 95 | # 存储历史动作 96 | self.action_history.append(action) 97 | self.done = gogame.game_ended(self.current_state) 98 | return np.copy(self.current_state) 99 | 100 | def _update_state_step(self, action: int) -> np.ndarray: 101 | """ 102 | 更新self.board_state,须在更新完self.current_state之后更新self.board_state 103 | 104 | :param action: 下一步落子位置,1d-action 105 | :return: 106 | """ 107 | if self.state_format == "separated": 108 | # 根据上一步落子方更新self.board_state(因为self.current_state已经更新完毕) 109 | if self.turn() == govars.WHITE: 110 | # 根据更新过后的self.current_state,下一步落子方为白方,则上一步落子方为黑方 111 | self.board_state[:self.record_step - 1] = np.copy(self.board_state[1:self.record_step]) 112 | self.board_state[self.record_step - 1] = np.copy(self.current_state[govars.BLACK]) 113 | else: 114 | # 根据更新过后的self.current_state,下一步落子方为黑方,则上一步落子方为白方 115 | self.board_state[self.record_step: self.record_step * 2 - 1] = \ 116 | np.copy(self.board_state[self.record_step + 1: self.record_step * 2]) 117 | self.board_state[self.record_step * 2 - 1] = np.copy(self.current_state[govars.WHITE]) 118 | elif self.state_format == "merged": 119 | self.board_state[:self.record_step - 1] = np.copy(self.board_state[1:self.record_step]) 120 | current_state = self.current_state[[govars.BLACK, govars.WHITE]] 121 | current_state[govars.WHITE] *= -1 122 | self.board_state[self.record_step - 1] = np.sum(current_state, axis=0) 123 | 124 | if self.record_last: 125 | # 更新下一步落子方 126 | self.board_state[-2] = np.copy(self.current_state[govars.TURN_CHNL]) 127 | # 更新上一步落子位置 128 | self.board_state[-1] = np.zeros((self.board_size, self.board_size)) 129 | # 上一步不为pass 130 | if action != self.board_size ** 2: 131 | # 将action转换成position 132 | position = action // self.board_size, action % self.board_size 133 | self.board_state[-1, position[0], position[1]] = 1 134 | else: 135 | # 更新下一步落子方 136 | self.board_state[-1] = np.copy(self.current_state[govars.TURN_CHNL]) 137 | return np.copy(self.board_state) 138 | 139 | def get_board_state(self) -> np.ndarray: 140 | """用于训练神经网络的棋盘状态矩阵""" 141 | return np.copy(self.board_state) 142 | 143 | def game_ended(self) -> bool: 144 | """游戏是否结束""" 145 | return self.done 146 | 147 | def winner(self) -> int: 148 | """获胜方,游戏未结束返回-1""" 149 | if not self.done: 150 | return -1 151 | else: 152 | winner = self.winning() 153 | winner = govars.BLACK if winner == 1 else govars.WHITE 154 | return winner 155 | 156 | def action_valid(self, action) -> bool: 157 | """判断action是否合法""" 158 | return self.valid_moves()[action] 159 | 160 | def valid_move_idcs(self) -> np.ndarray: 161 | """下一步落子有效位置的id""" 162 | valid_moves = self.valid_moves() 163 | return np.argwhere(valid_moves).flatten() 164 | 165 | def advanced_valid_move_idcs(self) -> np.ndarray: 166 | """下一步落子的非真眼有效位置的id""" 167 | advanced_valid_moves = self.advanced_valid_moves() 168 | return np.argwhere(advanced_valid_moves).flatten() 169 | 170 | def uniform_random_action(self) -> np.ndarray: 171 | """随机选择落子位置""" 172 | valid_move_idcs = self.valid_move_idcs() 173 | return np.random.choice(valid_move_idcs) 174 | 175 | def advanced_uniform_random_action(self) -> np.ndarray: 176 | """不填真眼的随机位置""" 177 | advanced_valid_move_idcs = self.advanced_valid_move_idcs() 178 | return np.random.choice(advanced_valid_move_idcs) 179 | 180 | def turn(self) -> int: 181 | """下一步落子方""" 182 | return gogame.turn(self.current_state) 183 | 184 | def valid_moves(self) -> np.ndarray: 185 | """下一步落子的有效位置""" 186 | return gogame.valid_moves(self.current_state) 187 | 188 | def advanced_valid_moves(self): 189 | """下一步落子的非真眼有效位置""" 190 | valid_moves = 1 - self.current_state[govars.INVD_CHNL] 191 | eyes_mask = 1 - self.eyes() 192 | return np.append((valid_moves * eyes_mask).flatten(), 1) 193 | 194 | def winning(self): 195 | """ 196 | 当游戏结束之后,从黑方角度看待,上一步落子后,哪一方胜利 197 | 黑胜:1 白胜:-1 198 | """ 199 | return gogame.winning(self.current_state, self.komi) 200 | 201 | def areas(self): 202 | """black_area, white_area""" 203 | return gogame.areas(self.current_state) 204 | 205 | def eyes(self): 206 | """ 207 | 下一步落子方的真眼位置 208 | 1.如果在角上或者边上,则需要对应8个最近位置均有下一步落子方的棋子; 209 | 2.如果不在边上和角上,则需要对应4个最近边全有下一步落子方的棋子,且至少有三个角有下一步落子方的棋子; 210 | 3.所判断的位置没有棋子 211 | """ 212 | board_shape = self.current_state.shape[1:] 213 | 214 | side_mask = np.zeros(board_shape) 215 | side_mask[[0, -1], :] = 1 216 | side_mask[:, [0, -1]] = 1 217 | nonside_mask = 1 - side_mask 218 | 219 | # 下一步落子方 220 | next_player = self.turn() 221 | # next_player的棋子分布矩阵 222 | next_player_pieces = self.current_state[next_player] 223 | # 棋盘所有有棋子的分布矩阵,有棋子则相应位置为1 224 | all_pieces = np.sum(self.current_state[[govars.BLACK, govars.WHITE]], axis=0) 225 | # 棋盘上所有空交叉点的分布矩阵,空交叉点位置为1 226 | empties = 1 - all_pieces 227 | 228 | # 对于边角位置 229 | side_matrix = ndimage.convolve(next_player_pieces, eye_struct, mode='constant', cval=1) == 8 230 | side_matrix = side_matrix * side_mask 231 | # 对于非边角位置 232 | nonside_matrix = ndimage.convolve(next_player_pieces, surround_struct, mode='constant', cval=1) == 4 233 | nonside_matrix *= ndimage.convolve(next_player_pieces, corner_struct, mode='constant', cval=1) > 2 234 | nonside_matrix = nonside_matrix * nonside_mask 235 | 236 | return empties * (side_matrix + nonside_matrix) 237 | 238 | def all_symmetries(self) -> List[np.ndarray]: 239 | """board_state的8种等价表示""" 240 | return gogame.all_symmetries(np.copy(self.board_state)) 241 | 242 | @staticmethod 243 | def array_symmetries(array: np.ndarray) -> List[np.ndarray]: 244 | """ 245 | 指定array的8种旋转表示 246 | 247 | :param array: A (C, BOARD_SIZE, BOARD_SIZE) numpy array, where C is any number 248 | :return: 249 | """ 250 | return gogame.all_symmetries(array) 251 | -------------------------------------------------------------------------------- /mcts.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/3/20 10:51 3 | # @Author : He Ruizhi 4 | # @File : player.py 5 | # @Software: PyCharm 6 | 7 | import numpy as np 8 | import copy 9 | from operator import itemgetter 10 | 11 | 12 | def softmax(x): 13 | probs = np.exp(x - np.max(x)) 14 | probs /= np.sum(probs) 15 | return probs 16 | 17 | 18 | def evaluate_rollout(simulate_game_state, rollout_policy_fn, limit=1000): 19 | """ 20 | 使用rollout_policy_fn玩游戏直至游戏结束或达到限制数,如果当前玩家获胜,则返回+1,对手胜则返回-1,和棋则返回0 21 | 如果模拟次数超过限制游戏还没结束,则同样返回0 22 | 23 | :param simulate_game_state: 模拟游戏状态 24 | :param rollout_policy_fn: 产生下一步各合法动作及其概率的函数 25 | :param limit: 限制模拟步数,超过这个限制还没结束,游戏视为和棋 26 | :return: 27 | """ 28 | game_state_copy = copy.deepcopy(simulate_game_state) 29 | player = game_state_copy.turn() 30 | for _ in range(limit): 31 | end, winner = game_state_copy.game_ended(), game_state_copy.winner() 32 | if end: 33 | break 34 | action_probs = rollout_policy_fn(game_state_copy) 35 | max_action = max(action_probs, key=itemgetter(1))[0] 36 | game_state_copy.step(max_action) 37 | else: 38 | winner = -1 39 | if winner == -1: # 和棋 40 | return 0 41 | else: 42 | return 1 if winner == player else -1 43 | 44 | 45 | class TreeNode: 46 | """蒙特卡洛树节点""" 47 | def __init__(self, parent, prior_p): 48 | self.parent = parent # 节点的父节点 49 | self.children = {} # 一个字典,用来存节点的子节点 50 | self.n_visits = 0 # 节点被访问的次数 51 | self.Q = 0 # 节点的平均行动价值 52 | self.U = 0 # MCTS选择Q+U最大的节点,公式里的U 53 | self.P = prior_p # 节点被选择的概率 54 | 55 | def select(self, c_puct): 56 | """ 57 | 蒙特卡洛树搜索的第一步:选择 58 | 蒙特卡洛树搜索通过不断选择 最大上置信限Q+U 的子节点,直至一个树的叶结点 59 | 该函数为进行一步选择函数 60 | 61 | :param c_puct: 为计算U值公式中的c_puct,是一个决定探索水平的常数 62 | :return: 返回一个元组(action, next_node) 63 | """ 64 | return max(self.children.items(), 65 | key=lambda act_node: act_node[1].get_value(c_puct)) 66 | 67 | def expand(self, action_priors): 68 | """ 69 | 当select搜索到一个叶结点,且该叶节点代表的局面游戏没有结束, 70 | 需要expand树,创建一系列可能得节点,即对应节点所有可能选择的动作对应的子节点 71 | 72 | :param action_priors: 为一个列表,列表中的每一个元素为一个 特定动作及其先验概率 的元组 73 | :return: 74 | """ 75 | for action, prob in action_priors: 76 | if action not in self.children: 77 | self.children[action] = TreeNode(self, prob) 78 | 79 | def update(self, leaf_value): 80 | """ 81 | 根据子树的价值更新当前节点的价值 82 | 83 | :param leaf_value: 以当前玩家的视角看待得到的子树的价值 84 | :return: 85 | """ 86 | self.n_visits += 1 # 当前节点的访问次数+1 87 | # 更新当前节点的Q值,下述公式可由Q = W / N 推导得到 88 | # Q_old = W_old / N_old 89 | # Q = (W_old + v) / (N_old + 1) = (Q_old * N_old + v) / (N_old + 1) 90 | self.Q += 1.0 * (leaf_value - self.Q) / self.n_visits 91 | 92 | def update_recursive(self, leaf_value): 93 | """ 94 | 跟心所有祖先的Q值及访问次数 95 | 96 | :param leaf_value: 97 | :return: 98 | """ 99 | if self.parent: # 如果有父节点,证明还没到根节点 100 | self.parent.update_recursive(-leaf_value) # -leaf_value是因为每向上一层,以当前玩家视角,价值反转 101 | self.update(leaf_value) 102 | 103 | def get_value(self, c_puct): 104 | """ 105 | 计算并返回一个节点的 上置信限 评价,即Q+U值 106 | 107 | :param c_puct: 为计算U值公式中的c_puct,是一个决定探索水平的常数 108 | :return: 109 | """ 110 | self.U = c_puct * self.P * np.sqrt(self.parent.n_visits) / (1 + self.n_visits) 111 | return self.Q + self.U 112 | 113 | def is_leaf(self): 114 | """ 115 | 判断当前节点是否为叶结点。、 116 | 117 | :return: 118 | """ 119 | return self.children == {} 120 | 121 | def is_root(self): 122 | """ 123 | 判断当前节点是否为根节点 124 | 125 | :return: 126 | """ 127 | return self.parent is None 128 | 129 | 130 | class MCTS: 131 | """蒙特卡洛树搜索主体""" 132 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): 133 | self.root = TreeNode(None, 1.0) # 整个蒙特卡洛搜索树的根节点 134 | # policy_value_fn是一个函数,该函数的输入为game_state, 135 | # 输出为一个列表,列表中的每一个元素为(action, probability)形式的元组 136 | self.policy = policy_value_fn 137 | # c_puct为一个正数,用于控制多块收敛到策略的最大值。这个数越大,意味着越依赖前面的结果。 138 | self.c_puct = c_puct 139 | self.n_playout = n_playout 140 | 141 | def playout(self, simulate_game_state): 142 | """ 143 | 从根节点不断选择直到叶结点,并获取叶结点的值,反向传播到叶结点的祖先节点 144 | 145 | :param simulate_game_state: 模拟游戏对象 146 | :return: 147 | """ 148 | node = self.root 149 | while True: # 从根节点一直定位到叶结点 150 | if node.is_leaf(): 151 | break 152 | # 贪婪地选择下一步动作 153 | action, node = node.select(self.c_puct) 154 | simulate_game_state.step(action) 155 | # 使用网络来评估叶结点,产生一个每一个元素均为(action, probability)元组的列表,以及 156 | # 一个以当前玩家视角看待的在[-1, 1]之间的v值 157 | action_probs, leaf_value = self.policy(simulate_game_state) 158 | # 检查模拟游戏是否结束 159 | end, winner = simulate_game_state.game_ended(), simulate_game_state.winner() 160 | if not end: # 没结束则扩展 161 | node.expand(action_probs) 162 | else: 163 | if winner == -1: # 和棋 164 | leaf_value = 0.0 165 | else: 166 | leaf_value = ( 167 | 1.0 if winner == simulate_game_state.turn() else -1.0 168 | ) 169 | # 更新此条遍历路径上的节点的访问次数和value 170 | # 这里的值要符号反转,因为这个值是根据根节点的player视角来得到的 171 | # 但是做出下一步落子的是根节点对应player的对手 172 | node.update_recursive(-leaf_value) 173 | 174 | def get_move_probs(self, game, temp=1e-3, player=None): 175 | """ 176 | 执行n_playout次模拟,并根据子节点的访问次数,获得每个动作对应的概率 177 | 178 | :param game: 游戏模拟器 179 | :param temp: 制探索水平的温度参数 180 | :param player: 调用该函数的player,用于进行进度绘制 181 | :return: 182 | """ 183 | for i in range(self.n_playout): 184 | if not player.valid: 185 | return -1, -1 186 | if player is not None: 187 | player.speed = (i + 1, self.n_playout) 188 | simulate_game_state = game.game_state_simulator(player.is_selfplay) 189 | self.playout(simulate_game_state) 190 | # 基于节点访问次数,计算每个动作对应的概率 191 | act_visits = [(act, node.n_visits) 192 | for act, node in self.root.children.items()] 193 | acts, visits = zip(*act_visits) 194 | act_probs = softmax(1.0 / temp * np.log(np.array(visits) + 1e-10)) 195 | return acts, act_probs 196 | 197 | def get_move(self, game, player=None): 198 | """ 199 | 执行n_playout次模拟,返回访问次数最多的动作 200 | 201 | :param game: 游戏模拟器 202 | :param player: 调用该函数的player,用于进行进度绘制 203 | :return: 返回访问次数最多的动作 204 | """ 205 | for i in range(self.n_playout): 206 | if not player.valid: 207 | return -1 208 | if player is not None: 209 | player.speed = (i + 1, self.n_playout) 210 | game_state = game.game_state_simulator() 211 | self.playout(game_state) 212 | return max(self.root.children.items(), key=lambda act_node: act_node[1].n_visits)[0] 213 | 214 | def update_with_move(self, last_move): 215 | """ 216 | 蒙特卡洛搜索树向深层前进一步,并且保存对应子树的全部信息 217 | 218 | :param last_move: 上一步选择的动作 219 | :return: 220 | """ 221 | if last_move in self.root.children: 222 | self.root = self.root.children[last_move] 223 | self.root.parent = None 224 | else: 225 | self.root = TreeNode(None, 1.0) 226 | -------------------------------------------------------------------------------- /models/alpha_go.pdparams: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/models/alpha_go.pdparams -------------------------------------------------------------------------------- /pgutils/manager.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/10/7 10:57 3 | # @Author : He Ruizhi 4 | # @File : manager.py 5 | # @Software: PyCharm 6 | 7 | from pgutils.pgcontrols.ctbase import CtBase 8 | from pgutils.pgtools.toolbase import ToolBase 9 | import pygame 10 | from typing import List, Union 11 | 12 | 13 | class Manager: 14 | def __init__(self): 15 | self.controls = [] 16 | self.tools = [] 17 | 18 | def control_register(self, controls: Union[List[CtBase], CtBase]): 19 | """ 20 | 控件注册 21 | 22 | :param controls: pygame控件或控件数组 23 | :return: 24 | """ 25 | if isinstance(controls, CtBase): 26 | self.controls.append(controls) 27 | else: 28 | for control in controls: 29 | self.controls.append(control) 30 | 31 | def tool_register(self, tools: Union[List[ToolBase], ToolBase]): 32 | """ 33 | 工具注册 34 | 35 | :param tools: pygame工具或工具数组 36 | :return: 37 | """ 38 | if isinstance(tools, ToolBase): 39 | self.tools.append(tools) 40 | else: 41 | for tool in tools: 42 | self.tools.append(tool) 43 | 44 | def control_update(self, event: pygame.event): 45 | """ 46 | 对所有注册的激活控件进行更新 47 | 48 | :param event: pygame事件 49 | :return: 50 | """ 51 | for control in self.controls: 52 | if control.active: 53 | control.update(event) 54 | 55 | def tool_update(self): 56 | """对所有激活的工具进行更新""" 57 | for tool in self.tools: 58 | if tool.active: 59 | tool.update() 60 | # pgtool会在更新后冻结 61 | tool.disable() 62 | -------------------------------------------------------------------------------- /pgutils/pgcontrols/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/10/4 20:57 3 | # @Author : He Ruizhi 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | 7 | """ 8 | Pygame controls: 以事件(pygame.event)驱动的对象 9 | """ 10 | -------------------------------------------------------------------------------- /pgutils/pgcontrols/button.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/9/25 20:45 3 | # @Author : He Ruizhi 4 | # @File : button.py 5 | # @Software: PyCharm 6 | 7 | import pygame 8 | import os 9 | import copy 10 | from pgutils.text import draw_text 11 | from pgutils.position import pos_in_surface 12 | from pgutils.pgcontrols.ctbase import CtBase 13 | from typing import Tuple, List, Union, Callable, Optional 14 | 15 | current_path = os.path.dirname(__file__) 16 | 17 | 18 | class Button(CtBase): 19 | """每一个Button均为一个pygame.surface.subsurface""" 20 | def __init__(self, surface: pygame.Surface, 21 | text: str, 22 | pos: Union[Tuple[str or int], List[str or int]], 23 | call_function: Optional[Callable] = None, 24 | click_sound: Union[str, pygame.mixer.Sound] = current_path + "/../../assets/audios/Button.wav", 25 | font_path: str = current_path + "/../../assets/fonts/msyh.ttc", 26 | font_size: int = 14, 27 | size: Union[Tuple[int], List[int]] = (87, 27), 28 | text_color: Union[Tuple[int], List[int]] = (0, 0, 0), 29 | up_color: Union[Tuple[int], List[int]] = (225, 225, 225), 30 | down_color: Union[Tuple[int], List[int]] = (190, 190, 190), 31 | outer_edge_color: Union[Tuple[int], List[int]] = (240, 240, 240), 32 | inner_edge_color: Union[Tuple[int], List[int]] = (173, 173, 173)): 33 | """ 34 | pygame按钮控件,用于在给定pygame.surface上绘制一个按钮 35 | 36 | :param surface: 绘制按钮的pygame.surface 37 | :param text: 按钮上的文本 38 | :param pos: 按钮绘制位置 39 | :param call_function: 点击按钮调用的方法 40 | :param click_sound: 按钮的点击音效 41 | :param font_path: 按钮上的文本字体路径 42 | :param text_color: 按钮上的文本颜色 43 | :param font_size: 文本大小 44 | :param size: 按钮大小 45 | :param up_color: 按钮弹起时的颜色 46 | :param down_color: 按钮按下时的颜色 47 | :param outer_edge_color: 按钮外边框颜色 48 | :param inner_edge_color: 按钮内边框颜色 49 | """ 50 | super(Button, self).__init__() 51 | 52 | pos = copy.copy(list(pos)) 53 | if isinstance(pos[0], str): 54 | assert pos[0] == "center" 55 | pos[0] = (surface.get_width() - size[0]) // 2 56 | if isinstance(pos[1], str): 57 | assert pos[1] == "center" 58 | pos[1] = (surface.get_height() - size[1]) // 2 59 | if isinstance(click_sound, str): 60 | click_sound = pygame.mixer.Sound(click_sound) 61 | 62 | # 按钮surface 63 | self.button_surface = surface.subsurface(pos[0], pos[1], size[0], size[1]) 64 | # 外边框 65 | self.outer_rect = 0, 0, size[0], size[1] 66 | # 内边框 67 | self.inner_rect = self.outer_rect[0] + 1, self.outer_rect[1] + 1, self.outer_rect[2] - 2, self.outer_rect[3] - 2 68 | 69 | self.font = pygame.font.Font(font_path, font_size) 70 | self.text = self.font.render(text, True, text_color) 71 | self.text_color = text_color 72 | self.size = size 73 | self.call_function = call_function 74 | self.click_sound = click_sound 75 | self.up_color = up_color 76 | self.down_color = down_color 77 | self.outer_edge_color = outer_edge_color 78 | self.inner_edge_color = inner_edge_color 79 | # 按钮是否被按下 80 | self.is_down = False 81 | 82 | def draw_up(self): 83 | """绘制未被点击的按钮""" 84 | self.is_down = False 85 | self.draw(self.up_color) 86 | 87 | def draw_down(self): 88 | """绘制已被点击的按钮""" 89 | self.is_down = True 90 | self.draw(self.down_color) 91 | 92 | def draw(self, base_color: Union[Tuple[int], List[int]]): 93 | """根据传入的颜色,对按钮显示效果进行更新""" 94 | # 填充按钮底色 95 | self.button_surface.fill(base_color) 96 | # 绘制外框 97 | pygame.draw.rect(self.button_surface, self.outer_edge_color, self.outer_rect, width=1) 98 | # 绘制内框 99 | pygame.draw.rect(self.button_surface, self.inner_edge_color, self.inner_rect, width=1) 100 | # 绘制按钮文本 101 | draw_text(self.button_surface, self.text, ["center", "center"]) 102 | 103 | def set_text(self, text: str, draw_update: bool = True): 104 | """设置按钮文本""" 105 | self.text = self.font.render(text, True, self.text_color) 106 | if draw_update: 107 | self.draw_up() 108 | 109 | def enable(self): 110 | """激活按钮""" 111 | self.active = True 112 | self.draw_up() 113 | 114 | def disable(self): 115 | """冻结按钮""" 116 | self.active = False 117 | self.draw_down() 118 | 119 | def update(self, event: pygame.event): 120 | """根据pygame.event对按钮进行状态更新和方法调用""" 121 | if event.type == pygame.MOUSEBUTTONDOWN and event.button == 1: 122 | # 鼠标左键按下 123 | if pos_in_surface(event.pos, self.button_surface): 124 | self.draw_down() 125 | self.is_down = True 126 | elif event.type == pygame.MOUSEMOTION: 127 | # 鼠标移动事件,用来检测按钮是否应该弹起 128 | if not pos_in_surface(event.pos, self.button_surface) and self.is_down: 129 | self.draw_up() 130 | self.is_down = False 131 | elif event.type == pygame.MOUSEBUTTONUP and event.button == 1: 132 | # 鼠标左键弹起事件 133 | if pos_in_surface(event.pos, self.button_surface) and self.is_down: 134 | self.draw_up() 135 | # 播放按钮点击音效 136 | self.click_sound.play() 137 | # 调用相应方法 138 | if self.call_function is not None: 139 | self.call_function() 140 | 141 | 142 | if __name__ == "__main__": 143 | def say_hello(): 144 | print("hello!") 145 | 146 | # 功能测试 147 | pygame.init() 148 | screen = pygame.display.set_mode((600, 400)) 149 | pygame.display.set_caption("测试") 150 | 151 | button = Button(screen, "测试按钮", ["center", "center"], call_function=say_hello) 152 | button.enable() 153 | 154 | pygame.display.update() 155 | while True: 156 | for event in pygame.event.get(): 157 | button.update(event) 158 | pygame.display.update() 159 | -------------------------------------------------------------------------------- /pgutils/pgcontrols/ctbase.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/10/7 10:56 3 | # @Author : He Ruizhi 4 | # @File : ctbase.py 5 | # @Software: PyCharm 6 | 7 | import pygame 8 | 9 | 10 | class CtBase: 11 | """pygame控件基类,所有自定义控件均需继承CtBase""" 12 | def __init__(self): 13 | # 控件是否被激活 14 | self.active = False 15 | 16 | def enable(self): 17 | """激活控件""" 18 | self.active = True 19 | 20 | def disable(self): 21 | """冻结控件""" 22 | self.active = False 23 | 24 | def update(self, event: pygame.event) -> ...: 25 | """ 26 | 根据pygame.event对控件状态进行更新 27 | 28 | 所有控件类均需重写该方法 29 | """ 30 | raise NotImplementedError 31 | -------------------------------------------------------------------------------- /pgutils/pgtools/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/10/4 20:58 3 | # @Author : He Ruizhi 4 | # @File : __init__.py 5 | # @Software: PyCharm 6 | 7 | """ 8 | Pygame tools: Pygame常用工具类方法 9 | """ -------------------------------------------------------------------------------- /pgutils/pgtools/information_display.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/10/8 12:44 3 | # @Author : He Ruizhi 4 | # @File : text.py 5 | # @Software: PyCharm 6 | 7 | import os 8 | import pygame 9 | from collections import deque 10 | from typing import List, Tuple, Optional 11 | from pgutils.text import draw_text 12 | from pgutils.pgtools.toolbase import ToolBase 13 | 14 | current_path = os.path.dirname(__file__) 15 | 16 | 17 | class InformationDisplay(ToolBase): 18 | def __init__(self, surface: pygame.Surface, 19 | display_pos: Optional[List[str or float or int]] = None, 20 | display_size: Optional[List[int or float]] = None, 21 | max_show: int = 5, 22 | bg_color: Tuple[int, int, int] = (165, 219, 214), 23 | outer_rect_color: Tuple[int, int, int] = (240, 240, 240), 24 | inner_rect_color: Tuple[int, int, int] = (173, 173, 173), 25 | font_size: int = 14, 26 | font_color: Tuple[int, int, int] = (0, 0, 0), 27 | font_path: str = current_path + "/../../assets/fonts/msyh.ttc"): 28 | """ 29 | 在指定pygame.surface上滚动显示信息 30 | 31 | :param surface: 绘制屏幕 32 | :param display_pos: 绘制位置 33 | :param display_size: display大小 34 | :param max_show: 信息滚动显示数 35 | :param bg_color: 背景颜色 36 | :param font_size: 字体大小 37 | :param font_color: 字体颜色 38 | :param font_path: 字体文件路径 39 | """ 40 | super(InformationDisplay, self).__init__() 41 | if display_pos is None: 42 | display_pos = [20, 20] 43 | if display_size is None: 44 | surface_width, surface_height = surface.get_width(), surface.get_height() 45 | display_size = [surface_width - display_pos[0] * 2, surface_height - display_pos[1] * 2] 46 | 47 | # 创建subsurface 48 | self.display_surface = surface.subsurface((*display_pos, *display_size)) 49 | # 内外边框绘制位置 50 | self.outer_rect = 0, 0, self.display_surface.get_width(), self.display_surface.get_height() 51 | self.inner_rect = self.outer_rect[0] + 1, self.outer_rect[1] + 1, self.outer_rect[2] - 2, self.outer_rect[3] - 2 52 | 53 | # 生成字体对象 54 | self.font = pygame.font.Font(font_path, font_size) 55 | # 生成信息存储器 56 | self.information_container = deque(maxlen=max_show) 57 | 58 | self.display_pos = display_pos 59 | self.display_size = display_size 60 | self.bg_color = bg_color 61 | self.font_color = font_color 62 | self.outer_rect_color = outer_rect_color 63 | self.inner_rect_color = inner_rect_color 64 | 65 | def push_text(self, text: str, update=False): 66 | self.information_container.append(text) 67 | if update: 68 | self.enable() 69 | 70 | def update(self): 71 | self.display_surface.fill(self.bg_color) 72 | # 绘制外框 73 | pygame.draw.rect(self.display_surface, self.outer_rect_color, self.outer_rect, width=1) 74 | # 绘制内框 75 | pygame.draw.rect(self.display_surface, self.inner_rect_color, self.inner_rect, width=1) 76 | 77 | # 绘制文本 78 | next_pos = [3, 2] 79 | for line in self.information_container: 80 | line = self.font.render(line, True, self.font_color) 81 | next_pos = draw_text(self.display_surface, line, next_pos) 82 | 83 | 84 | if __name__ == "__main__": 85 | import time 86 | 87 | # 功能测试 88 | pygame.init() 89 | screen = pygame.display.set_mode((600, 400)) 90 | pygame.display.set_caption("测试") 91 | 92 | info_display = InformationDisplay(screen) 93 | for i in range(10): 94 | pygame.event.pump() 95 | info_display.push_text("测试消息:{}".format(i)) 96 | info_display.update() 97 | pygame.display.update() 98 | time.sleep(1) 99 | while True: 100 | for event in pygame.event.get(): 101 | pass 102 | -------------------------------------------------------------------------------- /pgutils/pgtools/toolbase.py: -------------------------------------------------------------------------------- 1 | 2 | class ToolBase: 3 | """pygame工具基类,所有自定义工具均需继承ToolBase""" 4 | def __init__(self): 5 | # 工具是否被激活 6 | self.active = False 7 | 8 | def enable(self): 9 | """激活工具""" 10 | self.active = True 11 | 12 | def disable(self): 13 | """冻结工具""" 14 | self.active = False 15 | 16 | def update(self): 17 | """ 18 | 对工具状态进行更新 19 | 20 | 所有工具类均需重写该方法 21 | """ 22 | raise NotImplementedError 23 | -------------------------------------------------------------------------------- /pgutils/position.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | from typing import Union, List, Tuple 3 | 4 | 5 | def pos_in_surface(pos: Union[List[int or float], Tuple[int or float]], 6 | surface: pygame.Surface) -> bool: 7 | """判断pos位置是否在pygame.surface范围内""" 8 | offset = surface.get_abs_offset() 9 | surface_size = surface.get_size() 10 | if offset[0] < pos[0] < (offset[0] + surface_size[0]) and \ 11 | offset[1] < pos[1] < (offset[1] + surface_size[1]): 12 | return True 13 | else: 14 | return False 15 | -------------------------------------------------------------------------------- /pgutils/text.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/10/7 11:44 3 | # @Author : He Ruizhi 4 | # @File : text.py 5 | # @Software: PyCharm 6 | 7 | import pygame 8 | from typing import List, Tuple, Union 9 | import os 10 | 11 | current_path = os.path.dirname(__file__) 12 | 13 | 14 | def draw_text(surface: pygame.Surface, 15 | text: Union[str, pygame.Surface], 16 | pos: Union[List[str or float or int], Tuple[str or float or int]], 17 | font_size: int = 48, 18 | font_color: Union[Tuple[int], List[int]] = (255, 255, 255), 19 | font_path: str = current_path + "/../assets/fonts/msyh.ttc", 20 | next_bias: Union[Tuple[int or float], List[int or float]] = (0, 0)) -> Tuple: 21 | """ 22 | 在指定pygame.surface上绘制文字的方法 23 | 24 | :param surface: 绘制文本的pygame.surface 25 | :param text: 文本内容 26 | :param pos: 文本绘制位置 27 | :param font_size: 字体大小 28 | :param font_color: 字体颜色 29 | :param font_path: 30 | :param next_bias: 下一行文本位置偏移 31 | :return: 下一行文本绘制位置 32 | """ 33 | # 创建pygame.font.Font对象 34 | if isinstance(text, str): 35 | font = pygame.font.Font(font_path, font_size) 36 | text = font.render(text, True, font_color) 37 | 38 | pos = list(pos) 39 | if isinstance(pos[0], str): 40 | assert pos[0] == "center" 41 | pos[0] = (surface.get_width() - text.get_width()) / 2 42 | if isinstance(pos[1], str): 43 | assert pos[1] == "center" 44 | pos[1] = (surface.get_height() - text.get_height()) / 2 45 | 46 | surface.blit(text, pos) 47 | 48 | return pos[0] + next_bias[0], pos[1] + text.get_height() + next_bias[1] 49 | 50 | 51 | if __name__ == "__main__": 52 | # 功能测试 53 | pygame.init() 54 | screen = pygame.display.set_mode((600, 400)) 55 | pygame.display.set_caption("测试") 56 | 57 | draw_text(screen, 'Hello!', ["center", "center"]) 58 | pygame.display.update() 59 | while True: 60 | for event in pygame.event.get(): 61 | pass 62 | -------------------------------------------------------------------------------- /pictures/1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/1_1.png -------------------------------------------------------------------------------- /pictures/2_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/2_1.png -------------------------------------------------------------------------------- /pictures/2_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/2_10.png -------------------------------------------------------------------------------- /pictures/2_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/2_2.png -------------------------------------------------------------------------------- /pictures/2_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/2_3.png -------------------------------------------------------------------------------- /pictures/2_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/2_4.png -------------------------------------------------------------------------------- /pictures/2_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/2_5.png -------------------------------------------------------------------------------- /pictures/2_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/2_6.png -------------------------------------------------------------------------------- /pictures/2_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/2_7.png -------------------------------------------------------------------------------- /pictures/2_8.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/2_8.gif -------------------------------------------------------------------------------- /pictures/2_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/2_9.png -------------------------------------------------------------------------------- /pictures/3_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/3_1.png -------------------------------------------------------------------------------- /pictures/4_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/4_1.png -------------------------------------------------------------------------------- /pictures/4_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/4_2.png -------------------------------------------------------------------------------- /pictures/4_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/4_3.png -------------------------------------------------------------------------------- /pictures/6_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/6_1.png -------------------------------------------------------------------------------- /pictures/6_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/6_2.png -------------------------------------------------------------------------------- /pictures/6_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/6_3.png -------------------------------------------------------------------------------- /pictures/7_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/7_1.png -------------------------------------------------------------------------------- /pictures/7_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/7_2.png -------------------------------------------------------------------------------- /pictures/7_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/7_3.png -------------------------------------------------------------------------------- /pictures/7_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/7_4.png -------------------------------------------------------------------------------- /pictures/8_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/8_1.png -------------------------------------------------------------------------------- /pictures/8_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/8_2.png -------------------------------------------------------------------------------- /pictures/8_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/8_3.png -------------------------------------------------------------------------------- /pictures/9_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/9_1.png -------------------------------------------------------------------------------- /pictures/启动界面.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/启动界面.png -------------------------------------------------------------------------------- /pictures/对弈.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/对弈.png -------------------------------------------------------------------------------- /pictures/训练初始界面.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/训练初始界面.png -------------------------------------------------------------------------------- /pictures/训练过程.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QPT-Family/QPT-CleverGo/5ab381de52bf3d515dfad60152eda73d62701116/pictures/训练过程.png -------------------------------------------------------------------------------- /play_game.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/3/7 14:29 3 | # @Author : He Ruizhi 4 | # @File : play_game.py 5 | # @Software: PyCharm 6 | 7 | from game_engine import GameEngine 8 | import pygame 9 | import sys 10 | 11 | 12 | if __name__ == '__main__': 13 | game = GameEngine() 14 | 15 | while True: 16 | for event in pygame.event.get(): 17 | if event.type == pygame.QUIT: # 退出事件 18 | sys.exit() 19 | else: 20 | game.event_control(event) 21 | # 落子 22 | game.take_action() 23 | # 音乐控制 24 | game.music_control() 25 | # 屏幕刷新 26 | pygame.display.update() 27 | -------------------------------------------------------------------------------- /player.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/3/8 18:51 3 | # @Author : He Ruizhi 4 | # @File : player.py 5 | # @Software: PyCharm 6 | 7 | from threading import Thread 8 | import numpy as np 9 | from time import sleep 10 | from mcts import MCTS, evaluate_rollout 11 | from policy_value_net import PolicyValueNet 12 | import paddle 13 | import os 14 | 15 | 16 | class Player: 17 | def __init__(self): 18 | # 是否允许启动线程计算下一步action标记 19 | self.allow = True 20 | # 下一步action 21 | self.action = None 22 | # Player名字 23 | self.name = 'Player' 24 | # 该Player是否有效,用于提前退出计算循环 25 | self.valid = True 26 | # 表明落子计算进度的量(仅在Player为MCTS或AlphaGo时生效) 27 | self.speed = None 28 | 29 | def play(self, game): 30 | if self.allow and self.action is None: 31 | self.allow = False 32 | # daemon=True可以使得主线程结束时,所有子线程全部退出,使得点击退出游戏按钮后,不用等待子线程结束 33 | Thread(target=self.step, args=(game, ), daemon=True).start() 34 | 35 | def step(self, game): 36 | """ 37 | 根据当前游戏状态,获得执行动作 38 | :param game: 游戏模拟器对象 39 | :return: 40 | """ 41 | print('Hello!') 42 | 43 | 44 | class HumanPlayer(Player): 45 | def __init__(self): 46 | super().__init__() 47 | self.name = '人类玩家' 48 | 49 | 50 | class RandomPlayer(Player): 51 | def __init__(self): 52 | super().__init__() 53 | self.name = '随机落子' 54 | 55 | def step(self, game): 56 | sleep(1) 57 | self.action = self.get_action(game) 58 | 59 | @staticmethod 60 | def get_action(game): 61 | valid_move_idcs = game.game_state.advanced_valid_move_idcs() 62 | if len(valid_move_idcs) > 1: 63 | valid_move_idcs = valid_move_idcs[:-1] 64 | action = np.random.choice(valid_move_idcs) 65 | return action 66 | 67 | 68 | class MCTSPlayer(Player): 69 | def __init__(self, c_puct=5, n_playout=20): 70 | super().__init__() 71 | self.name = '蒙特卡洛{}'.format(n_playout) 72 | 73 | def rollout_policy_fn(game_state_simulator): 74 | # 选择随机动作 75 | availables = game_state_simulator.valid_move_idcs() 76 | action_probs = np.random.rand(len(availables)) 77 | return zip(availables, action_probs) 78 | 79 | def policy_value_fn(game_state_simulator): 80 | # 返回均匀概率及通过随机方法获得的节点价值 81 | availables = game_state_simulator.valid_move_idcs() 82 | action_probs = np.ones(len(availables)) / len(availables) 83 | return zip(availables, action_probs), evaluate_rollout(game_state_simulator, rollout_policy_fn) 84 | 85 | self.mcts = MCTS(policy_value_fn, c_puct, n_playout) 86 | 87 | def step(self, game): 88 | action = self.get_action(game) 89 | if action == -1: 90 | action = None 91 | self.allow = True 92 | self.action = action 93 | 94 | # 获得动作后将速度区域清空 95 | self.speed = (0, 1) 96 | 97 | def reset_player(self): 98 | self.mcts.update_with_move(-1) 99 | 100 | def get_action(self, game): 101 | move = self.mcts.get_move(game, self) 102 | self.mcts.update_with_move(-1) 103 | return move 104 | 105 | 106 | class AlphaGoPlayer(Player): 107 | def __init__(self, model_path='models/pdparams', c_puct=5, n_playout=400, is_selfplay=False): 108 | super(AlphaGoPlayer, self).__init__() 109 | if model_path == 'models/alpha_go.pdparams': 110 | self.name = '阿尔法狗' 111 | elif model_path == 'models/my_alpha_go.pdparams': 112 | self.name = '幼生阿尔法狗' 113 | else: 114 | self.name = '预期之外的错误名称' 115 | self.policy_value_net = PolicyValueNet() 116 | self.policy_value_net.eval() 117 | 118 | if os.path.exists(model_path): 119 | state_dict = paddle.load(model_path) 120 | self.policy_value_net.set_state_dict(state_dict) 121 | 122 | self.mcts = MCTS(self.policy_value_net.policy_value_fn, c_puct, n_playout) 123 | self.is_selfplay = is_selfplay 124 | 125 | def reset_player(self): 126 | self.mcts.update_with_move(-1) 127 | 128 | def step(self, game): 129 | action = self.get_action(game) 130 | if action == -1: 131 | action = None 132 | self.allow = True 133 | self.action = action 134 | self.speed = (0, 1) 135 | 136 | def get_action(self, game, temp=1e-3, return_probs=False): 137 | move_probs = np.zeros(game.board_size ** 2 + 1) 138 | acts, probs = self.mcts.get_move_probs(game, temp, self) 139 | if acts == -1 and probs == -1: 140 | return -1 141 | move_probs[list(acts)] = probs 142 | if self.is_selfplay: 143 | # 增加Dirichlet噪声用于探索(在训练时候) 144 | move = np.random.choice(acts, p=0.75*probs + 0.25*np.random.dirichlet(0.3*np.ones(len(probs)))) 145 | # 更新蒙特卡洛搜索树 146 | self.mcts.update_with_move(move) # 因为在生成自对弈棋谱时,落子是黑白交替,均由自己做出决策 147 | else: 148 | move = np.random.choice(acts, p=probs) 149 | self.mcts.update_with_move(-1) # 与其它对手对弈时,只控制黑方或白方落子,因此每步均置为-1 150 | if return_probs: 151 | return move, move_probs 152 | else: 153 | return move 154 | 155 | 156 | class PolicyNetPlayer(Player): 157 | def __init__(self, model_path='models/model.pdparams'): 158 | super(PolicyNetPlayer, self).__init__() 159 | self.name = '策略网络' 160 | self.policy_value_net = PolicyValueNet() 161 | 162 | if os.path.exists(model_path): 163 | state_dict = paddle.load(model_path) 164 | self.policy_value_net.set_state_dict(state_dict) 165 | self.policy_value_net.eval() 166 | 167 | def step(self, game): 168 | sleep(1) 169 | self.action = self.get_action(game) 170 | 171 | def get_action(self, game): 172 | valid_moves = game.game_state.valid_moves() 173 | valid_moves = paddle.to_tensor(valid_moves) 174 | 175 | current_state = game.game_state.get_board_state() 176 | current_state = paddle.to_tensor([current_state], dtype='float32') 177 | probs, _ = self.policy_value_net(current_state) 178 | probs = probs[0] 179 | probs *= valid_moves 180 | probs = probs / paddle.sum(probs) 181 | 182 | action = np.random.choice(range(82), p=probs.numpy()) 183 | return action 184 | 185 | 186 | class ValueNetPlayer(Player): 187 | def __init__(self, model_path='models/model.pdparams'): 188 | super(ValueNetPlayer, self).__init__() 189 | self.name = '价值网络' 190 | self.policy_value_net = PolicyValueNet() 191 | 192 | if os.path.exists(model_path): 193 | state_dict = paddle.load(model_path) 194 | self.policy_value_net.set_state_dict(state_dict) 195 | self.policy_value_net.eval() 196 | 197 | def step(self, game): 198 | sleep(1) 199 | self.action = self.get_action(game) 200 | 201 | def get_action(self, game): 202 | valid_move_idcs = game.game_state.valid_move_idcs() 203 | 204 | # 计算所有可落子位置,对手的局面价值,选择对手局面价值最小的落子 205 | max_value = 1 206 | action = game.board_size ** 2 207 | for simulate_action in valid_move_idcs: 208 | simulate_game_state = game.game_state_simulator() 209 | simulate_game_state.step(simulate_action) 210 | 211 | current_state = simulate_game_state.get_board_state() 212 | current_state = paddle.to_tensor([current_state], dtype='float32') 213 | 214 | _, value = self.policy_value_net(current_state) 215 | value = value.numpy().flatten()[0] 216 | 217 | if value < max_value: 218 | max_value = value 219 | action = simulate_action 220 | 221 | return action 222 | -------------------------------------------------------------------------------- /policy_value_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/3/29 21:01 3 | # @Author : He Ruizhi 4 | # @File : policy_value_net.py 5 | # @Software: PyCharm 6 | 7 | import numpy as np 8 | import paddle 9 | 10 | 11 | class PolicyValueNet(paddle.nn.Layer): 12 | def __init__(self, input_channels: int = 10, 13 | board_size: int = 9): 14 | """ 15 | 16 | :param input_channels: 输入的通道数,默认为10。双方最近4步,再加一个表示当前落子方的平面,再加上一个最近一手位置的平面 17 | :param board_size: 棋盘大小 18 | """ 19 | super(PolicyValueNet, self).__init__() 20 | 21 | # AlphaGo Zero网络架构:一个身子,两个头 22 | # 特征提取网络部分 23 | self.conv_layer = paddle.nn.Sequential( 24 | paddle.nn.Conv2D(in_channels=input_channels, out_channels=32, kernel_size=3, padding=1), 25 | paddle.nn.ReLU(), 26 | paddle.nn.Conv2D(in_channels=32, out_channels=64, kernel_size=3, padding=1), 27 | paddle.nn.ReLU(), 28 | paddle.nn.Conv2D(in_channels=64, out_channels=128, kernel_size=3, padding=1), 29 | paddle.nn.ReLU() 30 | ) 31 | 32 | # 策略网络部分 33 | self.policy_layer = paddle.nn.Sequential( 34 | paddle.nn.Conv2D(in_channels=128, out_channels=8, kernel_size=1), 35 | paddle.nn.ReLU(), 36 | paddle.nn.Flatten(), 37 | paddle.nn.Linear(in_features=9*9*8, out_features=256), 38 | paddle.nn.ReLU(), 39 | paddle.nn.Linear(in_features=256, out_features=board_size*board_size+1), 40 | paddle.nn.Softmax() 41 | ) 42 | 43 | # 价值网络部分 44 | self.value_layer = paddle.nn.Sequential( 45 | paddle.nn.Conv2D(in_channels=128, out_channels=4, kernel_size=1), 46 | paddle.nn.ReLU(), 47 | paddle.nn.Flatten(), 48 | paddle.nn.Linear(in_features=9*9*4, out_features=128), 49 | paddle.nn.ReLU(), 50 | paddle.nn.Linear(in_features=128, out_features=64), 51 | paddle.nn.ReLU(), 52 | paddle.nn.Linear(in_features=64, out_features=1), 53 | paddle.nn.Tanh() 54 | ) 55 | 56 | def forward(self, x): 57 | x = self.conv_layer(x) 58 | policy = self.policy_layer(x) 59 | value = self.value_layer(x) 60 | return policy, value 61 | 62 | def policy_value_fn(self, simulate_game_state): 63 | """ 64 | 65 | :param simulate_game_state: 66 | :return: 67 | """ 68 | legal_positions = simulate_game_state.valid_move_idcs() 69 | current_state = paddle.to_tensor(simulate_game_state.get_board_state()[np.newaxis], dtype='float32') 70 | act_probs, value = self.forward(current_state) 71 | act_probs = zip(legal_positions, act_probs.numpy().flatten()[legal_positions]) 72 | return act_probs, value 73 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.62.2 2 | pyglet==1.5.0 3 | scipy==1.5.2 4 | gym==0.18.0 5 | numpy==1.19.2 6 | pygame==2.0.1 7 | paddlepaddle==2.1.2 8 | scikit_learn==0.24.2 9 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/3/7 16:02 3 | # @Author : He Ruizhi 4 | # @File : test.py 5 | # @Software: PyCharm 6 | 7 | # from weiqi_engine import GoGameState 8 | # import copy 9 | # import sys 10 | # sys.path.append('GymGo/') 11 | # from gym_go.envs.go_env import GoEnv 12 | # 13 | # game_state = GoEnv(9) 14 | # new_game_state = copy.deepcopy(game_state) 15 | # while True: 16 | # print('afadf') 17 | # pass 18 | 19 | import numpy as np 20 | # a = np.array([[1, 2], 21 | # [3, 4]]) 22 | # b = np.array([[0, 1], 23 | # [1, 0]]) 24 | # print(a * b) 25 | 26 | # a = [[[[1]], [[2]], [[3]]], [4, 5, 6], [7, 8, 9]] 27 | # for b, c, d in a: 28 | # print(b, c, d) 29 | 30 | a = [1, 2, 3, 4, 5] 31 | b = a[6:8] 32 | print(b) 33 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import os 3 | import paddle 4 | from player import AlphaGoPlayer 5 | import numpy as np 6 | from threading import Thread 7 | from threading import Lock 8 | import go_engine 9 | 10 | lock = Lock() 11 | 12 | 13 | class Trainer: 14 | def __init__(self, epochs=10, learning_rate=1e-3, batch_size=128, temp=1.0, n_playout=100, c_puct=5, 15 | train_model_path='models/my_alpha_go.pdparams'): 16 | """ 17 | 训练阿尔法狗的训练器 18 | 19 | :param epochs: 每自对弈一局,对样本迭代训练的epoch数 20 | :param learning_rate: 学习率 21 | :param temp: 蒙特卡洛数搜索温度参数 22 | :param n_playout: 蒙特卡洛树搜索模拟次数 23 | :param c_puct: 蒙特卡洛树搜索中计算上置信限的参数 24 | :param train_model_path: 训练模型的参数路径 25 | """ 26 | self.epochs = epochs 27 | self.learning_rate = learning_rate 28 | self.batch_size = batch_size 29 | self.temp = temp 30 | self.n_playout = n_playout 31 | self.c_puct = c_puct 32 | self.train_model_path = train_model_path 33 | self.train_step = 0 34 | self.model_update_step = 0 35 | 36 | # 创建阿尔法狗 37 | self.player = AlphaGoPlayer(train_model_path, c_puct, n_playout, is_selfplay=True) 38 | 39 | # 创建训练优化器 40 | self.optimizer = paddle.optimizer.Momentum(learning_rate=learning_rate, 41 | parameters=self.player.policy_value_net.parameters()) 42 | 43 | def start(self, game): 44 | """启动阿尔法狗训练线程""" 45 | Thread(target=self._train, args=(game,), daemon=True).start() 46 | 47 | def _train(self, game): 48 | """训练阿尔法狗""" 49 | 50 | # 加载训练网络参数 51 | if os.path.exists(self.train_model_path): 52 | state_dict = paddle.load(self.train_model_path) 53 | self.player.policy_value_net.set_state_dict(state_dict) 54 | print('加载模型权重成功!') 55 | game.info_display.push_text('{} 你成功唤醒了你的幼生阿尔法狗!'.format( 56 | datetime.now().strftime(r'%m-%d %H:%M:%S')), update=True) 57 | else: 58 | print('未找到模型参数!') 59 | game.info_display.push_text('{} 你成功领养了一只幼生阿尔法狗!'.format( 60 | datetime.now().strftime(r'%m-%d %H:%M:%S')), update=True) 61 | 62 | while True: 63 | if game.surface_state == 'play': 64 | break 65 | 66 | # 自对弈一局 67 | game.info_display.push_text('{} 你的阿尔法狗开始了自对弈!'.format( 68 | datetime.now().strftime(r'%m-%d %H:%M:%S')), update=True) 69 | play_datas = self.self_play_one_game(game) 70 | if play_datas is not None: 71 | play_datas = self.get_equi_data(play_datas) 72 | # 训练网络 73 | self.update_network(game, play_datas) 74 | paddle.save(self.player.policy_value_net.state_dict(), self.train_model_path) 75 | self.model_update_step += 1 76 | print('保存模型权重一次!') 77 | game.info_display.push_text('{} 阿尔法狗成长阶段{}!'.format( 78 | datetime.now().strftime(r'%m-%d %H:%M:%S'), self.model_update_step), update=True) 79 | 80 | def self_play_one_game(self, game): 81 | """自对弈依据游戏,并获取对弈数据""" 82 | states, mcts_probs, current_players = [], [], [] 83 | 84 | while True: 85 | if game.surface_state == 'play': 86 | break 87 | 88 | # 获取动作及概率 89 | move, move_probs = self.player.get_action(game, temp=self.temp, return_probs=True) 90 | # 存数据 91 | states.append(game.train_game_state.get_board_state()) 92 | mcts_probs.append(move_probs) 93 | current_players.append(game.train_game_state.turn()) 94 | # 执行落子 95 | lock.acquire() 96 | if game.surface_state == 'train': 97 | game.train_step(move) 98 | lock.release() 99 | 100 | end, winner = game.train_game_state.game_ended(), game.train_game_state.winner() 101 | if end: 102 | print('{}胜!'.format('黑' if winner == go_engine.BLACK else '白')) 103 | game.info_display.push_text('{} {}胜!'.format( 104 | datetime.now().strftime(r'%m-%d %H:%M:%S'), '黑' if winner == go_engine.BLACK else '白'), update=True) 105 | 106 | winners = np.zeros(len(current_players)) 107 | if winner != -1: 108 | winners[np.array(current_players) == winner] = 1.0 109 | winners[np.array(current_players) != winner] = -1.0 110 | # 重置蒙特卡洛搜索树 111 | self.player.reset_player() 112 | # 重置train_game_state 113 | game.train_game_state.reset() 114 | states = np.array(states) 115 | mcts_probs = np.array(mcts_probs) 116 | return zip(states, mcts_probs, winners) 117 | 118 | @staticmethod 119 | def get_equi_data(play_data): 120 | """通过旋转和翻转来扩增数据""" 121 | extend_data = [] 122 | for state, mcts_porb, winner in play_data: 123 | board_size = state.shape[-1] 124 | for i in [1, 2, 3, 4]: 125 | # 逆时针旋转 126 | equi_state = np.array([np.rot90(s, i) for s in state]) 127 | pass_move_prob = mcts_porb[-1] 128 | equi_mcts_prob = np.rot90(np.flipud(mcts_porb[:-1].reshape(board_size, board_size)), i) 129 | extend_data.append((equi_state, np.append(np.flipud(equi_mcts_prob).flatten(), pass_move_prob), winner)) 130 | # 翻转 131 | equi_state = np.array([np.fliplr(s) for s in equi_state]) 132 | equi_mcts_prob = np.fliplr(equi_mcts_prob) 133 | extend_data.append((equi_state, np.append(np.flipud(equi_mcts_prob).flatten(), pass_move_prob), winner)) 134 | return extend_data 135 | 136 | def update_network(self, game, play_datas): 137 | """更新网络参数""" 138 | self.player.policy_value_net.train() 139 | for epoch in range(self.epochs): 140 | if game.surface_state == 'play': 141 | break 142 | 143 | np.random.shuffle(play_datas) 144 | for i in range(len(play_datas) // self.batch_size + 1): 145 | self.train_step += 1 146 | 147 | batch = play_datas[i * self.batch_size:(i + 1) * self.batch_size] 148 | if len(batch) == 0: 149 | continue 150 | state_batch = paddle.to_tensor([data[0] for data in batch], dtype='float32') 151 | mcts_probs_batch = paddle.to_tensor([data[1] for data in batch], dtype='float32') 152 | winner_batch = paddle.to_tensor([data[2] for data in batch], dtype='float32') 153 | 154 | act_probs, value = self.player.policy_value_net(state_batch) 155 | ce_loss = paddle.nn.functional.cross_entropy(act_probs, mcts_probs_batch, 156 | soft_label=True, use_softmax=False) 157 | mse_loss = paddle.nn.functional.mse_loss(value, winner_batch) 158 | loss = ce_loss + mse_loss 159 | 160 | loss.backward() 161 | self.optimizer.step() 162 | self.optimizer.clear_grad() 163 | 164 | print('{} Step:{} CELoss:{} MSELoss:{} Loss:{}'.format( 165 | datetime.now().strftime('%Y-%m-%d %H:%M:%S'), self.train_step, 166 | ce_loss.numpy(), mse_loss.numpy(), loss.numpy())) 167 | self.player.policy_value_net.eval() 168 | --------------------------------------------------------------------------------