├── .gitignore ├── GUI.py ├── choose_best_player.py ├── ckpt ├── alphaFive-6960.data-00000-of-00001 ├── alphaFive-6960.index ├── alphaFive-6960.meta └── checkpoint ├── config.py ├── data_buffer ├── data6960.pkl ├── data_len6960.pkl └── result6960.pkl ├── genData ├── __init__.py ├── network.py ├── networkAPI.py └── player.py ├── images └── back.png ├── main.py ├── note.txt ├── readme.md ├── self_play.py ├── summary └── log_20200312_11_54_18 │ ├── events.out.tfevents.1583997414.DESKTOP-T9U7R33 │ └── events.out.tfevents.1584241253.DESKTOP-T9U7R33 ├── test.py ├── tmp ├── entropy.jpg ├── episode_length.jpg ├── five_6960.gif ├── total_loss.jpg ├── value_loss.jpg └── xentropy_loss.jpg └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .git 2 | .idea 3 | __pycache__ 4 | */__pycache__ 5 | game_record 6 | test.py 7 | result.txt 8 | note.txt 9 | -------------------------------------------------------------------------------- /GUI.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from genData.network import ResNet as Model 3 | import config 4 | import pygame 5 | import os 6 | import numpy as np 7 | from genData.player import Player 8 | import utils 9 | import tensorflow as tf 10 | import imageio 11 | import cv2 12 | 13 | make_gif = False # 是否要生成gif 14 | GRID_WIDTH = 36 15 | WIDTH = (config.board_size + 2) * GRID_WIDTH 16 | HEIGHT = (config.board_size + 2) * GRID_WIDTH 17 | FPS = 100 18 | WHITE = (255, 255, 255) 19 | BLACK = (0, 0, 0) 20 | HUMAN = 0 21 | AI = 2 22 | 23 | 24 | def main(trained_ckpt): 25 | net = Model(config.board_size) 26 | player = Player(config, training=False, pv_fn=net.eval) 27 | net.restore(trained_ckpt) 28 | pygame.init() 29 | screen = pygame.display.set_mode((WIDTH, HEIGHT)) 30 | pygame.display.set_caption("五子棋") 31 | clock = pygame.time.Clock() 32 | base_folder = os.path.dirname(__file__) 33 | img_folder = os.path.join(base_folder, 'images') 34 | background_img = pygame.image.load(os.path.join(img_folder, 'back.png')).convert() 35 | background = pygame.transform.scale(background_img, (WIDTH, HEIGHT)) 36 | back_rect = background.get_rect() 37 | running = True 38 | frames = [] 39 | 40 | # def draw_stone(screen_): 41 | # for i in range(config.board_size): 42 | # for j in range(config.board_size): 43 | # if state[i, j] == 1: 44 | # pygame.draw.circle(screen_, BLACK, (int((i + 1.5) * GRID_WIDTH), int((j + 1.5) * GRID_WIDTH)), 16) 45 | # elif state[i, j] == -1: 46 | # pygame.draw.circle(screen_, WHITE, (int((i + 1.5) * GRID_WIDTH), int((j + 1.5) * GRID_WIDTH)), 16) 47 | # else: 48 | # assert state[i, j] == 0 49 | def draw_stone(screen_): 50 | for i in range(config.board_size): 51 | for j in range(config.board_size): 52 | if state[i, j] == 1: 53 | pygame.draw.circle(screen_, BLACK, (int((j + 1.5) * GRID_WIDTH), int((i + 1.5) * GRID_WIDTH)), 16) 54 | elif state[i, j] == -1: 55 | pygame.draw.circle(screen_, WHITE, (int((j + 1.5) * GRID_WIDTH), int((i + 1.5) * GRID_WIDTH)), 16) 56 | else: 57 | assert state[i, j] == 0 58 | 59 | def draw_background(surf): 60 | screen.blit(background, back_rect) 61 | rect_lines = [ 62 | ((GRID_WIDTH, GRID_WIDTH), (GRID_WIDTH, HEIGHT - GRID_WIDTH)), 63 | ((GRID_WIDTH, GRID_WIDTH), (WIDTH - GRID_WIDTH, GRID_WIDTH)), 64 | ((GRID_WIDTH, HEIGHT - GRID_WIDTH), 65 | (WIDTH - GRID_WIDTH, HEIGHT - GRID_WIDTH)), 66 | ((WIDTH - GRID_WIDTH, GRID_WIDTH), 67 | (WIDTH - GRID_WIDTH, HEIGHT - GRID_WIDTH)), 68 | ] 69 | for line in rect_lines: 70 | pygame.draw.line(surf, BLACK, line[0], line[1], 2) 71 | 72 | for i in range(config.board_size): 73 | pygame.draw.line(surf, BLACK, 74 | (GRID_WIDTH * (2 + i), GRID_WIDTH), 75 | (GRID_WIDTH * (2 + i), HEIGHT - GRID_WIDTH)) 76 | pygame.draw.line(surf, BLACK, 77 | (GRID_WIDTH, GRID_WIDTH * (2 + i)), 78 | (HEIGHT - GRID_WIDTH, GRID_WIDTH * (2 + i))) 79 | 80 | circle_center = [ 81 | (GRID_WIDTH * 4, GRID_WIDTH * 4), 82 | (WIDTH - GRID_WIDTH * 4, GRID_WIDTH * 4), 83 | (WIDTH - GRID_WIDTH * 4, HEIGHT - GRID_WIDTH * 4), 84 | (GRID_WIDTH * 4, HEIGHT - GRID_WIDTH * 4), 85 | ] 86 | for cc in circle_center: 87 | pygame.draw.circle(surf, BLACK, cc, 5) 88 | 89 | draw_background(screen) 90 | pygame.display.flip() 91 | image_data = pygame.surfarray.array3d(pygame.display.get_surface()) 92 | frames.append(cv2.resize(image_data, (0, 0), fx=0.5, fy=0.5)) 93 | players = [HUMAN, AI] # 0 表示人类玩家,2表示包含network的AI 94 | idx = int(input("input the fist side, (0 human), (1 AI), (2 exit): ")) 95 | while idx not in [0, 1, 2]: 96 | idx = int(input("input the fist side, (0 human), (1 AI), (2 exit): ")) 97 | if idx == 2: 98 | exit() 99 | if players[idx] == AI: 100 | print("AI first") 101 | else: 102 | print("Human first") 103 | game_over = False 104 | state_str = player.get_init_state() 105 | board = utils.state_to_board(state_str, config.board_size) 106 | state = board 107 | last_action = None 108 | huihe = 0 109 | if players[idx] == AI: 110 | _, action = player.get_action(state_str, last_action=last_action) 111 | print("AI's action, ", action) 112 | huihe += 1 113 | board = utils.step(utils.state_to_board(state_str, config.board_size), action) 114 | state_str = utils.board_to_state(board) 115 | # player.pruning_tree(board, state_str) # 走完一步以后,对其他分支进行剪枝,以节约内存 116 | game_over, value = utils.is_game_over(board, config.goal) 117 | state = -board 118 | draw_background(screen) 119 | draw_stone(screen) 120 | pygame.display.flip() 121 | image_data = pygame.surfarray.array3d(pygame.display.get_surface()) 122 | frames.append(cv2.resize(image_data, (0, 0), fx=0.5, fy=0.5)) 123 | i = 0 124 | while running: 125 | clock.tick(FPS) 126 | for event in pygame.event.get(): 127 | if event.type == pygame.QUIT: 128 | running = False 129 | break 130 | elif event.type == pygame.MOUSEBUTTONDOWN: 131 | if game_over: 132 | break 133 | pos = event.pos # 获得的坐标是(x, y) 134 | if out_of_boundry(pos): 135 | continue 136 | action = (int((pos[1] - GRID_WIDTH) / GRID_WIDTH), int((pos[0] - GRID_WIDTH) / GRID_WIDTH)) 137 | print("Human's action: ", action) 138 | huihe += 1 139 | if state[action[0], action[1]] != 0: 140 | continue 141 | board = utils.step(board, action) # 人类落子 142 | last_action = action 143 | state_str = utils.board_to_state(board) 144 | # player.pruning_tree(board, state_str) 145 | game_over, value = utils.is_game_over(board, config.goal) 146 | state = board 147 | draw_background(screen) 148 | draw_stone(screen) 149 | pygame.display.flip() 150 | image_data = pygame.surfarray.array3d(pygame.display.get_surface()) 151 | frames.append(cv2.resize(image_data, (0, 0), fx=0.5, fy=0.5)) 152 | if game_over: 153 | continue 154 | _, action = player.get_action(state_str, last_action=last_action, random_a=False) 155 | last_action = action 156 | print("AI's action ", action) 157 | huihe += 1 158 | board = utils.step(utils.state_to_board(state_str, config.board_size), action) 159 | state_str = utils.board_to_state(board) 160 | player.pruning_tree(board, state_str) # 走完一步以后,对其他分支进行剪枝,以节约内存 161 | game_over, value = utils.is_game_over(board, config.goal) 162 | state = -board 163 | draw_background(screen) 164 | draw_stone(screen) 165 | pygame.display.flip() 166 | image_data = pygame.surfarray.array3d(pygame.display.get_surface()) 167 | frames.append(cv2.resize(image_data, (0, 0), fx=0.5, fy=0.5)) 168 | if game_over: 169 | if i == 0: 170 | print(f"game over, total {(huihe+1)//2} rounds") 171 | if huihe == config.batch_size * config.batch_size: 172 | print("game tied!") 173 | elif huihe % 2 == 1 and players[idx] == AI: 174 | print("AI won! You are stupid!") 175 | else: 176 | print("you won!, You niubi") 177 | i += 1 178 | image_data = pygame.surfarray.array3d(pygame.display.get_surface()) 179 | frames.append(cv2.resize(image_data, (0, 0), fx=0.5, fy=0.5)) 180 | if i >= 5 and make_gif: 181 | break 182 | 183 | pygame.quit() 184 | if make_gif: 185 | print("game finished, start to write to gif.") 186 | gif = imageio.mimsave("tmp/five_6960.gif", frames, 'GIF', duration=1.0) 187 | print("done!") 188 | 189 | 190 | def out_of_boundry(pos): 191 | return pos[0] < GRID_WIDTH or pos[1] < GRID_WIDTH or pos[0] > WIDTH - GRID_WIDTH or pos[1] > HEIGHT - GRID_WIDTH 192 | 193 | 194 | if __name__ == "__main__": 195 | main(trained_ckpt=config.ckpt_path) 196 | -------------------------------------------------------------------------------- /choose_best_player.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from genData.network import ResNet as Model 3 | import config 4 | import pygame 5 | import os 6 | import tensorflow as tf 7 | import numpy as np 8 | from genData.player import Player 9 | import utils 10 | from random import shuffle 11 | import random 12 | 13 | 14 | GRID_WIDTH = 36 15 | WIDTH = (config.board_size + 2) * GRID_WIDTH 16 | HEIGHT = (config.board_size + 2) * GRID_WIDTH 17 | FPS = 300 18 | WHITE = (255, 255, 255) 19 | BLACK = (0, 0, 0) 20 | HUMAN = 0 21 | AI = 2 22 | 23 | 24 | def main(): 25 | config.simulation_per_step = 500 26 | # 先只搜索6060以上的 27 | all_ckpts = [os.path.join("ckpt", "alphaFive-"+str(num)) for num in range(60, 8800, 60)][100:-1] 28 | net0 = Model(config.board_size, tf.Graph()) 29 | net0.restore(all_ckpts[0]) 30 | net1 = Model(config.board_size, tf.Graph()) 31 | net1.restore(all_ckpts[-1]) 32 | player0 = Player(config, training=False, pv_fn=net0.eval) 33 | player1 = Player(config, training=False, pv_fn=net1.eval) 34 | players = [{'p': player0, "win": 0, "ckpt": all_ckpts[0]}, 35 | {'p': player1, "win": 0, "ckpt": all_ckpts[-1]}] 36 | result = open("result.txt", "a") 37 | low, high = 0, len(all_ckpts)-1 38 | while low < high: # 尽量让实力悬殊的ckpt进行对弈 39 | print("") 40 | print("==================================================================") 41 | print(players[0]["ckpt"] + " vs " + players[1]["ckpt"] + '...') 42 | for i in range(100): # 最多对弈100局 43 | players[0]['p'].reset() # 每一局开始前都要重置 44 | players[1]['p'].reset() 45 | game_over = False 46 | action = None 47 | state = player1.get_init_state() 48 | current_ids = i % 2 49 | value = 0.0 50 | count = 0 51 | while not game_over: 52 | _, action = players[current_ids]['p'].get_action(state, last_action=action, random_a=True) 53 | board = utils.step(utils.state_to_board(state, config.board_size), action) 54 | state = utils.board_to_state(board) 55 | # players[current_ids].pruning_tree(board, state) # 走完一步以后,对其他分支进行剪枝,以节约内存, 不剪枝,节约时间 56 | game_over, value = utils.is_game_over(board, config.goal) 57 | current_ids = (current_ids + 1) % 2 # 下一个选手 58 | count += 1 59 | if value == 0.0: # 和棋了 60 | print(f"game: {i}, tied! all {count} turns.") 61 | continue 62 | else: 63 | print(f"game: {i} {players[(current_ids+1) % 2]['ckpt']} won! all {count} turns.") 64 | players[(current_ids+1) % 2]["win"] += 1 65 | if i >= 30: 66 | # 超过24局以后,输赢悬殊太大的话,直接break 67 | w0 = players[0]["win"] 68 | w1 = players[1]["win"] 69 | if w0 == 0 or w1 == 0: 70 | break 71 | elif w0 / w1 > 2.0 or w0 / w1 < 0.5: 72 | break 73 | print_str = players[0]["ckpt"] + ": " + players[1]["ckpt"] + f' = {players[0]["win"]}: {players[1]["win"]}' 74 | print(print_str) 75 | print(print_str, file=result, flush=True) 76 | if players[0]["win"] < players[1]["win"]: 77 | low += 1 78 | net0.restore(all_ckpts[low]) 79 | players[0]["ckpt"] = all_ckpts[low] 80 | else: 81 | high -= 1 82 | net1.restore(all_ckpts[high]) 83 | players[1]["ckpt"] = all_ckpts[high] 84 | 85 | players[0]["win"] = players[1]["win"] = 0 86 | result.close() 87 | net1.close() 88 | net0.close() 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | 94 | -------------------------------------------------------------------------------- /ckpt/alphaFive-6960.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoYi0/alphaFive/9492d1083154d13c5ed8c8e3c2a34f0fc1fee2e7/ckpt/alphaFive-6960.data-00000-of-00001 -------------------------------------------------------------------------------- /ckpt/alphaFive-6960.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoYi0/alphaFive/9492d1083154d13c5ed8c8e3c2a34f0fc1fee2e7/ckpt/alphaFive-6960.index -------------------------------------------------------------------------------- /ckpt/alphaFive-6960.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoYi0/alphaFive/9492d1083154d13c5ed8c8e3c2a34f0fc1fee2e7/ckpt/alphaFive-6960.meta -------------------------------------------------------------------------------- /ckpt/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "alphaFive-6960" 2 | all_model_checkpoint_paths: "alphaFive-6960" 3 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | board_size = 11 # 棋盘大小 3 | buffer_size = 12000 # 棋盘越大,这个数字就得越大,每局棋约30步,则可以保存400局棋。这个数字貌似还得适当调大一点 4 | simulation_per_step = 542 # 对于初始棋盘格点数121,每个至少访问2次,花销242,。还剩300次自由访问机会,alphago是1600 5 | upper_simulation_per_step = 642 # 有些节点作为非根节点已经被访问过。那么合计最多访问这么多次了算了,节约点时间 6 | goal = 5 # 五子棋,故为5 7 | batch_size = 512 8 | # 学习率下降方案,貌似第一个7000还可以继续调大,即以1e-3的学习率需要学习更长时间 9 | lr_ = [(7000, 1e-3), (14000, 2e-4), (28000, 4e-5), (100000000, 2e-6)] 10 | ckpt_path = "ckpt" 11 | total_step = 20000 # 预设定的学习步数。我只学习到8k左右就掐断了,电脑太水了跑的太慢了。 12 | tau_decay_rate = 0.94 # 温度衰减项,越往后温度应该越低 13 | # 这个温度是为了选择最佳ckpt的时候使用。在人机对弈的时候,AI是决定性落子。 14 | # 但是在比较两个ckpt的优劣的时候,就不能决定性落子。这个温度比tau_decay_rate,表示不需要深度探索了。 15 | tau_decay_rate_r = 0.9 16 | c_puct = 5.0 # 这个数调大一点。 17 | dirichlet_alpha = 0.3 # 这个值越大,产生的分布越均匀,棋盘越大,这个值也相应稍微调小点为好 18 | gamma = 0.94 # 权重衰减因子,越是后面的局面,出现的频率就越低,分配的训练权重越大。反过来分配的。。 19 | init_temp = 1.2 # 初始温度 # 温度越大分布越平缓,这里适当调大,主要是考虑到初始分布太尖锐了 20 | max_processes = 5 # 并行产生数据的进程数。时间主要消耗在了生成数据上,这个数值在电脑运行范围内适当调大为好。 21 | 22 | 23 | def get_lr(step): 24 | for item in lr_: 25 | if step < item[0]: 26 | return item[1] 27 | return lr_[-1][-1] 28 | -------------------------------------------------------------------------------- /data_buffer/data6960.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoYi0/alphaFive/9492d1083154d13c5ed8c8e3c2a34f0fc1fee2e7/data_buffer/data6960.pkl -------------------------------------------------------------------------------- /data_buffer/data_len6960.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoYi0/alphaFive/9492d1083154d13c5ed8c8e3c2a34f0fc1fee2e7/data_buffer/data_len6960.pkl -------------------------------------------------------------------------------- /data_buffer/result6960.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoYi0/alphaFive/9492d1083154d13c5ed8c8e3c2a34f0fc1fee2e7/data_buffer/result6960.pkl -------------------------------------------------------------------------------- /genData/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /genData/network.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | from functools import reduce 4 | import numpy as np 5 | from tensorflow.python import pywrap_tensorflow 6 | from genData.networkAPI import NetworkAPI 7 | 8 | DATA_FORMAT = "channels_first" 9 | 10 | 11 | class ResNet(object): 12 | """ 13 | 针对当前局面进行评估 14 | """ 15 | def __init__(self, board_size, graph=None): 16 | self.graph = tf.get_default_graph() if graph is None else graph 17 | with self.graph.as_default(): 18 | self.board_size = board_size 19 | # 棋局的输入 20 | self.inputs = tf.placeholder(dtype=tf.float32, shape=[None, 3, board_size, board_size], name="inputs") 21 | self.winner = tf.placeholder(dtype=tf.float32, shape=[None], name="winner") # value的监督信号 22 | self.distrib = tf.placeholder(dtype=tf.float32, shape=[None, board_size * board_size], name="distrib") # policy的监督信号 23 | self.weights = tf.placeholder(dtype=tf.float32, shape=[None], name="weights") 24 | self.training = tf.placeholder(dtype=tf.bool, shape=(), name="training") 25 | self.value = None 26 | self.policy = None 27 | self.entropy = None 28 | self.log_softmax = None 29 | self.prob = None 30 | self.network() 31 | self.cross_entropy_loss, self.value_loss, self.total_loss = None, None, None 32 | self.construct_loss() 33 | gpu_options = tf.GPUOptions(allow_growth=True) 34 | self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options), graph=self.graph) 35 | # self.sess = tf.Session() 36 | self.sess.run(tf.global_variables_initializer()) 37 | self.saver = tf.train.Saver(max_to_keep=10000) 38 | self.api = None 39 | 40 | def construct_loss(self): 41 | x_entropy = tf.reduce_sum(tf.multiply(self.distrib, self.log_softmax), axis=1) 42 | self.cross_entropy_loss = tf.negative(tf.reduce_mean(x_entropy)) # 用于显示 43 | weighted_x_entropy = tf.negative(tf.reduce_mean(tf.multiply(x_entropy, self.weights))) # 用于实际计算 44 | value_loss = tf.squared_difference(self.value, self.winner) 45 | self.value_loss = tf.reduce_mean(value_loss) 46 | weighted_value_loss = tf.reduce_mean(tf.multiply(value_loss, self.weights)) 47 | L2_loss = tf.add_n( 48 | [tf.nn.l2_loss(v) for v in tf.trainable_variables() if "bias" not in v.name and 'bn' not in v.name]) 49 | # self.total_loss = cross_entropy + value_loss + L2_loss * 1e-5 50 | self.total_loss = weighted_x_entropy + 2.0*weighted_value_loss + L2_loss * 4e-5 51 | 52 | def residual(self, f, units, name): 53 | res = tf.layers.conv2d(f, units, 1, padding="VALID", data_format=DATA_FORMAT, name=name+"_res", activation=None) 54 | f = tf.layers.conv2d(f, units, 3, padding="SAME", data_format=DATA_FORMAT, name=name+"_conv1", activation=tf.nn.elu) 55 | f = tf.layers.conv2d(f, units, 3, padding="SAME", data_format=DATA_FORMAT, name=name+"_conv2", activation=None) 56 | return tf.nn.elu(tf.add(res, f, name+"_add"), "elu") 57 | 58 | def network(self): 59 | # total params 44w 60 | f = self.inputs 61 | with tf.variable_scope("bone"): 62 | # params: 13w 63 | f = tf.layers.conv2d(f, 32, 5, padding="SAME", data_format=DATA_FORMAT, name="conv1", activation=tf.nn.elu) 64 | f = self.residual(f, 64, "block1") 65 | f = self.residual(f, 128, "block2") 66 | 67 | with tf.variable_scope("value"): 68 | v = self.residual(f, 32, "block3") 69 | # 为全连接层降低参数量 70 | v = tf.layers.conv2d(v, 4, 1, padding="SAME", data_format=DATA_FORMAT, name="conv", activation=tf.nn.elu) 71 | last_dim = reduce(lambda x, y: x * y, v.get_shape().as_list()[1:]) 72 | v = tf.reshape(v, (-1, last_dim)) 73 | v = tf.layers.dense(v, 64, activation=tf.nn.elu, name="fc1") 74 | # 手痒才搞的half_tanh激活函数。因为初期看到value loss很快就下降了,所有才搞的half_tanh,一方面使得实际学习率减半 75 | # 另一方面使得对logit的敏感度降低。实际上tanh就可以了 76 | self.value = tf.squeeze(tf.layers.dense(v, 1, activation=half_tanh, name="fc2"), axis=1) 77 | 78 | with tf.variable_scope("policy"): 79 | p = self.residual(f, 64, "block4") 80 | p = self.residual(p, 32, "block5") 81 | # 为全连接层降低参数量 82 | p = tf.layers.conv2d(p, 16, 1, padding="SAME", data_format=DATA_FORMAT, name="conv", activation=tf.nn.elu) 83 | last_dim = reduce(lambda x, y: x * y, p.get_shape().as_list()[1:]) 84 | p = tf.reshape(p, (-1, last_dim)) 85 | self.policy = tf.layers.dense(p, self.board_size * self.board_size, activation=None, name="fc") 86 | self.log_softmax = tf.nn.log_softmax(self.policy, axis=1) 87 | self.entropy = -tf.reduce_mean(tf.reduce_sum(tf.nn.softmax(self.policy) * self.log_softmax, axis=1)) 88 | self.prob = tf.nn.softmax(self.policy, axis=1) 89 | 90 | def eval(self, inputs): 91 | """ 92 | 把一个eval函数拆分成下面两个get_prob和get_value, 要调用的时候分开分别调用,会快很多 93 | :param inputs: 94 | :return: 95 | """ 96 | prob, value_ = self.sess.run([self.prob, self.value], feed_dict={self.inputs: inputs, self.training: False}) 97 | return prob, value_ 98 | 99 | def get_prob(self, inputs): 100 | """ 101 | 网络搭建好了以后就不要再添加结点了,不然会慢很多。所以最好先求出policy,再用numpy进行softmax 102 | :param inputs: 103 | :return: 104 | """ 105 | # prob = tf.nn.softmax(self.policy, axis=1) # 这个写法是不好了,因为这个函数每次调用,都会往gpu增加结点 106 | policy = self.sess.run(self.policy, feed_dict={self.inputs: inputs}) 107 | return softmax(policy) 108 | 109 | def get_value(self, inputs): 110 | value_ = self.sess.run(self.value, feed_dict={self.inputs: inputs}) 111 | return value_ 112 | 113 | def restore(self, ckpt_path): 114 | checkpoint = tf.train.get_checkpoint_state(ckpt_path) 115 | if checkpoint and checkpoint.model_checkpoint_path: 116 | self.saver.restore(self.sess, checkpoint.model_checkpoint_path) 117 | print("Successfully loaded:", checkpoint.model_checkpoint_path) 118 | else: 119 | try: 120 | self.saver.restore(self.sess, ckpt_path) 121 | except: 122 | raise FileNotFoundError("Could not find old network weights") 123 | 124 | def get_pipes(self, config, reload=True): 125 | """ 126 | 预测的时候,networkAPI只有一个线程,但有多个管道 127 | :param config: 128 | :param reload: 129 | :return: 130 | """ 131 | if self.api is None: 132 | self.api = NetworkAPI(config, self) 133 | self.api.start(reload) # 开启一个线程 134 | return self.api.get_pipe(reload) # 开启一个管道,返回管道的另一端 135 | 136 | def load_pretrained(self, data_path): 137 | reader = pywrap_tensorflow.NewCheckpointReader(data_path) 138 | var = reader.get_variable_to_shape_map() 139 | load_sucess, load_ignore = [], [] 140 | with tf.variable_scope("", reuse=True): 141 | for v in var: 142 | try: 143 | value = reader.get_tensor(v) 144 | self.sess.run(tf.assign(tf.get_variable(v), value)) 145 | load_sucess.append(v) 146 | except ValueError: 147 | load_ignore.append(v) 148 | continue 149 | print("loaded successed: ") 150 | for v in load_sucess: 151 | print(v) 152 | print("=" * 80) 153 | print("missed:") 154 | for v in load_ignore: 155 | print(v) 156 | 157 | def close(self): 158 | self.sess.close() 159 | if self.api is not None: 160 | self.api.close() 161 | 162 | 163 | def half_tanh(x): 164 | # 让tanh函数平滑一点,有点类似学习率降低了0.5 165 | return tf.nn.tanh(x / 2) 166 | 167 | 168 | def softmax(x): 169 | x -= np.max(x, axis=1, keepdims=True) 170 | ex = np.exp(x) 171 | return ex / np.sum(ex, axis=1, keepdims=True) 172 | -------------------------------------------------------------------------------- /genData/networkAPI.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from threading import Thread 3 | from multiprocessing import connection, Pipe 4 | from logging import getLogger 5 | import numpy as np 6 | 7 | logger = getLogger(__name__) 8 | 9 | 10 | class NetworkAPI(object): 11 | def __init__(self, cfg=None, agent_model=None): 12 | self.agent_model = agent_model 13 | self.config = cfg 14 | self.pipes = [] # 用于进程/线程之间通信 15 | self.reload = True 16 | self.prediction_worker = None 17 | self.done = False 18 | 19 | def start(self, reload): 20 | """ 21 | 开启一个线程,来预测数据 22 | :param reload: 23 | :return: 24 | """ 25 | self.reload = reload 26 | # 这里貌似只能给线程,不能给进程, 27 | self.prediction_worker = Thread(target=self.predict_batch_worker, name="prediction_worker") 28 | # prediction_worker = Process(target=self.predict_batch_worker, name="prediction_worker") 29 | self.prediction_worker.daemon = True # 守护线程,不必等其结束即可结束主线程 30 | self.prediction_worker.start() # 开启 31 | 32 | def get_pipe(self, reload=True): 33 | """ 34 | 定义一个管道,自己得一端,返回另一端 35 | :param reload: 36 | :return: 37 | """ 38 | me, you = Pipe() # 通信管道的两端 39 | self.pipes.append(me) 40 | self.reload = reload 41 | return you 42 | 43 | def predict_batch_worker(self): 44 | """ 45 | 把各个有用管道收集来的数据集中eval,再在原来的管道发送出去 46 | :return: 47 | """ 48 | while not self.done: 49 | ready = connection.wait(self.pipes, timeout=0.001) # 等待有通信管道可用,返回可用的管道 50 | if not ready: 51 | continue 52 | data, result_pipes, data_len = [], [], [] 53 | for pipe in ready: 54 | while pipe.poll(): # 不停地返回false,直到连接到有用数据,就返回true 55 | try: 56 | tmp = pipe.recv() # 如果没有消息可接收,recv方法会一直阻塞。如果连接的另外一端已经关闭,那么recv方法会抛出EOFError。 57 | except EOFError as e: 58 | logger.error(f"EOF error: {e}") 59 | pipe.close() # 另一端关闭,这端就关闭了 60 | else: 61 | data.extend(tmp) 62 | data_len.append(len(tmp)) 63 | result_pipes.append(pipe) 64 | if not data: 65 | continue 66 | data = np.asarray(data, dtype=np.float32) 67 | with self.agent_model.graph.as_default(): 68 | policy, value = self.agent_model.eval(data) 69 | buf = [] 70 | k, i = 0, 0 71 | for p, v in zip(policy, value): 72 | buf.append((p, float(v))) 73 | k += 1 74 | if k >= data_len[i]: 75 | result_pipes[i].send(buf) 76 | buf = [] 77 | k = 0 78 | i += 1 79 | 80 | def close(self): 81 | self.done = True 82 | for pipe in self.pipes: 83 | pipe.close() 84 | -------------------------------------------------------------------------------- /genData/player.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from collections import defaultdict 3 | import utils 4 | import numpy as np 5 | import random 6 | import gc 7 | 8 | 9 | class Action(object): 10 | def __init__(self): 11 | self.n = 0 # 初始化为1,使得概率分布smooth一丢丢 12 | self.w = 0 # W(s, a) : total action value 13 | self.q = 0 # Q(s, a) = N / W : action value 14 | self.p = 0 # P(s, a) : prior probability 15 | 16 | 17 | class State(object): 18 | def __init__(self): 19 | self.a = defaultdict(Action) # key: action, value: ActionState only valid action included 20 | self.sum_n = 0 # visit count 21 | 22 | 23 | class Player(object): 24 | def __init__(self, cfg=None, training=True, pipe=None, pv_fn=None): 25 | assert pipe is not None or pv_fn is not None 26 | self.config = cfg 27 | self.training = training 28 | # 做成这个样子而不是树状,是因为不同的action序列可能最终得到同一state,做成树状就不利于搜索信息的搜集 29 | self.tree = defaultdict(State) # 一个字符串表示的状态到包含信息的状态的映射 30 | self.root_state = None 31 | self.goal = self.config.goal 32 | self.tau = self.config.init_temp # 初始温度 33 | self.pipe = pipe # 通信管道 34 | self.job_done = False 35 | self.pv_fn = pv_fn 36 | 37 | def get_init_state(self): 38 | """ 39 | 用一个字符串表示棋盘,从上至下从左至由编码 40 | 黑子用3白字用1表示,空格部分用小写字母表示,a表示0个连续空格,b表示一个连续空格,以此类推 41 | :return: 42 | """ 43 | fen = "" 44 | for i in range(self.config.board_size): 45 | fen += chr(ord("a") + self.config.board_size) + '/' 46 | return fen 47 | 48 | def reset(self, search_tree=None): 49 | self.tree = defaultdict(State) if search_tree is None else search_tree 50 | self.root_state = None 51 | self.tau = self.config.init_temp # 初始温度 52 | 53 | def run(self, e=0.25): 54 | """ 55 | 对弈一局,获得一条数据,即从初始到游戏结束的一条数据 56 | :return: 57 | """ 58 | state = self.get_init_state() 59 | game_over = False 60 | data = [] # 收集(状态,动作)二元组 61 | value = 0 62 | last_action = None 63 | while not game_over: 64 | policy, action = self.get_action(state, e, last_action) 65 | data.append((state, policy, last_action)) # 装初始局面不装最终局面,装的是动作执行之前的局面 66 | board = utils.step(utils.state_to_board(state, self.config.board_size), action) 67 | state = utils.board_to_state(board) 68 | # self.pruning_tree(board, state) # 走完一步以后,对其他分支进行剪枝,以节约内存;注释掉,以节约时间 69 | game_over, value = utils.is_game_over(board, self.goal) 70 | # assert value != 1.0 71 | last_action = action 72 | 73 | self.reset() # 把树重启 74 | turns = len(data) 75 | if turns % 2 == 1: 76 | value = -value 77 | weights = utils.construct_weights(turns, gamma=self.config.gamma) 78 | final_data = [] 79 | for i in range(turns): 80 | final_data.append((*data[i], value, weights[i])) # (状态,policy,last_action, value, weight) 81 | value = -value 82 | return final_data 83 | 84 | def calc_policy(self, state, e, random_a): 85 | """ 86 | 根据state表示的状态的状态信息来计算policy 87 | :param state: 88 | :param e: 89 | :param random_a: 这个参数是为了`choose_best_player.py`而设定的,是的在各个ckpt之间进行对弈的时候有一定的随机性 90 | 在人机对弈的时候,为了让机器的落子有一定的随机性,也可以把这个变量设置为True 91 | :return: 92 | """ 93 | node = self.tree[state] 94 | policy = np.zeros((self.config.board_size, self.config.board_size), np.float32) 95 | most_visit_count = -1 96 | candidate_actions = list(node.a.keys()) 97 | policy_valid = np.empty((len(candidate_actions),), dtype=np.float32) 98 | for i, action in enumerate(candidate_actions): 99 | policy_valid[i] = node.a[action].n 100 | most_visit_count = node.a[action].n if node.a[action].n > most_visit_count else most_visit_count 101 | best_actions = [action for action in candidate_actions if node.a[action].n == most_visit_count] 102 | best_action = random.choice(best_actions) 103 | # for i, action in enumerate(candidate_actions): 104 | # print(action, node.a[action].n,node.a[action].p) 105 | # from IPython import embed; embed() 106 | if not self.training and not random_a: 107 | return None, best_action 108 | if random_a: 109 | self.tau *= self.config.tau_decay_rate_r 110 | else: 111 | self.tau *= self.config.tau_decay_rate 112 | if self.tau <= 0.01: 113 | for a in best_actions: 114 | policy[a[0], a[1]] = 1.0 / len(best_actions) 115 | return policy, best_action 116 | policy_valid /= np.max(policy_valid) # 除以最大值,再取指数,以免溢出 117 | policy_valid = np.power(policy_valid, 1 / self.tau) 118 | policy_valid /= np.sum(policy_valid) 119 | for i, action in enumerate(candidate_actions): 120 | policy[action[0], action[1]] = policy_valid[i] 121 | p = policy_valid 122 | # alphaGo这里添加了一个噪声。本project因为在探索的时候加的噪声足够多了,这里就不需要了 123 | # p = (1 - e) * p + e * np.random.dirichlet(0.5 * np.ones(policy_valid.shape[0])) 124 | # p = p / p.sum() # 有精度损失,导致其和不是1了 125 | random_action = candidate_actions[int(np.random.choice(len(candidate_actions), p=p))] 126 | return policy, random_action 127 | 128 | def get_action(self, state: str, e: float=0.25, last_action: tuple=None, random_a=False): 129 | """ 130 | 根据state表示的棋局状态进行多次蒙特卡洛搜索以获取一个动作 131 | :param state: 字符串表示的当前棋局状态 132 | :param e: 为训练而添加的噪声系数。后来还是没有使用他 133 | :param last_action: 134 | :param random_a: 这个参数是为了`choose_best_player.py`而设定的,是的在各个ckpt之间进行对弈的时候有一定的随机性 135 | 在人机对弈的时候,为了让机器的落子有一定的随机性,也可以把这个变量设置为True 136 | :return: 137 | """ 138 | self.root_state = state 139 | # # 该节点已经被访问了sum_n次,最多访问642次好了,节约点时间 140 | if state not in self.tree: 141 | num = self.config.simulation_per_step 142 | else: 143 | num = min(self.config.simulation_per_step, self.config.upper_simulation_per_step-self.tree[state].sum_n) 144 | for i in range(num): 145 | self.MCTS_search(state, [state], last_action) 146 | policy, action = self.calc_policy(state, e, random_a=random_a) 147 | return policy, action 148 | 149 | def pruning_tree(self, board: np.ndarray, state: str = None): 150 | """ 151 | 主游戏前进一步以后,可以对树进行剪枝,只保留前进的那一步所对应的子树 152 | :param board: 153 | :param state: 154 | :return: 155 | """ 156 | if state is None: 157 | state = utils.board_to_state(board) 158 | keys = list(self.tree.keys()) 159 | for key in keys: 160 | b = utils.state_to_board(key, self.config.board_size) 161 | if key != state \ 162 | and np.all(np.where(board == 1, 1, 0) >= np.where(b == 1, 1, 0)) \ 163 | and np.all(np.where(board == -1, 1, 0) >= np.where(b == -1, 1, 0)): 164 | del self.tree[key] 165 | 166 | def update_tree(self, v, history: list): 167 | """ 168 | 回溯更新 169 | :param p: policy 当前局面对黑方的策略 170 | :param v: value, 当前局面对黑方的价值 171 | :param history: 包含当前局面的一个棋局,(state, action) pair 172 | :return: 173 | """ 174 | _ = history.pop() # 最近的棋局 175 | # 注意,这里并没有把v赋给当前node 176 | while len(history) > 0: 177 | action = history.pop() 178 | state = history.pop() 179 | v = -v 180 | node = self.tree[state] # 状态结点 181 | action_state = node.a[action] # 该状态下的action边 182 | action_state.n += 1 183 | action_state.w += v 184 | action_state.q = action_state.w / action_state.n 185 | 186 | def evaluate_and_expand(self, state: str, board: np.ndarray = None, last_action: tuple=None): 187 | if board is None: 188 | board = utils.state_to_board(state, self.config.board_size) 189 | data_to_send = utils.board_to_inputs(board, last_action=last_action) 190 | if self.pv_fn is not None: 191 | policy, value = self.pv_fn(data_to_send[np.newaxis, ...]) 192 | policy, value = policy[0], value[0] 193 | else: 194 | self.pipe.send([data_to_send]) 195 | while not self.pipe.poll(): # 等待对方处理数据,这里能收到的时候,poll()就返回true 196 | pass 197 | policy, value = self.pipe.recv()[0] # 收回来的是一个列表,batch大小就是列表长度 198 | legal_actions = utils.get_legal_actions(board) 199 | all_p = max(sum([policy[action[0] * self.config.board_size + action[1]] for action in legal_actions]), 1e-5) 200 | for action in legal_actions: 201 | self.tree[state].a[action].p = policy[action[0] * self.config.board_size + action[1]] / all_p 202 | return value 203 | 204 | def MCTS_search(self, state: str, history: list, last_action: tuple): 205 | """ 206 | 以state为根节点进行MCTS搜索,搜索历史保存在histoty之中 207 | :param state: 一个字符串代表的当前状态,根节点 208 | :param history: 包含当前状态的一个列表 209 | :param last_action: 上一次的落子位置 210 | :return: 211 | """ 212 | while True: 213 | board = utils.state_to_board(state, self.config.board_size) 214 | game_over, v = utils.is_game_over(board, self.goal) # 落子前检查game over 215 | if game_over: 216 | self.update_tree(v, history=history) 217 | break 218 | if state not in self.tree: 219 | # 未出现过的state,则评估然后展开 220 | v = self.evaluate_and_expand(state, board, last_action) # 落子前进行评估 221 | self.update_tree(v, history=history) 222 | break 223 | sel_action = self.select_action_q_and_u(state) # 根据state选择一个action 224 | history.append(sel_action) # 放进action 225 | board = utils.step(board, sel_action) 226 | state = utils.board_to_state(board) 227 | history.append(state) 228 | last_action = sel_action 229 | 230 | def select_action_q_and_u(self, state: str) -> tuple: 231 | """ 232 | 根据结点状态信息返回一个action 233 | :param state: 234 | :return: 235 | """ 236 | node = self.tree[state] 237 | node.sum_n += 1 # 从这结点出发选择动作,该节点访问次数加一 238 | action_keys = list(node.a.keys()) 239 | act_count = len(action_keys) 240 | dirichlet = np.random.dirichlet(self.config.dirichlet_alpha * np.ones(act_count)) 241 | scores = np.empty((act_count,), np.float32) 242 | q_value = np.empty((act_count,), np.float32) 243 | counts = np.empty((act_count,), np.int32) 244 | for i, ac in enumerate(action_keys): 245 | action_state = node.a[ac] 246 | p_ = action_state.p # 该动作的先验概率 247 | if self.training: 248 | # 训练时候为根节点添加较大噪声,非根节点添加较小噪声 249 | if self.root_state == state: 250 | # simulation阶段的这个噪声可以防止坍缩 251 | p_ = 0.75 * p_ + 0.25 * dirichlet[i] 252 | else: 253 | p_ = 0.9 * p_ + 0.1 * dirichlet[i] # 非根节点添加较小的噪声 254 | # else: 255 | # # 给测试的时候也适当添加噪声,以便于充分搜索,和增加一点随机性。当然,这个随机性也可以在policy的概率分布中产生 256 | # if self.root_state == state: 257 | # # simulation阶段的这个噪声可以防止坍缩 258 | # p_ = 0.85 * p_ + 0.15 * dirichlet[i] 259 | # else: 260 | # p_ = 0.95 * p_ + 0.05 * dirichlet[i] # 非根节点添加较小的噪声 261 | scores[i] = action_state.q + self.config.c_puct * p_ * np.sqrt(node.sum_n + 1) / (1 + action_state.n) 262 | q_value[i] = action_state.q 263 | counts[i] = action_state.n 264 | if self.root_state == state and self.training: 265 | # 对于根节点,保证每个结点至少被访问两次,其中一次是展开,另一次是探索。 266 | # 故要求simulation_per_step >> 2*board_size*board_size才有意义 267 | # 这么做使得概率分布更加smooth,从而探索得更好 268 | no_visits = np.where(counts == 0)[0] 269 | if no_visits.shape[0] > 0: 270 | act_idx = np.random.choice(no_visits) 271 | return action_keys[act_idx] 272 | else: 273 | one_visits = np.where(counts == 1)[0] 274 | if one_visits.shape[0] > 0: 275 | act_idx = np.random.choice(one_visits) 276 | return action_keys[act_idx] 277 | max_score = np.max(scores) 278 | act_idx = np.random.choice([idx for idx in range(act_count) if scores[idx] == max_score]) 279 | return action_keys[act_idx] 280 | 281 | def close(self): 282 | self.job_done = True 283 | del self.tree 284 | gc.collect() 285 | -------------------------------------------------------------------------------- /images/back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoYi0/alphaFive/9492d1083154d13c5ed8c8e3c2a34f0fc1fee2e7/images/back.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | import utils 4 | from genData.network import ResNet as model 5 | import tensorflow as tf 6 | import os 7 | import config 8 | import numpy as np 9 | from concurrent.futures import ProcessPoolExecutor, wait, ALL_COMPLETED 10 | from threading import Lock, Thread 11 | from genData.player import Player 12 | from multiprocessing import Manager, Process, Queue 13 | # from queue import Queue # 这个只能在多线程之间使用,即同一个进程内部使用,不能跨进程通信 14 | from multiprocessing.managers import BaseManager 15 | from utils import RandomStack 16 | import gc 17 | 18 | # global只能在父子进程中共享只读变量 19 | 20 | cur_dir = os.path.dirname(__file__) 21 | os.chdir(cur_dir) 22 | PURE_MCST = 1 23 | AI = 2 24 | 25 | 26 | # job_lock = Lock() 27 | 28 | 29 | def main(restore=False): 30 | stack = RandomStack(board_size=config.board_size, length=config.buffer_size) # 命名为stack了,事实上是一个队列 31 | net = model(config.board_size) 32 | if restore: 33 | net.restore(config.ckpt_path) 34 | stack.load(6960) # 这里需要根据实际情况更改,看从哪一步接着训练,就写为哪一步 35 | with net.graph.as_default(): 36 | episode_length = tf.placeholder(tf.float32, (), "episode_length") 37 | total_loss, cross_entropy, value_loss, entropy = net.total_loss, net.cross_entropy_loss, net.value_loss, net.entropy 38 | lr = tf.get_variable("learning_rate", dtype=tf.float32, initializer=1e-3) 39 | opt = tf.train.AdamOptimizer(lr).minimize(total_loss) 40 | net.sess.run(tf.global_variables_initializer()) 41 | tf.summary.scalar("x_entropy_loss", cross_entropy) 42 | tf.summary.scalar("value_loss", value_loss) 43 | tf.summary.scalar("total_loss", total_loss) 44 | tf.summary.scalar("entropy", entropy) 45 | tf.summary.scalar('episode_len', episode_length) 46 | log_dir = os.path.join("summary", "log_" + time.strftime("%Y%m%d_%H_%M_%S", time.localtime())) 47 | journalist = tf.summary.FileWriter(log_dir, flush_secs=10) 48 | summury_op = tf.summary.merge_all() 49 | step = 1 # 如果接着训练,这里就改为接着的那一步的下一步。手动改了算了,懒得写成自动识别的了 50 | cur_pipes = [net.get_pipes(config) for _ in range(config.max_processes)] # 手动创建进程不需要Manager() 51 | q = Queue(50) # 用Process手动创建的进程可以使用这个Queue,否则需要Manager()来管理 52 | for i in range(config.max_processes): 53 | proc = Process(target=gen_data, args=(cur_pipes[i], q)) 54 | proc.daemon = True # 父进程结束以后,子进程就自动结束 55 | proc.start() 56 | 57 | while step < config.total_step: 58 | # 每生成一条数据,才训练一次 59 | net.sess.run(tf.assign(lr, config.get_lr(step))) 60 | data_record, result = q.get(block=True) # 获取一个item,没有则阻塞 61 | r = stack.push(data_record, result) 62 | if r and stack.is_full(): # 满了再训练会比较慢,但是消除了biase 63 | for _ in range(4): 64 | boards, weights, values, policies = stack.get_data(batch_size=config.batch_size) 65 | xcro_loss, mse_, entropy_, _, sum_res = net.sess.run( 66 | [cross_entropy, value_loss, entropy, opt, summury_op], 67 | feed_dict={net.inputs: boards, net.distrib: policies, 68 | net.winner: values, net.weights: weights, episode_length: len(data_record)}) 69 | step += 1 70 | journalist.add_summary(sum_res, step) 71 | print(" ") 72 | print("step: %d, xcross_loss: %0.3f, mse: %0.3f, entropy: %0.3f" % (step, xcro_loss, mse_, entropy_)) 73 | if step % 60 == 0: 74 | net.saver.save(net.sess, save_path=os.path.join(config.ckpt_path, "alphaFive"), global_step=step) 75 | stack.save(step) 76 | print("save ckpt and data successfully") 77 | net.saver.save(net.sess, save_path=os.path.join(config.ckpt_path, "alphaFive"), global_step=step) 78 | stack.save() 79 | net.close() 80 | 81 | 82 | def gen_data(pipe, q): 83 | player = Player(config, training=True, pipe=pipe) 84 | while True: 85 | game_record = player.run() 86 | value = game_record[-1][-2] 87 | game_length = len(game_record) 88 | if value == 0.0: 89 | result = utils.DRAW 90 | elif game_length % 2 == 1: 91 | result = utils.BLACK_WIN 92 | else: 93 | result = utils.WHITE_WIN 94 | q.put((game_record, result), block=True) # block=True满了则阻塞 95 | 96 | 97 | def next_unused_name(name): 98 | save_name = name 99 | iteration = 0 100 | while os.path.exists(save_name): 101 | save_name = name + '-' + str(iteration) 102 | iteration += 1 103 | return save_name 104 | 105 | 106 | if __name__ == '__main__': 107 | main(restore=False) 108 | -------------------------------------------------------------------------------- /note.txt: -------------------------------------------------------------------------------- 1 | ckpt-3540很牛逼了,我下不赢 2 | 开始训练的时候,memory里面的对局数是black: white = 145: 141 in the memory,self-play black: 132, white: 137 -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # AlphaFive 2 | 模仿AlphaGo/AlphaGo Zero写的一个五子棋AI,为了快速收敛,针对五子棋的特点做了一些小trick改进 3 | 4 | 先上效果图 5 | 6 | ![five_6960.gif](https://github.com/GuoYi0/alphaFive/blob/master/tmp/five_6960.gif) 7 | 8 | ## 运行 9 | *自我对弈* `python self_play.py` 10 | 11 | *人机对弈* `python GUI.py` 12 | 13 | *训练* `python train.py` 14 | 15 | *两个ckpt之间对弈或者在所有ckpt之间选择一个最佳* `python choose_best_player.py` 16 | 这也需要在`choose_best_player.py`的`main()`函数中修改`all_ckpts`以适应ckpt目录下的具体的ckpt 17 | 18 | ## 算法 19 | AlphaGo/AlphaGo Zero核心思想是用一个network搭配MCST进行搜索和优化。network的输入为相对于当前玩家棋盘局面, 20 | 输出为在各个地方落子的概率(即policy)和当前局面对于当前玩家而言最终的得分期望(即value),若最终当前玩家输了,得-1分, 21 | 赢了得+1分,和棋得0分。故value是介于-1到1之间的一个实数。 22 | 23 | policy的作用是为MCTS提供一个先验概率,让MCTS优先搜索对当前选手而言更可能获胜的路径,也就是说基于当前策略去采样,而不是随机采样;value的作用在于搜索到叶子节点的时候,若没有game over,则以value值进行回溯更新。 24 | 单纯的MCTS在搜索到非game over的叶子节点的时候会进行roll out,进行一次路径采样,用这个采样结果来估计基于当前局面当前选手的得分期望。 25 | 虽然这个估计是无偏估计,但是仅仅用一个样本来估计期望显然具有很大的方差。而value值虽然可能是有偏的,但是他是依据很多次棋局训练出来的,具有一定的平均意义,方差相对较小。 26 | 27 | 以某种棋局状态出发进行多次MCTS搜索以后,就可以依据各个子节点的访问次数构造一个概率分布,依据该概率分布来决定真正应该再何处落子,同时该概率分布也可以作为network训练policy的监督信号。 28 | 当一局棋结束以后,就可以知道该轮对弈在每个棋局状态下的得分,该得分将作为训练value的监督信号。 29 | 30 | 从以上算法可知,这是一个不断根据network的输出以MCST进行对弈采样,然后把对弈结果再拿来更新network的参数,这样一个不断迭代过程。仅仅用单纯的network输出的策略对弈来更新network显然是不行的,结合了network策略和MCST,会比单纯network输出的策略强一丢丢。从而可以提升network策略。 31 | 32 | ## 特点 33 | (该project的运行环境是GTX 1070显卡和六核i7-8700 CPU的笔记本电脑。对于强化学习来说穷得不该入这行的大门。用主进程训练network,五个子进程生成模拟数据,一条长度为30的对弈数据大约就得30秒钟。故有些特点是我自己加进去的,感觉可能会加快运行速度和收敛) 34 | 35 | [1] **输入特征** AlphaGo/AlphaGo Zero把过往的棋局和当前棋局叠在一起作为network的输入,一方面可能是因为围棋的落子规则决定了当前落子不仅仅依赖于当前局面,也依赖于过往局面;另一方面叠加更多的过往棋局也可能利于训练。 36 | 此外,围棋的先手和后手的最终判输赢的规则也不一样,故也需要告诉network当前玩家是先手还是后手。对于五子棋而言,落子仅仅取决于当前棋局,与过往棋局完全无关,此外先后手最终的判输赢规则是一样的,所以仅仅需要输入当前局面即可。 37 | 对于11*11的棋盘,输入shape可以是[B, 2, 11, 11],其中B是bachsize,2个channel其中一个是当前玩家的特征,另一个是对方玩家的特征,有棋子的地方为1,没有棋子的地方为0。 38 | 理论上只需要一个channel就可以的,当前玩家的棋子为1,对方玩家的棋子为-1,没有棋子的地方是0。搞两个channel可以加速收敛。此外,该project加了第三个channel,其在对方玩家最后一次落子的地方为1,其余地方为0,即表征last action。这个channel可以起到一个attention的作用,告诉当前玩家可能需要聚焦于对方玩家的落子点的附近进行落子。 39 | 这个channel可能没啥太大卵用,后续可以做对比实验试一试。 40 | 41 | [2] **MCST树设计** 一般来说蒙特卡洛搜索树是一个树状结构,但由于五子棋的落子决策完全仅依赖于当前状态(有last action的情形除外),而不同落子顺序可能到达相同的状态,这个相同的状态的状态信息就可以复用了。 42 | 故本project并没有设计成树状结构,而是以dict的形式存储,其中key为一个字符串表示的某种状态,value是该状态的状态信息。从不同路径抵达该状态时可以共享该信息,并共同更新该信息。但是在有last action的时候,情况有些微妙的变化。 43 | last action仅仅在需要作为network的输入的时候起作用。在模拟对弈的时候,到达某一个局面以后,假设需要以当前局面为根节点出发进行500次搜索,这500次的last action是一样的,搜索完毕以后在该节点形成的概率分布将作为policy的监督信号,该监督信号都对应于同一个last action,这一点是没有问题的。但是当该局面节点曾经作为叶子节点的时候,对叶子节点的评估所使用的last action就未必是现在的last action了。 44 | 只有当对弈局数足够多以后,这个影响才可以逐渐减弱。后续可以去掉last action这个channel一试。 45 | 46 | [3] **数据处理** 五子棋有一个很大的bug,就是(貌似35步以内)先手必赢。这样产生的后果是,模拟对弈的数据里面,先手赢的数据量会多于后手赢的数据量,这样失衡的数据直接拿去训练,会导致网络进一步偏好先手赢(如果当前玩家落子前,棋局里面当前玩家的棋子数量等于对手玩家的棋子数量,则当前玩家就是先手;若当前玩家的棋子数量比对手棋子数量少一个,则是后手。故网络完全可以通过棋子数量学到当前玩家是先手还是后手。),这种偏好进一步让模拟对弈产生更多的先手赢的棋局。 47 | 最终模型可能会坍塌,即先手预测的value接近于1,policy是比较准确;后手预测的value接近于-1。坍塌以后,会产生大量长度只有9或者11的棋局,这些棋局是先手很快就连成了五颗棋子,后手连成了四颗或者比较乱的摆放。当然,如果噪声足够大,训练时间足够长,这种现象可以缓解。本project采取一个的缓解方案是,记录replay buffer里面先手和后手赢棋的棋局数,当某一方赢棋数量太少的时候,若搜集到该方的一条赢棋,则重复加入buffer。此外,对于步数太短的棋局,以一定概率舍弃。这一部分在`utils.py RandomStack`里面。 48 | 49 | [4] **训练权重** 在不断模拟对弈过程中,越是往后的的局面出现的频次就会越小,越是靠前的局面出现的频次肯定越大。故在本project设计了一个权重增长因子,使得靠后的局面获得的训练权重大于靠前的局面的训练权重。这样做的另一个原因是,靠后的落子与输赢的关联性可能更大,所以获得一个较大的训练权重。 50 | 51 | [5] **探索方案** 在训练过程中,很容易使得网络输出的policy的熵太小,不利于充分探索。缓解方案有很多,例如把熵加入loss里面,加大噪声等等。本project比较暴力,在根节点处强制要求每个子节点至少被访问两次。这样一方面可以加大探索力度,另一方面让监督信号的概率分布smooth一些,有点类似于监督学习里面的label smooth。 52 | 此外,对于根节点加入了0.25/0.75的狄利克雷噪声,非根节点加入0.1/0.9的狄利克雷噪声。在主游戏最终选择action的时候,只根据结点访问次数的概率分布选择action,不再加入噪声。 53 | 54 | 55 | 56 | ## 训练结果 57 | 上面的效果图是我训练到6960步的时候,人机对弈的结果。AI是黑方,我是白方(尽管最后还是我赢了),可以看到,AI在游戏前期还是不错的。会进攻会防守。但是后期就有点乏力了。 58 | 可能的原因是训练的次数远远不够,对弈后期的局面没有得到充分的训练。事实上,在6960步的时候,对弈的平均长度只有25步左右,需要继续往后训练,模拟对弈才会生成更长的对弈棋局。 59 | 60 | 各种loss如下图所示 61 | 62 | ![episode_length](https://github.com/GuoYi0/alphaFive/blob/master/tmp/episode_length.jpg) 63 | ![value_loss](https://github.com/GuoYi0/alphaFive/blob/master/tmp/value_loss.jpg) 64 | ![xentropy_loss](https://github.com/GuoYi0/alphaFive/blob/master/tmp/xentropy_loss.jpg), 65 | ![entropy](https://github.com/GuoYi0/alphaFive/blob/master/tmp/entropy.jpg) 66 | ![total_loss](https://github.com/GuoYi0/alphaFive/blob/master/tmp/total_loss.jpg), 67 | 68 | 图eposide_len反映的是模拟对弈的时候产生的棋局的步数,(不是回合数,黑方落子,白方再落子,即为2步)。从该图可以看出,在训练初期,由于是完全随机落子,棋局步数很长,达到了50+ 69 | 然后随着训练进行,模拟对弈的棋局步数很快下降,说明AI逐渐掌握了游戏初步规则,需要把五颗棋子摆放成一条线才能赢,赶紧很快就摆成一条线了,这时候只知道进攻,不懂得防守。随着训练的继续进行,AI才逐渐知道怎么防守,游戏逐渐变长。也只有把前期的攻防都学好以后,游戏才能发展到后期。 70 | 从该图也可以看出,游戏的长度依然在持续增长,说明8k步的训练是远远不够的。 71 | 72 | value_loss的变化图,曲线在7k步的时候突然增大了,原因是在7k步的时候把学习率由1e-3降为了2e-4。为啥学习率降低了会导致value loss陡增?有待研究。 73 | 所有的loss在8700步的时候都陡增了,是因为我这里掐断了,然后重新跑的。我明明同时保存了ckpt和data buffer,不知为啥依然会有这个现象。。不过loss下降的速度也很快。如果不用data buffer里面的数据,而是根据断点处重新生成数据,则新生成的数据都是基于断点处的network生成的,而network只对buffer里面的历史数据拟合得很好,对完全由断点处的network产生的数据拟合度未必好,产生loss陡增的现象还可以理解。 74 | 可能是代码哪里有bug。后续再研究。 75 | 76 | entropy反映了输出policy的概率分布的熵。在初期随机落子,熵比较大,随着训练进行,熵自然就减小了。 77 | 78 | 图中显示的value loss和xentropy loss都是没有加权的,即棋局初期和后期的权重保持一致。而total loss是加权value loss,加权xentropy loss,L2正则化loss之和。 79 | 80 | 该project只提供了学习率下降之前的ckpt,即6960步的ckpt。本人通过`choose_best_player.py`发现后续的ckpt依然在慢慢变得厉害。 81 | 82 | 83 | ## 参考代码 84 | [ChineseChess-alphaZero](https://github.com/NeymarL/ChineseChess-AlphaZero), 还有一个写界面的参考代码找不着了 85 | 86 | ## 后续工作 87 | 1. 进一步往后训练。可能需要适当加大buffer size,并调整学习率; 88 | 2. 研究为何在7k步学习率下降会导致value loss陡增; 89 | 3. 研究为何掐断以后保持原来的数据和ckpt会导致loss陡增(虽然很快还是降下来了)。 90 | -------------------------------------------------------------------------------- /self_play.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from genData.network import ResNet as Model 3 | import config 4 | import pygame 5 | import os 6 | import numpy as np 7 | from genData.player import Player 8 | import utils 9 | import cv2 10 | import imageio 11 | 12 | 13 | GRID_WIDTH = 36 14 | WIDTH = (config.board_size + 2) * GRID_WIDTH 15 | HEIGHT = (config.board_size + 2) * GRID_WIDTH 16 | FPS = 30 17 | WHITE = (255, 255, 255) 18 | BLACK = (0, 0, 0) 19 | HUMAN = 0 20 | AI = 2 21 | 22 | 23 | def main(trained_ckpt): 24 | print(config.simulation_per_step) 25 | net = Model(config.board_size) 26 | player = Player(config, training=False, pv_fn=net.eval) 27 | net.restore(trained_ckpt) 28 | pygame.init() 29 | screen = pygame.display.set_mode((WIDTH, HEIGHT)) 30 | pygame.display.set_caption("五子棋") 31 | clock = pygame.time.Clock() 32 | base_folder = os.path.dirname(__file__) 33 | img_folder = os.path.join(base_folder, 'images') 34 | background_img = pygame.image.load(os.path.join(img_folder, 'back.png')).convert() 35 | background = pygame.transform.scale(background_img, (WIDTH, HEIGHT)) 36 | back_rect = background.get_rect() 37 | running = True 38 | frames = [] 39 | def draw_stone(screen_): 40 | for i in range(config.board_size): 41 | for j in range(config.board_size): 42 | if state[i, j] == 1: 43 | pygame.draw.circle(screen_, BLACK, (int((j + 1.5) * GRID_WIDTH), int((i + 1.5) * GRID_WIDTH)), 16) 44 | elif state[i, j] == -1: 45 | pygame.draw.circle(screen_, WHITE, (int((j + 1.5) * GRID_WIDTH), int((i + 1.5) * GRID_WIDTH)), 16) 46 | else: 47 | assert state[i, j] == 0 48 | 49 | def draw_background(surf): 50 | screen.blit(background, back_rect) 51 | rect_lines = [ 52 | ((GRID_WIDTH, GRID_WIDTH), (GRID_WIDTH, HEIGHT - GRID_WIDTH)), 53 | ((GRID_WIDTH, GRID_WIDTH), (WIDTH - GRID_WIDTH, GRID_WIDTH)), 54 | ((GRID_WIDTH, HEIGHT - GRID_WIDTH), 55 | (WIDTH - GRID_WIDTH, HEIGHT - GRID_WIDTH)), 56 | ((WIDTH - GRID_WIDTH, GRID_WIDTH), 57 | (WIDTH - GRID_WIDTH, HEIGHT - GRID_WIDTH)), 58 | ] 59 | for line in rect_lines: 60 | pygame.draw.line(surf, BLACK, line[0], line[1], 2) 61 | 62 | for i in range(config.board_size): 63 | pygame.draw.line(surf, BLACK, 64 | (GRID_WIDTH * (2 + i), GRID_WIDTH), 65 | (GRID_WIDTH * (2 + i), HEIGHT - GRID_WIDTH)) 66 | pygame.draw.line(surf, BLACK, 67 | (GRID_WIDTH, GRID_WIDTH * (2 + i)), 68 | (HEIGHT - GRID_WIDTH, GRID_WIDTH * (2 + i))) 69 | 70 | circle_center = [ 71 | (GRID_WIDTH * 4, GRID_WIDTH * 4), 72 | (WIDTH - GRID_WIDTH * 4, GRID_WIDTH * 4), 73 | (WIDTH - GRID_WIDTH * 4, HEIGHT - GRID_WIDTH * 4), 74 | (GRID_WIDTH * 4, HEIGHT - GRID_WIDTH * 4), 75 | ] 76 | for cc in circle_center: 77 | pygame.draw.circle(surf, BLACK, cc, 5) 78 | 79 | game_over = False 80 | state_str = player.get_init_state() 81 | board = utils.state_to_board(state_str, config.board_size) 82 | state = board 83 | draw_background(screen) 84 | pygame.display.flip() 85 | image_data = pygame.surfarray.array3d(pygame.display.get_surface()) 86 | frames.append(cv2.resize(image_data, (0, 0), fx=0.5, fy=0.5)) 87 | turn = 0 88 | i = 0 89 | while running: 90 | clock.tick(FPS) 91 | for event in pygame.event.get(): 92 | if event.type == pygame.QUIT: 93 | running = False 94 | break 95 | action = None 96 | if not game_over: 97 | _, action = player.get_action(state_str, last_action=action) 98 | board = utils.step(utils.state_to_board(state_str, config.board_size), action) 99 | state_str = utils.board_to_state(board) 100 | # player.pruning_tree(board, state_str) # 走完一步以后,对其他分支进行剪枝,以节约内存 101 | game_over, value = utils.is_game_over(board, config.goal) 102 | if turn %2 ==1: 103 | state = board 104 | else: 105 | state = -board 106 | turn += 1 107 | draw_background(screen) 108 | draw_stone(screen) 109 | pygame.display.flip() 110 | image_data = pygame.surfarray.array3d(pygame.display.get_surface()) 111 | frames.append(cv2.resize(image_data, (0, 0), fx=0.5, fy=0.5)) 112 | 113 | # draw_background(screen) 114 | # draw_stone(screen) 115 | # pygame.display.flip() 116 | if game_over: 117 | i += 1 118 | image_data = pygame.surfarray.array3d(pygame.display.get_surface()) 119 | frames.append(cv2.resize(image_data, (0, 0), fx=0.5, fy=0.5)) 120 | if i >= 3: # 最终保留三帧 121 | break 122 | 123 | pygame.quit() 124 | print("game finished, start to write to gif.") 125 | gif = imageio.mimsave("tmp/five.gif", frames,'GIF', duration=0.8) 126 | print("done!") 127 | 128 | 129 | if __name__ == "__main__": 130 | main(trained_ckpt=config.ckpt_path) 131 | -------------------------------------------------------------------------------- /summary/log_20200312_11_54_18/events.out.tfevents.1583997414.DESKTOP-T9U7R33: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoYi0/alphaFive/9492d1083154d13c5ed8c8e3c2a34f0fc1fee2e7/summary/log_20200312_11_54_18/events.out.tfevents.1583997414.DESKTOP-T9U7R33 -------------------------------------------------------------------------------- /summary/log_20200312_11_54_18/events.out.tfevents.1584241253.DESKTOP-T9U7R33: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoYi0/alphaFive/9492d1083154d13c5ed8c8e3c2a34f0fc1fee2e7/summary/log_20200312_11_54_18/events.out.tfevents.1584241253.DESKTOP-T9U7R33 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import sys 4 | 5 | # # 6 | # # def get_init_state(): 7 | # # """ 8 | # # 用一个字符串表示棋盘,从上至下从左至由编码 9 | # # 黑子用3白字用1表示,空格部分用小写字母表示,a表示一个连续空格,b表示两个连续空格,以此类推 10 | # # :return: 11 | # # """ 12 | # # fen = "" 13 | # # for i in range(4): 14 | # # fen += chr(ord("a") + i) + '/' 15 | # # return fen 16 | # # 17 | # # print('a'.isalpha()) 18 | # # print(ord('c')-ord('a')) 19 | # 20 | # a = np.array([[1,0], [0,0]]) 21 | # print(np.equal(a, 0).astype(np.float32)) 22 | # # print(np.where(a==0)) 23 | # # def get_legal_moves(board): 24 | # # zeros = np.where(board == 0) 25 | # # return [(int(i), int(j)) for i, j in zip(*zeros)] 26 | # # 27 | # # print(get_legal_moves(a)) 28 | # 29 | # def board_to_state(board: np.ndarray) -> str: 30 | # fen = "" 31 | # h, w = board.shape 32 | # for i in range(h): 33 | # c = 0 34 | # for j in range(w): 35 | # if board[i, j] == 0: 36 | # c += 1 37 | # else: 38 | # fen += chr(ord('a')+c) if c > 0 else '' 39 | # fen += str(board[i, j] + 2) 40 | # c = 0 41 | # fen += chr(ord('a') + c) if c > 0 else '' 42 | # fen += '/' 43 | # return fen 44 | # 45 | # def state_to_board(state:str, board_size:int): 46 | # """ 47 | # 根据字符串表示的state转换为棋盘。字符串中,黑子用1表示,红子用3表示 48 | # :param state: 49 | # :param board_size: 50 | # :return: 51 | # """ 52 | # board = np.zeros((board_size, board_size), np.int8) 53 | # i = j = 0 54 | # for ch in state: 55 | # if ch == '/': 56 | # i += 1 57 | # j = 0 58 | # elif ch.isalpha(): 59 | # j += ord(ch) - ord('a') 60 | # else: 61 | # board[i][j] = int(ch) - 2 62 | # j += 1 63 | # return board 64 | # 65 | # 66 | # b = np.array([[1,0], [-1, 1]]) 67 | # # print(board_to_state(b)) 68 | # # 69 | # # print(state_to_board(board_to_state(b), 2)) 70 | # 71 | # index = [(0,0), (1,1)] 72 | # g = list(zip(*index)) 73 | # b[g[0], g[1]] = -1 74 | # print(b) 75 | # b = [1,2,3,4,5,6,] 76 | # print(b[-3:]) 77 | # print(np.random.dirichlet(2 * np.ones(50))) 78 | # a = np.array([0,2,3,0,2]) 79 | # a = 0.28437745 80 | # print("a {}, %.3f, ".format((1,2))%a) 81 | # a = np.array([1, 2, 3, 4, 5], dtype=np.float32) 82 | # 83 | # 84 | # def fenmu(): 85 | # print("haha") 86 | # return 2.0 87 | # 88 | # 89 | # a /= fenmu() 90 | # 91 | # # print(a) 92 | # a = [1,2,3,1,2,3] 93 | # a.remove(1) 94 | # # b = a.index(1) 95 | # print(a) 96 | # from logging import getLogger 97 | # logger = getLogger(__name__) 98 | # logger.info("haha") 99 | # g = np.random.dirichlet(0.5 * np.ones(20)) 100 | # print(g) 101 | # import time 102 | # 103 | # try: 104 | # for i in range(10): 105 | # print(i) 106 | # time.sleep(5) 107 | # except KeyboardInterrupt: 108 | # print("over") 109 | from concurrent.futures import ThreadPoolExecutor 110 | import time 111 | # from multiprocessing import Process,Pipe 112 | # # 导入进程,管道模块 113 | # 114 | # def f(conn): 115 | # conn.send([1,'test',None]) 116 | # conn.send([2,'test',None]) 117 | # print(conn.recv()) 118 | # conn.close() 119 | # 120 | # if __name__ == "__main__": 121 | # parent_conn,child_conn = Pipe() # 产生两个返回对象,一个是管道这一头,一个是另一头 122 | # p = Process(target=f,args=(child_conn,)) 123 | # p.start() 124 | # parent_conn.send('father test') 125 | # print(parent_conn.recv()) 126 | # print(parent_conn.recv()) 127 | # p.join() 128 | 129 | # a = np.array([1,2,3,4,5]) 130 | # b = np.zeros((5,3)) 131 | # c = 0.1 132 | # 133 | # data = [(a,b,c), (a,b,c), (a,b,c)] 134 | # print(data) 135 | # print("") 136 | # print("======================================================================") 137 | # print("") 138 | # import pickle 139 | # 140 | # 141 | # f = open("d.pkl", "wb") 142 | # pickle.dump(data, f) 143 | # f.close() 144 | # 145 | # g = open("d.pkl", "rb") 146 | # d = pickle.load(g) 147 | # g.close() 148 | # print(d) 149 | import multiprocessing 150 | # 151 | # # 声明一个全局变量 152 | # share_var = ["start flag"] 153 | # 154 | # def sub_process(process_name): 155 | # # 企图像单个进程那样通过global声明使用全局变量 156 | # global share_var 157 | # share_var.append(process_name) 158 | # # 但是很可惜,在多进程中这样引用只能读,修改其他进程不会同步改变 159 | # for item in share_var: 160 | # print(f"{process_name}-{item}") 161 | # pass 162 | # 163 | # def main_process(): 164 | # process_list = [] 165 | # # 创建进程1 166 | # process_name = "process 1" 167 | # tmp_process = multiprocessing.Process(target=sub_process,args=(process_name,)) 168 | # process_list.append(tmp_process) 169 | # # 创建进程2 170 | # process_name = "process 2" 171 | # tmp_process = multiprocessing.Process(target=sub_process, args=(process_name,)) 172 | # process_list.append(tmp_process) 173 | # # 启动所有进程 174 | # for process in process_list: 175 | # process.start() 176 | # for process in process_list: 177 | # process.join() 178 | # 179 | # if __name__ == "__main__": 180 | # main_process() 181 | 182 | from multiprocessing import Queue, Process 183 | from concurrent.futures import ProcessPoolExecutor 184 | import time 185 | 186 | from multiprocessing import Process, Lock 187 | import json, time, os 188 | 189 | 190 | # def search(): 191 | # time.sleep(1) # 模拟网络io 192 | # with open('db.txt', mode='rt', encoding='utf-8') as f: 193 | # res = json.load(f) 194 | # print(f'还剩{res["count"]}') 195 | # 196 | # 197 | # def get(): 198 | # with open('db.txt', mode='rt', encoding='utf-8') as f: 199 | # res = json.load(f) 200 | # # print(f'还剩{res["count"]}') 201 | # time.sleep(1) # 模拟网络io 202 | # if res['count'] > 0: 203 | # res['count'] -= 1 204 | # with open('db.txt', mode='wt', encoding='utf-8') as f: 205 | # json.dump(res, f) 206 | # print(f'进程{os.getpid()} 抢票成功') 207 | # time.sleep(1.5) # 模拟网络io 208 | # else: 209 | # print('票已经售空啦!!!!!!!!!!!') 210 | # 211 | # 212 | # def task(lock): 213 | # search() 214 | # 215 | # # 锁住 216 | # lock.acquire() 217 | # get() 218 | # lock.release() 219 | # # 释放锁头 220 | # 221 | # 222 | # if __name__ == '__main__': 223 | # lock = Lock() # 写在主进程是为了让子进程拿到同一把锁. 224 | # for i in range(15): 225 | # p = Process(target=task, args=(lock,)) 226 | # p.start() 227 | # class A(object): 228 | # def __init__(self): 229 | # self.num = 0 230 | # 231 | # def my_print(self): 232 | # print(self.num) 233 | # 234 | # def set(self, h): 235 | # self.num = h 236 | # 237 | # 238 | # class M(object): 239 | # def __init__(self, fn): 240 | # self.fn = fn 241 | # 242 | # def zhixing(self): 243 | # self.fn() 244 | # 245 | # 246 | # a = A() 247 | # m = M(a.my_print) 248 | # m.zhixing() 249 | # a.set(20) 250 | # m.zhixing() 251 | 252 | for _ in range(10): 253 | idx = np.random.choice(5, 3) 254 | print(idx) 255 | -------------------------------------------------------------------------------- /tmp/entropy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoYi0/alphaFive/9492d1083154d13c5ed8c8e3c2a34f0fc1fee2e7/tmp/entropy.jpg -------------------------------------------------------------------------------- /tmp/episode_length.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoYi0/alphaFive/9492d1083154d13c5ed8c8e3c2a34f0fc1fee2e7/tmp/episode_length.jpg -------------------------------------------------------------------------------- /tmp/five_6960.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoYi0/alphaFive/9492d1083154d13c5ed8c8e3c2a34f0fc1fee2e7/tmp/five_6960.gif -------------------------------------------------------------------------------- /tmp/total_loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoYi0/alphaFive/9492d1083154d13c5ed8c8e3c2a34f0fc1fee2e7/tmp/total_loss.jpg -------------------------------------------------------------------------------- /tmp/value_loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoYi0/alphaFive/9492d1083154d13c5ed8c8e3c2a34f0fc1fee2e7/tmp/value_loss.jpg -------------------------------------------------------------------------------- /tmp/xentropy_loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoYi0/alphaFive/9492d1083154d13c5ed8c8e3c2a34f0fc1fee2e7/tmp/xentropy_loss.jpg -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import random 4 | from logging import getLogger 5 | from time import time 6 | import pickle 7 | 8 | logger = getLogger(__name__) 9 | BLACK_WIN = 1 10 | WHITE_WIN = -1 11 | DRAW = 0 12 | 13 | 14 | class RandomStack(object): 15 | def __init__(self, board_size, length=2000): 16 | self.data = [] # 列表每个元素是 (状态,policy,value, weight)的tuple 17 | self.board_size = board_size 18 | self.length = length 19 | self.white_win = 0 20 | self.black_win = 0 21 | self.data_len = [] # 装载每条数据的长度 22 | self.result = [] # 装载结果 23 | self.total_length = 0 24 | self.num = 0 25 | self.time = time() 26 | self.self_play_black_win = 0 27 | self.self_play_white_win = 0 28 | 29 | def save(self, s=""): 30 | f1 = open(f"data_buffer/data{s}.pkl", "wb") 31 | pickle.dump(self.data, f1) 32 | f1.close() 33 | 34 | f1 = open(f"data_buffer/data_len{s}.pkl", "wb") 35 | pickle.dump(self.data_len, f1) 36 | f1.close() 37 | 38 | f1 = open(f"data_buffer/result{s}.pkl", "wb") 39 | pickle.dump(self.result, f1) 40 | f1.close() 41 | 42 | def load(self, s=""): 43 | try: 44 | with open(f"data_buffer/data{s}.pkl", "rb") as f: 45 | self.data = pickle.load(f) 46 | with open(f"data_buffer/data_len{s}.pkl", "rb") as f: 47 | self.data_len = pickle.load(f) 48 | with open(f"data_buffer/result{s}.pkl", "rb") as f: 49 | self.result = pickle.load(f) 50 | self.white_win = self.result.count(WHITE_WIN) 51 | self.black_win = self.result.count(BLACK_WIN) 52 | print("load data successfully, with length %d" % len(self.data)) 53 | print("black: white = %d: %d in the memory" % (self.black_win, self.white_win)) 54 | except: 55 | from IPython import embed; 56 | embed() 57 | pass 58 | 59 | def isEmpty(self): 60 | return len(self.data) == 0 61 | 62 | def is_full(self): 63 | return len(self.data) >= self.length 64 | 65 | def push(self, data: list, result: int): 66 | data_len = len(data) # 数据的长度 67 | self.total_length += data_len 68 | self.num += 1 69 | if result == BLACK_WIN: 70 | self.self_play_black_win += 1 71 | elif result == WHITE_WIN: 72 | self.self_play_white_win += 1 73 | if self.total_length >= 100: 74 | t = time() 75 | print("black: white = %d: %d in the memory, avg_length: %0.1f avg: %0.3fs per piece" % ( 76 | self.black_win, self.white_win, self.total_length / self.num, (t - self.time) / self.total_length)) 77 | print("self-play black: %d, white: %d" % (self.self_play_black_win, self.self_play_white_win)) 78 | self.total_length = self.num = 0 79 | self.time = t 80 | # 太短小的数据就舍弃,长度为9的以0.75概率舍弃;长度为20的以0.0概率舍弃,中间线性过渡 81 | if random.random() <= -0.0682 * data_len + 1.364: 82 | return False 83 | self.data.extend(data) # 数据添加进去 84 | self.data_len.append(data_len) # 长度添加进去 85 | self.result.append(result) # 结果添加进去 86 | if result == BLACK_WIN: 87 | self.black_win += 1 88 | if random.random() < (self.white_win - self.black_win) / (self.black_win * 1.3): 89 | self.data.extend(data) # 数据添加进去 90 | self.data_len.append(data_len) # 长度添加进去 91 | self.result.append(result) # 结果添加进去 92 | self.black_win += 1 93 | 94 | elif result == WHITE_WIN: 95 | self.white_win += 1 96 | if random.random() < (self.black_win - self.white_win) / (self.white_win * 1.02): 97 | self.data.extend(data) # 数据添加进去 98 | self.data_len.append(data_len) # 长度添加进去 99 | self.result.append(result) # 结果添加进去 100 | self.white_win += 1 101 | beyond = len(self.data) - self.length 102 | if beyond > 0: 103 | self.data = self.data[beyond:] 104 | while True: 105 | if beyond >= self.data_len[0]: # 需要跳出去的数据长度大于第一条数据的长度 106 | beyond -= self.data_len[0] 107 | self.data_len.pop(0) 108 | result = self.result.pop(0) 109 | if result == BLACK_WIN: 110 | self.black_win -= 1 111 | elif result == WHITE_WIN: 112 | self.white_win -= 1 113 | else: 114 | self.data_len[0] -= beyond 115 | break 116 | return True 117 | 118 | def get_data(self, batch_size=1): 119 | num = min(batch_size, len(self.data)) 120 | idx = np.random.choice(len(self.data), size=num, replace=False) 121 | boards = np.empty((num, 3, self.board_size, self.board_size), dtype=np.float32) 122 | weights = np.empty((num,), dtype=np.float32) 123 | values = np.empty((num,), dtype=np.float32) 124 | policies = np.empty((num, self.board_size, self.board_size), dtype=np.float32) 125 | for i, ix in enumerate(idx): 126 | # 有序棋盘具有对称性,所以有旋转和翻转,共8种对称方式来进行数据增强 127 | state, p, la, v, w = self.data[ix] 128 | board = state_to_board(state, self.board_size) 129 | k = np.random.choice([0, 1, 2, 3]) 130 | board = np.rot90(board, k=k, axes=(0, 1)) 131 | p = np.rot90(p, k=k, axes=(0, 1)) 132 | if la is not None: 133 | la = [la, (self.board_size - 1 - la[1], la[0]), 134 | (self.board_size - 1 - la[0], self.board_size - 1 - la[1]), 135 | (la[1], self.board_size - 1 - la[0])][k] 136 | if random.choice([1, 2]) == 1: 137 | board = np.flip(board, axis=0) 138 | p = np.flip(p, axis=0) 139 | if la is not None: 140 | la = (self.board_size - 1 - la[0], la[1]) 141 | boards[i] = board_to_inputs(board, last_action=la) 142 | weights[i] = w 143 | values[i] = v 144 | policies[i] = p 145 | policies = policies.reshape((num, self.board_size * self.board_size)) 146 | return boards, weights, values, policies 147 | 148 | 149 | def softmax(x): 150 | max_value = np.max(x) 151 | probs = np.exp(x - max_value) 152 | probs /= np.sum(probs) 153 | return probs 154 | 155 | 156 | def board_to_state(board: np.ndarray) -> str: 157 | """ 158 | 由数组表示棋盘转换为字符串表示的棋盘 159 | :param board: 一个棋盘 160 | :return: 161 | """ 162 | fen = "" 163 | h, w = board.shape 164 | for i in range(h): 165 | c = 0 166 | for j in range(w): 167 | if board[i, j] == 0: 168 | c += 1 169 | else: 170 | fen += chr(ord('a') + c) if c > 0 else '' 171 | fen += str(board[i, j] + 2) 172 | c = 0 173 | fen += chr(ord('a') + c) if c > 0 else '' 174 | fen += '/' 175 | return fen 176 | 177 | 178 | def state_to_board(state: str, board_size: int): 179 | """ 180 | 根据字符串表示的state转换为棋盘。字符串中 181 | :param state: 182 | :param board_size: 183 | :return: 184 | """ 185 | board = np.zeros((board_size, board_size), np.int8) 186 | i = j = 0 187 | for ch in state: 188 | if ch == '/': 189 | i += 1 190 | j = 0 191 | elif ch.isalpha(): 192 | j += ord(ch) - ord('a') 193 | else: 194 | board[i][j] = int(ch) - 2 195 | j += 1 196 | return board 197 | 198 | 199 | def is_game_over(board: np.ndarray, goal: int) -> tuple: 200 | """ 201 | 基于当前玩家落子前,判断当前局面是否结束,一般来说若结束且非和棋都会返回-1.0, 202 | 因为现在轮到当前玩家落子了,但是游戏却已经结束了,结束前的最后一步一定是对手落子的,对手赢了,则返回-1 203 | :param board: 204 | :param goal:五子棋,goal就等于五 205 | :return: 206 | """ 207 | h, w = board.shape 208 | for i in range(h): 209 | for j in range(w): 210 | hang = sum(board[i: min(i + goal, w), j]) 211 | if hang == goal: 212 | return True, 1.0 213 | elif hang == -goal: 214 | return True, -1.0 215 | lie = sum(board[i, j: min(j + goal, h)]) 216 | if lie == goal: 217 | return True, 1.0 218 | elif lie == -goal: 219 | return True, -1.0 220 | # 斜线有点麻烦 221 | if i <= h - goal and j <= w - goal: 222 | xie = sum([board[i + k, j + k] for k in range(goal)]) 223 | if xie == goal: 224 | return True, 1.0 225 | elif xie == -goal: 226 | return True, -1.0 227 | if i >= goal - 1 and j <= w - goal: 228 | xie = sum([board[i - k, j + k] for k in range(goal)]) 229 | if xie == goal: 230 | return True, 1.0 231 | elif xie == -goal: 232 | return True, -1.0 233 | if np.where(board == 0)[0].shape[0] == 0: # 棋盘满了,和棋 234 | return True, 0.0 235 | return False, 0.0 236 | 237 | 238 | def get_legal_actions(board: np.ndarray): 239 | """ 240 | 根据棋局返回所有的合法落子位置 241 | :param board: 242 | :return: 243 | """ 244 | zeros = np.where(board == 0) 245 | return [(int(i), int(j)) for i, j in zip(*zeros)] 246 | 247 | 248 | def board_to_inputs2(board: np.ndarray, type_=np.float32): 249 | # return board.astype(np.float32) 250 | tmp1 = np.equal(board, 1).astype(type_) 251 | tmp2 = np.equal(board, -1).astype(type_) 252 | out = np.stack([tmp1, tmp2]) 253 | return out 254 | 255 | 256 | def board_to_inputs(board: np.ndarray, type_=np.float32, last_action=None): 257 | """ 258 | 根据当前棋局和上一次落子地方,生成network的输入。 259 | 第三个last action的channel估计可以去掉,影响很小。 260 | :param board: 261 | :param type_: 262 | :param last_action: 263 | :return: 264 | """ 265 | f1 = np.where(board == 1, 1.0, 0.0) 266 | f2 = np.where(board == -1, 1.0, 0.0) 267 | # return np.stack([f1, f2], axis=0).astype(type_) 268 | f3 = np.zeros(shape=board.shape, dtype=np.float32) 269 | if last_action is not None: 270 | f3[last_action[0], last_action[1]] = 1.0 271 | inputs = np.stack([f1, f2, f3], axis=0).astype(type_) 272 | return inputs 273 | 274 | 275 | def step(board: np.ndarray, action: tuple): 276 | """ 277 | 执行动作并翻转棋盘,保持当前选手为1代表的棋局 278 | :param board: 279 | :param action: 280 | :return: 281 | """ 282 | board[action[0], action[1]] = 1 283 | return -board 284 | 285 | 286 | def construct_weights(length: int, gamma=0.95): 287 | """ 288 | :param length: 289 | :param gamma: 290 | :return: 291 | """ 292 | w = np.empty((int(length),), np.float32) 293 | w[length - 1] = 1.0 # 最靠后的权重最大 294 | for i in range(length - 2, -1, -1): 295 | w[i] = w[i + 1] * gamma 296 | return length * w / np.sum(w) # 所有元素之和为length 297 | --------------------------------------------------------------------------------