├── ChessBoard.py ├── ChessGame.py ├── ChessGame_tf2.py ├── ChessPiece.py ├── ChessView.py ├── LICENSE ├── Mastering_Chess_and_Shogi_by_Self-Play_with_a_General_Reinforcement_Learning_Algorithm.ipynb ├── Mastering_the_Game_of_Go_without_Human_Knowledge.ipynb ├── README.md ├── assets ├── a1.png ├── a10.png ├── a11.png ├── a2.png ├── a3.png ├── a4.png ├── a5.png ├── a6.png ├── a7.png ├── a8.png ├── a9.png ├── b1.png ├── b2.png ├── b3.png ├── b4.png ├── b5.png ├── b6.png ├── b7.png ├── b8.png ├── b9.png ├── c1.png ├── c2.png └── c3.png ├── cchess-zero.ipynb ├── chessman ├── Bing.py ├── Bing.pyc ├── Che.py ├── Che.pyc ├── Ma.py ├── Ma.pyc ├── Pao.py ├── Pao.pyc ├── Shi.py ├── Shi.pyc ├── Shuai.py ├── Shuai.pyc ├── Xiang.py ├── Xiang.pyc ├── __init__.py ├── __init__.pyc └── __pycache__ │ ├── Bing.cpython-35.pyc │ ├── Che.cpython-35.pyc │ ├── Ma.cpython-35.pyc │ ├── Pao.cpython-35.pyc │ ├── Shi.cpython-35.pyc │ ├── Shuai.cpython-35.pyc │ ├── Xiang.cpython-35.pyc │ └── __init__.cpython-35.pyc ├── images ├── BA.GIF ├── BAS.GIF ├── BB.GIF ├── BBS.GIF ├── BC.GIF ├── BCS.GIF ├── BK.GIF ├── BKM.GIF ├── BKS.GIF ├── BN.GIF ├── BNS.GIF ├── BP.GIF ├── BPS.GIF ├── BR.GIF ├── BRS.GIF ├── OOS.GIF ├── RA.GIF ├── RAS.GIF ├── RB.GIF ├── RBS.GIF ├── RC.GIF ├── RCS.GIF ├── RK.GIF ├── RKM.GIF ├── RKS.GIF ├── RN.GIF ├── RNS.GIF ├── RP.GIF ├── RPS.GIF ├── RR.GIF ├── RRS.GIF └── WHITE.GIF ├── main.py ├── main_tf2.py ├── policy_value_network.py ├── policy_value_network_gpus.py ├── policy_value_network_gpus_tf2.py └── policy_value_network_tf2.py /ChessBoard.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | from chessman.Bing import * 3 | from chessman.Shuai import * 4 | from chessman.Pao import * 5 | from chessman.Shi import * 6 | from chessman.Xiang import * 7 | from chessman.Ma import * 8 | from chessman.Che import * 9 | 10 | 11 | class ChessBoard: 12 | pieces = dict() 13 | 14 | selected_piece = None 15 | 16 | def __init__(self, north_is_red = True): 17 | # self.north_is_red = north_is_red 18 | # north area 19 | ChessBoard.pieces[4, 0] = Shuai(4, 0, north_is_red, "north") 20 | 21 | ChessBoard.pieces[0, 3] = Bing(0, 3, north_is_red, "north") 22 | ChessBoard.pieces[2, 3] = Bing(2, 3, north_is_red, "north") 23 | ChessBoard.pieces[4, 3] = Bing(4, 3, north_is_red, "north") 24 | ChessBoard.pieces[6, 3] = Bing(6, 3, north_is_red, "north") 25 | ChessBoard.pieces[8, 3] = Bing(8, 3, north_is_red, "north") 26 | 27 | ChessBoard.pieces[1, 2] = Pao(1, 2, north_is_red, "north") 28 | ChessBoard.pieces[7, 2] = Pao(7, 2, north_is_red, "north") 29 | 30 | ChessBoard.pieces[3, 0] = Shi(3, 0, north_is_red, "north") 31 | ChessBoard.pieces[5, 0] = Shi(5, 0, north_is_red, "north") 32 | 33 | ChessBoard.pieces[2, 0] = Xiang(2, 0, north_is_red, "north") 34 | ChessBoard.pieces[6, 0] = Xiang(6, 0, north_is_red, "north") 35 | 36 | ChessBoard.pieces[1, 0] = Ma(1, 0, north_is_red, "north") 37 | ChessBoard.pieces[7, 0] = Ma(7, 0, north_is_red, "north") 38 | 39 | ChessBoard.pieces[0, 0] = Che(0, 0, north_is_red, "north") 40 | ChessBoard.pieces[8, 0] = Che(8, 0, north_is_red, "north") 41 | 42 | # south area 43 | ChessBoard.pieces[4, 9] = Shuai(4, 9, not north_is_red, "south") 44 | 45 | ChessBoard.pieces[0, 6] = Bing(0, 6, not north_is_red, "south") 46 | ChessBoard.pieces[2, 6] = Bing(2, 6, not north_is_red, "south") 47 | ChessBoard.pieces[4, 6] = Bing(4, 6, not north_is_red, "south") 48 | ChessBoard.pieces[6, 6] = Bing(6, 6, not north_is_red, "south") 49 | ChessBoard.pieces[8, 6] = Bing(8, 6, not north_is_red, "south") 50 | 51 | ChessBoard.pieces[1, 7] = Pao(1, 7, not north_is_red, "south") 52 | ChessBoard.pieces[7, 7] = Pao(7, 7, not north_is_red, "south") 53 | 54 | ChessBoard.pieces[3, 9] = Shi(3, 9, not north_is_red, "south") 55 | ChessBoard.pieces[5, 9] = Shi(5, 9, not north_is_red, "south") 56 | 57 | ChessBoard.pieces[2, 9] = Xiang(2, 9, not north_is_red, "south") 58 | ChessBoard.pieces[6, 9] = Xiang(6, 9, not north_is_red, "south") 59 | 60 | ChessBoard.pieces[1, 9] = Ma(1, 9, not north_is_red, "south") 61 | ChessBoard.pieces[7, 9] = Ma(7, 9, not north_is_red, "south") 62 | 63 | ChessBoard.pieces[0, 9] = Che(0, 9, not north_is_red, "south") 64 | ChessBoard.pieces[8, 9] = Che(8, 9, not north_is_red, "south") 65 | 66 | def can_move(self, x, y, dx, dy): 67 | return self.pieces[x, y].can_move(self, dx, dy) 68 | 69 | def move(self, x, y, dx, dy): 70 | return self.pieces[x, y].move(self, dx, dy) 71 | 72 | def remove(self, x, y): 73 | del self.pieces[x, y] 74 | 75 | def select(self, x, y, player_is_red): 76 | # 选中棋子 77 | if not self.selected_piece: 78 | if (x, y) in self.pieces and self.pieces[x, y].is_red == player_is_red: 79 | self.pieces[x, y].selected = True 80 | self.selected_piece = self.pieces[x, y] 81 | return False, None 82 | 83 | # 移动棋子 84 | if not (x, y) in self.pieces: 85 | if self.selected_piece: 86 | ox, oy = self.selected_piece.x, self.selected_piece.y 87 | if self.can_move(ox, oy, x-ox, y-oy): 88 | self.move(ox, oy, x-ox, y-oy) 89 | self.pieces[x,y].selected = False 90 | self.selected_piece = None 91 | return True, (ox, oy, x, y) 92 | return False, None 93 | 94 | # 同一个棋子 95 | if self.pieces[x, y].selected: 96 | return False, None 97 | 98 | # 吃子 99 | if self.pieces[x, y].is_red != player_is_red: 100 | ox, oy = self.selected_piece.x, self.selected_piece.y 101 | if self.can_move(ox, oy, x-ox, y-oy): 102 | self.move(ox, oy, x-ox, y-oy) 103 | self.pieces[x,y].selected = False 104 | self.selected_piece = None 105 | return True, (ox, oy, x, y) 106 | return False, None 107 | 108 | # 取消选中 109 | for key in self.pieces.keys(): 110 | self.pieces[key].selected = False 111 | # 选择棋子 112 | self.pieces[x, y].selected = True 113 | self.selected_piece = self.pieces[x,y] 114 | return False, None -------------------------------------------------------------------------------- /ChessGame.py: -------------------------------------------------------------------------------- 1 | from ChessBoard import * 2 | from ChessView import ChessView 3 | from main import * 4 | import tkinter 5 | 6 | def real_coord(x): 7 | if x <= 50: 8 | return 0 9 | else: 10 | return (x-50)//40 + 1 11 | 12 | 13 | def board_coord(x): 14 | return 30 + 40*x 15 | 16 | 17 | class ChessGame: 18 | 19 | board = None #ChessBoard() 20 | cur_round = 1 21 | game_mode = 1 # 0:HUMAN VS HUMAN 1:HUMAN VS AI 2:AI VS AI 22 | time_red = [] 23 | time_green = [] 24 | 25 | def __init__(self, in_ai_count, in_ai_function, in_play_playout, in_delay, in_end_delay, batch_size, search_threads, 26 | processor, num_gpus, res_block_nums, human_color = "b"): 27 | self.human_color = human_color 28 | self.current_player = "w" 29 | self.players = {} 30 | self.players[self.human_color] = "human" 31 | ai_color = "w" if self.human_color == "b" else "b" 32 | self.players[ai_color] = "AI" 33 | 34 | ChessGame.board = ChessBoard(self.human_color == 'b') 35 | self.view = ChessView(self, board=ChessGame.board) 36 | self.view.showMsg("Loading Models...") #"Red" player_color 37 | self.view.draw_board(self.board) 38 | ChessGame.game_mode = in_ai_count 39 | self.ai_function = in_ai_function 40 | self.play_playout = in_play_playout 41 | self.delay = in_delay 42 | self.end_delay = in_end_delay 43 | 44 | self.win_rate = {} 45 | self.win_rate['w'] = 0.0 46 | self.win_rate['b'] = 0.0 47 | 48 | self.view.root.update() 49 | self.cchess_engine = cchess_main(playout=self.play_playout, in_batch_size=batch_size, exploration=False, in_search_threads=search_threads, 50 | processor=processor, num_gpus=num_gpus, res_block_nums=res_block_nums, human_color=human_color) 51 | 52 | def player_is_red(self): 53 | return self.current_player == "w" 54 | 55 | def start(self): 56 | # below added by Fei Li 57 | self.view.showMsg("Red") 58 | if self.game_mode == 1: 59 | print ('-----Round %d-----' % self.cur_round) 60 | if self.players["w"] == "AI": 61 | self.win_rate['w'] = self.perform_AI() 62 | self.view.draw_board(self.board) 63 | self.change_player() 64 | elif self.game_mode == 2: 65 | print('-----Round %d-----' % self.cur_round) 66 | self.win_rate['w'] = self.perform_AI() 67 | self.view.draw_board(self.board) 68 | 69 | self.view.start() 70 | 71 | def disp_mcts_msg(self): 72 | self.view.showMsg("MCTS Searching...") 73 | 74 | def callback(self, event): 75 | if self.game_mode == 1 and self.players[self.current_player] == "AI": 76 | return 77 | if self.game_mode == 2: 78 | return 79 | rx, ry = real_coord(event.x), real_coord(event.y) 80 | # print(rx, ry) 81 | change, coord = self.board.select(rx, ry, self.player_is_red()) 82 | if self.view.print_text_flag == True: 83 | self.view.print_text_flag = False 84 | self.view.can.create_image(0, 0, image=self.view.img, anchor=tkinter.NW) 85 | self.view.draw_board(self.board) 86 | if self.check_end(): 87 | self.view.root.update() 88 | self.quit() 89 | return 90 | if change: 91 | if self.cur_round == 1 and self.human_color == 'w': 92 | self.view.showMsg("MCTS Searching...") 93 | 94 | self.win_rate[self.current_player] = self.cchess_engine.human_move(coord, self.ai_function) 95 | if self.check_end(): 96 | self.view.root.update() 97 | self.quit() 98 | return 99 | performed = self.change_player() 100 | if performed: 101 | self.view.draw_board(self.board) 102 | if self.check_end(): 103 | self.view.root.update() 104 | self.quit() 105 | return 106 | self.change_player() 107 | 108 | 109 | # below added by Fei Li 110 | 111 | def quit(self): 112 | time.sleep(self.end_delay) 113 | self.view.quit() 114 | 115 | def check_end(self): 116 | ret, winner = self.cchess_engine.check_end() 117 | if ret == True: 118 | if winner == "b": 119 | self.view.showMsg('*****Green Wins at Round %d*****' % self.cur_round) 120 | self.view.root.update() 121 | elif winner == "w": 122 | self.view.showMsg('*****Red Wins at Round %d*****' % self.cur_round) 123 | self.view.root.update() 124 | elif winner == "t": 125 | self.view.showMsg('*****Draw at Round %d*****' % self.cur_round) 126 | self.view.root.update() 127 | return ret 128 | 129 | def _check_end(self, board): 130 | red_king = False 131 | green_king = False 132 | pieces = board.pieces 133 | for (x, y) in pieces.keys(): 134 | if pieces[x, y].is_king: 135 | if pieces[x, y].is_red: 136 | red_king = True 137 | else: 138 | green_king = True 139 | if not red_king: 140 | self.view.showMsg('*****Green Wins at Round %d*****' % self.cur_round) 141 | self.view.root.update() 142 | return True 143 | elif not green_king: 144 | self.view.showMsg('*****Red Wins at Round %d*****' % self.cur_round) 145 | self.view.root.update() 146 | return True 147 | elif self.cur_round >= 200: 148 | self.view.showMsg('*****Draw at Round %d*****' % self.cur_round) 149 | self.view.root.update() 150 | return True 151 | return False 152 | 153 | def change_player(self): 154 | self.current_player = "w" if self.current_player == "b" else "b" 155 | if self.current_player == "w": 156 | self.cur_round += 1 157 | print ('-----Round %d-----' % self.cur_round) 158 | red_msg = " ({:.4f})".format(self.win_rate['w']) 159 | green_msg = " ({:.4f})".format(self.win_rate['b']) 160 | sorted_move_probs = self.cchess_engine.get_hint(self.ai_function, True, self.disp_mcts_msg) 161 | # print(sorted_move_probs) 162 | self.view.print_all_hint(sorted_move_probs) 163 | # self.move_images.append(tkinter.PhotoImage(file="images/OOS.gif")) 164 | # self.can.create_image(board_coord(x), board_coord(y), image=self.move_images[-1]) 165 | 166 | self.view.showMsg("Red" + red_msg + " Green" + green_msg if self.current_player == "w" else "Green" + green_msg + " Red" + red_msg) 167 | self.view.root.update() 168 | # if self.game_mode == 0: 169 | # return False 170 | if self.game_mode == 1: 171 | if self.players[self.current_player] == "AI": 172 | self.win_rate[self.current_player] = self.perform_AI() 173 | return True 174 | return False 175 | elif self.game_mode == 2: 176 | # if self.current_player == "w": 177 | # self.human_win_rate = self.perform_AI() 178 | # else: 179 | self.win_rate[self.current_player] = self.perform_AI() 180 | return True 181 | return False 182 | 183 | def perform_AI(self): 184 | print ('...AI is calculating...') 185 | START_TIME = time.clock() 186 | move, win_rate = self.cchess_engine.select_move(self.ai_function) 187 | time_used = time.clock() - START_TIME 188 | print ('...Use %fs...' % time_used) 189 | if self.current_player == "w": 190 | self.time_red.append(time_used) 191 | else: 192 | self.time_green.append(time_used) 193 | if move is not None: 194 | self.board.move(move[0], move[1], move[2], move[3]) 195 | return win_rate 196 | 197 | # AI VS AI mode 198 | def game_mode_2(self): 199 | self.change_player() 200 | self.view.draw_board(self.board) 201 | self.view.root.update() 202 | if self.check_end(): 203 | return True 204 | return False 205 | 206 | # game = ChessGame() 207 | # game.start() 208 | -------------------------------------------------------------------------------- /ChessGame_tf2.py: -------------------------------------------------------------------------------- 1 | from ChessBoard import * 2 | from ChessView import ChessView 3 | from main_tf2 import * 4 | import tkinter 5 | 6 | def real_coord(x): 7 | if x <= 50: 8 | return 0 9 | else: 10 | return (x-50)//40 + 1 11 | 12 | 13 | def board_coord(x): 14 | return 30 + 40*x 15 | 16 | 17 | class ChessGame: 18 | 19 | board = None #ChessBoard() 20 | cur_round = 1 21 | game_mode = 1 # 0:HUMAN VS HUMAN 1:HUMAN VS AI 2:AI VS AI 22 | time_red = [] 23 | time_green = [] 24 | 25 | def __init__(self, in_ai_count, in_ai_function, in_play_playout, in_delay, in_end_delay, batch_size, search_threads, 26 | processor, num_gpus, res_block_nums, human_color = "b"): 27 | self.human_color = human_color 28 | self.current_player = "w" 29 | self.players = {} 30 | self.players[self.human_color] = "human" 31 | ai_color = "w" if self.human_color == "b" else "b" 32 | self.players[ai_color] = "AI" 33 | 34 | ChessGame.board = ChessBoard(self.human_color == 'b') 35 | self.view = ChessView(self, board=ChessGame.board) 36 | self.view.showMsg("Loading Models...") #"Red" player_color 37 | self.view.draw_board(self.board) 38 | ChessGame.game_mode = in_ai_count 39 | self.ai_function = in_ai_function 40 | self.play_playout = in_play_playout 41 | self.delay = in_delay 42 | self.end_delay = in_end_delay 43 | 44 | self.win_rate = {} 45 | self.win_rate['w'] = 0.0 46 | self.win_rate['b'] = 0.0 47 | 48 | self.view.root.update() 49 | self.cchess_engine = cchess_main(playout=self.play_playout, in_batch_size=batch_size, exploration=False, in_search_threads=search_threads, 50 | processor=processor, num_gpus=num_gpus, res_block_nums=res_block_nums, human_color=human_color) 51 | 52 | def player_is_red(self): 53 | return self.current_player == "w" 54 | 55 | def start(self): 56 | # below added by Fei Li 57 | self.view.showMsg("Red") 58 | if self.game_mode == 1: 59 | print ('-----Round %d-----' % self.cur_round) 60 | if self.players["w"] == "AI": 61 | self.win_rate['w'] = self.perform_AI() 62 | self.view.draw_board(self.board) 63 | self.change_player() 64 | elif self.game_mode == 2: 65 | print('-----Round %d-----' % self.cur_round) 66 | self.win_rate['w'] = self.perform_AI() 67 | self.view.draw_board(self.board) 68 | 69 | self.view.start() 70 | 71 | def disp_mcts_msg(self): 72 | self.view.showMsg("MCTS Searching...") 73 | 74 | def callback(self, event): 75 | if self.game_mode == 1 and self.players[self.current_player] == "AI": 76 | return 77 | if self.game_mode == 2: 78 | return 79 | rx, ry = real_coord(event.x), real_coord(event.y) 80 | # print(rx, ry) 81 | change, coord = self.board.select(rx, ry, self.player_is_red()) 82 | if self.view.print_text_flag == True: 83 | self.view.print_text_flag = False 84 | self.view.can.create_image(0, 0, image=self.view.img, anchor=tkinter.NW) 85 | self.view.draw_board(self.board) 86 | if self.check_end(): 87 | self.view.root.update() 88 | self.quit() 89 | return 90 | if change: 91 | if self.cur_round == 1 and self.human_color == 'w': 92 | self.view.showMsg("MCTS Searching...") 93 | 94 | self.win_rate[self.current_player] = self.cchess_engine.human_move(coord, self.ai_function) 95 | if self.check_end(): 96 | self.view.root.update() 97 | self.quit() 98 | return 99 | performed = self.change_player() 100 | if performed: 101 | self.view.draw_board(self.board) 102 | if self.check_end(): 103 | self.view.root.update() 104 | self.quit() 105 | return 106 | self.change_player() 107 | 108 | 109 | # below added by Fei Li 110 | 111 | def quit(self): 112 | time.sleep(self.end_delay) 113 | self.view.quit() 114 | 115 | def check_end(self): 116 | ret, winner = self.cchess_engine.check_end() 117 | if ret == True: 118 | if winner == "b": 119 | self.view.showMsg('*****Green Wins at Round %d*****' % self.cur_round) 120 | self.view.root.update() 121 | elif winner == "w": 122 | self.view.showMsg('*****Red Wins at Round %d*****' % self.cur_round) 123 | self.view.root.update() 124 | elif winner == "t": 125 | self.view.showMsg('*****Draw at Round %d*****' % self.cur_round) 126 | self.view.root.update() 127 | return ret 128 | 129 | def _check_end(self, board): 130 | red_king = False 131 | green_king = False 132 | pieces = board.pieces 133 | for (x, y) in pieces.keys(): 134 | if pieces[x, y].is_king: 135 | if pieces[x, y].is_red: 136 | red_king = True 137 | else: 138 | green_king = True 139 | if not red_king: 140 | self.view.showMsg('*****Green Wins at Round %d*****' % self.cur_round) 141 | self.view.root.update() 142 | return True 143 | elif not green_king: 144 | self.view.showMsg('*****Red Wins at Round %d*****' % self.cur_round) 145 | self.view.root.update() 146 | return True 147 | elif self.cur_round >= 200: 148 | self.view.showMsg('*****Draw at Round %d*****' % self.cur_round) 149 | self.view.root.update() 150 | return True 151 | return False 152 | 153 | def change_player(self): 154 | self.current_player = "w" if self.current_player == "b" else "b" 155 | if self.current_player == "w": 156 | self.cur_round += 1 157 | print ('-----Round %d-----' % self.cur_round) 158 | red_msg = " ({:.4f})".format(self.win_rate['w']) 159 | green_msg = " ({:.4f})".format(self.win_rate['b']) 160 | sorted_move_probs = self.cchess_engine.get_hint(self.ai_function, True, self.disp_mcts_msg) 161 | # print(sorted_move_probs) 162 | self.view.print_all_hint(sorted_move_probs) 163 | # self.move_images.append(tkinter.PhotoImage(file="images/OOS.gif")) 164 | # self.can.create_image(board_coord(x), board_coord(y), image=self.move_images[-1]) 165 | 166 | self.view.showMsg("Red" + red_msg + " Green" + green_msg if self.current_player == "w" else "Green" + green_msg + " Red" + red_msg) 167 | self.view.root.update() 168 | # if self.game_mode == 0: 169 | # return False 170 | if self.game_mode == 1: 171 | if self.players[self.current_player] == "AI": 172 | self.win_rate[self.current_player] = self.perform_AI() 173 | return True 174 | return False 175 | elif self.game_mode == 2: 176 | # if self.current_player == "w": 177 | # self.human_win_rate = self.perform_AI() 178 | # else: 179 | self.win_rate[self.current_player] = self.perform_AI() 180 | return True 181 | return False 182 | 183 | def perform_AI(self): 184 | print ('...AI is calculating...') 185 | START_TIME = time.clock() 186 | move, win_rate = self.cchess_engine.select_move(self.ai_function) 187 | time_used = time.clock() - START_TIME 188 | print ('...Use %fs...' % time_used) 189 | if self.current_player == "w": 190 | self.time_red.append(time_used) 191 | else: 192 | self.time_green.append(time_used) 193 | if move is not None: 194 | self.board.move(move[0], move[1], move[2], move[3]) 195 | return win_rate 196 | 197 | # AI VS AI mode 198 | def game_mode_2(self): 199 | self.change_player() 200 | self.view.draw_board(self.board) 201 | self.view.root.update() 202 | if self.check_end(): 203 | return True 204 | return False 205 | 206 | # game = ChessGame() 207 | # game.start() 208 | -------------------------------------------------------------------------------- /ChessPiece.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class ChessPiece: 4 | 5 | selected = False 6 | is_king = False 7 | def __init__(self, x, y, is_red, direction): 8 | self.x = x 9 | self.y = y 10 | self.is_red = is_red 11 | self.direction = direction 12 | 13 | def is_north(self): 14 | return self.direction == 'north' 15 | 16 | def is_south(self): 17 | return self.direction == 'south' 18 | 19 | def get_move_locs(self, board): 20 | moves = [] 21 | for x in range(9): 22 | for y in range(10): 23 | if (x,y) in board.pieces and board.pieces[x,y].is_red == self.is_red: 24 | continue 25 | if self.can_move(board, x-self.x, y-self.y): 26 | moves.append((x,y)) 27 | return moves 28 | def move(self, board, dx, dy): 29 | nx, ny = self.x + dx, self.y + dy 30 | if (nx, ny) in board.pieces: 31 | board.remove(nx, ny) 32 | board.remove(self.x, self.y) 33 | #print 'Move a chessman from (%d,%d) to (%d,%d)'%(self.x, self.y, self.x+dx, self.y+dy) 34 | self.x += dx 35 | self.y += dy 36 | board.pieces[self.x, self.y] = self 37 | return True 38 | 39 | def count_pieces(self, board, x, y, dx, dy): 40 | sx = dx/abs(dx) if dx!=0 else 0 41 | sy = dy/abs(dy) if dy!=0 else 0 42 | nx, ny = x + dx, y + dy 43 | x, y = x + sx, y + sy 44 | cnt = 0 45 | while x != nx or y != ny: 46 | if (x, y) in board.pieces: 47 | cnt += 1 48 | x += sx 49 | y += sy 50 | return cnt 51 | -------------------------------------------------------------------------------- /ChessView.py: -------------------------------------------------------------------------------- 1 | import tkinter 2 | import time 3 | 4 | def board_coord(x): 5 | return 30 + 40*x 6 | 7 | class ChessView: 8 | root = tkinter.Tk() 9 | root.title("Chinese Chess") 10 | root.resizable(0, 0) 11 | can = tkinter.Canvas(root, width=373, height=410) 12 | can.pack(expand=tkinter.YES, fill=tkinter.BOTH) 13 | img = tkinter.PhotoImage(file="images/WHITE.gif") 14 | can.create_image(0, 0, image=img, anchor=tkinter.NW) 15 | piece_images = dict() 16 | move_images = [] 17 | def draw_board(self, board): 18 | self.piece_images.clear() 19 | self.move_images = [] 20 | pieces = board.pieces 21 | for (x, y) in pieces.keys(): 22 | self.piece_images[x, y] = tkinter.PhotoImage(file=pieces[x, y].get_image_file_name()) 23 | self.can.create_image(board_coord(x), board_coord(y), image=self.piece_images[x, y]) 24 | if board.selected_piece: 25 | for (x, y) in board.selected_piece.get_move_locs(board): 26 | self.move_images.append(tkinter.PhotoImage(file="images/OOS.gif")) 27 | self.can.create_image(board_coord(x), board_coord(y), image=self.move_images[-1]) 28 | # self.can.create_text(board_coord(x), board_coord(y),text="Hello") 29 | 30 | # label = tkinter.Label(self.root, text='Hello world!') 31 | # label.place(x=30,y=30) 32 | # label.pack(fill='x', expand=1) 33 | 34 | def disp_hint_on_board(self, action, percentage): 35 | board = self.board 36 | for key in board.pieces.keys(): 37 | board.pieces[key].selected = False 38 | board.selected_piece = None 39 | 40 | self.can.create_image(0, 0, image=self.img, anchor=tkinter.NW) 41 | self.draw_board(board) 42 | # self.can.create_text(board_coord(self.last_text_x), board_coord(self.last_text_y), text="") 43 | x_trans = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8} 44 | 45 | src = action[0:2] 46 | dst = action[2:4] 47 | 48 | src_x = int(x_trans[src[0]]) 49 | src_y = int(src[1]) 50 | 51 | dst_x = int(x_trans[dst[0]]) 52 | dst_y = int(dst[1]) 53 | 54 | pieces = board.pieces 55 | if (src_x, src_y) in pieces.keys(): 56 | self.piece_images[src_x, src_y] = tkinter.PhotoImage(file=pieces[src_x, src_y].get_selected_image()) 57 | self.can.create_image(board_coord(src_x), board_coord(src_y), image=self.piece_images[src_x, src_y]) 58 | 59 | if (dst_x, dst_y) in pieces.keys(): 60 | self.piece_images[dst_x, dst_y] = tkinter.PhotoImage(file=pieces[dst_x, dst_y].get_selected_image()) 61 | self.can.create_image(board_coord(dst_x), board_coord(dst_y), image=self.piece_images[dst_x, dst_y]) 62 | self.can.create_text(board_coord(dst_x), board_coord(dst_y), text="{:.3f}".format(percentage)) 63 | self.last_text_x = dst_x 64 | self.last_text_y = dst_y 65 | else: 66 | self.move_images.append(tkinter.PhotoImage(file="images/OOS.gif")) 67 | self.can.create_image(board_coord(dst_x), board_coord(dst_y), image=self.move_images[-1]) 68 | self.can.create_text(board_coord(dst_x), board_coord(dst_y),text="{:.3f}".format(percentage)) 69 | self.last_text_x = dst_x 70 | self.last_text_y = dst_y 71 | self.print_text_flag = True 72 | # return (src_x, src_y, dst_x - src_x, dst_y - src_y), win_rate 73 | 74 | def print_all_hint(self, sorted_move_probs): 75 | 76 | # for i in range(len(sorted_move_probs)): 77 | # self.lb.insert(END, str(i * 100)) 78 | 79 | self.lb.delete(0, "end") 80 | for item in sorted_move_probs: 81 | # print(item[0], item[1]) 82 | self.lb.insert("end", item) 83 | self.lb.pack() 84 | 85 | def showMsg(self, msg): 86 | print(msg) 87 | self.root.title(msg) 88 | 89 | def printList(self, event): 90 | # print(self.lb.curselection()) 91 | # print(self.lb.get(self.lb.curselection())) 92 | # for i in range(self.lb.size()): 93 | # print(i, self.lb.selection_includes(i)) 94 | w = event.widget 95 | index = int(w.curselection()[0]) 96 | value = w.get(index) 97 | print(value) 98 | self.disp_hint_on_board(value[0], value[1]) 99 | 100 | 101 | def __init__(self, control, board): 102 | self.control = control 103 | if self.control.game_mode != 2: 104 | self.can.bind('', self.control.callback) 105 | 106 | self.lb = tkinter.Listbox(ChessView.root,selectmode="browse") 107 | self.scr1 = tkinter.Scrollbar(ChessView.root) 108 | self.lb.configure(yscrollcommand=self.scr1.set) 109 | self.scr1['command'] = self.lb.yview 110 | self.scr1.pack(side='right',fill="y") 111 | self.lb.pack(fill="x") 112 | 113 | self.lb.bind('<>', self.printList) # Double- 114 | self.board = board 115 | self.last_text_x = 0 116 | self.last_text_y = 0 117 | self.print_text_flag = False 118 | 119 | # def start(self): 120 | # tkinter.mainloop() 121 | def start(self): 122 | if self.control.game_mode == 2: 123 | self.root.update() 124 | time.sleep(self.control.delay) 125 | while True: 126 | game_end = self.control.game_mode_2() 127 | self.root.update() 128 | time.sleep(self.control.delay) 129 | if game_end: 130 | time.sleep(self.control.end_delay) 131 | self.quit() 132 | return 133 | else: 134 | tkinter.mainloop() 135 | # self.root.mainloop() 136 | 137 | # below added by Fei Li 138 | 139 | def quit(self): 140 | self.root.quit() 141 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Mastering_Chess_and_Shogi_by_Self-Play_with_a_General_Reinforcement_Learning_Algorithm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 用通用强化学习算法自我对弈,掌握国际象棋和将棋" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "[`程世东`](http://zhihu.com/people/cheng-shi-dong-47) 翻译\n", 15 | "\n", 16 | "[`GitHub`](http://github.com/chengstone) [`Mail`](mailto:69558140@163.com)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "国际象棋是人工智能史上研究最为广泛的领域。最强大的象棋程序是基于复杂的搜索技术、适应于特定领域、和过去几十年里人类专家手工提炼的评估函数的结合。相比之下,通过自我对弈进行“白板”强化学习,在围棋游戏中AlphaGo Zero取得了超越人类的成绩。在本文中,我们将这种方法推广到一个单一的AlphaZero算法中,从“白板”开始学习,可以在许多具有挑战性的领域具有超越人类的表现。从随机下棋开始,除了游戏规则之外没有给予任何领域知识,AlphaZero在24小时内实现了在国际象棋、将棋(日本象棋)和围棋上的超人类水平,并击败了每一个世界冠军程序。" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "对计算机象棋的研究和计算机科学本身一样历史悠久。巴贝奇,图灵,香农和冯诺依曼都设计过计算机硬件、算法和理论来分析和指导下棋。国际象棋随后成为一代人工智能研究人员的挑战性任务,最终以高性能的超越人类水平的计算机国际象棋程序的出现而告终。然而,这些系统高度适应与它们的特定领域,不投入大量的人力是不能推广到其他问题的。" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "人工智能的长期目标是创造出可以从最初规则中自我学习的程序。最近,AlphaGo Zero算法通过使用深度卷积神经网络来表达围棋知识,只通过自我对弈的强化学习来训练,在围棋中实现了超人的表现。在本文中,我们应用了一个类似但完全通用的算法,我们将该算法[arXiv:1712.01815v1 [cs.AI] 5 Dec 2017]称之为AlphaZero,用来像围棋一样下国际象棋和将棋,除了游戏的规则外没有给予任何额外的领域知识,这个算法表明通用强化学习算法可以实现以“白板”方式学习,在多个具有挑战性的领域获得超人的表现。" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "1997年,“深蓝”击败了国际象棋人类世界冠军,这是人工智能的一个里程碑。计算机国际象棋程序在那以后的二十多年继续稳步超越人类水平。这些程序使用人类专家的知识和精心调校的参数评估[`走子`](https://baike.baidu.com/item/%E8%B5%B0%E5%AD%90/97274)位置,结合高性能的alpha-beta搜索,使用大量启发式和领域特定的适应性来扩展巨大的搜索树。在[`方法`](#方法)一节我们描述这些增强方法,重点关注2016年顶级国际象棋引擎锦标赛(TCEC)世界冠军Stockfish,其他强大的国际象棋程序,包括深蓝,使用的是非常相似的架构。" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "在计算复杂性方面,相比国际象棋,将棋更难:它是在一个更大的棋盘上玩,任何被俘获的对手棋子都会改变方向,随后可能被放置在棋盘上的任何位置。最强大的将棋程序,如电脑将棋协会(CSA)的世界冠军Elmo,直到最近才击败人类冠军。这些程序使用与计算机国际象棋程序类似的算法,基于高度优化的alpha-beta搜索引擎,具有许多特定领域的适应性。" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "围棋非常适合AlphaGo中使用的神经网络架构,因为游戏规则是平移不变的(匹配卷积网络的权值共享结构),是根据棋盘上的走子点之间的相邻点的自由度来定义的(匹配卷积网络的局部结构),并且是旋转和反射对称的(允许数据增强和合成)。而且,动作空间很简单(一颗棋子可以放在任何可能的位置),游戏结果只有二元结果赢或输,这两者都有助于神经网络的训练。" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "国际象棋和将棋不太适合AlphaGo的神经网络架构。这些规则是与位置有关的(例如,兵可以从第二横线前进两步,在第八横线上升变)和不对称的(例如,兵只能向前移动,在王翼和后翼的王车易位是不同的)。规则包括远程交互(例如皇后可以一次穿过整个棋盘,或者从棋盘的另一边将军)。国际象棋的行动空间包括棋盘上所有棋手棋子的所有符合规则的位置;将棋允许将被吃掉的棋子放回棋盘上。国际象棋和将棋都可能造成平局;事实上,人们认为国际象棋最佳的解决方案是平局。" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "AlphaZero算法是AlphaGo Zero算法的更通用的版本。它用深度神经网络和白板强化学习算法,替代传统程序中使用的人工先验知识和特定领域增强。" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "为取代手工制作的评估函数和启发式移动排序,AlphaZero使用参数为θ的深度神经网络$(p,v)= f_θ (s)$。这个神经网络使用局面(棋盘状态)s作为输入,输出走子概率向量p,它包含每一个走子动作a的概率分量$p_a = Pr(a|s)$,同时输出一个标量值v(胜率)——从局面s估算预期结果$z,v ≈ Е[z|s]$。AlphaZero完全从自我对弈中学习这些走子概率和价值估计;然后将学到的知识指导其搜索。" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "为取代具有特定领域增强的alpha-beta搜索,AlphaZero使用通用的蒙特卡洛树搜索(MCTS)算法。每次搜索都包含一系列从根节点$s_{root}$到叶子节点遍历树的自我对弈模拟。每次模拟都是通过在每个状态s下,根据当前的神经网络$f_θ$,选择一个访问次数低、走子概率高和价值高的走子走法a(这些值是从状态s中选择的动作a的叶子节点状态上做平均)。搜索返回一个表示走子概率分布的向量π ,是在根节点状态下关于访问计数的概率分布(无论是按比例还是贪婪算法)。" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "AlphaZero深度神经网络的参数θ,从随机初始化参数开始,通过自我对弈强化学习进行训练。通过MCTS($a_t$ ∼ $π_t$ )轮流为两个棋手选择走子进行下棋。在棋局结束时,根据游戏规则计算游戏结果z 作为结束位置$s_T$的评分:-1代表失败,0代表平局,+1代表胜利。更新神经网络参数θ以使预测结果$v_t$与游戏结果z之间的误差最小,并且使策略向量$p_t$与搜索概率$π_t$的相似度最大。具体而言,参数θ通过在均方误差和交叉熵损失之和上的损失函数l上做梯度下降进行调整," 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "$(p,v)= f_θ (s), l = (z - v)^2- π^T log p + c||θ||^2$" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "其中c是控制L2正则化水平的参数。更新的参数被用于随后的自我对弈中。" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "本文描述的AlphaZero算法在几个方面与原始的AlphaGo Zero算法不同。" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "AlphaGo Zero在假设只有赢或输二元结果的情况下,对获胜概率进行估计和优化。AlphaZero会考虑平局或潜在的其他结果,对预期的结果进行估算和优化。" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "围棋的规则是旋转和反转不变的。对此,在AlphaGo和AlphaGo Zero中有两种使用方式。首先,训练数据通过为每个局面生成8个对称图像来增强。其次,MCTS期间,棋盘位置在被神经网络评估前,会使用随机选择的旋转或反转变换进行转换,以便蒙特卡洛评估在不同的偏差上进行平均。国际象棋和将棋的规则是不对称的。AlphaZero不会增强训练数据,也不会在MCTS期间转换棋盘位置。" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "在AlphaGo Zero中,自我对弈是由以前所有迭代中最好的玩家生成的。每次训练迭代之后,与最好玩家对弈测量新玩家的能力;如果以55%的优势获胜,那么它将取代最好的玩家,而自我对弈将由这个新玩家产生。相反,AlphaZero只维护一个不断更新的单个神经网络,而不是等待迭代完成。自我对弈是通过使用这个神经网络的最新参数生成的,省略了评估步骤和选择最佳玩家的过程。" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "AlphaGo Zero通过贝叶斯优化调整搜索的超参数。在AlphaZero中,我们为所有棋局重复使用相同的超参数,而无需进行特定于某种游戏的调整。唯一的例外是为保证探索而添加到先验策略中的噪声;这与棋局类型的典型合法走子的数量成比例。" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "像AlphaGo Zero一样,棋盘状态仅由基于每个游戏的基本规则的空间平面编码。下棋的行动是由空间平面或平面向量编码的,而且仅仅基于每种游戏的基本规则(参见[`方法`](#方法))。" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "我们将AlphaZero算法应用于国际象棋,将棋,还有围棋。除非另有说明,所有三个游戏都使用相同的算法,网络架构和超参数。我们为每一种棋单独训练了一个AlphaZero。从随机初始化的参数开始,使用5,000个第一代TPU生成自我对弈数据和64个第二代TPU来训练神经网络,训练进行了700,000步(mini-batches 大小是4096)。 [`方法`](#方法)中提供了训练步骤的更多细节。" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "图1显示了AlphaZero在自我对弈强化学习期间的表现。在国际象棋中,AlphaZero仅仅用了4小时(300k步)就胜过了Stockfish;在将棋中,AlphaZero在不到2小时(110K步)就胜过了Elmo;而在围棋中,AlphaZero 8小时(165k步)就胜过了AlphaGo Lee。" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": {}, 169 | "source": [ 170 | "\n", 171 | "![b1](assets\\b1.png\")\n", 172 | "图1:训练AlphaZero 70万步。国际等级分是在不同的玩家之间的比赛进行评估计算出来的,每一步棋有1秒的思考时间。a国际象棋中AlphaZero的表现,与2016年TCEC世界冠军程序Stockfish比较。b在将棋中AlphaZero的表现,与2017年CSA世界冠军程序Elmo比较。c 在围棋中AlphaZero的表现,与AlphaGo Lee和AlphaGo Zero(20 block / 3天)比较。" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "我们使用所有训练好的AlphaZero,分别在国际象棋、将棋和围棋中与Stockfish, Elmo和上一个版本的AlphaGo Zero(训练3天)进行了100场比赛,时间控制在每步棋1分钟。AlphaZero和之前的AlphaGo Zero使用一台带有4个TPU的机器。Elmo和Stockfish使用他们最强的版本,使用64个线程和1GB hash。AlphaZero击败了所有的对手,对Stockfish 零封对手,对 Elmo输了8局(见几个棋局的补充材料),以及击败以前版本的AlphaGo Zero(见表1)。" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "\n", 187 | "![b2](assets\\b2.png\")\n", 188 | "表1: 在国际象棋,将棋和围棋中评估AlphaZero,以AlphaZero的角度的胜平负,与Stockfish, Elmo,和训练了三天的AlphaGo Zero进行100场比赛。每个程序下一步棋有1分钟的思考时间。" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "我们还分析了AlphaZero的MCTS搜索的表现,与Stockfish和Elmo使用的alpha-beta搜索引擎进行比较。AlphaZero在国际象棋中每秒只搜索8万个局面(positions),在将棋中搜索4万个,相比之下,Stockfish要搜索7000万个,Elmo搜索3500万个。AlphaZero通过使用其深层神经网络更有选择性地关注最有希望的[`变着`](https://baike.baidu.com/item/%E5%9B%BD%E9%99%85%E8%B1%A1%E6%A3%8B%E6%9C%AF%E8%AF%AD/7549734?fr=aladdin),补偿较低数量的评估- 可以说是像Shannon最初提出的那样,是一种更“人性化”的搜索方法。图2显示了每个玩家的思考时间,以国际等级分衡量,相对于Stockfish或者Elmo,思考时间为40ms。AlphaZero的MCTS的思考时间比Stockfish或Elmo更有效,这使得人们对普遍持有的观点认为alpha-beta搜索在这些领域本质上是优越的产生了质疑。" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "最后,我们分析了AlphaZero发现的国际象棋知识。表2分析了最常见的人类开局(在人类国际象棋游戏的在线数据库中出现过超过10万次)。在自我对弈训练期间,AlphaZero独立地发现和使用了这些开局。从每个人类的开局开始,AlphaZero击败了Stockfish,表明它确实掌握了广泛的国际象棋玩法。" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "国际象棋代表了过去几十年人工智能研究的巅峰。最先进的象棋程序基于强大的引擎,搜索数以百万计的局面,利用领域的专业知识和复杂的领域适应性。AlphaZero是一个通用的强化学习算法 - 最初为围棋而设计 - 在几个小时内取得了优异的成绩,搜索次数减少了1000倍,除了国际象棋规则之外不需要任何领域知识。此外,同样的算法不经修改也适用于更具挑战性的将棋游戏,在几小时内再次超越了当前最先进的水平。" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "![b3](assets\\b3.png\")" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": {}, 222 | "source": [ 223 | "![b4](assets\\b4.png\")\n", 224 | "表2:分析12个最受欢迎的人类开局(在线数据库中出现超过10万次)。每个开局标有其ECO代码和通用名称。该图显示了AlphaZero在自我训练比赛时使用的每次开局的比例。我们还从AlphaZero的角度报告了从每个开局开始与Stockfish 100场比赛的胜负/平局/失败结果,无论是白色(W)还是黑色(B)。最后,从每个开局提供AlphaZero的主要变着(PV)。" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": {}, 230 | "source": [ 231 | "![b5](assets\\b5.png\")\n", 232 | "图2:关于AlphaZero思考时间的可扩展性,以国际等级分衡量。a在国际象棋中的AlphaZero和Stockfish的表现,描画每一步的思考时间。b在将棋中AlphaZero和Elmo的表现,描画每一步的思考时间。" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": {}, 238 | "source": [ 239 | "# 方法" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "## 计算机国际象棋程序剖析" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "在本节中,我们将描述一个典型的计算机国际象棋程序的组件,特别关注Stockfish,这是一个赢得2016年TCEC电脑国际象棋锦标赛的开源程序。" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": {}, 259 | "source": [ 260 | "每个局面s由手工特征φ(s)的稀疏向量描述,包括特定中局/残局[`子力`](http://chessprogramming.wikispaces.com/material)(译者注:由每条线上棋子价值的总和确定的一个术语。所有的棋子和兵。[`子力优势`](https://baike.baidu.com/item/%E5%9B%BD%E9%99%85%E8%B1%A1%E6%A3%8B%E6%9C%AF%E8%AF%AD/7549734?fr=aladdin)是棋手在棋盘上有比对手更多的棋子或棋子的价值更大。)点的价值,[`子力不平衡表`](http://chessprogramming.wikispaces.com/Material+Tables)(译者注:例如 车vs两个[`轻子`](https://baike.baidu.com/item/%E8%BD%BB%E5%AD%90)(Minor pieces:象和马),皇后vs两个车或三个轻子,三个兵vs普通棋子),[`Piece-Square表`](http://chessprogramming.wikispaces.com/Piece-Square+Tables)(译者注:给特定位置上的特定棋子分配一个值),[`机动性`](http://chessprogramming.wikispaces.com/Mobility)(译者注:衡量一个玩家在一个给定的位置上合法移动的选择数量,[`棋子的行动自由`](https://baike.baidu.com/item/%E5%9B%BD%E9%99%85%E8%B1%A1%E6%A3%8B%E6%9C%AF%E8%AF%AD/7549734?fr=aladdin)。)和[`被困棋子`](http://chessprogramming.wikispaces.com/Trapped+Pieces)(译者注:被困棋子是移动性差的极端例子),[`兵型`](http://chessprogramming.wikispaces.com/Pawn+Structure)(译者注:用来描述棋盘上所有兵的位置,忽略所有其他棋子。也指兵骨架。所有兵的位置的各个方面。),[`国王安全性`](http://chessprogramming.wikispaces.com/King+Safety),[`前哨`](http://chessprogramming.wikispaces.com/Outposts)(译者注:通常与马在棋盘中心或敌方一侧有关的国际象棋术语,被自己的棋子保护不再受对手棋子的攻击,或者在半开放线削弱对手的棋子,不再徒劳无功),[`双象`](https://en.wikipedia.org/wiki/Glossary_of_chess#Bishop_pair)(译者注:棋手是否有两个象),和其他复杂的评估 模型。通过手动和自动调整的组合,每个特征$φ_i$被分配相应的权重$w_i$,并且通过线性组合$v(s,w)=φ(s)^T w$来评估局面。然而,对于安全的位置,这个原始评估仅被认为是准确的,不包括未解决的[`吃子`](http://chessprogramming.wikispaces.com/Captures)和[`将军`](http://chessprogramming.wikispaces.com/Check)。在应用评估函数之前,使用领域专用的[`静止搜索`](http://chessprogramming.wikispaces.com/Quiescence+Search)来解决正在进行的战术局势。" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": {}, 266 | "source": [ 267 | "局面s的最终评估是通过使用静止搜索评估每个叶子的极小极大搜索来计算的。alpha-beta剪枝被用来安全地剪切任何可能被另一个变着控制的分支。额外的剪切是使用愿望窗口和主要变着搜索实现的。其他剪枝策略包括无效走子修剪(假定走子以后结果比任何变着还要差),徒劳修剪(假设知道评估中可能的最大变着),和其他依赖于领域的修剪规则(假设知道被吃棋子的价值)。" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "metadata": {}, 273 | "source": [ 274 | "搜索的重点是关注有希望的变着,通过扩展有希望的变着的搜索深度,并通过基于历史,静态交换评估(SEE)和移动的棋子类型等启发式技术减少没有希望的变着的搜索深度。扩展是基于独立于领域的规则,这些规则用于识别单一的走子,没有合适的选择余地,以及依赖于领域的规则,比如扩展检查走子。减少(译者注:搜索深度),如后期走子减少,主要依赖于领域知识。" 275 | ] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "metadata": {}, 280 | "source": [ 281 | "alpha-beta搜索的效率主要取决于下棋走子的顺序。因此,走子是通过迭代加深来排序的(使用更浅的搜索命令移动以进行更深入的搜索)。此外,结合了与领域无关的启发式走子排序,如杀手启发式,历史启发式,相反走子启发式,以及基于捕获(SEE)和潜在捕获(MVV / LVA)的领域相关知识。" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "[`换位表`](http://chessprogramming.wikispaces.com/Transposition+Table)(译者注:是存储先前执行的搜索的结果的数据库)便于重复使用在多个路径达到相同位置时的下棋顺序和值。经过仔细调整的开局库用于在棋局开始时选择走子。通过对残局位置的彻底逆向分析预先设计的残局库,在六个、有时七个或更少的所有位置提供最佳的走子。" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "其他强大的国际象棋程序,以及像“深蓝”这样的早期程序,都使用了非常类似的架构,包括上述的大部分组件,虽然重要的细节差别很大。" 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": {}, 301 | "source": [ 302 | "AlphaZero不使用本节中描述的技术。这些技术中的一些可能会进一步提高AlphaZero的性能;然而,我们专注于纯粹的自我对弈强化学习方法,并将这些扩展留给未来研究。" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "metadata": {}, 308 | "source": [ 309 | "## 计算机国际象棋和将棋上的早期工作" 310 | ] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "metadata": {}, 315 | "source": [ 316 | "在本节中,我们将讨论一些关于计算机国际象棋在强化学习上的重要早期工作成果。" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "NeuroChess通过使用175个手工输入特征的神经网络评估局面。它被训练通过时序差分学习预测最终的游戏结果,以及两步棋之后的预期特征。NeuroChess赢得了对GnuChess 13%的比赛,使用固定的深度2搜索。" 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "Beal和Smith应用时序差分学习来估计国际象棋和将棋中的棋子值,从随机值开始,单独通过自我对弈来学习。" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": {}, 336 | "source": [ 337 | "KnightCap通过一个神经网络来评估局面,这个神经网络使用了一个基于哪个区域受到哪些棋子攻击或防守的知识的攻击表。它是由时序差分学习的一种变体(称为TD(叶))进行训练的,它更新了alpha-beta搜索的主要变着的叶子值。KnightCap在训练之后与使用手动初始化棋子值权重的强大计算机对手对弈,达到了人类大师级别的能力。" 338 | ] 339 | }, 340 | { 341 | "cell_type": "markdown", 342 | "metadata": {}, 343 | "source": [ 344 | "Meep通过基于手工特征的线性评估函数来评估局面。它是由另一个时序差分学习的变种(被称为TreeStrap)训练的,它更新了一个alpha-beta搜索的所有节点。Meep经过随机初始权重自我对弈训练之后,在15场比赛中13场击败了人类国际大师级棋手。" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "metadata": {}, 350 | "source": [ 351 | "Kaneko和Hoki通过学习在alpha-beta 搜索期间选择人类专家的走子,来训练包括一百万个特征的将棋评估函数的权重。他们还基于根据专家棋局日志调整的极小极大搜索进行了大规模优化;这是获得2013年世界计算机将棋冠军的Bonanza引擎的一部分。" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "Giraffe通过一个神经网络评估局面,包括移动能力地图和描述每个方格(走子点)的攻击者和防御者的最低值的攻击和捍卫地图。它通过使用TD(叶)的自我对弈训练,也达到了与国际大师相当的水平。" 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "metadata": {}, 364 | "source": [ 365 | "DeepChess训练了一个神经网络来执行成对的局面评估。它是通过监督学习从人类专家对弈数据库进行训练的,这些棋局是经过预先过滤的,以避免吃子棋和平局。DeepChess达到了一个强大的特级大师的水平。" 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "metadata": {}, 371 | "source": [ 372 | "所有这些程序都将他们学到的评估函数与各种扩展增强的alpha-beta搜索功能相结合。" 373 | ] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "metadata": {}, 378 | "source": [ 379 | "基于使用像AlphaZero策略迭代的双重策略和价值网络的训练方法已经成功应用于改进Hex的最新技术。" 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "metadata": {}, 385 | "source": [ 386 | "## MCTS 和Alpha-Beta 搜索" 387 | ] 388 | }, 389 | { 390 | "cell_type": "markdown", 391 | "metadata": {}, 392 | "source": [ 393 | "至少四十年来,最强大的计算机国际象棋程序已经使用alpha-beta搜索。AlphaZero使用明显不同的方法来平均子树内的局面评估,而不是计算该子树的最小最大估计。但是,使用传统MCTS的国际象棋程序比Alpha-Beta搜索程序弱得多;而基于神经网络的alpha-beta程序以前不能与更快的手工评估函数对抗。" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": {}, 399 | "source": [ 400 | "AlphaZero局面评估使用基于深度神经网络的非线性函数逼近,而不是典型国际象棋程序中使用的线性函数逼近。这提供了更强大的表示,但是也可能引入虚假逼近误差。MCTS对这些近似误差进行平均,因此在评估大型子树时趋向于误差抵消。相比之下,alpha-beta搜索计算明确的最小最大值,它将最大的近似误差传播到子树的根节点。使用MCTS可以允许AlphaZero将其神经网络的表示与强大的、独立于领域的搜索有效地结合起来。" 401 | ] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "metadata": {}, 406 | "source": [ 407 | "### 领域知识" 408 | ] 409 | }, 410 | { 411 | "cell_type": "markdown", 412 | "metadata": {}, 413 | "source": [ 414 | "- 1.描述位置的输入特征和描述走子的输出特征被构造为一组平面;即神经网络结构与棋盘的网格结构相匹配。\n", 415 | "- 2.为AlphaZero提供了完善的游戏规则知识。这些在MCTS期间被用来模拟由一系列走子产生的位置,以确定游戏的结束,并对达到结束状态的任何模拟对弈进行评分。\n", 416 | "- 3.对规则的了解也被用来编码输入平面(即[`王车易位`](http://chessprogramming.wikispaces.com/Castling),[`重复局面`](https://baike.baidu.com/item/%E5%9B%BD%E9%99%85%E8%B1%A1%E6%A3%8B%E6%9C%AF%E8%AF%AD/7549734?fr=aladdin),没有进展)和输出平面(棋子如何走子,升变和将棋中的[`取驹`](https://baike.baidu.com/item/%E5%B0%86%E6%A3%8B/491643)([`piece drops`](https://en.wikipedia.org/wiki/Shogi#Drops)))。\n", 417 | "- 4.合法走子的典型数量用于缩放探索噪音(见下文)。\n", 418 | "- 5.国际象棋和将棋比赛超过最大步数(由典型比赛长度决定)将被终止,并被判为平局;围棋比赛结束,使用Tromp-Taylor规则打分。\n" 419 | ] 420 | }, 421 | { 422 | "cell_type": "markdown", 423 | "metadata": {}, 424 | "source": [ 425 | "除了上面列出的要点,AlphaZero没有使用任何形式的领域知识。" 426 | ] 427 | }, 428 | { 429 | "cell_type": "markdown", 430 | "metadata": {}, 431 | "source": [ 432 | "### 表示" 433 | ] 434 | }, 435 | { 436 | "cell_type": "markdown", 437 | "metadata": {}, 438 | "source": [ 439 | "在本节中,我们将描述棋盘输入的表示形式,以及AlphaZero中神经网络使用的走子动作输出的表示形式。其他的表示本来也可以使用; 在我们的实验中,训练算法对于许多合理的选择可以有效地工作。" 440 | ] 441 | }, 442 | { 443 | "cell_type": "markdown", 444 | "metadata": {}, 445 | "source": [ 446 | "![b6](assets\\b6.png\")\n", 447 | "表S1:分别在围棋,国际象棋和将棋中AlphaZero使用的输入特征。第一组特征是8步历史走子记录的每个局面。计数由实数值表示;其他输入特征通过使用指定数量的二值输入平面的独热编码来表示。当前玩家由P1表示,对手由P2表示。" 448 | ] 449 | }, 450 | { 451 | "cell_type": "markdown", 452 | "metadata": {}, 453 | "source": [ 454 | "神经网络的输入是N × N ×(MT + L)图像栈,其表示状态使用大小为N×N的T组M个平面的级联组成。每一组平面代表时间步t-T + 1,...,t的棋盘位置,在小于1的时间步中设置为零。棋盘朝向当前玩家的角度。M特征平面由棋手存在的棋子的二值特征平面组成,每个棋子类型具有一个平面,第二组平面表示对手存在的棋子。对于将棋,还有额外的平面显示每种类型的持驹数。还有一个额外的L个常值输入平面,表示玩家的颜色,总的回合数量和特殊规则的状态:在国际象棋合法的王车易位(王翼或者后翼);局面的重复次数(3次重复在国际象棋中自动判为平局;在将棋中是4次);和在国际象棋中没有进展的走子次数(没有进展的50次走子自动判为平局)。表格S1中总结了输入特征。" 455 | ] 456 | }, 457 | { 458 | "cell_type": "markdown", 459 | "metadata": {}, 460 | "source": [ 461 | "下棋走子可以分为两部分:选择要移动的棋子,然后在棋子的合法棋步中进行选择。我们用一个8 × 8 × 73的平面栈来表示策略π(a|s),它编码了4,672个可能的走子的概率分布。每个8×8的位置标识从哪个方块“拾取”一个棋子。前56个平面编码对于任何棋子可能的“皇后走子”,沿着八个相对方向{北,东北,东,东南,南,西南,西,西北}中的一个,若干方块[1..7]中将被吃的棋子。接下来的8个平面编码可能的马的走子。最后的9个平面编码对于兵底线升变后在对角线上可能的走子和吃子,分别对于车,马或象。和其他兵从第七横线升变为为皇后的走子或吃子。" 462 | ] 463 | }, 464 | { 465 | "cell_type": "markdown", 466 | "metadata": {}, 467 | "source": [ 468 | "将棋中的策略由一个9 × 9 × 139的平面栈表示,类似地对11,259个可能的走子进行概率分布编码。前64个平面编码“皇后走子”,接下来的2个编码马走子。另有64 + 2个平面分别编码升变成皇后的走子和升变成马的走子。最后7个平面编码将一个被捕获的棋子放回棋盘上的位置。" 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "metadata": {}, 474 | "source": [ 475 | "围棋中的策略与AlphaGo Zero表示相同,使用一个包含19 × 19 + 1的走子平坦分布代表可能的放置和走子。我们在国际象棋和将棋上也尝试用关于走子的平坦分布,最后的结果几乎相同,尽管训练稍微慢了点。" 476 | ] 477 | }, 478 | { 479 | "cell_type": "markdown", 480 | "metadata": {}, 481 | "source": [ 482 | "![b7](assets\\b7.png\")\n", 483 | "表S2:国际象棋和将棋中AlphaZero使用的动作表示。该策略是由一堆编码合法走子概率分布的平面表示的;平面对应于表中的条目。" 484 | ] 485 | }, 486 | { 487 | "cell_type": "markdown", 488 | "metadata": {}, 489 | "source": [ 490 | "表S2中总结了行动表示。通过将概率设置为零,并将使用的走子概率重新归一化,可以屏蔽非法走子。" 491 | ] 492 | }, 493 | { 494 | "cell_type": "markdown", 495 | "metadata": {}, 496 | "source": [ 497 | "### 配置" 498 | ] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "metadata": {}, 503 | "source": [ 504 | "在训练期间,每个MCTS使用了800次模拟。棋局的数量,局面和思考时间由于不同的棋盘大小和游戏长度而有所不同,如表S3所示。" 505 | ] 506 | }, 507 | { 508 | "cell_type": "markdown", 509 | "metadata": {}, 510 | "source": [ 511 | "每场比赛的学习率为0.2,在训练过程中分别下降了三次(分别为0.02,0.002和0.0002)。走子的选择与根节点的访问次数成正比。Dirichlet噪声Dir(α)被添加到根节点的先验概率中;这个比例与典型位置的合法走子的近似数量成反比,分别为国际象棋、将棋和围棋的α 取{0.3,0.15,0.03}。除非另有说明,否则训练和搜索算法和参数与AlphaGo Zero相同。" 512 | ] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "metadata": {}, 517 | "source": [ 518 | "![b8](assets\\b8.png\")\n", 519 | "表S3:国际象棋,将棋和围棋中AlphaZero训练的选择统计。" 520 | ] 521 | }, 522 | { 523 | "cell_type": "markdown", 524 | "metadata": {}, 525 | "source": [ 526 | "在评估期间,AlphaZero选择走子使用关于根节点访问次数的贪婪方法。每台MCTS在一台带有4个TPU的机器上执行。" 527 | ] 528 | }, 529 | { 530 | "cell_type": "markdown", 531 | "metadata": {}, 532 | "source": [ 533 | "### 评估" 534 | ] 535 | }, 536 | { 537 | "cell_type": "markdown", 538 | "metadata": {}, 539 | "source": [ 540 | "为了评估国际象棋的性能,我们使用了Stockfish版本8(官方的Linux版本)作为基准程序,使用64个CPU线程和1GB Hash。" 541 | ] 542 | }, 543 | { 544 | "cell_type": "markdown", 545 | "metadata": {}, 546 | "source": [ 547 | "为了评估将棋中的性能,我们使用了Elmo版本WCSC27并结合了有64个CPU线程和1GB Hash,EnteringKingRule的usi选项设置为NoEnteringKing,YaneuraOu 2017早期的KPPT 4.73 64AVX2。我们通过测量每个选手的国际等级分来评估AlphaZero的相对强度(图1)。我们通过Logistic函数p(a defeats b) =$\\frac{1}{1+exp (c_{elo} (e(b)-e(a)) }$估计玩家a击败玩家b的概率,并通过贝叶斯逻辑回归估计等级分e(·),用BayesElo程序使用标准常数$c_{elo}$ = 1/400计算。国际等级评分是根据与AlphaZero在训练迭代期间进行比赛1秒每次走子的结果计算得出的,同时Stockfish,Elmo或者AlphaGo Lee分别是基准选手。基准玩家的国际等级分是以公开可用的价值为基础的。" 548 | ] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "metadata": {}, 553 | "source": [ 554 | "我们还测量了AlphaZero对每个基准玩家快棋的表现。设置被选择为符合计算机象棋比赛的条件:每个棋手允许1分钟下一步棋,所有棋手都可以投降认负(Stockfish和Elmo 10次连续走子-900 centipawns,AlphaZero 5%胜率)。所有玩家的思考都被禁止了。" 555 | ] 556 | }, 557 | { 558 | "cell_type": "markdown", 559 | "metadata": {}, 560 | "source": [ 561 | "![b9](assets\\b9.png\")\n", 562 | "表S4:国际象棋、将棋和围棋中AlphaZero,Stockfish和Elmo的评估速度(局面/秒)" 563 | ] 564 | } 565 | ], 566 | "metadata": { 567 | "kernelspec": { 568 | "display_name": "Python 3", 569 | "language": "python", 570 | "name": "python3" 571 | }, 572 | "language_info": { 573 | "codemirror_mode": { 574 | "name": "ipython", 575 | "version": 3 576 | }, 577 | "file_extension": ".py", 578 | "mimetype": "text/x-python", 579 | "name": "python", 580 | "nbconvert_exporter": "python", 581 | "pygments_lexer": "ipython3", 582 | "version": "3.6.7" 583 | } 584 | }, 585 | "nbformat": 4, 586 | "nbformat_minor": 1 587 | } 588 | -------------------------------------------------------------------------------- /Mastering_the_Game_of_Go_without_Human_Knowledge.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 无须人类知识掌握围棋" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "[`程世东`](http://zhihu.com/people/cheng-shi-dong-47) 翻译\n", 15 | "\n", 16 | "[`GitHub`](http://github.com/chengstone) [`Mail`](mailto:69558140@163.com)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "人工智能的一个长期目标是从“白板”开始,在挑战性领域达到超越人类的学习算法。最近,AlphaGo成为第一个在围棋中击败世界冠军的程序。AlphaGo中的树搜索使用深度神经网络评估局面和选定走子。这些神经网络是通过来自人类专家下棋的监督学习和通过从自我对弈中强化学习来训练的。在这里,我们介绍一个完全基于强化学习的算法,无须人类数据、指导或游戏规则以外的领域知识。AlphaGo成为自己的老师:训练一个神经网络,预测AlphaGo自己的走子选择。这个神经网络提高了树搜索的强度,在下一次迭代中产生更高质量的走子选择和更强的自我对弈。从“白板”开始,我们的新程序AlphaGo Zero实现了超人的表现,对之前的程序AlphaGo赢得了100-0的成绩。" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "在人工智能方面取得了很大的进展,使用了经过训练的监督学习系统来复制人类专家的决策。但是,专家数据往往昂贵,不可靠,或者根本无法使用。即使有可靠的数据,也可能会对以这种方式进行训练的系统的性能施加上限。相比之下,强化学习系统是根据它们自己的经验进行训练的,原则上允许它们超越人类的能力,并在缺乏人类专家的领域工作。最近,通过强化学习训练的深度神经网络,朝着这个目标快速发展。在Atari和3D虚拟环境等电脑游戏中,这些系统已经超越了人类。然而,在人类智力方面最具挑战性的领域 - 比如被广泛认为是人工智能面临的巨大挑战的围棋游戏 - 在广阔的搜索空间中需要精确和复杂的前瞻。完全通用的方法以前在这些领域没有实现人类级别的表现。" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "AlphaGo是第一个在围棋中实现超人表现的程序。我们称之为AlphaGo Fan的发行版在2015年10月击败了欧洲冠军樊麾。AlphaGo Fan使用了两个深度神经网络:输出走子概率的策略网络和输出局面评估的价值网络。策略网络最初是通过监督学习来训练的,以准确地预测人类专家的走子,并随后通过策略梯度强化学习进行改进。价值网络经过训练,可以预测策略网络与自己下棋的胜利者。一旦训练完成,这些网络与蒙特卡洛树搜索(MCTS)相结合,提供了一个前瞻性搜索,使用策略网络缩小高概率走子的搜索范围,并使用价值网络(与Monte-Carlo 使用快速展开策略rollout)来评估树中的局面。随后的版本,我们称之为AlphaGo Lee,采用了类似的方法(见[`方法`](#方法)),并在2016年3月击败了赢得18个国际冠军的李世石。" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "我们的AlphaGo Zero程序在几个重要方面与AlphaGo Fan和AlphaGo Lee 不同。首先,它是完全由自我对弈强化学习训练的,从随机下棋开始,没有任何监督或使用人类的数据。其次,它只使用黑子和白子作为输入特征。第三,它使用单一的神经网络,而不是分别使用策略和价值网络。最后,它使用更简单的树搜索,依靠这个单一的神经网络来评估局面和选择走子,而不需要执行任何rollouts(快速走子)。为了达到这些效果,我们引入了一种新的强化学习算法,在训练迭代中引入了前瞻搜索,从而实现了快速的提升和精确稳定的学习。搜索算法、训练过程和网络结构中的技术差异将在[`方法`](#方法)中描述。" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "## 1、AlphaGo Zero的强化学习" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "我们的新方法使用参数为θ的深度神经网络$f_θ$。这个神经网络把位置的原始表示形式s(棋盘状态)和它的历史走子记录作为输入,并输出走子概率和一个评估值(胜率),$(p,v)= f_θ(s)$。向量走子概率p表示选择每个走子的概率,$p_a = Pr(a|s)$。评估值v是标量,估计当前玩家从局面s获胜的概率。这个神经网络将策略网络和价值网络的任务结合到一个体系结构中。神经网络由许多残差块卷积层组成,包含批归一化和整流非线性(见[`方法`](#方法))。" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "AlphaGo Zero中的神经网络是通过一种新的强化学习算法从自我对弈中训练出来的。在每个局面s,由神经网络$f_θ$的预测结果作为参考执行MCTS搜索。MCTS搜索输出每个可能走子的概率π。MCTS搜索输出的概率通常比神经网络$f_θ(s)$输出的原始概率p选择更强的走子;因此,MCTS可被视为一个强有力的策略提升器。搜索 - 使用基于MCTS提升的策略来选择每个走子,然后使用获胜者z作为价值- 可以被视为一个强大的策略评估器。强化学习算法的主要思想是在策略迭代过程中重复使用这些搜索:不断更新神经网络的参数以使得网络输出的走子概率和胜率$(p,v)= f_θ (s)$,更接近地匹配MCTS提升的走子概率和自对弈胜者(π, z);这些新的参数被用在下一次自我对弈中,使得MCTS搜索更加强大。图1显示了自我对弈训练过程。" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "![图1](assets\\a1.png\")\n", 73 | "图1:AlphaGo Zero中的自我对弈强化学习。a 程序自我对弈状态$s_1$, ..., $s_T$。在每个局面$s_t$,使用最新的神经网络$f_θ$执行蒙特卡洛树搜索(MCTS)$α_θ$(见图2)。根据由MCTS计算的概率选择走子,$a_t$∼ $π_t$。按照游戏规则对棋局结束状态$s_T$进行评分得到游戏获胜者z。b 训练AlphaGo Zero神经网络。神经网络将棋局原始走子状态$s_t$作为其输入,将其输入给具有参数θ的多个卷积层,并且输出表示走子概率分布的向量$p_t$和表示当前玩家在局面$s_t$的获胜概率标量值$v_t$。更新神经网络的参数θ以使策略向量$p_t$与搜索概率$π_t$的相似度最大化,并使得预测的获胜者$v_t$和游戏胜者z之间的误差最小化(参见等式1)。新参数用于下一次迭代的自我对弈a中。" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "蒙特卡洛树搜索使用神经网络$f_θ$来指导其搜索(见图2)。搜索树中的每个边(s, a)存储先验概率P(s, a),访问计数N(s, a)和行动价值Q(s, a)。每次搜索从根状态开始并且迭代地选择使置信上限区间$Q(s, a) +U(s, a)$最大化的走子,其中$U(s, a) ∝ P(s, a)/(1 + N(s, a))$,直到遇到叶节点$s'$。叶子节点只会被神经网络扩展和评估一次,以产生先验概率和胜率评估值,$(P(s', ·), V (s')) = f_θ(s')$。更新搜索中遍历的每个边(s, a),增加其访问次数N(s, a),将它的行动价值更新为在这些搜索上的平均估计值,即 $Q(s,a) = 1 / N (s,a)∑_{s’ |s,a→s’} V(s’)$,其中$s,a →s'$表示从局面s选择走子a后搜索最终达到$s'$。" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "![a2](assets\\a2.png\")\n", 88 | "图2:AlphaGo Zero中的蒙特卡洛树搜索。a.每次模拟通过选择具有最大行动价值Q的边加上取决于所存储的先验概率P和该边的访问计数N(每次访问都被增加一次)的上限置信区间U来遍历树。b.展开叶子节点,通过神经网络$(P(s, ·), V (s)) = f_θ(s)$来评估局面s;向量P的值存储在叶子结点扩展的边上。c.更新行动价值Q等于在该行动下的子树中的所有评估值V的均值。d.一旦MCTS搜索完成,返回局面s下的落子概率π,与$N^{1 /τ}$成正比,其中N是从根状态每次移动的访问计数, τ是控制温度的参数。" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "MCTS可以被看作一种自我对弈算法,在给定神经网络参数θ和根状态s的情况下,计算推荐走子的先验概率向量$π = α_θ(s)$,与每次走子的指数访问计数成正比 ,$π_a ∝ N(s,a)^{1 /τ}$,其中τ是温度参数。" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "神经网络通过自我对弈强化学习算法进行训练,该算法使用MCTS下每一步棋。首先,神经网络随机初始化权重$θ_0$。在随后的每次迭代i≥1时,生成自我对弈棋局(图1a)。在每个时间步t,使用前一次迭代的神经网络$f_{θ_{i-1} }$执行MCTS搜索$π_t=α_{θ_{i-1} }(s_t)$,并且通过对搜索概率$π_t$进行采样来执行走子。双方都选择跳过;当搜索值低于认输阈值时,或当游戏超过最大长度(棋盘无处走子)时,游戏在步骤T终止;然后对游戏进行评分给出最终奖励$r_T ∈ ${−1, +1}(详见[`方法`](#方法))。MCTS搜索过程中每个时间步t的数据存储为($s_t$ , $π_t$ , $z_t$),其中$z_t=± r_T$是当前玩家在步骤t的角度的游戏获胜者。同时(图1b),新的网络参数$θ_i$从上一次自我对弈的所有时间步产生的数据(s, π, z)中采样进行训练。调整神经网络$(p,v)= f_{θ_i } (s)$以最小化预测胜率v和自我对弈的胜者z之间的误差,并使神经网络走子概率p与搜索概率π的相似度最大化。具体而言,参数θ通过梯度下降分别在均方误差和交叉熵损失之和上的损失函数l进行调整," 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "#### $(p,v)= f_θ (s)$, $l = (z - v)^2- π^T log p + c||θ||^2$ (等式1)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "其中c是控制L2权重正则化水平的参数(防止过拟合)。" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "## 2、AlphaGo Zero训练经验" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "我们使用强化学习来训练AlphaGo Zero。从随机下棋开始,没有人为干预,持续训练约3天。在训练过程中,生成了490万个自对弈棋局,每次MCTS使用1600次模拟,每次走子花费大约0.4s的思考时间。使用2048个局面的70万个mini-batches训练参数。神经网络包含20个残差块(详见[`方法`](#方法))。" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "图3a显示了AlphaGo Zero在自我对弈强化学习期间的表现,作为训练时间和国际等级分的函数。学习在整个训练过程中顺利进行,并没有出现先前文献中提出的振荡或灾难性遗忘。令人惊讶的是,AlphaGo Zero仅仅36小时就打败了AlphaGo Lee;为了比较,AlphaGo Lee训练了几个月的时间。在72小时之后,我们根据首尔人机比赛中的比赛条件(见[`方法`](#方法))和2小时时间控制,评估了AlphaGo Zero与击败Lee Sedol版本的AlphaGo Lee。AlphaGo Zero使用了一台带有4个TPU的机器,而AlphaGo Lee则分布在多台机器上,并使用了48个TPU。AlphaGo Zero以100比0击败了AlphaGo Lee(参见扩展数据图5和补充信息)。" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "![a3](assets\\a3.png\")\n", 145 | "图3:AlphaGo Zero的训练评估。a.自我对弈强化学习的表现。该图显示了来自AlphaGo Zero强化学习的每次迭代i的每个MCTS棋手的表现$α_{θ_i }$。国际等级分是估算不同棋手间下棋计算出来的,每次走子使用0.4秒的思考时间(见[`方法`](#方法))。为了比较,还显示了使用KGS数据集通过人类数据监督学习训练的棋手。b.人类走子的预测准确率。该图显示了神经网络$f_{θ_i }$在自我对弈i的每次迭代中,用于预测来自GoKifu数据集的人类走子的准确率。准确率度量的是神经网络为人类走子分配高概率的位置的百分比。也显示了通过监督学习训练的神经网络的准确率。c.在人类下棋结果的均方误差(MSE)。该图显示神经网络$f_{θ_i }$的MSE,在自我对弈i的每次迭代中,使用GoKifu数据集预测人类下棋的结果。MSE介于实际结果$z ∈ ${-1,+1}和神经网络价值v(胜率)之间,范围是[0,1]。还显示了通过监督学习训练的神经网络的MSE。" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "为了评估自我对弈强化学习的优点,与从人类数据中学习相比,我们训练了第二个神经网络(使用相同的体系结构)来预测KGS数据集中的专家走子;与之前的工作(分别参见扩展数据表1和2)相比,这实现了最好的预测精度。监督学习能够获得更好的初期表现,更好地预测人类专家下棋的结果(图3)。值得注意的是,尽管监督学习获得了更高的走子预测准确率,但是自学者的整体表现更好,在训练的前24小时内击败了学习人类的棋手。这表明AlphaGo Zero可能正在学习一种与人类有着本质区别的策略。" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "为了分别评估架构和算法的贡献,我们比较了AlphaGo Zero的神经网络架构与AlphaGo Lee使用的神经网络架构(见图4)的表现。使用AlphaGo Lee独立的策略和价值网络或者AlphaGo Zero中组合的策略和价值网络来创建四个神经网络;并使用AlphaGo Lee的卷积网络架构或AlphaGo Zero的残差网络架构。在自我对弈训练72小时后,使用由AlphaGo Zero产生的自我对弈的固定数据集来训练每个网络以最小化相同的损失函数(等式1)。使用残差网络更准确,达到更低的误差,AlphaGo的性能提高了600多个等级分。将策略和价值组合在一个网络中,略微降低了移动走子的准确率,但减少了价值错误,并提升了AlphaGo 600个等级分的下棋表现。这部分由于提高了计算效率,但更重要的是,将网络双重目标规范化为支持多个用例的共同表示。" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "![a4](assets\\a4.png\")\n", 167 | "图4:AlphaGo Zero和AlphaGo Lee神经网络架构的比较。使用单独的(“sep”)或组合的策略和价值网络(“dual”)以及使用卷积(“conv”)或残差网络(“res”)的神经网络架构的比较。“dual-res”和“sep-conv”组合分别对应于AlphaGo Zero和AlphaGo Lee中使用的神经网络架构。每一个网络都是通过之前运行的AlphaGo Zero生成的固定数据集来训练的。a.每个训练过的网络都与AlphaGo Zero的搜索相结合,得到不同的棋手。国际等级分是根据这些不同的棋手之间的对弈计算出来的,每一步棋的思考时间为5秒。b.对每个网络架构在人类专家走子(来自GoKifu数据集)的预测准确率。c.每个网络架构的在人类专家下棋结果(来自GoKifu数据集)的均方误差。" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "## 3、AlphaGo Zero学到的知识" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": {}, 180 | "source": [ 181 | "AlphaGo Zero在其自我对弈训练过程中发现了非凡的围棋知识。这包括人类围棋知识的基本要素,以及超出传统围棋知识范围的特殊策略。" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "metadata": {}, 187 | "source": [ 188 | "图5显示了一个时间线,指示何时发现了专业的定式(托角序列)(图5a,扩展数据图1);最终AlphaGo Zero更喜欢以前未发现的新定式(图5b,扩展数据图2)。图5c和补充信息显示了几个在不同训练阶段的快速自我对弈。AlphaGo Zero从完全随机的走子过渡到对围棋概念的完善理解,包括布局(开场),手筋(战术),死活,劫(重复棋局),缓气劫(终局),对杀,先手(主动 ),形状,势和地,都是从最基本规则中发现的。令人惊讶的是,Shocho(“征”吃序列可能跨越整个棋盘)是人类学习的围棋知识的第一个元素之一,后来在AlphaGo Zero的训练中得到了理解。" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "![a5](assets\\a5.png\")\n", 196 | "图5:AlphaGo Zero学到的围棋知识。a.在AlphaGo Zero训练期间发现了五种人类定式(常见的托角序列)。相关的时间戳指示在自我对弈训练期间每个序列第一次发生的时间(考虑旋转和反转)。扩展数据图1提供了每个序列训练出现的频率。b.在自我训练的不同阶段偏爱的五个定式。在自我对弈训练的迭代中,每个显示的托角序列在所有托角序列中高频率出现。该迭代的时间戳在时间线上标识了出来。训练10个小时,弱托角是首选。训练47小时,点三三入侵使用最频繁。这个定式也是在人类的对弈中常见的;然而,AlphaGo Zero后来发现并选择了一个新的变着。扩展数据图2提供了所有五个频繁出现的序列和新变着。c.在不同的训练阶段进行的三个自我对弈的前80步棋,每次搜索使用1600次模拟(大约0.4s)。训练3小时,游戏贪婪地关注于吃子,很像人类初学者。在19小时关注于死活、势和地,即围棋的根本。在70小时,下棋表现非常完美平衡,涉及多场战斗和一次复杂的劫争,最终以白子多半子获胜。请参阅完整对弈的补充信息。" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": {}, 202 | "source": [ 203 | "## 4、AlphaGo Zero的最终表现" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": {}, 209 | "source": [ 210 | "随后,我们使用更大的神经网络和更长的持续时间将我们的强化学习应用于AlphaGo Zero的第二个实例。再次训练从完全随机下棋开始,持续约40天。在训练过程中,共产生了2900万次自我对弈棋局。从310万个mini-batches(2048个局面)更新参数。神经网络包含40个残差块。学习曲线如图6a所示。" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | "我们使用与AlphaGo Fan,AlphaGo Lee和之前的几个围棋程序进行内部比赛来评估训练好的AlphaGo Zero。我们还与最强大的现有程序AlphaGo Master进行了比赛,该程序基于论文提供的算法和体系结构,但是利用人类数据和特征(请参阅[`方法`](#方法))——在2017年1月的网络比赛中以60-0击败了最强的人类专业玩家。在我们的评估中,所有的程序都被允许每步棋5秒的思考时间;AlphaGo Zero和AlphaGo Master均部署在一台带有4个TPU的机器上;AlphaGo Fan和AlphaGo Lee分别分配了176个GPU和48个TPU。我们还包括一个基于AlphaGo Zero的raw neural network(不使用MCTS的网络);这个网络只是以最大的概率选择如何走子。" 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "metadata": {}, 223 | "source": [ 224 | "![a6](assets\\a6.png\")\n", 225 | "图6:AlphaGo Zero的表现。a. AlphaGo Zero的学习曲线,使用40 block的残差网络训练40天。该图显示了我们的强化学习算法的每次迭代i的每个棋手$α_{θ_i}$的表现。国际等级分是根据不同棋手之间的棋局评估来计算的,每个搜索使用0.4秒(参见[`方法`](#方法))。b. AlphaGo Zero的最终表现。AlphaGo Zero使用40个残差块神经网络,训练了40天。该图显示了各个围棋程序的比赛结果:AlphaGo Zero,AlphaGo Master(在线比赛中击败顶级人类玩家60-0),AlphaGo Lee(击败李世石),AlphaGo Fan(击败樊麾)以及之前的围棋程序Crazy Stone,Pachi和GnuGo。每个程序每步棋都有5秒的思考时间。AlphaGo Zero和AlphaGo Master部署在单独一台机器在Google Cloud上;AlphaGo Fan和AlphaGo Lee分布式部署在机群上。还包括来自AlphaGo Zero的raw neural network,其直接选择网络输出$p_a$中的最大概率走子a,而不使用MCTS。程序以Elo国际等级分25为单位进行评估:200点差距对应于胜率75%。" 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "metadata": {}, 231 | "source": [ 232 | "图6b显示了每个程序在Elo国际等级分上的表现。未使用任何预测的raw neural network达到了3,055的Elo等级分。AlphaGo Zero获得了5,185的评分,而AlphaGo Master则为4,858,AlphaGo Lee为3,739,AlphaGo Fan为3,144。" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": {}, 238 | "source": [ 239 | "最后,我们评估了AlphaGo Zero与 AlphaGo Master在2小时时间的100场比赛。AlphaGo Zero以 89:11赢得了比赛(参见扩展数据图6和补充信息)。" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "## 5、结论" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "我们的结果彻底地证明,即使在最具挑战性的领域中,纯强化学习方法也是完全可行的:在没有基本规则之外的领域知识的情况下,没有人类的例子或指导,就有可能训练成超人类的水平。此外,与对人类专家数据的训练相比,纯粹的强化学习方法只需要几个小时的训练时间,能够逐渐获得更好的表现。使用这种方法,AlphaGo Zero大幅度地击败了之前最强大的使用手工特征的人类数据进行训练的AlphaGo版本。" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": {}, 259 | "source": [ 260 | "人类已经从几千年来数百万的棋局中积累了围棋知识,总结了通用的模式、术语和书籍。在几天的时间里,从“白板”开始,AlphaGo Zero能够重新发现围棋的许多知识,以及为最古老的游戏提供新见解的新颖策略。" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": {}, 266 | "source": [ 267 | "# 方法" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "metadata": {}, 273 | "source": [ 274 | "强化学习策略迭代是一种经典算法,通过轮流在策略评估(估算当前策略的价值函数 )和策略改进(使用当前价值函数来产生更好的策略)之间产生一系列改进策略。策略评估的一个简单方法是从采样的结果中估计价值函数。策略改进的一个简单方法是关于价值函数贪婪地选择行动。在庞大的状态空间中,需要近似评估每个策略并表示其改进。" 275 | ] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "metadata": {}, 280 | "source": [ 281 | "基于分类的强化学习使用简单的MonteCarlo搜索来改进策略。每个动作都会执行许多rollouts(快速走子)。具有最大均值的动作提供了一个积极的训练实例,而所有其他的动作提供了负面的训练实例;然后策略被训练为将动作分为正面或负面,并在随后的rollouts中使用。当τ →0时,这可以被视为AlphaGo Zero训练算法的策略组件的前身。" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "更近期的具体事例,基于分类的修改策略迭代(CBMPI),也通过将价值函数回归到剪枝rollout值来执行策略评估,类似于AlphaGo Zero的价值组件;这在俄罗斯方块的游戏中取得了最好的结果。然而,这个先前的工作仅限于使用手工特征的简单rollouts和线性函数逼近。" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "AlphaGo Zero自我对弈算法可以类似地被理解为近似策略迭代方案,其中MCTS被用于策略改进和策略评估。策略改进始于神经网络策略,根据该策略的建议执行MCTS,然后将更强大的搜索策略投影到神经网络的函数空间中。策略评估应用于(更强大的)搜索策略:自我对弈的结果也被投射回神经网络的函数空间。这些投影步骤通过训练神经网络参数分别匹配搜索概率和自我对弈结果来实现。" 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": {}, 301 | "source": [ 302 | "Guo等人还将MCTS的输出投影到神经网络中,或者通过对搜索值进行价值网络回归,或者通过对MCTS选择的动作进行分类。这种方法被用来训练神经网络来玩Atari游戏;然而,MCTS是固定的 - 没有策略迭代 - 并没有使用任何训练过的网络。" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "metadata": {}, 308 | "source": [ 309 | "## 游戏中的自我对弈强化学习" 310 | ] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "metadata": {}, 315 | "source": [ 316 | "我们的方法最直接适用于完全信息的零和游戏。我们遵循先前工作中描述的交替马尔可夫博弈的形式,指出基于价值或策略迭代的算法自然延伸到这个环境。" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "自我对弈强化学习以前已经应用于围棋游戏。NeuroGo使用神经网络来代表价值函数,使用基于围棋关于连,地和眼的知识的复杂架构。这个神经网络是由时序差分学习(temporal-difference learning)进行训练,在自我对弈中预测地。一个相关的方法,RLGO,代表替代特征的线性组合的价值函数,而不是由功能的线性组合,详尽地列举所有3 × 3模式的棋子;它是通过时序差分学习来预测自我对弈中的赢家。NeuroGo和RLGO达到了很弱的业余水平。" 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "蒙特卡洛树搜索(MCTS)也可以被看作是自我对弈强化学习的一种形式。搜索树的节点包含在搜索期间遇到的局面的价值函数;这些价值被更新以预测模拟自我对弈的胜者。MCTS程序之前在围棋中已经取得了非常强大的业余水平,但是使用了大量的领域专业知识:基于手工特征的快速走子策略,通过运行模拟下棋(rollout)来评估局面,直到游戏结束;以及基于手工特征的树策略,该策略在搜索树中选择走子。" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": {}, 336 | "source": [ 337 | "自我对弈强化学习方法在其他游戏中取得了高水平的表现:国际象棋,跳棋,西洋双陆棋,奥赛罗棋(又叫黑白棋、反棋(Reversi)、翻转棋),Scrabble(英语拼字游戏)和最近的扑克。在所有这些例子中,价值函数都是通过回归或时序差分学习从自我对弈生成的数据来训练的。经过训练的价值函数在alpha-beta搜索中被用作一个简单的蒙特卡洛搜索或反事实遗憾最小化(Counterfactual Regret Minimization)的评估函数。但是,这些方法使用手工输入特征或手工特征模板。此外,学习过程使用监督式学习来初始化权重,手动选择棋子价值的权重,手动限制动作空间,或使用预先存在的计算机程序作为训练对手或产生游戏记录。" 338 | ] 339 | }, 340 | { 341 | "cell_type": "markdown", 342 | "metadata": {}, 343 | "source": [ 344 | "许多最成功和广泛使用的强化学习方法首先在零和游戏的背景下引入:时序差分学习首先被引入在一个跳棋游戏程序,而MCTS被引入在围棋游戏。然而,非常相似的算法在视频游戏机器人,工业控制和在线推荐系统中被证明是非常有效的。" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "metadata": {}, 350 | "source": [ 351 | "## AlphaGo版本" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "我们比较AlphaGo的三个不同版本:" 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "metadata": {}, 364 | "source": [ 365 | "1.\tAlphaGo Fan是之前发行的2015年10月与Fan Hui对弈的一个程序。这个程序分布式部署在使用176个GPU的许多机器上。\n", 366 | "2.\tAlphaGo Lee是在2016年3月4月1击败李世石的程序。它与以前未发表的AlphaGo Fan在很多方面都很相似。但是,我们强调了几个关键差异,以促进公平比较。首先,价值网络是通过AlphaGo自我对弈的快速游戏的结果来训练的,而不是使用策略网络自我对弈;这个过程迭代了好几次 - 这是本文提出的白板算法的第一步。其次,策略和价值网络比原始论文中描述的要大 - 分别使用256个平面的12个卷积层,并且接受更多迭代的训练。该网络还分布在许多使用48个TPU而不是GPU的机器上,使其能够在搜索过程中更快地评估神经网络。\n", 367 | "3.\tAlphaGo Master是2017年1月份以60-0击败顶级人类选手的程序。它以前是未发表的,但是使用本文所述的相同的神经网络结构,强化学习算法和MCTS算法。然而,它使用与AlphaGo Lee相同的手工特征和rollouts,并且通过从人类数据监督学习来初始化训练。\n", 368 | "4.\tAlphaGo Zero是本文描述的程序。它从自我对弈强化学习中学习,从随机初始权重开始,不使用rollouts,没有人为的监督,只用历史棋局作为输入特征。它只使用一台在Google Cloud上的带有4个TPU的机器(AlphaGo Zero也可以分布式部署,但我们选择使用最简单的搜索算法)。\n" 369 | ] 370 | }, 371 | { 372 | "cell_type": "markdown", 373 | "metadata": {}, 374 | "source": [ 375 | "## 领域知识" 376 | ] 377 | }, 378 | { 379 | "cell_type": "markdown", 380 | "metadata": {}, 381 | "source": [ 382 | "我们的主要贡献是证明超人的表现可以在没有人类领域知识的情况下实现。" 383 | ] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "metadata": {}, 388 | "source": [ 389 | "为了阐明这一贡献,我们列举了AlphaGo Zero在其训练过程或蒙特卡洛树搜索中明确或隐含地使用的领域知识;这些是AlphaGo Zero需要替换以学习不同的(交替马尔可夫)游戏的知识项目。" 390 | ] 391 | }, 392 | { 393 | "cell_type": "markdown", 394 | "metadata": {}, 395 | "source": [ 396 | "1.\tAlphaGo Zero获得了完整的游戏规则知识。这些知识在MCTS中使用,模拟由一系列走子产生的局面,并对达到结束状态的任何模拟进行评分。当双方都放弃行棋(pass)时,或在19 · 19 · 2 = 722次走子后,游戏终止。此外,玩家在每个局面提供一组合法的走子。\n", 397 | "2.\tAlphaGo Zero在MCTS模拟和自我对弈训练中使用Tromp-Taylor评分。这是因为如果游戏在领土边界得到解决之前终止,那么人类的分数(中国,日本或者韩国的规则)就没有明确的定义。不过,所有比赛和游戏评估都是使用中国规则进行评分的。\n", 398 | "3.\t描述局面的输入特征被构造为19 × 19图像;即神经网络结构与棋盘的网格结构相匹配。\n", 399 | "4.\t围棋的规则在旋转和反转下是不变的; AlphaGo Zero中已经使用了这种知识,通过增加训练期间的数据集以包括每个局面的旋转和反转,并且在MCTS期间对局面的随机旋转或反转进行采样(参见[`搜索算法`](#搜索算法))。除了贴目以外,围棋的规则也是颜色转换不变的。 这个知识是通过从当前玩家的角度来表示棋盘来使用的(见[`神经网络结构`](#神经网络结构))。\n" 400 | ] 401 | }, 402 | { 403 | "cell_type": "markdown", 404 | "metadata": {}, 405 | "source": [ 406 | "除了上面列出的要点之外,AlphaGo Zero不使用任何形式的领域知识。它只使用其深层神经网络来评估叶节点并选择走子(参见下面的部分)。它不使用任何rollout策略或树策略,并且MCTS不会被任何其他启发式或特定于领域的规则增强。排除非法走子 - 即使是填充玩家自己的眼(以前所有程序中使用的标准启发式)。" 407 | ] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "metadata": {}, 412 | "source": [ 413 | "该算法从神经网络的随机初始参数开始。神经网络体系结构(参见[`神经网络结构`](#神经网络结构))基于当前的图像识别技术,并据此选择用于训练的超参数(参见[`自我对弈训练`](#自我对弈训练))。通过高斯过程优化来选择MCTS搜索参数,以便使用在初步训练的神经网络来优化AlphaGo Zero的自我对弈表现。对于较大的运行(40块,40天),使用较小运行(20块,3天)训练的神经网络重新优化MCTS搜索参数。训练算法是在没有人为干预的情况下自主执行的。" 414 | ] 415 | }, 416 | { 417 | "cell_type": "markdown", 418 | "metadata": {}, 419 | "source": [ 420 | "## 自我对弈训练" 421 | ] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "metadata": {}, 426 | "source": [ 427 | "AlphaGo Zero的自我对弈训练由三个主要组件组成,所有组件都是异步并行执行的。神经网络参数$θ_i$从最近的自我对弈数据持续优化;持续对AlphaGo Zero棋手 $α_{θ_i}$进行评估;迄今为止表现最好的棋手$α_{θ_*}$被用来产生新的自我对弈数据。" 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "metadata": {}, 433 | "source": [ 434 | "## 优化" 435 | ] 436 | }, 437 | { 438 | "cell_type": "markdown", 439 | "metadata": { 440 | "collapsed": true 441 | }, 442 | "source": [ 443 | "每个神经网络$f_{θ_i}$在Google Cloud上使用TensorFlow进行优化,包括64个GPU worker和19个CPU参数服务器。batch-size为每个worker 32个,总的mini-batch为2,048个。从最近的50万次自我对弈的所有局面随机抽取每个mini-batch数据。神经网络参数使用动量和学习率退火随机梯度下降法,使用方程1中的损失进行优化。学习率按照扩展数据表3中的标准时间表进行退火。动量参数设置为0.9。交叉熵和均方误差损失的权重相同(这是合理的,因为奖励被归一化到$r ∈ ${−1, +1}),L2正则化参数设为$c = 10^{-4}$。优化过程每1,000个训练步骤产生一个新的检查点。这个检查点由评估者评估,并且可以用来生成下一个batch自我对弈数据,我们将在后面解释。" 444 | ] 445 | }, 446 | { 447 | "cell_type": "markdown", 448 | "metadata": {}, 449 | "source": [ 450 | "## 评估者" 451 | ] 452 | }, 453 | { 454 | "cell_type": "markdown", 455 | "metadata": { 456 | "collapsed": true 457 | }, 458 | "source": [ 459 | "为了确保我们始终生成最好质量的数据,我们将每个新的神经网络检查点与当前最好的网络$f_{θ_*}$进行比较,然后将其用于数据生成。通过使用$f_{θ_i}$评估叶子节点的局面和先验概率(参见[`搜索算法`](#搜索算法))的MCTS搜索$α_{θ_i}$的性能来评估神经网络$f_{θ_i}$。每个评估包括400场比赛,使用1,600个模拟的MCTS来选择每个走子,使用无限小的温度τ → 0(即我们确定性地选择最大访问次数的走子,以提供最强的可能性)。如果新玩家赢得55%以上的比赛(为了避免在单独噪音上选择),那么它将成为最好的玩家$α_{θ_*}$,随后用于生成自我对弈数据,并成为后续比较的基准。" 460 | ] 461 | }, 462 | { 463 | "cell_type": "markdown", 464 | "metadata": {}, 465 | "source": [ 466 | "## 自我对弈 " 467 | ] 468 | }, 469 | { 470 | "cell_type": "markdown", 471 | "metadata": {}, 472 | "source": [ 473 | "由评估者选择的当前最佳玩家$α_{θ_*}$被用于产生数据。在每次迭代中,$α_{θ_*}$进行25,000次自我对弈, MCTS搜索使用1,600次模拟来选择每次走子(每次搜索大约需要0.4s)。对于每场比赛的前30个走子,温度设置为τ = 1;这将根据MCTS中的访问次数按比例选择走子,并确保覆盖不同的局面(模拟不同的走法)。对于游戏的其余部分,使用无限小的温度τ→0。在根节点s0中的先验概率中添加Dirichlet噪声 $P(s,a) = (1 - ϵ)p_a + ϵη_a$,其中η∼Dir(0.03)和ϵ = 0.25,实现额外的探索;这种噪音确保所有的走子都可以被尝试,但是搜索仍然可能下臭棋。为了节省计算,明显会输的游戏将投子认负。选择认输阈值$v_resign$以自动保持误报率(如果AlphaGo没有认输可能赢得的游戏比例)小于5%。为了测量误报(false positives),我们在10%的自我对弈中禁止认输,并下到游戏结束。" 474 | ] 475 | }, 476 | { 477 | "cell_type": "markdown", 478 | "metadata": {}, 479 | "source": [ 480 | "## 监督学习" 481 | ] 482 | }, 483 | { 484 | "cell_type": "markdown", 485 | "metadata": {}, 486 | "source": [ 487 | "为了比较,我们还通过监督学习来训练神经网络参数$θ_{SL}$。神经网络架构与AlphaGo Zero相同。从KGS数据集中随机抽取小批量(Mini-batches)的数据(s, π, z),为人类专家走子a设定$π_a= 1$。通过动量和学习率退火的随机梯度下降,使用与方程1中相同的损失对参数进行优化,但是将均方误差系数设为0.01。学习率按照扩展数据表3中的标准时间表退火。动量参数设为0.9,L2正则化参数设为$c = 10^{-4}$。" 488 | ] 489 | }, 490 | { 491 | "cell_type": "markdown", 492 | "metadata": {}, 493 | "source": [ 494 | "通过使用组合的策略和价值网络体系结构,并且通过对价值组件使用较低的权重,可以避免价值的过拟合(在之前的工作中描述的问题)。72小时后,走子预测准确率超过了之前的网络,KGS测试集达到了60.4%;价值预测误差也大大好于以前。验证集由GoKifu的专业游戏数据组成。准确率和均方误差分别在扩展数据表1和扩展数据表2中报告。" 495 | ] 496 | }, 497 | { 498 | "cell_type": "markdown", 499 | "metadata": {}, 500 | "source": [ 501 | "## 搜索算法" 502 | ] 503 | }, 504 | { 505 | "cell_type": "markdown", 506 | "metadata": {}, 507 | "source": [ 508 | "AlphaGo Zero使用AlphaGo Fan和AlphaGo Lee中使用的非常简单的异步策略和价值MCTS算法(APV-MCTS)的变种。" 509 | ] 510 | }, 511 | { 512 | "cell_type": "markdown", 513 | "metadata": {}, 514 | "source": [ 515 | "搜索树中的每个节点都包含所有合法动作a ∈ A(s)的边(s, a)。 每个边缘存储一组统计信息," 516 | ] 517 | }, 518 | { 519 | "cell_type": "markdown", 520 | "metadata": {}, 521 | "source": [ 522 | "{N(s, a), W(s, a), Q(s, a), P(s, a)}," 523 | ] 524 | }, 525 | { 526 | "cell_type": "markdown", 527 | "metadata": {}, 528 | "source": [ 529 | "其中N(s, a)是访问次数,W(s, a)是总行动价值,Q(s, a)是平均行动价值,P(s, a)选择边的先验概率。多个模拟分别在搜索线程上并行执行。该算法通过三个阶段(图2中的a-c)的迭代,然后选择走子(图2中的d)。" 530 | ] 531 | }, 532 | { 533 | "cell_type": "markdown", 534 | "metadata": {}, 535 | "source": [ 536 | "### Select(图2a)" 537 | ] 538 | }, 539 | { 540 | "cell_type": "markdown", 541 | "metadata": {}, 542 | "source": [ 543 | "选择阶段几乎与AlphaGo Fan相同;为了完整性在这里概括一下。每个模拟的第一个入树阶段开始于搜索树的根节点$s_0$,并且在时间步骤L模拟到达叶节点$s_L$时结束。在每个时间步t 2 | # cchess-zero 3 | AlphaZero implemented Chinese chess. AlphaGo Zero / AlphaZero实践项目,实现中国象棋。 4 | 5 | __Author__ chengstone 6 | 7 | __e-Mail__ 69558140@163.com 8 | 9 | 代码详解请参见文内jupyter notebook和↓↓↓ 10 | 11 | 知乎专栏:https://zhuanlan.zhihu.com/p/34433581 12 | 13 | 博客:http://blog.csdn.net/chengcheng1394/article/details/79526474 14 | 15 | 欢迎转发扩散 ^_^ 16 | 17 | 这是一个AlphaZero的实践项目,实现了一个中国象棋程序,使用TensorFlow1.0和Python 3.5开发,还要安装uvloop。 18 | 19 | 因为我的模型训练的不充分,只训练了不到4K次,模型刚刚学会用象和士防守,总之仍然下棋很烂。 20 | 21 | 如果您有条件可以再多训练试试,我自从收到信用卡扣款400美元通知以后就把aws下线了:D 贫穷限制了我的想象力O(∩_∩)O 22 | 23 | 我训练的模型文件下载地址:https://pan.baidu.com/s/1dLvxFFpeWZK-aZ2Koewrvg 24 | 25 | 解压后放到项目根目录下即可,文件夹名叫做gpu_models 26 | 27 | 现在介绍下命令如何使用: 28 | 29 | 命令分为两类,一类是训练,一类是下棋。 30 | 31 | 训练专用: 32 | 33 | - --mode 指定是训练(train)还是下棋(play),默认是训练 34 | - --train_playout 指定MCTS的模拟次数,论文中是1600,我做训练时使用1200 35 | - --batch_size 指定训练数据达到多少时开始训练,默认512 36 | - --search_threads 指定执行MCTS时的线程个数,默认16 37 | - --processor 指定是使用cpu还是gpu,默认是cpu 38 | - --num_gpus 指定gpu的个数,默认是1 39 | - --res_block_nums 指定残差块的层数,论文中是19或39层,我默认是7 40 | 41 | 下棋专用: 42 | 43 | - --ai_count 指定ai的个数,1是人机对战,2是看两个ai下棋 44 | - --ai_function 指定ai的下棋方法,是思考(mcts,会慢),还是直觉(net,下棋快) 45 | - --play_playout 指定ai进行MCTS的模拟次数 46 | - --delay和--end_delay默认就好,两个ai下棋太快,就不知道俩ai怎么下的了:) 47 | - --human_color 指定人类棋手的颜色,w是先手,b是后手 48 | 49 | 训练命令举例: 50 | 51 | python main.py --mode train --train_playout 1200 --batch_size 512 --search_threads 16 --processor gpu --num_gpus 2 --res_block_nums 7 52 | 53 | 下棋命令举例: 54 | 55 | python main.py --mode play --ai_count 1 --ai_function mcts --play_playout 1200 --human_color w 56 | 57 | # 许可 58 | Licensed under the MIT License with the [`996ICU License`](https://github.com/996icu/996.ICU/blob/master/LICENSE). 59 | -------------------------------------------------------------------------------- /assets/a1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/a1.png -------------------------------------------------------------------------------- /assets/a10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/a10.png -------------------------------------------------------------------------------- /assets/a11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/a11.png -------------------------------------------------------------------------------- /assets/a2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/a2.png -------------------------------------------------------------------------------- /assets/a3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/a3.png -------------------------------------------------------------------------------- /assets/a4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/a4.png -------------------------------------------------------------------------------- /assets/a5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/a5.png -------------------------------------------------------------------------------- /assets/a6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/a6.png -------------------------------------------------------------------------------- /assets/a7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/a7.png -------------------------------------------------------------------------------- /assets/a8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/a8.png -------------------------------------------------------------------------------- /assets/a9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/a9.png -------------------------------------------------------------------------------- /assets/b1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/b1.png -------------------------------------------------------------------------------- /assets/b2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/b2.png -------------------------------------------------------------------------------- /assets/b3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/b3.png -------------------------------------------------------------------------------- /assets/b4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/b4.png -------------------------------------------------------------------------------- /assets/b5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/b5.png -------------------------------------------------------------------------------- /assets/b6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/b6.png -------------------------------------------------------------------------------- /assets/b7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/b7.png -------------------------------------------------------------------------------- /assets/b8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/b8.png -------------------------------------------------------------------------------- /assets/b9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/b9.png -------------------------------------------------------------------------------- /assets/c1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/c1.png -------------------------------------------------------------------------------- /assets/c2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/c2.png -------------------------------------------------------------------------------- /assets/c3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/assets/c3.png -------------------------------------------------------------------------------- /chessman/Bing.py: -------------------------------------------------------------------------------- 1 | from ChessPiece import ChessPiece 2 | import sys 3 | 4 | 5 | class Bing(ChessPiece): 6 | 7 | def get_image_file_name(self): 8 | if self.selected: 9 | if self.is_red: 10 | return "images/RPS.gif" 11 | else: 12 | return "images/BPS.gif" 13 | else: 14 | if self.is_red: 15 | return "images/RP.gif" 16 | else: 17 | return "images/BP.gif" 18 | 19 | def get_selected_image(self): 20 | if self.is_red: 21 | return "images/RPS.gif" 22 | else: 23 | return "images/BPS.gif" 24 | 25 | def can_move(self, board, dx, dy): 26 | if abs(dx) + abs(dy) != 1: 27 | # print('Too far') 28 | return False 29 | if (self.is_north() and dy == -1) or (self.is_south() and dy==1): 30 | # print('cannot go back') 31 | return False 32 | if dy == 0: 33 | if (self.is_north() and self.y <5) or (self.is_south() and self.y >=5): 34 | # print('behind river') 35 | return False 36 | nx, ny = self.x + dx, self.y + dy 37 | if nx < 0 or nx > 8 or ny < 0 or ny > 9: 38 | return False 39 | if (nx, ny) in board.pieces: 40 | if board.pieces[nx, ny].is_red == self.is_red: 41 | # print('blocked by yourself') 42 | return False 43 | else: 44 | pass 45 | #print 'kill a chessman' 46 | return True 47 | 48 | def __init__(self, x, y, is_red, direction): 49 | ChessPiece.__init__(self, x, y, is_red, direction) 50 | 51 | def display(self): 52 | sys.stdout.write('B') -------------------------------------------------------------------------------- /chessman/Bing.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/chessman/Bing.pyc -------------------------------------------------------------------------------- /chessman/Che.py: -------------------------------------------------------------------------------- 1 | 2 | from ChessPiece import ChessPiece 3 | 4 | class Che(ChessPiece): 5 | 6 | def get_image_file_name(self): 7 | if self.selected: 8 | if self.is_red: 9 | return "images/RRS.gif" 10 | else: 11 | return "images/BRS.gif" 12 | else: 13 | if self.is_red: 14 | return "images/RR.gif" 15 | else: 16 | return "images/BR.gif" 17 | 18 | def get_selected_image(self): 19 | if self.is_red: 20 | return "images/RRS.gif" 21 | else: 22 | return "images/BRS.gif" 23 | 24 | def can_move(self, board, dx, dy): 25 | if dx != 0 and dy != 0: 26 | #print 'no diag' 27 | return False 28 | nx, ny = self.x + dx, self.y + dy 29 | if nx < 0 or nx > 8 or ny < 0 or ny > 9: 30 | return False 31 | if (nx, ny) in board.pieces: 32 | if board.pieces[nx, ny].is_red == self.is_red: 33 | #print 'blocked by yourself' 34 | return False 35 | cnt = self.count_pieces(board, self.x, self.y, dx, dy) 36 | # print 'Che cnt', cnt 37 | if (nx, ny) not in board.pieces: 38 | if cnt!= 0: 39 | #print 'blocked' 40 | return False 41 | else: 42 | if cnt != 0: 43 | #print 'cannot kill' 44 | return False 45 | print ('kill a chessman') 46 | return True 47 | 48 | def __init__(self, x, y, is_red, direction): 49 | ChessPiece.__init__(self, x, y, is_red, direction) 50 | 51 | -------------------------------------------------------------------------------- /chessman/Che.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/chessman/Che.pyc -------------------------------------------------------------------------------- /chessman/Ma.py: -------------------------------------------------------------------------------- 1 | 2 | from ChessPiece import ChessPiece 3 | 4 | 5 | class Ma(ChessPiece): 6 | 7 | def get_image_file_name(self): 8 | if self.selected: 9 | if self.is_red: 10 | return "images/RNS.gif" 11 | else: 12 | return "images/BNS.gif" 13 | else: 14 | if self.is_red: 15 | return "images/RN.gif" 16 | else: 17 | return "images/BN.gif" 18 | 19 | def get_selected_image(self): 20 | if self.is_red: 21 | return "images/RNS.gif" 22 | else: 23 | return "images/BNS.gif" 24 | 25 | def can_move(self, board, dx, dy): 26 | x, y = self.x, self.y 27 | nx, ny = x+dx, y+dy 28 | if nx < 0 or nx > 8 or ny < 0 or ny > 9: 29 | return False 30 | if dx == 0 or dy == 0: 31 | #print 'no straight' 32 | return False 33 | if abs(dx) + abs(dy) !=3: 34 | #print 'not normal' 35 | return False 36 | if (nx, ny) in board.pieces: 37 | if board.pieces[nx, ny].is_red == self.is_red: 38 | #print 'blocked by yourself' 39 | return False 40 | if (x if abs(dx) ==1 else x+dx/2, y if abs(dy) ==1 else y+ (dy/2)) in board.pieces: 41 | #print 'blocked' 42 | return False 43 | return True 44 | 45 | def __init__(self, x, y, is_red, direction): 46 | ChessPiece.__init__(self, x, y, is_red, direction) 47 | 48 | -------------------------------------------------------------------------------- /chessman/Ma.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/chessman/Ma.pyc -------------------------------------------------------------------------------- /chessman/Pao.py: -------------------------------------------------------------------------------- 1 | 2 | from ChessPiece import ChessPiece 3 | 4 | 5 | class Pao(ChessPiece): 6 | 7 | def get_image_file_name(self): 8 | if self.selected: 9 | if self.is_red: 10 | return "images/RCS.gif" 11 | else: 12 | return "images/BCS.gif" 13 | else: 14 | if self.is_red: 15 | return "images/RC.gif" 16 | else: 17 | return "images/BC.gif" 18 | 19 | def get_selected_image(self): 20 | if self.is_red: 21 | return "images/RCS.gif" 22 | else: 23 | return "images/BCS.gif" 24 | 25 | def can_move(self, board, dx, dy): 26 | if dx != 0 and dy != 0: 27 | #print 'no diag' 28 | return False 29 | nx, ny = self.x + dx, self.y + dy 30 | if nx < 0 or nx > 8 or ny < 0 or ny > 9: 31 | return False 32 | if (nx, ny) in board.pieces: 33 | if board.pieces[nx, ny].is_red == self.is_red: 34 | #print 'blocked by yourself' 35 | return False 36 | cnt = self.count_pieces(board, self.x, self.y, dx, dy) 37 | # print 'Pao cnt',cnt 38 | if (nx, ny) not in board.pieces: 39 | if cnt!= 0: 40 | #print 'blocked' 41 | return False 42 | else: 43 | if cnt != 1: 44 | #print 'cannot kill' 45 | return False 46 | return True 47 | 48 | def __init__(self, x, y, is_red, direction): 49 | ChessPiece.__init__(self, x, y, is_red, direction) 50 | 51 | -------------------------------------------------------------------------------- /chessman/Pao.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/chessman/Pao.pyc -------------------------------------------------------------------------------- /chessman/Shi.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Zhaoliang' 2 | from ChessPiece import ChessPiece 3 | 4 | 5 | class Shi(ChessPiece): 6 | 7 | def get_image_file_name(self): 8 | if self.selected: 9 | if self.is_red: 10 | return "images/RAS.gif" 11 | else: 12 | return "images/BAS.gif" 13 | else: 14 | if self.is_red: 15 | return "images/RA.gif" 16 | else: 17 | return "images/BA.gif" 18 | 19 | def get_selected_image(self): 20 | if self.is_red: 21 | return "images/RAS.gif" 22 | else: 23 | return "images/BAS.gif" 24 | 25 | def can_move(self, board, dx, dy): 26 | nx, ny = self.x + dx, self.y + dy 27 | if nx < 0 or nx > 8 or ny < 0 or ny > 9: 28 | return False 29 | if (nx, ny) in board.pieces: 30 | if board.pieces[nx, ny].is_red == self.is_red: 31 | #print 'blocked by yourself' 32 | return False 33 | x, y = self.x, self.y 34 | if not (self.is_north() and 3 <= nx <=5 and 0<= ny <=2) and\ 35 | not (self.is_south() and 3 <= nx <= 5 and 7 <= ny <= 9): 36 | #print 'out of castle' 37 | return False 38 | if self.is_north() and (nx, ny) == (4, 1) or (x,y) == (4,1): 39 | if abs(dx)>1 or abs(dy)>1: 40 | #print 'too far' 41 | return False 42 | if self.is_south() and (nx, ny) == (4, 8) or (x,y) == (4,8): 43 | if abs(dx)>1 or abs(dy)>1: 44 | #print 'too far' 45 | return False 46 | #below modified by Fei Li 47 | if abs(dx) != 1 or abs(dy) != 1: 48 | #print 'no diag' 49 | return False 50 | return True 51 | 52 | def __init__(self, x, y, is_red, direction): 53 | ChessPiece.__init__(self, x, y, is_red, direction) 54 | 55 | -------------------------------------------------------------------------------- /chessman/Shi.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/chessman/Shi.pyc -------------------------------------------------------------------------------- /chessman/Shuai.py: -------------------------------------------------------------------------------- 1 | from ChessPiece import ChessPiece 2 | 3 | 4 | class Shuai(ChessPiece): 5 | 6 | is_king = True 7 | def get_image_file_name(self): 8 | if self.selected: 9 | if self.is_red: 10 | return "images/RKS.gif" 11 | else: 12 | return "images/BKS.gif" 13 | else: 14 | if self.is_red: 15 | return "images/RK.gif" 16 | else: 17 | return "images/BK.gif" 18 | 19 | def get_selected_image(self): 20 | if self.is_red: 21 | return "images/RKS.gif" 22 | else: 23 | return "images/BKS.gif" 24 | 25 | def can_move(self, board, dx, dy): 26 | # print 'king' 27 | nx, ny = self.x + dx, self.y + dy 28 | if nx < 0 or nx > 8 or ny < 0 or ny > 9: 29 | return False 30 | if (nx, ny) in board.pieces: 31 | if board.pieces[nx, ny].is_red == self.is_red: 32 | #print 'blocked by yourself' 33 | return False 34 | if dx == 0 and self.count_pieces(board, self.x, self.y, dx, dy) == 0 and ((nx, ny) in board.pieces) and board.pieces[nx, ny].is_king: 35 | return True 36 | if not (self.is_north() and 3 <= nx <=5 and 0<= ny <=2) and not (self.is_south() and 3 <= nx <= 5 and 7 <= ny <= 9): 37 | # print 'out of castle' 38 | return False 39 | if abs(dx) + abs(dy) !=1: 40 | #print 'too far' 41 | return False 42 | return True 43 | 44 | def __init__(self, x, y, is_red, direction): 45 | ChessPiece.__init__(self, x, y, is_red, direction) 46 | 47 | -------------------------------------------------------------------------------- /chessman/Shuai.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/chessman/Shuai.pyc -------------------------------------------------------------------------------- /chessman/Xiang.py: -------------------------------------------------------------------------------- 1 | from ChessPiece import ChessPiece 2 | 3 | 4 | class Xiang(ChessPiece): 5 | 6 | def get_image_file_name(self): 7 | if self.selected: 8 | if self.is_red: 9 | return "images/RBS.gif" 10 | else: 11 | return "images/BBS.gif" 12 | else: 13 | if self.is_red: 14 | return "images/RB.gif" 15 | else: 16 | return "images/BB.gif" 17 | 18 | def get_selected_image(self): 19 | if self.is_red: 20 | return "images/RBS.gif" 21 | else: 22 | return "images/BBS.gif" 23 | 24 | def can_move(self, board, dx, dy): 25 | x,y = self.x, self.y 26 | nx, ny = x + dx, y + dy 27 | if nx < 0 or nx > 8 or ny < 0 or ny > 9: 28 | return False 29 | if (nx, ny) in board.pieces: 30 | if board.pieces[nx, ny].is_red == self.is_red: 31 | #print 'blocked by yourself' 32 | return False 33 | if (self.is_north() and ny > 4) or (self.is_south() and ny <5): 34 | #print 'no river cross' 35 | return False 36 | 37 | if abs(dx)!=2 or abs(dy)!=2: 38 | #print 'not normal' 39 | return False 40 | sx, sy = dx/abs(dx), dy/abs(dy) 41 | if (x+sx, y+sy) in board.pieces: 42 | #print 'blocked' 43 | return False 44 | return True 45 | 46 | def __init__(self, x, y, is_red, direction): 47 | ChessPiece.__init__(self, x, y, is_red, direction) 48 | 49 | -------------------------------------------------------------------------------- /chessman/Xiang.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/chessman/Xiang.pyc -------------------------------------------------------------------------------- /chessman/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Zhaoliang' 2 | -------------------------------------------------------------------------------- /chessman/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/chessman/__init__.pyc -------------------------------------------------------------------------------- /chessman/__pycache__/Bing.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/chessman/__pycache__/Bing.cpython-35.pyc -------------------------------------------------------------------------------- /chessman/__pycache__/Che.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/chessman/__pycache__/Che.cpython-35.pyc -------------------------------------------------------------------------------- /chessman/__pycache__/Ma.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/chessman/__pycache__/Ma.cpython-35.pyc -------------------------------------------------------------------------------- /chessman/__pycache__/Pao.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/chessman/__pycache__/Pao.cpython-35.pyc -------------------------------------------------------------------------------- /chessman/__pycache__/Shi.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/chessman/__pycache__/Shi.cpython-35.pyc -------------------------------------------------------------------------------- /chessman/__pycache__/Shuai.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/chessman/__pycache__/Shuai.cpython-35.pyc -------------------------------------------------------------------------------- /chessman/__pycache__/Xiang.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/chessman/__pycache__/Xiang.cpython-35.pyc -------------------------------------------------------------------------------- /chessman/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/chessman/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /images/BA.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/BA.GIF -------------------------------------------------------------------------------- /images/BAS.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/BAS.GIF -------------------------------------------------------------------------------- /images/BB.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/BB.GIF -------------------------------------------------------------------------------- /images/BBS.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/BBS.GIF -------------------------------------------------------------------------------- /images/BC.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/BC.GIF -------------------------------------------------------------------------------- /images/BCS.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/BCS.GIF -------------------------------------------------------------------------------- /images/BK.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/BK.GIF -------------------------------------------------------------------------------- /images/BKM.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/BKM.GIF -------------------------------------------------------------------------------- /images/BKS.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/BKS.GIF -------------------------------------------------------------------------------- /images/BN.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/BN.GIF -------------------------------------------------------------------------------- /images/BNS.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/BNS.GIF -------------------------------------------------------------------------------- /images/BP.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/BP.GIF -------------------------------------------------------------------------------- /images/BPS.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/BPS.GIF -------------------------------------------------------------------------------- /images/BR.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/BR.GIF -------------------------------------------------------------------------------- /images/BRS.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/BRS.GIF -------------------------------------------------------------------------------- /images/OOS.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/OOS.GIF -------------------------------------------------------------------------------- /images/RA.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/RA.GIF -------------------------------------------------------------------------------- /images/RAS.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/RAS.GIF -------------------------------------------------------------------------------- /images/RB.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/RB.GIF -------------------------------------------------------------------------------- /images/RBS.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/RBS.GIF -------------------------------------------------------------------------------- /images/RC.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/RC.GIF -------------------------------------------------------------------------------- /images/RCS.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/RCS.GIF -------------------------------------------------------------------------------- /images/RK.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/RK.GIF -------------------------------------------------------------------------------- /images/RKM.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/RKM.GIF -------------------------------------------------------------------------------- /images/RKS.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/RKS.GIF -------------------------------------------------------------------------------- /images/RN.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/RN.GIF -------------------------------------------------------------------------------- /images/RNS.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/RNS.GIF -------------------------------------------------------------------------------- /images/RP.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/RP.GIF -------------------------------------------------------------------------------- /images/RPS.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/RPS.GIF -------------------------------------------------------------------------------- /images/RR.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/RR.GIF -------------------------------------------------------------------------------- /images/RRS.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/RRS.GIF -------------------------------------------------------------------------------- /images/WHITE.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengstone/cchess-zero/7661eea60404bbf19c490096d66c16b174f81688/images/WHITE.GIF -------------------------------------------------------------------------------- /policy_value_network.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | import os 6 | 7 | 8 | class policy_value_network(object): 9 | def __init__(self, res_block_nums = 7): 10 | # self.ckpt = os.path.join(os.getcwd(), 'models/best_model.ckpt-13999') # TODO 11 | self.save_dir = "./models" 12 | self.is_logging = True 13 | 14 | """reset TF Graph""" 15 | tf.reset_default_graph() 16 | """Creat a new graph for the network""" 17 | # g = tf.Graph() 18 | 19 | self.sess = tf.Session() 20 | # self.sess = tf.InteractiveSession() 21 | 22 | # Variables 23 | self.filters_size = 128 # or 256 24 | self.prob_size = 2086 25 | self.digest = None 26 | self.training = tf.placeholder(tf.bool, name='training') 27 | self.inputs_ = tf.placeholder(tf.float32, [None, 9, 10, 14], name='inputs') # + 2 # TODO C plain x 2 28 | self.c_l2 = 0.0001 29 | self.momentum = 0.9 30 | self.global_norm = 100 31 | self.learning_rate = tf.placeholder(tf.float32, name='learning_rate') #0.001 #5e-3 #0.05 # 32 | tf.summary.scalar('learning_rate', self.learning_rate) 33 | 34 | # First block 35 | self.pi_ = tf.placeholder(tf.float32, [None, self.prob_size], name='pi') 36 | self.z_ = tf.placeholder(tf.float32, [None, 1], name='z') 37 | 38 | # NWHC format 39 | # batch, 9 * 10, 14 channels 40 | # inputs_ = tf.reshape(self.inputs_, [-1, 9, 10, 14]) 41 | # data_format: A string, one of `channels_last` (default) or `channels_first`. 42 | # The ordering of the dimensions in the inputs. 43 | # `channels_last` corresponds to inputs with shape `(batch, width, height, channels)` 44 | # while `channels_first` corresponds to inputs with shape `(batch, channels, width, height)`. 45 | self.layer = tf.layers.conv2d(self.inputs_, self.filters_size, 3, padding='SAME') # filters 128(or 256) 46 | 47 | self.layer = tf.contrib.layers.batch_norm(self.layer, center=False, epsilon=1e-5, fused=True, 48 | is_training=self.training, activation_fn=tf.nn.relu) # epsilon = 0.25 49 | 50 | # residual_block 51 | with tf.name_scope("residual_block"): 52 | for _ in range(res_block_nums): 53 | self.layer = self.residual_block(self.layer) 54 | 55 | # policy_head 56 | with tf.name_scope("policy_head"): 57 | self.policy_head = tf.layers.conv2d(self.layer, 2, 1, padding='SAME') 58 | self.policy_head = tf.contrib.layers.batch_norm(self.policy_head, center=False, epsilon=1e-5, fused=True, 59 | is_training=self.training, activation_fn=tf.nn.relu) 60 | 61 | # print(self.policy_head.shape) # (?, 9, 10, 2) 62 | self.policy_head = tf.reshape(self.policy_head, [-1, 9 * 10 * 2]) 63 | self.policy_head = tf.contrib.layers.fully_connected(self.policy_head, self.prob_size, activation_fn=None) 64 | # self.prediction = tf.nn.softmax(self.policy_head) 65 | 66 | # value_head 67 | with tf.name_scope("value_head"): 68 | self.value_head = tf.layers.conv2d(self.layer, 1, 1, padding='SAME') 69 | self.value_head = tf.contrib.layers.batch_norm(self.value_head, center=False, epsilon=1e-5, fused=True, 70 | is_training=self.training, activation_fn=tf.nn.relu) 71 | # print(self.value_head.shape) # (?, 9, 10, 1) 72 | self.value_head = tf.reshape(self.value_head, [-1, 9 * 10 * 1]) 73 | self.value_head = tf.contrib.layers.fully_connected(self.value_head, 256, activation_fn=tf.nn.relu) 74 | self.value_head = tf.contrib.layers.fully_connected(self.value_head, 1, activation_fn=tf.nn.tanh) 75 | 76 | # loss 77 | with tf.name_scope("loss"): 78 | self.policy_loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.pi_, logits=self.policy_head) 79 | self.policy_loss = tf.reduce_mean(self.policy_loss) 80 | 81 | # self.value_loss = tf.squared_difference(self.z_, self.value_head) 82 | self.value_loss = tf.losses.mean_squared_error(labels=self.z_, predictions=self.value_head) 83 | self.value_loss = tf.reduce_mean(self.value_loss) 84 | tf.summary.scalar('mse_loss', self.value_loss) 85 | 86 | regularizer = tf.contrib.layers.l2_regularizer(scale=self.c_l2) 87 | regular_variables = tf.trainable_variables() 88 | self.l2_loss = tf.contrib.layers.apply_regularization(regularizer, regular_variables) 89 | 90 | # self.loss = self.value_loss - self.policy_loss + self.l2_loss 91 | self.loss = self.value_loss + self.policy_loss + self.l2_loss 92 | tf.summary.scalar('loss', self.loss) 93 | 94 | # train_op = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss) 95 | self.global_step = tf.Variable(0, name="global_step", trainable=False) 96 | # optimizer = tf.train.AdamOptimizer(self.learning_rate) 97 | # gradients = optimizer.compute_gradients(self.loss) 98 | # train_op = optimizer.apply_gradients(gradients, global_step=global_step) 99 | 100 | # 优化损失 101 | optimizer = tf.train.MomentumOptimizer( 102 | learning_rate=self.learning_rate, momentum=self.momentum, use_nesterov=True) 103 | 104 | # self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 105 | # with tf.control_dependencies(self.update_ops): 106 | # self.train_op = optimizer.minimize(self.loss, global_step=self.global_step) 107 | 108 | # Accuracy 109 | correct_prediction = tf.equal(tf.argmax(self.policy_head, 1), tf.argmax(self.pi_, 1)) 110 | correct_prediction = tf.cast(correct_prediction, tf.float32) 111 | self.accuracy = tf.reduce_mean(correct_prediction, name='accuracy') 112 | tf.summary.scalar('move_accuracy', self.accuracy) 113 | 114 | # grads = self.average_gradients(tower_grads) 115 | grads = optimizer.compute_gradients(self.loss) 116 | # defensive step 2 to clip norm 117 | clipped_grads, self.norm = tf.clip_by_global_norm( 118 | [g for g, _ in grads], self.global_norm) 119 | 120 | # defensive step 3 check NaN 121 | # See: https://stackoverflow.com/questions/40701712/how-to-check-nan-in-gradients-in-tensorflow-when-updating 122 | grad_check = [tf.check_numerics(g, message='NaN Found!') for g in clipped_grads] 123 | with tf.control_dependencies(grad_check): 124 | self.train_op = optimizer.apply_gradients( 125 | zip(clipped_grads, [v for _, v in grads]), 126 | global_step=self.global_step, name='train_step') 127 | 128 | if self.is_logging: 129 | for grad, var in grads: 130 | if grad is not None: 131 | tf.summary.histogram(var.op.name + '/gradients', grad) 132 | for var in tf.trainable_variables(): 133 | tf.summary.histogram(var.op.name, var) 134 | 135 | self.summaries_op = tf.summary.merge_all() 136 | 137 | # Train Summaries 138 | self.train_writer = tf.summary.FileWriter( 139 | os.path.join(os.getcwd(), "cchesslogs/train"), self.sess.graph) 140 | 141 | # Test summaries 142 | self.test_writer = tf.summary.FileWriter( 143 | os.path.join(os.getcwd(), "cchesslogs/test"), self.sess.graph) 144 | 145 | self.sess.run(tf.global_variables_initializer()) 146 | # self.sess.run(tf.local_variables_initializer()) 147 | # self.sess.run(tf.initialize_all_variables()) 148 | self.saver = tf.train.Saver() 149 | self.train_restore() 150 | 151 | def residual_block(self, in_layer): 152 | orig = tf.identity(in_layer) 153 | 154 | layer = tf.layers.conv2d(in_layer, self.filters_size, 3, padding='SAME') # filters 128(or 256) 155 | layer = tf.contrib.layers.batch_norm(layer, center=False, epsilon=1e-5, fused=True, 156 | is_training=self.training, activation_fn=tf.nn.relu) 157 | 158 | layer = tf.layers.conv2d(layer, self.filters_size, 3, padding='SAME') # filters 128(or 256) 159 | layer = tf.contrib.layers.batch_norm(layer, center=False, epsilon=1e-5, fused=True, is_training=self.training) 160 | out = tf.nn.relu(tf.add(orig, layer)) 161 | 162 | return out 163 | 164 | def train_restore(self): 165 | if not os.path.isdir(self.save_dir): 166 | os.mkdir(self.save_dir) 167 | checkpoint = tf.train.get_checkpoint_state(self.save_dir) 168 | if checkpoint and checkpoint.model_checkpoint_path: 169 | # self.saver.restore(self.sess, checkpoint.model_checkpoint_path) 170 | self.saver.restore(self.sess, tf.train.latest_checkpoint(self.save_dir)) 171 | print("Successfully loaded:", tf.train.latest_checkpoint(self.save_dir)) 172 | # print("Successfully loaded:", checkpoint.model_checkpoint_path) 173 | else: 174 | print("Could not find old network weights") 175 | 176 | def restore(self, file): 177 | print("Restoring from {0}".format(file)) 178 | self.saver.restore(self.sess, file) # self.ckpt 179 | 180 | def save(self, in_global_step): 181 | # save_path = self.saver.save(self.sess, path, global_step=self.global_step) 182 | save_path = self.saver.save(self.sess, os.path.join(self.save_dir, 'best_model.ckpt'), 183 | global_step=in_global_step) #self.global_step 184 | print("Model saved in file: {}".format(save_path)) 185 | 186 | def train_step(self, positions, probs, winners, learning_rate): 187 | feed_dict = { 188 | self.inputs_: positions, 189 | self.training: True, 190 | self.learning_rate: learning_rate, 191 | self.pi_: probs, 192 | self.z_: winners 193 | } 194 | 195 | _, accuracy, loss, global_step, summary = self.sess.run([self.train_op, self.accuracy, self.loss, self.global_step, self.summaries_op], feed_dict=feed_dict) 196 | self.train_writer.add_summary(summary, global_step) 197 | # print(accuracy) 198 | # print(loss) 199 | return accuracy, loss, global_step 200 | 201 | #@profile 202 | def forward(self, positions): # , probs, winners 203 | feed_dict = { 204 | self.inputs_: positions, 205 | self.training: False 206 | } 207 | # , 208 | # self.pi_: probs, 209 | # self.z_: winners 210 | action_probs, value = self.sess.run([self.policy_head, self.value_head], feed_dict=feed_dict) # self.prediction 211 | # print(action_probs.shape) 212 | # print(value.shape) 213 | 214 | return action_probs, value 215 | # return action_probs, value -------------------------------------------------------------------------------- /policy_value_network_gpus.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | import os 6 | 7 | PS_OPS = ['Variable', 'VariableV2', 'AutoReloadVariable'] 8 | 9 | 10 | class policy_value_network_gpus(object): 11 | def __init__(self, num_gpus = 1, res_block_nums = 7): 12 | # self.ckpt = os.path.join(os.getcwd(), 'models/best_model.ckpt-13999') # TODO 13 | self.num_gpus = num_gpus 14 | self.save_dir = "./gpu_models" 15 | self.is_logging = True 16 | self.res_block_nums = res_block_nums 17 | 18 | """reset TF Graph""" 19 | tf.reset_default_graph() 20 | """Creat a new graph for the network""" 21 | # g = tf.Graph() 22 | 23 | config = tf.ConfigProto( 24 | inter_op_parallelism_threads=4, 25 | intra_op_parallelism_threads=4) 26 | config.gpu_options.allow_growth = True 27 | config.allow_soft_placement = True 28 | """Assign a Session that excute the network""" 29 | # config.gpu_options.per_process_gpu_memory_fraction = 0.75 30 | # self.sess = tf.Session(config=config, graph=g) 31 | 32 | # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.75) 33 | # config = tf.ConfigProto(gpu_options=gpu_options) 34 | self.sess = tf.Session(config=config) 35 | # self.sess = tf.InteractiveSession() 36 | 37 | with tf.device('/cpu:0'): 38 | # Variables 39 | self.filters_size = 128 # or 256 40 | self.prob_size = 2086 41 | self.digest = None 42 | self.training = tf.placeholder(tf.bool, name='training') 43 | self.inputs_ = tf.placeholder(tf.float32, [None, 9, 10, 14], name='inputs') # + 2 # TODO C plain x 2 44 | self.c_l2 = 0.0001 45 | self.momentum = 0.9 46 | self.global_norm = 100 47 | self.learning_rate = tf.placeholder(tf.float32, name='learning_rate') #0.001 #5e-3 #0.05 # 48 | 49 | self.global_step = tf.Variable(0, name="global_step", trainable=False) 50 | # self.learning_rate = tf.maximum(tf.train.exponential_decay( 51 | # 0.001, self.global_step, 1e3, 0.66), 1e-5) 52 | # self.learning_rate = tf.Variable(self.hps.lrn_rate, dtype=tf.float32, trainable=False) 53 | tf.summary.scalar('learning_rate', self.learning_rate) 54 | 55 | # 优化损失 56 | optimizer = tf.train.MomentumOptimizer( 57 | learning_rate=self.learning_rate, momentum=self.momentum, use_nesterov=True) # , use_locking=True 58 | # optimizer = tf.train.AdamOptimizer(self.learning_rate) 59 | 60 | # First block 61 | self.pi_ = tf.placeholder(tf.float32, [None, self.prob_size], name='pi') 62 | self.z_ = tf.placeholder(tf.float32, [None, 1], name='z') 63 | 64 | # batch_queue = tf.contrib.slim.prefetch_queue.prefetch_queue([self.inputs_, self.pi_, self.z_], capacity=3 * self.num_gpus) 65 | 66 | inputs_batches = tf.split(self.inputs_, self.num_gpus, axis=0) 67 | pi_batches = tf.split(self.pi_, self.num_gpus, axis=0) 68 | z_batches = tf.split(self.z_, self.num_gpus, axis=0) 69 | 70 | 71 | tower_grads = [None] * self.num_gpus 72 | 73 | self.loss = 0 74 | self.accuracy = 0 75 | self.policy_head = [] 76 | self.value_head = [] 77 | 78 | with tf.variable_scope(tf.get_variable_scope()): 79 | """Build the core model within the graph.""" 80 | for i in range(self.num_gpus): 81 | with tf.device(self.assign_to_device('/gpu:{}'.format(i), ps_device='/cpu:0')): #tf.device('/gpu:{i}'): 82 | with tf.name_scope('TOWER_{}'.format(i)) as scope: 83 | inputs_batch, pi_batch, z_batch = inputs_batches[i], pi_batches[i], z_batches[i] # batch_queue.dequeue() # 84 | # NWHC format 85 | # batch, 9 * 10, 14 channels 86 | # inputs_ = tf.reshape(self.inputs_, [-1, 9, 10, 14]) 87 | loss = self.tower_loss(inputs_batch, pi_batch, z_batch, i) 88 | # reuse variable happens here 89 | tf.get_variable_scope().reuse_variables() 90 | grad = optimizer.compute_gradients(loss) 91 | tower_grads[i] = grad 92 | 93 | self.loss /= self.num_gpus 94 | self.accuracy /= self.num_gpus 95 | grads = self.average_gradients(tower_grads) 96 | # defensive step 2 to clip norm 97 | clipped_grads, self.norm = tf.clip_by_global_norm( 98 | [g for g, _ in grads], self.global_norm) 99 | 100 | # defensive step 3 check NaN 101 | # See: https://stackoverflow.com/questions/40701712/how-to-check-nan-in-gradients-in-tensorflow-when-updating 102 | grad_check = [tf.check_numerics(g, message='NaN Found!') for g in clipped_grads] 103 | with tf.control_dependencies(grad_check): 104 | self.train_op = optimizer.apply_gradients( 105 | zip(clipped_grads, [v for _, v in grads]), 106 | global_step=self.global_step, name='train_step') 107 | 108 | if self.is_logging: 109 | for grad, var in grads: 110 | if grad is not None: 111 | tf.summary.histogram(var.op.name + '/gradients', grad) 112 | for var in tf.trainable_variables(): 113 | tf.summary.histogram(var.op.name, var) 114 | 115 | self.summaries_op = tf.summary.merge_all() 116 | # Train Summaries 117 | self.train_writer = tf.summary.FileWriter( 118 | os.path.join(os.getcwd(), "cchesslogs/train"), self.sess.graph) 119 | 120 | # Test summaries 121 | self.test_writer = tf.summary.FileWriter( 122 | os.path.join(os.getcwd(), "cchesslogs/test"), self.sess.graph) 123 | 124 | self.sess.run(tf.global_variables_initializer()) 125 | # self.sess.run(tf.local_variables_initializer()) 126 | # self.sess.run(tf.initialize_all_variables()) 127 | self.saver = tf.train.Saver() 128 | self.train_restore() 129 | 130 | def tower_loss(self, inputs_batch, pi_batch, z_batch, i): 131 | with tf.variable_scope('init'): 132 | layer = tf.layers.conv2d(inputs_batch, self.filters_size, 3, padding='SAME') # filters 128(or 256) 133 | 134 | layer = tf.contrib.layers.batch_norm(layer, center=False, epsilon=1e-5, fused=True, 135 | is_training=self.training, activation_fn=tf.nn.relu) # epsilon = 0.25 136 | 137 | # residual_block 138 | with tf.variable_scope("residual_block"): 139 | for _ in range(self.res_block_nums): 140 | layer = self.residual_block(layer) 141 | 142 | # policy_head 143 | with tf.variable_scope("policy_head"): 144 | policy_head = tf.layers.conv2d(layer, 2, 1, padding='SAME') 145 | policy_head = tf.contrib.layers.batch_norm(policy_head, center=False, epsilon=1e-5, fused=True, 146 | is_training=self.training, activation_fn=tf.nn.relu) 147 | 148 | # print(self.policy_head.shape) # (?, 9, 10, 2) 149 | policy_head = tf.reshape(policy_head, [-1, 9 * 10 * 2]) 150 | policy_head = tf.contrib.layers.fully_connected(policy_head, self.prob_size, activation_fn=None) 151 | # prediction = tf.nn.softmax(policy_head) 152 | self.policy_head.append(policy_head) #prediction 153 | 154 | # value_head 155 | with tf.variable_scope("value_head"): 156 | value_head = tf.layers.conv2d(layer, 1, 1, padding='SAME') 157 | value_head = tf.contrib.layers.batch_norm(value_head, center=False, epsilon=1e-5, fused=True, 158 | is_training=self.training, activation_fn=tf.nn.relu) 159 | # print(self.value_head.shape) # (?, 9, 10, 1) 160 | value_head = tf.reshape(value_head, [-1, 9 * 10 * 1]) 161 | value_head = tf.contrib.layers.fully_connected(value_head, 256, activation_fn=tf.nn.relu) 162 | value_head = tf.contrib.layers.fully_connected(value_head, 1, activation_fn=tf.nn.tanh) 163 | self.value_head.append(value_head) 164 | 165 | # loss 166 | with tf.variable_scope("loss"): 167 | policy_loss = tf.nn.softmax_cross_entropy_with_logits(labels=pi_batch, logits=policy_head) #self.pi_ 168 | policy_loss = tf.reduce_mean(policy_loss) 169 | 170 | # self.value_loss = tf.squared_difference(self.z_, self.value_head) 171 | value_loss = tf.losses.mean_squared_error(labels=z_batch, predictions=value_head) #self.z_ 172 | value_loss = tf.reduce_mean(value_loss) 173 | tf.summary.scalar('mse_tower_{}'.format(i), value_loss) 174 | 175 | regularizer = tf.contrib.layers.l2_regularizer(scale=self.c_l2) 176 | regular_variables = tf.trainable_variables() 177 | l2_loss = tf.contrib.layers.apply_regularization(regularizer, regular_variables) 178 | 179 | # self.loss = self.value_loss - self.policy_loss + self.l2_loss 180 | loss = value_loss + policy_loss + l2_loss 181 | self.loss += loss 182 | tf.summary.scalar('loss_tower_{}'.format(i), loss) 183 | # train_op = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss) 184 | # self.global_step = tf.Variable(0, name="global_step", trainable=False) 185 | # optimizer = tf.train.AdamOptimizer(self.learning_rate) 186 | # gradients = optimizer.compute_gradients(self.loss) 187 | # train_op = optimizer.apply_gradients(gradients, global_step=global_step) 188 | 189 | # self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 190 | # with tf.control_dependencies(self.update_ops): 191 | # self.train_op = optimizer.minimize(self.loss, global_step=self.global_step) 192 | with tf.variable_scope("accuracy"): 193 | # Accuracy 194 | correct_prediction = tf.equal(tf.argmax(policy_head, 1), tf.argmax(pi_batch, 1)) #self.pi_ 195 | correct_prediction = tf.cast(correct_prediction, tf.float32) 196 | accuracy = tf.reduce_mean(correct_prediction, name='accuracy') 197 | self.accuracy += accuracy 198 | tf.summary.scalar('move_accuracy_tower_{}'.format(i), accuracy) 199 | return loss 200 | 201 | 202 | # By default, all variables will be placed on '/gpu:0' 203 | # So we need a custom device function, to assign all variables to '/cpu:0' 204 | # Note: If GPUs are peered, '/gpu:0' can be a faster option 205 | 206 | def assign_to_device(self, device, ps_device='/cpu:0'): 207 | def _assign(op): 208 | node_def = op if isinstance(op, tf.NodeDef) else op.node_def 209 | if node_def.op in PS_OPS: 210 | return "/" + ps_device 211 | else: 212 | return device 213 | 214 | return _assign 215 | 216 | def average_gradients(self, tower_grads): 217 | """Calculate the average gradient for each shared variable across all towers. 218 | Note that this function provides a synchronization point across all towers. 219 | Args: 220 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 221 | is over individual gradients. The inner list is over the gradient 222 | calculation for each tower. 223 | Returns: 224 | List of pairs of (gradient, variable) where the gradient has been averaged 225 | across all towers. 226 | """ 227 | average_grads = [] 228 | for grad_and_vars in zip(*tower_grads): 229 | # Note that each grad_and_vars looks like the following: 230 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 231 | grads = [] 232 | for g, var in grad_and_vars: 233 | # Add 0 dimension to the gradients to represent the tower. 234 | # print('Network variables: {var.name}') 235 | expanded_g = tf.expand_dims(g, 0) 236 | 237 | # Append on a 'tower' dimension which we will average over below. 238 | grads.append(expanded_g) 239 | 240 | # Average over the 'tower' dimension. 241 | grad = tf.concat(axis=0, values=grads) 242 | grad = tf.reduce_mean(grad, 0) 243 | 244 | # Keep in mind that the Variables are redundant because they are shared 245 | # across towers. So .. we will just return the first tower's pointer to 246 | # the Variable. 247 | v = grad_and_vars[0][1] 248 | grad_and_var = (grad, v) 249 | average_grads.append(grad_and_var) 250 | return average_grads 251 | 252 | def residual_block(self, in_layer): 253 | orig = tf.identity(in_layer) 254 | 255 | layer = tf.layers.conv2d(in_layer, self.filters_size, 3, padding='SAME') # filters 128(or 256) 256 | layer = tf.contrib.layers.batch_norm(layer, center=False, epsilon=1e-5, fused=True, 257 | is_training=self.training, activation_fn=tf.nn.relu) 258 | 259 | layer = tf.layers.conv2d(layer, self.filters_size, 3, padding='SAME') # filters 128(or 256) 260 | layer = tf.contrib.layers.batch_norm(layer, center=False, epsilon=1e-5, fused=True, is_training=self.training) 261 | out = tf.nn.relu(tf.add(orig, layer)) 262 | 263 | return out 264 | 265 | def train_restore(self): 266 | if not os.path.isdir(self.save_dir): 267 | os.mkdir(self.save_dir) 268 | checkpoint = tf.train.get_checkpoint_state(self.save_dir) 269 | if checkpoint and checkpoint.model_checkpoint_path: 270 | # self.saver.restore(self.sess, checkpoint.model_checkpoint_path) 271 | self.saver.restore(self.sess, tf.train.latest_checkpoint(self.save_dir)) 272 | print("Successfully loaded:", tf.train.latest_checkpoint(self.save_dir)) 273 | # print("Successfully loaded:", checkpoint.model_checkpoint_path) 274 | else: 275 | print("Could not find old network weights") 276 | 277 | def restore(self, file): 278 | print("Restoring from {0}".format(file)) 279 | self.saver.restore(self.sess, file) # self.ckpt 280 | 281 | def save(self, in_global_step): 282 | # save_path = self.saver.save(self.sess, path, global_step=self.global_step) 283 | save_path = self.saver.save(self.sess, os.path.join(self.save_dir, 'best_model.ckpt'), 284 | global_step=in_global_step) #self.global_step 285 | print("Model saved in file: {}".format(save_path)) 286 | 287 | def train_step(self, positions, probs, winners, learning_rate): 288 | feed_dict = { 289 | self.inputs_: positions, 290 | self.training: True, 291 | self.learning_rate: learning_rate, 292 | self.pi_: probs, 293 | self.z_: winners 294 | } 295 | 296 | 297 | # try: 298 | _, accuracy, loss, global_step, summary = self.sess.run([self.train_op, self.accuracy, self.loss, self.global_step, self.summaries_op], feed_dict=feed_dict) 299 | self.train_writer.add_summary(summary, global_step) 300 | # print(accuracy) 301 | # print(loss) 302 | return accuracy, loss, global_step 303 | # except tf.errors.InvalidArgumentError: 304 | # print('Contains NaN gradients.') 305 | # continue 306 | 307 | #@profile 308 | def forward(self, positions): # , probs, winners 309 | # print("positions.shape : ", positions.shape) 310 | positions = np.array(positions) 311 | batch_n = positions.shape[0] // self.num_gpus 312 | alone = positions.shape[0] % self.num_gpus 313 | 314 | if alone != 0: 315 | if(positions.shape[0] != 1): 316 | feed_dict = { 317 | self.inputs_: positions[:positions.shape[0] - alone], 318 | self.training: False 319 | } 320 | action_probs, value = self.sess.run([self.policy_head, self.value_head], feed_dict=feed_dict) 321 | action_probs, value = np.vstack(action_probs), np.vstack(value) 322 | 323 | new_positions = positions[positions.shape[0] - alone:] 324 | pos_lst = [] 325 | while len(pos_lst) == 0 or (np.array(pos_lst).shape[0] * np.array(pos_lst).shape[1]) % self.num_gpus != 0: 326 | pos_lst.append(new_positions) 327 | 328 | if(len(pos_lst) != 0): 329 | shape = np.array(pos_lst).shape 330 | pos_lst = np.array(pos_lst).reshape([shape[0] * shape[1], 9, 10, 14]) 331 | 332 | feed_dict = { 333 | self.inputs_: pos_lst, 334 | self.training: False 335 | } 336 | action_probs_2, value_2 = self.sess.run([self.policy_head, self.value_head], feed_dict=feed_dict) 337 | # print("action_probs_2.shape : ", np.array(action_probs_2).shape) 338 | # print("value_2.shape : ", np.array(value_2).shape) 339 | action_probs_2, value_2 = action_probs_2[0], value_2[0] 340 | # print("------------------------") 341 | # print("action_probs_2.shape : ", np.array(action_probs_2).shape) 342 | # print("value_2.shape : ", np.array(value_2).shape) 343 | 344 | if(positions.shape[0] != 1): 345 | action_probs = np.concatenate((action_probs, action_probs_2),axis=0) 346 | value = np.concatenate((value, value_2),axis=0) 347 | 348 | # print("action_probs.shape : ", np.array(action_probs).shape) 349 | # print("value.shape : ", np.array(value).shape) 350 | return action_probs, value 351 | else: 352 | return action_probs_2, value_2 353 | else: 354 | feed_dict = { 355 | self.inputs_: positions, 356 | self.training: False 357 | } 358 | action_probs, value = self.sess.run([self.policy_head, self.value_head], feed_dict=feed_dict) 359 | # print("np.vstack(action_probs) shape : ", np.vstack(action_probs).shape) 360 | # print("np.vstack(value) shape : ", np.vstack(value).shape) 361 | 362 | return np.vstack(action_probs), np.vstack(value) 363 | # feed_dict = { 364 | # self.inputs_: positions if len(pos_lst) == 0 else pos_lst, 365 | # self.training: False 366 | # } 367 | 368 | # , 369 | # self.pi_: probs, 370 | # self.z_: winners 371 | 372 | # action_probs, value = self.sess.run([self.policy_head, self.value_head], feed_dict=feed_dict) 373 | # print(action_probs.shape) 374 | # print(value.shape) 375 | 376 | # with multi-gpu, porbs and values are separated in each outputs 377 | # so vstack will merge them together. 378 | 379 | # return np.vstack(action_probs), np.vstack(value) 380 | # return action_probs, value -------------------------------------------------------------------------------- /policy_value_network_gpus_tf2.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import tensorflow as tf 3 | import numpy as np 4 | from tensorflow.python.ops import summary_ops_v2 5 | import os 6 | 7 | PS_OPS = ['Variable', 'VariableV2', 'AutoReloadVariable'] 8 | 9 | # pip install tf-nightly-gpu-2.0-preview 10 | # require compute capabilities >= 3.5 11 | # cuda 10 12 | 13 | class policy_value_network_gpus(object): 14 | def __init__(self, learning_rate_fn, res_block_nums = 7): 15 | # self.ckpt = os.path.join(os.getcwd(), 'models/best_model.ckpt-13999') # TODO 16 | self.save_dir = "./models" 17 | self.is_logging = True 18 | 19 | if tf.io.gfile.exists(self.save_dir): 20 | # print('Removing existing model dir: {}'.format(MODEL_DIR)) 21 | # tf.io.gfile.rmtree(MODEL_DIR) 22 | pass 23 | else: 24 | tf.io.gfile.makedirs(self.save_dir) 25 | 26 | train_dir = os.path.join(self.save_dir, 'summaries', 'train') 27 | test_dir = os.path.join(self.save_dir, 'summaries', 'eval') 28 | 29 | self.train_summary_writer = summary_ops_v2.create_file_writer(train_dir, flush_millis=10000) 30 | self.test_summary_writer = summary_ops_v2.create_file_writer(test_dir, flush_millis=10000, name='test') 31 | 32 | self.strategy = tf.distribute.MirroredStrategy() 33 | print ('Number of devices: {}'.format(self.strategy.num_replicas_in_sync)) 34 | 35 | self.distributed_train = lambda it: self.strategy.experimental_run(self.train_step, it) 36 | self.distributed_train = tf.function(self.distributed_train) 37 | 38 | with tf.device('/cpu:0'): 39 | self.global_step = tf.Variable(0, name="global_step", trainable=False) 40 | 41 | with self.strategy.scope(): 42 | 43 | # Variables 44 | self.filters_size = 128 # or 256 45 | self.prob_size = 2086 46 | self.digest = None 47 | 48 | self.inputs_ = tf.keras.layers.Input([9, 10, 14], dtype='float32', name='inputs') # TODO C plain x 2 49 | self.c_l2 = 0.0001 50 | self.momentum = 0.9 51 | self.global_norm = 100 52 | 53 | self.layer = tf.keras.layers.Conv2D(kernel_size=3, filters=self.filters_size, padding='same')(self.inputs_) 54 | self.layer = tf.keras.layers.BatchNormalization(epsilon=1e-5, fused=True)(self.layer) 55 | self.layer = tf.keras.layers.ReLU()(self.layer) 56 | 57 | # residual_block 58 | with tf.name_scope("residual_block"): 59 | for _ in range(res_block_nums): 60 | self.layer = self.residual_block(self.layer) 61 | 62 | # policy_head 63 | with tf.name_scope("policy_head"): 64 | self.policy_head = tf.keras.layers.Conv2D(filters=2, kernel_size=1, padding='same')(self.layer) 65 | self.policy_head = tf.keras.layers.BatchNormalization(epsilon=1e-5, fused=True)(self.policy_head) 66 | self.policy_head = tf.keras.layers.ReLU()(self.policy_head) 67 | 68 | self.policy_head = tf.keras.layers.Reshape([9 * 10 * 2])(self.policy_head) 69 | self.policy_head = tf.keras.layers.Dense(self.prob_size)(self.policy_head) 70 | 71 | # value_head 72 | with tf.name_scope("value_head"): 73 | self.value_head = tf.keras.layers.Conv2D(filters=1, kernel_size=1, padding='same')(self.layer) 74 | self.value_head = tf.keras.layers.BatchNormalization(epsilon=1e-5, fused=True)( 75 | self.value_head) 76 | self.value_head = tf.keras.layers.ReLU()(self.value_head) 77 | 78 | self.value_head = tf.keras.layers.Reshape([9 * 10 * 1])(self.value_head) 79 | self.value_head = tf.keras.layers.Dense(256, activation='relu')(self.value_head) 80 | self.value_head = tf.keras.layers.Dense(1, activation='tanh')(self.value_head) 81 | 82 | self.model = tf.keras.Model( 83 | inputs=[self.inputs_], 84 | outputs=[self.policy_head, self.value_head]) 85 | 86 | self.model.summary() 87 | 88 | 89 | # 优化损失 90 | self.optimizer = tf.compat.v1.train.MomentumOptimizer( 91 | learning_rate=learning_rate_fn, momentum=self.momentum, use_nesterov=True) 92 | 93 | self.CategoricalCrossentropyLoss = tf.keras.losses.CategoricalCrossentropy(from_logits=True) 94 | self.MSE = tf.keras.losses.MeanSquaredError() 95 | self.ComputeMetrics = tf.keras.metrics.CategoricalAccuracy() 96 | self.avg_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32) 97 | 98 | # self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 99 | # with tf.control_dependencies(self.update_ops): 100 | # self.train_op = optimizer.minimize(self.loss, global_step=self.global_step) 101 | 102 | self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoints') 103 | self.checkpoint_prefix = os.path.join(self.checkpoint_dir, 'ckpt') 104 | self.checkpoint = tf.train.Checkpoint(model=self.model, optimizer=self.optimizer) 105 | 106 | # Restore variables on creation if a checkpoint exists. 107 | self.checkpoint.restore(tf.train.latest_checkpoint(self.checkpoint_dir)) 108 | 109 | 110 | def residual_block(self, in_layer): 111 | orig = tf.convert_to_tensor(in_layer) # tf.identity(in_layer) 112 | layer = tf.keras.layers.Conv2D(kernel_size=3, filters=self.filters_size, padding='same')(in_layer) 113 | layer = tf.keras.layers.BatchNormalization(epsilon=1e-5, fused=True)(layer) 114 | layer = tf.keras.layers.ReLU()(layer) 115 | 116 | layer = tf.keras.layers.Conv2D(kernel_size=3, filters=self.filters_size, padding='same')(layer) 117 | layer = tf.keras.layers.BatchNormalization(epsilon=1e-5, fused=True)(layer) 118 | add_layer = tf.keras.layers.add([orig, layer]) 119 | out = tf.keras.layers.ReLU()(add_layer) 120 | 121 | return out 122 | 123 | # def train_restore(self): 124 | # if not os.path.isdir(self.save_dir): 125 | # os.mkdir(self.save_dir) 126 | # checkpoint = tf.train.get_checkpoint_state(self.save_dir) 127 | # if checkpoint and checkpoint.model_checkpoint_path: 128 | # # self.saver.restore(self.sess, checkpoint.model_checkpoint_path) 129 | # self.saver.restore(self.sess, tf.train.latest_checkpoint(self.save_dir)) 130 | # print("Successfully loaded:", tf.train.latest_checkpoint(self.save_dir)) 131 | # # print("Successfully loaded:", checkpoint.model_checkpoint_path) 132 | # else: 133 | # print("Could not find old network weights") 134 | 135 | # def restore(self, file): 136 | # print("Restoring from {0}".format(file)) 137 | # self.saver.restore(self.sess, file) # self.ckpt 138 | 139 | def save(self, in_global_step): 140 | with self.strategy.scope(): 141 | self.checkpoint.save(self.checkpoint_prefix) 142 | # print("Model saved in file: {}".format(save_path)) 143 | 144 | def compute_metrics(self, pi_, policy_head): 145 | # Accuracy 146 | correct_prediction = tf.equal(tf.argmax(input=policy_head, axis=1), tf.argmax(input=pi_, axis=1)) 147 | correct_prediction = tf.cast(correct_prediction, tf.float32) 148 | accuracy = tf.reduce_mean(input_tensor=correct_prediction, name='accuracy') 149 | 150 | # summary_ops_v2.scalar('move_accuracy', accuracy) 151 | return accuracy 152 | 153 | def apply_regularization(self, regularizer, weights_list=None): 154 | """Returns the summed penalty by applying `regularizer` to the `weights_list`. 155 | Adding a regularization penalty over the layer weights and embedding weights 156 | can help prevent overfitting the training data. Regularization over layer 157 | biases is less common/useful, but assuming proper data preprocessing/mean 158 | subtraction, it usually shouldn't hurt much either. 159 | Args: 160 | regularizer: A function that takes a single `Tensor` argument and returns 161 | a scalar `Tensor` output. 162 | weights_list: List of weights `Tensors` or `Variables` to apply 163 | `regularizer` over. Defaults to the `GraphKeys.WEIGHTS` collection if 164 | `None`. 165 | Returns: 166 | A scalar representing the overall regularization penalty. 167 | Raises: 168 | ValueError: If `regularizer` does not return a scalar output, or if we find 169 | no weights. 170 | """ 171 | # if not weights_list: 172 | # weights_list = ops.get_collection(ops.GraphKeys.WEIGHTS) 173 | if not weights_list: 174 | raise ValueError('No weights to regularize.') 175 | with tf.name_scope('get_regularization_penalty', 176 | values=weights_list) as scope: 177 | penalties = [regularizer(w) for w in weights_list] 178 | penalties = [ 179 | p if p is not None else tf.constant(0.0) for p in penalties 180 | ] 181 | for p in penalties: 182 | if p.get_shape().ndims != 0: 183 | raise ValueError('regularizer must return a scalar Tensor instead of a ' 184 | 'Tensor with rank %d.' % p.get_shape().ndims) 185 | 186 | summed_penalty = tf.add_n(penalties, name=scope) 187 | # ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, summed_penalty) 188 | return summed_penalty 189 | 190 | def compute_loss(self, pi_, z_, policy_head, value_head): 191 | 192 | # loss 193 | with tf.name_scope("loss"): 194 | policy_loss = tf.keras.losses.categorical_crossentropy(y_true=pi_, y_pred=policy_head, from_logits=True) 195 | policy_loss = tf.reduce_mean(policy_loss) 196 | 197 | value_loss = tf.keras.losses.mean_squared_error(z_, value_head) 198 | value_loss = tf.reduce_mean(value_loss) 199 | # summary_ops_v2.scalar('mse_loss', value_loss) 200 | 201 | regularizer = tf.keras.regularizers.l2(self.c_l2) 202 | regular_variables = self.model.trainable_variables 203 | l2_loss = self.apply_regularization(regularizer, regular_variables) 204 | 205 | # self.loss = value_loss - policy_loss + l2_loss 206 | self.loss = value_loss + policy_loss + l2_loss 207 | # summary_ops_v2.scalar('loss', self.loss) 208 | 209 | return self.loss 210 | 211 | # TODO(yashkatariya): Add tf.function when b/123315763 is resolved 212 | # @tf.function 213 | def train_step(self, it, learning_rate=0): 214 | positions = it[0] 215 | pi = it[1] 216 | z = it[2] 217 | # print("tf.executing_eagerly() ", tf.executing_eagerly()) 218 | # print("positions.shape ", positions.shape) 219 | # print("pi ", pi) 220 | # print("z ", z) 221 | # print("learning_rate ", learning_rate) 222 | 223 | # Record the operations used to compute the loss, so that the gradient 224 | # of the loss with respect to the variables can be computed. 225 | # metrics = 0 226 | 227 | # with self.strategy.scope(): 228 | if True: 229 | with tf.GradientTape() as tape: 230 | policy_head, value_head = self.model(positions, training=True) 231 | loss = self.compute_loss(pi, z, policy_head, value_head) 232 | # loss = self.compute_loss(labels, logits) 233 | self.ComputeMetrics(pi, policy_head) 234 | self.avg_loss(loss) 235 | # metrics = self.compute_metrics(pi, policy_head) 236 | grads = tape.gradient(loss, self.model.trainable_variables) 237 | # print("grads ", grads) 238 | # print("metrics ", self.ComputeMetrics.result()) 239 | # print("loss ", loss) 240 | 241 | # grads = self.average_gradients(tower_grads) 242 | # grads = self.optimizer.compute_gradients(self.loss) 243 | # defensive step 2 to clip norm 244 | # grads0_lst = tf.map_fn(lambda x: x[0], grads) # [g for g, _ in grads] 245 | # clipped_grads, self.norm = tf.clip_by_global_norm(grads, self.global_norm) 246 | 247 | self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) 248 | 249 | # self.optimizer.apply_gradients(zip(clipped_grads, self.model.trainable_variables)) 250 | 251 | # defensive step 3 check NaN 252 | # See: https://stackoverflow.com/questions/40701712/how-to-check-nan-in-gradients-in-tensorflow-when-updating 253 | # grad_check = [tf.debugging.check_numerics(g, message='NaN Found!') for g in clipped_grads] 254 | # with tf.control_dependencies(grad_check): 255 | # self.optimizer.apply_gradients( 256 | # zip(clipped_grads, self.model.trainable_variables), # [v for _, v in grads] 257 | # global_step=self.global_step, name='train_step') 258 | 259 | # if self.is_logging: 260 | # for grad, var in zip(grads, self.model.trainable_variables): 261 | # if grad is not None: 262 | # summary_ops_v2.histogram(var.name + '/gradients', grad) 263 | # for var in self.model.trainable_variables: 264 | # summary_ops_v2.histogram(var.name, var) 265 | 266 | 267 | # self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) 268 | return self.ComputeMetrics.result(), self.avg_loss.result(), self.global_step 269 | 270 | #@profile 271 | def forward(self, positions): 272 | 273 | with self.strategy.scope(): 274 | positions=np.array(positions) 275 | if len(positions.shape) == 3: 276 | sp = positions.shape 277 | positions=np.reshape(positions, [1, sp[0], sp[1], sp[2]]) 278 | action_probs, value = self.model(positions, training=False) 279 | 280 | return action_probs, value 281 | -------------------------------------------------------------------------------- /policy_value_network_tf2.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import tensorflow as tf 3 | import numpy as np 4 | from tensorflow.python.ops import summary_ops_v2 5 | import os 6 | 7 | 8 | class policy_value_network(object): 9 | def __init__(self, learning_rate_fn, res_block_nums = 7): 10 | # self.ckpt = os.path.join(os.getcwd(), 'models/best_model.ckpt-13999') # TODO 11 | self.save_dir = "./models" 12 | self.is_logging = True 13 | 14 | if tf.io.gfile.exists(self.save_dir): 15 | # print('Removing existing model dir: {}'.format(MODEL_DIR)) 16 | # tf.io.gfile.rmtree(MODEL_DIR) 17 | pass 18 | else: 19 | tf.io.gfile.makedirs(self.save_dir) 20 | 21 | train_dir = os.path.join(self.save_dir, 'summaries', 'train') 22 | test_dir = os.path.join(self.save_dir, 'summaries', 'eval') 23 | 24 | self.train_summary_writer = summary_ops_v2.create_file_writer(train_dir, flush_millis=10000) 25 | self.test_summary_writer = summary_ops_v2.create_file_writer(test_dir, flush_millis=10000, name='test') 26 | 27 | # Variables 28 | self.filters_size = 128 # or 256 29 | self.prob_size = 2086 30 | self.digest = None 31 | 32 | self.inputs_ = tf.keras.layers.Input([9, 10, 14], dtype='float32', name='inputs') # TODO C plain x 2 33 | self.c_l2 = 0.0001 34 | self.momentum = 0.9 35 | self.global_norm = 100 36 | 37 | self.layer = tf.keras.layers.Conv2D(kernel_size=3, filters=self.filters_size, padding='same')(self.inputs_) 38 | self.layer = tf.keras.layers.BatchNormalization(epsilon=1e-5, fused=True)(self.layer) 39 | self.layer = tf.keras.layers.ReLU()(self.layer) 40 | 41 | # residual_block 42 | with tf.name_scope("residual_block"): 43 | for _ in range(res_block_nums): 44 | self.layer = self.residual_block(self.layer) 45 | 46 | # policy_head 47 | with tf.name_scope("policy_head"): 48 | self.policy_head = tf.keras.layers.Conv2D(filters=2, kernel_size=1, padding='same')(self.layer) 49 | self.policy_head = tf.keras.layers.BatchNormalization(epsilon=1e-5, fused=True)(self.policy_head) 50 | self.policy_head = tf.keras.layers.ReLU()(self.policy_head) 51 | 52 | self.policy_head = tf.keras.layers.Reshape([9 * 10 * 2])(self.policy_head) 53 | self.policy_head = tf.keras.layers.Dense(self.prob_size)(self.policy_head) 54 | 55 | # value_head 56 | with tf.name_scope("value_head"): 57 | self.value_head = tf.keras.layers.Conv2D(filters=1, kernel_size=1, padding='same')(self.layer) 58 | self.value_head = tf.keras.layers.BatchNormalization(epsilon=1e-5, fused=True)( 59 | self.value_head) 60 | self.value_head = tf.keras.layers.ReLU()(self.value_head) 61 | 62 | self.value_head = tf.keras.layers.Reshape([9 * 10 * 1])(self.value_head) 63 | self.value_head = tf.keras.layers.Dense(256, activation='relu')(self.value_head) 64 | self.value_head = tf.keras.layers.Dense(1, activation='tanh')(self.value_head) 65 | 66 | self.model = tf.keras.Model( 67 | inputs=[self.inputs_], 68 | outputs=[self.policy_head, self.value_head]) 69 | 70 | self.model.summary() 71 | 72 | self.global_step = tf.Variable(0, name="global_step", trainable=False) 73 | # optimizer = tf.train.AdamOptimizer(self.learning_rate) 74 | 75 | # 优化损失 76 | self.optimizer = tf.compat.v1.train.MomentumOptimizer( 77 | learning_rate=learning_rate_fn, momentum=self.momentum, use_nesterov=True) 78 | 79 | self.CategoricalCrossentropyLoss = tf.keras.losses.CategoricalCrossentropy(from_logits=True) 80 | self.MSE = tf.keras.losses.MeanSquaredError() 81 | self.ComputeMetrics = tf.keras.metrics.MeanAbsoluteError() 82 | 83 | 84 | # self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 85 | # with tf.control_dependencies(self.update_ops): 86 | # self.train_op = optimizer.minimize(self.loss, global_step=self.global_step) 87 | 88 | self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoints') 89 | self.checkpoint_prefix = os.path.join(self.checkpoint_dir, 'ckpt') 90 | self.checkpoint = tf.train.Checkpoint(model=self.model, optimizer=self.optimizer) 91 | 92 | # Restore variables on creation if a checkpoint exists. 93 | self.checkpoint.restore(tf.train.latest_checkpoint(self.checkpoint_dir)) 94 | 95 | def residual_block(self, in_layer): 96 | orig = tf.convert_to_tensor(in_layer) # tf.identity(in_layer) 97 | layer = tf.keras.layers.Conv2D(kernel_size=3, filters=self.filters_size, padding='same')(in_layer) 98 | layer = tf.keras.layers.BatchNormalization(epsilon=1e-5, fused=True)(layer) 99 | layer = tf.keras.layers.ReLU()(layer) 100 | 101 | layer = tf.keras.layers.Conv2D(kernel_size=3, filters=self.filters_size, padding='same')(layer) 102 | layer = tf.keras.layers.BatchNormalization(epsilon=1e-5, fused=True)(layer) 103 | add_layer = tf.keras.layers.add([orig, layer]) 104 | out = tf.keras.layers.ReLU()(add_layer) 105 | 106 | return out 107 | 108 | # def train_restore(self): 109 | # if not os.path.isdir(self.save_dir): 110 | # os.mkdir(self.save_dir) 111 | # checkpoint = tf.train.get_checkpoint_state(self.save_dir) 112 | # if checkpoint and checkpoint.model_checkpoint_path: 113 | # # self.saver.restore(self.sess, checkpoint.model_checkpoint_path) 114 | # self.saver.restore(self.sess, tf.train.latest_checkpoint(self.save_dir)) 115 | # print("Successfully loaded:", tf.train.latest_checkpoint(self.save_dir)) 116 | # # print("Successfully loaded:", checkpoint.model_checkpoint_path) 117 | # else: 118 | # print("Could not find old network weights") 119 | 120 | # def restore(self, file): 121 | # print("Restoring from {0}".format(file)) 122 | # self.saver.restore(self.sess, file) # self.ckpt 123 | 124 | def save(self, in_global_step): 125 | self.checkpoint.save(self.checkpoint_prefix) 126 | # print("Model saved in file: {}".format(save_path)) 127 | 128 | def compute_metrics(self, pi_, policy_head): 129 | # Accuracy 130 | correct_prediction = tf.equal(tf.argmax(input=policy_head, axis=1), tf.argmax(input=pi_, axis=1)) 131 | correct_prediction = tf.cast(correct_prediction, tf.float32) 132 | accuracy = tf.reduce_mean(input_tensor=correct_prediction, name='accuracy') 133 | 134 | summary_ops_v2.scalar('move_accuracy', accuracy) 135 | return accuracy 136 | 137 | def apply_regularization(self, regularizer, weights_list=None): 138 | """Returns the summed penalty by applying `regularizer` to the `weights_list`. 139 | Adding a regularization penalty over the layer weights and embedding weights 140 | can help prevent overfitting the training data. Regularization over layer 141 | biases is less common/useful, but assuming proper data preprocessing/mean 142 | subtraction, it usually shouldn't hurt much either. 143 | Args: 144 | regularizer: A function that takes a single `Tensor` argument and returns 145 | a scalar `Tensor` output. 146 | weights_list: List of weights `Tensors` or `Variables` to apply 147 | `regularizer` over. Defaults to the `GraphKeys.WEIGHTS` collection if 148 | `None`. 149 | Returns: 150 | A scalar representing the overall regularization penalty. 151 | Raises: 152 | ValueError: If `regularizer` does not return a scalar output, or if we find 153 | no weights. 154 | """ 155 | # if not weights_list: 156 | # weights_list = ops.get_collection(ops.GraphKeys.WEIGHTS) 157 | if not weights_list: 158 | raise ValueError('No weights to regularize.') 159 | with tf.name_scope('get_regularization_penalty', 160 | values=weights_list) as scope: 161 | penalties = [regularizer(w) for w in weights_list] 162 | penalties = [ 163 | p if p is not None else tf.constant(0.0) for p in penalties 164 | ] 165 | for p in penalties: 166 | if p.get_shape().ndims != 0: 167 | raise ValueError('regularizer must return a scalar Tensor instead of a ' 168 | 'Tensor with rank %d.' % p.get_shape().ndims) 169 | 170 | summed_penalty = tf.add_n(penalties, name=scope) 171 | # ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, summed_penalty) 172 | return summed_penalty 173 | 174 | def compute_loss(self, pi_, z_, policy_head, value_head): 175 | 176 | # loss 177 | with tf.name_scope("loss"): 178 | policy_loss = tf.keras.losses.categorical_crossentropy(y_true=pi_, y_pred=policy_head, from_logits=True) 179 | policy_loss = tf.reduce_mean(policy_loss) 180 | 181 | value_loss = tf.keras.losses.mean_squared_error(z_, value_head) 182 | value_loss = tf.reduce_mean(value_loss) 183 | summary_ops_v2.scalar('mse_loss', value_loss) 184 | 185 | regularizer = tf.keras.regularizers.l2(self.c_l2) 186 | regular_variables = self.model.trainable_variables 187 | l2_loss = self.apply_regularization(regularizer, regular_variables) 188 | 189 | # self.loss = value_loss - policy_loss + l2_loss 190 | self.loss = value_loss + policy_loss + l2_loss 191 | summary_ops_v2.scalar('loss', self.loss) 192 | 193 | return self.loss 194 | 195 | @tf.function 196 | def train_step(self, positions, pi, z, learning_rate=0): 197 | # Record the operations used to compute the loss, so that the gradient 198 | # of the loss with respect to the variables can be computed. 199 | # metrics = 0 200 | 201 | with tf.GradientTape() as tape: 202 | policy_head, value_head = self.model(positions, training=True) 203 | loss = self.compute_loss(pi, z, policy_head, value_head) 204 | # self.ComputeMetrics(y, logits) 205 | metrics = self.compute_metrics(pi, policy_head) 206 | grads = tape.gradient(loss, self.model.trainable_variables) 207 | 208 | # grads = self.average_gradients(tower_grads) 209 | # grads = self.optimizer.compute_gradients(self.loss) 210 | # defensive step 2 to clip norm 211 | # grads0_lst = tf.map_fn(lambda x: x[0], grads) # [g for g, _ in grads] 212 | clipped_grads, self.norm = tf.clip_by_global_norm(grads, self.global_norm) 213 | 214 | # defensive step 3 check NaN 215 | # See: https://stackoverflow.com/questions/40701712/how-to-check-nan-in-gradients-in-tensorflow-when-updating 216 | grad_check = [tf.debugging.check_numerics(g, message='NaN Found!') for g in clipped_grads] 217 | with tf.control_dependencies(grad_check): 218 | self.optimizer.apply_gradients( 219 | zip(clipped_grads, self.model.trainable_variables), # [v for _, v in grads] 220 | global_step=self.global_step, name='train_step') 221 | 222 | if self.is_logging: 223 | for grad, var in zip(grads, self.model.trainable_variables): 224 | if grad is not None: 225 | summary_ops_v2.histogram(var.name + '/gradients', grad) 226 | for var in self.model.trainable_variables: 227 | summary_ops_v2.histogram(var.name, var) 228 | 229 | return metrics, loss, self.global_step 230 | 231 | #@profile 232 | def forward(self, positions): 233 | 234 | positions=np.array(positions) 235 | if len(positions.shape) == 3: 236 | sp = positions.shape 237 | positions=np.reshape(positions, [1, sp[0], sp[1], sp[2]]) 238 | action_probs, value = self.model(positions, training=False) 239 | 240 | return action_probs, value 241 | 242 | --------------------------------------------------------------------------------