├── LICENSE
├── README.md
├── board.py
├── checkpoint_to_weights.py
├── config.py
├── dlgo.py
├── docs
├── ComputerGoHistory.md
├── English.md
├── Methods.md
├── PyGoEngine.md
├── SmartGameFormat.md
├── Structure.md
├── Training.md
├── Tutorial.md
├── dlgoAPI.md
└── dlgoGTP.md
├── gtp.py
├── gui.py
├── img
├── alphago_zero_mcts.jpg
├── dlgo_vs_leela.gif
├── loss.gif
├── loss_plot.png
├── mcts.png
├── overfitting.png
├── policy_value.gif
├── puct.gif
├── sabaki-analysis.png
├── score_board.png
├── screenshot_sabaki_01.png
├── screenshot_sabaki_02.png
├── screenshot_sabaki_03.png
├── shortcut.png
└── ucb.gif
├── mcts.py
├── network.py
├── requirements.txt
├── sgf.py
├── sgf.zip
├── time_control.py
├── train.py
└── validate.py
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Copyright 2021-2022 Hung-Zhe Lin.
3 |
4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
5 | associated documentation files (the "Software"), to deal in the Software without restriction,
6 | including without limitation the rights to use, copy, modify, merge, publish, distribute,
7 | sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all copies or
11 | substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
14 | NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
15 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
16 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
17 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
18 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pyDLGO
2 |
3 | The simple English tutorial is [Here](./docs/English.md)
4 |
5 | 自從 AlphaGo 打敗世界冠軍後,電腦圍棋儼然變成深度學習的代名詞,讓不少同學對於電腦圍棋有不小的興趣,但實做一個完整的圍棋引擎並不是只有深度學習而已,還包含許許多多枯燥乏味且需花費大量時間的部份,這令多數同學望而怯步。dlgo 實做一個最低要求的圍棋引擎,它包含圍棋的基本演算法、GTP 界面和 SGF 格式解析器,讓同學可以先跳過這些部份,專注於深度學習,體驗電腦圍棋的魅力。最終目標是希望幫助同學製造屬於自己的圍棋引擎,並參加 TCGA 電腦對局競賽。
6 |
7 | #### (黑) dlgo-0.1 vs (白) Leela-0.11 (黑中盤勝)
8 | 
9 |
10 |
11 | #### 使用 Sabaki 分析
12 |
13 |
14 |

15 |
16 |
17 | ## 快速開始
18 |
19 | 開始前請先安裝以下的 python 依賴庫(請注意本程式使用 python3)
20 | 1. PyTorch(1.x 版本或以上,如果要使用 GPU 請下載對應的 CUDA/cuDNN 版本)
21 | 2. NumPy
22 | 3. Tkinter(僅使用內建的 GUI 時需要)
23 |
24 | 請輸入下列指令安裝,或自行使用下載可執行的版本
25 |
26 | pip3 install -r requirements.txt
27 |
28 | 完成依賴庫安裝後,首先請先下載本程式碼和預先訓練好的權重,預先訓練好的權重可到 Release 裡找到(為 pt 檔,不需要解壓縮),將權重放到 pyDLGO 的資料夾裡,假設權重的名稱為 nn_2x64.pt ,請輸入以下指令打開圖形界面
29 |
30 | $ python3 dlgo.py --weights nn_2x64.pt --gui
31 |
32 | ## 文件目錄
33 | 0. [圍棋規則的簡單介紹](https://www.smartgo.com/tw/go.html)
34 | 1. [完整的操作教學和 TCGA 比賽](./docs/Tutorial.md)
35 | 2. [演算法實做和原理(持續施工中)](./docs/Methods.md)
36 | 3. [GTP 界面原理](./docs/dlgoGTP.md)
37 | 4. [SGF 格式說明](./docs/SmartGameFormat.md)
38 | 5. [board.py 內的 functions](./docs/dlgoAPI.md)
39 | 6. [Python 圍棋引擎列表(歡迎添加)](./docs/PyGoEngine.md)
40 |
41 | ## License
42 |
43 | 本程式為 MIT License ,部份程式有各自的 License ,分別為
44 |
45 | * [board.py](https://github.com/ymgaq/Pyaq) 和 [sgf.zip](https://github.com/ymgaq/Pyaq)
46 | * [gui.py](https://github.com/YoujiaZhang/AlphaGo-Zero-Gobang)
47 |
48 |
49 | ## Todo
50 |
51 | * 修復 GUI 的 bug 並優化之
52 |
53 | ### 聯絡資訊
54 |
55 | 如果有任何問題或是建議,可以通過 ```cglemon000@gmail.com``` 聯繫我。
56 |
--------------------------------------------------------------------------------
/board.py:
--------------------------------------------------------------------------------
1 | from config import BOARD_SIZE, KOMI, INPUT_CHANNELS, PAST_MOVES
2 | import numpy as np
3 | import copy
4 |
5 | BLACK = 0
6 | WHITE = 1
7 | EMPTY = 2
8 | INVLD = 3
9 |
10 | NUM_VERTICES = (BOARD_SIZE+2) ** 2 # max vertices number
11 | NUM_INTESECTIONS = BOARD_SIZE ** 2 # max intersections number
12 |
13 | PASS = -1 # pass
14 | RESIGN = -2 # resign
15 | NULL_VERTEX = NUM_VERTICES+1 # invalid position
16 |
17 | class StoneLiberty(object):
18 | def __init__(self):
19 | self.lib_cnt = NULL_VERTEX # liberty count
20 | self.v_atr = NULL_VERTEX # liberty position if in atari
21 | self.libs = set() # set of liberty positions
22 |
23 | def clear(self):
24 | # Reset itself.
25 | self.lib_cnt = NULL_VERTEX
26 | self.v_atr = NULL_VERTEX
27 | self.libs.clear()
28 |
29 | def set(self):
30 | # Set one stone.
31 | self.lib_cnt = 0
32 | self.v_atr = NULL_VERTEX
33 | self.libs.clear()
34 |
35 | def add(self, v):
36 | # Add liberty at v.
37 | if v not in self.libs:
38 | self.libs.add(v)
39 | self.lib_cnt += 1
40 | self.v_atr = v
41 |
42 | def sub(self, v):
43 | # Remove liberty at v.
44 | if v in self.libs:
45 | self.libs.remove(v)
46 | self.lib_cnt -= 1
47 |
48 | def merge(self, other):
49 | # Merge itself with another stone.
50 | self.libs |= other.libs
51 | self.lib_cnt = len(self.libs)
52 | if self.lib_cnt == 1:
53 | for lib in self.libs:
54 | self.v_atr = lib
55 |
56 | '''
57 | What is the vertex? Vertex is not real board position. It is mail-box position. For example,
58 | We set the board size to 5. The real board looks like
59 |
60 | a b c d e
61 | 1 . . . . .
62 | 2 . . . . .
63 | 3 . . . . .
64 | 4 . . . . .
65 | 5 . . . . .
66 |
67 | We define the coordinate as index, from a1 to e5. There is some problem to shife the index. The
68 | shift operation may out of the board. For example, we want to find all positions of adjacent a1
69 | index. There are two positions out of the board. One way to deal with it is to check out the
70 | boundary. Another fast way to deal with it is mail-box struct. Here is the mail-box looks like
71 |
72 | a b c d e
73 | - - - - - - -
74 | 1 - . . . . . -
75 | 2 - . . . . . -
76 | 3 - . . . . . -
77 | 4 - . . . . . -
78 | 5 - . . . . . -
79 | - - - - - - -
80 |
81 | The board size is changed from 5 to 7. We define the new coordinate as vertex. With mail-box,
82 | we don't need to waste time to check out the boundary any more. Notice that '-' is out of board
83 | position.
84 |
85 | '''
86 |
87 | class Board(object):
88 | def __init__(self, board_size=BOARD_SIZE, komi=KOMI):
89 | self.state = np.full(NUM_VERTICES, INVLD) # positions state
90 | self.sl = [StoneLiberty() for _ in range(NUM_VERTICES)] # stone liberties
91 | self.reset(board_size, komi)
92 |
93 | def reset(self, board_size, komi):
94 | # Initialize all board data with current board size and komi.
95 |
96 | self.board_size = min(board_size, BOARD_SIZE)
97 | self.num_intersections = self.board_size ** 2
98 | self.num_vertices = (self.board_size+2) ** 2
99 | self.komi = komi
100 | ebsize = board_size+2
101 | self.dir4 = [1, ebsize, -1, -ebsize]
102 | self.diag4 = [1 + ebsize, ebsize - 1, -ebsize - 1, 1 - ebsize]
103 |
104 | for vtx in range(self.num_vertices):
105 | self.state[vtx] = INVLD # set invalid for out border
106 |
107 | for idx in range(self.num_intersections):
108 | self.state[self.index_to_vertex(idx)] = EMPTY # set empty for intersetions
109 |
110 | '''
111 | self.id, self,next, self.stones are basic data struct for strings. By
112 | these structs, we can search a whole string more fast. For exmple, we
113 | have the boards looks like
114 |
115 | board position
116 | a b c d e
117 | 1| . . . . .
118 | 2| . x x x .
119 | 3| . . . . .
120 | 4| . x x . .
121 | 5| . . . . .
122 |
123 | vertex position
124 | a b c d e
125 | 1| 8 9 10 11 12
126 | 2| 15 16 17 18 19
127 | 3| 22 23 24 25 26
128 | 4| 29 30 31 32 33
129 | 5| 36 37 38 39 40
130 |
131 | self.id
132 | a b c d e
133 | 1| . . . . .
134 | 2| . 16 16 16 .
135 | 3| . . . . .
136 | 4| . 30 30 . .
137 | 5| . . . . .
138 |
139 | self.next
140 | a b c d e
141 | 1| . . . . .
142 | 2| . 17 18 16 .
143 | 3| . . . . .
144 | 4| . 31 30 . .
145 | 5| . . . . .
146 |
147 | self.stones
148 | a b c d e
149 | 1| . . . . .
150 | 2| . 3 . . .
151 | 3| . . . . .
152 | 4| . 2 . . .
153 | 5| . . . . .
154 |
155 | If we want to search the string 16, just simply start from its
156 | id (the string parent vertex). The pseudo code looks like
157 |
158 | start_pos = id[vertex]
159 | next_pos = start_pos
160 | {
161 | next_pos = next[next_pos]
162 | } while(next_pos != start_pos)
163 |
164 | '''
165 |
166 | self.id = np.arange(NUM_VERTICES) # the id(parent vertex) of string
167 | self.next = np.arange(NUM_VERTICES) # next position in the same string
168 | self.stones = np.zeros(NUM_VERTICES) # the string size
169 |
170 | for i in range(NUM_VERTICES):
171 | self.sl[i].clear() # clear liberties
172 |
173 | self.num_passes = 0 # number of passes played.
174 | self.ko = NULL_VERTEX # illegal position due to Ko
175 | self.to_move = BLACK # black
176 | self.move_num = 0 # move number
177 | self.last_move = NULL_VERTEX # last move
178 | self.removed_cnt = 0 # removed stones count
179 | self.history = [] # history board positions.
180 |
181 | def copy(self):
182 | # Deep copy the board to another board. But they will share the same
183 | # history board positions.
184 |
185 | b_cpy = Board(self.board_size, self.komi)
186 | b_cpy.state = np.copy(self.state)
187 | b_cpy.id = np.copy(self.id)
188 | b_cpy.next = np.copy(self.next)
189 | b_cpy.stones = np.copy(self.stones)
190 | for i in range(NUM_VERTICES):
191 | b_cpy.sl[i].lib_cnt = self.sl[i].lib_cnt
192 | b_cpy.sl[i].v_atr = self.sl[i].v_atr
193 | b_cpy.sl[i].libs |= self.sl[i].libs
194 |
195 | b_cpy.num_passes = self.num_passes
196 | b_cpy.ko = self.ko
197 | b_cpy.to_move = self.to_move
198 | b_cpy.move_num = self.move_num
199 | b_cpy.last_move = self.last_move
200 | b_cpy.removed_cnt = self.removed_cnt
201 |
202 | for h in self.history:
203 | b_cpy.history.append(h)
204 | return b_cpy
205 |
206 | def _remove(self, v):
207 | # Remove a string including v.
208 |
209 | v_tmp = v
210 | removed = 0
211 | while True:
212 | removed += 1
213 | self.state[v_tmp] = EMPTY # set empty
214 | self.id[v_tmp] = v_tmp # reset id
215 | for d in self.dir4:
216 | nv = v_tmp + d
217 | # Add liberty to neighbor strings.
218 | self.sl[self.id[nv]].add(v_tmp)
219 | v_next = self.next[v_tmp]
220 | self.next[v_tmp] = v_tmp
221 | v_tmp = v_next
222 | if v_tmp == v:
223 | break # Finish when all stones are removed.
224 | return removed
225 |
226 | def _merge(self, v1, v2):
227 | '''
228 | board position
229 | a b c d e
230 | 1| . . . . .
231 | 2| . x x x .
232 | 3| . [x] . . .
233 | 4| . x . . .
234 | 5| . . . . .
235 |
236 | Merge two strings...
237 |
238 | [before] >> [after]
239 |
240 | self.id
241 | a b c d e a b c d e
242 | 1| . . . . . 1| . . . . .
243 | 2| . 16 16 16 . 2| . 16 16 16 .
244 | 3| . 30 . . . >> 3| . 16 . . .
245 | 4| . 30 . . . 4| . 16 . . .
246 | 5| . . . . . 5| . . . . .
247 |
248 | self.next
249 | a b c d e a b c d e
250 | 1| . . . . . 1| . . . . .
251 | 2| . 17 18 16 . 2| . 30 18 16 .
252 | 3| . 30 . . . >> 3| . 17 . . .
253 | 4| . 23 . . . 4| . 23 . . .
254 | 5| . . . . . 5| . . . . .
255 |
256 | self.stones
257 | a b c d e a b c d e
258 | 1| . . . . . 1| . . . . .
259 | 2| . 3 . . . 2| . 5 . . .
260 | 3| . . . . . >> 3| . . . . .
261 | 4| . 2 . . . 4| . . . . .
262 | 5| . . . . . 5| . . . . .
263 |
264 | '''
265 |
266 | # Merge string including v1 with string including v2.
267 |
268 | id_base = self.id[v1]
269 | id_add = self.id[v2]
270 |
271 | # We want the large string merges the small string.
272 | if self.stones[id_base] < self.stones[id_add]:
273 | id_base, id_add = id_add, id_base # swap
274 |
275 | self.sl[id_base].merge(self.sl[id_add])
276 | self.stones[id_base] += self.stones[id_add]
277 |
278 | v_tmp = id_add
279 | while True:
280 | self.id[v_tmp] = id_base # change id to id_base
281 | v_tmp = self.next[v_tmp]
282 | if v_tmp == id_add:
283 | break
284 | # Swap next id for circulation.
285 | self.next[v1], self.next[v2] = self.next[v2], self.next[v1]
286 |
287 | def _place_stone(self, v):
288 | # Play a stone on the board and try to merge itself with adjacent strings.
289 |
290 | # Set one stone to the board and prepare data.
291 | self.state[v] = self.to_move
292 | self.id[v] = v
293 | self.stones[v] = 1
294 | self.sl[v].set()
295 |
296 | for d in self.dir4:
297 | nv = v + d
298 | if self.state[nv] == EMPTY:
299 | self.sl[self.id[v]].add(nv) # Add liberty to itself.
300 | else:
301 | self.sl[self.id[nv]].sub(v) # Remove liberty from opponent's string.
302 |
303 | # Merge the stone with my string.
304 | for d in self.dir4:
305 | nv = v + d
306 | if self.state[nv] == self.to_move and self.id[nv] != self.id[v]:
307 | self._merge(v, nv)
308 |
309 | # Remove the opponent's string.
310 | self.removed_cnt = 0
311 | for d in self.dir4:
312 | nv = v + d
313 | if self.state[nv] == int(self.to_move == 0) and \
314 | self.sl[self.id[nv]].lib_cnt == 0:
315 | self.removed_cnt += self._remove(nv)
316 |
317 | def legal(self, v):
318 | # Reture true if the move is legal.
319 |
320 | if v == PASS:
321 | # The pass move is always legal in any condition.
322 | return True
323 | elif v == self.ko or self.state[v] != EMPTY:
324 | # The move is ko move.
325 | return False
326 |
327 | stone_cnt = [0, 0]
328 | atr_cnt = [0, 0] # atari count
329 | for d in self.dir4:
330 | nv = v + d
331 | c = self.state[nv]
332 | if c == EMPTY:
333 | return True
334 | elif c <= 1: # The color must be black or white
335 | stone_cnt[c] += 1
336 | if self.sl[self.id[nv]].lib_cnt == 1:
337 | atr_cnt[c] += 1
338 |
339 | return (atr_cnt[int(self.to_move == 0)] != 0 or # That means we can eat other stones.
340 | atr_cnt[self.to_move] < stone_cnt[self.to_move]) # That means we have enough liberty to live.
341 |
342 | def play(self, v):
343 | # Play the move and update board data if the move is legal.
344 |
345 | if not self.legal(v):
346 | return False
347 | else:
348 | if v == PASS:
349 | # We should be stop it if the number of passes is bigger than 2.
350 | # Be sure to check the number of passes before playing it.
351 | self.num_passes += 1
352 | self.ko = NULL_VERTEX
353 | else:
354 | self._place_stone(v)
355 | id = self.id[v]
356 | self.ko = NULL_VERTEX
357 | if self.removed_cnt == 1 and \
358 | self.sl[id].lib_cnt == 1 and \
359 | self.stones[id] == 1:
360 | # Set the ko move if the last move only captured one and was surround
361 | # by opponent's stones.
362 | self.ko = self.sl[id].v_atr
363 | self.num_passes = 0
364 |
365 | self.last_move = v
366 | self.to_move = int(self.to_move == 0) # switch side
367 | self.move_num += 1
368 |
369 | # Push the current board positions to history.
370 | self.history.append(copy.deepcopy(self.state))
371 |
372 | return True
373 |
374 | def _compute_reach_color(self, color):
375 | # This is simple BFS algorithm to compute evey reachable vertices.
376 |
377 | queue = []
378 | reachable = 0
379 | buf = [False] * NUM_VERTICES
380 |
381 | # Collect my positions.
382 | for v in range(NUM_VERTICES):
383 | if self.state[v] == color:
384 | reachable += 1
385 | buf[v] = True
386 | queue.append(v)
387 |
388 | # Now start the BFS algorithm to search all reachable positions.
389 | while len(queue) != 0:
390 | v = queue.pop(0)
391 | for d in self.dir4:
392 | nv = v + d
393 | if self.state[nv] == EMPTY and buf[nv] == False:
394 | reachable += 1
395 | queue.append(nv)
396 | buf[nv] = True
397 | return reachable
398 |
399 | def final_score(self):
400 | # Scored the board area with Tromp-Taylor rule.
401 | return self._compute_reach_color(BLACK) - self._compute_reach_color(WHITE) - self.komi
402 |
403 | def get_x(self, v):
404 | # vertex to x
405 | return v % (self.board_size+2) - 1
406 |
407 | def get_y(self, v):
408 | # vertex to y
409 | return v // (self.board_size+2) - 1
410 |
411 | def get_vertex(self, x, y):
412 | # x, y to vertex
413 | return (y+1) * (self.board_size+2) + (x+1)
414 |
415 | def get_index(self, x, y):
416 | # x, y to index
417 | return y * self.board_size + x
418 |
419 | def vertex_to_index(self, v):
420 | # vertex to index
421 | return self.get_index(self.get_x(v), self.get_y(v))
422 |
423 | def index_to_vertex(self, idx):
424 | # index to vertex
425 | return self.get_vertex(idx % self.board_size, idx // self.board_size)
426 |
427 | def vertex_to_text(self, vtx):
428 | # vertex to GTP move
429 |
430 | if vtx == PASS:
431 | return "pass"
432 | elif vtx == RESIGN:
433 | return "resign"
434 |
435 | x = self.get_x(vtx)
436 | y = self.get_y(vtx)
437 | offset = 1 if x >= 8 else 0 # skip 'I'
438 | return "".join([chr(x + ord('A') + offset), str(y+1)])
439 |
440 | def get_features(self):
441 | # 1~ 16, odd planes : My side to move current and past boards stones
442 | # 1~ 16, even planes: Other side to move current and past boards stones
443 | # 17 plane : Set one if the side to move is black.
444 | # 18 plane : Set one if the side to move is white.
445 | my_color = self.to_move
446 | opp_color = (self.to_move + 1) % 2
447 | past_moves = min(PAST_MOVES, len(self.history))
448 |
449 | features = np.zeros((INPUT_CHANNELS, self.num_intersections), dtype=np.int8)
450 | for p in range(past_moves):
451 | # Fill past board positions features.
452 | h = self.history[len(self.history) - p - 1]
453 | for v in range(self.num_vertices):
454 | c = h[v]
455 | if c == my_color:
456 | features[p*2, self.vertex_to_index(v)] = 1
457 | elif c == opp_color:
458 | features[p*2+1, self.vertex_to_index(v)] = 1
459 |
460 | # Fill side to move features.
461 | features[INPUT_CHANNELS - 2 + self.to_move, :] = 1
462 | return np.reshape(features, (INPUT_CHANNELS, self.board_size, self.board_size))
463 |
464 | def superko(self):
465 | # Return true if the current position is superko.
466 |
467 | curr_hash = hash(self.state.tostring())
468 | s = len(self.history)
469 | for p in range(s-1):
470 | h = self.history[p]
471 | if hash(h.tostring()) == curr_hash:
472 | return True
473 | return False
474 |
475 | def __str__(self):
476 | def get_xlabel(bsize):
477 | X_LABELS = "ABCDEFGHJKLMNOPQRST"
478 | line_str = " "
479 | for x in range(bsize):
480 | line_str += " " + X_LABELS[x] + " "
481 | return line_str + "\n"
482 | out = str()
483 | out += get_xlabel(self.board_size)
484 |
485 | for y in range(0, self.board_size)[::-1]: # 9, 8, ..., 1
486 | line_str = str(y+1) if y >= 9 else " " + str(y+1)
487 | for x in range(0, self.board_size):
488 | v = self.get_vertex(x, y)
489 | x_str = " . "
490 | color = self.state[v]
491 | if color <= 1:
492 | stone_str = "O" if color == WHITE else "X"
493 | if v == self.last_move:
494 | x_str = "[" + stone_str + "]"
495 | else:
496 | x_str = " " + stone_str + " "
497 | line_str += x_str
498 | line_str += str(y+1) if y >= 10 else " " + str(y+1)
499 | out += (line_str + "\n")
500 |
501 | out += get_xlabel(self.board_size)
502 | return out + "\n"
503 |
--------------------------------------------------------------------------------
/checkpoint_to_weights.py:
--------------------------------------------------------------------------------
1 | from network import Network
2 | from config import BOARD_SIZE
3 | import argparse
4 | import torch
5 | import time
6 |
7 | def load_checkpoint(network, checkpoint):
8 | state_dict = torch.load(checkpoint, map_location=network.gpu_device)
9 | network.load_state_dict(state_dict["network"])
10 | return network
11 |
12 | def get_currtime():
13 | lt = time.localtime(time.time())
14 | return "{y}-{m}-{d}-{h:02d}-{mi:02d}-{s:02d}".format(
15 | y=lt.tm_year, m=lt.tm_mon, d=lt.tm_mday, h=lt.tm_hour, mi=lt.tm_min, s=lt.tm_sec)
16 |
17 | if __name__ == "__main__":
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("-c", "--checkpoint", metavar="",
20 | help="The inpute checkpoint file name.", type=str)
21 | args = parser.parse_args()
22 |
23 | if args.checkpoint:
24 | network = Network(BOARD_SIZE, use_gpu=False)
25 | network = load_checkpoint(network, args.checkpoint)
26 | network.save_pt("weights-{}.pt".format(get_currtime()))
27 | else:
28 | print("Please give the checkpoint path.")
29 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | BOARD_SIZE = 9 # The default and max board size. We can reset the value later.
2 |
3 | KOMI = 7 # The default komi. We can reset the value later.
4 |
5 | USE_GPU = True # Set true will use the GPU automatically if you have one.
6 |
7 | BLOCK_SIZE = 2 # The network residual block size.
8 |
9 | BLOCK_CHANNELS = 64 # The number of network residual channels.
10 |
11 | POLICY_CHANNELS = 8 # The number of value head channels.
12 |
13 | VALUE_CHANNELS = 4 # The number of policy head channels.
14 |
15 | INPUT_CHANNELS = 18 # Number of the network input layers.
16 |
17 | PAST_MOVES = 8 # Number of past moves encoding to the planes.
18 |
19 | USE_SE = False # Enable Squeeze-and-Excite net struct.
20 |
--------------------------------------------------------------------------------
/dlgo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from gtp import GTP_LOOP
4 | from gui import GUI_LOOP
5 | import argparse
6 |
7 | if __name__ == "__main__":
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("-p", "--playouts", metavar="",
10 | help="The number of playouts.", type=int, default=400)
11 | parser.add_argument("-w", "--weights", metavar="",
12 | help="The weights file name.", type=str)
13 | parser.add_argument("-r", "--resign-threshold", metavar="",
14 | help="Resign when winrate is less than x.", type=float, default=0.1)
15 | parser.add_argument("-v", "--verbose", default=False,
16 | help="Dump some search verbose.", action="store_true")
17 | parser.add_argument("-k", "--kgs", default=False,
18 | help="Dump some hint verbose on KGS.", action="store_true")
19 | parser.add_argument("-g", "--gui", default=False,
20 | help="Open it with GUI.", action="store_true")
21 |
22 | args = parser.parse_args()
23 | if args.gui:
24 | loop = GUI_LOOP(args)
25 | else:
26 | loop = GTP_LOOP(args)
27 |
--------------------------------------------------------------------------------
/docs/ComputerGoHistory.md:
--------------------------------------------------------------------------------
1 | # History
2 |
3 | ## 準備
4 |
5 | 請確定你已經看過[演算法實做和原理](../docs/Methods.md)章節和[圍棋規則的簡單介紹](https://www.smartgo.com/tw/go.html),有些術語會在裡面出現。
6 |
--------------------------------------------------------------------------------
/docs/English.md:
--------------------------------------------------------------------------------
1 | # Simple Usage
2 |
3 | ## Requirements
4 | 1. PyTorch (1.x version)
5 | 2. NumPy
6 | 3. Tkinter
7 | 4. Matplotlib
8 |
9 | ## Open With Built-in GUI
10 |
11 | You may download the pre-training model from release section named 預先訓練好的小型權重. Then put the pt file to the dlgo directory and enter following command.
12 |
13 | $ python3 dlgo.py --weights nn_2x64.pt --gui
14 |
15 | The dlgo will use your GPU automatically. If you want to disable GPU, set the value ```USE_GPU``` False.
16 |
17 |
18 | ## Open With GTP GUI
19 |
20 | The dlgo support the GTP GUI. [Sabaki](https://sabaki.yichuanshen.de) is recommanded. Some helpful optional arguments are here.
21 |
22 | optional arguments:
23 | -p , --playouts
24 | The number of playouts.
25 | -w , --weights
26 | The weights file name.
27 | -r , --resign-threshold
28 | Resign when winrate is less than x.
29 |
30 | The sample command is here.
31 |
32 | $ dlgo.py --weights nn_2x64.pt -p 1600 -r 0.25
33 |
34 | ## Training
35 |
36 | Following above simple steps to train a new weights
37 |
38 | 1. Preparing the sgf files. You may just use the sgf.zip. The zip including around 35000 9x9 games.
39 | 2. Set the network parametes in the config.py. Including ```BOARD_SIZE```, ```BLOCK_SIZE```, ```FILTER_SIZE```.
40 | 3. Start training.
41 |
42 | $ python3 train.py --dir sgf-directory-name --steps 128000 --batch-size 512 --learning-rate 0.005
43 |
44 | Some helpful arguments are here.
45 |
46 | optional arguments:
47 | -h, --help show this help message and exit
48 | -d , --dir
49 | The input SGF files directory. Will use data cache if set None.
50 | -s , --steps
51 | Terminate after these steps.
52 | -v , --verbose-steps
53 | Dump verbose on every X steps.
54 | -b , --batch-size
55 | The batch size number.
56 | -l , --learning-rate
57 | The learning rate.
58 | --noplot Disable plotting.
59 |
--------------------------------------------------------------------------------
/docs/Methods.md:
--------------------------------------------------------------------------------
1 | # Methods
2 |
3 | ## 零、前言
4 |
5 | 這一章節主要描述 dlgo 裡使用的技術和實做方法,輔助讀者理解程式碼的內容。
6 |
7 | ## 一、棋盤資料結構
8 |
9 | 如果同學們以前有自己嘗試實做圍棋棋盤,應該可以發現圍棋棋盤和圖論有莫大的關係,首先棋盤大部份點和點之間是等價的,二來棋盤是一個平面圖,這些圖論性質暗示者在棋盤上找某些特定元素可能會非常困難,像是找出棋盤上活棋棋串,有些甚至無法保證可以找出來,像是雙活。慶幸的是,基本常用的資料結構是有定論的,接下來我們要討論如何快速計算棋盤上每一塊棋串的狀態
10 |
11 | ### MailBox
12 |
13 | 我們知道如果程式中分支條件越多,性能就會越低,假設我要找棋盤上某一點四周的氣,那就必須用四個分支條確保不會搜尋到到棋盤外,而且搜尋四周邊是使用次數非常多的功能,這將會有巨大的性能消耗,為了解決這個問題,我們要使用一個棋盤遊戲中常用的資料結構,MailBox 。假設我有一個大小為五棋盤如下
14 |
15 | a b c d e
16 | 1 . . . . .
17 | 2 . . . . .
18 | 3 . . . . .
19 | 4 . . . . .
20 | 5 . . . . .
21 |
22 |
23 | 改進前的資料結構虛擬碼如下(注意這邊是使用一維陣列)
24 |
25 | BLACK = 0
26 | WHITE = 1
27 | EMPTY = 2
28 | INVLD = 3 # out of board value
29 |
30 | find_adjacent(index):
31 | type_count[4] = {0,0,0,0}
32 |
33 | for adjacent in index
34 | if adjacent is out of board
35 | type_count[INVLD] += 1
36 | else
37 | type = board[adjacent]
38 | type_count[type] += 1
39 |
40 |
41 | MailBox 的核心概念就是在棋盤外圍加一圈無效區域(標示為 ```-``` 的位置),這樣就不用特別判斷是否超出邊界
42 |
43 | a b c d e
44 | - - - - - - -
45 | 1 - . . . . . -
46 | 2 - . . . . . -
47 | 3 - . . . . . -
48 | 4 - . . . . . -
49 | 5 - . . . . . -
50 | - - - - - - -
51 |
52 | 改進後的資料結構虛擬碼如下,可以看見不僅性能提,整個程式碼也簡潔不少
53 |
54 | BLACK = 0
55 | WHITE = 1
56 | EMPTY = 2
57 | INVLD = 3 # out of board value
58 |
59 | find_adjacent(vertex):
60 | type_count[4] = {0,0,0,0}
61 |
62 | for adjacent in vertex
63 | type_count[type] += 1
64 |
65 |
66 | 在本程式的實做當中,如果是使用改進前版本的座標表示法,則稱為 index ,一般用於輸出盤面資料給外部使用。如果是使用改進後版本的座標表示法,則稱為 vertex,,一般用於內部棋盤搜尋。
67 |
68 | ### 棋串(string)
69 |
70 | 棋串可以看成是整個棋盤中的子圖(sub-graph),而且它是一個節點循環的圖,我們來看看下列結構,board position 是當前盤面,可以看到有兩個黑棋棋串,vertex position 是當前 vertex 的座標數值(一維陣列),string identity 是棋串的 identity,這邊注意的是 identity 指到的位置是整個棋串的 root vertex 位置,像是 identity 為 16 的棋串,其 16 的 vertex 座標必在此棋串內,此位置也為此棋串的根節點,至於為甚麼要這樣做,稍後再來討論,最後 next position 指向下一個節點位置,而且它們是循環的,像是 identity 為 16 的棋串,他的 next position 串接起來為 (17->18->16) -> (17->18->16) -> ... 無限循環
71 |
72 | board position
73 | a b c d e
74 | 1| . . . . .
75 | 2| . x x x .
76 | 3| . . . . .
77 | 4| . x x . .
78 | 5| . . . . .
79 |
80 | vertex position
81 | a b c d e
82 | 1| 8 9 10 11 12
83 | 2| 15 16 17 18 19
84 | 3| 22 23 24 25 26
85 | 4| 29 30 31 32 33
86 | 5| 36 37 38 39 40
87 |
88 | string identity
89 | a b c d e
90 | 1| . . . . .
91 | 2| . 16 16 16 .
92 | 3| . . . . .
93 | 4| . 30 30 . .
94 | 5| . . . . .
95 |
96 | next position
97 | a b c d e
98 | 1| . . . . .
99 | 2| . 17 18 16 .
100 | 3| . . . . .
101 | 4| . 31 30 . .
102 | 5| . . . . .
103 |
104 |
105 | 假設今天我們要找一個棋串的氣,只要從一個節點開始走下去,依序計算直到走到原位置,虛擬碼如下
106 |
107 | conut_liberty(vertex):
108 | start_pos = identity[vertex] # get the start vertex postion
109 |
110 | next_pos = start_pos
111 | liberty_set = set()
112 | {
113 | for adjacent in next_pos
114 | if board[adjacent] == EMPTY
115 | liberty_set.add(adjacent) # add the adjacent vertex to set
116 |
117 | next_pos = next[next_pos] # go to next vertex postion
118 | } while(next_pos != start_pos)
119 |
120 | liberties = length(liberty_set)
121 |
122 |
123 | ### 儲存棋串(string)資訊
124 |
125 | 剛剛講了 identity 指向棋串的 root vertex,這 root vertex 可以儲存棋串的狀態資訊,當需要用到這些資訊時,不必每次都重算,像是棋串棋子數目,或是棋串氣數等等。本程式實做的資料結構如下
126 |
127 |
128 | string identity
129 | a b c d e
130 | 1| . . . . .
131 | 2| . 16 16 16 .
132 | 3| . . . . .
133 | 4| . 30 30 . .
134 | 5| . . . . .
135 |
136 |
137 | string stones
138 | a b c d e
139 | 1| . . . . .
140 | 2| . 3 . . .
141 | 3| . . . . .
142 | 4| . 2 . . .
143 | 5| . . . . .
144 |
145 | string liberty set
146 | a b c d e
147 | 1| . . . . .
148 | 2| . A . . .
149 | 3| . . . . . # A = liberty set of string 16
150 | 4| . B . . . # B = liberty set of string 30
151 | 5| . . . . .
152 |
153 |
154 | ### 合併棋串(string)
155 |
156 | 兩個棋串合併時,只要簡單的交換雙方接觸點的 next position,並把 string identity 、string stones 和 string liberty set 更新即可,如下所示。如果是多個棋串合併,只要簡單的把兩兩棋串一個個合併就好。
157 |
158 | board position
159 | a b c d e
160 | 1| . . . . .
161 | 2| . x x x .
162 | 3| . [x] . . .
163 | 4| . x . . .
164 | 5| . . . . .
165 |
166 | Merge two strings...
167 |
168 | string identity
169 | a b c d e a b c d e
170 | 1| . . . . . 1| . . . . .
171 | 2| . 16 16 16 . 2| . 16 16 16 .
172 | 3| . 30 . . . >> 3| . 16 . . .
173 | 4| . 30 . . . 4| . 16 . . .
174 | 5| . . . . . 5| . . . . .
175 |
176 | next position
177 | a b c d e a b c d e
178 | 1| . . . . . 1| . . . . .
179 | 2| . 17 18 16 . 2| . 30 18 16 .
180 | 3| . 30 . . . >> 3| . 17 . . .
181 | 4| . 23 . . . 4| . 23 . . .
182 | 5| . . . . . 5| . . . . .
183 |
184 | string stones
185 | a b c d e a b c d e
186 | 1| . . . . . 1| . . . . .
187 | 2| . 3 . . . 2| . 5 . . .
188 | 3| . . . . . >> 3| . . . . .
189 | 4| . 2 . . . 4| . . . . .
190 | 5| . . . . . 5| . . . . .
191 |
192 |
193 | string liberty set
194 | a b c d e a b c d e
195 | 1| . . . . . 1| . . . . .
196 | 2| . A . . . 2| . C . . .
197 | 3| . . . . . >> 3| . . . . . # set C = set A + set B
198 | 4| . B . . . 4| . . . . .
199 | 5| . . . . . 5| . . . . .
200 |
201 |
202 | ### 偵測合法手
203 |
204 | 依據不同的圍棋規則,合法手會有不同定義,為了方便討論問題,這裡依據本程式的實做給予合法手兩個基本條件
205 |
206 | 1. 此手棋下下去後,最終結果不為零氣,簡單來講就是不能自殺
207 | 2. 禁止同一盤棋出現相同盤面(super ko)
208 |
209 | 先討論第一點不能自殺,避免自殺有三種方式
210 |
211 | 1. 四周至少有一點為空點
212 | 2. 四周與自身相同的顏色的棋串,至少一塊棋超過一氣
213 | 3. 四周與自身相異的顏色的棋串,至少一塊棋為一氣(提吃)
214 |
215 |
216 | is_suicide(vertex):
217 | for adjacent in vertex
218 | if board[adjacent] == EMPTY
219 | return false
220 |
221 | if board[adjacent] == MY_COLOR &&
222 | string_liberties[adjacent] > 1:
223 | return false
224 |
225 | if board[adjacent] == OPP_COLOR &&
226 | string_liberties[adjacent] == 1:
227 | return false
228 | return true
229 |
230 | 以上其中一個條件滿足,則必不是自殺手。接著討論禁止相同盤面,由於偵測相同盤面會比較消耗計算力,一般我們以偵測劫為主,也就是是否為熱子,如果出現吃掉熱子棋況,則為非法手。熱子的定義為
231 |
232 | 1. 吃到一顆對方的棋子
233 | 2. 下的棋子最終結果只有一氣
234 | 3. 下的棋子最終結果只有一顆子
235 |
236 | is_ko(vertex):
237 | if captured_count == 1 &&
238 | string_stones[vertex] == 1 &&
239 | string_liberties[vertex] == 1
240 | return true
241 | return false
242 |
243 | 此三項都滿足,則必為熱子。如果為相同盤面但不是劫的情況,就沒有特別的算法,只能實際將棋子擺到棋盤上後,再看是否當前盤面和歷史盤面有重複。在 dlgo 的實做裡,只有在樹搜索的根節點上移除相同盤面的子節點,這樣能保證 dlgo 在符合 Tromp-Taylor 規則下,還有一定的性能。Tromp-Taylor 規則請看第六節。
244 |
245 | ## 二、審局函數
246 |
247 | 最早期的圍棋程式是沒有審局函數的,這是由於圍棋局勢多變缺乏明顯特徵,致使一直以來製作良好的審局函數都是一個大難題,一般而言都是以預測下一手棋的位置為主,通過快速落子到終局,得到勝負結果,相同盤面重複多次這個步驟後後,即可評估當前局面的好壞,可想而之,這樣得出的勝率自然不準。但自從深度學習開始興起,審局函數的問題便迎刃而解。這裡主要是對審局函數的一些基本描述,並不會涉及大多深度學習的部份。
248 |
249 | ### 基本狀態
250 |
251 | 假設今天有一個狀態
(當前盤面),它擁有數個動作
(最大合法手),執行動作者為代理人
(程式本體),我們會希望從當前狀態得到兩類資訊,第一類是策略(policy)資訊,告訴代理人哪些動作值得被執行或是搜尋,在 AlphaGo 的實做中,此為最大合法手的分佈機率,第二類為價值(value)資訊,告訴代理人當前狀態的分數或是每個動作的分數在 AlphaGo 的實做中,此為當前盤面的分數(也能視為勝率),如下所示
252 |
253 | 
254 |
255 | ### 訓練審局函數
256 |
257 | 根據前述的理論,我們收集當前狀態、下一步棋和本局的勝負當作訓練的資料,收集的結果為下
258 |
259 | | 當前狀態 | 落子座標 | 勝負結果 |
260 | | :------------: | :---------------: | :---------------: |
261 | | S1(換黑棋落子)| ```e5``` | 黑棋獲勝 |
262 | | S2(換白棋落子)| ```d3``` | 黑棋獲勝 |
263 | | S3(換黑棋落子)| ```e5``` | 白棋獲勝 |
264 | | S4(換白棋落子)| ```d3``` | 白棋獲勝 |
265 |
266 |
267 |
268 | 接下來將資料轉換成網路看得懂的資料,在本實做中,當前狀態為過去的八手棋(每手棋包含黑白兩個 planes )和當前下棋的顏色做編碼(兩個 planes ),編碼成 18 個 planes(可到 board.py 裡的 get_features() 查看實做細節),落子座標轉成一維陣列,只有落子處為 1 ,未落子處為 0 ,勝負結果如果是當前玩家獲勝則是 1,如果落敗則為 -1。轉換的結果如下
269 |
270 | | 當前狀態 | 落子座標 | 勝負結果 |
271 | | :------------: | :---------------: | :---------------: |
272 | | Inputs 1 | 40 | 1 |
273 | | Inputs 2 | 21 | -1 |
274 | | Inputs 3 | 40 | -1 |
275 | | Inputs 4 | 21 | 1 |
276 |
277 |
278 |
279 | 網路希望的優化策略為下。現代神經網路的訓練方式一般採用[反向傳播](https://en.wikipedia.org/wiki/Backpropagation),但你不需要實際了解演算法的詳細過程,只需要知道它是利用微分的方式,盡可能找到斜率比較低的點,此點也是網路參數中 Loss 最低的點,其原理類似於高中教的牛頓法。
280 |
281 | 
282 |
283 | 其中
284 |
285 |
為資料的勝負結果
286 |
287 |
為網路數出的 value
288 |
289 |
為資料的落子座標陣列
290 |
291 |
為網路數出的 policy
292 |
293 |
為對網路參數的懲罰項
294 |
295 | 當然這不是唯一的編碼方式,像是 ELF Open Go 的勝負結果只看黑棋的一方,如下方
296 |
297 | | 當前狀態 | 落子座標 | 勝負結果 |
298 | | :------------: | :---------------: | :---------------: |
299 | | Inputs 1 | 40 | 1 |
300 | | Inputs 2 | 21 | 1 |
301 | | Inputs 3 | 40 | -1 |
302 | | Inputs 4 | 21 | -1 |
303 |
304 |
305 |
306 | ## 三、殘差神經網路(Residual Networks)
307 |
308 | 理論而言,越深的網路可以有更強的擬合能力和更好的準確性,但實際上,直接疊加網路並不能有更好的結果,反而可能有過擬合(overfitting)現象或是網路退化等問題,Resnet 作者認為這是由於深度過深,梯度無法順利傳遞下去導致的梯度消失/爆炸。下圖顯示 56 層的捲積層不論在訓練時或推論時,其準確度都比 34 層的表現更差。
309 |
310 | 
311 |
312 | 為了解決此問題,作者提出 shortcut 結構,簡單粗暴的將最初的輸入加到最後的輸出中,此結構能夠讓梯度直接穿透多層網路,很好的避免了梯度消失/爆炸的問題,此種應用 shortcut 結構堆疊的網路稱為 Residual Networks(簡稱 Resnet)。Resnet 最大的好處和突破為,它可堆疊的層數幾乎沒有上限,基本上越深準度會越好,甚至在原論文中使用超過 1000 層的網路,也不會發生退化問題。通常我們稱一個 shortcut 結構為一個 block。
313 |
314 | 
315 |
316 | 雖然在原本論文中,作者實做了多種不同變體的 block,但一般應用在棋盤遊戲的 Resnet 比較簡單,每層使用的 kernel size 固定為 3,每一個 block 使用兩個捲積層和兩個正規層,不使用 Max Pooling。
317 |
318 | ## 四、蒙地卡羅樹搜索(Monte Carlo Tree Search)
319 |
320 | 蒙地卡羅樹搜索是一種啟發式算法,最早由 Crazy Stone 的作者 Rémi Coulom 於 2006 年在他的論文 [Efficient Selectivity and Backup Operators in Monte-Carlo Tree Search](https://hal.inria.fr/inria-00116992/document) 中提出,他成功結合 [negamax](http://rportal.lib.ntnu.edu.tw/bitstream/20.500.12235/106643/4/n069347008204.pdf) 和蒙地卡羅方法,此方法最大的突破點在於,不同以往的圍棋程式,它僅須少量的圍棋知識就可以實做。時至今日,蒙地卡羅樹搜索經歷多次的公式修正和加入更多的啟發式搜索,如傳統的 UCT(Upper Confidence bounds applied to Trees)和 RAVE,和本次程式實做的 [PUCT](https://www.chessprogramming.org/Christopher_D._Rosin#PUCT) ('Predictor' + UCT )。
321 |
322 | ### 蒙地卡羅方法(Monte Carlo Method)
323 |
324 | 蒙地卡羅方法的核心概念非常簡單,要知道某件事情發生的機率,我們只要模擬數次就可以得到近似發生的機率,但為何需要如此費力且不完全準確的方法呢?我以擲筊為例,大家應該都對筊杯不陌生,當生活遇到瓶頸時,或多或少都會有人擲筊請示神明,但不知道是否有人想過,擲出聖杯的機率到底是多少,依照古典算法,假設擲出正反兩面的機率是二分之一,那麼聖杯的機率是二分之一(笑杯也加入計算),但很顯然的,由於筊杯兩面不是對稱的,所以機率絕對不是二分之一。在一般情況下,擲出聖杯的機率是沒有辦法僅依靠計算得出的,此時蒙地卡羅方法就展現他的威力的,我只需要重複投擲一萬次,再計算共幾次聖杯即可。同樣的在圍棋上,由於圍棋的複雜性,早期圍棋是很難得出較為準確的勝率,但通過蒙地卡羅方法,讓同一個盤隨機模擬數次,即可算出一個相對來說較為可靠的勝率。
325 |
326 | ### 基本的 UCT(Upper Confidence bounds applied to Trees)
327 |
328 |
329 |

330 |
331 |
332 | 傳統的 MCTS 的每輪迭代更新會經歷基本的四個步驟
333 |
334 | 1. 選擇:
335 | 由根節點開始,根據一定的選擇算法,找到一個葉節點(終端節點),傳統的 MCTS 會使用 UCT 公式作為選擇的依據,選擇 UCT 值最大的子節點往下搜尋,直到葉節點(終端節點,此節點尚無價值數值),它的公式如下
336 |
337 | 
338 |
339 | 其中
340 |
341 |
表示節點累積己方的分數(或勝利次數)
342 |
343 |
表示節點訪問次數
344 |
345 |
表示探勘的參數
346 |
347 |
表示父節點的訪問次數
348 |
349 | 2. 擴張:
350 | 將被選到的葉節點,生長新的子節點。新的子節點代表當前葉節點的合法走步
351 |
352 | 3. 模擬:
353 | 使用蒙地卡羅方法(Monte Carlo Method),計算第一步驟中被選到的葉節點的分數(或勝率)。通常只會模擬一次,回傳值不是輸就是贏
354 |
355 | 4. 迭代:
356 | 延著被選擇的路徑,依序迭代路徑,以當前節的顏色(對方或己方)更新分數(或勝率),迭代的節點訪問次數加一
357 |
358 | 如果仔細看的話,會發現我對於四個步驟的描述和圖片的執行過程稍有不一樣,但其實只是敘述方式不太一樣,計算結果會是一樣的。
359 |
360 | ### PUCT 的改進
361 |
362 |
363 |

364 |
365 |
366 | 2017 年的 AlphaGo Zero 提出改進過的 MCTS 演算法,主要兩點不同,第一點是以 UCT 基礎上加入策略數值,第二就是移除隨機模擬的過程,所以只會重複三個步驟。
367 |
368 | 1. 選擇:
369 | 由根節點開始,根據 PUCT 選擇算法,選擇 PUCT 值最大的子節點往下搜尋,直到葉節點(終端節點,此節點尚無價值數值)
370 |
371 | 
372 |
373 | 其中
374 |
375 |
表示節點累積的己方價值數值(即累積的勝率)
376 |
377 |
表示節點訪問次數
378 |
379 |
表示探勘的參數
380 |
381 |
表示父節點的訪問次數
382 |
383 |
表示節點的策略數值(即父節點走此節點的機率)
384 |
385 | 2. 擴張:
386 | 將被選到的葉節點,生長新的子節點。新的子節點代表當前葉節點的合法走步,並將神經網路策略數值加入新的子節點
387 |
388 | 3. 迭代:
389 | 延著被選擇的路徑上的每一個節點,依序迭代路徑,以當前節的顏色(對方或己方)更新神經網路價值數值(即勝率),迭代的節點訪問次數加一
390 |
391 | AlphaGo Zero 版本的 MCTS 相當精簡,並且去除了模擬步驟,整體來講可以說是和跟蒙地卡羅方法毫無關係,理論上,此演算法不包含隨機性,由於本程式也是實做此版本的 MCTS 演算法,所以本程式在同個盤面上給相同的計算量時,每次的計算結果都會一致。
392 |
393 | ### 落子
394 |
395 | 最後 n 輪的 MCTS 結束後,找根節點上訪問次數最多的子節點當做最佳手輸出。
396 |
397 | ## 五、信賴上界(Upper Confidence bounds)
398 |
399 | 剛剛我們提到了 UCT 算法,但之前沒有提到它具體的原理,UCT 是一種將 UCB 應用在樹搜索上的計算方法,接下來我們會描述這個公式的運作原理
400 |
401 | 
402 |
403 | 我們先思考一個實驗,假設有一吃角子老虎機,我拉了三次,出來的分數為 5 ,10, 7.5 分,此時我要怎估計它未來的收益呢?最簡單的方法是求算術平均數得到估計平均收益,但這樣有一個問題是我們不知道估計平均收益和真實平均收益大概差多少,間單來說就是我們我們不知道以 7.5 分當作平均收益,風險為多少,它的真實平均收益可能在 1 也可能在 100,UCB 公式的作用就是幫我們估計在某個信賴區間內真實平均收益可能存在的上限。如果你理解一些統計學的話,看到這裡你可能覺得很奇怪,因為 UCB 公式沒有計算變異數,它是怎得出信賴上界的?你想的沒錯,因為 UCB 公式只是經驗上湊出來的公式,它沒有真實的數學意義,你應該把它視為黑盒子,假定它得出來的值就是信賴上界,而實際上它的效果也很好。
404 |
405 | 我們再思考另一個問題,假設有多個吃角子老虎機,它們各有不同的收益,我要怎麼在有限次數內獲得最高收益?數學家給出的方法為計算每台機器的信賴上界,也就是 UCB 數值,每次拉霸之前都計算一次,找最大值的機器拉霸。同樣的情況我們放回蒙地卡羅樹搜索,多個節點就像多個吃角子老虎機,每次選擇前用 UCB 公式就可以找到當前最佳路徑。
406 |
407 | ## 六、Tromp-Taylor 規則
408 |
409 | 一般的圍棋規則分為兩大種類,日本規則和中國規則,日本規則的核心思想為只計算圍起來的空目,而中國規則是子地皆目,但不論是哪種規則,在實際執行上都會有問題,尤其是出現循環劫和多個複雜劫爭時,這些規則顯然不適合電腦圍棋,然而今天介紹的 Tromp-Taylor 規則對於棋盤的所有狀態都有定義,還被譽為世界上最簡潔的圍棋規則,它是目前電腦圍棋主流使用的規則。
410 |
411 | Tromp-Taylor 規則為中國規則的變體,計點方式也為子地皆目,主要特點為
412 | 1. 沒有定義死棋,凡事在棋盤上的棋子都視為某方的領地
413 | 2. 如果某區域被黑方包圍且沒有白棋,則視為黑棋的領地
414 | 3. 如果某區域同時接觸黑棋和白棋,則不視為某方領地
415 | 4. 禁止出現和過去相同的盤面(禁全同),這個特性定義了劫
416 | 5. 棋子可以自殺
417 |
418 | 由於沒有定義死棋和禁全同,一來循環劫不可能產生,二來出現複雜劫爭也都有實際定義,不會有實戰解決的問題。下圖為計地範例,標記的白棋在其它一般規則下可能會被視為死棋,但在 Tromp-Taylor 規則下視為活棋,視為白方的領地,而其周圍五點區域同時接觸黑棋和白棋,被視為單官,剩餘的領地都明確被某一方包圍,視為包圍方的領地。
419 |
420 |
421 |

422 |
423 |
424 | dlgo 基本實作所有的 Tromp-Taylor 規則,不同之處是 dlgo 禁止自殺手。
425 |
426 | ## 七、其它
427 |
428 | * 為什麼要對顏色編碼?
429 |
430 | 因為貼目在棋盤上是無法被感知的。
431 |
432 |
433 | * 為什麼 dlgo 輸入是十八層 planes 而不是 AlphaGo Zero 的十七層?
434 |
435 | dlgo 使用十八層輸入而非 AlphaGo Zero 的十七層,在原版的 AlphaGo Zero 中,顏色只編碼成一層而已,如果是黑色則編碼為 1 ,反之為 0 ,而 dlgo 改為兩層,如果是黑色則第一層編碼為 1,反之第二層編碼為 1,根據 Leela Zero 作者的解釋,AlphaGo Zero 的編碼方式會使得黑色和白色不平衡,導致黑棋更容易看到棋盤邊緣。
436 |
437 |
438 | * 為什麼沒有實做 AlphaGo 的快速走子(Fast Policy)網路?
439 |
440 | 快速走子網路主要是輔助 AlphaGo 估計當前盤面價值,並且獲得領先目數,領地等額外資訊。但一來快速走子網路對於 Python 增益有限,二來現代現代的用於圍棋的神經網路非常大,準確度非常高,相比下快速走子對於盤面價值估計效果不好,因此沒必要再使用快速走子網路。
441 |
442 |
443 | * 為什麼 AlphaGo 的價值網路和策略網路是分開的?
444 |
445 | 我猜這主要是歷史原因,AlphaGo 是先實做策略網路,後來才實做價值網路,最後才合併,AlphaGo 第一版論文出現是在合併前。事實上合併後的推論速度較快,網路的準確率較高,所以分開並沒有好處。
446 |
447 |
448 | * 為什麼 AlphaGo Zero 在優勢時會持續退讓?
449 |
450 | 一般會認為這和蒙地卡羅樹搜索有關,雖然正確但並不完全是這樣,事實上和強化學習的系統也有關系,因為在自我對戰的過程中,並沒有強迫它下可以贏較多的位置,導致產生的自戰棋譜在優勢時品質降低。如果你用高強度的人類棋譜訓練 dlgo,你會發現蒙地卡羅樹搜索和神經網路的組合並不會有明顯的退讓行為。
451 |
452 |
453 | * 是否建議實做強化學習?
454 |
455 | 自 2023 年後,純粹用 python 實做強化學習在圍棋上已經被證明是可行的,可參考 [TamaGo](https://github.com/kobanium/TamaGo)。
456 |
457 |
--------------------------------------------------------------------------------
/docs/PyGoEngine.md:
--------------------------------------------------------------------------------
1 | # Python 圍棋引擎列表
2 |
3 | 如果你實做的圍棋引擎使用 Python 製作且支援 GTP 協議或內建 GUI,歡迎添加
4 |
5 | ## 引擎列表
6 |
7 | * [Boke Go](https://github.com/meiji163/bokego)
8 |
9 | 九路圍棋引擎,使用 Policy Network 和 Value Network 結合蒙地卡羅樹搜索。
10 |
11 |
12 | * [Michi](https://github.com/pasky/michi)
13 |
14 | 使用傳統的 pattern 系統結合蒙地卡羅樹搜索,支援任意大小盤面。pattern 檔案到 [Pachi 10.00 Release](https://github.com/pasky/pachi/releases/tag/pachi-10.00-satsugen) 下載,它包含在任意 Pachi 發行的壓縮檔內,共兩個,分別為 patterns.spat 和 patterns.prob。
15 |
16 |
17 | * [PikachuGo](https://github.com/wsdd2/PikachuGo)
18 |
19 | 十九路圍棋引擎,上海建橋學院的學生實做,使用 Policy Network 和 Value Network。這裡有[演示影片](https://www.bilibili.com/video/BV1wb41177ah)。
20 |
21 |
22 | * [Irene](https://github.com/GWDx/Irene)
23 |
24 | 十九路圍棋引擎,主要使用 Policy Network 下棋 。
25 |
26 |
27 | * [AlphaGOZero-python-tensorflow](https://github.com/yhyu13/AlphaGOZero-python-tensorflow)
28 |
29 | 十九路圍棋引擎,使用 Policy Network 和 Value Network,實做監督學習和強化學習。
30 |
31 |
32 | * [ymgaq/Pyaq](https://github.com/ymgaq/Pyaq)
33 |
34 | 九路圍棋引擎,示範和教學使用。
35 |
36 | * [kobanium/TamaGo](https://github.com/kobanium/TamaGo)
37 |
38 | 通過 Gumbel AlphaZero 強化學習的九路圍棋引擎,具有較高的強度。
39 |
--------------------------------------------------------------------------------
/docs/SmartGameFormat.md:
--------------------------------------------------------------------------------
1 | # Smart Game Format
2 |
3 | ## ㄧ、歷史
4 | 智慧遊戲格式 (Smart Game Format) 最早來源於 Smart Go ,由其原作者 Anders Kierulf 和後繼者 Martin Mueller , Arno Hollosi 接力開發,因此早期版本稱為智慧圍棋格式 (Smart Go Format)。到了現在 SGF 已經是圍棋軟體預設紀錄儲存棋譜的格式,而且不只是圍棋,其它棋類如,黑白棋,也多採用 SGF 格式。
5 |
6 | ## 二、基本概念
7 | SGF 是以樹狀結構紀錄,每一個節點以 ```;``` 分隔,每一個樹枝以 ```(``` 和 ```)``` 分隔,例如某一樹狀結構為
8 |
9 | |a
10 | |b
11 | f/ \c
12 | g/ \d
13 | \e
14 |
15 | 則其 SGF 結構為
16 |
17 | (;a;b(;f;g)(;c;d;e))
18 |
19 | ## 三、屬性
20 |
21 | 每個節點都有屬性(property)資料,他的表示法為下
22 |
23 | B[aa]
24 |
25 | 此屬性為 ```B``` ,括號內的 ```aa``` 為此屬性的值。如果用 SGF 表示則看起來像
26 |
27 | (;B[aa];W[ab](;B[ac];W[ad])(;B[bc];W[bd];B[bd]))
28 |
29 | 每個節點也可以包含多個屬性資料
30 |
31 | (;B[aa]C[Hello])
32 |
33 | 一些常用的屬性列在下方,如果想要了解更多屬性種類可到 [SGF Wiki](https://en.wikipedia.org/wiki/Smart_Game_Format)
34 |
35 | | 屬性 | 說明 |
36 | | :------------: | :---------------: |
37 | | GM | 遊戲種類,圍棋為 1,必須在 root node |
38 | | FF | 版本,現行版本為 4 ,必須在 root node |
39 | | RU | 使用的規則,必須在 root node |
40 | | RE | 勝負的結果,必須在 root node |
41 | | KM | 貼目,必須在 root node |
42 | | SZ | 盤面大小,必須在 root node |
43 | | AP | 使用的軟體,必須在 root node |
44 | | HA | 讓子數目,必須在 root node |
45 | | AB | 初始盤面的黑棋落子位置,必須在 root node |
46 | | AW | 初始盤面的白棋落子位置,必須在 root node |
47 | | PB | 黑棋玩家名稱,必須在 root node |
48 | | PW | 白落玩家名稱,必須在 root node |
49 | | DT | 日期,必須在 root node |
50 | | B | 黑棋落子座標 |
51 | | W | 白棋落子座標 |
52 | | C | 評論 |
53 |
54 |
55 |
56 | ## 四、範例
57 | 以下是一個 SGF 檔案範例,可用 Sabaki 或是其它支援 SGF 的軟體打開
58 |
59 | (
60 | ;GM[1]FF[4]CA[UTF-8]AP[Sabaki:0.43.3]KM[7.5]SZ[19]DT[2021-10-31]HA[2]AB[dp][pd]PB[Black Player]PW[White Player]
61 | ;W[qp];B[dd];W[fq];B[cn];W[kq]
62 | (
63 | ;B[qf];W[fc];B[df];W[jd];B[lc];W[pk];B[op];W[pn];B[qq];W[rq];B[pq];W[rr];B[mq];W[ko]
64 | )
65 | (
66 | ;B[df];W[qf];B[nc];W[pj];B[op]
67 | )
68 | )
69 |
70 | 其中座標的表示法為 ```a~z```,如 ```B[ab]``` 代表黑棋下在 (1,2) 的位置。而虛手,在十九路或小於十九路的棋盤裡,可用 ```[tt]``` 或是 ```[]``` 代表,但如果棋盤超過十九路,則 ```[tt]``` 代表 (20,20) 的位置,只有 ```[]``` 代表虛手。至於投降手沒有統一的表示方式,通常是不會紀錄在棋譜裡。
71 |
--------------------------------------------------------------------------------
/docs/Structure.md:
--------------------------------------------------------------------------------
1 | # Structure
2 |
3 | #### dlgo.py
4 |
5 | 使用 dlgo 的入口。
6 |
7 | #### train.py
8 |
9 | 實做整訓練管線,也是訓練的入口。
10 |
11 | #### board.py
12 |
13 | 實做整個棋盤的演算法,包括規則和神經網路的輸入。
14 |
15 | #### config.py
16 |
17 | 棋盤大小和網路結構相關設定。
18 |
19 | #### gtp.py
20 |
21 | 實做 GTP 協議的界面。
22 |
23 | #### gui.py
24 |
25 | 實做圖形界面。
26 |
27 | #### mcts.py
28 |
29 | 實做蒙地卡羅樹搜索。
30 |
31 | #### network.py
32 |
33 | 實做神經網路結構。
34 |
35 | #### sgf.py
36 |
37 | SGF 格式檔案的解析器。
38 |
39 | #### time_control.py
40 |
41 | 樹搜索的時間控制器,符合標準 GTP 協議的格式。
42 |
--------------------------------------------------------------------------------
/docs/Training.md:
--------------------------------------------------------------------------------
1 | # Training
2 |
3 | ## 訓練的資訊
4 |
5 | 依照訓練範例輸入下列指令,終端機會出現一系列訊息,幫助掌握目前學習的進度和情況,第一部份是程式在載入解析 sgf 當案並產生訓練資料,並存入 ```data-cache``` 裡,當出現 ```parsed 100.00% games``` 時,代表棋譜已經全處理完成。第二部份就開始訓練網路,其中 ```rate``` 代表每秒訓練幾個 steps ,```estimate``` 代表完成訓練估計的剩餘秒數。
6 |
7 | $ python3 train.py --dir sgf-directory-name --steps 128000 --batch-size 512 --learning-rate 0.001
8 | imported 34572 SGF files
9 | parsed 1.00% games
10 | parsed 2.00% games
11 | parsed 2.99% games
12 | parsed 3.99% games
13 | parsed 4.99% games
14 | parsed 5.99% games
15 | parsed 6.99% games
16 | parsed 7.98% games
17 | parsed 8.98% games
18 | parsed 9.98% games
19 | parsed 10.98% games
20 | parsed 11.98% games
21 | parsed 12.97% games
22 | .
23 | .
24 | .
25 | parsed 92.81% games
26 | parsed 93.80% games
27 | parsed 94.80% games
28 | parsed 95.80% games
29 | parsed 96.80% games
30 | parsed 97.80% games
31 | parsed 98.79% games
32 | parsed 99.79% games
33 | parsed 100.00% games
34 | [2022-3-18 22:06:42] steps: 1000/128000, 0.78% -> policy loss: 2.5627, value loss: 0.9143 | rate: 58.62(steps/sec), estimate: 2166(sec)
35 | [2022-3-18 22:07:00] steps: 2000/128000, 1.56% -> policy loss: 1.9588, value loss: 0.8503 | rate: 55.47(steps/sec), estimate: 2271(sec)
36 | [2022-3-18 22:07:18] steps: 3000/128000, 2.34% -> policy loss: 1.8491, value loss: 0.8228 | rate: 57.48(steps/sec), estimate: 2174(sec)
37 | [2022-3-18 22:07:36] steps: 4000/128000, 3.12% -> policy loss: 1.8122, value loss: 0.8065 | rate: 55.31(steps/sec), estimate: 2242(sec)
38 | [2022-3-18 22:07:54] steps: 5000/128000, 3.91% -> policy loss: 1.7586, value loss: 0.7864 | rate: 56.84(steps/sec), estimate: 2164(sec)
39 | [2022-3-18 22:08:11] steps: 6000/128000, 4.69% -> policy loss: 1.7399, value loss: 0.7695 | rate: 58.57(steps/sec), estimate: 2083(sec)
40 | [2022-3-18 22:08:28] steps: 7000/128000, 5.47% -> policy loss: 1.7173, value loss: 0.7587 | rate: 57.62(steps/sec), estimate: 2100(sec)
41 | [2022-3-18 22:08:46] steps: 8000/128000, 6.25% -> policy loss: 1.6980, value loss: 0.7596 | rate: 55.39(steps/sec), estimate: 2166(sec)
42 | [2022-3-18 22:09:04] steps: 9000/128000, 7.03% -> policy loss: 1.6809, value loss: 0.7423 | rate: 57.01(steps/sec), estimate: 2087(sec)
43 | [2022-3-18 22:09:21] steps: 10000/128000, 7.81% -> policy loss: 1.6723, value loss: 0.7393 | rate: 55.87(steps/sec), estimate: 2112(sec)
44 | [2022-3-18 22:09:38] steps: 11000/128000, 8.59% -> policy loss: 1.6539, value loss: 0.7287 | rate: 59.65(steps/sec), estimate: 1961(sec)
45 | [2022-3-18 22:09:56] steps: 12000/128000, 9.38% -> policy loss: 1.6534, value loss: 0.7135 | rate: 56.55(steps/sec), estimate: 2051(sec)
46 | [2022-3-18 22:10:12] steps: 13000/128000, 10.16% -> policy loss: 1.6464, value loss: 0.7167 | rate: 60.91(steps/sec), estimate: 1888(sec)
47 | [2022-3-18 22:10:30] steps: 14000/128000, 10.94% -> policy loss: 1.6329, value loss: 0.7065 | rate: 57.94(steps/sec), estimate: 1967(sec)
48 | [2022-3-18 22:10:47] steps: 15000/128000, 11.72% -> policy loss: 1.6203, value loss: 0.7064 | rate: 57.48(steps/sec), estimate: 1965(sec)
49 | [2022-3-18 22:11:04] steps: 16000/128000, 12.50% -> policy loss: 1.6204, value loss: 0.7007 | rate: 57.85(steps/sec), estimate: 1936(sec)
50 | [2022-3-18 22:11:22] steps: 17000/128000, 13.28% -> policy loss: 1.6226, value loss: 0.6969 | rate: 56.33(steps/sec), estimate: 1970(sec)
51 | [2022-3-18 22:11:41] steps: 18000/128000, 14.06% -> policy loss: 1.6208, value loss: 0.7004 | rate: 53.50(steps/sec), estimate: 2056(sec)
52 | [2022-3-18 22:11:58] steps: 19000/128000, 14.84% -> policy loss: 1.5990, value loss: 0.6839 | rate: 57.20(steps/sec), estimate: 1905(sec)
53 | [2022-3-18 22:12:16] steps: 20000/128000, 15.62% -> policy loss: 1.6045, value loss: 0.6868 | rate: 54.86(steps/sec), estimate: 1968(sec)
54 | [2022-3-18 22:12:33] steps: 21000/128000, 16.41% -> policy loss: 1.6040, value loss: 0.6831 | rate: 59.90(steps/sec), estimate: 1786(sec)
55 | [2022-3-18 22:12:51] steps: 22000/128000, 17.19% -> policy loss: 1.6024, value loss: 0.6887 | rate: 57.02(steps/sec), estimate: 1859(sec)
56 | [2022-3-18 22:13:08] steps: 23000/128000, 17.97% -> policy loss: 1.5903, value loss: 0.6652 | rate: 58.84(steps/sec), estimate: 1784(sec)
57 | [2022-3-18 22:13:26] steps: 24000/128000, 18.75% -> policy loss: 1.5924, value loss: 0.6760 | rate: 54.78(steps/sec), estimate: 1898(sec)
58 | [2022-3-18 22:13:42] steps: 25000/128000, 19.53% -> policy loss: 1.5900, value loss: 0.6784 | rate: 60.90(steps/sec), estimate: 1691(sec)
59 | [2022-3-18 22:13:59] steps: 26000/128000, 20.31% -> policy loss: 1.5901, value loss: 0.6820 | rate: 59.29(steps/sec), estimate: 1720(sec)
60 | [2022-3-18 22:14:16] steps: 27000/128000, 21.09% -> policy loss: 1.5855, value loss: 0.6672 | rate: 58.16(steps/sec), estimate: 1736(sec)
61 | [2022-3-18 22:14:34] steps: 28000/128000, 21.88% -> policy loss: 1.5843, value loss: 0.6664 | rate: 56.64(steps/sec), estimate: 1765(sec)
62 | [2022-3-18 22:14:52] steps: 29000/128000, 22.66% -> policy loss: 1.5688, value loss: 0.6495 | rate: 54.80(steps/sec), estimate: 1806(sec)
63 | [2022-3-18 22:15:11] steps: 30000/128000, 23.44% -> policy loss: 1.5838, value loss: 0.6698 | rate: 53.59(steps/sec), estimate: 1828(sec)
64 | [2022-3-18 22:15:29] steps: 31000/128000, 24.22% -> policy loss: 1.5638, value loss: 0.6518 | rate: 56.76(steps/sec), estimate: 1708(sec)
65 | [2022-3-18 22:15:46] steps: 32000/128000, 25.00% -> policy loss: 1.5773, value loss: 0.6582 | rate: 56.63(steps/sec), estimate: 1695(sec)
66 | [2022-3-18 22:16:04] steps: 33000/128000, 25.78% -> policy loss: 1.5702, value loss: 0.6650 | rate: 57.30(steps/sec), estimate: 1658(sec)
67 | [2022-3-18 22:16:22] steps: 34000/128000, 26.56% -> policy loss: 1.5723, value loss: 0.6501 | rate: 53.49(steps/sec), estimate: 1757(sec)
68 | [2022-3-18 22:16:39] steps: 35000/128000, 27.34% -> policy loss: 1.5739, value loss: 0.6458 | rate: 58.74(steps/sec), estimate: 1583(sec)
69 | [2022-3-18 22:16:56] steps: 36000/128000, 28.12% -> policy loss: 1.5616, value loss: 0.6486 | rate: 61.29(steps/sec), estimate: 1501(sec)
70 | [2022-3-18 22:17:14] steps: 37000/128000, 28.91% -> policy loss: 1.5699, value loss: 0.6452 | rate: 54.92(steps/sec), estimate: 1657(sec)
71 | [2022-3-18 22:17:32] steps: 38000/128000, 29.69% -> policy loss: 1.5610, value loss: 0.6483 | rate: 55.38(steps/sec), estimate: 1625(sec)
72 | [2022-3-18 22:17:49] steps: 39000/128000, 30.47% -> policy loss: 1.5623, value loss: 0.6464 | rate: 58.77(steps/sec), estimate: 1514(sec)
73 | [2022-3-18 22:18:07] steps: 40000/128000, 31.25% -> policy loss: 1.5544, value loss: 0.6313 | rate: 55.43(steps/sec), estimate: 1587(sec)
74 | [2022-3-18 22:18:24] steps: 41000/128000, 32.03% -> policy loss: 1.5561, value loss: 0.6344 | rate: 60.28(steps/sec), estimate: 1443(sec)
75 | [2022-3-18 22:18:41] steps: 42000/128000, 32.81% -> policy loss: 1.5663, value loss: 0.6550 | rate: 58.49(steps/sec), estimate: 1470(sec)
76 | [2022-3-18 22:18:59] steps: 43000/128000, 33.59% -> policy loss: 1.5520, value loss: 0.6315 | rate: 54.88(steps/sec), estimate: 1548(sec)
77 | [2022-3-18 22:19:17] steps: 44000/128000, 34.38% -> policy loss: 1.5516, value loss: 0.6166 | rate: 54.13(steps/sec), estimate: 1551(sec)
78 | [2022-3-18 22:19:35] steps: 45000/128000, 35.16% -> policy loss: 1.5543, value loss: 0.6227 | rate: 57.33(steps/sec), estimate: 1447(sec)
79 | [2022-3-18 22:19:53] steps: 46000/128000, 35.94% -> policy loss: 1.5484, value loss: 0.6334 | rate: 56.61(steps/sec), estimate: 1448(sec)
80 | [2022-3-18 22:20:10] steps: 47000/128000, 36.72% -> policy loss: 1.5526, value loss: 0.6216 | rate: 58.87(steps/sec), estimate: 1375(sec)
81 | [2022-3-18 22:20:28] steps: 48000/128000, 37.50% -> policy loss: 1.5540, value loss: 0.6420 | rate: 53.49(steps/sec), estimate: 1495(sec)
82 | [2022-3-18 22:20:46] steps: 49000/128000, 38.28% -> policy loss: 1.5425, value loss: 0.6110 | rate: 55.50(steps/sec), estimate: 1423(sec)
83 | [2022-3-18 22:21:05] steps: 50000/128000, 39.06% -> policy loss: 1.5494, value loss: 0.6300 | rate: 52.44(steps/sec), estimate: 1487(sec)
84 | [2022-3-18 22:21:23] steps: 51000/128000, 39.84% -> policy loss: 1.5448, value loss: 0.6226 | rate: 56.87(steps/sec), estimate: 1354(sec)
85 | [2022-3-18 22:21:41] steps: 52000/128000, 40.62% -> policy loss: 1.5406, value loss: 0.6203 | rate: 54.77(steps/sec), estimate: 1387(sec)
86 | [2022-3-18 22:21:59] steps: 53000/128000, 41.41% -> policy loss: 1.5428, value loss: 0.6134 | rate: 56.75(steps/sec), estimate: 1321(sec)
87 | [2022-3-18 22:22:17] steps: 54000/128000, 42.19% -> policy loss: 1.5331, value loss: 0.6079 | rate: 54.10(steps/sec), estimate: 1367(sec)
88 | [2022-3-18 22:22:34] steps: 55000/128000, 42.97% -> policy loss: 1.5387, value loss: 0.6132 | rate: 58.18(steps/sec), estimate: 1254(sec)
89 | [2022-3-18 22:22:52] steps: 56000/128000, 43.75% -> policy loss: 1.5403, value loss: 0.6069 | rate: 56.04(steps/sec), estimate: 1284(sec)
90 | [2022-3-18 22:23:11] steps: 57000/128000, 44.53% -> policy loss: 1.5446, value loss: 0.6307 | rate: 54.43(steps/sec), estimate: 1304(sec)
91 | [2022-3-18 22:23:29] steps: 58000/128000, 45.31% -> policy loss: 1.5368, value loss: 0.6115 | rate: 54.64(steps/sec), estimate: 1281(sec)
92 | [2022-3-18 22:23:46] steps: 59000/128000, 46.09% -> policy loss: 1.5420, value loss: 0.6166 | rate: 59.38(steps/sec), estimate: 1162(sec)
93 | [2022-3-18 22:24:03] steps: 60000/128000, 46.88% -> policy loss: 1.5351, value loss: 0.6047 | rate: 56.63(steps/sec), estimate: 1200(sec)
94 | [2022-3-18 22:24:20] steps: 61000/128000, 47.66% -> policy loss: 1.5338, value loss: 0.6038 | rate: 60.84(steps/sec), estimate: 1101(sec)
95 | [2022-3-18 22:24:37] steps: 62000/128000, 48.44% -> policy loss: 1.5447, value loss: 0.6215 | rate: 57.14(steps/sec), estimate: 1155(sec)
96 | [2022-3-18 22:24:55] steps: 63000/128000, 49.22% -> policy loss: 1.5469, value loss: 0.6223 | rate: 56.20(steps/sec), estimate: 1156(sec)
97 | [2022-3-18 22:25:14] steps: 64000/128000, 50.00% -> policy loss: 1.5304, value loss: 0.5915 | rate: 53.05(steps/sec), estimate: 1206(sec)
98 | [2022-3-18 22:25:32] steps: 65000/128000, 50.78% -> policy loss: 1.5378, value loss: 0.6182 | rate: 54.35(steps/sec), estimate: 1159(sec)
99 | [2022-3-18 22:25:53] steps: 66000/128000, 51.56% -> policy loss: 1.5383, value loss: 0.6068 | rate: 49.76(steps/sec), estimate: 1246(sec)
100 | [2022-3-18 22:26:10] steps: 67000/128000, 52.34% -> policy loss: 1.5390, value loss: 0.6032 | rate: 57.46(steps/sec), estimate: 1061(sec)
101 | [2022-3-18 22:26:28] steps: 68000/128000, 53.12% -> policy loss: 1.5259, value loss: 0.6061 | rate: 55.30(steps/sec), estimate: 1084(sec)
102 | [2022-3-18 22:26:46] steps: 69000/128000, 53.91% -> policy loss: 1.5372, value loss: 0.6100 | rate: 56.86(steps/sec), estimate: 1037(sec)
103 | [2022-3-18 22:27:04] steps: 70000/128000, 54.69% -> policy loss: 1.5356, value loss: 0.6017 | rate: 55.32(steps/sec), estimate: 1048(sec)
104 | [2022-3-18 22:27:21] steps: 71000/128000, 55.47% -> policy loss: 1.5336, value loss: 0.6019 | rate: 59.49(steps/sec), estimate: 958(sec)
105 | [2022-3-18 22:27:37] steps: 72000/128000, 56.25% -> policy loss: 1.5306, value loss: 0.6040 | rate: 59.23(steps/sec), estimate: 945(sec)
106 | [2022-3-18 22:27:55] steps: 73000/128000, 57.03% -> policy loss: 1.5225, value loss: 0.5848 | rate: 58.19(steps/sec), estimate: 945(sec)
107 | [2022-3-18 22:28:11] steps: 74000/128000, 57.81% -> policy loss: 1.5334, value loss: 0.6007 | rate: 60.64(steps/sec), estimate: 890(sec)
108 | [2022-3-18 22:28:29] steps: 75000/128000, 58.59% -> policy loss: 1.5237, value loss: 0.5774 | rate: 56.89(steps/sec), estimate: 931(sec)
109 | [2022-3-18 22:28:47] steps: 76000/128000, 59.38% -> policy loss: 1.5283, value loss: 0.5897 | rate: 55.31(steps/sec), estimate: 940(sec)
110 | [2022-3-18 22:29:04] steps: 77000/128000, 60.16% -> policy loss: 1.5210, value loss: 0.5764 | rate: 56.74(steps/sec), estimate: 898(sec)
111 | [2022-3-18 22:29:22] steps: 78000/128000, 60.94% -> policy loss: 1.5246, value loss: 0.5869 | rate: 55.91(steps/sec), estimate: 894(sec)
112 | [2022-3-18 22:29:40] steps: 79000/128000, 61.72% -> policy loss: 1.5331, value loss: 0.5964 | rate: 57.43(steps/sec), estimate: 853(sec)
113 | [2022-3-18 22:29:59] steps: 80000/128000, 62.50% -> policy loss: 1.5240, value loss: 0.5922 | rate: 52.22(steps/sec), estimate: 919(sec)
114 | [2022-3-18 22:30:17] steps: 81000/128000, 63.28% -> policy loss: 1.5215, value loss: 0.5796 | rate: 55.63(steps/sec), estimate: 844(sec)
115 | [2022-3-18 22:30:35] steps: 82000/128000, 64.06% -> policy loss: 1.5279, value loss: 0.5876 | rate: 55.34(steps/sec), estimate: 831(sec)
116 | [2022-3-18 22:30:53] steps: 83000/128000, 64.84% -> policy loss: 1.5252, value loss: 0.5928 | rate: 56.17(steps/sec), estimate: 801(sec)
117 | [2022-3-18 22:31:11] steps: 84000/128000, 65.62% -> policy loss: 1.5207, value loss: 0.5735 | rate: 55.96(steps/sec), estimate: 786(sec)
118 | [2022-3-18 22:31:27] steps: 85000/128000, 66.41% -> policy loss: 1.5259, value loss: 0.5954 | rate: 60.83(steps/sec), estimate: 706(sec)
119 | [2022-3-18 22:31:44] steps: 86000/128000, 67.19% -> policy loss: 1.5281, value loss: 0.5985 | rate: 58.60(steps/sec), estimate: 716(sec)
120 | [2022-3-18 22:32:01] steps: 87000/128000, 67.97% -> policy loss: 1.5248, value loss: 0.5826 | rate: 60.12(steps/sec), estimate: 681(sec)
121 | [2022-3-18 22:32:18] steps: 88000/128000, 68.75% -> policy loss: 1.5305, value loss: 0.5979 | rate: 59.27(steps/sec), estimate: 674(sec)
122 | [2022-3-18 22:32:34] steps: 89000/128000, 69.53% -> policy loss: 1.5267, value loss: 0.5845 | rate: 59.44(steps/sec), estimate: 656(sec)
123 | [2022-3-18 22:32:53] steps: 90000/128000, 70.31% -> policy loss: 1.5221, value loss: 0.5798 | rate: 54.01(steps/sec), estimate: 703(sec)
124 | [2022-3-18 22:33:10] steps: 91000/128000, 71.09% -> policy loss: 1.5087, value loss: 0.5720 | rate: 58.15(steps/sec), estimate: 636(sec)
125 | [2022-3-18 22:33:29] steps: 92000/128000, 71.88% -> policy loss: 1.5270, value loss: 0.5969 | rate: 52.00(steps/sec), estimate: 692(sec)
126 | [2022-3-18 22:33:48] steps: 93000/128000, 72.66% -> policy loss: 1.5140, value loss: 0.5797 | rate: 54.31(steps/sec), estimate: 644(sec)
127 | [2022-3-18 22:34:06] steps: 94000/128000, 73.44% -> policy loss: 1.5003, value loss: 0.5379 | rate: 55.37(steps/sec), estimate: 614(sec)
128 | [2022-3-18 22:34:23] steps: 95000/128000, 74.22% -> policy loss: 1.5165, value loss: 0.5736 | rate: 57.46(steps/sec), estimate: 574(sec)
129 | [2022-3-18 22:34:42] steps: 96000/128000, 75.00% -> policy loss: 1.5258, value loss: 0.5830 | rate: 53.00(steps/sec), estimate: 603(sec)
130 | [2022-3-18 22:34:59] steps: 97000/128000, 75.78% -> policy loss: 1.5174, value loss: 0.5783 | rate: 58.18(steps/sec), estimate: 532(sec)
131 | [2022-3-18 22:35:18] steps: 98000/128000, 76.56% -> policy loss: 1.5306, value loss: 0.6037 | rate: 51.91(steps/sec), estimate: 577(sec)
132 | [2022-3-18 22:35:35] steps: 99000/128000, 77.34% -> policy loss: 1.5095, value loss: 0.5782 | rate: 58.84(steps/sec), estimate: 492(sec)
133 | [2022-3-18 22:35:54] steps: 100000/128000, 78.12% -> policy loss: 1.5210, value loss: 0.5772 | rate: 55.29(steps/sec), estimate: 506(sec)
134 | [2022-3-18 22:36:12] steps: 101000/128000, 78.91% -> policy loss: 1.5176, value loss: 0.5769 | rate: 55.61(steps/sec), estimate: 485(sec)
135 | [2022-3-18 22:36:30] steps: 102000/128000, 79.69% -> policy loss: 1.5054, value loss: 0.5610 | rate: 53.00(steps/sec), estimate: 490(sec)
136 | [2022-3-18 22:36:47] steps: 103000/128000, 80.47% -> policy loss: 1.5197, value loss: 0.5877 | rate: 59.42(steps/sec), estimate: 420(sec)
137 | [2022-3-18 22:37:04] steps: 104000/128000, 81.25% -> policy loss: 1.5190, value loss: 0.5804 | rate: 58.66(steps/sec), estimate: 409(sec)
138 | [2022-3-18 22:37:22] steps: 105000/128000, 82.03% -> policy loss: 1.5240, value loss: 0.5798 | rate: 58.12(steps/sec), estimate: 395(sec)
139 | [2022-3-18 22:37:39] steps: 106000/128000, 82.81% -> policy loss: 1.5163, value loss: 0.5662 | rate: 56.55(steps/sec), estimate: 389(sec)
140 | [2022-3-18 22:37:57] steps: 107000/128000, 83.59% -> policy loss: 1.5183, value loss: 0.5747 | rate: 57.51(steps/sec), estimate: 365(sec)
141 | [2022-3-18 22:38:14] steps: 108000/128000, 84.38% -> policy loss: 1.5191, value loss: 0.5811 | rate: 56.56(steps/sec), estimate: 353(sec)
142 | [2022-3-18 22:38:31] steps: 109000/128000, 85.16% -> policy loss: 1.5131, value loss: 0.5695 | rate: 59.62(steps/sec), estimate: 318(sec)
143 | [2022-3-18 22:38:48] steps: 110000/128000, 85.94% -> policy loss: 1.5182, value loss: 0.5795 | rate: 58.70(steps/sec), estimate: 306(sec)
144 | [2022-3-18 22:39:06] steps: 111000/128000, 86.72% -> policy loss: 1.5225, value loss: 0.5811 | rate: 57.20(steps/sec), estimate: 297(sec)
145 | [2022-3-18 22:39:22] steps: 112000/128000, 87.50% -> policy loss: 1.5123, value loss: 0.5797 | rate: 59.83(steps/sec), estimate: 267(sec)
146 | [2022-3-18 22:39:39] steps: 113000/128000, 88.28% -> policy loss: 1.5155, value loss: 0.5838 | rate: 59.51(steps/sec), estimate: 252(sec)
147 | [2022-3-18 22:39:57] steps: 114000/128000, 89.06% -> policy loss: 1.5143, value loss: 0.5791 | rate: 57.29(steps/sec), estimate: 244(sec)
148 | [2022-3-18 22:40:15] steps: 115000/128000, 89.84% -> policy loss: 1.5059, value loss: 0.5605 | rate: 53.27(steps/sec), estimate: 244(sec)
149 | [2022-3-18 22:40:35] steps: 116000/128000, 90.62% -> policy loss: 1.5191, value loss: 0.5751 | rate: 49.64(steps/sec), estimate: 241(sec)
150 | [2022-3-18 22:40:52] steps: 117000/128000, 91.41% -> policy loss: 1.5173, value loss: 0.5786 | rate: 59.58(steps/sec), estimate: 184(sec)
151 | [2022-3-18 22:41:10] steps: 118000/128000, 92.19% -> policy loss: 1.5196, value loss: 0.5793 | rate: 57.20(steps/sec), estimate: 174(sec)
152 | [2022-3-18 22:41:27] steps: 119000/128000, 92.97% -> policy loss: 1.5108, value loss: 0.5671 | rate: 57.48(steps/sec), estimate: 156(sec)
153 | [2022-3-18 22:41:44] steps: 120000/128000, 93.75% -> policy loss: 1.5169, value loss: 0.5782 | rate: 57.85(steps/sec), estimate: 138(sec)
154 | [2022-3-18 22:42:01] steps: 121000/128000, 94.53% -> policy loss: 1.5127, value loss: 0.5712 | rate: 60.27(steps/sec), estimate: 116(sec)
155 | [2022-3-18 22:42:19] steps: 122000/128000, 95.31% -> policy loss: 1.5137, value loss: 0.5644 | rate: 54.07(steps/sec), estimate: 110(sec)
156 | [2022-3-18 22:42:36] steps: 123000/128000, 96.09% -> policy loss: 1.5297, value loss: 0.5880 | rate: 59.58(steps/sec), estimate: 83(sec)
157 | [2022-3-18 22:42:54] steps: 124000/128000, 96.88% -> policy loss: 1.5129, value loss: 0.5691 | rate: 55.93(steps/sec), estimate: 71(sec)
158 | [2022-3-18 22:43:12] steps: 125000/128000, 97.66% -> policy loss: 1.5155, value loss: 0.5816 | rate: 57.47(steps/sec), estimate: 52(sec)
159 | [2022-3-18 22:43:29] steps: 126000/128000, 98.44% -> policy loss: 1.5120, value loss: 0.5695 | rate: 56.65(steps/sec), estimate: 35(sec)
160 | [2022-3-18 22:43:46] steps: 127000/128000, 99.22% -> policy loss: 1.5142, value loss: 0.5791 | rate: 60.98(steps/sec), estimate: 16(sec)
161 | [2022-3-18 22:44:03] steps: 128000/128000, 100.00% -> policy loss: 1.5186, value loss: 0.5736 | rate: 56.67(steps/sec), estimate: 0(sec)
162 | Training is over.
163 |
164 |
165 | 最後完成訓練後,會出現以下圖片以視覺化的方式顯示訓練過程
166 |
167 | 
168 |
169 | ## 指定訓練的 GPU
170 |
171 | 在 ```python3``` 輸入環境參數 ```CUDA_VISIBLE_DEVICES``` ,可以指定要用哪個 GPU 訓練網路,GPU 的編號從 0 開始,如果有 4 個 GPU 則編號從 0 到 3,數字 0 代表使用預設的。如果不指定,則默認使用 0 號 GPU。
172 |
173 | $ CUDA_VISIBLE_DEVICES=0 python3 train.py --dir sgf-directory-name --steps 128000 --batch-size 512 --learning-rate 0.01
174 |
175 | ## 降低學習率
176 |
177 | 事實上,訓練圍棋的網路,持續的降低學習率是很重要的,相同訓練資料,有降低學習率和沒有學習率的網路,其強度可以差距三段以上,這個差距在讓子棋中尤其明顯,未降低學習率的網路在前期通常無法有效辨識當前盤面的好壞。dlgo 提供重新載入網路的的功能,可以直接從 workspace 中載入練到一半的網路。這邊可以不用再輸入指令 ```--dir``` ,可以避免重新解析棋譜,直接使用 data-cache 內的資料,加速訓練流程
178 |
179 | $ python3 train.py --steps 128000 --batch-size 512 --learning-rate 0.001
180 |
181 | 或是直接輸入 ```--lr-decay``` 的參數,讓程式自動降低學習率。
182 |
183 | $ python3 train.py ... --steps 512000 --lr-decay-steps 128000 --lr-decay-factor 0.1
184 |
185 | 你可能會好奇,每次降低學習大概需要多少個 steps ,以經驗來看,範例給的 128000 steps 配合 512 batch 的訓練量就非常足夠,依照上面的訓練資訊,loss 已經很難再降低了。當然如果你不放心,可以選用更大 step 數來訓練,以確到達到完全訓練,只要訓練集夠大(數萬盤以上),過度訓練並不會太影響網路強度。最後學習率大概要降低到多少,大概到 1e-5 就可以停止了,如果低於這個值,監督學習的網路可能會 overfitting ,導致網路強度降低。
186 |
187 | ## 改變 batch size
188 |
189 | 理論上,batch size 和學習率是一體的,當該變 batch size 時,學習率也需要一同更改才能盡可能保持一致性,當 batch size 為兩倍時,學習率會被等效縮小兩倍,所以需要增加學習率保持一致性,反之則相反,例如,原本的參數為
190 |
191 | $ python3 train.py --steps 128000 --batch-size 512 --learning-rate 0.001
192 |
193 | 如果 batch size 改成 1024,則學習率需要增加兩倍,反之 batch size 改成 256,則學習率需要減少兩倍。以下的設定理論上和上列是等效的
194 |
195 | $ python3 train.py --steps 128000 --batch-size 1024 --learning-rate 0.002
196 | $ python3 train.py --steps 128000 --batch-size 256 --learning-rate 0.0005
197 |
198 | ## 為甚麼使用 data-cache?
199 | 我們將每一個訓練的 sample 都保存在硬碟上,需要時才讀進主記憶體,因為在訓練大盤面網路時(十九路),需要大量資料,通常需要上百 GB 才能完全讀入,使用 data-cache 可以避免主記憶體容量不夠而且也不太影響訓練效率。
200 |
--------------------------------------------------------------------------------
/docs/Tutorial.md:
--------------------------------------------------------------------------------
1 | # 教學
2 |
3 | ## 零、依賴與來源
4 |
5 | 以下是部份程式碼和資源的來源
6 | 1. board.py 修改自 [ymgaq/Pyaq](https://github.com/ymgaq/Pyaq)
7 | 2. sgf.zip 來源自 [ymgaq/Pyaq](https://github.com/ymgaq/Pyaq)
8 | 3. gui.py 修改自 [YoujiaZhang/AlphaGo-Zero-Gobang](https://github.com/YoujiaZhang/AlphaGo-Zero-Gobang)
9 |
10 | 以下的 python 依賴庫需要安裝(請注意本程式使用 python3)
11 | 1. PyTorch(1.x 版本,如果要使用 GPU 請下載對應的 CUDA/cuDNN 版本)
12 | 2. NumPy
13 | 3. Tkinter(僅使用內建的 GUI 時需要)
14 | 4. Matplotlib(僅訓練時需要)
15 |
16 | 以下程式需要 Java
17 | 1. KGS GTP
18 |
19 | python 部份請輸入下列指令安裝,或自行使用下載可執行的版本
20 |
21 | pip3 install -r requirements.txt
22 |
23 | ## ㄧ、訓練網路
24 |
25 | dlgo 包含 SGF 解析器,可以解析此格式的棋譜,並將棋譜作為訓練資料訓練一個網路,通過以下步驟可以訓練出一個基本網路。
26 |
27 | #### 第一步、收集棋譜
28 |
29 | 需要收集訓練的棋譜,如果你沒有可使用的棋譜,可以使用附的 sgf.zip,裡面包含三萬五千盤左右的九路棋譜。也可以到 [Aya](http://www.yss-aya.com/ayaself/ayaself.html) 、[DarkGo](https://pjreddie.com/media/files/jgdb.tar.gz) 、[KGS](https://www.u-go.net/gamerecords/) 或是 [Leela Zero](https://leela.online-go.com/zero/) 上找到更多可訓練的棋譜。需要注意的是,dlgo 不能解析讓子棋棋譜,如有讓子棋棋譜需要事先清除,還有訓練棋譜至少要有數萬盤,不然的話價值頭(value head)容易崩潰,尤其是十九路。
30 |
31 | #### 第二步、設定網路大小
32 |
33 | 網路的參數包含在 config.py 裡,所需要用到的參數如下
34 |
35 | | 參數 | 說明 |
36 | | :------------------: | :------------------------------------------: |
37 | | BLOCK_SIZE | 殘差網路的 block 的數目,數目越大網路越大 |
38 | | BLOCK_CHANNELS | 卷積網路 channel 的數目,數目越大網路越大 |
39 | | POLICY_CHANNELS | 策略頭 channel 的數目,數目越大策略頭預準度越好 |
40 | | VALUE_CHANNELS | 價值頭 channel 的數目,數目越大價值頭預準度越好 |
41 | | BOARD_SIZE | 棋盤大小,必須和棋譜的大小一致 |
42 | | USE_SE | 是否啟用 Squeeze-and-Excitation 網路結構 |
43 | | USE_GPU | 是否使用 GPU 訓練。如果為 True ,會自動檢查是否有可用的 GPU ,如果沒有檢測到 GPU ,則會使用 CPU 訓練,如果為 False ,則強制使用 CPU 訓練。此參數建議使用 True |
44 |
45 |
46 |
47 | #### 第三步、開始訓練
48 |
49 | 接下來便是開始訓練一個網路,所需要用到的參數如下
50 |
51 | | 參數 |參數類別 | 說明 |
52 | | :---------------: | :---------------: | :---------------: |
53 | | -d, --dir | string | 要訓練的 SGF 檔案夾,不指定則直接使用 ```data-cache``` 的訓練資料|
54 | | -s, --steps | integer | 要訓練的步數,越多訓練時間越久 |
55 | | -b, --batch-size | integer | 訓練的 batch size,建議至少大於 128 ,太低會無法訓練 |
56 | | -l, --learning-rate | float | 學習率大小 ,建議從 0.005 開始 |
57 | | --value-loss-scale | float | Value Loss 的倍率,預設是 0.25 倍 |
58 | | --lr-decay-steps | integer | 每 X steps 降低當前的學習率 |
59 | | --lr-decay-factor | float | 降低學習率的乘數 |
60 | | --noplot | NA | 訓練完後不要使用 Matplotlib 繪圖 |
61 |
62 |
63 |
64 | 以下是訓練範例命令
65 |
66 | $ python3 train.py --dir sgf-directory-name --steps 128000 --batch-size 512 --learning-rate 0.001
67 |
68 | 在一台有配備獨立顯示卡的電腦,大概數個小時內可以完成訓練,如果使用 CPU 訓練大概需要幾天時間。當網路權重出現後,就完成第一步的訓練了。如果你對當前的訓練結果不滿意,可到[這裏](../docs/Training.md)查看一些訓練時的小技巧。
69 |
70 | ## 二、啟動引擎
71 |
72 | ### Linux/MacOS
73 |
74 | 啟動引擎有四個參數是比較重要的
75 |
76 | | 參數 |參數類別 | 說明 |
77 | | :------------: | :---------------: | :---------------: |
78 | | -w, --weights | string | 要使用的網路權重名稱,如果沒給則使用 random 的權重|
79 | | -p, --playouts | integer | MCTS 的 playouts,數目越多越強。預設值是 400 |
80 | | -r, --resign-threshold | float | 投降的門檻,0.1 代表勝率低於 10% 就會投降。預設值是 0.1 |
81 | | -g, --gui | NA | 使用內建的圖形界面。|
82 |
83 |
84 |
85 | 可使用兩種方式執行,一個是直接使用 python 執行之
86 |
87 | $ python3 ./dlgo.py --weights weights-name --playouts 1600 -r 0.25
88 |
89 | 或是將程式碼當作可執行檔案執行,注意在執行以前,必須確定你有權限執行 dlgo.py ,如果沒有,請先使用 chmod 指令更改權限,以下是啟動的範例
90 |
91 | $ chmod 777 dlgo.py
92 | $ ./dlgo.py --weights weights-name --playouts 1600 -r 0.25
93 |
94 | 啟動之後,可以試試輸入 GTP 指令 ```showboard``` ,看看是否有正常運作,順利的話可以看到以下輸出
95 |
96 | showboard
97 | A B C D E F G H J
98 | 9 . . . . . . . . . 9
99 | 8 . . . . . . . . . 8
100 | 7 . . . . . . . . . 7
101 | 6 . . . . . . . . . 6
102 | 5 . . . . . . . . . 5
103 | 4 . . . . . . . . . 4
104 | 3 . . . . . . . . . 3
105 | 2 . . . . . . . . . 2
106 | 1 . . . . . . . . . 1
107 | A B C D E F G H J
108 |
109 | =
110 |
111 |
112 | ### Windows
113 |
114 | 請直接使用 python 執行之
115 |
116 | $ python3 ./dlgo.py --weights weights-name --playouts 1600 -r 0.25
117 |
118 | ## 三、使用 GTP 介面
119 |
120 | dlgo 支援基本的 GTP 介面,你可以使用任何支援 GTP 軟體,比如用 [Sabaki](https://sabaki.yichuanshen.de) 或是 [GoGui](https://github.com/Remi-Coulom/gogui) 將 dlgo 掛載上去,使用的參數參考第二部份。以下是如何在 Sabaki 上使用的教學。
121 |
122 | #### 第一步、打開引擎選項
123 |
124 |
125 |
126 |

127 |
128 |
129 |
130 | #### 第二步、新增引擎
131 |
132 |
133 |

134 |
135 |
136 |
137 | * 如果想用 python 執行之,請將 path 欄位改成 python 執行檔的位置,arguments 改為 ```path/to/dlgo.py --weights weights-name --playouts 1600 -r 0.25```
138 |
139 | #### 第三步、加載引擎
140 |
141 |
142 |

143 |
144 |
145 |
146 | 設置完成後就可以和 dlgo 對戰了。如果想知道 dlgo 支援哪些 GTP 指令,可到[這裏](../docs/dlgoGTP.md)查看。
147 |
148 | ## 四、在 KGS 上使用
149 |
150 | KGS 是一個網路圍棋伺服器,它曾經世界最大、最多人使用的網路圍棋。KGS 除了可以上網下棋以外,還能掛載 GTP 引擎上去,以下將會教學如何將 dlgo 掛載上去。
151 |
152 | #### 第一步、下載 KGS 客戶端並註冊
153 |
154 | 請到 [KGS 官網](https://www.gokgs.com/index.jsp?locale=zh_CN)上下載對應系統的客戶端,如果是 Linux 系統,請選擇 Raw JAR File。接下來到 [KGS 註冊網站](https://www.gokgs.com/register/index.html)創立一個帳號。
155 |
156 |
157 | #### 第二步、下載 KGS GTP 客戶端
158 |
159 | 到 [KGS GTP 網站](https://www.gokgs.com/download.jsp)下載專為 GTP 引擎設計的客戶端。
160 |
161 | #### 第三步、掛載 GTP 引擎
162 |
163 | 首先我們需要創建並設定 config.txt 的參數,你可以參考以下直接設定
164 |
165 | name=帳號
166 | password=密碼
167 | room=Computer Go
168 | mode=custom
169 |
170 | rules=chinese
171 | rules.boardSize=9
172 | rules.time=10:00
173 |
174 | undo=f
175 | reconnect=t
176 | verbose=t
177 |
178 | engine=dlgo 的路徑和參數
179 |
180 | 設定完成後就可以輸入以下命令即可掛載引擎
181 |
182 | $ java -jar kgsGtp.jar config.txt
183 |
184 | 詳細的參數說明可以看 [KGS GTP 文件](http://www.weddslist.com/kgs/how/kgsGtp.html)。注意這是舊版的文件,如果要新版的文件,可以點擊在同個資料夾的 kgsGtp.xhtml,或是輸入以下命令在終端機觀看。
185 |
186 | $ java -jar kgsGtp.jar -h
187 |
188 | #### 第四步、和 dlgo 在 KGS 上下棋
189 |
190 | 登錄 KGS 客戶端(注意,你第一個申請的帳號正在被引擎使用,請申請第二個帳號或使用參觀模式),可以從 “新開對局” 中找到你的帳號,點擊你的帳號即可發出對局申請。
191 |
192 | ## 五、參加 TCGA 競賽
193 |
194 | TCGA 全名為台灣電腦對局協會,基本上每年會舉辦兩場各類型的電腦對局比賽,當然也包括圍棋。TCGA 的圍棋比賽是使用 KGS 伺服器連線的,如果你已經能順利掛載引擎到 KGS 上,恭喜你完成第一步比賽的準備,接下來可以設法強化 dlgo,像是訓練更大的權重,或是更改網路結構等等,希望你能在比賽中獲得好成績。
195 |
196 | ### 一些基本問題
197 |
198 | 如果你已經有自己的圍棋程式,並且想參加比賽,你可能會有一些問題
199 |
200 | 1. 如何繳費?費用多少?
201 |
202 | 現場繳費,一般而言,費用是每個程式一千到兩千之間,學生可能會有優惠,實際費用需要看大會規定。如果要報名多個程式,假設每個程式一千元,則兩個倆千元,以此類推。
203 |
204 |
205 | 2. 參加比賽是否有任何限制?
206 |
207 | 沒有限制,所有人皆可參加,但要求實做程式盡可能原創,不可直接拿他人寫好的程式參加比賽。
208 |
209 |
210 | 3. 參加圍棋比賽是否一定需要實做 GTP 界面?
211 |
212 | 可以不實做 GTP 界面,但申請 KGS 帳號是必須的,你可以用手動的方式,將程式的輸出擺到盤面上,但這樣損失很多思考時間,建議還是實做 GTP 指令。
213 |
214 |
215 | 4. 建議最低需要實做的 GTP 指令?
216 |
217 | 建議至少實做 ```quit```、```name```、```version```、```protocol_version```、```list_commands```、```play```、```genmove```、```clear_board```、```boardsize```、```komi``` 等指令。如果可行的話,建議也實做 ```time_settings```、```time_left``` 這兩個指令,因為每場比賽都有時間限,這兩個指令可以告訴程式剩餘時間,讓程式可以分配剩餘時間,但如果你可以確保程式能在規定時間內執行完畢(包括網路延遲),則可不必實做。
218 |
219 |
220 | 5. 參加圍棋比賽是否一定要到現場?
221 |
222 | 不一定,但需要事先聯繫其他人,討論決定當天比賽如何進行。
223 |
224 |
225 | 6. 使用 dlgo 參加比賽是否可以?
226 |
227 | 可以,但是要求報名時標注使用此程式,並且要有一定的改進,最後希望能將改進的程式能開源。
228 |
229 | 7. 比賽規則為何?
230 |
231 | 通過 KGS 平台比賽。均採用中國規則且禁全同,九路思考時間為 10 分鐘不讀秒,十九路思考時間為 30 分鐘不讀秒。
232 |
233 |
234 | ### TCGA/ICGA 相關比賽列表
235 |
236 | 請查看[這裡](https://hackmd.io/@yrHb-fKBRoyrKDEKdPSDWg/ryoCXyXjK)
237 |
238 |
--------------------------------------------------------------------------------
/docs/dlgoAPI.md:
--------------------------------------------------------------------------------
1 | # dlgo API
2 |
3 | ## board.py
4 |
5 | 以下是棋盤裡可用的 functions 和參數,此檔案依賴於 config.py 和 NumPy,只需要少量修改,就可以遷移進入你的專案。
6 |
7 | #### Functions
8 | * `void Board.__init__(size: int, komi: float)`
9 | * Board 的初始化建構。
10 |
11 | * `void Board.reset(size: int, komi: float)`
12 | * 清理盤面並重新開始。
13 |
14 | * `bool Board.legal(vertex: int)`
15 | * 測試是否為合法手,如果是合法手,返回 True。
16 |
17 | * `bool Board.play(vertex: int)`
18 | * 走一手棋到到盤面上。也會測試是否為合法手,如果是合法手,返回 True。
19 |
20 | * `int Board.final_score()`
21 | * 計算基於 Tromp-Taylor 規則的目數。
22 |
23 | * `int Board.get_vertex(x: int, y: int)`
24 | * 將 x, y 座標轉成 vertex。
25 |
26 | * `str Board.vertex_to_text(vtx: int)`
27 | * 將 vertex 轉成文字。
28 |
29 | * `Board Board.copy()`
30 | * 快速複製當前的棋盤,複製的棋盤共用歷史盤面。
31 |
32 | * `bool Board.superko()`
33 | * 當前盤面是否為 superko,如果是則返回 True。
34 |
35 | * `nparry Board.get_features()`
36 | * 得到神經網路的輸入資料。
37 |
38 | * `str Board.__str__()`
39 | * 將當前盤面轉成文字。
40 |
41 | #### Parametes
42 |
43 | * `int BLACK = 0`
44 | * 黑棋的數值。
45 |
46 | * `int WHITE = 1`
47 | * 白棋的數值。
48 |
49 | * `int PASS = -1`
50 | * 虛手的 vertex 數值。
51 |
52 | * `int RESIGN = -2`
53 | * 投降的 vertex 數值。
54 |
55 | * `int Board.board_size`
56 | * 當前盤面大小。
57 |
58 | * `float Board.komi`
59 | * 當前貼目。
60 |
61 | * `float Board.to_move`
62 | * 當前下棋的顏色。
63 |
64 | * `int Board.move_num`
65 | * 當前手數。
66 |
67 | * `int Board.last_move`
68 | * 上一手棋下的位置。
69 |
70 | * `int Board.num_passes`
71 | * 虛手的次數。
72 |
73 | * `list[nparray] Board.history`
74 | * 歷史的盤面。
75 |
--------------------------------------------------------------------------------
/docs/dlgoGTP.md:
--------------------------------------------------------------------------------
1 | # dlgo GTP
2 |
3 | ## ㄧ、GTP 簡介
4 | GTP(Go Text Protocol) 最早為 GNU Go 團隊為了簡化當時的電腦圍棋協定 Go Modem Protocol,在 GNU Go 3.0 時引入,到了現代 GTP 成為圍棋軟體普遍的溝通方式。GTP 運作的原理相當簡單,就是界面(或是使用者)向引擎送出一條指令,引擎根據指令做出相應的動作,並且回覆訊息給界面。回覆的格式分成兩種,一種是執行成功,此時回覆的第一個字元是 ```=```,另一種是執行失敗,此時回覆的第一個字元是 ```?```,回覆完結時要換兩行表示結束。以下是 dlgo 的範例,此範例要注意的是 showboard 的部份輸出是 stderr ,stdout 部份依舊是標準的 GTP 指令。
5 |
6 | name
7 | = dlgo
8 |
9 | version
10 | = 0.1
11 |
12 | protocol_version
13 | = 2
14 |
15 | play b e5
16 | =
17 |
18 | play w e3
19 | =
20 |
21 | play b g5
22 | =
23 |
24 | showboard
25 | A B C D E F G H J
26 | 9 . . . . . . . . . 9
27 | 8 . . . . . . . . . 8
28 | 7 . . . . . . . . . 7
29 | 6 . . . . . . . . . 6
30 | 5 . . . . X . [X] . . 5
31 | 4 . . . . . . . . . 4
32 | 3 . . . . O . . . . 3
33 | 2 . . . . . . . . . 2
34 | 1 . . . . . . . . . 1
35 | A B C D E F G H J
36 |
37 | =
38 |
39 | play w c4
40 | =
41 |
42 | showboard
43 | A B C D E F G H J
44 | 9 . . . . . . . . . 9
45 | 8 . . . . . . . . . 8
46 | 7 . . . . . . . . . 7
47 | 6 . . . . . . . . . 6
48 | 5 . . . . X . X . . 5
49 | 4 . . [O] . . . . . . 4
50 | 3 . . . . O . . . . 3
51 | 2 . . . . . . . . . 2
52 | 1 . . . . . . . . . 1
53 | A B C D E F G H J
54 |
55 | =
56 |
57 | aabbbcccc
58 | ? Unknown command
59 |
60 | quit
61 | =
62 |
63 | ## 二、支援的指令
64 |
65 | dlgo 僅支援基本的 GTP 指令集,主要是為了滿足 TCGA 比賽的基本需求(KGS),當然還有很多標準指令尚未實作,如果有興趣可到 [GTP 英文文檔](https://www.gnu.org/software/gnugo/gnugo_19.html)找到更多資訊。以下是 dlgo 支援的指令。
66 |
67 | * `quit`
68 | * 退出並結束執行。
69 |
70 | * `name`
71 | * 顯示程式名字。
72 |
73 | * `version`
74 | * 顯示程式版本。
75 |
76 | * `protocol_version`
77 | * 顯示使用的 GTP 版本。
78 |
79 | * `list_commands`
80 | * 顯示所有此程式支援的 GTP 指令。
81 |
82 | * `play [black|white] `
83 | * 走一手棋到盤面上,必須是合法手。參數中的 vertex 為 [GTP vertex](https://www.lysator.liu.se/~gunnar/gtp/gtp2-spec-draft2/gtp2-spec.html#SECTION00042000000000000000) ,例如 ```a1```、```a2```、```e5``` 等座標位置或是 pass 代表虛手,resign 代表投降。
84 |
85 | * `genmove [black|white]`
86 | * 讓引擎思考並產生下一手棋到盤面上。
87 |
88 | * `undo`
89 | * 悔棋。
90 |
91 | * `clear_board`
92 | * 清空盤面,重新開始新的一局。
93 |
94 | * `boardsize `
95 | * 設定不同的盤面大小。注意,dlgo 的神經網路接受的盤面大小是固定的,隨意調整可能使程式崩貴退出。
96 |
97 | * `komi `
98 | * 設定不同的貼目。注意,dlgo 的網路輸出勝率不會因貼目改變而動態調整。
99 |
100 | * `time_settings `
101 | * 設定初始的時限並重新計時, ```main time``` 為基本思考時間,```byo time``` 為讀秒思考時間,```byo stones``` 為讀秒內要下的棋手數,僅支援加拿大讀秒(Canadian byo-yomi)規則。
102 |
103 | * `time_left [black|white] `
104 | * 設定某方剩餘的時限。
105 |
106 | * `analyze [black|white] `
107 | * 背景分析結果,詳細指令的參數可到[這裡](https://github.com/SabakiHQ/Sabaki/blob/master/docs/guides/engine-analysis-integration.md)查看。
108 |
109 | * `genmove_analyze [black|white] `
110 | * 讓引擎思考並產生下一手棋到盤面上並在背景送出分析結果,詳細指令的參數可到[這裡](https://github.com/SabakiHQ/Sabaki/blob/master/docs/guides/engine-analysis-integration.md)查看。
111 |
112 | ## 三、其它 KGS 可用指令
113 |
114 | 有些指令在 KGS 上有特殊效果,或是可以提供更多功能,如果有興趣的話,可以優先實作下列指令,指令的參數和效果可到 [GTP 英文文檔](https://www.gnu.org/software/gnugo/gnugo_19.html)查看
115 |
116 | * `final_status_list [alive|dead]`
117 | * 顯示當前盤面的死棋和活棋棋串。GNU Go 還有實做其它種類的判斷,如 ```seki```、 ```white_territory```、 ```black_territory``` 和 ```dame```。
118 |
119 | * `place_free_handicap `
120 | * 讓引擎自己生成讓子的位置。
121 |
122 | * `set_free_handicap `
123 | * 使用者告訴電腦讓子的位置。
124 |
125 | * `kgs-genmove_cleanup [black|white]`
126 | * KGS 專用的生成合法手的指令,禁止虛手直到盤面沒有死棋為止,用以清除盤面死棋。
127 |
128 | * `kgs-time_settings ...`
129 | * KGS 專用的時間控制指令,相比原版的多支援 Byo-Yomi 讀秒規則,詳情請看[這裡](https://www.gokgs.com/help/timesystems.html)。
130 |
131 | * `kgs-game_over`
132 | * 當每盤對戰結束,會發出此指令。
133 |
--------------------------------------------------------------------------------
/gtp.py:
--------------------------------------------------------------------------------
1 | from sys import stderr, stdout, stdin
2 | from board import Board, PASS, RESIGN, BLACK, WHITE, INVLD
3 | from network import Network
4 | from mcts import Search
5 | from config import BOARD_SIZE, KOMI
6 | from time_control import TimeControl
7 |
8 | class GTP_ENGINE:
9 | def __init__(self, args):
10 | self.args = args
11 | self.board = Board(BOARD_SIZE, KOMI)
12 | self.network = Network(BOARD_SIZE)
13 | self.time_control = TimeControl()
14 | self.network.trainable(False)
15 | self.board_history = [self.board.copy()]
16 | self.last_verbose = str()
17 |
18 | if self.args.weights != None:
19 | self.network.load_pt(self.args.weights)
20 |
21 | # For GTP command "clear_board". Reset the board to the initial state and
22 | # clear the move history.
23 | def clear_board(self):
24 | self.board.reset(self.board.board_size, self.board.komi)
25 | self.board_history = [self.board.copy()]
26 | self.network.clear_cache()
27 |
28 | # For GTP command "genmove". The engine returns the best move and play it.
29 | def genmove(self, color):
30 | # Genrate next move and play it.
31 | c = self.board.to_move
32 | if color.lower()[:1] == "b":
33 | c = BLACK
34 | elif color.lower()[:1] == "w":
35 | c = WHITE
36 |
37 | self.board.to_move = c
38 | search = Search(self.board, self.network, self.time_control)
39 |
40 | # Collect the search verbose for the built-in GUI.
41 | move, self.last_verbose = search.think(
42 | self.args.playouts,
43 | self.args.resign_threshold,
44 | self.args.verbose)
45 | if self.board.play(move):
46 | self.board_history.append(self.board.copy())
47 |
48 | return self.board.vertex_to_text(move)
49 |
50 | # For GTP command "play". Play a move if it is legal.
51 | def play(self, color, move):
52 | # play move if the move is legal.
53 | c = INVLD
54 | if color.lower()[:1] == "b":
55 | c = BLACK
56 | elif color.lower()[:1] == "w":
57 | c = WHITE
58 |
59 | vtx = None
60 | if move == "pass":
61 | vtx = PASS
62 | elif move == "resign":
63 | vtx = RESIGN
64 | else:
65 | x = ord(move[0]) - (ord('A') if ord(move[0]) < ord('a') else ord('a'))
66 | y = int(move[1:]) - 1
67 | if x >= 8:
68 | x -= 1
69 | vtx = self.board.get_vertex(x,y)
70 |
71 | if c != INVLD:
72 | self.board.to_move = c
73 | if self.board.play(vtx):
74 | self.board_history.append(self.board.copy())
75 | return True
76 | return False
77 |
78 | # For GTP command "undo". Play the undo move.
79 | def undo(self):
80 | if len(self.board_history) > 1:
81 | self.board_history.pop()
82 | self.board = self.board_history[-1].copy()
83 |
84 | # For GTP command "boardsize". Set a variant board size.
85 | def boardsize(self, bsize):
86 | self.board.reset(bsize, self.board.komi)
87 | self.board_history = [self.board.copy()]
88 | self.network.clear_cache()
89 |
90 | # For GTP command "boardsize". Set a variant komi.
91 | def komi(self, k):
92 | self.board.komi = k
93 |
94 | # For GTP command "time_settings". Set initial time settings and restart it.
95 | # 'main time' is basic thinking time.
96 | # 'byo time' is byo yomi time.
97 | # 'byo stones' is byo yomi stone.
98 | def time_settings(self, main_time, byo_time, byo_stones):
99 | if not main_time.isdigit() or \
100 | not byo_time.isdigit() or \
101 | not byo_stones.isdigit():
102 | return False
103 |
104 | self.time_control.time_settings(int(main_time), int(byo_time), int(byo_stones))
105 | return True
106 |
107 | # For GTP command "time_left". Set time left value for one side.
108 | def time_left(self, color, time, stones):
109 | c = INVLD
110 | if color.lower()[:1] == "b":
111 | c = BLACK
112 | elif color.lower()[:1] == "w":
113 | c = WHITE
114 | if c == INVLD:
115 | return False
116 | self.time_control.time_left(c, int(time), int(stones))
117 | return True
118 |
119 | # For GTP command "showboard". Dump the board(stand error output).
120 | def showboard(self):
121 | stderr.write(str(self.board))
122 | stderr.flush()
123 |
124 | def lz_genmove_analyze(self, color, interval):
125 | c = self.board.to_move
126 | if color.lower()[:1] == "b":
127 | c = BLACK
128 | elif color.lower()[:1] == "w":
129 | c = WHITE
130 |
131 | self.board.to_move = c
132 | search = Search(self.board, self.network, self.time_control)
133 | search.analysis_tag["interval"] = interval/100
134 |
135 | # Collect the search verbose for the built-in GUI.
136 | move, self.last_verbose = search.think(
137 | self.args.playouts,
138 | self.args.resign_threshold,
139 | self.args.verbose)
140 | if self.board.play(move):
141 | self.board_history.append(self.board.copy())
142 |
143 | return self.board.vertex_to_text(move)
144 |
145 | def lz_analyze(self, color, interval):
146 | c = self.board.to_move
147 | if color.lower()[:1] == "b":
148 | c = BLACK
149 | elif color.lower()[:1] == "w":
150 | c = WHITE
151 |
152 | self.board.to_move = c
153 | search = Search(self.board, self.network, self.time_control)
154 | search.analysis_tag["interval"] = interval/100
155 |
156 | # Collect the search verbose for the built-in GUI.
157 | self.last_verbose = search.ponder(
158 | self.args.playouts * 100,
159 | self.args.verbose)
160 |
161 | class GTP_LOOP:
162 | COMMANDS_LIST = [
163 | "quit", "name", "version", "protocol_version", "list_commands",
164 | "play", "genmove", "undo", "clear_board", "boardsize", "komi",
165 | "time_settings", "time_left", "lz-genmove_analyze", "lz-analyze"
166 | ]
167 | def __init__(self, args):
168 | self.engine = GTP_ENGINE(args)
169 | self.args = args
170 | self.cmd_id = None
171 |
172 | # Start the main GTP loop.
173 | self.loop()
174 |
175 | def loop(self):
176 | while True:
177 | # Get the commands.
178 | cmd = stdin.readline().split()
179 |
180 | # Get the command id.
181 | self.set_id(cmd)
182 |
183 | if len(cmd) == 0:
184 | continue
185 |
186 | main = cmd[0]
187 | if main == "quit":
188 | self.success_print("")
189 | break
190 |
191 | # Parse the commands and execute it.
192 | self.process(cmd)
193 |
194 | def set_id(self, cmd):
195 | self.cmd_id = None
196 | if len(cmd) == 0:
197 | return
198 |
199 | if cmd[0].isdigit():
200 | self.cmd_id = cmd.pop(0)
201 |
202 | def process(self, cmd):
203 | # TODO: Support analyze and genmove_analyze commands.
204 | main = cmd[0]
205 |
206 | if main == "name":
207 | self.success_print("dlgo")
208 | elif main == "version":
209 | version = "0.1";
210 | if self.args.kgs:
211 | self.success_print(version + "\nI am a simple bot. I don't understand the life and death. Please help me to remove the dead strings when the game is end. Have a nice game.")
212 | else:
213 | self.success_print(version)
214 | elif main == "protocol_version":
215 | self.success_print("2")
216 | elif main == "list_commands":
217 | clist = str()
218 | for c in self.COMMANDS_LIST:
219 | clist += c
220 | if c is not self.COMMANDS_LIST[-1]:
221 | clist += '\n'
222 | self.success_print(clist)
223 | elif main == "clear_board":
224 | # reset the board
225 | self.engine.clear_board();
226 | self.success_print("")
227 | elif main == "play" and len(cmd) >= 3:
228 | # play color move
229 | if self.engine.play(cmd[1], cmd[2]):
230 | self.success_print("")
231 | else:
232 | self.fail_print("")
233 | elif main == "undo":
234 | # undo move
235 | self.engine.undo();
236 | self.success_print("")
237 | elif main == "genmove" and len(cmd) >= 2:
238 | # genrate next move
239 | self.success_print(self.engine.genmove(cmd[1]))
240 | elif main == "boardsize" and len(cmd) >= 2:
241 | # set board size and reset the board
242 | self.engine.boardsize(int(cmd[1]))
243 | self.success_print("")
244 | elif main == "komi" and len(cmd) >= 2:
245 | # set komi
246 | self.engine.komi(float(cmd[1]))
247 | self.success_print("")
248 | elif main == "showboard":
249 | # display the board
250 | self.engine.showboard()
251 | self.success_print("")
252 | elif main == "time_settings":
253 | if self.engine.time_settings(cmd[1], cmd[2], cmd[3]):
254 | self.success_print("")
255 | else:
256 | self.fail_print("")
257 | elif main == "time_left":
258 | if self.cmd_id is not None:
259 | stdout.write("={}\n".format(self.cmd_id))
260 | else:
261 | stdout.write("=\n")
262 | stdout.flush()
263 | elif main == "lz-genmove_analyze":
264 | color = "tomove"
265 | interval = 0
266 | if len(cmd) >= 2:
267 | if cmd[1].isdigit():
268 | interval = cmd[1]
269 | if cmd[2].isdigit():
270 | color = cmd[1]
271 | interval = cmd[2]
272 | self.success_half('')
273 | m = self.engine.lz_genmove_analyze(color, int(interval))
274 | stdout.write("play {}\n\n".format(m))
275 | stdout.flush()
276 | elif main == "lz-analyze":
277 | color = "tomove"
278 | interval = 0
279 | if len(cmd) >= 2:
280 | if cmd[1].isdigit():
281 | interval = cmd[1]
282 | if cmd[2].isdigit():
283 | color = cmd[1]
284 | interval = cmd[2]
285 | self.success_half('')
286 | m = self.engine.lz_analyze(color, int(interval))
287 | stdout.write("\n")
288 | stdout.flush()
289 | else:
290 | self.fail_print("Unknown command")
291 |
292 | def success_half(self, res):
293 | if self.cmd_id is not None:
294 | stdout.write("={} {}\n".format(self.cmd_id, res))
295 | else:
296 | stdout.write("= {}\n".format(res))
297 | stdout.flush()
298 |
299 | def success_print(self, res):
300 | if self.cmd_id is not None:
301 | stdout.write("={} {}\n\n".format(self.cmd_id, res))
302 | else:
303 | stdout.write("= {}\n\n".format(res))
304 | stdout.flush()
305 |
306 | def fail_print(self, res):
307 | stdout.write("? {}\n\n".format(res))
308 | stdout.flush()
309 |
--------------------------------------------------------------------------------
/gui.py:
--------------------------------------------------------------------------------
1 | from board import Board, PASS, RESIGN, BLACK, WHITE, INVLD, EMPTY
2 | from gtp import GTP_ENGINE
3 | from config import BOARD_SIZE, KOMI
4 |
5 | import time
6 | import argparse
7 | import tkinter as tk
8 | from threading import Thread
9 | from tkinter import scrolledtext
10 |
11 | class GUI_LOOP(GTP_ENGINE):
12 | def __init__(self, args):
13 | super(GUI_LOOP, self).__init__(args)
14 |
15 | self.init_layouts(1200, 800)
16 |
17 | self.window = tk.Tk()
18 | self.window.resizable(0, 0)
19 | self.window.title("Deep Learning of Go")
20 | self.window.geometry("{w}x{h}".format(w=self.width, h=self.height))
21 |
22 | self.oval_buffer = [None] * self.board.num_intersections
23 | self.text_buffer = [None] * self.board.num_intersections
24 |
25 | self.game_thread = None
26 | self.suspend = False
27 | self.acquire_vtx = None
28 |
29 | self.init_widgets()
30 | self.window.mainloop()
31 |
32 | def init_layouts(self, width, height):
33 | min_width = 800
34 | min_height = 500
35 | self.widgets_offset_base = 30
36 |
37 |
38 | self.width = max(width, min_width)
39 | self.height = max(height, min_height)
40 |
41 | size_base = min(self.width, self.height)
42 |
43 | # Set the canvas size and coordinate.
44 | self.canvas_size = size_base - self.widgets_offset_base * 2 # The canvas the always square.
45 | self.canvas_x = self.widgets_offset_base
46 | self.canvas_y = self.widgets_offset_base
47 |
48 | # Set the buttons's coordinates.
49 | buttons_offset_base = self.canvas_x + self.canvas_size + self.widgets_offset_base
50 | self.buttons_x = [buttons_offset_base + 0 * 90,
51 | buttons_offset_base + 1 * 90,
52 | buttons_offset_base + 2 * 90,
53 | buttons_offset_base + 3 * 90]
54 | self.buttons_y = 4 * [self.widgets_offset_base]
55 |
56 | # Set the scrolled text size and coordinate.
57 | self.scrolled_x = self.canvas_x + self.canvas_size + self.widgets_offset_base
58 | self.scrolled_y = 3 * self.widgets_offset_base
59 |
60 | self.scrolled_width = round((self.width - self.scrolled_x - self.widgets_offset_base) / 9)
61 | self.scrolled_height = round((self.height - self.scrolled_y - self.widgets_offset_base) / 18)
62 |
63 | def init_widgets(self):
64 | self.canvas = tk.Canvas(self.window, bg="#CD853F", height=self.canvas_size, width=self.canvas_size)
65 | self.scroll_rext = scrolledtext.ScrolledText(self.window, height=self.scrolled_height, width=self.scrolled_width)
66 |
67 | self.bt_black_start = tk.Button(self.window, text="執黑開始", command=lambda : self.start_new_game(BLACK))
68 | self.bt_black_start.place(x=self.buttons_x[0], y=self.buttons_y[0])
69 |
70 | self.bt_white_start = tk.Button(self.window, text="執白開始", command=lambda : self.start_new_game(WHITE))
71 | self.bt_white_start.place(x=self.buttons_x[1], y=self.buttons_y[1])
72 |
73 | self.bt_self_play = tk.Button(self.window, text="電腦自戰", command=lambda : self.start_new_game())
74 | self.bt_self_play.place(x=self.buttons_x[2], y=self.buttons_y[2])
75 |
76 | self.bt_pass_start = tk.Button(self.window, text="虛手", command=lambda : self.acquire_move(PASS))
77 | self.bt_pass_start.place(x=self.buttons_x[3], y=self.buttons_y[3])
78 |
79 | self.draw_canvas(self.canvas_x, self.canvas_y)
80 | self.draw_scroll_text(self.scrolled_x, self.scrolled_y)
81 |
82 | def draw_canvas(self, x, y):
83 | bsize = self.board.board_size
84 | square_size = self.canvas_size / bsize
85 | lower = square_size/2
86 | upper = self.canvas_size - square_size/2
87 |
88 | for i in range(bsize):
89 | offset = i * square_size
90 | self.canvas.create_line(lower ,lower+offset, upper , lower+offset)
91 | self.canvas.create_line(lower+offset ,lower , lower+offset, upper)
92 | self.canvas.place(x=x, y=y)
93 |
94 | def draw_scroll_text(self, x, y):
95 | self.scroll_rext.place(x=x, y=y)
96 |
97 | def insert_scroll_text(self, string):
98 | self.scroll_rext.insert(tk.END, string+'\n')
99 | self.scroll_rext.see(tk.END)
100 | self.scroll_rext.update()
101 |
102 | def reset_canvas(self):
103 | self.clear_board()
104 | self.canvas.delete("all")
105 | self.scroll_rext.delete(1.0, tk.END)
106 | self.draw_canvas(self.canvas_x, self.canvas_y)
107 | self.canvas.bind("", self.scan_move)
108 | self.rect = None
109 |
110 | def draw_stone(self, to_move, rc_pos, move_num=None):
111 | r, c = rc_pos
112 | x, y = self.convert_rc_to_xy(rc_pos)
113 |
114 | bsize = self.board.board_size
115 | square_size = self.canvas_size/bsize
116 |
117 | color_stone = "black" if to_move == BLACK else "white"
118 | color_index = "white" if to_move == BLACK else "black"
119 | color_border = "#696969" if to_move == BLACK else "black"
120 |
121 |
122 | radius = max(square_size/2 - 5, 15)
123 | border = max(round(radius/15), 2)
124 | self.oval_buffer[self.board.get_index(r, c)] = self.canvas.create_oval(
125 | x-radius, y-radius, x+radius, y+radius,
126 | fill=color_stone, outline=color_border, width=border)
127 | if self.rect == None:
128 | offset = max(square_size/2 , 20)
129 | self.rect = self.canvas.create_rectangle(x-offset, y-offset, x+offset, y+offset, outline="#c1005d")
130 | self.rect_xy_pos = (x, y)
131 | else:
132 | rc_pos = self.convert_xy_to_rc((x, y))
133 | old_x, old_y = self.rect_xy_pos
134 | new_x, new_y = self.convert_rc_to_xy(rc_pos)
135 | dx, dy = new_x-old_x, new_y-old_y
136 | self.canvas.move(self.rect, dx, dy)
137 | self.rect_xy_pos = (new_x, new_y)
138 |
139 | text_size = round(1*square_size/2)
140 | if move_num == None:
141 | move_num = str()
142 |
143 | self.text_buffer[self.board.get_index(r,c)] = self.canvas.create_text(x,y, text=str(move_num), fill=color_index, font=('Arial', text_size))
144 | self.canvas.update()
145 |
146 | def convert_rc_to_xy(self, rc_pos):
147 | bsize = self.board.board_size
148 | square_size = self.canvas_size/bsize
149 | lower = square_size/2
150 |
151 | r, c = rc_pos
152 |
153 | x = c*square_size + lower
154 | y = r*square_size + lower
155 | return x, y
156 |
157 | def convert_xy_to_rc(self, xy_pos):
158 | bsize = self.board.board_size
159 | square_size = self.canvas_size/bsize
160 | lower = square_size/2
161 |
162 | x, y = xy_pos
163 | r = round((y-lower)/square_size)
164 | c = round((x-lower)/square_size)
165 | return r, c
166 |
167 | def start_new_game(self, color=None):
168 | self.suspend = True # stop the board updating.
169 |
170 | self.acquire_vtx = None
171 | self.turns = ["compute", "compute"]
172 | if color != None:
173 | self.turns[color] = "player"
174 | self.reset_canvas()
175 | self.game_over = False
176 |
177 | if self.game_thread == None:
178 | # Create one game if we don't do it.
179 | self.game_thread = Thread(target=self.process_game,)
180 | self.game_thread.setDaemon(True)
181 | self.game_thread.start()
182 |
183 | self.suspend = False # start the game.
184 |
185 | def process_game(self):
186 | resignd = None
187 |
188 | while True:
189 | # Short sleep in order to avoid busy running.
190 | time.sleep(0.1)
191 |
192 | if self.suspend or self.game_over:
193 | continue
194 |
195 | to_move = self.board.to_move
196 | move_num = self.board.move_num
197 |
198 | if self.turns[to_move] == "compute":
199 | move = self.genmove("black" if to_move == BLACK else "white")
200 |
201 | vtx = self.board.last_move
202 | if move == "resign":
203 | vtx = RESIGN
204 | resignd = to_move
205 |
206 | if move == "pass":
207 | self.insert_scroll_text("電腦虛手")
208 |
209 | if vtx != PASS or vtx != RESIGN:
210 | self.update_canvas(vtx, to_move, move_num+1)
211 |
212 | # Dump the search verbose.
213 | if self.args.verbose:
214 | self.insert_scroll_text(self.last_verbose)
215 | self.acquire_vtx = None
216 | else:
217 | if self.acquire_vtx != None:
218 | if self.acquire_vtx == PASS:
219 | self.board.play(PASS)
220 | self.canvas.delete(self.rect)
221 | else:
222 | self.board.play(self.acquire_vtx)
223 | self.update_canvas(self.acquire_vtx, to_move, move_num+1)
224 | self.acquire_vtx = None
225 |
226 | if resignd != None:
227 | if resignd == BLACK:
228 | self.insert_scroll_text("黑棋投降")
229 | else:
230 | self.insert_scroll_text("白棋投降")
231 | resignd = None
232 | self.game_over = True
233 | self.network.clear_cache()
234 | elif self.board.num_passes >= 2:
235 | score = self.board.final_score()
236 | if abs(score) <= 0.01:
237 | self.insert_scroll_text("和局")
238 | elif score > 0:
239 | self.insert_scroll_text("黑勝{}目".format(score))
240 | elif score < 0:
241 | self.insert_scroll_text("白勝{}目".format(-score))
242 | self.game_over = True
243 | self.network.clear_cache()
244 |
245 | def update_canvas(self, vtx, to_move, move_num):
246 | # Update the board canvas.
247 |
248 | r = self.board.get_x(vtx)
249 | c = self.board.get_y(vtx)
250 | self.draw_stone(to_move, (r,c))
251 |
252 | if self.board.removed_cnt != 0:
253 | curr = len(self.board.history) - 1
254 | post_state = self.board.history[curr-1]
255 | for v in range(len(post_state)):
256 | if self.board.state[v] == EMPTY and post_state[v] != EMPTY:
257 | self.canvas.delete(self.oval_buffer[self.board.vertex_to_index(v)])
258 | self.canvas.delete(self.text_buffer[self.board.vertex_to_index(v)])
259 |
260 | def scan_move(self, event):
261 | # Acquire a move after the player click the board.
262 |
263 | x, y = event.x, event.y
264 | r, c = self.convert_xy_to_rc((x, y))
265 |
266 | if r < 0 or r >= self.board.board_size:
267 | return
268 |
269 | if c < 0 or c >= self.board.board_size:
270 | return
271 |
272 | self.acquire_move(self.board.get_vertex(r,c))
273 |
274 | def acquire_move(self, vtx):
275 | # Set acquire move if the move is legal.
276 |
277 | if self.board.legal(vtx):
278 | self.acquire_vtx = vtx
279 |
--------------------------------------------------------------------------------
/img/alphago_zero_mcts.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CGLemon/pyDLGO/731792f82d9079aa033a84419627d7127db5f349/img/alphago_zero_mcts.jpg
--------------------------------------------------------------------------------
/img/dlgo_vs_leela.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CGLemon/pyDLGO/731792f82d9079aa033a84419627d7127db5f349/img/dlgo_vs_leela.gif
--------------------------------------------------------------------------------
/img/loss.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CGLemon/pyDLGO/731792f82d9079aa033a84419627d7127db5f349/img/loss.gif
--------------------------------------------------------------------------------
/img/loss_plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CGLemon/pyDLGO/731792f82d9079aa033a84419627d7127db5f349/img/loss_plot.png
--------------------------------------------------------------------------------
/img/mcts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CGLemon/pyDLGO/731792f82d9079aa033a84419627d7127db5f349/img/mcts.png
--------------------------------------------------------------------------------
/img/overfitting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CGLemon/pyDLGO/731792f82d9079aa033a84419627d7127db5f349/img/overfitting.png
--------------------------------------------------------------------------------
/img/policy_value.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CGLemon/pyDLGO/731792f82d9079aa033a84419627d7127db5f349/img/policy_value.gif
--------------------------------------------------------------------------------
/img/puct.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CGLemon/pyDLGO/731792f82d9079aa033a84419627d7127db5f349/img/puct.gif
--------------------------------------------------------------------------------
/img/sabaki-analysis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CGLemon/pyDLGO/731792f82d9079aa033a84419627d7127db5f349/img/sabaki-analysis.png
--------------------------------------------------------------------------------
/img/score_board.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CGLemon/pyDLGO/731792f82d9079aa033a84419627d7127db5f349/img/score_board.png
--------------------------------------------------------------------------------
/img/screenshot_sabaki_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CGLemon/pyDLGO/731792f82d9079aa033a84419627d7127db5f349/img/screenshot_sabaki_01.png
--------------------------------------------------------------------------------
/img/screenshot_sabaki_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CGLemon/pyDLGO/731792f82d9079aa033a84419627d7127db5f349/img/screenshot_sabaki_02.png
--------------------------------------------------------------------------------
/img/screenshot_sabaki_03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CGLemon/pyDLGO/731792f82d9079aa033a84419627d7127db5f349/img/screenshot_sabaki_03.png
--------------------------------------------------------------------------------
/img/shortcut.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CGLemon/pyDLGO/731792f82d9079aa033a84419627d7127db5f349/img/shortcut.png
--------------------------------------------------------------------------------
/img/ucb.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CGLemon/pyDLGO/731792f82d9079aa033a84419627d7127db5f349/img/ucb.gif
--------------------------------------------------------------------------------
/mcts.py:
--------------------------------------------------------------------------------
1 | from board import Board, PASS, RESIGN, BLACK, WHITE
2 | from network import Network
3 | from time_control import TimeControl
4 |
5 | from sys import stderr, stdout, stdin
6 | import math
7 | import time
8 | import select
9 |
10 | class Node:
11 | C_PUCT = 0.5 # The PUCT hyperparameter. This value should be 1.25 in
12 | # AlphaGo Zero. However our value range is 1 ~ 0, not 1 ~ -1.
13 | # So we rescale this value as 0.5 (LeelaZero use it).
14 | def __init__(self, p):
15 | self.policy = p # The network raw policy from its parents node.
16 | self.nn_eval = 0 # The network raw eval from this node.
17 |
18 | self.values = 0 # The accumulation winrate.
19 | self.visits = 0 # The accumulation node visits.
20 | # The Q value must be equal to (self.values / self.visits)
21 | self.children = dict() # Next node.
22 |
23 | def clamp(self, v):
24 | # Map the winrate 1 ~ -1 to 1 ~ 0.
25 | return (v + 1) / 2
26 |
27 | def inverse(self, v):
28 | # Swap the side to move winrate.
29 | return 1 - v
30 |
31 | def expand_children(self, board: Board, network: Network):
32 | if board.last_move == PASS:
33 | score = board.final_score()
34 | if (board.to_move == BLACK and score > 0) or \
35 | (board.to_move == WHITE and score < 0):
36 | # Play pass move if we win the game.
37 | self.children[PASS] = Node(1.0)
38 | return 1;
39 |
40 | # Compute the net results.
41 | policy, value = network.get_outputs(board.get_features())
42 |
43 | for idx in range(board.num_intersections):
44 | vtx = board.index_to_vertex(idx)
45 |
46 | # Remove the all illegal move.
47 | if board.legal(vtx):
48 | p = policy[idx]
49 | self.children[vtx] = Node(p)
50 |
51 | # The pass move is alwaly the legal move. We don't need to
52 | # check it.
53 | self.children[PASS] = Node(policy[board.num_intersections])
54 |
55 | # The nn eval is side-to-move winrate.
56 | self.nn_eval = self.clamp(value[0])
57 |
58 | return self.nn_eval
59 |
60 | def remove_superko(self, board: Board):
61 | # Remove all superko moves.
62 |
63 | remove_list = list()
64 | for vtx, _ in self.children.items():
65 | if vtx != PASS:
66 | next_board = board.copy()
67 | next_board.play(vtx)
68 | if next_board.superko():
69 | remove_list.append(vtx)
70 | for vtx in remove_list:
71 | self.children.pop(vtx)
72 |
73 | def puct_select(self):
74 | parent_visits = max(self.visits, 1) # The parent visits must great than 1 because we want to get the
75 | # best policy value if it is the first selection.
76 | numerator = math.sqrt(parent_visits)
77 | puct_list = list()
78 |
79 | # Select the best node by PUCT algorithm.
80 | for vtx, child in self.children.items():
81 | q_value = 0 # init to lose
82 |
83 | if child.visits != 0:
84 | q_value = self.inverse(child.values / child.visits)
85 |
86 | puct = q_value + self.C_PUCT * child.policy * (numerator / (1+child.visits))
87 | puct_list.append((puct, vtx))
88 | return max(puct_list)[1]
89 |
90 | def update(self, v):
91 | self.values += v
92 | self.visits += 1
93 |
94 | def get_best_prob_move(self):
95 | gather_list = list()
96 | for vtx, child in self.children.items():
97 | gather_list.append((child.policy, vtx))
98 | return max(gather_list)[1]
99 |
100 | def get_best_move(self, resign_threshold):
101 | # Return best probability move if there are no playouts.
102 | if self.visits == 1:
103 | if resign_threshold is not None and \
104 | self.values < resign_threshold:
105 | return RESIGN
106 | else:
107 | return self.get_best_prob_move()
108 |
109 | # Get best move by number of node visits.
110 | gather_list = list()
111 | for vtx, child in self.children.items():
112 | gather_list.append((child.visits, vtx))
113 |
114 | vtx = max(gather_list)[1]
115 | child = self.children[vtx]
116 |
117 | # Play resin move if we think we have already lost.
118 | if resign_threshold is not None and \
119 | self.inverse(child.values / child.visits) < resign_threshold:
120 | return RESIGN
121 | return vtx
122 |
123 | def to_string(self, board: Board):
124 | # Collect some node information in order to debug.
125 |
126 | out = str()
127 | out += "Root -> W: {:5.2f}%, V: {}\n".format(
128 | 100.0 * self.values/self.visits,
129 | self.visits)
130 |
131 | gather_list = list()
132 | for vtx, child in self.children.items():
133 | gather_list.append((child.visits, vtx))
134 | gather_list.sort(reverse=True)
135 |
136 | for _, vtx in gather_list:
137 | child = self.children[vtx]
138 | if child.visits != 0:
139 | out += " {:4} -> W: {:5.2f}%, P: {:5.2f}%, V: {}\n".format(
140 | board.vertex_to_text(vtx),
141 | 100.0 * self.inverse(child.values/child.visits),
142 | 100.0 * child.policy,
143 | child.visits)
144 | return out
145 |
146 | def get_pv(self, board: Board, pv_str):
147 | # Get the best Principal Variation list since this
148 | # node.
149 | if len(self.children) == 0:
150 | return pv_str
151 |
152 | next_vtx = self.get_best_move(None)
153 | next = self.children[next_vtx]
154 | pv_str += "{} ".format(board.vertex_to_text(next_vtx))
155 | return next.get_pv(board, pv_str)
156 |
157 | def to_lz_analysis(self, board: Board):
158 | # Output the leela zero analysis string. Watch the detail
159 | # here: https://github.com/SabakiHQ/Sabaki/blob/master/docs/guides/engine-analysis-integration.md
160 | out = str()
161 |
162 | gather_list = list()
163 | for vtx, child in self.children.items():
164 | gather_list.append((child.visits, vtx))
165 | gather_list.sort(reverse=True)
166 |
167 | if len(gather_list) == 0:
168 | return str()
169 |
170 | i = 0
171 | for _, vtx in gather_list:
172 | child = self.children[vtx]
173 | if child.visits != 0:
174 | winrate = self.inverse(child.values/child.visits)
175 | prior = child.policy
176 | lcb = winrate
177 | order = i
178 | pv = "{} ".format(board.vertex_to_text(vtx))
179 | out += "info move {} visits {} winrate {} prior {} lcb {} order {} pv {}".format(
180 | board.vertex_to_text(vtx),
181 | child.visits,
182 | round(10000 * winrate),
183 | round(10000 * prior),
184 | round(10000 * lcb),
185 | order,
186 | child.get_pv(board, pv))
187 | i+=1
188 | out += '\n'
189 | return out
190 |
191 |
192 |
193 | # TODO: The MCTS performance is bad. Maybe the recursive is much
194 | # slower than loop. Or self.children do too many times mapping
195 | # operator. Try to fix it.
196 | class Search:
197 | def __init__(self, board: Board, network: Network, time_control: TimeControl):
198 | self.root_board = board # Root board positions, all simulation boards will fork from it.
199 | self.root_node = None # Root node, start the PUCT search from it.
200 | self.network = network
201 | self.time_control = time_control
202 | self.analysis_tag = {
203 | "interval" : -1
204 | }
205 |
206 | def _prepare_root_node(self):
207 | # Expand the root node first.
208 | self.root_node = Node(1)
209 | val = self.root_node.expand_children(self.root_board, self.network)
210 |
211 | # In order to avoid overhead, we only remove the superko positions in
212 | # the root.
213 | self.root_node.remove_superko(self.root_board)
214 | self.root_node.update(val)
215 |
216 | def _descend(self, color, curr_board, node):
217 | value = None
218 | if curr_board.num_passes >= 2:
219 | # The game is over. Compute the final score.
220 | score = curr_board.final_score()
221 | if score > 1e-4:
222 | # The black player is winner.
223 | value = 1 if color is BLACK else 0
224 | elif score < -1e-4:
225 | # The white player is winner.
226 | value = 1 if color is WHITE else 0
227 | else:
228 | # The game is draw
229 | value = 0.5
230 | elif len(node.children) != 0:
231 | # Select the next node by PUCT algorithm.
232 | vtx = node.puct_select()
233 | curr_board.to_move = color
234 | curr_board.play(vtx)
235 | color = (color + 1) % 2
236 | next_node = node.children[vtx]
237 |
238 | # go to the next node.
239 | value = self._descend(color, curr_board, next_node)
240 | else:
241 | # This is the termainated node. Now try to expand it.
242 | value = node.expand_children(curr_board, self.network)
243 |
244 | assert value != None, ""
245 | node.update(value)
246 |
247 | return node.inverse(value)
248 |
249 | def ponder(self, playouts, verbose):
250 | if self.root_board.num_passes >= 2:
251 | return str()
252 |
253 | analysis_clock = time.time()
254 | interval = self.analysis_tag["interval"]
255 |
256 | # Try to expand the root node first.
257 | self._prepare_root_node()
258 |
259 | for p in range(playouts):
260 | if p != 0 and \
261 | interval > 0 and \
262 | time.time() - analysis_clock > interval:
263 | analysis_clock = time.time()
264 | stdout.write(self.root_node.to_lz_analysis(self.root_board))
265 | stdout.flush()
266 |
267 | rlist, _, _ = select.select([stdin], [], [], 0)
268 | if rlist:
269 | break
270 |
271 | # Copy the root board because we need to simulate the current board.
272 | curr_board = self.root_board.copy()
273 | color = curr_board.to_move
274 |
275 | # Start the Monte Carlo tree search.
276 | self._descend(color, curr_board, self.root_node)
277 |
278 | # Always dump last tree stats for GUI, like Sabaki.
279 | if interval > 0 and \
280 | self.root_node.visits > 1:
281 | stdout.write(self.root_node.to_lz_analysis(self.root_board))
282 | stdout.flush()
283 |
284 | out_verbose = self.root_node.to_string(self.root_board)
285 | if verbose:
286 | # Dump verbose to stderr because we want to debug it on GTP
287 | # interface(sabaki).
288 | stderr.write(out_verbose)
289 | stderr.write("\n")
290 | stderr.flush()
291 |
292 | return out_verbose
293 |
294 | def think(self, playouts, resign_threshold, verbose):
295 | # Get the best move with Monte carlo tree. The time controller and max playouts limit
296 | # the search. More thinking time or playouts is stronger.
297 |
298 | if self.root_board.num_passes >= 2:
299 | return PASS, str()
300 |
301 | analysis_clock = time.time()
302 | interval = self.analysis_tag["interval"]
303 | self.time_control.clock()
304 | if verbose:
305 | stderr.write(str(self.time_control))
306 | stderr.write("\n")
307 | stderr.flush()
308 |
309 | # Prepare some basic information.
310 | to_move = self.root_board.to_move
311 | bsize = self.root_board.board_size
312 | move_num = self.root_board.move_num
313 |
314 | # Compute thinking time limit.
315 | max_time = self.time_control.get_thinking_time(to_move, bsize, move_num)
316 |
317 | # Try to expand the root node first.
318 | self._prepare_root_node()
319 |
320 | for p in range(playouts):
321 | if p != 0 and \
322 | interval > 0 and \
323 | time.time() - analysis_clock > interval:
324 | analysis_clock = time.time()
325 | stdout.write(self.root_node.to_lz_analysis(self.root_board))
326 | stdout.flush()
327 |
328 | if self.time_control.should_stop(max_time):
329 | break
330 |
331 | # Copy the root board because we need to simulate the current board.
332 | curr_board = self.root_board.copy()
333 | color = curr_board.to_move
334 |
335 | # Start the Monte Carlo tree search.
336 | self._descend(color, curr_board, self.root_node)
337 |
338 | # Eat the remaining time.
339 | self.time_control.took_time(to_move)
340 |
341 | # Always dump last tree stats for GUI, like Sabaki.
342 | if interval > 0 and \
343 | self.root_node.visits > 1:
344 | stdout.write(self.root_node.to_lz_analysis(self.root_board))
345 | stdout.flush()
346 |
347 | out_verbose = self.root_node.to_string(self.root_board)
348 | if verbose:
349 | # Dump verbose to stderr because we want to debug it on GTP
350 | # interface(sabaki).
351 | stderr.write(out_verbose)
352 | stderr.write(str(self.time_control))
353 | stderr.write("\n")
354 | stderr.flush()
355 |
356 | return self.root_node.get_best_move(resign_threshold), out_verbose
357 |
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from config import *
6 |
7 | class FullyConnect(nn.Module):
8 | def __init__(self, in_size,
9 | out_size,
10 | relu=True):
11 | super().__init__()
12 | self.relu = relu
13 | self.linear = nn.Linear(in_size, out_size)
14 |
15 | def forward(self, x):
16 | x = self.linear(x)
17 | return F.relu(x, inplace=True) if self.relu else x
18 |
19 | class ConvBlock(nn.Module):
20 | def __init__(self, in_channels,
21 | out_channels,
22 | kernel_size,
23 | relu=True):
24 | super().__init__()
25 |
26 | assert kernel_size in (1, 3)
27 | self.relu = relu
28 | self.conv = nn.Conv2d(
29 | in_channels,
30 | out_channels,
31 | kernel_size,
32 | padding="same",
33 | bias=True,
34 | )
35 | self.bn = nn.BatchNorm2d(
36 | out_channels,
37 | eps=1e-5
38 | )
39 |
40 | nn.init.kaiming_normal_(self.conv.weight,
41 | mode="fan_out",
42 | nonlinearity="relu")
43 | def forward(self, x):
44 | x = self.conv(x)
45 | x = self.bn(x)
46 | return F.relu(x, inplace=True) if self.relu else x
47 |
48 | class ResBlock(nn.Module):
49 | def __init__(self, channels, se_size=None):
50 | super().__init__()
51 | self.with_se=False
52 | self.channels=channels
53 |
54 | self.conv1 = ConvBlock(
55 | in_channels=channels,
56 | out_channels=channels,
57 | kernel_size=3
58 | )
59 | self.conv2 = ConvBlock(
60 | in_channels=channels,
61 | out_channels=channels,
62 | kernel_size=3,
63 | relu=False
64 | )
65 |
66 | if se_size != None:
67 | self.with_se = True
68 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
69 | self.squeeze = FullyConnect(
70 | in_size=channels,
71 | out_size=se_size,
72 | relu=True
73 | )
74 | self.excite = FullyConnect(
75 | in_size=se_size,
76 | out_size=2 * channels,
77 | relu=False
78 | )
79 |
80 | def forward(self, x):
81 | identity = x
82 |
83 | out = self.conv1(x)
84 | out = self.conv2(out)
85 |
86 | if self.with_se:
87 | b, c, _, _ = out.size()
88 | seprocess = self.avg_pool(out)
89 | seprocess = torch.flatten(seprocess, start_dim=1, end_dim=3)
90 | seprocess = self.squeeze(seprocess)
91 | seprocess = self.excite(seprocess)
92 |
93 | gammas, betas = torch.split(seprocess, self.channels, dim=1)
94 | gammas = torch.reshape(gammas, (b, c, 1, 1))
95 | betas = torch.reshape(betas, (b, c, 1, 1))
96 | out = torch.sigmoid(gammas) * out + betas
97 |
98 | out += identity
99 |
100 | return F.relu(out, inplace=True)
101 |
102 |
103 | class Network(nn.Module):
104 | def __init__(self, board_size,
105 | input_channels=INPUT_CHANNELS,
106 | block_size=BLOCK_SIZE,
107 | block_channels=BLOCK_CHANNELS,
108 | policy_channels=POLICY_CHANNELS,
109 | value_channels=VALUE_CHANNELS,
110 | use_se=USE_SE,
111 | use_gpu=USE_GPU):
112 | super().__init__()
113 |
114 | self.nn_cache = {}
115 |
116 | self.block_size = block_size
117 | self.residual_channels = block_channels
118 | self.policy_channels = policy_channels
119 | self.value_channels = value_channels
120 | self.value_layers = 256
121 | self.board_size = board_size
122 | self.spatial_size = self.board_size ** 2
123 | self.input_channels = input_channels
124 | self.use_se = use_se
125 | self.use_gpu = True if torch.cuda.is_available() and use_gpu else False
126 | self.gpu_device = torch.device("cpu")
127 |
128 | if self.use_se:
129 | assert self.residual_channels // 2 == 0, "BLOCK_CHANNELS must be divided by 2."
130 |
131 | self.construct_layers()
132 | if self.use_gpu:
133 | self.gpu_device = torch.device("cuda")
134 | self.to_gpu_device()
135 |
136 | def to_gpu_device(self):
137 | self = self.to(self.gpu_device)
138 |
139 | def construct_layers(self):
140 | self.input_conv = ConvBlock(
141 | in_channels=self.input_channels,
142 | out_channels=self.residual_channels,
143 | kernel_size=3,
144 | relu=True
145 | )
146 |
147 | # residual tower
148 | self.residual_tower = nn.ModuleList()
149 | for s in range(self.block_size):
150 | se_size = self.residual_channels // 2 if self.use_se else None
151 | self.residual_tower.append(
152 | ResBlock(self.residual_channels, se_size))
153 |
154 | # policy head
155 | self.policy_conv = ConvBlock(
156 | in_channels=self.residual_channels,
157 | out_channels=self.policy_channels,
158 | kernel_size=1,
159 | relu=True
160 | )
161 | self.policy_fc = FullyConnect(
162 | in_size=self.policy_channels * self.spatial_size,
163 | out_size=self.spatial_size + 1,
164 | relu=False
165 | )
166 |
167 | # value head
168 | self.value_conv = ConvBlock(
169 | in_channels=self.residual_channels,
170 | out_channels=self.value_channels,
171 | kernel_size=1,
172 | relu=True
173 | )
174 |
175 | self.value_fc = FullyConnect(
176 | in_size=self.value_channels * self.spatial_size,
177 | out_size=self.value_layers,
178 | relu=True
179 | )
180 | self.winrate_fc = FullyConnect(
181 | in_size=self.value_layers,
182 | out_size=1,
183 | relu=False
184 | )
185 |
186 | def forward(self, planes):
187 | x = self.input_conv(planes)
188 |
189 | # residual tower
190 | for block in self.residual_tower:
191 | x = block(x)
192 |
193 | # policy head
194 | pol = self.policy_conv(x)
195 | pol = self.policy_fc(torch.flatten(pol, start_dim=1, end_dim=3))
196 |
197 | # value head
198 | val = self.value_conv(x)
199 | val = self.value_fc(torch.flatten(val, start_dim=1, end_dim=3))
200 | val = self.winrate_fc(val)
201 |
202 | return pol, torch.tanh(val)
203 |
204 | @torch.no_grad()
205 | def get_outputs(self, planes):
206 | # TODO: Limit the NN cache size.
207 |
208 | h = hash(planes.tostring())
209 | res = self.nn_cache.get(h) # search the NN computation
210 |
211 | if res is not None:
212 | p, v = res
213 | return p, v
214 |
215 | m = nn.Softmax(dim=1)
216 | x = torch.unsqueeze(torch.tensor(planes, dtype=torch.float32), dim=0)
217 | if self.use_gpu:
218 | x = x.to(self.gpu_device)
219 | p, v = self.forward(x)
220 | p, v = m(p).data.tolist()[0], v.data.tolist()[0]
221 |
222 | self.nn_cache[h] = (p, v) # save the NN computation
223 |
224 | return p, v
225 |
226 | def clear_cache(self):
227 | self.nn_cache.clear()
228 |
229 | def trainable(self, t=True):
230 | torch.set_grad_enabled(t)
231 | if t==True:
232 | self.train()
233 | else:
234 | self.eval()
235 |
236 | def save_pt(self, filename):
237 | torch.save(self.state_dict(), filename)
238 |
239 | def load_pt(self, filename):
240 | self.load_state_dict(
241 | torch.load(filename, map_location=self.gpu_device, weights_only=True)
242 | )
243 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | torchaudio
4 | numpy
5 | tk
6 | matplotlib
7 |
--------------------------------------------------------------------------------
/sgf.py:
--------------------------------------------------------------------------------
1 | import glob, os, argparse
2 | from board import BLACK, WHITE, EMPTY, INVLD
3 |
4 | class SgfParser:
5 | def __init__(self, sgf_string):
6 | self.history = list()
7 | self.black_player = str()
8 | self.white_player = str()
9 | self.winner = INVLD
10 | self.board_size = None
11 | self.komi = None
12 | self._parse(sgf_string)
13 |
14 | def _process_key_value(self, key, val):
15 | def as_move(m, bsize=self.board_size):
16 | if len(m) == 0 or m == "tt":
17 | return None
18 | x = ord(m[0]) - ord('a')
19 | y = ord(m[1]) - ord('a')
20 | y = bsize - 1 - y
21 | return (x, y)
22 |
23 | if key == "SZ":
24 | self.board_size = int(val)
25 | elif key == "KM":
26 | self.komi = float(val)
27 | elif key == "B":
28 | self.history.append((BLACK, as_move(val)))
29 | elif key == "W":
30 | self.history.append((WHITE, as_move(val)))
31 | elif key == "PB":
32 | self.black_player = val
33 | elif key == "PW":
34 | self.white_player = val
35 | elif key == "AB" or key == "AW":
36 | raise Exception("Do not support for AB/AW tag in the SGF file.")
37 | elif key == "RE":
38 | if "B+" in val:
39 | self.winner = BLACK
40 | elif "W+" in val:
41 | self.winner = WHITE
42 | elif val == "0":
43 | self.winner = EMPTY
44 | else:
45 | self.winner = INVLD
46 |
47 | def _parse(self, sgf):
48 | nesting = 0
49 | idx = 0
50 | node_cnt = 0
51 | key = str()
52 | while idx < len(sgf):
53 | c = sgf[idx]
54 | idx += 1;
55 |
56 | if c == '(':
57 | nesting += 1
58 | elif c == ')':
59 | nesting -= 1
60 |
61 | if c in ['(', ')', '\t', '\n', '\r'] or nesting != 1:
62 | continue
63 | elif c == ';':
64 | node_cnt += 1
65 | elif c == '[':
66 | end = sgf.find(']', idx)
67 | val = sgf[idx:end]
68 | self._process_key_value(key, val)
69 | key = str()
70 | idx = end+1
71 | else:
72 | key += c
73 |
74 | def _load_file(filename):
75 | try:
76 | with open(filename, "r") as f:
77 | data = f.read().strip()
78 | except Exception as e:
79 | print(e)
80 | return None
81 | return data
82 |
83 | def chop_sgfs_string(sgfs_string):
84 | sgfs_list = list()
85 | sgfs_string = sgfs_string.strip()
86 |
87 | nesting = 0
88 | head_idx = 0
89 | tail_idx = 0
90 | while tail_idx < len(sgfs_string):
91 | c = sgfs_string[tail_idx]
92 | tail_idx += 1;
93 |
94 | if c == '(':
95 | if nesting == 0:
96 | head_idx = tail_idx - 1
97 | nesting += 1
98 | elif c == ')':
99 | nesting -= 1
100 | if nesting == 0:
101 | sgfs_list.append(sgfs_string[head_idx:tail_idx])
102 |
103 | if c in ['(', ')', ';', '\t', '\n', '\r'] or nesting != 1:
104 | continue
105 | elif c == '[':
106 | end = sgfs_string.find(']', tail_idx)
107 | tail_idx = end + 1
108 | return sgfs_list
109 |
110 | def parse_from_dir(root):
111 | sgfs_files = list()
112 | sgfs_files.extend(glob.glob(os.path.join(root, "*.sgf")))
113 | sgfs_files.extend(glob.glob(os.path.join(root, "*.sgfs")))
114 | sgfs = list()
115 | for filename in sgfs_files:
116 | data = _load_file(filename)
117 | if data:
118 | sgfs_list = chop_sgfs_string(data)
119 | for sgf_string in sgfs_list:
120 | try:
121 | sgf = SgfParser(sgf_string)
122 | sgfs.append(sgf)
123 | except Exception as e:
124 | print(e)
125 | return sgfs
126 |
127 | if __name__ == "__main__":
128 | parser = argparse.ArgumentParser()
129 | parser.add_argument("-d", "--sgf-dir", metavar="",
130 | help="input SGF directory", type=str)
131 | args = parser.parse_args()
132 |
133 | try:
134 | sgfs = get_sgf_from_dir(args.sgf_dir)
135 | print("\nSuccessfuly parse every SGF string...")
136 | except Exception as e:
137 | print(e)
138 |
--------------------------------------------------------------------------------
/sgf.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CGLemon/pyDLGO/731792f82d9079aa033a84419627d7127db5f349/sgf.zip
--------------------------------------------------------------------------------
/time_control.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | class TimeControl:
4 | def __init__(self):
5 | self.main_time = 0
6 | self.byo_time = 7 * 24 * 60 * 60 # one week per move
7 | self.byo_stones = 1
8 |
9 | self.maintime_left = [0, 0]
10 | self.byotime_left = [0, 0]
11 | self.stones_left = [0, 0]
12 | self.in_byo = [False, False]
13 |
14 | self.clock_time = time.time()
15 | self.reset()
16 |
17 | def check_in_byo(self):
18 | self.in_byo[0] = True if self.maintime_left[0] <= 0 else False
19 | self.in_byo[1] = True if self.maintime_left[1] <= 0 else False
20 |
21 | def reset(self):
22 | self.maintime_left = [self.main_time] * 2
23 | self.byotime_left = [self.byo_time] * 2
24 | self.stones_left = [self.byo_stones] * 2
25 | self.check_in_byo()
26 |
27 | def time_settings(self, main_time, byo_time, byo_stones):
28 | self.main_time = main_time
29 | self.byo_time = byo_time
30 | self.byo_stones = byo_stones
31 | self.reset()
32 |
33 | def time_left(self, color, time, stones):
34 | if stones == 0:
35 | self.maintime_left[color] = time
36 | else:
37 | self.maintime_left[color] = 0
38 | self.byotime_left[color] = time
39 | self.stones_left[color] = stones
40 | self.check_in_byo()
41 |
42 | def clock(self):
43 | self.clock_time = time.time()
44 |
45 | def took_time(self, color):
46 | remaining_took_time = time.time() - self.clock_time
47 | if not self.in_byo[color]:
48 | if self.maintime_left[color] > remaining_took_time:
49 | self.maintime_left[color] -= remaining_took_time
50 | remaining_took_time = -1
51 | else:
52 | remaining_took_time -= self.maintime_left[color]
53 | self.maintime_left[color] = 0
54 | self.in_byo[color] = True
55 |
56 | if self.in_byo[color] and remaining_took_time > 0:
57 | self.byotime_left[color] -= remaining_took_time
58 | self.stones_left[color] -= 1
59 | if self.stones_left[color] == 0:
60 | self.stones_left[color] = self.byo_stones
61 | self.byotime_left[color] = self.byo_time
62 |
63 | def get_thinking_time(self, color, board_size, move_num):
64 | estimate_moves_left = max(4, int(board_size * board_size * 0.4) - move_num)
65 | lag_buffer = 1 # Remaining some time for network hiccups or GUI lag
66 | remaining_time = self.maintime_left[color] + self.byotime_left[color] - lag_buffer
67 | if self.byo_stones == 0:
68 | return remaining_time / estimate_moves_left
69 | return remaining_time / self.stones_left[color]
70 |
71 | def should_stop(self, max_time):
72 | elapsed = time.time() - self.clock_time
73 | return elapsed > max_time
74 |
75 | def get_timeleft_string(self, color):
76 | out = str()
77 | if not self.in_byo[color]:
78 | out += "{s} sec".format(
79 | s=int(self.maintime_left[color]))
80 | else:
81 | out += "{s} sec, {c} stones".format(
82 | s=int(self.byotime_left[color]),
83 | c=self.stones_left[color])
84 | return out
85 |
86 | def __str__(self):
87 | return "".join(["Black: ",
88 | self.get_timeleft_string(0),
89 | " | White: ",
90 | self.get_timeleft_string(1)])
91 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from network import Network
2 | from config import BOARD_SIZE, INPUT_CHANNELS
3 | from board import Board, PASS, BLACK, WHITE, EMPTY, INVLD, NUM_INTESECTIONS
4 |
5 | import sgf, argparse
6 | import copy, time, os, shutil, glob
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.optim as optim
13 |
14 | CACHE_TRAIN_DIR = "tdata-cache"
15 | CACHE_VALID_DIR = "vdata-cache"
16 |
17 | def gather_filenames(dirname):
18 | def gather_recursive_files(root):
19 | l = list()
20 | for name in glob.glob(os.path.join(root, "*")):
21 | if os.path.isdir(name):
22 | l.extend(gather_recursive_files(name))
23 | else:
24 | l.append(name)
25 | return l
26 | return gather_recursive_files(root=dirname)
27 |
28 | def get_currtime():
29 | lt = time.localtime(time.time())
30 | return "{y}-{m}-{d} {h:02d}:{mi:02d}:{s:02d}".format(
31 | y=lt.tm_year, m=lt.tm_mon, d=lt.tm_mday, h=lt.tm_hour, mi=lt.tm_min, s=lt.tm_sec)
32 |
33 | def get_weights_name(prefix):
34 | return "{}-{}.pt".format(prefix, get_currtime().replace(":", "-").replace(" ", "-"))
35 |
36 | class Data:
37 | def __init__(self):
38 | self.inputs = None # should be numpy array, shape is [INPUT_CHANNELS, BOARD_SIZE, BOARD_SIZE]
39 | self.policy = None # should be integer, range is 0 ~ NUM_INTESECTIONS
40 | self.value = None # should be float, range is -1 ~ 1
41 | self.to_move = None
42 |
43 | def _get_symmetry_plane(self, symm, plane):
44 | use_flip = False
45 | if symm // 4 != 0:
46 | use_flip = True
47 | symm = symm % 4
48 |
49 | transformed = np.rot90(plane, symm)
50 |
51 | if use_flip:
52 | transformed = np.flip(transformed, 1)
53 | return transformed
54 |
55 | def do_symmetry(self, symm=None):
56 | assert self.policy != None, ""
57 |
58 | if symm is None:
59 | symm = int(np.random.choice(8, 1)[0])
60 |
61 | for i in range(INPUT_CHANNELS-2): # last 2 channels is side to move.
62 | p = self.inputs[i]
63 | self.inputs[i][:][:] = self._get_symmetry_plane(symm, p)[:][:]
64 |
65 | if self.policy != NUM_INTESECTIONS:
66 | buf = np.zeros(NUM_INTESECTIONS)
67 | buf[self.policy] = 1
68 | buf = self._get_symmetry_plane(symm, np.reshape(buf, (BOARD_SIZE, BOARD_SIZE)))
69 | self.policy = int(np.argmax(buf))
70 |
71 | def from_npfile(self, filename):
72 | npdata = np.load(filename)
73 | self.inputs = npdata["i"]
74 | self.policy = npdata["p"][0]
75 | self.value = npdata["v"][0]
76 | self.to_move = npdata["t"][0]
77 |
78 | class Dataset(torch.utils.data.Dataset):
79 | def __init__(self, source_dir, num_virtual_samples=None):
80 | self.filenames = gather_filenames(source_dir)
81 | self.num_virtual_samples = num_virtual_samples \
82 | if num_virtual_samples is not None else len(self.filenames)
83 |
84 | def __len__(self):
85 | return self.num_virtual_samples
86 |
87 | def __getitem__(self, i):
88 | current_idx = i % len(self.filenames)
89 | data = Data()
90 | data.from_npfile(self.filenames[current_idx])
91 | data.do_symmetry()
92 |
93 | inputs = torch.tensor(data.inputs).float()
94 | policy = torch.tensor(data.policy).long()
95 | value = torch.tensor([data.value]).float()
96 | return inputs, policy, value
97 |
98 | # Load the SGF files and save the training data to the disk.
99 | class DataChopper:
100 | def __init__(self, dir_name, num_sgfs):
101 | self.num_data = 0
102 | self._chop_data(dir_name, num_sgfs)
103 |
104 | def __del__(self):
105 | # Do not delete the training data in the cache dir. We may
106 | # use them next time.
107 | pass
108 |
109 | def _chop_data(self, dir_name, num_sgfs):
110 | # Load the SGF files and tranfer them to training data.
111 | sgf_games = sgf.parse_from_dir(dir_name)
112 | total_games = min(len(sgf_games), num_sgfs)
113 |
114 | print("imported {} SGF files".format(total_games))
115 |
116 | if os.path.isdir(CACHE_TRAIN_DIR):
117 | shutil.rmtree(CACHE_TRAIN_DIR, ignore_errors=True)
118 | os.makedirs(CACHE_TRAIN_DIR)
119 |
120 | if os.path.isdir(CACHE_VALID_DIR):
121 | shutil.rmtree(CACHE_VALID_DIR, ignore_errors=True)
122 | os.makedirs(CACHE_VALID_DIR)
123 |
124 | for s in range(total_games):
125 | game = sgf_games[s]
126 | buf = self._process_one_game(game)
127 |
128 | if (s+1) % (max(1, total_games//100)) == 0:
129 | print("parsed {:.2f}% games".format(100 * (s+1)/total_games))
130 | self._save_data(buf)
131 | print("done! parsed {:.2f}% games".format(100))
132 |
133 | def _save_data(self, buf):
134 | size = len(buf)
135 |
136 | for i in range(size):
137 | # Allocate data buffer
138 | inputs_buf = np.zeros((INPUT_CHANNELS, BOARD_SIZE, BOARD_SIZE), dtype=np.int8)
139 | policy_buf = np.zeros((1), dtype=np.int32)
140 | value_buf = np.zeros((1), dtype=np.float32)
141 | to_move_buf = np.zeros((1), dtype=np.int8)
142 |
143 | # Fill the data buffer.
144 | data = buf[i]
145 | inputs_buf[:] = data.inputs[:]
146 | policy_buf[:] = data.policy
147 | value_buf[:] = data.value
148 | to_move_buf[:] = data.to_move
149 |
150 | # Save the date on disk.
151 | use_valid = int(np.random.choice(10, 1)[0]) == 0
152 | if use_valid:
153 | filename = os.path.join(CACHE_VALID_DIR, "data_{}.npz".format(self.num_data))
154 | else:
155 | filename = os.path.join(CACHE_TRAIN_DIR, "data_{}.npz".format(self.num_data))
156 | np.savez_compressed(filename, i=inputs_buf, p=policy_buf, v=value_buf, t=to_move_buf)
157 | self.num_data += 1
158 |
159 | def _process_one_game(self, game):
160 | # Collect training data from one SGF game.
161 |
162 | if game.board_size is not BOARD_SIZE:
163 | return list()
164 |
165 | temp = list()
166 | winner = game.winner
167 | board = Board(BOARD_SIZE)
168 |
169 | for color, move in game.history:
170 | data = Data()
171 | data.inputs = board.get_features()
172 | data.to_move = color
173 | if move:
174 | x, y = move
175 | data.policy = board.get_index(x, y)
176 | board.play(board.get_vertex(x, y))
177 | else:
178 | data.policy = board.num_intersections
179 | board.play(PASS)
180 | temp.append(data)
181 |
182 | for data in temp:
183 | if winner == EMPTY:
184 | data.value = 0
185 | elif winner == data.to_move:
186 | data.value = 1
187 | elif winner != data.to_move:
188 | data.value = -1
189 | return temp
190 |
191 | def plot_loss(record):
192 | if len(record) <= 1:
193 | return
194 |
195 | p_running_loss = []
196 | v_running_loss = []
197 | step = []
198 | for (s, p, v) in record:
199 | p_running_loss.append(p)
200 | v_running_loss.append(v)
201 | step.append(s)
202 |
203 | y_upper = max(max(p_running_loss), max(v_running_loss))
204 |
205 | plt.plot(step, p_running_loss, label="policy loss")
206 | plt.plot(step, v_running_loss, label="value loss")
207 | plt.ylabel("loss")
208 | plt.xlabel("steps")
209 | plt.ylim([0, y_upper * 1.1])
210 | plt.legend()
211 | plt.show()
212 |
213 | def load_checkpoint(network, optimizer, workspace):
214 | filenames = gather_filenames(workspace)
215 | if len(filenames) == 0:
216 | return network, optimizer, 0
217 |
218 | filenames.sort(key=os.path.getmtime, reverse=True)
219 | last_pt = filenames[0]
220 |
221 | state_dict = torch.load(last_pt, map_location=network.gpu_device, weights_only=True)
222 | network.load_state_dict(state_dict["network"])
223 | optimizer.load_state_dict(state_dict["optimizer"])
224 | steps = state_dict["steps"]
225 | return network, optimizer, steps
226 |
227 | def save_checkpoint(network, optimizer, steps, workspace):
228 | state_dict = dict()
229 | state_dict["network"] = network.state_dict()
230 | state_dict["optimizer"] = optimizer.state_dict()
231 | state_dict["steps"] = steps
232 | torch.save(state_dict, os.path.join(workspace, "checkpoint-s{}.pt".format(steps)))
233 |
234 | def training_process(args):
235 | # Set the network. Will push on GPU device later if it is
236 | # available.
237 | network = Network(BOARD_SIZE)
238 | network.trainable(True)
239 |
240 | # SGD instead of Adam. Seemd the SGD performance
241 | # is better than Adam.
242 | optimizer = optim.SGD(network.parameters(),
243 | lr=args.learning_rate,
244 | momentum=0.9,
245 | nesterov=True,
246 | weight_decay=1e-3)
247 | if not os.path.isdir(args.workspace):
248 | os.mkdir(args.workspace)
249 | network, optimizer, steps = load_checkpoint(network, optimizer, args.workspace)
250 | cross_entry = nn.CrossEntropyLoss()
251 | mse_loss = nn.MSELoss()
252 |
253 | if args.dir is not None:
254 | data_chopper = DataChopper(
255 | args.dir,
256 | args.imported_games
257 | )
258 |
259 | # Leave two cores for training pipe.
260 | num_workers = max(min(os.cpu_count(), 16) - 2 , 1) \
261 | if args.num_workers is None else max(args.num_workers, 1)
262 |
263 | print("Use {n} workers for loader.".format(n=num_workers))
264 |
265 | data_loader = torch.utils.data.DataLoader(
266 | dataset=Dataset(CACHE_TRAIN_DIR, args.batch_size * args.steps),
267 | batch_size=args.batch_size,
268 | shuffle=True,
269 | num_workers=num_workers
270 | )
271 |
272 | print("Start training...");
273 |
274 | # init some basic parameters
275 | p_running_loss = 0
276 | v_running_loss = 0
277 | max_steps = steps + args.steps
278 | running_loss_record = []
279 | clock_time = time.time()
280 |
281 | for _, batch in enumerate(data_loader):
282 | if args.lr_decay_steps is not None:
283 | learning_rate = optimizer.param_groups[0]["lr"]
284 | if (steps+1) % args.lr_decay_steps == 0:
285 | print("Drop the learning rate from {} to {}.".format(
286 | learning_rate,
287 | learning_rate * args.lr_decay_factor
288 | ))
289 | learning_rate = learning_rate * args.lr_decay_factor
290 | for param in optimizer.param_groups:
291 | param["lr"] = learning_rate
292 |
293 | # First, get the batch data.
294 | inputs, target_p, target_v = batch
295 |
296 | # Second, Move the data to GPU memory if we use it.
297 | if network.use_gpu:
298 | inputs = inputs.to(network.gpu_device)
299 | target_p = target_p.to(network.gpu_device)
300 | target_v = target_v.to(network.gpu_device)
301 |
302 | # Third, compute the network result.
303 | p, v = network(inputs)
304 |
305 | # Fourth, compute the loss result and update network.
306 | p_loss = cross_entry(p, target_p)
307 | v_loss = mse_loss(v, target_v)
308 | loss = p_loss + args.value_loss_scale * v_loss
309 |
310 | optimizer.zero_grad()
311 | loss.backward()
312 | optimizer.step()
313 |
314 | # Accumulate running loss.
315 | p_running_loss += p_loss.item()
316 | v_running_loss += v_loss.item()
317 |
318 | # Fifth, dump training verbose.
319 | if (steps+1) % args.verbose_steps == 0:
320 | elapsed = time.time() - clock_time
321 | rate = args.verbose_steps/elapsed
322 | remaining_steps = max_steps - steps
323 | estimate_remaining_time = int(remaining_steps/rate)
324 | print("[{}] steps: {}/{}, {:.2f}% -> policy loss: {:.4f}, value loss: {:.4f} | rate: {:.2f}(steps/sec), estimate: {}(sec)".format(
325 | get_currtime(),
326 | steps+1,
327 | max_steps,
328 | 100 * ((steps+1)/max_steps),
329 | p_running_loss/args.verbose_steps,
330 | v_running_loss/args.verbose_steps,
331 | rate,
332 | estimate_remaining_time))
333 | running_loss_record.append(
334 | (steps+1, p_running_loss/args.verbose_steps, v_running_loss/args.verbose_steps))
335 | p_running_loss = 0
336 | v_running_loss = 0
337 | save_checkpoint(network, optimizer, steps+1, args.workspace)
338 | clock_time = time.time()
339 | steps += 1
340 |
341 | print("Training is over.");
342 | if not args.noplot:
343 | # Sixth plot the running loss graph.
344 | plot_loss(running_loss_record)
345 | network.save_pt(get_weights_name("weights"))
346 |
347 | if __name__ == "__main__":
348 | parser = argparse.ArgumentParser()
349 | parser.add_argument("-d", "--dir", metavar="",
350 | help="The input SGF files directory. Will use data cache if set None.", type=str)
351 | parser.add_argument("-s", "--steps", metavar="",
352 | help="Terminate after these steps for each run.", type=int, required=True)
353 | parser.add_argument("-v", "--verbose-steps", metavar="",
354 | help="Dump verbose and save checkpoint every X steps.", type=int, default=1000)
355 | parser.add_argument("-b", "--batch-size", metavar="",
356 | help="The batch size number.", type=int, required=True)
357 | parser.add_argument("-l", "--learning-rate", metavar="",
358 | help="The learning rate.", type=float, required=True)
359 | parser.add_argument("-w", "--workspace", metavar="", default="workspace",
360 | help="Will save the checkpoint here.", type=str)
361 | parser.add_argument("-i", "--imported-games", metavar="",
362 | help="The max number of imported games.", type=int, default=10240000)
363 | parser.add_argument("--noplot", action="store_true",
364 | help="Disable plotting.", default=False)
365 | parser.add_argument("--lr-decay-steps", metavar="",
366 | help="Reduce the learning rate every X steps.", type=int, default=None)
367 | parser.add_argument("--lr-decay-factor", metavar="",
368 | help="The learning rate decay multiple factor.", type=float, default=0.1)
369 | parser.add_argument("--value-loss-scale", metavar="",
370 | help="Scaling factor of value loss. Default is 0.25 based on AlphaGo paper.", type=float, default=0.25)
371 | parser.add_argument("--num-workers", metavar="",
372 | help="Select a specific number of workerer for DataLoader.", type=int, default=None)
373 |
374 | args = parser.parse_args()
375 | training_process(args)
376 |
--------------------------------------------------------------------------------
/validate.py:
--------------------------------------------------------------------------------
1 | from network import Network
2 | from config import BOARD_SIZE
3 | from board import NUM_INTESECTIONS
4 |
5 | import argparse
6 | import torch
7 | import torch.nn as nn
8 | from train import CACHE_VALID_DIR, DataChopper, Dataset
9 |
10 | def report_stats(total, total_samples, correct_policy, total_value_loss):
11 | policy_acc = correct_policy / total if total > 0 else 0
12 | value_loss = total_value_loss / total if total > 0 else 0
13 | print(f"[{total}/{total_samples}] Policy Acc: {100 * policy_acc:.2f}% | Value MSE: {value_loss:.4f}")
14 |
15 | @torch.no_grad()
16 | def validate(args):
17 | # Prepare the validation dataset.
18 | if args.dir is not None:
19 | DataChopper(args.dir, args.imported_games)
20 |
21 | dataset = Dataset(CACHE_VALID_DIR)
22 | dataloader = torch.utils.data.DataLoader(
23 | dataset, batch_size=args.batch_size, shuffle=False)
24 |
25 | # Load the model.
26 | network = Network(BOARD_SIZE)
27 | network.trainable(False)
28 | if args.weights is not None:
29 | network.load_pt(args.weights)
30 | else:
31 | raise ValueError("Please specify --weights")
32 |
33 | # Validation loop.
34 | total = 0
35 | correct_policy = 0
36 | total_value_loss = 0.0
37 | mse_loss = nn.MSELoss(reduction='sum')
38 |
39 | total_samples = len(dataset)
40 |
41 | for idx, (inputs, policy, value) in enumerate(dataloader):
42 | inputs = inputs.to(network.gpu_device)
43 | policy = policy.to(network.gpu_device)
44 | value = value.to(network.gpu_device)
45 | pred_policy, pred_value = network(inputs)
46 | # policy: take the index of the maximum value
47 | pred_policy_idx = torch.argmax(pred_policy, dim=1)
48 | correct_policy += (pred_policy_idx == policy).sum().item()
49 | # value: MSE
50 | total_value_loss += mse_loss(pred_value.squeeze(), value.squeeze()).item()
51 | total += inputs.size(0)
52 |
53 | if idx % 10 == 0:
54 | report_stats(total, total_samples, correct_policy, total_value_loss)
55 | report_stats(total, total_samples, correct_policy, total_value_loss)
56 |
57 | if __name__ == "__main__":
58 | parser = argparse.ArgumentParser()
59 | parser.add_argument("-d", "--dir", metavar="",
60 | help="The input SGF files directory. Will use data cache if set None.", type=str, default=None)
61 | parser.add_argument("-w", "--weights", metavar="",
62 | help="The weights file name.", type=str, required=True)
63 | parser.add_argument("-b", "--batch-size", metavar="",
64 | help="The batch size number.", type=int, default=256)
65 | parser.add_argument("-i", "--imported-games", metavar="",
66 | help="The max number of imported games.", type=int, default=10240000)
67 | args = parser.parse_args()
68 | validate(args)
--------------------------------------------------------------------------------