├── .gitignore ├── LICENSE ├── README.md ├── demonstration ├── Main.py ├── img │ ├── black_btn.png │ ├── close_btn.png │ ├── easy_btn.png │ ├── exit_btn.png │ ├── hard_btn.png │ ├── idle.jpg │ ├── medium_btn.png │ ├── think_btn.png │ ├── white_btn.png │ └── wood.png ├── mdl │ ├── gomoku_11x11_1000.model.data-00000-of-00001 │ ├── gomoku_11x11_1000.model.index │ ├── gomoku_11x11_1000.model.meta │ ├── gomoku_11x11_2000.model.data-00000-of-00001 │ ├── gomoku_11x11_2000.model.index │ ├── gomoku_11x11_2000.model.meta │ ├── gomoku_11x11_3000.model.data-00000-of-00001 │ ├── gomoku_11x11_3000.model.index │ ├── gomoku_11x11_3000.model.meta │ ├── gomoku_11x11_5000.model.data-00000-of-00001 │ ├── gomoku_11x11_5000.model.index │ └── gomoku_11x11_5000.model.meta ├── src │ ├── GameApp.py │ ├── GameScene.py │ ├── IdleScene.py │ ├── OrderSelectionScene.py │ ├── ThinkScene.py │ ├── WelcomeScene.py │ ├── __init__.py │ ├── components │ │ ├── ChessCanvas.py │ │ ├── GameCanvas.py │ │ ├── ThinkCanvas.py │ │ └── __init__.py │ ├── config.py │ ├── model │ │ ├── CrossPoint.py │ │ ├── Gomoku.py │ │ ├── MCTS_AlphaZero.py │ │ ├── MCTS_Pure.py │ │ ├── PVN_11.py │ │ └── __init__.py │ └── play_data.py └── static │ └── play_data.pkl ├── screenshots ├── 1.jpg ├── 2.jpg └── 3.jpg └── training ├── main.py └── src ├── __init__.py ├── human_play.py ├── model ├── __init__.py ├── game.py ├── inception_resnet_v2.py ├── mcts_alphaZero.py ├── mcts_pure.py ├── pvn_inception.py ├── pvn_resnet.py └── resnet.py ├── train_pipeline.py ├── train_thread.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .idea/* 3 | .DS_Store 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 杨卓谦 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AlphaGomokuZero 2 | An illustration program which visualizes CNN filters inside AlphaZero and its MCTS mechanism in order to provide a better understanding of how an AI makes decisions. Based Tensorflow and Tkinter. Exhibited at China Science and Technology Museum. 3 | 一个通过可视化Alphago Zero中的MCTS机制来解释AI决策方式的程序。展出于中国科技馆。 4 | 5 | ## Screenshots 6 | 7 | ![2.jpg](https://raw.githubusercontent.com/yzhq97/AlphaGomokuZero/master/screenshots/2.jpg) 8 | ![3.jpg](https://raw.githubusercontent.com/yzhq97/AlphaGomokuZero/master/screenshots/3.jpg) 9 | 10 | ## Training 11 | 12 | The training module is based on https://github.com/junxiaosong/AlphaZero_Gomoku. 13 | I experimented numerous models and added a multi-thread training scheme for the policy value net. 14 | 15 | ``` 16 | training 17 | ├── main.py Main script. 18 | └── src 19 | ├── __init__.py 20 | ├── human_play.py A simple command-line client for human play. 21 | ├── model 22 | │   ├── __init__.py 23 | │   ├── game.py Defines rules of Gomoku Game. 24 | │   ├── inception_resnet_v2.py 25 | │   ├── mcts_alphaZero.py AlphaZero player. 26 | │   ├── mcts_pure.py MCTS player. 27 | │   ├── pvn_inception.py inception version of the policy value net. 28 | │   ├── pvn_resnet.py resnet version of the policy value net. 29 | │   └── resnet.py 30 | ├── train_pipeline.py Training pipeline. 31 | ├── train_thread.py A single training thread. 32 | └── utils.py Utilities. 33 | ``` 34 | 35 | ## Demonstration 36 | 37 | Illustration program with GUI. Allows users to play game with AlphaZero and see visualizations of Monte Carlo Tree Search. 38 | Implemented with Tkinter on python3. 39 | 40 | ### GUI Structure 41 | ``` 42 | GameApp 43 | ├── WelcomeScene Welcome Screen. 44 | ├── OrderSelectionScene Scene for selecting play order. 45 | ├── GameScene Game Screen. 46 | ├── IdleScene Displayed when game is not played. 47 | └── ThinkScene Displays visualizations of MCTS. 48 | ``` 49 | 50 | ### Directory Stucture 51 | ``` 52 | demonstration/ 53 | ├── Main.py Main script. 54 | ├── img Store image resources. 55 | ├── mdl Store trained models 56 | ├── src 57 | │   ├── GameApp.py 58 | │   ├── GameScene.py 59 | │   ├── IdleScene.py 60 | │   ├── OrderSelectionScene.py 61 | │   ├── ThinkScene.py 62 | │   ├── WelcomeScene.py 63 | │   ├── __init__.py 64 | │   ├── components UI components. 65 | │   ├── config.py Configurations. 66 | │   ├── model Basic models of game and AI. 67 | │   │   ├── CrossPoint.py 68 | │   │   ├── Gomoku.py 69 | │   │   ├── MCTS_AlphaZero.py 70 | │   │   ├── MCTS_Pure.py 71 | │   │   ├── PVN_11.py 72 | │   │   └── __init__.py 73 | │   └── play_data.py Data structure used to store play data. 74 | └── static Place to store play data 75 | ``` 76 | -------------------------------------------------------------------------------- /demonstration/Main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding: utf-8 -*- 3 | #author: Zhuoqian Yang yzhq97@buaa.edu.cn 4 | 5 | from src.GameApp import GameApp 6 | import tkinter as tk 7 | 8 | root = tk.Tk() 9 | window2 = tk.Toplevel() 10 | GameApp(root, window2) 11 | root.mainloop() -------------------------------------------------------------------------------- /demonstration/img/black_btn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/img/black_btn.png -------------------------------------------------------------------------------- /demonstration/img/close_btn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/img/close_btn.png -------------------------------------------------------------------------------- /demonstration/img/easy_btn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/img/easy_btn.png -------------------------------------------------------------------------------- /demonstration/img/exit_btn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/img/exit_btn.png -------------------------------------------------------------------------------- /demonstration/img/hard_btn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/img/hard_btn.png -------------------------------------------------------------------------------- /demonstration/img/idle.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/img/idle.jpg -------------------------------------------------------------------------------- /demonstration/img/medium_btn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/img/medium_btn.png -------------------------------------------------------------------------------- /demonstration/img/think_btn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/img/think_btn.png -------------------------------------------------------------------------------- /demonstration/img/white_btn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/img/white_btn.png -------------------------------------------------------------------------------- /demonstration/img/wood.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/img/wood.png -------------------------------------------------------------------------------- /demonstration/mdl/gomoku_11x11_1000.model.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/mdl/gomoku_11x11_1000.model.data-00000-of-00001 -------------------------------------------------------------------------------- /demonstration/mdl/gomoku_11x11_1000.model.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/mdl/gomoku_11x11_1000.model.index -------------------------------------------------------------------------------- /demonstration/mdl/gomoku_11x11_1000.model.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/mdl/gomoku_11x11_1000.model.meta -------------------------------------------------------------------------------- /demonstration/mdl/gomoku_11x11_2000.model.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/mdl/gomoku_11x11_2000.model.data-00000-of-00001 -------------------------------------------------------------------------------- /demonstration/mdl/gomoku_11x11_2000.model.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/mdl/gomoku_11x11_2000.model.index -------------------------------------------------------------------------------- /demonstration/mdl/gomoku_11x11_2000.model.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/mdl/gomoku_11x11_2000.model.meta -------------------------------------------------------------------------------- /demonstration/mdl/gomoku_11x11_3000.model.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/mdl/gomoku_11x11_3000.model.data-00000-of-00001 -------------------------------------------------------------------------------- /demonstration/mdl/gomoku_11x11_3000.model.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/mdl/gomoku_11x11_3000.model.index -------------------------------------------------------------------------------- /demonstration/mdl/gomoku_11x11_3000.model.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/mdl/gomoku_11x11_3000.model.meta -------------------------------------------------------------------------------- /demonstration/mdl/gomoku_11x11_5000.model.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/mdl/gomoku_11x11_5000.model.data-00000-of-00001 -------------------------------------------------------------------------------- /demonstration/mdl/gomoku_11x11_5000.model.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/mdl/gomoku_11x11_5000.model.index -------------------------------------------------------------------------------- /demonstration/mdl/gomoku_11x11_5000.model.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/mdl/gomoku_11x11_5000.model.meta -------------------------------------------------------------------------------- /demonstration/src/GameApp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding: utf-8 -*- 3 | #author: Zhuoqian Yang yzhq97@buaa.edu.cn 4 | 5 | import tkinter as tk 6 | import sys 7 | import os 8 | import pickle 9 | import src.config as config 10 | from src.WelcomeScene import WelcomeScene 11 | from src.IdleScene import IdleScene 12 | from src.OrderSelectionScene import OrderSelectionScene 13 | from src.GameScene import GameScene 14 | from src.ThinkScene import ThinkScene 15 | from src.play_data import PlayData 16 | 17 | class GameApp: 18 | def __init__(self, master, win2): 19 | self.root = master 20 | self.win2 = win2 21 | 22 | self.root.configure(background=config.BGCOLOR) 23 | self.win2.configure(background=config.BGCOLOR) 24 | # self.scn_width = self.root.winfo_screenwidth() 25 | # self.scn_height = self.root.winfo_screenheight() 26 | self.scn_width = 1920 27 | self.scn_height = 1080 28 | 29 | self.root.geometry('%dx%d%+d+%d' % (self.scn_width, self.scn_height, 0, 0)) 30 | # self.win2.geometry('%dx%d%+d+%d' % (self.scn_width, self.scn_height, 0, -self.scn_height)) 31 | self.win2.geometry('%dx%d%+d+%d' % (self.scn_width, self.scn_height, 0, 0)) 32 | # self.root.attributes("-fullscreen", True) 33 | # self.win2.attributes("-fullscreen", True) 34 | 35 | self.model = None 36 | self.human_order = None 37 | 38 | self.play_data_file = os.path.join('.', 'static', 'play_data.pkl') 39 | self.play_data = None 40 | self.init_data_file() 41 | 42 | self.board = None 43 | self.mcts_root = None 44 | 45 | self.welcome_scene = WelcomeScene(self.root, self) 46 | self.idle_scene = IdleScene(self.win2, self) 47 | self.order_selection_scene = OrderSelectionScene(self.root, self) 48 | self.game_scene = GameScene(self.root, self) 49 | self.think_scene = ThinkScene(self.win2, self, None, None) 50 | 51 | self.welcome_scene.pack() 52 | self.idle_scene.pack() 53 | self.root.bind("", self.exit) 54 | 55 | def on_select_difficulty_finish(self): 56 | self.welcome_scene.pack_forget() 57 | self.order_selection_scene.pack() 58 | 59 | def on_select_order_finish(self): 60 | self.order_selection_scene.pack_forget() 61 | self.game_scene.start_game(self.model, self.human_order) 62 | self.game_scene.pack() 63 | 64 | def show_think(self): 65 | if self.board is None: return 66 | self.idle_scene.pack_forget() 67 | self.think_scene.board = self.board 68 | self.think_scene.mcts_root = self.mcts_root 69 | self.think_scene.show() 70 | # self.game_scene.pack_forget() 71 | self.think_scene.pack() 72 | 73 | def show_think_finish(self): 74 | del self.board 75 | del self.mcts_root 76 | # self.think_scene.pack_forget() 77 | # del self.think_scene 78 | # self.game_scene.pack() 79 | 80 | def on_game_exit(self): 81 | self.game_scene.pack_forget() 82 | self.think_scene.pack_forget() 83 | self.think_scene.reset() 84 | self.idle_scene.pack() 85 | self.welcome_scene.pack() 86 | 87 | def on_game_end(self, winner): 88 | self.save_data(winner) 89 | 90 | def exit(self, arg): 91 | sys.exit(0) 92 | 93 | def new_data_file(self): 94 | self.play_data = PlayData() 95 | with open(self.play_data_file, 'wb') as output: 96 | pickle.dump(self.play_data, output, pickle.HIGHEST_PROTOCOL) 97 | 98 | 99 | def init_data_file(self): 100 | if os.path.isfile(self.play_data_file): 101 | with open(self.play_data_file, 'rb') as input: 102 | self.play_data = pickle.load(input) 103 | else: 104 | self.new_data_file() 105 | 106 | def save_data(self, winner): 107 | if self.model == config.MODEL_EASY: 108 | if self.human_order == 1: 109 | if winner == 1 - self.human_order: 110 | self.play_data.easy_black_win += 1 111 | elif winner == self.human_order: 112 | self.play_data.easy_black_lose += 1 113 | else: 114 | self.play_data.easy_black_draw += 1 115 | else: 116 | if winner == 1 - self.human_order: 117 | self.play_data.easy_white_win += 1 118 | elif winner == self.human_order: 119 | self.play_data.easy_white_lose += 1 120 | else: 121 | self.play_data.easy_white_draw += 1 122 | if self.model == config.MODEL_MED: 123 | if self.human_order == 1: 124 | if winner == 1 - self.human_order: 125 | self.play_data.medium_black_win += 1 126 | elif winner == self.human_order: 127 | self.play_data.medium_black_lose += 1 128 | else: 129 | self.play_data.medium_black_draw += 1 130 | else: 131 | if winner == 1 - self.human_order: 132 | self.play_data.medium_white_win += 1 133 | elif winner == self.human_order: 134 | self.play_data.medium_white_lose += 1 135 | else: 136 | self.play_data.medium_white_draw += 1 137 | if self.model == config.MODEL_HARD: 138 | if self.human_order == 1: 139 | if winner == 1 - self.human_order: 140 | self.play_data.hard_black_win += 1 141 | elif winner == self.human_order: 142 | self.play_data.hard_black_lose += 1 143 | else: 144 | self.play_data.hard_black_draw += 1 145 | else: 146 | if winner == 1 - self.human_order: 147 | self.play_data.hard_white_win += 1 148 | elif winner == self.human_order: 149 | self.play_data.hard_white_lose += 1 150 | else: 151 | self.play_data.hard_white_draw += 1 152 | 153 | self.welcome_scene.show_play_data() 154 | with open(self.play_data_file, 'wb') as output: 155 | pickle.dump(self.play_data, output) 156 | 157 | 158 | -------------------------------------------------------------------------------- /demonstration/src/GameScene.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #author: Zhuoqian Yang yzhq97@buaa.edu.cn 3 | 4 | import os 5 | import tkinter as tk 6 | import tkinter.font as tkf 7 | import src.config as config 8 | from src.components.GameCanvas import GameCanvas 9 | from PIL import ImageTk, Image 10 | 11 | class GameScene(tk.Frame): 12 | 13 | def __init__(self, master, game_app): 14 | self.N = config.N 15 | self.scn_width = game_app.scn_width 16 | self.scn_height = game_app.scn_height 17 | self.canvas_width = self.scn_height 18 | self.canvas_height = self.scn_height 19 | self.M = self.canvas_height / (self.N+1) 20 | self.play_out = config.PLAYOUT 21 | self.model = None 22 | self.human_order = None 23 | 24 | self.master = master 25 | self.game_app = game_app 26 | self.play_data = self.game_app.play_data 27 | tk.Frame.__init__(self, master, background=config.BGCOLOR) 28 | 29 | self.construct() 30 | 31 | def reset(self): 32 | self.msg_text.set("") 33 | self.status_text.set("") 34 | self.stats_text.set("") 35 | self.canvas.reset() 36 | 37 | def construct(self): 38 | panel_width = (self.scn_width - self.scn_height)/2 39 | panel_height = self.scn_height 40 | 41 | title_font = tkf.Font(family=config.FONTFAMILY, size=24, weight=tkf.BOLD) 42 | msg_font = tkf.Font(family=config.FONTFAMILY, size=20) 43 | 44 | self.msg_text = tk.StringVar() 45 | self.status_text = tk.StringVar() 46 | self.stats_text = tk.StringVar() 47 | 48 | self.mid_frame = tk.Frame(self, background=config.BGCOLOR, padx=20) 49 | 50 | self.canvas = GameCanvas(self.mid_frame, self) 51 | self.canvas.pack() 52 | 53 | self.right_panel = tk.Frame(self, width=panel_width, height=panel_height, padx=10, background=config.BGCOLOR) 54 | 55 | self.title_label = tk.Label(self.right_panel, height=3, text=u"五子棋", bg=config.BGCOLOR, font=title_font) 56 | self.status_label = tk.Label(self.right_panel, height=3, bg=config.BGCOLOR, font=msg_font, textvariable=self.status_text) 57 | self.stats_label = tk.Label(self.right_panel, height=3, bg=config.BGCOLOR, font=msg_font, textvariable=self.stats_text) 58 | self.msg_label = tk.Label(self.right_panel, height=3, bg=config.BGCOLOR, font=msg_font, textvariable=self.msg_text) 59 | self.exit_image = ImageTk.PhotoImage(Image.open(os.path.join(".", "img", "exit_btn.png"))) 60 | self.exit_btn = tk.Button(self.right_panel, command=self.on_exit, 61 | width=260, height=80, border=0, 62 | bg=config.BGCOLOR, activebackground=config.BGCOLOR, 63 | highlightbackground=config.BGCOLOR, 64 | image=self.exit_image) 65 | 66 | self.title_label.pack(side='top') 67 | self.status_label.pack(side='top') 68 | self.stats_label.pack(side='top') 69 | self.msg_label.pack(side='top') 70 | self.exit_btn.pack(side='top') 71 | 72 | self.mid_frame.pack(side='left') 73 | self.right_panel.pack(side='left') 74 | 75 | def start_game(self, model, human_order): 76 | self.reset() 77 | self.model = model 78 | self.human_order = human_order 79 | self.show_play_data() 80 | 81 | status_str = u"当前棋局:" 82 | 83 | if model == config.MODEL_EASY: 84 | status_str += u"简单难度\n" 85 | elif model == config.MODEL_MED: 86 | status_str += u"中等难度\n" 87 | else: 88 | status_str += u"困难难度\n" 89 | 90 | if human_order == 0: 91 | status_str += u"人类执黑,AI执白" 92 | else: 93 | status_str += u"AI执黑,人类执白" 94 | 95 | self.status_text.set(status_str) 96 | 97 | self.canvas.game_start() 98 | 99 | def on_exit(self): 100 | self.canvas.pvn.close() 101 | self.game_app.on_game_exit() 102 | 103 | def on_game_end(self, winner): 104 | self.game_app.on_game_end(winner) 105 | self.show_play_data() 106 | 107 | def show_message(self, msg): 108 | self.msg_text.set(msg) 109 | 110 | def show_play_data(self): 111 | stats_str = u"" 112 | 113 | if self.model == config.MODEL_EASY: 114 | stats_str += u"简单模型,自我对弈训练6000局\n" 115 | stats_str += u"该难度AI在展览期间胜率:\n" \ 116 | u"执黑%4.1f%% 执白%4.1f%%" % ( 117 | 100*self.play_data.easy_black_winrate(), 118 | 100*self.play_data.easy_white_winrate() 119 | ) 120 | elif self.model == config.MODEL_MED: 121 | stats_str += u"中等模型,自我对弈训练8000局\n" 122 | stats_str += u"该难度AI在展览期间胜率:\n" \ 123 | u"执黑%4.1f%% 执白%4.1f%%" % ( 124 | 100 * self.play_data.medium_black_winrate(), 125 | 100 * self.play_data.medium_white_winrate() 126 | ) 127 | else: 128 | stats_str += u"困难模型,自我对弈训练15000局\n" 129 | stats_str += u"该难度AI在展览期间胜率:\n" \ 130 | u"执黑%4.1f%% 执白%4.1f%%" % ( 131 | 100 * self.play_data.hard_black_winrate(), 132 | 100 * self.play_data.hard_white_winrate() 133 | ) 134 | 135 | self.stats_text.set(stats_str) 136 | 137 | 138 | def show_think(self, board, mcts_root): 139 | self.game_app.board = board 140 | self.game_app.mcts_root = mcts_root 141 | self.game_app.show_think() 142 | 143 | def show_think_finish(self): 144 | self.game_app.show_think_finish() -------------------------------------------------------------------------------- /demonstration/src/IdleScene.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #author: Zhuoqian Yang yzhq97@buaa.edu.cn 3 | 4 | import os 5 | import tkinter as tk 6 | import src.config as config 7 | from PIL import ImageTk, Image 8 | 9 | class IdleScene(tk.Frame): 10 | 11 | def __init__(self, master, game_app): 12 | self.master = master 13 | self.game_app = game_app 14 | 15 | tk.Frame.__init__(self, master) 16 | self.construct() 17 | 18 | def construct(self): 19 | self.idle_image = ImageTk.PhotoImage(Image.open(os.path.join(".", "img", "idle.jpg"))) 20 | self.label = tk.Label(self, image=self.idle_image) 21 | self.label.place(x=0, y=0, relwidth=1, relheight=1) 22 | self.label.pack() -------------------------------------------------------------------------------- /demonstration/src/OrderSelectionScene.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #author: Zhuoqian Yang yzhq97@buaa.edu.cn 3 | 4 | 5 | import tkinter as tk 6 | import tkinter.font as tkf 7 | import src.config as config 8 | import os 9 | from PIL import ImageTk, Image 10 | 11 | class OrderSelectionScene(tk.Frame): 12 | def __init__(self, master, game_app): 13 | self.game_app = game_app 14 | self.scn_height = game_app.scn_height 15 | tk.Frame.__init__(self, master, 16 | height=self.scn_height, 17 | background=config.BGCOLOR) 18 | self.construct() 19 | 20 | def construct(self): 21 | title_font = tkf.Font(family=config.FONTFAMILY, size=36, weight=tkf.BOLD) 22 | text_font = tkf.Font(family=config.FONTFAMILY, size=20) 23 | 24 | self.title_panel = tk.Frame(self, pady=self.scn_height * 0.2, bg=config.BGCOLOR) 25 | self.title_label = tk.Label(self.title_panel, height=1, text=u"请选择", 26 | bg=config.BGCOLOR, font=title_font) 27 | self.title_label.pack() 28 | 29 | self.btn_panel = tk.Frame(self) 30 | 31 | self.black_image = ImageTk.PhotoImage(Image.open(os.path.join(".", "img", "black_btn.png"))) 32 | self.white_image = ImageTk.PhotoImage(Image.open(os.path.join(".", "img", "white_btn.png"))) 33 | 34 | self.black_btn = tk.Button(self.btn_panel, command=self.on_black, 35 | width=260, height=80, border=0, 36 | bg=config.BGCOLOR, activebackground=config.BGCOLOR, 37 | highlightbackground=config.BGCOLOR, 38 | image=self.black_image) 39 | self.white_btn = tk.Button(self.btn_panel, command=self.on_white, 40 | width=260, height=80, border=0, 41 | bg=config.BGCOLOR, activebackground=config.BGCOLOR, 42 | highlightbackground=config.BGCOLOR, 43 | image=self.white_image) 44 | self.black_btn.pack(side='top') 45 | self.white_btn.pack(side='top') 46 | 47 | self.title_panel.pack(side='top') 48 | self.btn_panel.pack(side='top') 49 | 50 | def on_black(self): 51 | self.game_app.human_order = 0 52 | self.game_app.on_select_order_finish() 53 | 54 | def on_white(self): 55 | self.game_app.human_order = 1 56 | self.game_app.on_select_order_finish() -------------------------------------------------------------------------------- /demonstration/src/ThinkScene.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #author: Zhuoqian Yang yzhq97@buaa.edu.cn 3 | 4 | import os 5 | import tkinter as tk 6 | import tkinter.font as tkf 7 | import src.config as config 8 | from src.components.ThinkCanvas import ThinkCanvas 9 | from src.components.GameCanvas import GameCanvas 10 | from PIL import ImageTk, Image 11 | 12 | class ThinkScene(tk.Frame): 13 | def __init__(self, master, game_app, board, mcts_root): 14 | self.N = config.N 15 | self.scn_width = game_app.scn_width 16 | self.scn_height = game_app.scn_height 17 | self.canvas_width = self.scn_height 18 | self.canvas_height = self.scn_height 19 | self.M = (self.canvas_height-200)/ 3 / (self.N+1) 20 | self.board = board 21 | self.mcts_root = mcts_root 22 | 23 | self.master = master 24 | self.game_app = game_app 25 | tk.Frame.__init__(self, master, background=config.BGCOLOR) 26 | 27 | self.construct() 28 | 29 | def construct(self): 30 | title_font = tkf.Font(family=config.FONTFAMILY, size=18, weight=tkf.BOLD) 31 | status_font = tkf.Font(family=config.FONTFAMILY, size=14) 32 | 33 | self.left_panel = tk.Frame(self, pady=10, bg=config.BGCOLOR) 34 | 35 | self.title_label = tk.Label(self.left_panel, font=title_font, height=1, text=u"AI思维过程展示", bg=config.BGCOLOR) 36 | self.desc_label = tk.Label(self.left_panel, font=status_font, height=15, width=32, 37 | wraplength = 375, anchor='w', justify='left', 38 | text=config.THINK_DESC, bg=config.BGCOLOR) 39 | self.close_image = ImageTk.PhotoImage(Image.open(os.path.join(".", "img", "close_btn.png"))) 40 | self.close_btn = tk.Button(self.left_panel, command=self.show_think_finish, 41 | width=260, height=80, border=0, 42 | bg=config.BGCOLOR, activebackground=config.BGCOLOR, 43 | highlightbackground=config.BGCOLOR, 44 | image=self.close_image) 45 | self.title_label.pack(side='top') 46 | self.desc_label.pack(side='top') 47 | # self.close_btn.pack(side='top') 48 | 49 | self.main_panel = tk.Frame(self, padx=30, pady=10, bg=config.BGCOLOR) 50 | 51 | self.board_col_1 = tk.Frame(self.main_panel, padx=30, bg=config.BGCOLOR) 52 | self.desc_label_1 = tk.Label(self.board_col_1, font=title_font, height=1, text=u"第一轮", bg=config.BGCOLOR) 53 | self.canvas_1_1 = ThinkCanvas(self.board_col_1, self.N, self.M) 54 | self.desc_label_1_1 = tk.Label(self.board_col_1, font=status_font, height=1, text=u"走法1", bg=config.BGCOLOR) 55 | self.canvas_1_2 = ThinkCanvas(self.board_col_1, self.N, self.M) 56 | self.desc_label_1_2 = tk.Label(self.board_col_1, font=status_font, height=1, text=u"走法2", bg=config.BGCOLOR) 57 | self.canvas_1_3 = ThinkCanvas(self.board_col_1, self.N, self.M) 58 | self.desc_label_1_3 = tk.Label(self.board_col_1, font=status_font, height=1, text=u"走法3", bg=config.BGCOLOR) 59 | self.desc_label_1.pack(side='top') 60 | self.canvas_1_1.pack(side='top') 61 | self.desc_label_1_1.pack(side='top') 62 | self.canvas_1_2.pack(side='top') 63 | self.desc_label_1_2.pack(side='top') 64 | self.canvas_1_3.pack(side='top') 65 | self.desc_label_1_3.pack(side='top') 66 | self.board_col_1.pack(side='left') 67 | 68 | self.board_col_2 = tk.Frame(self.main_panel, padx=60, bg=config.BGCOLOR) 69 | self.desc_label_2 = tk.Label(self.board_col_2, font=title_font, height=1, text=u"第二轮", bg=config.BGCOLOR) 70 | self.canvas_2_1 = ThinkCanvas(self.board_col_2, self.N, self.M) 71 | self.desc_label_2_1 = tk.Label(self.board_col_2, font=status_font, height=1, text=u"走法1", bg=config.BGCOLOR) 72 | self.canvas_2_2 = ThinkCanvas(self.board_col_2, self.N, self.M) 73 | self.desc_label_2_2 = tk.Label(self.board_col_2, font=status_font, height=1, text=u"走法2", bg=config.BGCOLOR) 74 | self.canvas_2_3 = ThinkCanvas(self.board_col_2, self.N, self.M) 75 | self.desc_label_2_3 = tk.Label(self.board_col_2, font=status_font, height=1, text=u"走法3", bg=config.BGCOLOR) 76 | self.desc_label_2.pack(side='top') 77 | self.canvas_2_1.pack(side='top') 78 | self.desc_label_2_1.pack(side='top') 79 | self.canvas_2_2.pack(side='top') 80 | self.desc_label_2_2.pack(side='top') 81 | self.canvas_2_3.pack(side='top') 82 | self.desc_label_2_3.pack(side='top') 83 | self.board_col_2.pack(side='left') 84 | 85 | self.board_col_3 = tk.Frame(self.main_panel, padx=30, bg=config.BGCOLOR) 86 | self.desc_label_3 = tk.Label(self.board_col_3, font=title_font, height=1, text=u"第三轮", bg=config.BGCOLOR) 87 | self.canvas_3_1 = ThinkCanvas(self.board_col_3, self.N, self.M) 88 | self.desc_label_3_1 = tk.Label(self.board_col_3, font=status_font, height=1, text=u"走法1", bg=config.BGCOLOR) 89 | self.canvas_3_2 = ThinkCanvas(self.board_col_3, self.N, self.M) 90 | self.desc_label_3_2 = tk.Label(self.board_col_3, font=status_font, height=1, text=u"走法2", bg=config.BGCOLOR) 91 | self.canvas_3_3 = ThinkCanvas(self.board_col_3, self.N, self.M) 92 | self.desc_label_3_3 = tk.Label(self.board_col_3, font=status_font, height=1, text=u"走法3", bg=config.BGCOLOR) 93 | self.desc_label_3.pack(side='top') 94 | self.canvas_3_1.pack(side='top') 95 | self.desc_label_3_1.pack(side='top') 96 | self.canvas_3_2.pack(side='top') 97 | self.desc_label_3_2.pack(side='top') 98 | self.canvas_3_3.pack(side='top') 99 | self.desc_label_3_3.pack(side='top') 100 | self.board_col_3.pack(side='left') 101 | 102 | self.left_panel.pack(side='left') 103 | self.main_panel.pack(side='left') 104 | 105 | def show(self): 106 | self.canvas_1_1.show_think(self.board, self.mcts_root, level=0, order=0) 107 | self.canvas_1_2.show_think(self.board, self.mcts_root, level=0, order=1) 108 | self.canvas_1_3.show_think(self.board, self.mcts_root, level=0, order=2) 109 | 110 | self.canvas_2_1.show_think(self.board, self.mcts_root, level=1, order=0) 111 | self.canvas_2_2.show_think(self.board, self.mcts_root, level=1, order=1) 112 | self.canvas_2_3.show_think(self.board, self.mcts_root, level=1, order=2) 113 | 114 | self.canvas_3_1.show_think(self.board, self.mcts_root, level=2, order=0) 115 | self.canvas_3_2.show_think(self.board, self.mcts_root, level=2, order=1) 116 | self.canvas_3_3.show_think(self.board, self.mcts_root, level=2, order=2) 117 | 118 | def reset(self): 119 | self.canvas_1_1.reset() 120 | self.canvas_1_2.reset() 121 | self.canvas_1_3.reset() 122 | self.canvas_2_1.reset() 123 | self.canvas_2_2.reset() 124 | self.canvas_2_3.reset() 125 | self.canvas_3_1.reset() 126 | self.canvas_3_2.reset() 127 | self.canvas_3_3.reset() 128 | 129 | def show_think_finish(self): 130 | self.reset() 131 | self.game_app.show_think_finish() -------------------------------------------------------------------------------- /demonstration/src/WelcomeScene.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #author: Zhuoqian Yang yzhq97@buaa.edu.cn 3 | 4 | import tkinter as tk 5 | import tkinter.font as tkf 6 | import src.config as config 7 | import os 8 | from PIL import ImageTk, Image 9 | from src.GameScene import GameScene 10 | 11 | class WelcomeScene(tk.Frame): 12 | def __init__(self, master, game_app): 13 | self.game_app = game_app 14 | self.scn_width = game_app.scn_width 15 | self.scn_height = game_app.scn_height 16 | self.play_data = self.game_app.play_data 17 | tk.Frame.__init__(self, master, 18 | height=self.scn_height, 19 | background=config.BGCOLOR) 20 | self.construct() 21 | self.show_play_data() 22 | 23 | def construct(self): 24 | title_font = tkf.Font(family=config.FONTFAMILY, size=36, weight=tkf.BOLD) 25 | text_font = tkf.Font(family=config.FONTFAMILY, size=20) 26 | desc_font = tkf.Font(family=config.FONTFAMILY, size=16) 27 | 28 | self.title_panel = tk.Frame(self, pady=self.scn_height*0.1, bg=config.BGCOLOR) 29 | self.title_label = tk.Label(self.title_panel, height=1, text=u"五子棋", 30 | bg=config.BGCOLOR, font=title_font) 31 | self.title_label.pack() 32 | 33 | 34 | 35 | self.stats_panel = tk.Frame(self, pady=self.scn_height*0.05, bg=config.BGCOLOR) 36 | 37 | self.easy_var = tk.StringVar() 38 | self.medium_var = tk.StringVar() 39 | self.hard_var = tk.StringVar() 40 | desc_str = u"* 胜率数据通过统计展览期间对局数据得到" 41 | 42 | self.easy_label = tk.Label(self.stats_panel, height=2, bg=config.BGCOLOR, font=text_font, 43 | textvariable=self.easy_var) 44 | self.medium_label = tk.Label(self.stats_panel, height=2, bg=config.BGCOLOR, font=text_font, 45 | textvariable=self.medium_var) 46 | self.hard_label = tk.Label(self.stats_panel, height=2, bg=config.BGCOLOR, font=text_font, 47 | textvariable=self.hard_var) 48 | self.desc_label = tk.Label(self.stats_panel, height=2, bg=config.BGCOLOR, font=desc_font, 49 | text=desc_str) 50 | 51 | self.easy_label.pack() 52 | self.medium_label.pack() 53 | self.hard_label.pack() 54 | self.desc_label.pack() 55 | 56 | self.btn_panel = tk.Frame(self) 57 | 58 | self.easy_btn_image = ImageTk.PhotoImage(Image.open(os.path.join(".", "img", "easy_btn.png"))) 59 | self.medium_btn_image = ImageTk.PhotoImage(Image.open(os.path.join(".", "img", "medium_btn.png"))) 60 | self.hard_btn_image = ImageTk.PhotoImage(Image.open(os.path.join(".", "img", "hard_btn.png"))) 61 | 62 | self.easy_mode_btn = tk.Button(self.btn_panel, command=self.on_easy, 63 | width=260, height=80, border=0, 64 | bg=config.BGCOLOR, activebackground=config.BGCOLOR, 65 | highlightbackground=config.BGCOLOR, 66 | image=self.easy_btn_image) 67 | 68 | self.medium_mode_btn = tk.Button(self.btn_panel, command=self.on_medium, 69 | width=260, height=80, border=0, 70 | bg=config.BGCOLOR, activebackground=config.BGCOLOR, 71 | highlightbackground=config.BGCOLOR, 72 | image=self.medium_btn_image) 73 | 74 | self.hard_mode_btn = tk.Button(self.btn_panel, command=self.on_hard, 75 | width=260, height=80, border=0, 76 | bg=config.BGCOLOR, activebackground=config.BGCOLOR, 77 | highlightbackground=config.BGCOLOR, 78 | image=self.hard_btn_image) 79 | self.easy_mode_btn.pack(side='top') 80 | self.medium_mode_btn.pack(side='top') 81 | self.hard_mode_btn.pack(side='top') 82 | 83 | self.title_panel.pack(side='top') 84 | self.btn_panel.pack(side='top') 85 | self.stats_panel.pack(side='top') 86 | 87 | def show_play_data(self): 88 | easy_str = u"简单模型:自我对弈训练6000局,它在展览期间执黑胜率%4.1f%% 执白胜率%4.1f%%" \ 89 | % (100 * self.play_data.easy_black_winrate(), 100 * self.play_data.easy_white_winrate()) 90 | medium_str = u"中等模型:自我对弈训练8000局,它在展览期间执黑胜率%4.1f%% 执白胜率%4.1f%%" \ 91 | % (100 * self.play_data.medium_black_winrate(), 100 * self.play_data.medium_white_winrate()) 92 | hard_str = u"困难模型:自我对弈训练15000局,它在展览期间执黑胜率%4.1f%% 执白胜率%4.1f%%" \ 93 | % (100 * self.play_data.hard_black_winrate(), 100 * self.play_data.hard_white_winrate()) 94 | self.easy_var.set(easy_str) 95 | self.medium_var.set(medium_str) 96 | self.hard_var.set(hard_str) 97 | 98 | def on_easy(self): 99 | self.game_app.model = config.MODEL_EASY 100 | self.game_app.on_select_difficulty_finish() 101 | 102 | def on_medium(self): 103 | self.game_app.model = config.MODEL_MED 104 | self.game_app.on_select_difficulty_finish() 105 | 106 | def on_hard(self): 107 | self.game_app.model = config.MODEL_HARD 108 | self.game_app.on_select_difficulty_finish() 109 | 110 | -------------------------------------------------------------------------------- /demonstration/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/src/__init__.py -------------------------------------------------------------------------------- /demonstration/src/components/ChessCanvas.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #author: Zhuoqian Yang yzhq97@buaa.edu.cn 3 | 4 | import tkinter as tk 5 | import os 6 | from src.model.CrossPoint import CrossPoint 7 | from PIL import ImageTk, Image 8 | 9 | class ChessCanvas(tk.Canvas): 10 | 11 | def __init__(self, master=None, N=8, M=48): 12 | self.master = master 13 | 14 | self.N = N 15 | self.M = M 16 | self.width = N * M + M 17 | self.height = N * M + M 18 | self.piece_size = M / 3 19 | 20 | tk.Canvas.__init__(self, master, 21 | height=self.height, 22 | width=self.width) 23 | self.bg = ImageTk.PhotoImage(Image.open(os.path.join(".", "img", "wood.png"))) 24 | self.create_image(0, 0, image=self.bg, anchor="nw") 25 | 26 | self.chess_points = [] 27 | self.init_chess_points() 28 | self.init_chess_canvas() 29 | 30 | self.pieces = {} 31 | self.inds = {} 32 | 33 | def reset(self): 34 | for location, piece in self.pieces.items(): 35 | self.delete(piece) 36 | for location, ind in self.inds.items(): 37 | if ind is not None: 38 | self.delete(ind) 39 | self.pieces = {} 40 | self.inds = {} 41 | 42 | def init_chess_points(self): 43 | N, M = self.N, self.M 44 | self.chess_points = [[CrossPoint(i, j, M) for j in range(N)] for i in range(N)] 45 | 46 | def init_chess_canvas(self): 47 | N, M = self.N, self.M 48 | for i in range(N): #绘制竖线 49 | self.create_line(self.chess_points[i][0].pixel_x, 50 | self.chess_points[i][0].pixel_y, 51 | self.chess_points[i][N-1].pixel_x, 52 | self.chess_points[i][N-1].pixel_y, 53 | width=1) 54 | 55 | for j in range(N): #绘制横线 56 | self.create_line(self.chess_points[0][j].pixel_x, 57 | self.chess_points[0][j].pixel_y, 58 | self.chess_points[N-1][j].pixel_x, 59 | self.chess_points[N-1][j].pixel_y, 60 | width=1) 61 | for i in range(N): 62 | for j in range(N): 63 | r = M / 30.0 64 | self.create_oval(self.chess_points[i][j].pixel_x-r, 65 | self.chess_points[i][j].pixel_y-r, 66 | self.chess_points[i][j].pixel_x+r, 67 | self.chess_points[i][j].pixel_y+r, 68 | fill='black') 69 | 70 | def render_location(self, location, r, color, outline=None, text=None, text_color=None, width=3.0): 71 | i, j = location 72 | if outline is None: outline = color 73 | piece = self.create_oval(self.chess_points[i][j].pixel_x - r, 74 | self.chess_points[i][j].pixel_y - r, 75 | self.chess_points[i][j].pixel_x + r, 76 | self.chess_points[i][j].pixel_y + r, 77 | fill=color, outline=outline, width=width) 78 | ind = None 79 | if text is not None: 80 | ind = self.create_text(self.chess_points[i][j].pixel_x, 81 | self.chess_points[i][j].pixel_y, 82 | text=text, fill=text_color) 83 | 84 | self.pieces[location] = piece 85 | self.inds[location] = ind 86 | 87 | def delete_location(self, location): 88 | if self.pieces[location] is not None: 89 | self.delete(self.pieces[location]) 90 | if self.inds[location] is not None: 91 | self.delete(self.inds[location]) 92 | 93 | -------------------------------------------------------------------------------- /demonstration/src/components/GameCanvas.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #author: Zhuoqian Yang yzhq97@buaa.edu.cn 3 | 4 | import math 5 | import src.config as config 6 | from src.components.ChessCanvas import ChessCanvas 7 | from src.model.Gomoku import Board 8 | from src.model.MCTS_AlphaZero import MCTSPlayer 9 | from src.model.PVN_11 import PolicyValueNet 10 | from copy import deepcopy 11 | import sys 12 | 13 | 14 | class GameCanvas(ChessCanvas): 15 | 16 | def __init__(self, master, scene): 17 | ChessCanvas.__init__(self, master, config.N, scene.M) 18 | self.scene = scene 19 | self.N = config.N 20 | self.M = scene.M 21 | 22 | self.board = Board(self.N) 23 | 24 | self.colors = ["black", "white"] 25 | 26 | self.last_location = None 27 | 28 | self.ai_ready = False 29 | self.ai_probs = None 30 | # self.ai_prob_vis = [[None for _ in range(self.N)] for _ in range(self.N)] 31 | self.ai_location = None 32 | self.pvn = None 33 | self.ai = None 34 | 35 | self.think_path_is_shown = False 36 | self.ended = False 37 | self.ai_thinking = False 38 | 39 | def reset(self): 40 | self.board.init_board(0) 41 | self.ai_ready = False 42 | self.ai_location = None 43 | self.last_location = None 44 | ChessCanvas.reset(self) 45 | 46 | def init_ai(self, model_file, play_out): 47 | self.pvn = PolicyValueNet(self.N, self.N, model_file) 48 | self.ai = MCTSPlayer(self.pvn.policy_value_fn, 49 | c_puct=5, 50 | n_playout=play_out) 51 | 52 | def game_start(self): 53 | self.reset() 54 | self.init_ai(self.scene.model, config.PLAYOUT) 55 | self.human_order = self.scene.human_order 56 | self.ai_order = 1 - self.scene.human_order 57 | if self.ai_order == 0: 58 | self.scene.show_message("AI思考中") 59 | self.after(50, self.ai_move) 60 | else: 61 | self.scene.show_message("由您先开始") 62 | self.bind('', self.mouse_click) 63 | 64 | def place_location(self, location): 65 | current_player = self.board.get_current_player() 66 | 67 | if self.last_location is not None: 68 | self.delete_location(self.last_location) 69 | if current_player == 0: 70 | self.render_location(self.last_location, self.piece_size, 'white') 71 | else: 72 | self.render_location(self.last_location, self.piece_size, 'black') 73 | self.last_location = location 74 | 75 | if current_player == 0: 76 | self.render_location(location, self.piece_size, 'black', outline=config.PROB_COLOR, width=5.0) 77 | else: 78 | self.render_location(location, self.piece_size, 'white', outline=config.PROB_COLOR, width=5.0) 79 | 80 | self.board.do_location(location) 81 | 82 | current_player = self.board.get_current_player() 83 | if current_player == self.human_order: 84 | self.scene.show_message(u"这一轮由您走") 85 | else: 86 | self.scene.show_message(u"AI思考中") 87 | 88 | self.think_path_is_shown = False 89 | return self.board.game_end() 90 | 91 | def game_end(self, winner): 92 | if winner == -1: 93 | self.scene.show_message(u"游戏结束,平局") 94 | elif winner == self.human_order: 95 | self.scene.show_message(u"恭喜您获得了胜利!") 96 | else: 97 | self.scene.show_message(u"AI获得了胜利") 98 | self.scene.on_game_end(winner) 99 | self.unbind("") 100 | self.pvn.close() 101 | 102 | def mouse_click(self, event): 103 | if self.ai_thinking: return 104 | 105 | N, M = self.N, self.M 106 | 107 | human_location = None 108 | mouse_found = False 109 | for i in range(N): 110 | if mouse_found: break 111 | for j in range(N): 112 | square_distance = math.pow((event.x - self.chess_points[i][j].pixel_x), 2) \ 113 | + math.pow((event.y - self.chess_points[i][j].pixel_y), 2) 114 | distance = square_distance ** 0.5 115 | human_location = (i, j) 116 | if (distance < self.M/2.0) and \ 117 | self.board.location_is_valid(human_location): 118 | mouse_found = True 119 | break 120 | 121 | if not mouse_found: return 122 | 123 | current_player = self.board.get_current_player() 124 | if current_player != self.human_order: return 125 | 126 | end, winner = self.place_location(human_location) 127 | if end: self.game_end(winner) 128 | else: 129 | self.after(50, self.ai_move) 130 | 131 | def ai_think(self): 132 | self.show_think_finish() 133 | 134 | self.unbind("") 135 | self.ai_thinking = True 136 | self.scene.show_message(u"AI思考中") 137 | if self.ai_ready: 138 | return 139 | 140 | ai_move, self.ai_probs, mcts_root = self.ai.get_action(self.board, temp=1e-6, return_prob=True) 141 | self.ai_location = Board.move_to_location(ai_move, self.N) 142 | 143 | self.scene.board = self.board 144 | self.scene.mcts_root = mcts_root 145 | 146 | self.ai_ready = True 147 | self.ai_thinking = False 148 | self.scene.show_think(self.scene.board, self.scene.mcts_root) 149 | self.bind('', self.mouse_click) 150 | 151 | def show_think_finish(self): 152 | self.scene.show_think_finish() 153 | 154 | def ai_move(self): 155 | if self.ended: 156 | return 157 | if self.board.current_player != self.ai_order: 158 | return 159 | if not self.ai_ready: 160 | self.ai_think() 161 | 162 | end, winner = self.place_location(self.ai_location) 163 | if end: self.game_end(winner) 164 | 165 | self.ai_ready = False 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /demonstration/src/components/ThinkCanvas.py: -------------------------------------------------------------------------------- 1 | import src.config as config 2 | from src.model.Gomoku import Board 3 | from src.components.ChessCanvas import ChessCanvas 4 | #author: Zhuoqian Yang yzhq97@buaa.edu.cn 5 | 6 | class ThinkCanvas(ChessCanvas): 7 | 8 | def __init__(self, master=None, N=8, M=48): 9 | self.master = master 10 | ChessCanvas.__init__(self, master, N, M) 11 | 12 | self.colors = ["black", "white"] 13 | 14 | self.value_node = None 15 | self.board = None 16 | self.node = None 17 | self.order = None 18 | self.location = None 19 | self.current_player = None 20 | 21 | def reset(self): 22 | self.board = None 23 | self.node = None 24 | self.order = None 25 | self.location = None 26 | self.current_player = None 27 | ChessCanvas.reset(self) 28 | 29 | def recover_state(self, state): 30 | for (move, player) in state.items(): 31 | location = Board.move_to_location(move, self.N) 32 | self.render_location(location, self.M/3.0, self.colors[player]) 33 | 34 | def place_location(self, location, r, color, outline=None, text=None, text_color=None, width=2.0): 35 | self.render_location(location, r, color, outline, text, text_color, width) 36 | 37 | def data_check(self, mcts_root, level, order): 38 | move, node = None, mcts_root 39 | for i in range(2 * level): 40 | if node.is_leaf(): 41 | return False 42 | move, node = node.max_next(1) 43 | 44 | best_moves = node.max_next(3) 45 | if order >= len(best_moves): 46 | return False 47 | # tgt_move, tgt_node = best_moves[order] 48 | 49 | # human_moves = tgt_node.max_next(3) 50 | # visits_sum = 0 51 | # for tup in human_moves: 52 | # move, node = tup 53 | # visits_sum += node.n_visits 54 | # if visits_sum == 0: 55 | # return False 56 | 57 | return True 58 | 59 | 60 | def show_think(self, board, mcts_root, level, order): 61 | self.reset() 62 | 63 | if not self.data_check(mcts_root, level, order): 64 | return 65 | 66 | self.recover_state(board.states) 67 | self.current_player = board.get_current_player() 68 | 69 | move, node = None, mcts_root 70 | for i in range(2 * level): 71 | move, node = node.max_next(1) 72 | location = Board.move_to_location(move, self.N) 73 | r = self.M/3.0 74 | color = self.colors[self.current_player] 75 | text = "%d"%(i+1) 76 | self.place_location(location, r, color, text=text, text_color=config.PROB_COLOR) 77 | self.current_player = 1 - self.current_player 78 | 79 | best_moves = node.max_next(3) 80 | if order >= len(best_moves): 81 | self.clear() 82 | return 83 | tgt_move, tgt_node = best_moves[order] 84 | tgt_location = Board.move_to_location(tgt_move, self.N) 85 | tgt_color = self.colors[self.current_player] 86 | text = "%d" % (level * 2 + 1) 87 | self.place_location(tgt_location, self.M/3.0, outline=config.CURR_COLOR, color=tgt_color, text=text, text_color=config.CURR_COLOR) 88 | self.current_player = 1 - self.current_player 89 | 90 | human_moves = tgt_node.max_next(3) 91 | visits_sum = 0 92 | for tup in human_moves: 93 | move, node = tup 94 | visits_sum += node.n_visits 95 | if visits_sum == 0: 96 | return 97 | else: 98 | for tup in human_moves: 99 | move, node = tup 100 | location = Board.move_to_location(move, self.N) 101 | r = 1.2 * (self.M/3.0) * (0.4 + 0.6*(node.n_visits/visits_sum)) 102 | self.place_location(location, r, color=config.PROB_COLOR) 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /demonstration/src/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/src/components/__init__.py -------------------------------------------------------------------------------- /demonstration/src/config.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #author: Zhuoqian Yang yzhq97@buaa.edu.cn 3 | 4 | import os 5 | 6 | N = 11 7 | PLAYOUT = 500 8 | MODEL_EASY = os.path.join(".", "mdl", "gomoku_11x11_2000.model") 9 | MODEL_MED = os.path.join(".", "mdl", "gomoku_11x11_3000.model") 10 | MODEL_HARD = os.path.join(".", "mdl", "gomoku_11x11_5000.model") 11 | 12 | BGCOLOR = "#fdf6E3" 13 | PROB_COLOR = "#64b5f6" 14 | CURR_COLOR = "#ff5252" 15 | FONTFAMILY = "微软雅黑" 16 | 17 | THINK_DESC = u"屏幕中展示了AI在接下来三轮的外推中,对它认为最好的三种走法的外推情况。红色的棋子是AI当前评估的走法,蓝色的点是AI认为您会下棋的位置,大小表示概率。\n\n"+\ 18 | u"AI的思考方式:AI先用神经网络分析棋局,然后对未来可能的情况进行外推。AI在外推时会选择最佳的走法,同时假设对手也会选择最佳的走法。" 19 | -------------------------------------------------------------------------------- /demonstration/src/model/CrossPoint.py: -------------------------------------------------------------------------------- 1 | #author: Zhuoqian Yang yzhq97@buaa.edu.cn 2 | 3 | class CrossPoint: 4 | 5 | def __init__(self, x, y, M): 6 | self.x = x 7 | self.y = y 8 | self.pixel_x = M + M * self.x 9 | self.pixel_y = M + M * self.y -------------------------------------------------------------------------------- /demonstration/src/model/Gomoku.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: Junxiao Song 4 | """ 5 | 6 | from __future__ import print_function 7 | import numpy as np 8 | 9 | 10 | class Board(object): 11 | """board for the game""" 12 | 13 | def __init__(self, N): 14 | self.N = N 15 | self.width = N 16 | self.height = N 17 | # board states stored as a dict, 18 | # key: move as location on the board, 19 | # value: player as pieces type 20 | self.states = {} 21 | # need how many pieces in a row to win 22 | self.n_in_row = 5 23 | self.players = [0, 1] # player1 and player2 24 | 25 | def init_board(self, start_player=0): 26 | if self.width < self.n_in_row or self.height < self.n_in_row: 27 | raise Exception('board width and height can not be ' 28 | 'less than {}'.format(self.n_in_row)) 29 | self.current_player = self.players[start_player] # start player 30 | # keep available moves in a list 31 | self.availables = list(range(self.width * self.height)) 32 | self.states = {} 33 | self.last_move = -1 34 | 35 | def move_is_valid(self, move): 36 | return move in self.availables 37 | 38 | def location_is_valid(self, location): 39 | move = self.location_to_move(location, self.N) 40 | return move in self.availables 41 | 42 | @staticmethod 43 | def move_to_location(move, N): 44 | j = N - (move // N) - 1 45 | i = move % N 46 | return (i, j) 47 | 48 | @staticmethod 49 | def location_to_move(location, N): 50 | if len(location) != 2: 51 | return -1 52 | i, j = location 53 | move = (N - j - 1) * N + i 54 | if move not in range(N * N): 55 | return -1 56 | return move 57 | 58 | def current_state(self): 59 | """return the board state from the perspective of the current player. 60 | state shape: 4*width*height 61 | """ 62 | 63 | square_state = np.zeros((4, self.width, self.height)) 64 | if self.states: 65 | moves, players = np.array(list(zip(*self.states.items()))) 66 | move_curr = moves[players == self.current_player] 67 | move_oppo = moves[players != self.current_player] 68 | square_state[0][move_curr // self.width, 69 | move_curr % self.height] = 1.0 70 | square_state[1][move_oppo // self.width, 71 | move_oppo % self.height] = 1.0 72 | # indicate the last move location 73 | square_state[2][self.last_move // self.width, 74 | self.last_move % self.height] = 1.0 75 | if len(self.states) % 2 == 0: 76 | square_state[3][:, :] = 1.0 # indicate the colour to play 77 | return square_state[:, ::-1, :] 78 | 79 | def do_location(self, location): 80 | self.do_move(self.location_to_move(location, self.N)) 81 | 82 | def do_move(self, move): 83 | self.states[move] = self.current_player 84 | self.availables.remove(move) 85 | self.current_player = ( 86 | self.players[0] if self.current_player == self.players[1] 87 | else self.players[1] 88 | ) 89 | self.last_move = move 90 | 91 | def has_a_winner(self): 92 | width = self.width 93 | height = self.height 94 | states = self.states 95 | n = self.n_in_row 96 | 97 | moved = list(set(range(width * height)) - set(self.availables)) 98 | if len(moved) < self.n_in_row + 2: 99 | return False, -1 100 | 101 | for m in moved: 102 | h = m // width 103 | w = m % width 104 | player = states[m] 105 | 106 | if (w in range(width - n + 1) and 107 | len(set(states.get(i, -1) for i in range(m, m + n))) == 1): 108 | return True, player 109 | 110 | if (h in range(height - n + 1) and 111 | len(set(states.get(i, -1) for i in range(m, m + n * width, width))) == 1): 112 | return True, player 113 | 114 | if (w in range(width - n + 1) and h in range(height - n + 1) and 115 | len(set(states.get(i, -1) for i in range(m, m + n * (width + 1), width + 1))) == 1): 116 | return True, player 117 | 118 | if (w in range(n - 1, width) and h in range(height - n + 1) and 119 | len(set(states.get(i, -1) for i in range(m, m + n * (width - 1), width - 1))) == 1): 120 | return True, player 121 | 122 | return False, -1 123 | 124 | def game_end(self): 125 | """Check whether the game is ended or not""" 126 | win, winner = self.has_a_winner() 127 | if win: 128 | return True, winner 129 | elif not len(self.availables): 130 | return True, -1 131 | return False, -1 132 | 133 | def get_current_player(self): 134 | return self.current_player 135 | 136 | 137 | class Game(object): 138 | """game server""" 139 | 140 | def __init__(self, board, **kwargs): 141 | self.board = board 142 | 143 | def graphic(self, board, player1, player2): 144 | """Draw the board and show game info""" 145 | width = board.width 146 | height = board.height 147 | 148 | print("Player", player1, "with X".rjust(3)) 149 | print("Player", player2, "with O".rjust(3)) 150 | print() 151 | for x in range(width): 152 | print("{0:8}".format(x), end='') 153 | print('\r\n') 154 | for i in range(height - 1, -1, -1): 155 | print("{0:4d}".format(i), end='') 156 | for j in range(width): 157 | loc = i * width + j 158 | p = board.states.get(loc, -1) 159 | if p == player1: 160 | print('X'.center(8), end='') 161 | elif p == player2: 162 | print('O'.center(8), end='') 163 | else: 164 | print('_'.center(8), end='') 165 | print('\r\n\r\n') 166 | 167 | def start_play(self, player1, player2, start_player=0, is_shown=1): 168 | """start a game between two players""" 169 | if start_player not in (0, 1): 170 | raise Exception('start_player should be either 0 (player1 first) ' 171 | 'or 1 (player2 first)') 172 | self.board.init_board(start_player) 173 | p1, p2 = self.board.players 174 | player1.set_player_ind(p1) 175 | player2.set_player_ind(p2) 176 | players = {p1: player1, p2: player2} 177 | if is_shown: 178 | self.graphic(self.board, player1.player, player2.player) 179 | while True: 180 | current_player = self.board.get_current_player() 181 | player_in_turn = players[current_player] 182 | move = player_in_turn.get_action(self.board) 183 | print(move) 184 | self.board.do_move(move) 185 | if is_shown: 186 | self.graphic(self.board, player1.player, player2.player) 187 | end, winner = self.board.game_end() 188 | if end: 189 | if is_shown: 190 | if winner != -1: 191 | print("Game end. Winner is", players[winner]) 192 | else: 193 | print("Game end. Tie") 194 | return winner 195 | 196 | def start_self_play(self, player, is_shown=0, temp=1e-3): 197 | """ start a self-play game using a MCTS player, reuse the search tree, 198 | and store the self-play data: (state, mcts_probs, z) for training 199 | """ 200 | self.board.init_board() 201 | p1, p2 = self.board.players 202 | states, mcts_probs, current_players = [], [], [] 203 | while True: 204 | move, move_probs = player.get_action(self.board, 205 | temp=temp, 206 | return_prob=1) 207 | # store the data 208 | states.append(self.board.current_state()) 209 | mcts_probs.append(move_probs) 210 | current_players.append(self.board.current_player) 211 | # perform a move 212 | self.board.do_move(move) 213 | if is_shown: 214 | self.graphic(self.board, p1, p2) 215 | end, winner = self.board.game_end() 216 | if end: 217 | # winner from the perspective of the current player of each state 218 | winners_z = np.zeros(len(current_players)) 219 | if winner != -1: 220 | winners_z[np.array(current_players) == winner] = 1.0 221 | winners_z[np.array(current_players) != winner] = -1.0 222 | # reset MCTS root node 223 | player.reset_player() 224 | if is_shown: 225 | if winner != -1: 226 | print("Game end. Winner is player:", winner) 227 | else: 228 | print("Game end. Tie") 229 | return winner, zip(states, mcts_probs, winners_z) 230 | -------------------------------------------------------------------------------- /demonstration/src/model/MCTS_AlphaZero.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Monte Carlo Tree Search in AlphaGo Zero style, which uses a policy-value 4 | network to guide the tree search and evaluate the leaf nodes 5 | 6 | @author: Junxiao Song 7 | """ 8 | 9 | import numpy as np 10 | import copy 11 | import time 12 | 13 | seed = int(time.time() * 10**6) % (10**6) 14 | np.random.seed(seed) 15 | 16 | 17 | def softmax(x): 18 | probs = np.exp(x - np.max(x)) 19 | probs /= np.sum(probs) 20 | return probs 21 | 22 | 23 | class TreeNode(object): 24 | """A node in the MCTS tree. 25 | 26 | Each node keeps track of its own value Q, prior probability P, and 27 | its visit-count-adjusted prior score u. 28 | """ 29 | 30 | def __init__(self, parent, prior_p): 31 | self.parent = parent 32 | self.children = {} # a map from action to TreeNode 33 | self.n_visits = 0 34 | self.Q = 0 35 | self.u = 0 36 | self.value = 0 37 | self.P = prior_p 38 | 39 | def expand(self, action_priors): 40 | """Expand tree by creating new children. 41 | action_priors: a list of tuples of actions and their prior probability 42 | according to the policy function. 43 | """ 44 | for action, prob in action_priors: 45 | if action not in self.children: 46 | self.children[action] = TreeNode(self, prob) 47 | 48 | def select(self, c_puct): 49 | """Select action among children that gives maximum action value Q 50 | plus bonus u(P). 51 | Return: A tuple of (action, next_node) 52 | """ 53 | return max(self.children.items(), 54 | key=lambda act_node: act_node[1].get_value(c_puct)) 55 | 56 | def update(self, leaf_value): 57 | """Update node values from leaf evaluation. 58 | leaf_value: the value of subtree evaluation from the current player's 59 | perspective. 60 | """ 61 | # Count visit. 62 | self.n_visits += 1 63 | # Update Q, a running average of values for all visits. 64 | self.Q += 1.0*(leaf_value - self.Q) / self.n_visits 65 | 66 | def update_recursive(self, leaf_value): 67 | """Like a call to update(), but applied recursively for all ancestors. 68 | """ 69 | # If it is not root, this node's parent should be updated first. 70 | if self.parent: 71 | self.parent.update_recursive(-leaf_value) 72 | self.update(leaf_value) 73 | 74 | def get_value(self, c_puct): 75 | """Calculate and return the value for this node. 76 | It is a combination of leaf evaluations Q, and this node's prior 77 | adjusted for its visit count, u. 78 | c_puct: a number in (0, inf) controlling the relative impact of 79 | value Q, and prior probability P, on this node's score. 80 | """ 81 | self.u = (c_puct * self.P * 82 | np.sqrt(self.parent.n_visits) / (1 + self.n_visits)) 83 | self.value = self.Q + self.u 84 | return self.value 85 | 86 | def max_next(self, k): 87 | sorted_next = sorted(self.children.items(), reverse=True, 88 | key= lambda x: x[1].n_visits + x[1].value) 89 | if k == 0: 90 | return sorted_next 91 | elif k == 1: 92 | return sorted_next[0] 93 | else: 94 | return sorted_next[:k] 95 | 96 | def is_leaf(self): 97 | """Check if leaf node (i.e. no nodes below this have been expanded).""" 98 | return self.children == {} 99 | 100 | def is_one_above_leaf(self): 101 | for action, node in self.children.items(): 102 | if not node.is_leaf(): 103 | return False 104 | return True 105 | 106 | def is_root(self): 107 | return self.parent is None 108 | 109 | 110 | class MCTS(object): 111 | """An implementation of Monte Carlo Tree Search.""" 112 | 113 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): 114 | """ 115 | policy_value_fn: a function that takes in a board state and outputs 116 | a list of (action, probability) tuples and also a score in [-1, 1] 117 | (i.e. the expected value of the end game score from the current 118 | player's perspective) for the current player. 119 | c_puct: a number in (0, inf) that controls how quickly exploration 120 | converges to the maximum-value policy. A higher value means 121 | relying on the prior more. 122 | """ 123 | self._root = TreeNode(None, 1.0) 124 | self._policy = policy_value_fn 125 | self._c_puct = c_puct 126 | self._n_playout = n_playout 127 | 128 | def _playout(self, state): 129 | """Run a single playout from the root to the leaf, getting a value at 130 | the leaf and propagating it back through its parents. 131 | State is modified in-place, so a copy must be provided. 132 | """ 133 | node = self._root 134 | while(1): 135 | if node.is_leaf(): 136 | break 137 | # Greedily select next move. 138 | action, node = node.select(self._c_puct) 139 | state.do_move(action) 140 | 141 | # Evaluate the leaf using a network which outputs a list of 142 | # (action, probability) tuples p and also a score v in [-1, 1] 143 | # for the current player. 144 | action_probs, leaf_value = self._policy(state) 145 | # Check for end of game. 146 | end, winner = state.game_end() 147 | if not end: 148 | node.expand(action_probs) 149 | else: 150 | # for end state,return the "true" leaf_value 151 | if winner == -1: # tie 152 | leaf_value = 0.0 153 | else: 154 | leaf_value = ( 155 | 1.0 if winner == state.get_current_player() else -1.0 156 | ) 157 | 158 | # Update value and visit count of nodes in this traversal. 159 | node.update_recursive(-leaf_value) 160 | 161 | def get_move_probs(self, state, temp=1e-3): 162 | """Run all playouts sequentially and return the available actions and 163 | their corresponding probabilities. 164 | state: the current game state 165 | temp: temperature parameter in (0, 1] controls the level of exploration 166 | """ 167 | for n in range(self._n_playout): 168 | state_copy = copy.deepcopy(state) 169 | self._playout(state_copy) 170 | 171 | # calc the move probabilities based on visit counts at the root node 172 | act_visits = [(act, node.n_visits) 173 | for act, node in self._root.children.items()] 174 | acts, visits = zip(*act_visits) 175 | act_probs = softmax(1.0/temp * np.log(np.array(visits) + 1e-10)) 176 | 177 | return acts, act_probs, self._root 178 | 179 | def update_with_move(self, last_move): 180 | """Step forward in the tree, keeping everything we already know 181 | about the subtree. 182 | """ 183 | if last_move in self._root.children: 184 | self._root = self._root.children[last_move] 185 | self._root.parent = None 186 | else: 187 | self._root = TreeNode(None, 1.0) 188 | 189 | def __str__(self): 190 | return "MCTS" 191 | 192 | 193 | class MCTSPlayer(object): 194 | """AI player based on MCTS""" 195 | 196 | def __init__(self, policy_value_function, 197 | c_puct=5, n_playout=2000, is_selfplay=0): 198 | self.mcts = MCTS(policy_value_function, c_puct, n_playout) 199 | self._is_selfplay = is_selfplay 200 | 201 | def set_player_ind(self, p): 202 | self.player = p 203 | 204 | def reset_player(self): 205 | self.mcts.update_with_move(-1) 206 | 207 | def get_action(self, board, temp=1e-3, return_prob=0): 208 | sensible_moves = board.availables 209 | # the pi vector returned by MCTS as in the alphaGo Zero paper 210 | move_probs = np.zeros(board.width*board.height) 211 | if len(sensible_moves) > 0: 212 | acts, probs, root_node = self.mcts.get_move_probs(board, temp) 213 | move_probs[list(acts)] = probs 214 | if self._is_selfplay: 215 | # add Dirichlet Noise for exploration (needed for 216 | # self-play training) 217 | move = np.random.choice( 218 | acts, 219 | p=0.75*probs + 0.25*np.random.dirichlet(0.3*np.ones(len(probs))) 220 | ) 221 | # update the root node and reuse the search tree 222 | self.mcts.update_with_move(move) 223 | else: 224 | # with the default temp=1e-3, it is almost equivalent 225 | # to choosing the move with the highest prob 226 | move = np.random.choice(acts, p=probs) 227 | # reset the root node 228 | self.mcts.update_with_move(-1) 229 | # location = board.move_to_location(move) 230 | # print("AI move: %d,%d\n" % (location[0], location[1])) 231 | 232 | if return_prob: 233 | return move, move_probs, root_node 234 | else: 235 | return move 236 | else: 237 | print("WARNING: the board is full") 238 | 239 | def __str__(self): 240 | return "MCTS {}".format(self.player) 241 | -------------------------------------------------------------------------------- /demonstration/src/model/MCTS_Pure.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | A pure implementation of the Monte Carlo Tree Search (MCTS) 4 | 5 | @author: Junxiao Song 6 | """ 7 | 8 | import numpy as np 9 | import copy 10 | from operator import itemgetter 11 | 12 | 13 | def rollout_policy_fn(board): 14 | """a coarse, fast version of policy_fn used in the rollout phase.""" 15 | # rollout randomly 16 | action_probs = np.random.rand(len(board.availables)) 17 | return zip(board.availables, action_probs) 18 | 19 | 20 | def policy_value_fn(board): 21 | """a function that takes in a state and outputs a list of (action, probability) 22 | tuples and a score for the state""" 23 | # return uniform probabilities and 0 score for pure MCTS 24 | action_probs = np.ones(len(board.availables))/len(board.availables) 25 | return zip(board.availables, action_probs), 0 26 | 27 | 28 | class TreeNode(object): 29 | """A node in the MCTS tree. Each node keeps track of its own value Q, 30 | prior probability P, and its visit-count-adjusted prior score u. 31 | """ 32 | 33 | def __init__(self, parent, prior_p): 34 | self._parent = parent 35 | self._children = {} # a map from action to TreeNode 36 | self._n_visits = 0 37 | self._Q = 0 38 | self._u = 0 39 | self._P = prior_p 40 | 41 | def expand(self, action_priors): 42 | """Expand tree by creating new children. 43 | action_priors: a list of tuples of actions and their prior probability 44 | according to the policy function. 45 | """ 46 | for action, prob in action_priors: 47 | if action not in self._children: 48 | self._children[action] = TreeNode(self, prob) 49 | 50 | def select(self, c_puct): 51 | """Select action among children that gives maximum action value Q 52 | plus bonus u(P). 53 | Return: A tuple of (action, next_node) 54 | """ 55 | return max(self._children.items(), 56 | key=lambda act_node: act_node[1].get_value(c_puct)) 57 | 58 | def update(self, leaf_value): 59 | """Update node values from leaf evaluation. 60 | leaf_value: the value of subtree evaluation from the current player's 61 | perspective. 62 | """ 63 | # Count visit. 64 | self._n_visits += 1 65 | # Update Q, a running average of values for all visits. 66 | self._Q += 1.0*(leaf_value - self._Q) / self._n_visits 67 | 68 | def update_recursive(self, leaf_value): 69 | """Like a call to update(), but applied recursively for all ancestors. 70 | """ 71 | # If it is not root, this node's parent should be updated first. 72 | if self._parent: 73 | self._parent.update_recursive(-leaf_value) 74 | self.update(leaf_value) 75 | 76 | def get_value(self, c_puct): 77 | """Calculate and return the value for this node. 78 | It is a combination of leaf evaluations Q, and this node's prior 79 | adjusted for its visit count, u. 80 | c_puct: a number in (0, inf) controlling the relative impact of 81 | value Q, and prior probability P, on this node's score. 82 | """ 83 | self._u = (c_puct * self._P * 84 | np.sqrt(self._parent._n_visits) / (1 + self._n_visits)) 85 | return self._Q + self._u 86 | 87 | def is_leaf(self): 88 | """Check if leaf node (i.e. no nodes below this have been expanded). 89 | """ 90 | return self._children == {} 91 | 92 | def is_root(self): 93 | return self._parent is None 94 | 95 | 96 | class MCTS(object): 97 | """A simple implementation of Monte Carlo Tree Search.""" 98 | 99 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): 100 | """ 101 | policy_value_fn: a function that takes in a board state and outputs 102 | a list of (action, probability) tuples and also a score in [-1, 1] 103 | (i.e. the expected value of the end game score from the current 104 | player's perspective) for the current player. 105 | c_puct: a number in (0, inf) that controls how quickly exploration 106 | converges to the maximum-value policy. A higher value means 107 | relying on the prior more. 108 | """ 109 | self._root = TreeNode(None, 1.0) 110 | self._policy = policy_value_fn 111 | self._c_puct = c_puct 112 | self._n_playout = n_playout 113 | 114 | def _playout(self, state): 115 | """Run a single playout from the root to the leaf, getting a value at 116 | the leaf and propagating it back through its parents. 117 | State is modified in-place, so a copy must be provided. 118 | """ 119 | node = self._root 120 | while(1): 121 | if node.is_leaf(): 122 | 123 | break 124 | # Greedily select next move. 125 | action, node = node.select(self._c_puct) 126 | state.do_move(action) 127 | 128 | action_probs, _ = self._policy(state) 129 | # Check for end of game 130 | end, winner = state.game_end() 131 | if not end: 132 | node.expand(action_probs) 133 | # Evaluate the leaf node by random rollout 134 | leaf_value = self._evaluate_rollout(state) 135 | # Update value and visit count of nodes in this traversal. 136 | node.update_recursive(-leaf_value) 137 | 138 | def _evaluate_rollout(self, state, limit=1000): 139 | """Use the rollout policy to play until the end of the game, 140 | returning +1 if the current player wins, -1 if the opponent wins, 141 | and 0 if it is a tie. 142 | """ 143 | player = state.get_current_player() 144 | for i in range(limit): 145 | end, winner = state.game_end() 146 | if end: 147 | break 148 | action_probs = rollout_policy_fn(state) 149 | max_action = max(action_probs, key=itemgetter(1))[0] 150 | state.do_move(max_action) 151 | else: 152 | # If no break from the loop, issue a warning. 153 | print("WARNING: rollout reached move limit") 154 | if winner == -1: # tie 155 | return 0 156 | else: 157 | return 1 if winner == player else -1 158 | 159 | def get_move(self, state): 160 | """Runs all playouts sequentially and returns the most visited action. 161 | state: the current game state 162 | 163 | Return: the selected action 164 | """ 165 | for n in range(self._n_playout): 166 | state_copy = copy.deepcopy(state) 167 | self._playout(state_copy) 168 | return max(self._root._children.items(), 169 | key=lambda act_node: act_node[1]._n_visits)[0] 170 | 171 | def update_with_move(self, last_move): 172 | """Step forward in the tree, keeping everything we already know 173 | about the subtree. 174 | """ 175 | if last_move in self._root._children: 176 | self._root = self._root._children[last_move] 177 | self._root._parent = None 178 | else: 179 | self._root = TreeNode(None, 1.0) 180 | 181 | def __str__(self): 182 | return "MCTS" 183 | 184 | 185 | class MCTSPlayer(object): 186 | """AI player based on MCTS""" 187 | def __init__(self, c_puct=5, n_playout=2000): 188 | self.mcts = MCTS(policy_value_fn, c_puct, n_playout) 189 | 190 | def set_player_ind(self, p): 191 | self.player = p 192 | 193 | def reset_player(self): 194 | self.mcts.update_with_move(-1) 195 | 196 | def get_action(self, board): 197 | sensible_moves = board.availables 198 | if len(sensible_moves) > 0: 199 | move = self.mcts.get_move(board) 200 | self.mcts.update_with_move(-1) 201 | return move 202 | else: 203 | print("WARNING: the board is full") 204 | 205 | def __str__(self): 206 | return "MCTS {}".format(self.player) 207 | -------------------------------------------------------------------------------- /demonstration/src/model/PVN_11.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | A new design of the policyValueNet in Tensorflow 4 | @author: Zhuoqian Yang 5 | """ 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | import pickle 10 | 11 | cpu_conf = tf.ConfigProto(device_count = {'GPU': 0}) 12 | 13 | class PolicyValueNet(): 14 | def __init__(self, board_width, board_height, model_file=None): 15 | self.board_width = board_width 16 | self.board_height = board_height 17 | 18 | # Define the tensorflow neural network 19 | # 1. Input: 20 | self.input_states = tf.placeholder( 21 | tf.float32, shape=[None, 4, board_height, board_width]) 22 | self.input_state = tf.transpose(self.input_states, [0, 2, 3, 1]) 23 | # 2. Common Networks Layers 24 | 25 | self.conv1 = tf.layers.conv2d(inputs=self.input_state, 26 | filters=64, kernel_size=[3, 3], 27 | padding="same", data_format="channels_last", 28 | activation=tf.nn.relu) 29 | self.conv1_1 = tf.layers.conv2d(inputs=self.conv1, filters=64, 30 | kernel_size=[3, 3], padding="same", 31 | data_format="channels_last", 32 | activation=tf.nn.relu) 33 | # self.conv1_2 = tf.layers.conv2d(inputs=self.conv1_1, filters=64, 34 | # kernel_size=[3, 3], padding="same", 35 | # data_format="channels_last", 36 | # activation=tf.nn.relu) 37 | self.conv1_res = self.conv1 + self.conv1_1 38 | 39 | self.conv2 = tf.layers.conv2d(inputs=self.conv1_res, filters=128, 40 | kernel_size=[3, 3], padding="same", 41 | data_format="channels_last", 42 | activation=tf.nn.relu) 43 | self.conv2_1 = tf.layers.conv2d(inputs=self.conv2, filters=128, 44 | kernel_size=[3, 3], padding="same", 45 | data_format="channels_last", 46 | activation=tf.nn.relu) 47 | # self.conv2_2 = tf.layers.conv2d(inputs=self.conv2_1, filters=128, 48 | # kernel_size=[3, 3], padding="same", 49 | # data_format="channels_last", 50 | # activation=tf.nn.relu) 51 | self.conv2_res = self.conv2 + self.conv2_1 52 | 53 | self.conv3 = tf.layers.conv2d(inputs=self.conv2_res, filters=256, 54 | kernel_size=[3, 3], padding="same", 55 | data_format="channels_last", 56 | activation=tf.nn.relu) 57 | self.conv3_1 = tf.layers.conv2d(inputs=self.conv3, filters=256, 58 | kernel_size=[3, 3], padding="same", 59 | data_format="channels_last", 60 | activation=tf.nn.relu) 61 | # self.conv3_2 = tf.layers.conv2d(inputs=self.conv3_1, filters=256, 62 | # kernel_size=[3, 3], padding="same", 63 | # data_format="channels_last", 64 | # activation=tf.nn.relu) 65 | self.conv3_res = self.conv3 + self.conv3_1 66 | 67 | # 3-1 Action Networks 68 | self.action_conv = tf.layers.conv2d(inputs=self.conv3_res, filters=4, 69 | kernel_size=[1, 1], padding="same", 70 | data_format="channels_last", 71 | activation=tf.nn.relu) 72 | # Flatten the tensor 73 | self.action_conv_flat = tf.reshape( 74 | self.action_conv, [-1, 4 * board_height * board_width]) 75 | # 3-2 Full connected layer, the output is the log probability of moves 76 | # on each slot on the board 77 | self.action_fc = tf.layers.dense(inputs=self.action_conv_flat, 78 | units=board_height * board_width, 79 | activation=tf.nn.log_softmax) 80 | # 4 Evaluation Networks 81 | self.evaluation_conv = tf.layers.conv2d(inputs=self.conv3, filters=2, 82 | kernel_size=[1, 1], 83 | padding="same", 84 | data_format="channels_last", 85 | activation=tf.nn.relu) 86 | self.evaluation_conv_flat = tf.reshape( 87 | self.evaluation_conv, [-1, 2 * board_height * board_width]) 88 | self.evaluation_fc1 = tf.layers.dense(inputs=self.evaluation_conv_flat, 89 | units=64, activation=tf.nn.relu) 90 | # output the score of evaluation on current state 91 | self.evaluation_fc2 = tf.layers.dense(inputs=self.evaluation_fc1, 92 | units=1, activation=tf.nn.tanh) 93 | 94 | # Define the Loss function 95 | # 1. Label: the array containing if the game wins or not for each state 96 | self.labels = tf.placeholder(tf.float32, shape=[None, 1]) 97 | # 2. Predictions: the array containing the evaluation score of each state 98 | # which is self.evaluation_fc2 99 | # 3-1. Value Loss function 100 | self.value_loss = tf.losses.mean_squared_error(self.labels, 101 | self.evaluation_fc2) 102 | # 3-2. Policy Loss function 103 | self.mcts_probs = tf.placeholder( 104 | tf.float32, shape=[None, board_height * board_width]) 105 | self.policy_loss = tf.negative(tf.reduce_mean( 106 | tf.reduce_sum(tf.multiply(self.mcts_probs, self.action_fc), 1))) 107 | # 3-3. L2 penalty (regularization) 108 | l2_penalty_beta = 1e-4 109 | vars = tf.trainable_variables() 110 | l2_penalty = l2_penalty_beta * tf.add_n( 111 | [tf.nn.l2_loss(v) for v in vars if 'bias' not in v.name.lower()]) 112 | # 3-4 Add up to be the Loss function 113 | self.loss = self.value_loss + self.policy_loss + l2_penalty 114 | 115 | # Define the optimizer we use for training 116 | self.learning_rate = tf.placeholder(tf.float32) 117 | self.optimizer = tf.train.AdamOptimizer( 118 | learning_rate=self.learning_rate).minimize(self.loss) 119 | 120 | # Make a session 121 | self.session = tf.Session() 122 | 123 | # calc policy entropy, for monitoring only 124 | self.entropy = tf.negative(tf.reduce_mean( 125 | tf.reduce_sum(tf.exp(self.action_fc) * self.action_fc, 1))) 126 | 127 | # Initialize variables 128 | init = tf.global_variables_initializer() 129 | self.session.run(init) 130 | 131 | # For saving and restoring 132 | self.saver = tf.train.Saver(max_to_keep=500) 133 | if model_file is not None: 134 | self.restore_model(model_file) 135 | 136 | def policy_value(self, state_batch): 137 | """ 138 | input: a batch of states 139 | output: a batch of action probabilities and state values 140 | """ 141 | log_act_probs, value = self.session.run( 142 | [self.action_fc, self.evaluation_fc2], 143 | feed_dict={self.input_states: state_batch} 144 | ) 145 | act_probs = np.exp(log_act_probs) 146 | return act_probs, value 147 | 148 | def policy_value_fn(self, board): 149 | """ 150 | input: board 151 | output: a list of (action, probability) tuples for each available 152 | action and the score of the board state 153 | """ 154 | legal_positions = board.availables 155 | current_state = np.ascontiguousarray(board.current_state().reshape( 156 | -1, 4, self.board_width, self.board_height)) 157 | 158 | # np.save("/home/yzhq/state_2.npy", current_state, allow_pickle=False) 159 | 160 | act_probs, value = self.policy_value(current_state) 161 | act_probs = zip(legal_positions, act_probs[0][legal_positions]) 162 | return act_probs, value 163 | 164 | def train_step(self, state_batch, mcts_probs, winner_batch, lr): 165 | """perform a training step""" 166 | winner_batch = np.reshape(winner_batch, (-1, 1)) 167 | loss, entropy, _ = self.session.run( 168 | [self.loss, self.entropy, self.optimizer], 169 | feed_dict={self.input_states: state_batch, 170 | self.mcts_probs: mcts_probs, 171 | self.labels: winner_batch, 172 | self.learning_rate: lr}) 173 | return loss, entropy 174 | 175 | def save_model(self, model_path): 176 | self.saver.save(self.session, model_path) 177 | 178 | def restore_model(self, model_path): 179 | self.saver.restore(self.session, model_path) 180 | 181 | def close(self): 182 | tf.reset_default_graph() 183 | self.session.close() 184 | -------------------------------------------------------------------------------- /demonstration/src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/src/model/__init__.py -------------------------------------------------------------------------------- /demonstration/src/play_data.py: -------------------------------------------------------------------------------- 1 | #author: Zhuoqian Yang yzhq97@buaa.edu.cn 2 | 3 | class PlayData: 4 | def __init__(self): 5 | self.easy_black_win = 0 6 | self.easy_black_draw = 0 7 | self.easy_black_lose = 0 8 | self.easy_white_win = 0 9 | self.easy_white_draw = 0 10 | self.easy_white_lose = 0 11 | self.medium_black_win = 0 12 | self.medium_black_draw = 0 13 | self.medium_black_lose = 0 14 | self.medium_white_win = 0 15 | self.medium_white_draw = 0 16 | self.medium_white_lose = 0 17 | self.hard_black_win = 0 18 | self.hard_black_draw = 0 19 | self.hard_black_lose = 0 20 | self.hard_white_win = 0 21 | self.hard_white_draw = 0 22 | self.hard_white_lose = 0 23 | 24 | def easy_black_winrate(self): 25 | all = self.easy_black_win + self.easy_black_lose 26 | if all == 0: return 0.0 27 | else: return 1.0 * self.easy_black_win / all 28 | 29 | def easy_white_winrate(self): 30 | all = self.easy_white_win + self.easy_white_lose 31 | if all == 0: return 0.0 32 | else: return 1.0 * self.easy_white_win / all 33 | 34 | def medium_black_winrate(self): 35 | all = self.medium_black_win + self.medium_black_lose 36 | if all == 0: 37 | return 0.0 38 | else: 39 | return 1.0 * self.medium_black_win / all 40 | 41 | def medium_white_winrate(self): 42 | all = self.medium_white_win + self.medium_white_draw + self.medium_white_lose 43 | if all == 0: 44 | return 0.0 45 | else: 46 | return 1.0 * self.medium_white_win / all 47 | 48 | def hard_black_winrate(self): 49 | all = self.hard_black_win + self.hard_black_draw + self.hard_black_lose 50 | if all == 0: 51 | return 0.0 52 | else: 53 | return 1.0 * self.hard_black_win / all 54 | 55 | def hard_white_winrate(self): 56 | all = self.hard_white_win + self.hard_white_draw + self.hard_white_lose 57 | if all == 0: 58 | return 0.0 59 | else: 60 | return 1.0 * self.hard_white_win / all 61 | 62 | -------------------------------------------------------------------------------- /demonstration/static/play_data.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/demonstration/static/play_data.pkl -------------------------------------------------------------------------------- /screenshots/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/screenshots/1.jpg -------------------------------------------------------------------------------- /screenshots/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/screenshots/2.jpg -------------------------------------------------------------------------------- /screenshots/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/screenshots/3.jpg -------------------------------------------------------------------------------- /training/main.py: -------------------------------------------------------------------------------- 1 | from src.train_pipeline import TrainPipeline 2 | 3 | tp = TrainPipeline(10450, visualize=True) 4 | tp.run() 5 | -------------------------------------------------------------------------------- /training/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/training/src/__init__.py -------------------------------------------------------------------------------- /training/src/human_play.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | human VS AI models 4 | Input your move in the format: 2,3 5 | 6 | @author: Junxiao Song 7 | """ 8 | 9 | from __future__ import print_function 10 | import pickle 11 | from game import Board, Game 12 | from mcts_pure import MCTSPlayer as MCTS_Pure 13 | from mcts_alphaZero import MCTSPlayer 14 | # from policy_value_net_numpy import PolicyValueNetNumpy 15 | # from policy_value_net import PolicyValueNet # Theano and Lasagne 16 | # from policy_value_net_pytorch import PolicyValueNet # Pytorch 17 | from pvn_13 import PolicyValueNet # Tensorflow 18 | # from policy_value_net_keras import PolicyValueNet # Keras 19 | 20 | 21 | class Human(object): 22 | """ 23 | human player 24 | """ 25 | 26 | def __init__(self): 27 | self.player = None 28 | 29 | def set_player_ind(self, p): 30 | self.player = p 31 | 32 | def get_action(self, board): 33 | try: 34 | location = input("Your move: ") 35 | if isinstance(location, str): # for python3 36 | location = [int(n, 10) for n in location.split(",")] 37 | move = board.location_to_move(location) 38 | except Exception as e: 39 | move = -1 40 | if move == -1 or move not in board.availables: 41 | print("invalid move") 42 | move = self.get_action(board) 43 | return move 44 | 45 | def __str__(self): 46 | return "Human {}".format(self.player) 47 | 48 | 49 | def run(): 50 | n = 5 51 | width, height = 13, 13 52 | model_file = './train13/gomoku_13x13_370.model' 53 | try: 54 | board = Board(width=width, height=height, n_in_row=n) 55 | game = Game(board) 56 | 57 | # ############### human VS AI ################### 58 | # load the trained policy_value_net in either Theano/Lasagne, PyTorch or TensorFlow 59 | 60 | # best_policy = PolicyValueNet(width, height, model_file = model_file) 61 | # mcts_player = MCTSPlayer(best_policy.policy_value_fn, c_puct=5, n_playout=400) 62 | 63 | # load the provided model (trained in Theano/Lasagne) into a MCTS player written in pure numpy 64 | # try: 65 | # policy_param = pickle.load(open(model_file, 'rb')) 66 | # except: 67 | # policy_param = pickle.load(open(model_file, 'rb'), 68 | # encoding='bytes') # To support python3 69 | 70 | best_policy = PolicyValueNet(width, height, model_file) 71 | mcts_player = MCTSPlayer(best_policy.policy_value_fn, 72 | c_puct=5, 73 | n_playout=500) # set larger n_playout for better performance 74 | 75 | # uncomment the following line to play with pure MCTS (it's much weaker even with a larger n_playout) 76 | # mcts_player = MCTS_Pure(c_puct=5, n_playout=1000) 77 | 78 | # human player, input your move in the format: 2,3 79 | human = Human() 80 | 81 | # set start_player=0 for human first 82 | game.start_play(human, mcts_player, start_player=1, is_shown=1) 83 | except KeyboardInterrupt: 84 | print('\n\rquit') 85 | 86 | 87 | if __name__ == '__main__': 88 | run() 89 | -------------------------------------------------------------------------------- /training/src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inrainbws/AlphaGomokuZero/daab636949701eefce1ebd20e349bf4f06e32cc1/training/src/model/__init__.py -------------------------------------------------------------------------------- /training/src/model/game.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: Junxiao Song 4 | """ 5 | 6 | from __future__ import print_function 7 | import numpy as np 8 | import time 9 | 10 | class Board(object): 11 | """board for the game""" 12 | 13 | def __init__(self, **kwargs): 14 | self.width = int(kwargs.get('width', 8)) 15 | self.height = int(kwargs.get('height', 8)) 16 | # board states stored as a dict, 17 | # key: move as location on the board, 18 | # value: player as pieces type 19 | self.states = {} 20 | # need how many pieces in a row to win 21 | self.n_in_row = int(kwargs.get('n_in_row', 5)) 22 | self.players = [1, 2] # player1 and player2 23 | 24 | def init_board(self, start_player=0): 25 | if self.width < self.n_in_row or self.height < self.n_in_row: 26 | raise Exception('board width and height can not be ' 27 | 'less than {}'.format(self.n_in_row)) 28 | self.current_player = self.players[start_player] # start player 29 | # keep available moves in a list 30 | self.availables = list(range(self.width * self.height)) 31 | self.states = {} 32 | self.last_move = -1 33 | 34 | def move_to_location(self, move): 35 | """ 36 | 3*3 board's moves like: 37 | 6 7 8 38 | 3 4 5 39 | 0 1 2 40 | and move 5's location is (1,2) 41 | """ 42 | h = move // self.width 43 | w = move % self.width 44 | return [h, w] 45 | 46 | def location_to_move(self, location): 47 | if len(location) != 2: 48 | return -1 49 | h = location[0] 50 | w = location[1] 51 | move = h * self.width + w 52 | if move not in range(self.width * self.height): 53 | return -1 54 | return move 55 | 56 | def current_state(self): 57 | """return the board state from the perspective of the current player. 58 | state shape: 4*width*height 59 | """ 60 | 61 | square_state = np.zeros((4, self.width, self.height)) 62 | if self.states: 63 | moves, players = np.array(list(zip(*self.states.items()))) 64 | move_curr = moves[players == self.current_player] 65 | move_oppo = moves[players != self.current_player] 66 | square_state[0][move_curr // self.width, 67 | move_curr % self.height] = 1.0 68 | square_state[1][move_oppo // self.width, 69 | move_oppo % self.height] = 1.0 70 | # indicate the last move location 71 | square_state[2][self.last_move // self.width, 72 | self.last_move % self.height] = 1.0 73 | if len(self.states) % 2 == 0: 74 | square_state[3][:, :] = 1.0 # indicate the colour to play 75 | return square_state[:, ::-1, :] 76 | 77 | def do_move(self, move): 78 | self.states[move] = self.current_player 79 | self.availables.remove(move) 80 | self.current_player = ( 81 | self.players[0] if self.current_player == self.players[1] 82 | else self.players[1] 83 | ) 84 | self.last_move = move 85 | 86 | def has_a_winner(self): 87 | width = self.width 88 | height = self.height 89 | states = self.states 90 | n = self.n_in_row 91 | 92 | moved = list(set(range(width * height)) - set(self.availables)) 93 | if len(moved) < self.n_in_row + 2: 94 | return False, -1 95 | 96 | for m in moved: 97 | h = m // width 98 | w = m % width 99 | player = states[m] 100 | 101 | if (w in range(width - n + 1) and 102 | len(set(states.get(i, -1) for i in range(m, m + n))) == 1): 103 | return True, player 104 | 105 | if (h in range(height - n + 1) and 106 | len(set(states.get(i, -1) for i in range(m, m + n * width, width))) == 1): 107 | return True, player 108 | 109 | if (w in range(width - n + 1) and h in range(height - n + 1) and 110 | len(set(states.get(i, -1) for i in range(m, m + n * (width + 1), width + 1))) == 1): 111 | return True, player 112 | 113 | if (w in range(n - 1, width) and h in range(height - n + 1) and 114 | len(set(states.get(i, -1) for i in range(m, m + n * (width - 1), width - 1))) == 1): 115 | return True, player 116 | 117 | return False, -1 118 | 119 | def game_end(self): 120 | """Check whether the game is ended or not""" 121 | win, winner = self.has_a_winner() 122 | if win: 123 | return True, winner 124 | elif not len(self.availables): 125 | return True, -1 126 | return False, -1 127 | 128 | def get_current_player(self): 129 | return self.current_player 130 | 131 | 132 | class Game(object): 133 | """game server""" 134 | 135 | def __init__(self, board, **kwargs): 136 | self.board = board 137 | 138 | def graphic(self, board, player1, player2): 139 | """Draw the board and show game info""" 140 | width = board.width 141 | height = board.height 142 | 143 | print("Player", player1, "with X".rjust(3)) 144 | print("Player", player2, "with O".rjust(3)) 145 | print() 146 | for x in range(width): 147 | print("{0:8}".format(x), end='') 148 | print('\r\n') 149 | for i in range(height - 1, -1, -1): 150 | print("{0:4d}".format(i), end='') 151 | for j in range(width): 152 | loc = i * width + j 153 | p = board.states.get(loc, -1) 154 | if p == player1: 155 | print('X'.center(8), end='') 156 | elif p == player2: 157 | print('O'.center(8), end='') 158 | else: 159 | print('_'.center(8), end='') 160 | print('\r\n\r\n') 161 | 162 | def start_play(self, player1, player2, start_player=0, is_shown=1): 163 | """start a game between two players""" 164 | if start_player not in (0, 1): 165 | raise Exception('start_player should be either 0 (player1 first) ' 166 | 'or 1 (player2 first)') 167 | self.board.init_board(start_player) 168 | p1, p2 = self.board.players 169 | player1.set_player_ind(p1) 170 | player2.set_player_ind(p2) 171 | players = {p1: player1, p2: player2} 172 | if is_shown: 173 | self.graphic(self.board, player1.player, player2.player) 174 | while True: 175 | current_player = self.board.get_current_player() 176 | player_in_turn = players[current_player] 177 | move = player_in_turn.get_action(self.board) 178 | self.board.do_move(move) 179 | if is_shown: 180 | self.graphic(self.board, player1.player, player2.player) 181 | end, winner = self.board.game_end() 182 | if end: 183 | if is_shown: 184 | if winner != -1: 185 | print("Game end. Winner is", players[winner]) 186 | else: 187 | print("Game end. Tie") 188 | return winner 189 | 190 | def start_self_play(self, player, is_shown=0, temp=1e-3): 191 | """ start a self-play game using a MCTS player, reuse the search tree, 192 | and store the self-play data: (state, mcts_probs, z) for training 193 | """ 194 | self.board.init_board() 195 | p1, p2 = self.board.players 196 | states, mcts_probs, current_players = [], [], [] 197 | moves = 0 198 | while True: 199 | move, move_probs = player.get_action(self.board, 200 | temp=temp, 201 | return_prob=1) 202 | # store the data 203 | states.append(self.board.current_state()) 204 | mcts_probs.append(move_probs) 205 | current_players.append(self.board.current_player) 206 | # perform a move 207 | self.board.do_move(move) 208 | if is_shown: 209 | self.graphic(self.board, p1, p2) 210 | end, winner = self.board.game_end() 211 | moves += 1 212 | if end: 213 | # winner from the perspective of the current player of each state 214 | winners_z = np.zeros(len(current_players)) 215 | if winner != -1: 216 | winners_z[np.array(current_players) == winner] = 1.0 217 | winners_z[np.array(current_players) != winner] = -1.0 218 | # reset MCTS root node 219 | player.reset_player() 220 | if is_shown: 221 | if winner != -1: 222 | print("Game end. Winner is player:", winner) 223 | else: 224 | print("Game end. Tie") 225 | return winner, zip(states, mcts_probs, winners_z), moves 226 | -------------------------------------------------------------------------------- /training/src/model/inception_resnet_v2.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | import tensorflow as tf 7 | 8 | slim = tf.contrib.slim 9 | 10 | def block8(net, scale=1.0, activation_fn=None, scope=None, reuse=None): 11 | """Builds the 8x8 resnet block.""" 12 | with tf.variable_scope(scope, 'Block8', [net], reuse=reuse): 13 | with tf.variable_scope('Branch_0'): 14 | tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1') 15 | with tf.variable_scope('Branch_1'): 16 | tower_conv1_0 = slim.conv2d(net, 96, 1, scope='Conv2d_0a_1x1') 17 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 144, [1, 3], 18 | scope='Conv2d_0b_1x3') 19 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [3, 1], 20 | scope='Conv2d_0c_3x1') 21 | mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_2]) 22 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, 23 | activation_fn=None, scope='Conv2d_1x1') 24 | net = net + scale * up 25 | if activation_fn is not None: 26 | net = activation_fn(net) 27 | return net 28 | 29 | def inception_resnet_v2_conv(inputs, is_training, reuse=None, scope='InceptionResnetV2'): 30 | end_points = {} 31 | with tf.variable_scope(scope, 'InceptionResnetV2', [inputs], reuse=reuse): 32 | with slim.arg_scope([slim.batch_norm], is_training=is_training): 33 | with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d], stride=1, padding='SAME'): 34 | # 15 x 15 x 64 35 | net = slim.conv2d(inputs, 64, 3, scope='Conv2d_1a_3x3') 36 | end_points['Conv2d_1a_3x3'] = net 37 | # 15 x 15 x 128 38 | net = slim.conv2d(net, 128, 3, scope='Conv2d_2a_3x3') 39 | end_points['Conv2d_2a_3x3'] = net 40 | 41 | # 15 x 15 x 256 42 | with tf.variable_scope('Mixed_5b'): 43 | with tf.variable_scope('Branch_0'): 44 | tower_conv = slim.conv2d(net, 64, 1, scope='Conv2d_1x1') 45 | with tf.variable_scope('Branch_1'): 46 | tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1') 47 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 64, 5, 48 | scope='Conv2d_0b_5x5') 49 | with tf.variable_scope('Branch_2'): 50 | tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1') 51 | tower_conv2_1 = slim.conv2d(tower_conv2_0, 64, 3, 52 | scope='Conv2d_0b_3x3') 53 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 64, 3, 54 | scope='Conv2d_0c_3x3') 55 | with tf.variable_scope('Branch_3'): 56 | tower_pool = slim.avg_pool2d(net, 3, stride=1, padding='SAME', 57 | scope='AvgPool_0a_3x3') 58 | tower_pool_1 = slim.conv2d(tower_pool, 64, 1, 59 | scope='Conv2d_0b_1x1') 60 | net = tf.concat(axis=3, values=[tower_conv, tower_conv1_1, 61 | tower_conv2_2, tower_pool_1]) 62 | 63 | end_points['Mixed_5b'] = net 64 | 65 | net = block8(net, scale=0.50, activation_fn=tf.nn.relu) 66 | net = block8(net, scale=0.50, activation_fn=tf.nn.relu) 67 | net = block8(net, scale=0.50, activation_fn=tf.nn.relu) 68 | net = block8(net, scale=0.50, activation_fn=tf.nn.relu) 69 | 70 | end_points['End'] = net 71 | 72 | return net, end_points 73 | 74 | def inception_resnet_v2_arg_scope(weight_decay=0.0004, 75 | batch_norm_decay=0.997, 76 | batch_norm_epsilon=0.001): 77 | """Yields the scope with the default parameters for inception_resnet_v2. 78 | Args: 79 | weight_decay: the weight decay for weights variables. 80 | batch_norm_decay: decay for the moving average of batch_norm momentums. 81 | batch_norm_epsilon: small float added to variance to avoid dividing by zero. 82 | Returns: 83 | a arg_scope with the parameters needed for inception_resnet_v2. 84 | """ 85 | # Set weight_decay for weights in conv2d and fully_connected layers. 86 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 87 | weights_regularizer=slim.l2_regularizer(weight_decay), 88 | biases_regularizer=slim.l2_regularizer(weight_decay)): 89 | 90 | batch_norm_params = { 91 | 'decay': batch_norm_decay, 92 | 'epsilon': batch_norm_epsilon, 93 | } 94 | # Set activation_fn and parameters for batch_norm. 95 | with slim.arg_scope([slim.conv2d], activation_fn=tf.nn.relu, 96 | normalizer_fn=slim.batch_norm, 97 | normalizer_params=batch_norm_params) as scope: 98 | return scope -------------------------------------------------------------------------------- /training/src/model/mcts_alphaZero.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Monte Carlo Tree Search in AlphaGo Zero style, which uses a policy-value 4 | network to guide the tree search and evaluate the leaf nodes 5 | 6 | @author: Junxiao Song 7 | """ 8 | 9 | import numpy as np 10 | import copy 11 | import time 12 | 13 | def softmax(x): 14 | probs = np.exp(x - np.max(x)) 15 | probs /= np.sum(probs) 16 | return probs 17 | 18 | 19 | class TreeNode(object): 20 | """A node in the MCTS tree. 21 | 22 | Each node keeps track of its own value Q, prior probability P, and 23 | its visit-count-adjusted prior score u. 24 | """ 25 | 26 | def __init__(self, parent, prior_p): 27 | self._parent = parent 28 | self._children = {} # a map from action to TreeNode 29 | self._n_visits = 0 30 | self._Q = 0 31 | self._u = 0 32 | self._P = prior_p 33 | 34 | def expand(self, action_priors): 35 | """Expand tree by creating new children. 36 | action_priors: a list of tuples of actions and their prior probability 37 | according to the policy function. 38 | """ 39 | for action, prob in action_priors: 40 | if action not in self._children: 41 | self._children[action] = TreeNode(self, prob) 42 | 43 | def select(self, c_puct): 44 | """Select action among children that gives maximum action value Q 45 | plus bonus u(P). 46 | Return: A tuple of (action, next_node) 47 | """ 48 | return max(self._children.items(), 49 | key=lambda act_node: act_node[1].get_value(c_puct)) 50 | 51 | def update(self, leaf_value): 52 | """Update node values from leaf evaluation. 53 | leaf_value: the value of subtree evaluation from the current player's 54 | perspective. 55 | """ 56 | # Count visit. 57 | self._n_visits += 1 58 | # Update Q, a running average of values for all visits. 59 | self._Q += 1.0*(leaf_value - self._Q) / self._n_visits 60 | 61 | def update_recursive(self, leaf_value): 62 | """Like a call to update(), but applied recursively for all ancestors. 63 | """ 64 | # If it is not root, this node's parent should be updated first. 65 | if self._parent: 66 | self._parent.update_recursive(-leaf_value) 67 | self.update(leaf_value) 68 | 69 | def get_value(self, c_puct): 70 | """Calculate and return the value for this node. 71 | It is a combination of leaf evaluations Q, and this node's prior 72 | adjusted for its visit count, u. 73 | c_puct: a number in (0, inf) controlling the relative impact of 74 | value Q, and prior probability P, on this node's score. 75 | """ 76 | self._u = (c_puct * self._P * 77 | np.sqrt(self._parent._n_visits) / (1 + self._n_visits)) 78 | return self._Q + self._u 79 | 80 | def is_leaf(self): 81 | """Check if leaf node (i.e. no nodes below this have been expanded).""" 82 | return self._children == {} 83 | 84 | def is_root(self): 85 | return self._parent is None 86 | 87 | 88 | class MCTS(object): 89 | """An implementation of Monte Carlo Tree Search.""" 90 | 91 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): 92 | """ 93 | policy_value_fn: a function that takes in a board state and outputs 94 | a list of (action, probability) tuples and also a score in [-1, 1] 95 | (i.e. the expected value of the end game score from the current 96 | player's perspective) for the current player. 97 | c_puct: a number in (0, inf) that controls how quickly exploration 98 | converges to the maximum-value policy. A higher value means 99 | relying on the prior more. 100 | """ 101 | self._root = TreeNode(None, 1.0) 102 | self._policy = policy_value_fn 103 | self._c_puct = c_puct 104 | self._n_playout = n_playout 105 | np.random.seed(int((time.time()%1)*3000000)) 106 | 107 | def _playout(self, state): 108 | """Run a single playout from the root to the leaf, getting a value at 109 | the leaf and propagating it back through its parents. 110 | State is modified in-place, so a copy must be provided. 111 | """ 112 | node = self._root 113 | while(1): 114 | if node.is_leaf(): 115 | break 116 | # Greedily select next move. 117 | action, node = node.select(self._c_puct) 118 | state.do_move(action) 119 | 120 | # Evaluate the leaf using a network which outputs a list of 121 | # (action, probability) tuples p and also a score v in [-1, 1] 122 | # for the current player. 123 | action_probs, leaf_value = self._policy(state) 124 | # Check for end of game. 125 | end, winner = state.game_end() 126 | if not end: 127 | node.expand(action_probs) 128 | else: 129 | # for end state,return the "true" leaf_value 130 | if winner == -1: # tie 131 | leaf_value = 0.0 132 | else: 133 | leaf_value = ( 134 | 1.0 if winner == state.get_current_player() else -1.0 135 | ) 136 | 137 | # Update value and visit count of nodes in this traversal. 138 | node.update_recursive(-leaf_value) 139 | 140 | def get_move_probs(self, state, temp=1e-3): 141 | """Run all playouts sequentially and return the available actions and 142 | their corresponding probabilities. 143 | state: the current game state 144 | temp: temperature parameter in (0, 1] controls the level of exploration 145 | """ 146 | for n in range(self._n_playout): 147 | state_copy = copy.deepcopy(state) 148 | self._playout(state_copy) 149 | 150 | # calc the move probabilities based on visit counts at the root node 151 | act_visits = [(act, node._n_visits) 152 | for act, node in self._root._children.items()] 153 | acts, visits = zip(*act_visits) 154 | act_probs = softmax(1.0/temp * np.log(np.array(visits) + 1e-10)) 155 | 156 | return acts, act_probs 157 | 158 | def update_with_move(self, last_move): 159 | """Step forward in the tree, keeping everything we already know 160 | about the subtree. 161 | """ 162 | if last_move in self._root._children: 163 | self._root = self._root._children[last_move] 164 | self._root._parent = None 165 | else: 166 | self._root = TreeNode(None, 1.0) 167 | 168 | def __str__(self): 169 | return "MCTS" 170 | 171 | 172 | class MCTSPlayer(object): 173 | """AI player based on MCTS""" 174 | 175 | def __init__(self, policy_value_function, 176 | c_puct=5, n_playout=2000, is_selfplay=0): 177 | self.mcts = MCTS(policy_value_function, c_puct, n_playout) 178 | self._is_selfplay = is_selfplay 179 | np.random.seed(int((time.time()%1)*3000000)) 180 | 181 | def set_player_ind(self, p): 182 | self.player = p 183 | 184 | def reset_player(self): 185 | self.mcts.update_with_move(-1) 186 | 187 | def get_action(self, board, temp=1e-3, return_prob=0): 188 | sensible_moves = board.availables 189 | # the pi vector returned by MCTS as in the alphaGo Zero paper 190 | move_probs = np.zeros(board.width*board.height) 191 | if len(sensible_moves) > 0: 192 | acts, probs = self.mcts.get_move_probs(board, temp) 193 | move_probs[list(acts)] = probs 194 | if self._is_selfplay: 195 | # add Dirichlet Noise for exploration (needed for 196 | # self-play training) 197 | move = np.random.choice( 198 | acts, 199 | p=0.75*probs + 0.25*np.random.dirichlet(0.3*np.ones(len(probs))) 200 | ) 201 | # update the root node and reuse the search tree 202 | self.mcts.update_with_move(move) 203 | else: 204 | # with the default temp=1e-3, it is almost equivalent 205 | # to choosing the move with the highest prob 206 | move = np.random.choice(acts, p=probs) 207 | # reset the root node 208 | self.mcts.update_with_move(-1) 209 | # location = board.move_to_location(move) 210 | # print("AI move: %d,%d\n" % (location[0], location[1])) 211 | 212 | if return_prob: 213 | return move, move_probs 214 | else: 215 | return move 216 | else: 217 | print("WARNING: the board is full") 218 | 219 | def __str__(self): 220 | return "MCTS {}".format(self.player) 221 | -------------------------------------------------------------------------------- /training/src/model/mcts_pure.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | A pure implementation of the Monte Carlo Tree Search (MCTS) 4 | 5 | @author: Junxiao Song 6 | """ 7 | 8 | import numpy as np 9 | import copy 10 | from operator import itemgetter 11 | 12 | 13 | def rollout_policy_fn(board): 14 | """a coarse, fast version of policy_fn used in the rollout phase.""" 15 | # rollout randomly 16 | action_probs = np.random.rand(len(board.availables)) 17 | return zip(board.availables, action_probs) 18 | 19 | 20 | def policy_value_fn(board): 21 | """a function that takes in a state and outputs a list of (action, probability) 22 | tuples and a score for the state""" 23 | # return uniform probabilities and 0 score for pure MCTS 24 | action_probs = np.ones(len(board.availables))/len(board.availables) 25 | return zip(board.availables, action_probs), 0 26 | 27 | 28 | class TreeNode(object): 29 | """A node in the MCTS tree. Each node keeps track of its own value Q, 30 | prior probability P, and its visit-count-adjusted prior score u. 31 | """ 32 | 33 | def __init__(self, parent, prior_p): 34 | self._parent = parent 35 | self._children = {} # a map from action to TreeNode 36 | self._n_visits = 0 37 | self._Q = 0 38 | self._u = 0 39 | self._P = prior_p 40 | 41 | def expand(self, action_priors): 42 | """Expand tree by creating new children. 43 | action_priors: a list of tuples of actions and their prior probability 44 | according to the policy function. 45 | """ 46 | for action, prob in action_priors: 47 | if action not in self._children: 48 | self._children[action] = TreeNode(self, prob) 49 | 50 | def select(self, c_puct): 51 | """Select action among children that gives maximum action value Q 52 | plus bonus u(P). 53 | Return: A tuple of (action, next_node) 54 | """ 55 | return max(self._children.items(), 56 | key=lambda act_node: act_node[1].get_value(c_puct)) 57 | 58 | def update(self, leaf_value): 59 | """Update node values from leaf evaluation. 60 | leaf_value: the value of subtree evaluation from the current player's 61 | perspective. 62 | """ 63 | # Count visit. 64 | self._n_visits += 1 65 | # Update Q, a running average of values for all visits. 66 | self._Q += 1.0*(leaf_value - self._Q) / self._n_visits 67 | 68 | def update_recursive(self, leaf_value): 69 | """Like a call to update(), but applied recursively for all ancestors. 70 | """ 71 | # If it is not root, this node's parent should be updated first. 72 | if self._parent: 73 | self._parent.update_recursive(-leaf_value) 74 | self.update(leaf_value) 75 | 76 | def get_value(self, c_puct): 77 | """Calculate and return the value for this node. 78 | It is a combination of leaf evaluations Q, and this node's prior 79 | adjusted for its visit count, u. 80 | c_puct: a number in (0, inf) controlling the relative impact of 81 | value Q, and prior probability P, on this node's score. 82 | """ 83 | self._u = (c_puct * self._P * 84 | np.sqrt(self._parent._n_visits) / (1 + self._n_visits)) 85 | return self._Q + self._u 86 | 87 | def is_leaf(self): 88 | """Check if leaf node (i.e. no nodes below this have been expanded). 89 | """ 90 | return self._children == {} 91 | 92 | def is_root(self): 93 | return self._parent is None 94 | 95 | 96 | class MCTS(object): 97 | """A simple implementation of Monte Carlo Tree Search.""" 98 | 99 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): 100 | """ 101 | policy_value_fn: a function that takes in a board state and outputs 102 | a list of (action, probability) tuples and also a score in [-1, 1] 103 | (i.e. the expected value of the end game score from the current 104 | player's perspective) for the current player. 105 | c_puct: a number in (0, inf) that controls how quickly exploration 106 | converges to the maximum-value policy. A higher value means 107 | relying on the prior more. 108 | """ 109 | self._root = TreeNode(None, 1.0) 110 | self._policy = policy_value_fn 111 | self._c_puct = c_puct 112 | self._n_playout = n_playout 113 | 114 | def _playout(self, state): 115 | """Run a single playout from the root to the leaf, getting a value at 116 | the leaf and propagating it back through its parents. 117 | State is modified in-place, so a copy must be provided. 118 | """ 119 | node = self._root 120 | while(1): 121 | if node.is_leaf(): 122 | 123 | break 124 | # Greedily select next move. 125 | action, node = node.select(self._c_puct) 126 | state.do_move(action) 127 | 128 | action_probs, _ = self._policy(state) 129 | # Check for end of game 130 | end, winner = state.game_end() 131 | if not end: 132 | node.expand(action_probs) 133 | # Evaluate the leaf node by random rollout 134 | leaf_value = self._evaluate_rollout(state) 135 | # Update value and visit count of nodes in this traversal. 136 | node.update_recursive(-leaf_value) 137 | 138 | def _evaluate_rollout(self, state, limit=1000): 139 | """Use the rollout policy to play until the end of the game, 140 | returning +1 if the current player wins, -1 if the opponent wins, 141 | and 0 if it is a tie. 142 | """ 143 | player = state.get_current_player() 144 | for i in range(limit): 145 | end, winner = state.game_end() 146 | if end: 147 | break 148 | action_probs = rollout_policy_fn(state) 149 | max_action = max(action_probs, key=itemgetter(1))[0] 150 | state.do_move(max_action) 151 | else: 152 | # If no break from the loop, issue a warning. 153 | print("WARNING: rollout reached move limit") 154 | if winner == -1: # tie 155 | return 0 156 | else: 157 | return 1 if winner == player else -1 158 | 159 | def get_move(self, state): 160 | """Runs all playouts sequentially and returns the most visited action. 161 | state: the current game state 162 | 163 | Return: the selected action 164 | """ 165 | for n in range(self._n_playout): 166 | state_copy = copy.deepcopy(state) 167 | self._playout(state_copy) 168 | return max(self._root._children.items(), 169 | key=lambda act_node: act_node[1]._n_visits)[0] 170 | 171 | def update_with_move(self, last_move): 172 | """Step forward in the tree, keeping everything we already know 173 | about the subtree. 174 | """ 175 | if last_move in self._root._children: 176 | self._root = self._root._children[last_move] 177 | self._root._parent = None 178 | else: 179 | self._root = TreeNode(None, 1.0) 180 | 181 | def __str__(self): 182 | return "MCTS" 183 | 184 | 185 | class MCTSPlayer(object): 186 | """AI player based on MCTS""" 187 | def __init__(self, c_puct=5, n_playout=2000): 188 | self.mcts = MCTS(policy_value_fn, c_puct, n_playout) 189 | 190 | def set_player_ind(self, p): 191 | self.player = p 192 | 193 | def reset_player(self): 194 | self.mcts.update_with_move(-1) 195 | 196 | def get_action(self, board): 197 | sensible_moves = board.availables 198 | if len(sensible_moves) > 0: 199 | move = self.mcts.get_move(board) 200 | self.mcts.update_with_move(-1) 201 | return move 202 | else: 203 | print("WARNING: the board is full") 204 | 205 | def __str__(self): 206 | return "MCTS {}".format(self.player) 207 | -------------------------------------------------------------------------------- /training/src/model/pvn_inception.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An new design of the policyValueNet in Tensorflow 4 | Tested in Tensorflow 1.5 5 | 6 | @author: Zhuoqian Yang 7 | """ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | import tensorflow as tf 13 | slim = tf.contrib.slim 14 | import numpy as np 15 | from src.model.inception_resnet_v2 import * 16 | 17 | cpu_conf = tf.ConfigProto(device_count = {'GPU': 0}) 18 | 19 | class PolicyValueNet(): 20 | def __init__(self, board_width, board_height, log_dir=None, model_file=None, is_training=False): 21 | self.board_width = board_width 22 | self.board_height = board_height 23 | 24 | # Define the tensorflow neural network 25 | # 1. Input: 26 | self.input_states = tf.placeholder( 27 | tf.float32, shape=[None, 4, board_height, board_width]) 28 | self.input_state = tf.transpose(self.input_states, [0, 2, 3, 1]) 29 | # 2. Common Networks Layers 30 | 31 | with slim.arg_scope(inception_resnet_v2_arg_scope()): 32 | self.conv, self.conv_end_points = inception_resnet_v2_conv(self.input_state, is_training=is_training) 33 | 34 | # 3-1 Action Networks 35 | self.action_conv = tf.layers.conv2d(inputs=self.conv, filters=8, 36 | kernel_size=[1, 1], padding="same", 37 | data_format="channels_last", 38 | activation=tf.nn.relu) 39 | # Flatten the tensor 40 | self.action_conv_flat = tf.reshape( 41 | self.action_conv, [-1, 8 * board_height * board_width]) 42 | # 3-2 Full connected layer, the output is the log probability of moves 43 | # on each slot on the board 44 | self.action_fc = tf.layers.dense(inputs=self.action_conv_flat, 45 | units=board_height * board_width, 46 | activation=tf.nn.log_softmax) 47 | # 4 Evaluation Networks 48 | self.evaluation_conv = tf.layers.conv2d(inputs=self.conv, filters=4, 49 | kernel_size=[1, 1], 50 | padding="same", 51 | data_format="channels_last", 52 | activation=tf.nn.relu) 53 | self.evaluation_conv_flat = tf.reshape( 54 | self.evaluation_conv, [-1, 4 * board_height * board_width]) 55 | self.evaluation_fc1 = tf.layers.dense(inputs=self.evaluation_conv_flat, 56 | units=26, activation=tf.nn.relu) 57 | # output the score of evaluation on current state 58 | self.evaluation_fc2 = tf.layers.dense(inputs=self.evaluation_fc1, 59 | units=1, activation=tf.nn.tanh) 60 | 61 | # Define the Loss function 62 | # 1. Label: the array containing if the game wins or not for each state 63 | self.labels = tf.placeholder(tf.float32, shape=[None, 1]) 64 | # 2. Predictions: the array containing the evaluation score of each state 65 | # which is self.evaluation_fc2 66 | # 3-1. Value Loss function 67 | self.value_loss = tf.losses.mean_squared_error(self.labels, 68 | self.evaluation_fc2) 69 | # 3-2. Policy Loss function 70 | self.mcts_probs = tf.placeholder( 71 | tf.float32, shape=[None, board_height * board_width]) 72 | self.policy_loss = tf.negative(tf.reduce_mean( 73 | tf.reduce_sum(tf.multiply(self.mcts_probs, self.action_fc), 1))) 74 | # 3-3. L2 penalty (regularization) 75 | l2_penalty_beta = 1e-4 76 | vars = tf.trainable_variables() 77 | l2_penalty = l2_penalty_beta * tf.add_n( 78 | [tf.nn.l2_loss(v) for v in vars if 'bias' not in v.name.lower()]) 79 | # 3-4 Add up to be the Loss function 80 | self.loss = self.value_loss + self.policy_loss + l2_penalty 81 | 82 | # Define the optimizer we use for training 83 | self.learning_rate = tf.placeholder(tf.float32) 84 | self.optimizer = tf.train.AdamOptimizer( 85 | learning_rate=self.learning_rate).minimize(self.loss) 86 | 87 | # Make a session 88 | self.session = tf.Session() 89 | 90 | # calc policy entropy, for monitoring only 91 | self.entropy = tf.negative(tf.reduce_mean( 92 | tf.reduce_sum(tf.exp(self.action_fc) * self.action_fc, 1))) 93 | 94 | # Tensorboard summary 95 | if log_dir is not None: 96 | tf.summary.scalar('loss', self.loss) 97 | tf.summary.scalar('entropy', self.entropy) 98 | self.summary = tf.summary.merge_all() 99 | self.train_writer = tf.summary.FileWriter(log_dir, self.session.graph) 100 | 101 | # Initialize variables 102 | init = tf.global_variables_initializer() 103 | self.session.run(init) 104 | 105 | # For saving and restoring 106 | self.saver = tf.train.Saver(max_to_keep=1000) 107 | if model_file is not None: 108 | self.restore_model(model_file) 109 | 110 | def policy_value(self, state_batch): 111 | """ 112 | input: a batch of states 113 | output: a batch of action probabilities and state values 114 | """ 115 | log_act_probs, value = self.session.run( 116 | [self.action_fc, self.evaluation_fc2], 117 | feed_dict={self.input_states: state_batch} 118 | ) 119 | act_probs = np.exp(log_act_probs) 120 | return act_probs, value 121 | 122 | def policy_value_fn(self, board): 123 | """ 124 | input: board 125 | output: a list of (action, probability) tuples for each available 126 | action and the score of the board state 127 | """ 128 | legal_positions = board.availables 129 | current_state = np.ascontiguousarray(board.current_state().reshape( 130 | -1, 4, self.board_width, self.board_height)) 131 | act_probs, value = self.policy_value(current_state) 132 | act_probs = zip(legal_positions, act_probs[0][legal_positions]) 133 | return act_probs, value 134 | 135 | def train_step(self, state_batch, mcts_probs, winner_batch, lr): 136 | """perform a training step""" 137 | winner_batch = np.reshape(winner_batch, (-1, 1)) 138 | loss, entropy, _, summary = self.session.run( 139 | [self.loss, self.entropy, self.optimizer, self.summary], 140 | feed_dict={self.input_states: state_batch, 141 | self.mcts_probs: mcts_probs, 142 | self.labels: winner_batch, 143 | self.learning_rate: lr}) 144 | return loss, entropy, summary 145 | 146 | def write_summary(self, summary, i): 147 | self.train_writer.add_summary(summary, i) 148 | 149 | def save_model(self, model_path): 150 | self.saver.save(self.session, model_path) 151 | 152 | def restore_model(self, model_path): 153 | self.saver.restore(self.session, model_path) 154 | -------------------------------------------------------------------------------- /training/src/model/pvn_resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An new design of the policyValueNet in Tensorflow 4 | Tested in Tensorflow 1.5 5 | 6 | @author: Zhuoqian Yang 7 | """ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | import tensorflow as tf 13 | slim = tf.contrib.slim 14 | import numpy as np 15 | import src.model.resnet as resnet 16 | 17 | 18 | cpu_conf = tf.ConfigProto(device_count = {'GPU': 0}) 19 | 20 | class PolicyValueNet(): 21 | def __init__(self, board_width, board_height, log_dir=None, model_file=None, is_training=False): 22 | self.board_width = board_width 23 | self.board_height = board_height 24 | 25 | # Define the tensorflow neural network 26 | # 1. Input: 27 | self.input_states = tf.placeholder( 28 | tf.float32, shape=[None, 4, board_height, board_width]) 29 | self.input_state = tf.transpose(self.input_states, [0, 2, 3, 1]) 30 | # 2. Common Networks Layers 31 | 32 | self.conv = resnet.inference(self.input_state, [32, 64, 64, 128, 256]) 33 | 34 | # 3-1 Action Networks 35 | self.action_conv = tf.layers.conv2d(inputs=self.conv, filters=8, 36 | kernel_size=[1, 1], padding="same", 37 | data_format="channels_last", 38 | activation=tf.nn.relu) 39 | # Flatten the tensor 40 | self.action_conv_flat = tf.reshape( 41 | self.action_conv, [-1, 8 * board_height * board_width]) 42 | # 3-2 Full connected layer, the output is the log probability of moves 43 | # on each slot on the board 44 | self.action_fc = tf.layers.dense(inputs=self.action_conv_flat, 45 | units=board_height * board_width, 46 | activation=tf.nn.log_softmax) 47 | # 4 Evaluation Networks 48 | self.evaluation_conv = tf.layers.conv2d(inputs=self.conv, filters=4, 49 | kernel_size=[1, 1], 50 | padding="same", 51 | data_format="channels_last", 52 | activation=tf.nn.relu) 53 | self.evaluation_conv_flat = tf.reshape( 54 | self.evaluation_conv, [-1, 4 * board_height * board_width]) 55 | self.evaluation_fc1 = tf.layers.dense(inputs=self.evaluation_conv_flat, 56 | units=22, activation=tf.nn.relu) 57 | # output the score of evaluation on current state 58 | self.evaluation_fc2 = tf.layers.dense(inputs=self.evaluation_fc1, 59 | units=1, activation=tf.nn.tanh) 60 | 61 | # Define the Loss function 62 | # 1. Label: the array containing if the game wins or not for each state 63 | self.labels = tf.placeholder(tf.float32, shape=[None, 1]) 64 | # 2. Predictions: the array containing the evaluation score of each state 65 | # which is self.evaluation_fc2 66 | # 3-1. Value Loss function 67 | self.value_loss = tf.losses.mean_squared_error(self.labels, 68 | self.evaluation_fc2) 69 | # 3-2. Policy Loss function 70 | self.mcts_probs = tf.placeholder( 71 | tf.float32, shape=[None, board_height * board_width]) 72 | self.policy_loss = tf.negative(tf.reduce_mean( 73 | tf.reduce_sum(tf.multiply(self.mcts_probs, self.action_fc), 1))) 74 | # 3-3. L2 penalty (regularization) 75 | l2_penalty_beta = 1e-4 76 | vars = tf.trainable_variables() 77 | l2_penalty = l2_penalty_beta * tf.add_n( 78 | [tf.nn.l2_loss(v) for v in vars if 'bias' not in v.name.lower()]) 79 | # 3-4 Add up to be the Loss function 80 | self.loss = self.value_loss + self.policy_loss + l2_penalty 81 | 82 | # Define the optimizer we use for training 83 | self.learning_rate = tf.placeholder(tf.float32) 84 | self.optimizer = tf.train.AdamOptimizer( 85 | learning_rate=self.learning_rate).minimize(self.loss) 86 | 87 | # Make a session 88 | self.session = tf.Session() 89 | 90 | # calc policy entropy, for monitoring only 91 | self.entropy = tf.negative(tf.reduce_mean( 92 | tf.reduce_sum(tf.exp(self.action_fc) * self.action_fc, 1))) 93 | 94 | # Tensorboard summary 95 | if log_dir is not None: 96 | tf.summary.histogram('conv/activations', self.conv) 97 | tf.summary.scalar('conv/sparsity', tf.nn.zero_fraction(self.conv)) 98 | tf.summary.scalar('loss', self.loss) 99 | tf.summary.scalar('entropy', self.entropy) 100 | self.summary = tf.summary.merge_all() 101 | self.train_writer = tf.summary.FileWriter(log_dir, self.session.graph) 102 | 103 | # Initialize variables 104 | init = tf.global_variables_initializer() 105 | self.session.run(init) 106 | 107 | # For saving and restoring 108 | self.saver = tf.train.Saver(max_to_keep=1000) 109 | if model_file is not None: 110 | self.restore_model(model_file) 111 | 112 | def policy_value(self, state_batch): 113 | """ 114 | input: a batch of states 115 | output: a batch of action probabilities and state values 116 | """ 117 | log_act_probs, value = self.session.run( 118 | [self.action_fc, self.evaluation_fc2], 119 | feed_dict={self.input_states: state_batch} 120 | ) 121 | act_probs = np.exp(log_act_probs) 122 | return act_probs, value 123 | 124 | def policy_value_fn(self, board): 125 | """ 126 | input: board 127 | output: a list of (action, probability) tuples for each available 128 | action and the score of the board state 129 | """ 130 | legal_positions = board.availables 131 | current_state = np.ascontiguousarray(board.current_state().reshape( 132 | -1, 4, self.board_width, self.board_height)) 133 | act_probs, value = self.policy_value(current_state) 134 | act_probs = zip(legal_positions, act_probs[0][legal_positions]) 135 | return act_probs, value 136 | 137 | def train_step(self, state_batch, mcts_probs, winner_batch, lr): 138 | """perform a training step""" 139 | winner_batch = np.reshape(winner_batch, (-1, 1)) 140 | loss, entropy, _, summary = self.session.run( 141 | [self.loss, self.entropy, self.optimizer, self.summary], 142 | feed_dict={self.input_states: state_batch, 143 | self.mcts_probs: mcts_probs, 144 | self.labels: winner_batch, 145 | self.learning_rate: lr}) 146 | return loss, entropy, summary 147 | 148 | def write_summary(self, summary, i): 149 | self.train_writer.add_summary(summary, i) 150 | 151 | def save_model(self, model_path): 152 | self.saver.save(self.session, model_path) 153 | 154 | def restore_model(self, model_path): 155 | self.saver.restore(self.session, model_path) 156 | -------------------------------------------------------------------------------- /training/src/model/resnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | BN_EPSILON = 0.001 6 | 7 | def create_variables(name, shape, initializer=tf.contrib.layers.xavier_initializer(), is_fc_layer=False): 8 | ''' 9 | :param name: A string. The name of the new variable 10 | :param shape: A list of dimensions 11 | :param initializer: User Xavier as default. 12 | :param is_fc_layer: Want to create fc layer variable? May use different weight_decay for fc 13 | layers. 14 | :return: The created variable 15 | ''' 16 | 17 | ## TODO: to allow different weight decay to fully connected layer and conv layer 18 | regularizer = tf.contrib.layers.l2_regularizer(scale=0.0002) 19 | 20 | new_variables = tf.get_variable(name, shape=shape, initializer=initializer, 21 | regularizer=regularizer) 22 | return new_variables 23 | 24 | def batch_normalization_layer(input_layer, dimension): 25 | ''' 26 | Helper function to do batch normalziation 27 | :param input_layer: 4D tensor 28 | :param dimension: input_layer.get_shape().as_list()[-1]. The depth of the 4D tensor 29 | :return: the 4D tensor after being normalized 30 | ''' 31 | mean, variance = tf.nn.moments(input_layer, axes=[0, 1, 2]) 32 | beta = tf.get_variable('beta', dimension, tf.float32, 33 | initializer=tf.constant_initializer(0.0, tf.float32)) 34 | gamma = tf.get_variable('gamma', dimension, tf.float32, 35 | initializer=tf.constant_initializer(1.0, tf.float32)) 36 | bn_layer = tf.nn.batch_normalization(input_layer, mean, variance, beta, gamma, BN_EPSILON) 37 | 38 | return bn_layer 39 | 40 | 41 | def conv_bn_relu_layer(input_layer, filter_shape, stride): 42 | ''' 43 | A helper function to conv, batch normalize and relu the input tensor sequentially 44 | :param input_layer: 4D tensor 45 | :param filter_shape: list. [filter_height, filter_width, filter_depth, filter_number] 46 | :param stride: stride size for conv 47 | :return: 4D tensor. Y = Relu(batch_normalize(conv(X))) 48 | ''' 49 | 50 | out_channel = filter_shape[-1] 51 | filter = create_variables(name='conv', shape=filter_shape) 52 | 53 | conv_layer = tf.nn.conv2d(input_layer, filter, strides=[1, stride, stride, 1], padding='SAME') 54 | bn_layer = batch_normalization_layer(conv_layer, out_channel) 55 | 56 | output = tf.nn.relu(bn_layer) 57 | return output 58 | 59 | 60 | def bn_relu_conv_layer(input_layer, filter_shape, stride): 61 | ''' 62 | A helper function to batch normalize, relu and conv the input layer sequentially 63 | :param input_layer: 4D tensor 64 | :param filter_shape: list. [filter_height, filter_width, filter_depth, filter_number] 65 | :param stride: stride size for conv 66 | :return: 4D tensor. Y = conv(Relu(batch_normalize(X))) 67 | ''' 68 | 69 | in_channel = input_layer.get_shape().as_list()[-1] 70 | 71 | bn_layer = batch_normalization_layer(input_layer, in_channel) 72 | relu_layer = tf.nn.relu(bn_layer) 73 | 74 | filter = create_variables(name='conv', shape=filter_shape) 75 | conv_layer = tf.nn.conv2d(relu_layer, filter, strides=[1, stride, stride, 1], padding='SAME') 76 | return conv_layer 77 | 78 | 79 | 80 | def residual_block(input_layer, output_channel, first_block=False): 81 | ''' 82 | Defines a residual block in ResNet 83 | :param input_layer: 4D tensor 84 | :param output_channel: int. return_tensor.get_shape().as_list()[-1] = output_channel 85 | :param first_block: if this is the first residual block of the whole network 86 | :return: 4D tensor. 87 | ''' 88 | input_channel = input_layer.get_shape().as_list()[-1] 89 | 90 | # When it's time to "shrink" the image size, we use stride = 2 91 | if input_channel * 2 == output_channel: 92 | increase_dim = True 93 | elif input_channel == output_channel: 94 | increase_dim = False 95 | else: 96 | raise ValueError('Output and input channel does not match in residual blocks!!!') 97 | 98 | # The first conv layer of the first residual block does not need to be normalized and relu-ed. 99 | with tf.variable_scope('conv1_in_block'): 100 | if first_block: 101 | filter = create_variables(name='conv', shape=[3, 3, input_channel, output_channel]) 102 | conv1 = tf.nn.conv2d(input_layer, filter=filter, strides=[1, 1, 1, 1], padding='SAME') 103 | else: 104 | conv1 = bn_relu_conv_layer(input_layer, [3, 3, input_channel, output_channel], 1) 105 | 106 | with tf.variable_scope('conv2_in_block'): 107 | conv2 = bn_relu_conv_layer(conv1, [3, 3, output_channel, output_channel], 1) 108 | 109 | if increase_dim is True: 110 | padded_input = tf.pad(input_layer, [[0, 0], [0, 0], [0, 0], [input_channel, 0]]) 111 | else: 112 | padded_input = input_layer 113 | 114 | output = conv2 + padded_input 115 | return output 116 | 117 | 118 | def inference(input_tensor_batch, layers, reuse=False): 119 | ''' 120 | The main function that defines the ResNet. total layers = 1 + 2n + 2n + 2n +1 = 6n + 2 121 | :param input_tensor_batch: 4D tensor 122 | :param n: num_residual_blocks 123 | :param reuse: To build train graph, reuse=False. To build validation graph and share weights 124 | with train graph, resue=True 125 | :return: last layer in the network. Not softmax-ed 126 | ''' 127 | 128 | with tf.variable_scope('conv0', reuse=reuse): 129 | conv = conv_bn_relu_layer(input_tensor_batch, [3, 3, 4, layers[0]], 1) 130 | 131 | for i in range(len(layers)): 132 | with tf.variable_scope('conv%d'%(i+1), reuse=reuse): 133 | if i == 0: 134 | conv = residual_block(conv, layers[i], first_block=True) 135 | else: 136 | conv = residual_block(conv, layers[i]) 137 | 138 | return conv -------------------------------------------------------------------------------- /training/src/train_pipeline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Multi-thread implementation of the training pipeline of AlphaZero for Gomoku 4 | 5 | @author: Zhuoqian Yang 6 | """ 7 | 8 | from __future__ import print_function 9 | import random 10 | import time 11 | import os 12 | import sys 13 | import pickle 14 | from collections import defaultdict, deque 15 | from src.model.mcts_pure import MCTSPlayer as MCTS_Pure 16 | from src.model.mcts_alphaZero import MCTSPlayer 17 | from src.model.pvn_resnet import PolicyValueNet 18 | from src.train_thread import TrainThread 19 | from src.utils import * 20 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 21 | 22 | class TrainPipeline(): 23 | def __init__(self, ckpt=0, visualize=False): 24 | # training params 25 | self.learn_rate = 2e-4 26 | self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL 27 | self.temperature = 1.0 # the temperatureerature param 28 | self.n_playout = 400 # num of simulations for each move 29 | self.c_puct = 5 30 | self.buffer_size = 10000 31 | self.batch_size = 512 # mini-batch size for training 32 | self.epochs = 5 # num of train_steps for each update 33 | self.kl_targ = 0.02 34 | 35 | # accounting 36 | self.train_dir = os.path.join('.', 'train', 'resnet_11x11') 37 | self.log_dir = os.path.join(self.train_dir, 'tensorboard') 38 | self.save_freq = 50 39 | self.eval_points = [1000, 2000, 3000, 4000] 40 | self.total_trained = 0.0 41 | self.last_update = time.time() 42 | 43 | # load progress 44 | self.i = ckpt+1 45 | self.init_model(ckpt) 46 | self.init_memory(ckpt) 47 | self.init_progress(ckpt) 48 | 49 | # threads 50 | if visualize: self.n_thread = 1 51 | else: self.n_thread = 1 52 | self.remaining_train_jobs = 0 53 | self.accumulating = True 54 | self.accumulation_threshold = self.n_thread 55 | self.threads = [TrainThread(i, self, visualize=visualize) for i in range(self.n_thread)] 56 | self.mcts_players = [] 57 | 58 | self.play_locked = False 59 | self.train_locked = False 60 | self.dispatch_locked = False 61 | 62 | random.seed(int((time.time()%1)*1000000)) 63 | 64 | def init_model(self, ckpt): 65 | if ckpt > 0: 66 | model_path = os.path.join(self.train_dir, 'models', '%d.model' % ckpt) 67 | self.pvn = PolicyValueNet(N, N, model_file=model_path, is_training=True, log_dir=self.log_dir) 68 | else: 69 | self.pvn = PolicyValueNet(N, N, is_training=True, log_dir=self.log_dir) 70 | 71 | def init_memory(self, ckpt): 72 | if ckpt > 0: 73 | mem_file = open(os.path.join(self.train_dir, 'memory.deque'), "rb") 74 | self.memory = pickle.load(mem_file) 75 | mem_file.close() 76 | if len(self.memory) != self.buffer_size: 77 | self.tee("memory size changed from %d to %d" % (len(self.memory), self.buffer_size)) 78 | old = self.memory 79 | self.memory = deque(maxlen=self.buffer_size) 80 | for obj in old: 81 | self.memory.append(obj) 82 | del old 83 | 84 | else: 85 | self.memory = deque(maxlen=self.buffer_size) 86 | 87 | 88 | def init_progress(self, ckpt): 89 | if ckpt > 0: 90 | with open(os.path.join(self.train_dir, 'progress.pkl'), 'rb') as prog_file: 91 | prog = pickle.load(prog_file) 92 | self.total_trained = prog['total_trained'] 93 | self.lr_multiplier = prog['lr_multiplier'] 94 | else: 95 | self.total_trained = 0.0 96 | self.lr_multiplier = 1.0 97 | 98 | def dispatch(self): 99 | while self.dispatch_locked: time.sleep(0.01) 100 | self.dispatch_locked = True 101 | 102 | if self.remaining_train_jobs >= self.accumulation_threshold: 103 | self.accumulating = False 104 | 105 | job = PLAY 106 | if self.remaining_train_jobs > 0 and not self.accumulating: 107 | self.remaining_train_jobs -= 1 108 | job = TRAIN 109 | if self.remaining_train_jobs == 0: 110 | self.accumulating = True 111 | 112 | self.dispatch_locked = False 113 | return job 114 | 115 | def tee(self, str): 116 | with open(os.path.join(self.train_dir, 'train_log.txt'), "a") as myfile: 117 | myfile.write(str + "\n") 118 | print(str) 119 | 120 | def play_job_deliver(self, play_data): 121 | while self.play_locked: time.sleep(0.01) 122 | self.play_locked = True 123 | self.memory.extend(play_data) 124 | if len(self.memory) > self.batch_size: 125 | self.remaining_train_jobs += 1 126 | self.play_locked = False 127 | 128 | def policy_update(self): 129 | """update the policy-value net""" 130 | 131 | while self.train_locked: time.sleep(0.1) 132 | self.train_locked = True 133 | 134 | self.train_step() 135 | 136 | if self.i % self.save_freq == 0: 137 | self.save() 138 | 139 | self.i += 1 140 | self.train_locked = False 141 | 142 | def train_step(self): 143 | mini_batch = random.sample(self.memory, self.batch_size) 144 | state_batch = [data[0] for data in mini_batch] 145 | mcts_probs_batch = [data[1] for data in mini_batch] 146 | winner_batch = [data[2] for data in mini_batch] 147 | old_probs, old_v = self.pvn.policy_value(state_batch) 148 | 149 | loss, entropy, summary, kl, new_v = None, None, None, None, None 150 | for i in range(self.epochs): 151 | loss, entropy, summary = self.pvn.train_step( 152 | state_batch, 153 | mcts_probs_batch, 154 | winner_batch, 155 | self.learn_rate*self.lr_multiplier) 156 | new_probs, new_v = self.pvn.policy_value(state_batch) 157 | kl = np.mean(np.sum(old_probs * ( 158 | np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), 159 | axis=1) 160 | ) 161 | if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly 162 | break 163 | 164 | # record summary 165 | self.pvn.write_summary(summary, self.i) 166 | 167 | # adaptively adjust the learning rate 168 | if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: 169 | self.lr_multiplier /= 1.5 170 | elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: 171 | self.lr_multiplier *= 1.5 172 | 173 | explained_var_old = (1 - 174 | np.var(np.array(winner_batch) - old_v.flatten()) / 175 | np.var(np.array(winner_batch))) 176 | explained_var_new = (1 - 177 | np.var(np.array(winner_batch) - new_v.flatten()) / 178 | np.var(np.array(winner_batch))) 179 | 180 | now = time.time() 181 | self.total_trained += now-self.last_update 182 | self.last_update = now 183 | self.tee(("update {} at {:.2f} hrs | " 184 | "kl:{:.5f}, lr_multiplier:{:.3f}, loss:{}, entropy:{}," 185 | "explained_var_old:{:.3f}, explained_var_new:{:.3f}" 186 | ).format(self.i, self.total_trained/3600, 187 | kl, self.lr_multiplier, loss, entropy, 188 | explained_var_old, explained_var_new)) 189 | 190 | 191 | def save(self): 192 | self.pvn.save_model(os.path.join(self.train_dir, 'models', '%d.model' % self.i)) 193 | with open(os.path.join(self.train_dir, 'progress.pkl'), "wb") as prog_file: 194 | prog = {} 195 | prog['total_trained'] = self.total_trained 196 | prog['lr_multiplier'] = self.lr_multiplier 197 | pickle.dump(prog, prog_file) 198 | with open(os.path.join(self.train_dir, 'memory.deque'), "wb") as mem_file: 199 | pickle.dump(self.memory, mem_file) 200 | self.tee("progress saved at iteration %d" % self.i) 201 | 202 | 203 | def run(self): 204 | for i in range(self.n_thread): 205 | player = MCTSPlayer(self.pvn.policy_value_fn, 206 | c_puct=self.c_puct, 207 | n_playout=self.n_playout, 208 | is_selfplay=True) 209 | self.mcts_players.append(player) 210 | time.sleep(0.1) 211 | 212 | self.tee("traning started at iteration %d" % self.i) 213 | 214 | try: 215 | for thread in self.threads: 216 | thread.daemon=True 217 | thread.start() 218 | while True: time.sleep(3600) 219 | except (KeyboardInterrupt, SystemExit): 220 | print("Interrupted at iteration %d" % self.i) 221 | sys.exit(0) -------------------------------------------------------------------------------- /training/src/train_thread.py: -------------------------------------------------------------------------------- 1 | from threading import Thread 2 | import time 3 | from src.model.game import Board, Game 4 | from src.utils import * 5 | 6 | class TrainThread(Thread): 7 | def __init__(self, id, master, visualize=False): 8 | Thread.__init__(self) 9 | self.id = id 10 | self.master = master 11 | self.job = None 12 | self.visualize = visualize 13 | 14 | def run(self): 15 | try: 16 | while True: 17 | self.job = self.master.dispatch() 18 | if self.job == TRAIN: 19 | self.run_train() 20 | elif self.job == PLAY: 21 | self.run_play() 22 | elif self.job == EVAL: 23 | self.run_eval() 24 | except KeyboardInterrupt: 25 | return 26 | 27 | def run_train(self): 28 | self.master.policy_update() 29 | 30 | def run_play(self): 31 | time1 = time.time() 32 | self.board = Board(width=N, height=N, n_in_row=N_WIN) 33 | self.game = Game(self.board) 34 | winner, play_data, moves = self.game.start_self_play(self.master.mcts_players[self.id], 35 | temp=self.master.temperature, 36 | is_shown=self.visualize) 37 | play_data = list(play_data)[:] 38 | time2 = time.time() 39 | play_data = augment(play_data) 40 | self.master.tee("thrd %d, %d moves in %.2f mins, %.2fs per move" % (self.id, moves, ((time2 - time1) / 60) ,((time2 - time1) / moves))) 41 | self.play_data = play_data 42 | self.master.play_job_deliver(self.play_data) 43 | 44 | def run_eval(self): 45 | pass -------------------------------------------------------------------------------- /training/src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | N = 13 4 | N_WIN = 5 5 | 6 | PLAY = 0 7 | TRAIN = 1 8 | EVAL = 2 9 | IDLE = 10 10 | LOAD = 11 11 | 12 | def augment(play_data): 13 | """augment the data set by rotation and flipping 14 | play_data: [(state, mcts_prob, winner_z), ..., ...] 15 | """ 16 | extend_data = [] 17 | for state, mcts_porb, winner in play_data: 18 | for i in [1, 2, 3, 4]: 19 | # rotate counterclockwise 20 | equi_state = np.array([np.rot90(s, i) for s in state]) 21 | equi_mcts_prob = np.rot90(np.flipud( 22 | mcts_porb.reshape(N, N)), i) 23 | extend_data.append((equi_state, 24 | np.flipud(equi_mcts_prob).flatten(), 25 | winner)) 26 | # flip horizontally 27 | equi_state = np.array([np.fliplr(s) for s in equi_state]) 28 | equi_mcts_prob = np.fliplr(equi_mcts_prob) 29 | extend_data.append((equi_state, 30 | np.flipud(equi_mcts_prob).flatten(), 31 | winner)) 32 | return extend_data --------------------------------------------------------------------------------