├── .DS_Store
├── doc
└── 说明文档.pdf
└── src
├── .DS_Store
├── .idea
├── misc.xml
├── modules.xml
├── src.iml
└── workspace.xml
├── README.md
├── __pycache__
├── game.cpython-36.pyc
├── mcts_alphaZero.cpython-36.pyc
├── mcts_pure.cpython-36.pyc
└── policy_value_net_pytorch.cpython-36.pyc
├── black.png
├── chessboard.jpg
├── game.py
├── human_play.py
├── info
├── .DS_Store
├── 10_10_6_loss_.txt
└── 10_10_6_win_ration.txt
├── mcts_alphaZero.py
├── mcts_pure.py
├── model
├── .DS_Store
├── 10_10_6_best_policy_0.model
├── 10_10_6_best_policy_1.model
├── 10_10_6_best_policy_2.model
├── 10_10_6_best_policy_3.model
├── 10_10_6_current_policy_.model
├── 10_10_6_current_policy_0.model
├── 10_10_6_current_policy_1.model
├── 10_10_6_current_policy_2.model
└── 10_10_6_current_policy_3.model
├── policy_value_net_pytorch.py
├── train.py
├── ui.py
└── white.png
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/.DS_Store
--------------------------------------------------------------------------------
/doc/说明文档.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/doc/说明文档.pdf
--------------------------------------------------------------------------------
/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/.DS_Store
--------------------------------------------------------------------------------
/src/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/src/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/src/.idea/src.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/src/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
132 |
133 |
134 |
135 |
136 | true
137 | DEFINITION_ORDER
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 | 1528645108077
210 |
211 |
212 | 1528645108077
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 | file://$PROJECT_DIR$/mcts_pure.py
253 | 7
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
--------------------------------------------------------------------------------
/src/README.md:
--------------------------------------------------------------------------------
1 | ## 连六棋的深度强化学习AI
2 | ### python 依赖
3 | - numpy, pytorch0.4.0, PyQt5
4 |
5 | ### 使用方法
6 | - 实例 python ui.py -s 10 -r 6 -m 800 -i model/10_10_6_best_policy_3.model 是对战的命令,使用时可以修改选择的模型文件
7 | - 实例 python train.py -s 10 -r 6 -m 800 --graphics -n 2000 -i model/10_10_6_current_policy_.model 是模型的训练命令
8 | - 具体请运行 python human_play.py -h 或者 python train.py -h 寻找帮助
9 |
10 | ### 文件说明
11 | - human_play.py 命令行界面中的人机对战实现
12 | - ui.py 图形界面代码
13 | - game.py 游戏和棋盘的实现
14 | - mcts_pure.py 纯粹的蒙特卡洛搜索树
15 | - mcts_pureZero.py 基于连六棋游戏风格的蒙特卡洛搜索树
16 | - policy_value_net_pytorch.py 深度强化神经网络的实现
17 | - train.py 神经网络训练
18 |
--------------------------------------------------------------------------------
/src/__pycache__/game.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/__pycache__/game.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/mcts_alphaZero.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/__pycache__/mcts_alphaZero.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/mcts_pure.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/__pycache__/mcts_pure.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/policy_value_net_pytorch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/__pycache__/policy_value_net_pytorch.cpython-36.pyc
--------------------------------------------------------------------------------
/src/black.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/black.png
--------------------------------------------------------------------------------
/src/chessboard.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/chessboard.jpg
--------------------------------------------------------------------------------
/src/game.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: Junxiao Song
4 | @modifier: Junguang Jiang
5 |
6 | """
7 |
8 | from __future__ import print_function
9 | import numpy as np
10 | import copy
11 |
12 | class Board(object):
13 | """board for the game"""
14 |
15 | def __init__(self, **kwargs):
16 | self.width = int(kwargs.get('width', 8))
17 | self.height = int(kwargs.get('height', 8))
18 | # board states stored as a dict,
19 | # key: move as location on the board,
20 | # value: player as pieces type
21 | self.states = {}
22 | # need how many pieces in a row to win
23 | self.n_in_row = int(kwargs.get('n_in_row', 5))
24 | self.players = [1, 2] # player1 and player2
25 | self.chesses = 1 # 初始只能下一个棋
26 | self.last_moves = [] # 上回合下的所有棋
27 | self.curr_moves = [] # 这回合下的所有棋
28 |
29 |
30 | def init_board(self, start_player=0):
31 | if self.width < self.n_in_row or self.height < self.n_in_row:
32 | raise Exception('board width and height can not be '
33 | 'less than {}'.format(self.n_in_row))
34 | self.current_player = self.players[start_player] # start player
35 | # keep available moves in a list
36 | self.availables = list(range(self.width * self.height))
37 | self.states = {}
38 | self.last_move = -1
39 |
40 | def move_to_location(self, move):
41 | """
42 | 3*3 board's moves like:
43 | 6 7 8
44 | 3 4 5
45 | 0 1 2
46 | and move 5's location is (1,2)
47 | """
48 | h = move // self.width
49 | w = move % self.width
50 | return [h, w]
51 |
52 | def location_to_move(self, location):
53 | if len(location) != 2:
54 | return -1
55 | h = location[0]
56 | w = location[1]
57 | move = h * self.width + w
58 | if move not in range(self.width * self.height):
59 | return -1
60 | return move
61 |
62 | def current_state(self):
63 | """return the board state from the perspective of the current player.
64 | state shape: 4*width*height
65 | """
66 |
67 | square_state = np.zeros((4, self.width, self.height))
68 | if self.states:
69 | moves, players = np.array(list(zip(*self.states.items())))
70 | move_curr = moves[players == self.current_player]
71 | move_oppo = moves[players != self.current_player]
72 | square_state[0][move_curr // self.height,
73 | move_curr % self.height] = 1.0
74 | square_state[1][move_oppo // self.height,
75 | move_oppo % self.height] = 1.0
76 | for move in self.last_moves:
77 | square_state[2][move // self.height, move % self.height] = 1.0
78 | for move in self.curr_moves:
79 | square_state[3][move // self.height, move % self.height] = 1.0
80 | return square_state[:, ::-1, :]
81 |
82 | def do_move(self, move):
83 | """下一个棋子"""
84 | self.states[move] = self.current_player
85 | self.availables.remove(move)
86 | self.last_move = move
87 | self.curr_moves.append(move)
88 | self.chesses -= 1
89 | if self.chesses == 0:
90 | self._change_turn()
91 | self.chesses = 2
92 |
93 | def _change_turn(self):
94 | """交换下棋的权利"""
95 | self.current_player = (
96 | self.players[0] if self.current_player == self.players[1]
97 | else self.players[1]
98 | )
99 | self.last_moves = copy.deepcopy(self.curr_moves)
100 | self.curr_moves.clear()
101 |
102 | def has_a_winner(self):
103 | """判断当前是否有赢家了"""
104 | width = self.width
105 | height = self.height
106 | states = self.states
107 | n = self.n_in_row
108 |
109 | moved = list(set(range(width * height)) - set(self.availables))
110 | if len(moved) < self.n_in_row + 2:
111 | return False, -1
112 |
113 | for m in moved:
114 | h = m // width
115 | w = m % width
116 | player = states[m]
117 |
118 | if (w in range(width - n + 1) and
119 | len(set(states.get(i, -1) for i in range(m, m + n))) == 1): # 如果自[h,w]起,横排n个元素的颜色只有一种
120 | return True, player # 则游戏结束,返回赢家
121 |
122 | if (h in range(height - n + 1) and
123 | len(set(states.get(i, -1) for i in range(m, m + n * width, width))) == 1):
124 | return True, player
125 |
126 | if (w in range(width - n + 1) and h in range(height - n + 1) and
127 | len(set(states.get(i, -1) for i in range(m, m + n * (width + 1), width + 1))) == 1):
128 | return True, player
129 |
130 | if (w in range(n - 1, width) and h in range(height - n + 1) and
131 | len(set(states.get(i, -1) for i in range(m, m + n * (width - 1), width - 1))) == 1):
132 | return True, player
133 |
134 | return False, -1
135 |
136 | def game_end(self):
137 | """Check whether the game is ended or not"""
138 | win, winner = self.has_a_winner()
139 | if win:
140 | return True, winner
141 | elif not len(self.availables):
142 | return True, -1
143 | return False, -1
144 |
145 | def get_current_player(self):
146 | return self.current_player
147 |
148 | def is_start(self):
149 | """判断游戏是否为开局"""
150 | return len(self.availables) == ( self.width * self.height )
151 |
152 | def __str__(self):
153 | return str(self.height)+"_"+str(self.width)+"_"+str(self.n_in_row)
154 |
155 |
156 | class Game(object):
157 | """game server"""
158 |
159 | def __init__(self, board, **kwargs):
160 | self.board = board
161 |
162 | def graphic(self, board, player1, player2):
163 | """Draw the board and show game info"""
164 | width = board.width
165 | height = board.height
166 |
167 | print("Player", player1, "with X".rjust(3))
168 | print("Player", player2, "with O".rjust(3))
169 | print()
170 | for x in range(width):
171 | print("{0:8}".format(x), end='')
172 | print('\r\n')
173 | for i in range(height - 1, -1, -1):
174 | print("{0:4d}".format(i), end='')
175 | for j in range(width):
176 | loc = i * width + j
177 | p = board.states.get(loc, -1)
178 | if p == player1:
179 | print('X'.center(8), end='')
180 | elif p == player2:
181 | print('O'.center(8), end='')
182 | else:
183 | print('_'.center(8), end='')
184 | print('\r\n\r\n')
185 |
186 | def start_play(self, player1, player2, start_player=0, is_shown=1):
187 | """start a game between two players"""
188 | if start_player not in (0, 1):
189 | raise Exception('start_player should be either 0 (player1 first) '
190 | 'or 1 (player2 first)')
191 | self.board.init_board(start_player)
192 | p1, p2 = self.board.players
193 | player1.set_player_ind(p1)
194 | player2.set_player_ind(p2)
195 | players = {p1: player1, p2: player2}
196 | if is_shown:
197 | self.graphic(self.board, player1.player, player2.player)
198 | while True:
199 | current_player = self.board.get_current_player()
200 | player_in_turn = players[current_player]
201 | move = player_in_turn.get_action(self.board)
202 | self.board.do_move(move)
203 | if is_shown:
204 | self.graphic(self.board, player1.player, player2.player)
205 | end, winner = self.board.game_end()
206 | if end:
207 | if is_shown:
208 | if winner != -1:
209 | print("Game end. Winner is", players[winner])
210 | else:
211 | print("Game end. Tie")
212 | return winner
213 |
214 |
215 | def start_self_play(self, player, is_shown=0, temp=1e-3):
216 | """ start a self-play game using a MCTS player, reuse the search tree,
217 | and store the self-play data: (state, mcts_probs, z) for training
218 | """
219 | self.board.init_board()
220 | p1, p2 = self.board.players
221 | states, mcts_probs, current_players = [], [], []
222 | while True:
223 | move, move_probs = player.get_action(self.board,
224 | temp=temp,
225 | return_prob=1)
226 | # store the data
227 | states.append(self.board.current_state())
228 | mcts_probs.append(move_probs)
229 | current_players.append(self.board.current_player)
230 | # perform a move
231 | self.board.do_move(move)
232 | if is_shown:
233 | self.graphic(self.board, p1, p2)
234 | end, winner = self.board.game_end()
235 | if end:
236 | # winner from the perspective of the current player of each state
237 | winners_z = np.zeros(len(current_players))
238 | if winner != -1:
239 | winners_z[np.array(current_players) == winner] = 1.0
240 | winners_z[np.array(current_players) != winner] = -1.0
241 | # reset MCTS root node
242 | player.reset_player()
243 | if is_shown:
244 | if winner != -1:
245 | print("Game end. Winner is player:", winner)
246 | else:
247 | print("Game end. Tie")
248 | return winner, zip(states, mcts_probs, winners_z)
249 |
250 |
--------------------------------------------------------------------------------
/src/human_play.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | human VS AI models
4 | Input your move in the format: 2,3
5 |
6 | @author: Junxiao Song
7 | @modifier: Junguang Jiang
8 | """
9 |
10 | from __future__ import print_function
11 | from game import Board, Game
12 | from mcts_alphaZero import MCTSPlayer
13 | from policy_value_net_pytorch import PolicyValueNet # Pytorch
14 |
15 |
16 | # 请仔细阅读Human这个类
17 | class Human(object):
18 | """
19 | human player
20 | """
21 |
22 | def __init__(self):
23 | self.player = None
24 |
25 | def set_player_ind(self, p):
26 | """设置人类玩家的编号,黑:1,白:2"""
27 | self.player = p
28 |
29 | def get_action(self, board):
30 | """根据棋盘返回动作"""
31 | try:
32 | location = input("Your move: ") # 从键盘上读入位置,eg. "0,2"代表第0行第2列
33 | if isinstance(location, str): # 如果location确实是字符串
34 | location = [int(n, 10) for n in location.split(",")] # 将location转换为对应的坐标点
35 | move = board.location_to_move(location) # 坐标点转换为一个一维的值move,介于[0,width*height)
36 | except Exception as e: #异常情况下
37 | move = -1
38 | if move == -1 or move not in board.availables: # 如果move值不合法
39 | print("invalid move")
40 | move = self.get_action(board) # 重新等待输入
41 | return move
42 |
43 | def __str__(self):
44 | return "Human {}".format(self.player)
45 |
46 |
47 | # 以下函数可以略读
48 | def run(n_in_row, width, height, # 几子棋,棋盘宽度,高度
49 | model_file, ai_first, # 载入的模型文件,是否AI先下棋
50 | n_playout, use_gpu): # AI每次进行蒙特卡洛的模拟次数,是否使用GPU
51 | try:
52 | board = Board(width=width, height=height, n_in_row=n_in_row) # 产生一个棋盘
53 | game = Game(board) # 加载一个游戏
54 |
55 | # ############### human VS AI ###################
56 | best_policy = PolicyValueNet(width, height, model_file=model_file, use_gpu=use_gpu) # 加载最佳策略网络
57 | mcts_player = MCTSPlayer(best_policy.policy_value_fn, c_puct=5, n_playout=n_playout) # 生成一个AI玩家
58 | human = Human() # 生成一个人类玩家
59 |
60 | # set start_player=0 for human first
61 | game.start_play(human, mcts_player, start_player=ai_first, is_shown=1) # 开始游戏
62 | except KeyboardInterrupt:
63 | print('\n\rquit')
64 |
65 | def usage():
66 | print("-s 设置棋盘大小,默认为6")
67 | print("-r 设置是几子棋,默认为4")
68 | print("-m 设置每步棋执行MCTS模拟的次数,默认为400")
69 | print("-i ai使用哪个文件中的模型,默认为model/6_6_4_best_policy.model")
70 | print("--use_gpu 使用GPU进行运算")
71 | print("--human_first 让人类先下")
72 |
73 |
74 | if __name__ == '__main__':
75 | import sys, getopt
76 |
77 | height = 10
78 | width = 10
79 | n_in_row = 6
80 | use_gpu = False
81 | n_playout = 800
82 | model_file = "model/10_10_6_best_policy_3.model"
83 | ai_first=True
84 |
85 | opts, args = getopt.getopt(sys.argv[1:], "hs:r:m:i:", ["use_gpu", "graphics", "human_first"])
86 | for op, value in opts:
87 | if op == "-h":
88 | usage()
89 | sys.exit()
90 | elif op == "-s":
91 | height = width = int(value)
92 | elif op == "-r":
93 | n_in_row = int(value)
94 | elif op == "--use_gpu":
95 | use_gpu = True
96 | elif op == "-m":
97 | n_playout = int(value)
98 | elif op == "-i":
99 | model_file = value
100 | elif op == "--human_first":
101 | ai_first=False
102 | run(height=height, width=width, n_in_row=n_in_row, use_gpu=use_gpu, n_playout=n_playout,
103 | model_file=model_file, ai_first=ai_first)
104 |
--------------------------------------------------------------------------------
/src/info/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/info/.DS_Store
--------------------------------------------------------------------------------
/src/info/10_10_6_loss_.txt:
--------------------------------------------------------------------------------
1 | self-play次数,loss,entropy
2 |
--------------------------------------------------------------------------------
/src/info/10_10_6_win_ration.txt:
--------------------------------------------------------------------------------
1 | self-play次数, pure_MCTS战力, 胜率
2 |
--------------------------------------------------------------------------------
/src/mcts_alphaZero.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Monte Carlo Tree Search in AlphaGo Zero style, which uses a policy-value
4 | network to guide the tree search and evaluate the leaf nodes
5 |
6 | @author: Junxiao Song
7 | @modifier: Junguang Jiang
8 |
9 | """
10 |
11 | import numpy as np
12 | import copy
13 |
14 |
15 | def softmax(x):
16 | probs = np.exp(x - np.max(x))
17 | probs /= np.sum(probs)
18 | return probs
19 |
20 |
21 | class TreeNode(object):
22 | """A node in the MCTS tree.
23 |
24 | Each node keeps track of its own value Q, prior probability P, and
25 | its visit-count-adjusted prior score u.
26 | """
27 |
28 | def __init__(self, parent, prior_p):
29 | self._parent = parent
30 | self._children = {} # a map from action to TreeNode
31 | self._n_visits = 0
32 | self._Q = 0
33 | self._u = 0
34 | self._P = prior_p
35 | # self.flag = flag # 代表是否为对应选手所下的最后一步棋
36 |
37 | def expand(self, action_priors):
38 | """Expand tree by creating new children.
39 | action_priors: a list of tuples of actions and their prior probability
40 | according to the policy function.
41 | """
42 | for action, prob in action_priors:
43 | if action not in self._children:
44 | self._children[action] = TreeNode(self, prob)
45 |
46 | def select(self, c_puct):
47 | """Select action among children that gives maximum action value Q
48 | plus bonus u(P).
49 | Return: A tuple of (action, next_node)
50 | """
51 | return max(self._children.items(),
52 | key=lambda act_node: act_node[1].get_value(c_puct))
53 |
54 | def update(self, leaf_value):
55 | """Update node values from leaf evaluation.
56 | leaf_value: the value of subtree evaluation from the current player's
57 | perspective.
58 | """
59 | # Count visit.
60 | self._n_visits += 1
61 | # Update Q, a running average of values for all visits.
62 | self._Q += 1.0*(leaf_value - self._Q) / self._n_visits
63 |
64 | def update_recursive(self, leaf_value, flag):
65 | """Like a call to update(), but applied recursively for all ancestors.
66 | flag=1表示当前棋子是对应选手下的第2个棋子,其父节点是选手下的第一个棋子
67 | leaf_value是从父节点的视角看,选哪个子节点比较好
68 | """
69 | # If it is not root, this node's parent should be updated first.
70 | if self._parent:
71 | if flag:
72 | self._parent.update_recursive(-leaf_value, 1-flag)
73 | else:
74 | self._parent.update_recursive(leaf_value, 1-flag)
75 | self.update(leaf_value)
76 |
77 | def get_value(self, c_puct):
78 | """Calculate and return the value for this node.
79 | It is a combination of leaf evaluations Q, and this node's prior
80 | adjusted for its visit count, u.
81 | c_puct: a number in (0, inf) controlling the relative impact of
82 | value Q, and prior probability P, on this node's score.
83 | """
84 | self._u = (c_puct * self._P *
85 | np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
86 | return self._Q + self._u
87 |
88 | def is_leaf(self):
89 | """Check if leaf node (i.e. no nodes below this have been expanded)."""
90 | return self._children == {}
91 |
92 | def is_root(self):
93 | return self._parent is None
94 |
95 |
96 | class MCTS(object):
97 | """An implementation of Monte Carlo Tree Search."""
98 |
99 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
100 | """
101 | policy_value_fn: a function that takes in a board state and outputs
102 | a list of (action, probability) tuples and also a score in [-1, 1]
103 | (i.e. the expected value of the end game score from the current
104 | player's perspective) for the current player.
105 | c_puct: a number in (0, inf) that controls how quickly exploration
106 | converges to the maximum-value policy. A higher value means
107 | relying on the prior more.
108 | """
109 | self._root = TreeNode(None, 1.0)
110 | self._policy = policy_value_fn
111 | self._c_puct = c_puct
112 | self._n_playout = n_playout
113 |
114 | def _playout(self, state):
115 | """Run a single playout from the root to the leaf, getting a value at
116 | the leaf and propagating it back through its parents.
117 | State is modified in-place, so a copy must be provided.
118 | """
119 | node = self._root
120 | # current_player = state.get_current_player() # 发出当前动作的选手
121 | while(1):
122 | if node.is_leaf():
123 | break
124 | # Greedily select next move.
125 | action, node = node.select(self._c_puct)
126 | state.do_move(action)
127 |
128 | # Evaluate the leaf using a network which outputs a list of
129 | # (action, probability) tuples p and also a score v in [-1, 1]
130 | # for the current player.
131 | action_probs, leaf_value = self._policy(state)
132 | # Check for end of game.
133 | end, winner = state.game_end()
134 | if not end:
135 | node.expand(action_probs)
136 | else:
137 | # for end state,return the "true" leaf_value
138 | if winner == -1: # tie
139 | leaf_value = 0.0
140 | else:
141 | leaf_value = (
142 | 1.0 if winner == state.get_current_player() else -1.0
143 | )
144 |
145 | # Update value and visit count of nodes in this traversal.
146 | # node.update_recursive(-leaf_value) #存疑
147 | # node.update_recursive(leaf_value)
148 | if state.chesses == 2:
149 | node.update_recursive(-leaf_value, 0)
150 | else:
151 | node.update_recursive(leaf_value, 1)
152 |
153 | def get_move_probs(self, state, temp=1e-3):
154 | """Run all playouts sequentially and return the available actions and
155 | their corresponding probabilities.
156 | state: the current game state
157 | temp: temperature parameter in (0, 1] controls the level of exploration
158 | """
159 | for n in range(self._n_playout):
160 | state_copy = copy.deepcopy(state)
161 | self._playout(state_copy)
162 |
163 | # calc the move probabilities based on visit counts at the root node
164 | act_visits = [(act, node._n_visits)
165 | for act, node in self._root._children.items()]
166 | acts, visits = zip(*act_visits)
167 | act_probs = softmax(1.0/temp * np.log(np.array(visits) + 1e-10))
168 |
169 | return acts, act_probs
170 |
171 | def update_with_move(self, last_move):
172 | """Step forward in the tree, keeping everything we already know
173 | about the subtree.
174 | """
175 | if last_move in self._root._children:
176 | self._root = self._root._children[last_move]
177 | self._root._parent = None
178 | else:
179 | self._root = TreeNode(None, 1.0)
180 |
181 | def __str__(self):
182 | return "MCTS"
183 |
184 |
185 | class MCTSPlayer(object):
186 | """AI player based on MCTS"""
187 |
188 | def __init__(self, policy_value_function,
189 | c_puct=5, n_playout=2000, is_selfplay=0):
190 | self.mcts = MCTS(policy_value_function, c_puct, n_playout)
191 | self._is_selfplay = is_selfplay
192 |
193 | def set_player_ind(self, p):
194 | self.player = p
195 |
196 | def reset_player(self):
197 | self.mcts.update_with_move(-1)
198 |
199 | def get_action(self, board, temp=1e-3, return_prob=0):
200 | sensible_moves = board.availables
201 | # the pi vector returned by MCTS as in the alphaGo Zero paper
202 | move_probs = np.zeros(board.width*board.height)
203 | if len(sensible_moves) > 0:
204 | acts, probs = self.mcts.get_move_probs(board, temp)
205 | move_probs[list(acts)] = probs
206 | if self._is_selfplay:
207 | # add Dirichlet Noise for exploration (needed for
208 | # self-play training)
209 | move = np.random.choice(
210 | acts,
211 | p=0.75*probs + 0.25*np.random.dirichlet(0.3*np.ones(len(probs)))
212 | )
213 | # update the root node and reuse the search tree
214 | self.mcts.update_with_move(move)
215 | else:
216 | # with the default temp=1e-3, it is almost equivalent
217 | # to choosing the move with the highest prob
218 | move = np.random.choice(acts, p=probs)
219 | # reset the root node
220 | self.mcts.update_with_move(-1)
221 | # location = board.move_to_location(move)
222 | # print("AI move: %d,%d\n" % (location[0], location[1]))
223 |
224 | if return_prob:
225 | return move, move_probs
226 | else:
227 | return move
228 | else:
229 | print("WARNING: the board is full")
230 |
231 | def __str__(self):
232 | return "Alpha Zero MCTS {}".format(self.player)
233 |
--------------------------------------------------------------------------------
/src/mcts_pure.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | A pure implementation of the Monte Carlo Tree Search (MCTS)
4 |
5 | @author: Junxiao Song
6 | """
7 |
8 | import numpy as np
9 | import copy
10 | from operator import itemgetter
11 |
12 |
13 | def rollout_policy_fn(board):
14 | """a coarse, fast version of policy_fn used in the rollout phase."""
15 | # rollout randomly
16 | action_probs = np.random.rand(len(board.availables))
17 | return zip(board.availables, action_probs)
18 |
19 |
20 | def policy_value_fn(board):
21 | """a function that takes in a state and outputs a list of (action, probability)
22 | tuples and a score for the state"""
23 | # return uniform probabilities and 0 score for pure MCTS
24 | action_probs = np.ones(len(board.availables))/len(board.availables)
25 | return zip(board.availables, action_probs), 0
26 |
27 |
28 | class TreeNode(object):
29 | """A node in the MCTS tree. Each node keeps track of its own value Q,
30 | prior probability P, and its visit-count-adjusted prior score u.
31 | """
32 |
33 | def __init__(self, parent, prior_p):
34 | self._parent = parent
35 | self._children = {} # a map from action to TreeNode
36 | self._n_visits = 0
37 | self._Q = 0
38 | self._u = 0
39 | self._P = prior_p
40 | # self.flag = flag # 代表是否为对应选手所下的最后一步棋
41 |
42 |
43 | def expand(self, action_priors):
44 | """Expand tree by creating new children.
45 | action_priors: a list of tuples of actions and their prior probability
46 | according to the policy function.
47 | """
48 | for action, prob in action_priors:
49 | if action not in self._children:
50 | self._children[action] = TreeNode(self, prob)
51 |
52 | def select(self, c_puct):
53 | """Select action among children that gives maximum action value Q
54 | plus bonus u(P).
55 | Return: A tuple of (action, next_node)
56 | """
57 | return max(self._children.items(),
58 | key=lambda act_node: act_node[1].get_value(c_puct))
59 |
60 | def update(self, leaf_value):
61 | """Update node values from leaf evaluation.
62 | leaf_value: the value of subtree evaluation from the current player's
63 | perspective.
64 | """
65 | # Count visit.
66 | self._n_visits += 1
67 | # Update Q, a running average of values for all visits.
68 | self._Q += 1.0*(leaf_value - self._Q) / self._n_visits
69 |
70 | def update_recursive(self, leaf_value, flag):
71 | """Like a call to update(), but applied recursively for all ancestors.
72 | flag = 1时self.parent需要变号,否则不需要变号
73 | """
74 | # If it is not root, this node's parent should be updated first.
75 | if self._parent:
76 | if flag:
77 | self._parent.update_recursive(-leaf_value, 1-flag)
78 | else:
79 | self._parent.update_recursive(leaf_value, 1-flag)
80 | self.update(leaf_value)
81 |
82 | def get_value(self, c_puct):
83 | """Calculate and return the value for this node.
84 | It is a combination of leaf evaluations Q, and this node's prior
85 | adjusted for its visit count, u.
86 | c_puct: a number in (0, inf) controlling the relative impact of
87 | value Q, and prior probability P, on this node's score.
88 | """
89 | self._u = (c_puct * self._P *
90 | np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
91 | return self._Q + self._u
92 |
93 | def is_leaf(self):
94 | """Check if leaf node (i.e. no nodes below this have been expanded).
95 | """
96 | return self._children == {}
97 |
98 | def is_root(self):
99 | return self._parent is None
100 |
101 |
102 | class MCTS(object):
103 | """A simple implementation of Monte Carlo Tree Search."""
104 |
105 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
106 | """
107 | policy_value_fn: a function that takes in a board state and outputs
108 | a list of (action, probability) tuples and also a score in [-1, 1]
109 | (i.e. the expected value of the end game score from the current
110 | player's perspective) for the current player.
111 | c_puct: a number in (0, inf) that controls how quickly exploration
112 | converges to the maximum-value policy. A higher value means
113 | relying on the prior more.
114 | """
115 | self._root = TreeNode(None, 1.0)
116 | self._policy = policy_value_fn
117 | self._c_puct = c_puct
118 | self._n_playout = n_playout
119 |
120 | def _playout(self, state):
121 | """Run a single playout from the root to the leaf, getting a value at
122 | the leaf and propagating it back through its parents.
123 | State is modified in-place, so a copy must be provided.
124 | """
125 | node = self._root
126 | while(1):
127 | if node.is_leaf():
128 | break
129 |
130 | # Greedily select next move.
131 | action, node = node.select(self._c_puct)
132 | state.do_move(action)
133 |
134 | action_probs, _ = self._policy(state)
135 | # Check for end of game
136 | end, winner = state.game_end()
137 | if not end:
138 | node.expand(action_probs)
139 | # Evaluate the leaf node by random rollout
140 | leaf_value = self._evaluate_rollout(state)
141 | # Update value and visit count of nodes in this traversal.
142 | if state.chesses == 2:
143 | node.update_recursive(-leaf_value, 0)
144 | else:
145 | node.update_recursive(leaf_value, 1)
146 |
147 |
148 | def _evaluate_rollout(self, state, limit=1000):
149 | """Use the rollout policy to play until the end of the game,
150 | returning +1 if the current player wins, -1 if the opponent wins,
151 | and 0 if it is a tie.
152 | """
153 | player = state.get_current_player()
154 | for i in range(limit):
155 | end, winner = state.game_end()
156 | if end:
157 | break
158 | action_probs = rollout_policy_fn(state)
159 | max_action = max(action_probs, key=itemgetter(1))[0]
160 | state.do_move(max_action)
161 | else:
162 | # If no break from the loop, issue a warning.
163 | print("WARNING: rollout reached move limit")
164 | if winner == -1: # tie
165 | return 0
166 | else:
167 | return 1 if winner == player else -1
168 |
169 | def get_move(self, state):
170 | """Runs all playouts sequentially and returns the most visited action.
171 | state: the current game state
172 |
173 | Return: the selected action
174 | """
175 | for n in range(self._n_playout):
176 | state_copy = copy.deepcopy(state)
177 | self._playout(state_copy)
178 | return max(self._root._children.items(),
179 | key=lambda act_node: act_node[1]._n_visits)[0]
180 |
181 | def update_with_move(self, last_move):
182 | """Step forward in the tree, keeping everything we already know
183 | about the subtree.
184 | """
185 | if last_move in self._root._children:
186 | self._root = self._root._children[last_move]
187 | self._root._parent = None
188 | else:
189 | self._root = TreeNode(None, 1.0)
190 |
191 | def __str__(self):
192 | return "MCTS"
193 |
194 |
195 | class MCTSPlayer(object):
196 | """AI player based on MCTS"""
197 | def __init__(self, c_puct=5, n_playout=2000):
198 | self.mcts = MCTS(policy_value_fn, c_puct, n_playout)
199 |
200 | def set_player_ind(self, p):
201 | self.player = p
202 |
203 | def reset_player(self):
204 | self.mcts.update_with_move(-1)
205 |
206 | def get_action(self, board):
207 | sensible_moves = board.availables
208 | if len(sensible_moves) > 0:
209 | move = self.mcts.get_move(board)
210 | self.mcts.update_with_move(-1)
211 | return move
212 | else:
213 | print("WARNING: the board is full")
214 |
215 | def __str__(self):
216 | return "Pure MCTS {}".format(self.player)
217 |
--------------------------------------------------------------------------------
/src/model/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/model/.DS_Store
--------------------------------------------------------------------------------
/src/model/10_10_6_best_policy_0.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/model/10_10_6_best_policy_0.model
--------------------------------------------------------------------------------
/src/model/10_10_6_best_policy_1.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/model/10_10_6_best_policy_1.model
--------------------------------------------------------------------------------
/src/model/10_10_6_best_policy_2.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/model/10_10_6_best_policy_2.model
--------------------------------------------------------------------------------
/src/model/10_10_6_best_policy_3.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/model/10_10_6_best_policy_3.model
--------------------------------------------------------------------------------
/src/model/10_10_6_current_policy_.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/model/10_10_6_current_policy_.model
--------------------------------------------------------------------------------
/src/model/10_10_6_current_policy_0.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/model/10_10_6_current_policy_0.model
--------------------------------------------------------------------------------
/src/model/10_10_6_current_policy_1.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/model/10_10_6_current_policy_1.model
--------------------------------------------------------------------------------
/src/model/10_10_6_current_policy_2.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/model/10_10_6_current_policy_2.model
--------------------------------------------------------------------------------
/src/model/10_10_6_current_policy_3.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/model/10_10_6_current_policy_3.model
--------------------------------------------------------------------------------
/src/policy_value_net_pytorch.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | An implementation of the policyValueNet in PyTorch
4 | Tested in PyTorch 0.2.0 and 0.3.0
5 |
6 | @author: Junxiao Song
7 | @modifier: Junguang Jiang
8 |
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.optim as optim
14 | import torch.nn.functional as F
15 | from torch.autograd import Variable
16 | import numpy as np
17 |
18 |
19 | def set_learning_rate(optimizer, lr):
20 | """Sets the learning rate to the given value"""
21 | for param_group in optimizer.param_groups:
22 | param_group['lr'] = lr
23 |
24 |
25 | class Net(nn.Module):
26 | """policy-value network module"""
27 | def __init__(self, board_width, board_height):
28 | super(Net, self).__init__()
29 |
30 | self.board_width = board_width
31 | self.board_height = board_height
32 | # common layers
33 | self.conv1 = nn.Conv2d(4, 32, kernel_size=3, padding=1)
34 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
35 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
36 | # action policy layers
37 | self.act_conv1 = nn.Conv2d(128, 4, kernel_size=1)
38 | self.act_fc1 = nn.Linear(4*board_width*board_height,
39 | board_width*board_height)
40 | # state value layers
41 | self.val_conv1 = nn.Conv2d(128, 2, kernel_size=1)
42 | self.val_fc1 = nn.Linear(2*board_width*board_height, 64)
43 | self.val_fc2 = nn.Linear(64, 1)
44 |
45 | def forward(self, state_input):
46 | # common layers
47 | x = F.relu(self.conv1(state_input))
48 | x = F.relu(self.conv2(x))
49 | x = F.relu(self.conv3(x))
50 | # action policy layers
51 | x_act = F.relu(self.act_conv1(x))
52 | x_act = x_act.view(-1, 4*self.board_width*self.board_height)
53 | x_act = F.log_softmax(self.act_fc1(x_act), dim=1)
54 | # state value layers
55 | x_val = F.relu(self.val_conv1(x))
56 | x_val = x_val.view(-1, 2*self.board_width*self.board_height)
57 | x_val = F.relu(self.val_fc1(x_val))
58 | x_val = F.tanh(self.val_fc2(x_val))
59 | return x_act, x_val
60 |
61 |
62 | class PolicyValueNet():
63 | """policy-value network """
64 | def __init__(self, board_width, board_height,
65 | model_file=None, use_gpu=False):
66 | self.use_gpu = use_gpu
67 | self.board_width = board_width
68 | self.board_height = board_height
69 | self.l2_const = 1e-4 # coef of l2 penalty
70 | # the policy value net module
71 | if self.use_gpu:
72 | self.policy_value_net = Net(board_width, board_height).cuda()
73 | else:
74 | self.policy_value_net = Net(board_width, board_height)
75 | self.optimizer = optim.Adam(self.policy_value_net.parameters(),
76 | weight_decay=self.l2_const)
77 |
78 | if model_file:
79 | self.load_model(model_file=model_file)
80 |
81 | def policy_value(self, state_batch):
82 | """
83 | input: a batch of states
84 | output: a batch of action probabilities and state values
85 | """
86 | if self.use_gpu:
87 | state_batch = Variable(torch.FloatTensor(state_batch).cuda())
88 | log_act_probs, value = self.policy_value_net(state_batch)
89 | act_probs = np.exp(log_act_probs.data.cpu().numpy())
90 | return act_probs, value.data.cpu().numpy()
91 | else:
92 | state_batch = Variable(torch.FloatTensor(state_batch))
93 | log_act_probs, value = self.policy_value_net(state_batch)
94 | act_probs = np.exp(log_act_probs.data.numpy())
95 | return act_probs, value.data.numpy()
96 |
97 | def policy_value_fn(self, board):
98 | """
99 | input: board
100 | output: a list of (action, probability) tuples for each available
101 | action and the score of the board state
102 | """
103 | legal_positions = board.availables
104 | current_state = np.ascontiguousarray(board.current_state().reshape(
105 | -1, 4, self.board_width, self.board_height))
106 | if self.use_gpu:
107 | log_act_probs, value = self.policy_value_net(
108 | Variable(torch.from_numpy(current_state)).cuda().float())
109 | act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten())
110 | else:
111 | log_act_probs, value = self.policy_value_net(
112 | Variable(torch.from_numpy(current_state)).float())
113 | act_probs = np.exp(log_act_probs.data.numpy().flatten())
114 | act_probs = zip(legal_positions, act_probs[legal_positions])
115 | value = value.data[0][0]
116 | return act_probs, value
117 |
118 | def train_step(self, state_batch, mcts_probs, winner_batch, lr):
119 | """perform a training step"""
120 | # wrap in Variable
121 | if self.use_gpu:
122 | state_batch = Variable(torch.FloatTensor(state_batch).cuda())
123 | mcts_probs = Variable(torch.FloatTensor(mcts_probs).cuda())
124 | winner_batch = Variable(torch.FloatTensor(winner_batch).cuda())
125 | else:
126 | state_batch = Variable(torch.FloatTensor(state_batch))
127 | mcts_probs = Variable(torch.FloatTensor(mcts_probs))
128 | winner_batch = Variable(torch.FloatTensor(winner_batch))
129 |
130 | # zero the parameter gradients
131 | self.optimizer.zero_grad()
132 | # set learning rate
133 | set_learning_rate(self.optimizer, lr)
134 |
135 | # forward
136 | log_act_probs, value = self.policy_value_net(state_batch)
137 | # define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2
138 | # Note: the L2 penalty is incorporated in optimizer
139 | mseloss = nn.MSELoss()
140 | value_loss = mseloss(value.view(-1), winner_batch)
141 | policy_loss = -torch.mean(torch.sum(mcts_probs*log_act_probs, 1))
142 | loss = value_loss + policy_loss
143 | # backward and optimize
144 | loss.backward()
145 | self.optimizer.step()
146 | # calc policy entropy, for monitoring only
147 | entropy = -torch.mean(
148 | torch.sum(torch.exp(log_act_probs) * log_act_probs, 1)
149 | )
150 | if self.use_gpu:
151 | return loss.cpu().data[0],entropy.cpu().data[0]
152 | else:
153 | return loss.item(), entropy.item()
154 |
155 | def get_policy_param(self):
156 | net_params = self.policy_value_net.state_dict()
157 | return net_params
158 |
159 | def save_model(self, model_file):
160 | """ save model params to file """
161 | net_params = self.get_policy_param() # get model params
162 | torch.save(net_params, model_file)
163 |
164 | def load_model(self, model_file):
165 | """load model params from file"""
166 | # net_params = torch.load(model_file)
167 | net_params = torch.load(model_file, map_location=lambda storage, loc:storage)
168 | self.policy_value_net.load_state_dict(net_params)
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | An implementation of the training pipeline of AlphaZero for Gomoku
4 |
5 | @author: Junxiao Song
6 | @modifier: Junguang Jiang
7 |
8 | """
9 |
10 | from __future__ import print_function
11 | import random
12 | import numpy as np
13 | from collections import defaultdict, deque
14 | from game import Board, Game
15 | from mcts_pure import MCTSPlayer as MCTS_Pure
16 | from mcts_alphaZero import MCTSPlayer
17 | from policy_value_net_pytorch import PolicyValueNet # Pytorch
18 |
19 |
20 |
21 | class TrainPipeline():
22 | def __init__(self, init_model=None, board_width=6, board_height=6,
23 | n_in_row=4, n_playout=400, use_gpu=False, is_shown=False,
24 | output_file_name="", game_batch_number=1500):
25 | # params of the board and the game
26 | self.board_width = board_width
27 | self.board_height = board_height
28 | self.n_in_row = n_in_row
29 | self.board = Board(width=self.board_width,
30 | height=self.board_height,
31 | n_in_row=self.n_in_row)
32 | self.game = Game(self.board)
33 | # training params
34 | self.learn_rate = 2e-3
35 | self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL
36 | self.temp = 1.0 # the temperature param
37 | self.n_playout = n_playout # num of simulations for each move
38 | self.c_puct = 5
39 | self.buffer_size = 10000
40 | self.batch_size = 512 # mini-batch size for training
41 | self.data_buffer = deque(maxlen=self.buffer_size)
42 | self.play_batch_size = 1
43 | self.epochs = 5 # num of train_steps for each update
44 | self.kl_targ = 0.02
45 | self.check_freq = 50
46 | self.game_batch_num = game_batch_number
47 | self.best_win_ratio = 0.0
48 | # num of simulations used for the pure mcts, which is used as
49 | # the opponent to evaluate the trained policy
50 | self.pure_mcts_playout_num = 1000
51 | self.use_gpu = use_gpu
52 | self.is_shown = is_shown
53 | self.output_file_name = output_file_name
54 | self.policy_value_net = PolicyValueNet(self.board_width,
55 | self.board_height,
56 | model_file=init_model,
57 | use_gpu=self.use_gpu
58 | )
59 | self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
60 | c_puct=self.c_puct,
61 | n_playout=self.n_playout,
62 | is_selfplay=1)
63 |
64 |
65 | def get_equi_data(self, play_data):
66 | """augment the data set by rotation and flipping
67 | play_data: [(state, mcts_prob, winner_z), ..., ...]
68 | """
69 | extend_data = []
70 | for state, mcts_porb, winner in play_data:
71 | for i in [1, 2, 3, 4]:
72 | # rotate counterclockwise
73 | equi_state = np.array([np.rot90(s, i) for s in state])
74 | equi_mcts_prob = np.rot90(np.flipud(
75 | mcts_porb.reshape(self.board_height, self.board_width)), i)
76 | extend_data.append((equi_state,
77 | np.flipud(equi_mcts_prob).flatten(),
78 | winner))
79 | # flip horizontally
80 | equi_state = np.array([np.fliplr(s) for s in equi_state])
81 | equi_mcts_prob = np.fliplr(equi_mcts_prob)
82 | extend_data.append((equi_state,
83 | np.flipud(equi_mcts_prob).flatten(),
84 | winner))
85 | return extend_data
86 |
87 | def collect_selfplay_data(self, n_games=1):
88 | """collect self-play data for training"""
89 | for i in range(n_games):
90 | winner, play_data = self.game.start_self_play(self.mcts_player,
91 | temp=self.temp)
92 | play_data = list(play_data)[:]
93 | self.episode_len = len(play_data)
94 | # augment the data
95 | play_data = self.get_equi_data(play_data)
96 | self.data_buffer.extend(play_data)
97 |
98 | def policy_update(self):
99 | """update the policy-value net"""
100 | mini_batch = random.sample(self.data_buffer, self.batch_size)
101 | state_batch = [data[0] for data in mini_batch]
102 | mcts_probs_batch = [data[1] for data in mini_batch]
103 | winner_batch = [data[2] for data in mini_batch]
104 | old_probs, old_v = self.policy_value_net.policy_value(state_batch)
105 | for i in range(self.epochs):
106 | loss, entropy = self.policy_value_net.train_step(
107 | state_batch,
108 | mcts_probs_batch,
109 | winner_batch,
110 | self.learn_rate*self.lr_multiplier)
111 | new_probs, new_v = self.policy_value_net.policy_value(state_batch)
112 | kl = np.mean(np.sum(old_probs * (
113 | np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
114 | axis=1)
115 | )
116 | if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly
117 | break
118 | # adaptively adjust the learning rate
119 | if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
120 | self.lr_multiplier /= 1.5
121 | elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
122 | self.lr_multiplier *= 1.5
123 |
124 | explained_var_old = (1 -
125 | np.var(np.array(winner_batch) - old_v.flatten()) /
126 | np.var(np.array(winner_batch)))
127 | explained_var_new = (1 -
128 | np.var(np.array(winner_batch) - new_v.flatten()) /
129 | np.var(np.array(winner_batch)))
130 | print(("kl:{:.5f},"
131 | "lr_multiplier:{:.3f},"
132 | "loss:{},"
133 | "entropy:{},"
134 | "explained_var_old:{:.3f},"
135 | "explained_var_new:{:.3f}"
136 | ).format(kl,
137 | self.lr_multiplier,
138 | loss,
139 | entropy,
140 | explained_var_old,
141 | explained_var_new))
142 | return loss, entropy
143 |
144 | def policy_evaluate(self, n_games=10):
145 | """
146 | Evaluate the trained policy by playing against the pure MCTS player
147 | Note: this is only for monitoring the progress of training
148 | """
149 | current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
150 | c_puct=self.c_puct,
151 | n_playout=self.n_playout)
152 | pure_mcts_player = MCTS_Pure(c_puct=5,
153 | n_playout=self.pure_mcts_playout_num)
154 | win_cnt = defaultdict(int)
155 | for i in range(n_games):
156 | winner = self.game.start_play(current_mcts_player,
157 | pure_mcts_player,
158 | start_player=i % 2,
159 | is_shown=self.is_shown)
160 | win_cnt[winner] += 1
161 | win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games
162 | print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
163 | self.pure_mcts_playout_num,
164 | win_cnt[1], win_cnt[2], win_cnt[-1]))
165 | return win_ratio
166 |
167 | def run(self):
168 | """run the training pipeline"""
169 | with open("info/"+str(self.board)+"_loss_"+self.output_file_name+".txt",'w') as loss_file:
170 | loss_file.write("self-play次数,loss,entropy\n")
171 | with open("info/"+str(self.board)+"_win_ration"+self.output_file_name+".txt", 'w') as win_ratio_file:
172 | win_ratio_file.write("self-play次数, pure_MCTS战力, 胜率\n")
173 | try:
174 | for i in range(self.game_batch_num):
175 | self.collect_selfplay_data(self.play_batch_size)
176 | print("batch i:{}, episode_len:{}".format(
177 | i+1, self.episode_len))
178 | if len(self.data_buffer) > self.batch_size:
179 | loss, entropy = self.policy_update()
180 | with open("info/" + str(self.board) + "_loss_" + self.output_file_name + ".txt", 'a') as loss_file:
181 | loss_file.write(str(i+1)+','+str(loss)+','+str(entropy)+'\n')
182 | # check the performance of the current model,
183 | # and save the model params
184 | if (i+1) % self.check_freq == 0:
185 | print("current self-play batch: {}".format(i+1))
186 | win_ratio = self.policy_evaluate()
187 | with open("info/" + str(self.board) + "_win_ration" + self.output_file_name + ".txt",
188 | 'a') as win_ratio_file:
189 | win_ratio_file.write(str(i+1)+','+str(self.pure_mcts_playout_num)+','+str(win_ratio)+'\n')
190 | self.policy_value_net.save_model('./model/'+str(self.board_height)
191 | +'_'+str(self.board_width)
192 | +'_'+str(self.n_in_row)+
193 | '_current_policy_'+output_file_name+'.model')
194 | if win_ratio >= self.best_win_ratio:
195 | print("New best policy!!!!!!!!")
196 | self.best_win_ratio = win_ratio
197 | # update the best_policy
198 | self.policy_value_net.save_model('./model/'+str(self.board_height)
199 | +'_'+str(self.board_width)
200 | +'_'+str(self.n_in_row)+
201 | '_best_policy_'+output_file_name+'.model')
202 | if (self.best_win_ratio == 1.0 and
203 | self.pure_mcts_playout_num < 50000):
204 | self.pure_mcts_playout_num += 1000
205 | self.best_win_ratio = 0.0
206 | except KeyboardInterrupt:
207 | print('\n\rquit')
208 | loss_file.close()
209 | win_ratio_file.close()
210 |
211 | def usage():
212 | print("-s 设置棋盘大小,默认为6")
213 | print("-r 设置是几子棋,默认为4")
214 | print("-m 设置每步棋执行MCTS模拟的次数,默认为400")
215 | print("-o 训练好的模型存入文件的标识符(注意:程序会根据模型的参数自动生成文件名的前半部分)")
216 | print("-n 设置训练局数,默认为1500")
217 | print("--use_gpu 使用GPU进行训练")
218 | print("--graphics 当进行模型评估时,显示对战界面")
219 |
220 |
221 | if __name__ == '__main__':
222 | import sys, getopt
223 |
224 | height = 10
225 | width = 10
226 | n_in_row = 8
227 | use_gpu = False
228 | n_playout = 800
229 | is_shown = False
230 | output_file_name = ""
231 | game_batch_number = 1500
232 | init_model_name = None
233 | battle=False
234 |
235 | opts, args = getopt.getopt(sys.argv[1:], "hs:r:m:go:n:i:", ["use_gpu", "graphics"])
236 | for op, value in opts:
237 | if op == "-h":
238 | usage()
239 | sys.exit()
240 | elif op == "-s":
241 | height = width = int(value)
242 | elif op == "-r":
243 | n_in_row = int(value)
244 | elif op == "--use_gpu":
245 | use_gpu = True
246 | elif op == "-m":
247 | n_playout = int(value)
248 | elif op == "-g" or op == "--graphics":
249 | is_shown = True
250 | elif op == "-o":
251 | output_file_name = value
252 | elif op == "-i":
253 | init_model_name = value
254 | elif op == "-n":
255 | game_batch_number = int(value)
256 |
257 | training_pipeline = TrainPipeline(board_height=height, board_width=width,
258 | n_in_row=n_in_row, use_gpu=use_gpu,
259 | n_playout=n_playout, is_shown=is_shown,
260 | output_file_name=output_file_name,
261 | init_model=init_model_name,
262 | game_batch_number=game_batch_number)
263 | training_pipeline.run()
264 |
--------------------------------------------------------------------------------
/src/ui.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | ##############ONLY for testing################
4 | WIDTH = 540
5 | HEIGHT = 540
6 | MARGIN = 22
7 | GRID = (WIDTH - 2 * MARGIN) / (15 - 1)
8 | PIECE = 34
9 | EMPTY = 0
10 | BLACK = 1
11 | WHITE = 2
12 | SCALE = 5
13 | width = 8
14 | height = 8
15 | n_in_row = 5
16 | model_file = 'model/8_8_5_best_policy_.model'
17 | use_gpu = False
18 | n_playout = 800
19 | ##############################################
20 | from game import *
21 | from mcts_alphaZero import MCTSPlayer
22 | from policy_value_net_pytorch import PolicyValueNet # Pytorch
23 | import sys
24 | import os
25 | import time
26 | from PyQt5 import QtWidgets, QtCore, QtGui
27 | from PyQt5.QtWidgets import *
28 | from PyQt5.QtCore import *
29 | from PyQt5.QtGui import *
30 | import threading
31 |
32 | global AIChess
33 |
34 | ##################A Tool for arrangement of playing order###################
35 | #The point is to treat the game routine as a cycle, like
36 | #
37 | # 1 2 3 4
38 | # -> HUMAN -> HUMAN -> AI -> AI --
39 | # | |
40 | # | |
41 | # ------------------------------------
42 | #
43 | #and our game will start from position 2 or 4.
44 |
45 | class cycleGroup(tuple):
46 | def __init__(self, parent):
47 | self.elements = parent
48 | self.order = len(parent)
49 | self.point = 0
50 | def pointTurnRight(self):
51 | self.point = (self.point + 1) % self.order
52 | def pointTurnLeft(self):
53 | self.point = (self.point + self.order - 1) % self.order
54 | def element(self):
55 | return self.elements[self.point]
56 | ############################################################################
57 |
58 | class chessDetail(object):
59 | def __init__(self, i=0, j=0, x=0, y=0, chess=0, chessPic=None):
60 | self.gridCoordinate_i = i
61 | self.gridCoordinate_j = j
62 | self.pixelCoordinate_x = x
63 | self.pixelCoordinate_y = y
64 | self.chessType = chess
65 | self.chessPicture = chessPic
66 |
67 | class ChessBoard(QWidget):
68 | signalClicked = pyqtSignal()
69 | signalAIFirst = pyqtSignal(bool)
70 | signalHumanDraw_ChessCoordinates = pyqtSignal(int, int)
71 | signalDraw_Finished = pyqtSignal(bool)
72 | def __init__(self):
73 | super(ChessBoard, self).__init__()
74 |
75 | def initialize(self, scale):
76 | self.graphicsParameterSet(scale)
77 | self.graphicsUIInterfaceSet()
78 | self.boardRunningLogicSet()
79 |
80 | def graphicsParameterSet(self, scale):
81 | self.WIDTH = 540
82 | self.HEIGHT = 540
83 | self.MARGIN = 22
84 | self.GRID = (self.WIDTH - 2 * self.MARGIN) / (15 - 1)
85 | self.PIECE = 34
86 | self.EMPTY = 0
87 | self.BLACK = 1
88 | self.WHITE = 2
89 | self.SCALE = scale
90 | self.effectiveWIDTH = self.GRID * (self.SCALE - 1) + 2 * self.MARGIN
91 | self.effectiveHEIGHT = self.effectiveWIDTH
92 |
93 | def graphicsUIInterfaceSet(self):
94 | self.graphicsElementSet()
95 | self.graphicsUserSelfDefine()
96 | self.graphicsUserChoosingChessType()
97 | self.graphicsUserChoosingFirstHand()
98 | self.graphicsChessBoard()
99 |
100 | def graphicsElementSet(self):
101 | self.black = QPixmap('black.png')
102 | self.white = QPixmap('white.png')
103 |
104 | def graphicsUserSelfDefine(self):
105 | pass
106 |
107 | def graphicsUserChoosingChessType(self):
108 | message = QMessageBox()
109 | message.setIconPixmap(self.black)
110 | message.setWindowTitle("选择一")
111 | message.setText("玩家选择棋子颜色")
112 | message.addButton(QPushButton("黑子"), QMessageBox.YesRole)
113 | message.addButton(QPushButton("白子"), QMessageBox.NoRole)
114 |
115 | self.humanChessType = None
116 | answer = message.exec()
117 | global AIChess
118 | if answer == 0:
119 | self.humanChessType = BLACK
120 | self.AIChessType = WHITE
121 | AIChess = WHITE
122 | else:
123 | self.humanChessType = WHITE
124 | self.AIChessType = BLACK
125 | AIChess = BLACK
126 | self.dictionaryFromNameToElement = {'HUMAN':self.humanChessType, 'AI':self.AIChessType}
127 |
128 | def graphicsUserChoosingFirstHand(self):
129 | message = QMessageBox()
130 | message.setIconPixmap(self.white)
131 | message.setWindowTitle("选择先手")
132 | message.setText("玩家决定先手顺序")
133 | message.addButton(QPushButton("AI先手"), QMessageBox.YesRole)
134 | message.addButton(QPushButton("玩家先手"), QMessageBox.NoRole)
135 |
136 | answer = message.exec()
137 | if answer == 0:
138 | ai_first = True
139 | else:
140 | ai_first = False
141 |
142 | self.signalAIFirst.emit(ai_first)
143 |
144 | def graphicsChessBoard(self):
145 |
146 | # print("showBoard begins")
147 |
148 | palette1 = QPalette() # 设置棋盘背景
149 | palette1.setBrush(self.backgroundRole(), QtGui.QBrush(QtGui.QPixmap('chessboard.jpg')))
150 | self.setPalette(palette1)
151 |
152 | self.setCursor(Qt.PointingHandCursor)
153 |
154 | self.setMaximumSize(QtCore.QSize(WIDTH, HEIGHT))
155 | # print("hhh")
156 | self.setWindowTitle("NINAROW")
157 | self.setWindowIcon(QIcon('black.png'))
158 | self.resize(self.effectiveWIDTH, self.effectiveHEIGHT)
159 |
160 | self.mouse_point = QLabel()
161 | self.mouse_point.setScaledContents(True)
162 | self.mouse_point.setPixmap(self.black) #加载黑棋
163 | self.mouse_point.setGeometry(270, 270, PIECE, PIECE)
164 | self.mouse_point.raise_() # 鼠标始终在最上层
165 | self.setMouseTracking(True)
166 |
167 |
168 | def graphicsGameOver(self, winner):
169 | self.gameOverMessage = QMessageBox()
170 | reply = self.gameOverMessage.information(self, '游戏结束', winner, QMessageBox.Yes)
171 | if reply == QMessageBox.Yes:
172 | self.close()
173 | else:
174 | self.close()
175 | sys.exit()
176 |
177 | def boardRunningLogicSet(self):
178 | #set order cycle#
179 | self.isAIAlreadyDrawn = 1
180 | if self.isAIAlreadyDrawn is True:
181 | self.humanAvailable = 0
182 | else:
183 | self.humanAvailable = 1
184 | self.dictConstantToPic = {BLACK:self.black, WHITE:self.white}
185 | self.chessGrid = [[chessDetail(i, j, self.coordinate_transform_map2pixel(i, j)[0],
186 | self.coordinate_transform_map2pixel(i, j)[1], 0,
187 | QLabel(self)) for i in range(1+self.SCALE)] for j in range(1+self.SCALE)]
188 |
189 | def aiHasDrawn(self, AIHasDrawn):
190 | if AIHasDrawn is True:
191 | self.humanAvailable = True
192 |
193 |
194 | ###########Redefine the mouse event functions#########
195 | # By redefining mouseMoveEvent() we aim to set the mouse effect#
196 | def mouseMoveEvent(self, e):
197 | # self.lb1.setText(str(e.x()) + ' ' + str(e.y()))
198 | self.mouse_point.move(e.x() - 16, e.y() - 16)
199 |
200 | def mousePressEvent(self, e):
201 | if e.button() == Qt.LeftButton and self.humanAvailable:
202 | x, y = e.x(), e.y()
203 | i, j = self.coordinate_transform_pixel2map(x, y)
204 | if i < self.SCALE and j < self.SCALE:
205 | self.draw(i, j, 'HUMAN')
206 | self.signalHumanDraw_ChessCoordinates.emit(i, j)
207 | self.signalClicked.emit()
208 |
209 | def closeEvent(self, event):
210 | event.accept()
211 | sys.exit()
212 | ######################################################
213 |
214 |
215 | ###########Place chesses in the chessboard############
216 | def draw(self, i, j, Player):
217 | # print("Drawing begins!")
218 | # print('Player to draw = ', Player)
219 | chessToDraw = self.dictionaryFromNameToElement[Player]
220 | # print('chess to draw is ', chessToDraw)
221 | x, y = self.coordinate_transform_map2pixel(i, j)
222 | # print('x, y = ',x, y)
223 | self.chessGrid[i][j].chessType = chessToDraw
224 | self.chessGrid[i][j].chessPicture.setMouseTracking(True)
225 | self.chessGrid[i][j].chessPicture.setVisible(True)
226 | self.chessGrid[i][j].chessPicture.setPixmap(self.dictConstantToPic[chessToDraw])
227 | self.chessGrid[i][j].chessPicture.setGeometry(x, y, self.PIECE, self.PIECE)
228 | # print("Drawing ends!")
229 | self.update()
230 | QApplication.processEvents()
231 | self.signalDraw_Finished.emit(True)
232 | ######################################################
233 |
234 |
235 | ################Cordinates Transformation################
236 | # self-defined functions offering transformation between
237 | # pixel cordinates and
238 | # relative grid cordinates
239 |
240 | def coordinate_transform_map2pixel(self, i, j):
241 | # 从 chessMap 里的逻辑坐标到 UI 上的绘制坐标的转换
242 | return self.MARGIN + j * self.GRID - self.PIECE / 2, self.MARGIN + i * self.GRID - self.PIECE / 2
243 |
244 | def coordinate_transform_pixel2map(self, x, y):
245 | # 从 UI 上的绘制坐标到 chessMap 里的逻辑坐标的转换
246 | i, j = int(round((y - self.MARGIN) / self.GRID)), int(round((x - self.MARGIN) / self.GRID))
247 | # 有MAGIN, 排除边缘位置导致 i,j 越界
248 | if i < 0 or i >= 15 or j < 0 or j >= 15:
249 | return None, None
250 | else:
251 | return i, j
252 | #########################################################
253 |
254 | class HumanAgent(object):
255 | """
256 | human player
257 | """
258 |
259 | def __init__(self, interface):
260 | self.player = None
261 | self.interface = interface
262 | self.interface.signalHumanDraw_ChessCoordinates.connect(self.get_location)
263 | def set_player_ind(self, p):
264 | """设置人类玩家的编号,黑:1,白:2"""
265 | self.player = p
266 |
267 | def get_action(self, board):
268 | """根据棋盘返回动作"""
269 | location = self.get_location_from_window() # 从键盘上读入位置,eg. "0,2"代表第0行第2列
270 | # print('location = ', location)
271 | if isinstance(location, str): # 如果location确实是字符串
272 | location = [int(n, 10) for n in location.split(",")] # 将location转换为对应的坐标点
273 | move = board.location_to_move(location) # 坐标点转换为一个一维的值move,介于[0,width*height)
274 | if move not in board.availables:
275 | # print("Invalid move")
276 | move = self.get_action(board)
277 | return move
278 |
279 | def get_location(self, i, j):
280 | self.currentLocation = [i, j]
281 | return
282 |
283 | def get_location_from_window(self, timeout = 10000):
284 | loop = QEventLoop()
285 | self.interface.signalClicked.connect(loop.quit)
286 | loop.exec_()
287 | return self.currentLocation
288 |
289 | def __str__(self):
290 | return "Human {}".format(self.player)
291 | # try:
292 | # location = self.get_location_from_window() # 从键盘上读入位置,eg. "0,2"代表第0行第2列
293 | # print('location = ', location)
294 | # if isinstance(location, str): # 如果location确实是字符串
295 | # location = [int(n, 10) for n in location.split(",")] # 将location转换为对应的坐标点
296 | # move = board.location_to_move(location) # 坐标点转换为一个一维的值move,介于[0,width*height)
297 | # except Exception as e: #异常情况下
298 | # move = -1
299 | # if move == -1 or move not in board.availables: # 如果move值不合法
300 | # print("invalid move")
301 | # move = self.get_action(board) # 重新等待输入
302 | # return move
303 |
304 |
305 |
306 |
307 | class UserInterface_GO_Human_vs_AI(QWidget):
308 | signalOfDrawnChess = pyqtSignal(int, int, str)
309 | signalOfWinner = pyqtSignal(str)
310 | def __init__(self, AIPlayer, board_logic, width, height):
311 | super().__init__()
312 | self.AI = AIPlayer
313 | self.board = board_logic
314 | self.interface = ChessBoard()
315 | self.human = HumanAgent(self.interface)
316 | self.logicProcess()
317 |
318 | if width == height:
319 | self.scale = width
320 |
321 | self.width = width
322 | self.height = height
323 |
324 | def run(self):
325 | self.interface.show()
326 |
327 | def test(self):
328 | self.interface.signalAIFirst.connect(self.cycleInitialize)
329 | self.signalOfDrawnChess.connect(self.interface.draw)
330 | self.signalOfWinner.connect(self.interface.graphicsGameOver)
331 | self.interface.initialize(self.scale)
332 | self.interface.show()
333 | self.playChess()
334 |
335 |
336 | def playChess(self):
337 | # print('PlayChess Begin!!')
338 | end, winner = self.board.game_end()
339 | while end is False:
340 | currentPlayer = self.dictionary[self.chesses.element()]
341 | playerName = self.chesses.element()
342 | # print('The pointer points to', playerName)
343 | # print('current Player is ', currentPlayer, '.')
344 | nextMove = currentPlayer.get_action(self.board)
345 | self.board.do_move(nextMove)
346 | nextLocation = self.board.move_to_location(nextMove)
347 | i, j = nextLocation[0], nextLocation[1]
348 | self.signalOfDrawnChess.emit(i, j, playerName)
349 | self.chesses.pointTurnRight()
350 | end, winner = self.board.game_end()
351 | if winner == 1:
352 | print("玩家胜利")
353 | strWinner = '玩家胜利'
354 | else:
355 | print("AI胜利")
356 | strWinner = 'AI胜利'
357 | self.signalOfWinner.emit(strWinner)
358 |
359 | def logicProcess(self):
360 | self.dictionary = {'HUMAN':self.human, 'AI':self.AI}
361 | cycle = ('HUMAN', 'HUMAN', 'AI', 'AI')
362 | self.chesses = cycleGroup(cycle)
363 |
364 | def cycleInitialize(self, aiFirst):
365 | # print('Here we are~')
366 | if aiFirst is True:
367 | self.chesses.point = 3
368 | self.board.init_board(1)
369 | else:
370 | self.chesses.point = 1
371 | self.board.init_board(0)
372 |
373 | def run(n_in_row, width, height, # 几子棋,棋盘宽度,高度
374 | model_file, ai_first, # 载入的模型文件,是否AI先下棋
375 | n_playout, use_gpu): # AI每次进行蒙特卡洛的模拟次数,是否使用GPU
376 | try:
377 | board = Board(width=width, height=height, n_in_row=n_in_row) # 产生一个棋盘
378 |
379 | # ############### human VS AI ###################
380 | best_policy = PolicyValueNet(width, height, model_file=model_file, use_gpu=use_gpu) # 加载最佳策略网络
381 | mcts_player = MCTSPlayer(best_policy.policy_value_fn, c_puct=5, n_playout=n_playout) # 生成一个AI玩家
382 | main = UserInterface_GO_Human_vs_AI(mcts_player, board, width, height,)
383 |
384 | main.test()
385 | # set start_player=0 for human first
386 | # game.start_play(human, mcts_player, start_player=ai_first, is_shown=1) # 开始游戏
387 | except KeyboardInterrupt:
388 | print('\n\rquit')
389 |
390 | def usage():
391 | print("-s 设置棋盘大小,默认为6")
392 | print("-r 设置是几子棋,默认为4")
393 | print("-m 设置每步棋执行MCTS模拟的次数,默认为400")
394 | print("-i ai使用哪个文件中的模型,默认为model/6_6_4_best_policy.model")
395 | print("--use_gpu 使用GPU进行运算")
396 | print("--human_first 让人类先下")
397 |
398 |
399 | #if __name__ == '__main__':
400 | # from PyQt5 import QtWidgets
401 | # if not QtWidgets.QApplication.instance():
402 | # app = QtWidgets.QApplication(sys.argv)
403 | # else:
404 | # app = QtWidgets.QApplication.instance()
405 | #
406 | # best_policy = PolicyValueNet(width, height, model_file=model_file, use_gpu=use_gpu) # 加载最佳策略网络
407 | # mcts_player = MCTSPlayer(best_policy.policy_value_fn, c_puct=5, n_playout=n_playout) # 生成一个AI玩家
408 | #
409 | # board = Board()
410 | # board.width = width
411 | # board.height = height
412 | # main = UserInterface_GO_Human_vs_AI(mcts_player, board)
413 | ## main.interface.signalAIFirst.connect(main.cycleInitialize)
414 | # main.test()
415 | # sys.exit(app.exec_())
416 |
417 |
418 | if __name__ == '__main__':
419 | import sys, getopt
420 | from PyQt5 import QtWidgets
421 | if not QtWidgets.QApplication.instance():
422 | app = QtWidgets.QApplication(sys.argv)
423 | else:
424 | app = QtWidgets.QApplication.instance()
425 | height = 10
426 | width = 10
427 | n_in_row = 6
428 | use_gpu = False
429 | n_playout = 800
430 | model_file = "model/10_10_6_current_policy_.model"
431 | # model_file = "model/10_10_6_best_policy_3.model"
432 | ai_first=True
433 |
434 | opts, args = getopt.getopt(sys.argv[1:], "hs:r:m:i:", ["use_gpu", "graphics", "human_first"])
435 | for op, value in opts:
436 | if op == "-h":
437 | usage()
438 | sys.exit()
439 | elif op == "-s":
440 | height = width = int(value)
441 | elif op == "-r":
442 | n_in_row = int(value)
443 | elif op == "--use_gpu":
444 | use_gpu = True
445 | elif op == "-m":
446 | n_playout = int(value)
447 | elif op == "-i":
448 | model_file = value
449 | elif op == "--human_first":
450 | ai_first=False
451 | run(height=height, width=width, n_in_row=n_in_row, use_gpu=use_gpu, n_playout=n_playout,
452 | model_file=model_file, ai_first=ai_first)
453 |
--------------------------------------------------------------------------------
/src/white.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunguangJiang/AlphaSix/5a54b9f243bbeb790c9cd7601358748faebcc429/src/white.png
--------------------------------------------------------------------------------