├── img ├── play1.PNG ├── play2.PNG ├── play3.PNG ├── play4.PNG └── play5.PNG ├── cchess ├── test │ ├── test.PGN │ ├── test.xqf │ ├── test1.xqf │ ├── test2.xqf │ ├── EmptyTest.xqf │ ├── UnitTest.xqf │ ├── WildHouse.xqf │ ├── pawn_move.xqf │ ├── eleeye │ │ ├── BOOK.DAT │ │ └── ELEEYE.EXE │ ├── ucci_test1.xqf │ ├── ucci_test2.xqf │ ├── ucci_test3.xqf │ ├── BadMoveTest1.xqf │ ├── BadMoveTest2.xqf │ ├── BadMoveTest3.xqf │ ├── BadMoveTest4.xqf │ ├── FiveGoatsTest.xqf │ ├── test.cbf │ └── test2.cbf ├── doc │ ├── XqfFormat.txt │ └── cchess_move1.gif ├── __pycache__ │ ├── board.cpython-35.pyc │ ├── game.cpython-35.pyc │ ├── move.cpython-35.pyc │ ├── piece.cpython-35.pyc │ ├── ucci.cpython-35.pyc │ ├── __init__.cpython-35.pyc │ ├── exception.cpython-35.pyc │ ├── reader_cbf.cpython-35.pyc │ ├── reader_pgn.cpython-35.pyc │ ├── reader_xqf.cpython-35.pyc │ └── reader_dhtml.cpython-35.pyc ├── exception.py ├── __init__.py ├── reader_cbf.py ├── reader_pgn.py ├── game.py ├── reader_dhtml.py ├── move.py ├── ucci.py ├── board.py ├── reader_xqf.py └── piece.py ├── README.md ├── gameplay.py ├── game_convert.py ├── utils.py ├── get_data.ipynb ├── process_data.ipynb ├── chess_value_baseline.ipynb └── chess_policy_baseline.ipynb /img/play1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/img/play1.PNG -------------------------------------------------------------------------------- /img/play2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/img/play2.PNG -------------------------------------------------------------------------------- /img/play3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/img/play3.PNG -------------------------------------------------------------------------------- /img/play4.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/img/play4.PNG -------------------------------------------------------------------------------- /img/play5.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/img/play5.PNG -------------------------------------------------------------------------------- /cchess/test/test.PGN: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/test.PGN -------------------------------------------------------------------------------- /cchess/test/test.xqf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/test.xqf -------------------------------------------------------------------------------- /cchess/test/test1.xqf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/test1.xqf -------------------------------------------------------------------------------- /cchess/test/test2.xqf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/test2.xqf -------------------------------------------------------------------------------- /cchess/doc/XqfFormat.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/doc/XqfFormat.txt -------------------------------------------------------------------------------- /cchess/test/EmptyTest.xqf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/EmptyTest.xqf -------------------------------------------------------------------------------- /cchess/test/UnitTest.xqf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/UnitTest.xqf -------------------------------------------------------------------------------- /cchess/test/WildHouse.xqf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/WildHouse.xqf -------------------------------------------------------------------------------- /cchess/test/pawn_move.xqf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/pawn_move.xqf -------------------------------------------------------------------------------- /cchess/doc/cchess_move1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/doc/cchess_move1.gif -------------------------------------------------------------------------------- /cchess/test/eleeye/BOOK.DAT: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/eleeye/BOOK.DAT -------------------------------------------------------------------------------- /cchess/test/ucci_test1.xqf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/ucci_test1.xqf -------------------------------------------------------------------------------- /cchess/test/ucci_test2.xqf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/ucci_test2.xqf -------------------------------------------------------------------------------- /cchess/test/ucci_test3.xqf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/ucci_test3.xqf -------------------------------------------------------------------------------- /cchess/test/BadMoveTest1.xqf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/BadMoveTest1.xqf -------------------------------------------------------------------------------- /cchess/test/BadMoveTest2.xqf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/BadMoveTest2.xqf -------------------------------------------------------------------------------- /cchess/test/BadMoveTest3.xqf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/BadMoveTest3.xqf -------------------------------------------------------------------------------- /cchess/test/BadMoveTest4.xqf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/BadMoveTest4.xqf -------------------------------------------------------------------------------- /cchess/test/FiveGoatsTest.xqf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/FiveGoatsTest.xqf -------------------------------------------------------------------------------- /cchess/test/eleeye/ELEEYE.EXE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/test/eleeye/ELEEYE.EXE -------------------------------------------------------------------------------- /cchess/__pycache__/board.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/__pycache__/board.cpython-35.pyc -------------------------------------------------------------------------------- /cchess/__pycache__/game.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/__pycache__/game.cpython-35.pyc -------------------------------------------------------------------------------- /cchess/__pycache__/move.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/__pycache__/move.cpython-35.pyc -------------------------------------------------------------------------------- /cchess/__pycache__/piece.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/__pycache__/piece.cpython-35.pyc -------------------------------------------------------------------------------- /cchess/__pycache__/ucci.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/__pycache__/ucci.cpython-35.pyc -------------------------------------------------------------------------------- /cchess/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /cchess/__pycache__/exception.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/__pycache__/exception.cpython-35.pyc -------------------------------------------------------------------------------- /cchess/__pycache__/reader_cbf.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/__pycache__/reader_cbf.cpython-35.pyc -------------------------------------------------------------------------------- /cchess/__pycache__/reader_pgn.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/__pycache__/reader_pgn.cpython-35.pyc -------------------------------------------------------------------------------- /cchess/__pycache__/reader_xqf.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/__pycache__/reader_xqf.cpython-35.pyc -------------------------------------------------------------------------------- /cchess/__pycache__/reader_dhtml.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupticybee/icyElephant/HEAD/cchess/__pycache__/reader_dhtml.cpython-35.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Icy Elephant 2 | Like Go, Chinese chess is an ancient chess game, After I trained a CNN model to play Go ( https://github.com/bupticybee/icygo ), I wonder if the similar approach can be used to play chinese chess. 3 | 4 | So I trained a CNN to do the same thing in chinese chess. I played with it, Althrough the network preduces some interesting result, it is not strong enough, in the 10 games I played with the Neural network, I win all of them by large margin. 5 | 6 | Here are the first five move in a CNN self-play: 7 | 8 | ![](./img/play1.PNG) 9 | ![](./img/play2.PNG) 10 | ![](./img/play3.PNG) 11 | ![](./img/play4.PNG) 12 | ![](./img/play5.PNG) 13 | 14 | 15 | The data I use can be downloaded here: 16 | https://pan.baidu.com/s/1JCmweZUREJxjIMXQlL-o9g 17 | 18 | Follow the code in chess_policy_resnet10.ipynb to train the model. 19 | After you trained your model, follow the code in play_against_computer.ipynb to play with the model you train. 20 | 21 | Have Fun -------------------------------------------------------------------------------- /cchess/exception.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | Copyright (C) 2014 walker li 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with this program. If not, see . 18 | ''' 19 | 20 | class CChessException(): 21 | def __init__(self, reason): 22 | self.reason = reason 23 | 24 | #-----------------------------------------------------# 25 | if __name__ == '__main__': 26 | pass 27 | 28 | -------------------------------------------------------------------------------- /cchess/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | Copyright (C) 2014 walker li 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with this program. If not, see . 18 | ''' 19 | 20 | from cchess.piece import * 21 | from cchess.board import * 22 | from cchess.game import * 23 | from cchess.move import * 24 | from cchess.ucci import * 25 | from cchess.reader_xqf import read_from_xqf 26 | from cchess.reader_cbf import read_from_cbf 27 | from cchess.reader_pgn import read_from_pgn 28 | from cchess.reader_dhtml import read_from_dhtml 29 | from cchess.exception import * -------------------------------------------------------------------------------- /gameplay.py: -------------------------------------------------------------------------------- 1 | import xmltodict 2 | from cchess import * 3 | class GamePlay: 4 | def __init__(self): 5 | self.bb = BaseChessBoard(FULL_INIT_FEN) 6 | self.red = True 7 | 8 | def get_side(self): 9 | return "red" if self.red else "black" 10 | 11 | def make_move(self,move): 12 | i = move 13 | x1,y1,x2,y2 = int(i[0]),int(i[1]),int(i[3]),int(i[4]) 14 | #boardarr = bb.get_board_arr() 15 | if self.red: 16 | moveresult = self.bb.move(Pos(x1,y1),Pos(x2,y2)) 17 | else: 18 | moveresult = self.bb.move(Pos(x1,9-y1),Pos(x2,9-y2)) 19 | assert(moveresult != None) 20 | self.red = not self.red 21 | 22 | def print_board(self): 23 | self.bb.print_board() 24 | 25 | def get_board_arr(self): 26 | feature_list = {"red":['A', 'B', 'C', 'K', 'N', 'P', 'R'] 27 | ,"black":['a', 'b', 'c', 'k', 'n', 'p', 'r']} 28 | # chess picker features 29 | picker_x = [] 30 | picker_y = [] 31 | boardarr = self.bb.get_board_arr() 32 | if self.red: 33 | for one in feature_list['red']: 34 | picker_x.append(np.asarray(boardarr == one,dtype=np.uint8)) 35 | for one in feature_list['black']: 36 | picker_x.append(np.asarray(boardarr == one,dtype=np.uint8)) 37 | else: 38 | for one in feature_list['black']: 39 | picker_x.append(np.asarray(boardarr == one,dtype=np.uint8)) 40 | for one in feature_list['red']: 41 | picker_x.append(np.asarray(boardarr == one,dtype=np.uint8)) 42 | picker_x = np.asarray(picker_x) 43 | if self.red: 44 | return picker_x 45 | else: 46 | return picker_x[:,::-1,:] -------------------------------------------------------------------------------- /cchess/reader_cbf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | Copyright (C) 2014 walker li 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with this program. If not, see . 18 | ''' 19 | 20 | import os 21 | 22 | from xml.etree import ElementTree as et 23 | 24 | from cchess.board import * 25 | from cchess.game import * 26 | from cchess.exception import * 27 | 28 | #-----------------------------------------------------# 29 | 30 | def read_from_cbf(file_name): 31 | 32 | def decode_move(move_str): 33 | p_from = Pos(int(move_str[0]), 9 - int(move_str[1])) 34 | p_to = Pos(int(move_str[3]), 9 - int(move_str[4])) 35 | 36 | return (p_from, p_to) 37 | 38 | tree = et.parse(file_name) 39 | root = tree.getroot() 40 | 41 | head = root.find("Head") 42 | for node in head.getchildren() : 43 | if node.tag == "FEN": 44 | init_fen = node.text 45 | #print node.tag 46 | 47 | books = {} 48 | board = BaseChessBoard(init_fen) 49 | 50 | move_list = root.find("MoveList").getchildren() 51 | 52 | game = Game(board) 53 | last_move = game 54 | step_no = 1 55 | for node in move_list[1:] : 56 | move_from, move_to = decode_move(node.attrib["value"]) 57 | if board.is_valid_move(move_from, move_to) : 58 | new_move = board.move(move_from, move_to) 59 | last_move.append_next_move(new_move) 60 | last_move = new_move 61 | board.next_turn() 62 | else : 63 | raise CChessException("bad move at %d %s %s" % (step_no, move_from, move_to)) 64 | step_no += 1 65 | return game 66 | 67 | #-----------------------------------------------------# 68 | 69 | if __name__ == '__main__': 70 | game = read_from_cbf('test\\test.cbf') 71 | game.print_init_board() 72 | game.print_chinese_moves() 73 | 74 | -------------------------------------------------------------------------------- /cchess/test/test.cbf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 第一局 屏风马巡河炮抵当头炮局(和) 5 | 6 | 第一局 屏风马巡河炮抵当头炮局(和).xqf 7 | 8 | 居荣鑫《象棋古谱全局三种》 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 居荣鑫 26 | 27 | 28 | 29 | 30 | 2014-09-08 20:19:09 31 | C03 32 | 1 33 | 34 | 0 35 | 36 | rnbakabnr/9/1c5c1/p1p1p1p1p/9/9/P1P1P1P1P/1C5C1/9/RNBAKABNR w - - 0 1 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /game_convert.py: -------------------------------------------------------------------------------- 1 | import xmltodict 2 | from cchess import * 3 | def convert_game(onefile,feature_list): 4 | doc = xmltodict.parse(open(onefile,encoding='utf-8').read()) 5 | fen = doc['ChineseChessRecord']["Head"]["FEN"] 6 | pgnfile = doc['ChineseChessRecord']["Head"]["From"] 7 | moves = [i["@value"] for i in doc['ChineseChessRecord']['MoveList']["Move"] if i["@value"] != '00-00'] 8 | bb = BaseChessBoard(fen) 9 | red = False 10 | for i in moves: 11 | red = not red 12 | x1,y1,x2,y2 = int(i[0]),int(i[1]),int(i[3]),int(i[4]) 13 | #print("{} {}".format(i,"红" if red else "黑")) 14 | 15 | boardarr = bb.get_board_arr() 16 | 17 | # chess picker features 18 | picker_x = [] 19 | picker_y = [] 20 | if red: 21 | for one in feature_list['red']: 22 | picker_x.append(np.asarray(boardarr == one,dtype=np.uint8)) 23 | for one in feature_list['black']: 24 | picker_x.append(np.asarray(boardarr == one,dtype=np.uint8)) 25 | else: 26 | for one in feature_list['black']: 27 | picker_x.append(np.asarray(boardarr == one,dtype=np.uint8)) 28 | for one in feature_list['red']: 29 | picker_x.append(np.asarray(boardarr == one,dtype=np.uint8)) 30 | picker_x = np.asarray(picker_x) 31 | target = np.zeros((10,9)) 32 | target[y1,x1] = 1 33 | picker_y = target 34 | 35 | # chess mover features 36 | mover_x = [] 37 | mover_y = [] 38 | mover_x = np.concatenate((picker_x,target.reshape((1,10,9)))) 39 | mover_y = np.zeros((10,9)) 40 | mover_y[y2,x2] = 1 41 | if red: 42 | yield picker_x,picker_y,mover_x,mover_y 43 | else: 44 | yield picker_x[:,::-1,:],picker_y[::-1,:],mover_x[:,::-1,:],mover_y[::-1,:] 45 | moveresult = bb.move(Pos(x1,y1),Pos(x2,y2)) 46 | assert(moveresult != None) 47 | 48 | def convert_value(onefile,feature_list): 49 | doc = xmltodict.parse(open(onefile,encoding='utf-8').read()) 50 | fen = doc['ChineseChessRecord']["Head"]["FEN"] 51 | pgnfile = doc['ChineseChessRecord']["Head"]["From"] 52 | moves = [i["@value"] for i in doc['ChineseChessRecord']['MoveList']["Move"] if i["@value"] != '00-00'] 53 | bb = BaseChessBoard(fen) 54 | red = False 55 | for i in moves: 56 | red = not red 57 | x1,y1,x2,y2 = int(i[0]),int(i[1]),int(i[3]),int(i[4]) 58 | #print("{} {}".format(i,"红" if red else "黑")) 59 | 60 | boardarr = bb.get_board_arr() 61 | 62 | # chess picker features 63 | picker_x = [] 64 | picker_y = [] 65 | if red: 66 | for one in feature_list['red']: 67 | picker_x.append(np.asarray(boardarr == one,dtype=np.uint8)) 68 | for one in feature_list['black']: 69 | picker_x.append(np.asarray(boardarr == one,dtype=np.uint8)) 70 | else: 71 | for one in feature_list['black']: 72 | picker_x.append(np.asarray(boardarr == one,dtype=np.uint8)) 73 | for one in feature_list['red']: 74 | picker_x.append(np.asarray(boardarr == one,dtype=np.uint8)) 75 | picker_x = np.asarray(picker_x) 76 | 77 | 78 | picker_y = target 79 | 80 | 81 | if red: 82 | yield picker_x, 83 | else: 84 | yield picker_x[:,::-1,:], 85 | assert(moveresult != None) -------------------------------------------------------------------------------- /cchess/reader_pgn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | Copyright (C) 2014 walker li 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with this program. If not, see . 18 | ''' 19 | 20 | import os 21 | 22 | from cchess.board import * 23 | from cchess.exception import * 24 | 25 | #-----------------------------------------------------# 26 | def read_from_pgn(file_name): 27 | with open(file_name) as file: 28 | flines = file.readlines() 29 | 30 | lines = [] 31 | for line in flines : 32 | it = line.strip() #TODO, fix it in linux 33 | 34 | if len(it) == 0: 35 | continue 36 | 37 | lines.append(it) 38 | 39 | lines = __get_headers(lines) 40 | lines, docs = __get_comments(lines) 41 | #infos["Doc"] = docs 42 | __get_steps(lines) 43 | 44 | def __get_headers(lines): 45 | 46 | index = 0 47 | for line in lines: 48 | 49 | if line[0] != "[" : 50 | return lines[index:] 51 | 52 | if line[-1] != "]": 53 | raise CChessException("Format Error on line %" %(index + 1)) 54 | 55 | items = line[1:-1].split("\"") 56 | 57 | if len(items) < 3: 58 | raise CChessException("Format Error on line %" %(index + 1)) 59 | 60 | #self.infos[str(items[0]).strip()] = items[1].strip() 61 | 62 | index += 1 63 | 64 | def __get_comments(lines): 65 | 66 | if lines[0][0] != "{" : 67 | return (lines, None) 68 | 69 | docs = lines[0][1:] 70 | 71 | #处理一注释行的情况 72 | if docs[-1] == "}": 73 | return (lines[1:], docs[:-1].strip()) 74 | 75 | #处理多行注释的情况 76 | index = 1 77 | 78 | for line in lines[1:]: 79 | if line[-1] == "}": 80 | docs = docs + "\n" + line[:-1] 81 | return (lines[index+1:], docs.strip()) 82 | 83 | docs = docs + "\n" + line 84 | index += 1 85 | 86 | #代码能运行到这里,就是出了异常了 87 | raise CChessException("Comments not closed") 88 | 89 | def __get_token(token_mode, lines): 90 | pass 91 | 92 | def __get_steps(lines, next_step = 1): 93 | 94 | for line in lines : 95 | if line in["*", "1-0","0-1", "1/2-1/2"]: 96 | return 97 | 98 | print (line) 99 | items = line.split(".") 100 | 101 | if(len(items) < 2): 102 | continue 103 | raise Exception("format error") 104 | 105 | steps = items[1].strip().split(" ") 106 | print (steps) 107 | 108 | 109 | #-----------------------------------------------------# 110 | 111 | if __name__ == '__main__': 112 | pass 113 | -------------------------------------------------------------------------------- /cchess/test/test2.cbf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 第十局 当头炮抵当头炮局(和) 5 | 6 | 第十局 当头炮抵当头炮局(和).xqf 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 2014-09-08 22:22:20 31 | D04 32 | 1 33 | 34 | 3 35 | 36 | rnbakabnr/9/1c5c1/p1p1p1p1p/9/9/P1P1P1P1P/1C5C1/9/RNBAKABNR w - - 0 1 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /cchess/game.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | Copyright (C) 2014 walker li 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with this program. If not, see . 18 | ''' 19 | 20 | from cchess.board import * 21 | from cchess.move import * 22 | 23 | 24 | # 比赛结果 25 | UNKNOWN, RED_WIN, BLACK_WIN, PEACE = range(4) 26 | result_str = (u"未知", u"红胜", u"黑胜", u"平局" ) 27 | 28 | #存储类型 29 | BOOK_UNKNOWN, BOOK_ALL, BOOK_BEGIN, BOOK_MIDDLE, BOOK_END = range(5) 30 | book_type_str = (u"未知", u"全局", u"开局", u"中局", u"残局") 31 | 32 | 33 | #-----------------------------------------------------# 34 | class Game(object): 35 | def __init__(self, board = None, annotation = None): 36 | self.init_board = board.copy() 37 | self.annotation = annotation 38 | self.next_move = None 39 | self.infos = {} 40 | 41 | def append_next_move(self, chess_move): 42 | chess_move.parent = self 43 | if not self.next_move: 44 | self.next_move = chess_move 45 | else: 46 | #找最右一个 47 | move = self.next_move 48 | while move.right_move: 49 | move = move.right_move 50 | move.right_move = chess_move 51 | 52 | def verify_moves(self): 53 | return True 54 | move_list = self.dump_moves() 55 | for move_line in move_list: 56 | j = 0 57 | for move in move_line: 58 | if not move.is_valid_move(): 59 | #print moves_to_chinese(self.init_fen, move_line[:j]) 60 | #print j, move, move_line 61 | return False 62 | j += 1 63 | return True 64 | 65 | def dump_init_board(self): 66 | return self.init_board.dump_board() 67 | 68 | def dump_moves(self): 69 | 70 | if not self.next_move: 71 | return [] 72 | 73 | move_list = [] 74 | curr_move = [] 75 | move_list.append(curr_move) 76 | 77 | self.next_move.dump_moves(move_list, curr_move) 78 | 79 | return move_list 80 | 81 | def dump_std_moves(self): 82 | return [[str(move) for move in move_line] for move_line in self.dump_moves()] 83 | 84 | def dump_chinese_moves(self): 85 | return [[move.to_chinese() for move in move_line] for move_line in self.dump_moves()] 86 | 87 | def print_init_board(self): 88 | self.init_board.print_board() 89 | 90 | def print_chinese_moves(self, steps_per_line=3): 91 | 92 | moves = self.dump_chinese_moves() 93 | line_no = 1 94 | for line in moves: 95 | 96 | if len(moves) > 1: 97 | print (u'第%d分支' % line_no) 98 | 99 | i = 0 100 | for it in line: 101 | if (i%2) == 0: 102 | print ('%2d. '%(i/2+1),endl='') 103 | print (it,endl='') 104 | i += 1 105 | if (i%(steps_per_line*2)) == 0: 106 | print() 107 | print 108 | line_no += 1 109 | 110 | def dump_info(self): 111 | for key in self.info: 112 | print (key, self.info[key]) 113 | 114 | #-----------------------------------------------------# 115 | if __name__ == '__main__': 116 | from reader_xqf import * 117 | game = read_from_xqf('test\\ucci_test1.xqf') 118 | game.init_board.move_side = ChessSide.RED 119 | game.print_init_board() 120 | game.print_chinese_moves() 121 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # define dataset class to feed the model 2 | import numpy as np 3 | import os 4 | import sys 5 | import time 6 | 7 | class Dataset(): 8 | def __init__(self,data,label): 9 | self._index_in_epoch = 0 10 | self._epochs_completed = 0 11 | self._data = data 12 | self._label = label 13 | assert(data.shape[0] == label.shape[0]) 14 | self._num_examples = data.shape[0] 15 | pass 16 | 17 | @property 18 | def data(self): 19 | return self._data 20 | 21 | @property 22 | def label(self): 23 | return self._label 24 | 25 | def next_batch(self,batch_size,shuffle = True): 26 | start = self._index_in_epoch 27 | if start == 0 and self._epochs_completed == 0: 28 | idx = np.arange(0, self._num_examples) # get all possible indexes 29 | np.random.shuffle(idx) # shuffle indexe 30 | self._data = self.data[idx] # get list of `num` random samples 31 | self._label = self.label[idx] 32 | 33 | # go to the next batch 34 | if start + batch_size > self._num_examples: 35 | self._epochs_completed += 1 36 | rest_num_examples = self._num_examples - start 37 | data_rest_part = self.data[start:self._num_examples] 38 | label_rest_part = self.label[start:self._num_examples] 39 | idx0 = np.arange(0, self._num_examples) # get all possible indexes 40 | np.random.shuffle(idx0) # shuffle indexes 41 | self._data = self.data[idx0] # get list of `num` random samples 42 | self._label = self.label[idx0] 43 | 44 | start = 0 45 | self._index_in_epoch = batch_size - rest_num_examples #avoid the case where the #sample != integar times of batch_size 46 | end = self._index_in_epoch 47 | data_new_part = self._data[start:end] 48 | label_new_part = self._label[start:end] 49 | return np.concatenate((data_rest_part, data_new_part), axis=0),np.concatenate((label_rest_part, label_new_part), axis=0) 50 | else: 51 | self._index_in_epoch += batch_size 52 | end = self._index_in_epoch 53 | return self._data[start:end],self._label[start:end] 54 | 55 | class ProgressBar(): 56 | def __init__(self,worksum,info="",auto_display=True): 57 | self.worksum = worksum 58 | self.info = info 59 | self.finishsum = 0 60 | self.auto_display = auto_display 61 | def startjob(self): 62 | self.begin_time = time.time() 63 | def complete(self,num): 64 | self.gaptime = time.time() - self.begin_time 65 | self.finishsum += num 66 | if self.auto_display == True: 67 | self.display_progress_bar() 68 | def display_progress_bar(self): 69 | percent = self.finishsum * 100 / self.worksum 70 | eta_time = self.gaptime * 100 / (percent + 0.001) - self.gaptime 71 | strprogress = "[" + "=" * int(percent // 2) + ">" + "-" * int(50 - percent // 2) + "]" 72 | str_log = ("%s %.2f %% %s %s/%s \t used:%ds eta:%d s" % (self.info,percent,strprogress,self.finishsum,self.worksum,self.gaptime,eta_time)) 73 | sys.stdout.write('\r' + str_log) 74 | 75 | def get_dataset(paths): 76 | dataset = [] 77 | for path in paths.split(':'): 78 | path_exp = os.path.expanduser(path) 79 | classes = os.listdir(path_exp) 80 | classes.sort() 81 | nrof_classes = len(classes) 82 | for i in range(nrof_classes): 83 | class_name = classes[i] 84 | facedir = os.path.join(path_exp, class_name) 85 | if os.path.isdir(facedir): 86 | images = os.listdir(facedir) 87 | image_paths = [os.path.join(facedir,img) for img in images] 88 | dataset.append(ImageClass(class_name, image_paths)) 89 | 90 | return dataset 91 | 92 | class ImageClass(): 93 | "Stores the paths to images for a given class" 94 | def __init__(self, name, image_paths): 95 | self.name = name 96 | self.image_paths = image_paths 97 | 98 | def __str__(self): 99 | return self.name + ', ' + str(len(self.image_paths)) + ' images' 100 | 101 | def __len__(self): 102 | return len(self.image_paths) 103 | 104 | def split_dataset(dataset, split_ratio, mode): 105 | if mode=='SPLIT_CLASSES': 106 | nrof_classes = len(dataset) 107 | class_indices = np.arange(nrof_classes) 108 | np.random.shuffle(class_indices) 109 | split = int(round(nrof_classes*split_ratio)) 110 | train_set = [dataset[i] for i in class_indices[0:split]] 111 | test_set = [dataset[i] for i in class_indices[split:-1]] 112 | elif mode=='SPLIT_IMAGES': 113 | train_set = [] 114 | test_set = [] 115 | min_nrof_images = 2 116 | for cls in dataset: 117 | paths = cls.image_paths 118 | np.random.shuffle(paths) 119 | split = int(round(len(paths)*split_ratio)) 120 | if split 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with this program. If not, see . 18 | ''' 19 | 20 | import os 21 | import struct 22 | 23 | 24 | from bs4 import BeautifulSoup 25 | import xml.etree.cElementTree as et 26 | 27 | from cchess.board import * 28 | from cchess.game import * 29 | from cchess.exception import * 30 | 31 | #-----------------------------------------------------# 32 | def read_from_dhtml(html_page): 33 | res_dict = __parse_dhtml(html_page) 34 | game = read_from_txt(res_dict['moves'], res_dict['init']) 35 | return game 36 | 37 | #-----------------------------------------------------# 38 | def read_from_txt(moves_txt, pos_txt = None): 39 | 40 | def decode_txt_pos(pos) : 41 | return Pos(int(pos[0]), 9-int(pos[1])) 42 | 43 | #车马相士帅士相马车炮炮兵兵兵兵兵 44 | #车马象士将士象马车炮炮卒卒卒卒卒 45 | chessman_kinds = 'RNBAKABNRCCPPPPP' 46 | 47 | if not pos_txt: 48 | board = BaseChessBoard(FULL_INIT_FEN) 49 | else: 50 | if len(pos_txt) != 64: 51 | raise CChessException("bad pos_txt") 52 | 53 | board = BaseChessBoard() 54 | for side in range(2): 55 | for man_index in range(16): 56 | pos_index = (side * 16 + man_index)*2 57 | man_pos = pos_txt[pos_index : pos_index + 2] 58 | if man_pos == '99': 59 | continue 60 | pos = decode_txt_pos(man_pos) 61 | fen_ch = chr(ord(chessman_kinds[man_index]) + side * 32) 62 | board.put_fench(fen_ch, pos) 63 | 64 | last_move = None 65 | if not moves_txt: 66 | return Game(board) 67 | step_no = 0 68 | while step_no*4 < len(moves_txt) : 69 | steps = moves_txt[step_no*4:step_no*4+4] 70 | 71 | move_from = decode_txt_pos(moves_txt[step_no*4:step_no*4+2]) 72 | move_to = decode_txt_pos(moves_txt[step_no*4+2:step_no*4+4]) 73 | 74 | if board.is_valid_move(move_from, move_to) : 75 | 76 | if not last_move: 77 | _, man_side = fench_to_species(board.get_fench(move_from)) 78 | board.move_side = man_side 79 | game = Game(board) 80 | last_move = game 81 | 82 | new_move = board.move(move_from, move_to) 83 | last_move.append_next_move(new_move) 84 | last_move = new_move 85 | board.next_turn() 86 | else : 87 | raise CChessException("bad move at %d %s %s" % (step_no, move_from, move_to)) 88 | step_no += 1 89 | if step_no == 0: 90 | game = Game(board) 91 | 92 | return game 93 | 94 | #-----------------------------------------------------# 95 | def __str_between(src, begin_str, end_str) : 96 | first = src.find(begin_str) + len(begin_str) 97 | last = src.find(end_str) 98 | if (first != -1) and (last != -1) : 99 | return src[first:last] 100 | else : 101 | return None 102 | 103 | def __str_between2(src, begin_str, end_str) : 104 | first = src.find(begin_str) + len(begin_str) 105 | last = src.find(end_str) 106 | if last > first: 107 | return src[first:last] 108 | if last == -1: 109 | return None 110 | 111 | src2 = src[last + len(end_str):] 112 | f2 = src2.find(begin_str) + len(begin_str) 113 | l2 = src2.find(end_str) 114 | if l2 > f2: 115 | return src2[f2:l2] 116 | else : 117 | return None 118 | 119 | def __parse_dhtml(html_page) : 120 | result_dict = {} 121 | text = html_page.decode('GB18030') 122 | result_dict['event'] = __str_between(text, '[DhtmlXQ_event]', '[/DhtmlXQ_event]') 123 | if result_dict['event'] : 124 | result_dict['event'] = result_dict['event'] 125 | 126 | result_dict['title'] = __str_between(text, '[DhtmlXQ_title]', '[/DhtmlXQ_title]') 127 | if result_dict['title'] : 128 | result_dict['title'] = result_dict['title'] 129 | 130 | result_dict['result'] = __str_between(text, '[DhtmlXQ_result]', '[/DhtmlXQ_result]') 131 | if result_dict['result'] : 132 | result_dict['result'] = result_dict['result'] 133 | 134 | init = __str_between(text, '[DhtmlXQ_binit]', '[/DhtmlXQ_binit]') 135 | result_dict['init'] = init.encode('utf-8') if init else None 136 | moves = __str_between2(text, '[DhtmlXQ_movelist]', '[/DhtmlXQ_movelist]') 137 | result_dict['moves'] = moves.encode('utf-8') if moves else None 138 | 139 | return result_dict 140 | 141 | #-----------------------------------------------------# 142 | if __name__ == '__main__': 143 | 144 | pos_s = '9999999949399920109981999999993129629999409999997109993847999999' 145 | # 146 | # move_s = '31414050414050402032' #'77477242796770628979808166658131192710222625120209193136797136267121624117132324191724256755251547431516212226225534222434532454171600105361545113635161636061601610' 147 | # 148 | # try: 149 | # game = read_from_txt(moves_txt = move_s, pos_txt = pos_s) 150 | # except CChessException as e: 151 | # print e.reason 152 | # else: 153 | # 154 | # board_txt = game.dump_init_board() 155 | # print game.init_board.to_fen() 156 | # print 157 | # for line in board_txt: 158 | # print line 159 | # print 160 | # 161 | # moves = game.dump_std_moves() 162 | # print moves 163 | # for it in moves[0]: 164 | # print it 165 | # -------------------------------------------------------------------------------- /get_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import json\n", 12 | "import urllib\n", 13 | "import os\n", 14 | "import sys\n", 15 | "from matplotlib import pyplot as plt\n", 16 | "%matplotlib inline" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": { 23 | "collapsed": true 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "data_prefix = './data/imsa'\n", 28 | "jsonlists = os.listdir(data_prefix)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 3, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "data": { 38 | "text/plain": [ 39 | "569" 40 | ] 41 | }, 42 | "execution_count": 3, 43 | "metadata": {}, 44 | "output_type": "execute_result" 45 | } 46 | ], 47 | "source": [ 48 | "len(jsonlists)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 4, 54 | "metadata": { 55 | "collapsed": true 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "jsonlists = [os.path.join(data_prefix,i) for i in jsonlists]" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 5, 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "name": "stdout", 69 | "output_type": "stream", 70 | "text": [ 71 | "./data/imsa\\27\n", 72 | "./data/imsa\\287\n", 73 | "./data/imsa\\448\n", 74 | "./data/imsa\\81\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "alljson = []\n", 80 | "for i in jsonlists:\n", 81 | " try:\n", 82 | " alljson.append(json.load(open(i)))\n", 83 | " except:\n", 84 | " print(i)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 6, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "Fri Nov 03 23:20:44 2017 \n", 97 | "+-----------------------------------------------------------------------------+\n", 98 | "| NVIDIA-SMI 385.69 Driver Version: 385.69 |\n", 99 | "|-------------------------------+----------------------+----------------------+\n", 100 | "| GPU Name TCC/WDDM | Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 101 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 102 | "|===============================+======================+======================|\n", 103 | "| 0 GeForce GTX 108... WDDM | 00000000:04:00.0 Off | N/A |\n", 104 | "| 23% 23C P8 8W / 250W | 327MiB / 11264MiB | 0% Default |\n", 105 | "+-------------------------------+----------------------+----------------------+\n", 106 | "| 1 GeForce GTX 108... WDDM | 00000000:82:00.0 Off | N/A |\n", 107 | "| 23% 23C P8 8W / 250W | 156MiB / 11264MiB | 0% Default |\n", 108 | "+-------------------------------+----------------------+----------------------+\n", 109 | " \n", 110 | "+-----------------------------------------------------------------------------+\n", 111 | "| Processes: GPU Memory |\n", 112 | "| GPU PID Type Process name Usage |\n", 113 | "|=============================================================================|\n", 114 | "| 0 5396 C+G ...6)\\Google\\Chrome\\Application\\chrome.exe N/A |\n", 115 | "| 0 6308 C+G Insufficient Permissions N/A |\n", 116 | "| 0 8420 C+G C:\\Windows\\explorer.exe N/A |\n", 117 | "| 0 8984 C+G ...t_cw5n1h2txyewy\\ShellExperienceHost.exe N/A |\n", 118 | "| 0 9384 C+G ...dows.Cortana_cw5n1h2txyewy\\SearchUI.exe N/A |\n", 119 | "| 0 10844 C+G ...rogram Files\\Microsoft VS Code\\Code.exe N/A |\n", 120 | "| 1 4260 C+G Insufficient Permissions N/A |\n", 121 | "+-----------------------------------------------------------------------------+\n" 122 | ] 123 | } 124 | ], 125 | "source": [ 126 | "!nvidia-smi" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 22, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "data": { 136 | "text/plain": [ 137 | "697535" 138 | ] 139 | }, 140 | "execution_count": 22, 141 | "metadata": {}, 142 | "output_type": "execute_result" 143 | } 144 | ], 145 | "source": [ 146 | "alljson[0]['response']['list'][0]['playbook_id']" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 32, 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "name": "stdout", 156 | "output_type": "stream", 157 | "text": [ 158 | "one parse failed\n" 159 | ] 160 | } 161 | ], 162 | "source": [ 163 | "playbook_ids = []\n", 164 | "for onejson in alljson:\n", 165 | " playlist = onejson['response']['list']\n", 166 | " for play in playlist:\n", 167 | " try:\n", 168 | " playbook_ids.append(play['playbook_id'])\n", 169 | " except:\n", 170 | " print(\"one parse failed\")\n" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 34, 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "data": { 180 | "text/plain": [ 181 | "56467" 182 | ] 183 | }, 184 | "execution_count": 34, 185 | "metadata": {}, 186 | "output_type": "execute_result" 187 | } 188 | ], 189 | "source": [ 190 | "len(set(playbook_ids))" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": { 197 | "collapsed": true 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "import urllib2\n", 202 | "for playid in playbook_ids:\n", 203 | " response = urllib2.urlopen('http://www.example.com/')\n", 204 | " html = response.read()" 205 | ] 206 | } 207 | ], 208 | "metadata": { 209 | "anaconda-cloud": {}, 210 | "kernelspec": { 211 | "display_name": "Python [conda root]", 212 | "language": "python", 213 | "name": "conda-root-py" 214 | }, 215 | "language_info": { 216 | "codemirror_mode": { 217 | "name": "ipython", 218 | "version": 3 219 | }, 220 | "file_extension": ".py", 221 | "mimetype": "text/x-python", 222 | "name": "python", 223 | "nbconvert_exporter": "python", 224 | "pygments_lexer": "ipython3", 225 | "version": "3.5.2" 226 | } 227 | }, 228 | "nbformat": 4, 229 | "nbformat_minor": 1 230 | } 231 | -------------------------------------------------------------------------------- /cchess/move.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | Copyright (C) 2014 walker li 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with this program. If not, see . 18 | ''' 19 | 20 | from cchess.piece import * 21 | 22 | #-----------------------------------------------------# 23 | class Move(object): 24 | 25 | def __init__(self, board, p_from, p_to): 26 | 27 | self.board = board.copy() 28 | self.p_from = p_from 29 | self.p_to = p_to 30 | self.captured = self.board.get_fench(p_to) 31 | self.board_done = board.copy() 32 | self.board_done._move_piece(p_from, p_to) 33 | self.next_move = None 34 | self.right_move = None 35 | 36 | def is_king_killed(self): 37 | if self.captured and self.captured.lower() == 'k': 38 | return True 39 | return False 40 | 41 | def append_next_move(self, chess_move): 42 | chess_move.parent = self 43 | if not self.next_move: 44 | self.next_move = chess_move 45 | else: 46 | #找最右一个 47 | move = self.next_move 48 | while move.right_move: 49 | move = move.right_move 50 | move.right_move = chess_move 51 | 52 | def dump_moves(self, move_list, curr_move_line): 53 | 54 | if self.right_move: 55 | backup_move_line = curr_move_line[:] 56 | 57 | curr_move_line.append(self) 58 | #print curr_move_line 59 | if self.next_move: 60 | self.next_move.dump_moves(move_list, curr_move_line) 61 | #else: 62 | # print curr_move_line 63 | if self.right_move: 64 | #print self.move, 'has right', self.right.move 65 | move_list.append(backup_move_line) 66 | self.right_move.dump_moves(move_list, backup_move_line) 67 | 68 | def __str__(self): 69 | 70 | move_str = '' 71 | move_str += chr(ord('a') + self.p_from.x) 72 | move_str += str(self.p_from.y) 73 | move_str += chr(ord('a') + self.p_to.x) 74 | move_str += str(self.p_to.y) 75 | 76 | return move_str 77 | 78 | def from_str(self, move_str): 79 | 80 | self.p_from = Pos(ord(move_str[0]) - ord('a'), int(move_str[1])) 81 | self.p_to = Pos(ord(move_str[2]) - ord('a'), int(move_str[3])) 82 | 83 | return (self.p_from, self.p_to) 84 | 85 | def to_chinese(self): 86 | 87 | fench = self.board.get_fench(self.p_from) 88 | man_species, man_side = fench_to_species(fench) 89 | 90 | diff = self.p_to.y - self.p_from.y 91 | 92 | #黑方是红方的反向操作 93 | if man_side == ChessSide.BLACK: 94 | diff = -diff 95 | 96 | if diff == 0: 97 | diff_str = u"平" 98 | elif diff > 0: 99 | diff_str = u"进" 100 | else: 101 | diff_str = u"退" 102 | 103 | #王车炮兵规则 104 | if man_species in [ PieceT.KING, PieceT.ROOK, PieceT.CANNON, PieceT.PAWN]: 105 | if diff == 0 : 106 | dest_str = h_level_index[man_side][self.p_to.x] 107 | elif diff > 0 : 108 | dest_str = v_change_index[man_side][diff] 109 | else : 110 | dest_str = v_change_index[man_side][-diff] 111 | else : #士相马的规则 112 | dest_str = h_level_index[man_side][self.p_to.x] 113 | 114 | name_str = self.__get_chinese_name(self.p_from) 115 | 116 | return name_str + diff_str + dest_str 117 | 118 | def __get_chinese_name(self, p_from): 119 | 120 | fench = self.board.get_fench(p_from) 121 | man_species, man_side = fench_to_species(fench) 122 | man_name = fench_to_chinese(fench) 123 | 124 | #王,士,相命名规则 125 | if man_species in [ PieceT.KING, PieceT.ADVISOR, PieceT.BISHOP]: 126 | return man_name + h_level_index[man_side][p_from.x] 127 | 128 | pos_name2 = ((u'后', u'前'), (u'前', u'后')) 129 | pos_name3 = ((u'后', u'中', u'前'), (u'前', u'中', u'后')) 130 | pos_name4 = ((u'后', u'三', u'二', u'前'), (u'前', u'2', u'3', u'后')) 131 | pos_name5 = ((u'后', u'四', u'三', u'二', u'前'), (u'前', u'2', u'3', u'4', u'后')) 132 | 133 | #车马炮命名规则 134 | if man_species in [ PieceT.ROOK, PieceT.CANNON, PieceT.KNIGHT, PieceT.PAWN]: 135 | #红黑顺序相反,俩数组减少计算工作量 136 | count = 0 137 | pos_index = -1 138 | for y in range(10): 139 | if self.board._board[y][p_from.x] == fench: 140 | if p_from.y == y: 141 | pos_index = count 142 | count += 1 143 | if count == 1: 144 | return man_name + h_level_index[man_side][p_from.x] 145 | elif count == 2: 146 | return pos_name2[man_side][pos_index] + man_name 147 | elif count == 3: 148 | #TODO 查找另一个多子行 149 | return pos_name3[man_side][pos_index] + man_name 150 | elif count == 4: 151 | return pos_name4[man_side][pos_index] + man_name 152 | elif count == 5: 153 | return pos_name5[man_side][pos_index] + man_name 154 | 155 | return man_name + h_level_index[man_side][p_from.x] 156 | 157 | def for_ucci(self, move_side, history): 158 | if self.captured: 159 | self.board_done.move_side = move_side 160 | self.ucci_fen = self.board_done.to_fen() 161 | self.ucci_moves = [] 162 | else: 163 | if not history: 164 | self.ucci_fen = self.board.to_fen() 165 | self.ucci_moves = [self.to_iccs()] 166 | else: 167 | last_move = history[-1] 168 | self.ucci_fen = last_move.ucci_fen 169 | self.ucci_moves = last_move.ucci_moves[:] 170 | self.ucci_moves.append(self.to_iccs()) 171 | 172 | def to_ucci_fen(self): 173 | if not self.ucci_moves : 174 | return self.ucci_fen 175 | 176 | move_str = ' '.join(self.ucci_moves) 177 | return ' '.join([self.ucci_fen, 'moves', move_str]) 178 | 179 | def to_iccs(self): 180 | return chr(ord('a') + self.p_from.x) + str(self.p_from.y) + chr(ord('a') + self.p_to.x) + str(self.p_to.y) 181 | 182 | @staticmethod 183 | def from_iccs(move_str): 184 | return (Pos(ord(move_str[0]) - ord('a'),int(move_str[1])), Pos(ord(move_str[2]) - ord('a'), int(move_str[3]))) 185 | 186 | @staticmethod 187 | def from_chinese(self, move_str): 188 | 189 | move_indexs = [u"前", u"中", u"后", u"一", u"二", u"三", u"四", u"五"] 190 | 191 | multi_man = False 192 | multi_lines = False 193 | 194 | if move_str[0] in move_indexs: 195 | 196 | man_index = move_indexs.index(mov_str[0]) 197 | 198 | if man_index > 2: 199 | multi_lines = True 200 | 201 | multi_man = True 202 | man_name = move_str[1] 203 | 204 | else : 205 | 206 | man_name = move_str[0] 207 | 208 | if man_name not in list(fench_name_dict.values())[int(self.move_side)::2]: 209 | print ("error", move_str) 210 | 211 | man_kind = chessman_show_name_dict[self.move_side].index(man_name) 212 | if not multi_man: 213 | #单子移动指示 214 | man_x = h_level_index[self.move_side].index(man_name) 215 | mans = __get_fenchs_at_vline(man_kind, self.move_side) 216 | 217 | #无子可走 218 | if len(mans) == 0: 219 | return None 220 | 221 | #同一行选出来多个 222 | if (len(mans) > 1) and (man_kind not in[ADVISOR, BISHOP]): 223 | #只有士象是可以多个子尝试移动而不用标明前后的 224 | return None 225 | 226 | for man in mans: 227 | move = man.chinese_move_to_std_move(move_str[2:]) 228 | if move : 229 | return move 230 | 231 | return None 232 | 233 | else: 234 | #多子选一移动指示 235 | mans = __get_fenchs_of_kind(man_kind, self.move_side) 236 | 237 | return (p_from, p_to) 238 | 239 | #-----------------------------------------------------# 240 | if __name__ == '__main__': 241 | board = BaseChessBoard(FULL_INIT_FEN) 242 | m = Move(board, Pos(0,0), Pos(0,1) ) 243 | print (m.to_chinese() == u'车九进一') -------------------------------------------------------------------------------- /cchess/ucci.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | Copyright (C) 2014 walker li 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with this program. If not, see . 18 | ''' 19 | 20 | import sys,time 21 | from enum import * 22 | 23 | from subprocess import PIPE, Popen 24 | from threading import Thread 25 | 26 | #from Queue import Queue, Empty 27 | #from multiprocessing import Queue,Empty 28 | from cchess.board import * 29 | from cchess.move import * 30 | 31 | #-----------------------------------------------------# 32 | 33 | #Engine status 34 | class EngineStatus(IntEnum): 35 | BOOTING = 1, 36 | READY = 2, 37 | WAITING = 3, 38 | INFO_MOVE = 4, 39 | MOVE = 5, 40 | DEAD = 6, 41 | UNKNOWN = 7, 42 | BOARD_RESET = 8 43 | 44 | ON_POSIX = 'posix' in sys.builtin_module_names 45 | 46 | #-----------------------------------------------------# 47 | 48 | class UcciEngine(Thread): 49 | def __init__(self, name = ''): 50 | super(UcciEngine, self).__init__() 51 | 52 | self.engine_name = name 53 | 54 | self.daemon = True 55 | self.running = False 56 | 57 | self.engine_status = None 58 | self.ids = [] 59 | self.options = [] 60 | 61 | self.last_fen = None 62 | self.move_queue = Queue() 63 | 64 | def run(self) : 65 | 66 | self.running = True 67 | 68 | while self.running : 69 | output = self.pout.readline().strip() 70 | self.engine_out_queque.put(output) 71 | 72 | def handle_msg_once(self) : 73 | try: 74 | output = self.engine_out_queque.get_nowait() 75 | except Empty: 76 | return False 77 | 78 | if output in ['bye','']: #stop pipe 79 | self.pipe.terminate() 80 | return False 81 | 82 | self.__handle_engine_out_line(output) 83 | 84 | return True 85 | 86 | def load(self, engine_path): 87 | 88 | self.engine_name = engine_path 89 | 90 | try: 91 | self.pipe = Popen(self.engine_name, stdin=PIPE, stdout=PIPE)#, close_fds=ON_POSIX) 92 | except OSError: 93 | return False 94 | 95 | time.sleep(0.5) 96 | 97 | (self.pin, self.pout) = (self.pipe.stdin,self.pipe.stdout) 98 | 99 | self.engine_out_queque = Queue() 100 | 101 | self.enging_status = EngineStatus.BOOTING 102 | self.send_cmd("ucci") 103 | 104 | self.start() 105 | 106 | while self.enging_status == EngineStatus.BOOTING : 107 | self.handle_msg_once() 108 | 109 | return True 110 | 111 | def quit(self): 112 | 113 | self.send_cmd("quit") 114 | time.sleep(0.2) 115 | 116 | def go_from(self, fen, search_depth = 8): 117 | 118 | #pass all out msg first 119 | while True: 120 | try: 121 | output = self.engine_out_queque.get_nowait() 122 | except Empty: 123 | break 124 | 125 | self.send_cmd('position fen ' + fen) 126 | 127 | self.last_fen = fen 128 | 129 | #if ban_move : 130 | # self.send_cmd('banmoves ' + ban_move) 131 | 132 | self.send_cmd('go depth %d' % (search_depth)) 133 | time.sleep(0.2) 134 | 135 | def stop_thinking(self): 136 | self.send_cmd('stop') 137 | while True: 138 | try: 139 | output = self.engine_out_queque.get_nowait() 140 | except Empty: 141 | continue 142 | outputs_list = output.split() 143 | resp_id = outputs_list[0] 144 | if resp_id in ['bestmove', 'nobestmove']: 145 | return 146 | 147 | def send_cmd(self, cmd_str) : 148 | 149 | #print ">>>", cmd_str 150 | 151 | try : 152 | self.pin.write(cmd_str + "\n") 153 | self.pin.flush() 154 | except IOError as e : 155 | print ("error in send cmd", e) 156 | 157 | def __handle_engine_out_line(self, output) : 158 | 159 | #print "<<<", output 160 | 161 | outputs_list = output.split() 162 | resp_id = outputs_list[0] 163 | 164 | if self.enging_status == EngineStatus.BOOTING: 165 | if resp_id == "id" : 166 | self.ids.append(output) 167 | elif resp_id == "option" : 168 | self.options.append(output) 169 | if resp_id == "ucciok" : 170 | self.enging_status = EngineStatus.READY 171 | 172 | elif self.enging_status == EngineStatus.READY: 173 | 174 | if resp_id == 'nobestmove': 175 | print (output) 176 | self.move_queue.put(("dead", {'fen' : self.last_fen})) 177 | 178 | elif resp_id == 'bestmove': 179 | if outputs_list[1] == 'null': 180 | print (output) 181 | self.move_queue.put(("dead", {'fen' : self.last_fen})) 182 | elif outputs_list[-1] == 'draw': 183 | self.move_queue.put(("draw", {'fen' : self.last_fen})) 184 | elif outputs_list[-1] == 'resign': 185 | self.move_queue.put(("resign", {'fen' : self.last_fen})) 186 | else : 187 | move_str = output[9:13] 188 | pos_move = Move.from_iccs(move_str) 189 | 190 | move_info = {} 191 | move_info["fen"] = self.last_fen 192 | move_info["move"] = pos_move 193 | 194 | self.move_queue.put(("best_move",move_info)) 195 | 196 | elif resp_id == 'info': 197 | #info depth 6 score 4 pv b0c2 b9c7 c3c4 h9i7 c2d4 h7e7 198 | if outputs_list[1] == "depth": 199 | move_info = {} 200 | info_list = output[5:].split() 201 | 202 | if len(info_list) < 5: 203 | return 204 | 205 | move_info["fen"] = self.last_fen 206 | move_info[info_list[0]] = int(info_list[1]) #depth 6 207 | move_info[info_list[2]] = int(info_list[3]) #score 4 208 | 209 | move_steps = [] 210 | for step_str in info_list[5:] : 211 | move= Move.from_iccs(step_str) 212 | move_steps.append(move) 213 | move_info["move"] = move_steps 214 | 215 | self.move_queue.put(("info_move", move_info)) 216 | 217 | def go_best_iccs_move(self, move_str): 218 | 219 | pos_move = Move.from_iccs(move_str) 220 | 221 | move_info = {} 222 | move_info["fen"] = self.last_fen 223 | move_info["move"] = pos_move 224 | 225 | self.move_queue.put(("best_move",move_info)) 226 | 227 | 228 | #-----------------------------------------------------# 229 | 230 | if __name__ == '__main__': 231 | 232 | from reader_xqf import * 233 | # 234 | # win_dict = { ChessSide.RED : u"红胜", ChessSide.BLACK : u"黑胜" } 235 | # 236 | # game = read_from_xqf('test\\ucci_test1.xqf') 237 | # game.init_board.move_side = ChessSide.RED 238 | # game.print_init_board() 239 | # game.print_chinese_moves() 240 | # 241 | # board = game.init_board.copy() 242 | # 243 | # engine = UcciEngine() 244 | # engine.load("test\\eleeye\\eleeye.exe") 245 | # 246 | # for id in engine.ids: 247 | # print id 248 | # for op in engine.options: 249 | # print op 250 | # 251 | # dead = False 252 | # while not dead: 253 | # engine.go_from(board.to_fen(), 10) 254 | # while True: 255 | # engine.handle_msg_once() 256 | # if engine.move_queue.empty(): 257 | # time.sleep(0.2) 258 | # continue 259 | # output = engine.move_queue.get() 260 | # if output[0] == 'best_move': 261 | # p_from, p_to = output[1]["move"] 262 | # print board.move(p_from, p_to).to_chinese(), 263 | # #board.print_board() 264 | # last_side = board.move_side 265 | # board.next_turn() 266 | # break 267 | # elif output[0] == 'dead': 268 | # print win_dict[last_side] 269 | # dead = True 270 | # break 271 | # elif output[0] == 'draw': 272 | # print u'引擎议和' 273 | # dead = True 274 | # break 275 | # elif output[0] == 'resign': 276 | # print u'引擎认输', win_dict[last_side] 277 | # dead = True 278 | # break 279 | # 280 | # engine.quit() 281 | # time.sleep(0.5) 282 | # -------------------------------------------------------------------------------- /cchess/board.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | Copyright (C) 2014 walker li 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with this program. If not, see . 18 | ''' 19 | 20 | import sys 21 | import copy 22 | import numpy as np 23 | 24 | from cchess.exception import * 25 | from cchess.piece import * 26 | from cchess.move import * 27 | 28 | #-----------------------------------------------------# 29 | FULL_INIT_FEN = 'rnbakabnr/9/1c5c1/p1p1p1p1p/9/9/P1P1P1P1P/1C5C1/9/RNBAKABNR w - - 0 1' 30 | 31 | #-----------------------------------------------------# 32 | _text_board = [ 33 | #u' 1 2 3 4 5 6 7 8 9', 34 | u'0 ┌─┬─┬─┬───┬─┬─┬─┐', 35 | u' │ │ │ │\│/│ │ │ │', 36 | u'1 ├─┼─┼─┼─※─┼─┼─┼─┤', 37 | u' │ │ │ │/│\│ │ │ │', 38 | u'2 ├─┼─┼─┼─┼─┼─┼─┼─┤', 39 | u' │ │ │ │ │ │ │ │ │', 40 | u'3 ├─┼─┼─┼─┼─┼─┼─┼─┤', 41 | u' │ │ │ │ │ │ │ │ │', 42 | u'4 ├─┴─┴─┴─┴─┴─┴─┴─┤', 43 | u' │    │', 44 | u'5 ├─┬─┬─┬─┬─┬─┬─┬─┤', 45 | u' │ │ │ │ │ │ │ │ │', 46 | u'6 ├─┼─┼─┼─┼─┼─┼─┼─┤', 47 | u' │ │ │ │ │ │ │ │ │', 48 | u'7 ├─┼─┼─┼─┼─┼─┼─┼─┤', 49 | u' │ │ │ │\│/│ │ │ │', 50 | u'8 ├─┼─┼─┼─※─┼─┼─┼─┤', 51 | u' │ │ │ │/│\│ │ │ │', 52 | u'9 └─┴─┴─┴───┴─┴─┴─┘', 53 | u' 0 1 2 3 4 5 6 7 8' 54 | #u' 九 八 七 六 五 四 三 二 一' 55 | ] 56 | 57 | _fench_txt_name_dict = { 58 | 'K': u"帅", 59 | 'k': u"将", 60 | 'A': u"仕", 61 | 'a': u"士", 62 | 'B': u"相", 63 | 'b': u"象", 64 | 'N': u"马", 65 | 'n': u"碼", 66 | 'R': u"车", 67 | 'r': u"砗", 68 | 'C': u"炮", 69 | 'c': u"砲", 70 | 'P': u"兵", 71 | 'p': u"卒" 72 | 73 | } 74 | #-----------------------------------------------------# 75 | 76 | def _pos_to_text_board_pos(pos): 77 | return Pos(2*pos.x+2, (9 - pos.y)*2) 78 | 79 | def _fench_to_txt_name(fench) : 80 | return _fench_txt_name_dict[fench] 81 | 82 | #-----------------------------------------------------# 83 | class BaseChessBoard(object) : 84 | def __init__(self, fen = None): 85 | self.clear() 86 | if fen: self.from_fen(fen) 87 | 88 | def clear(self): 89 | self._board = [[None for x in range(9)] for y in range(10)] 90 | self.move_side = ChessSide.RED 91 | 92 | def copy(self): 93 | return copy.deepcopy(self) 94 | 95 | def put_fench(self, fench, pos): 96 | if self._board[pos.y][pos.x] != None: 97 | return False 98 | 99 | self._board[pos.y][pos.x] = fench 100 | 101 | return True 102 | 103 | def get_fench(self, pos): 104 | return self._board[pos.y][pos.x] 105 | 106 | def get_piece(self, pos): 107 | fench = self._board[pos.y][pos.x] 108 | 109 | if not fench: 110 | return None 111 | 112 | return Piece.create(self, fench, pos) 113 | 114 | def is_valid_move_t(self, move_t): 115 | pos_from, pos_to = move_t 116 | return self.is_valid_move(pos_from, pos_to) 117 | 118 | def is_valid_move(self, pos_from, pos_to): 119 | 120 | ''' 121 | 只进行最基本的走子规则检查,不对每个子的规则进行检查,以加快文件加载之类的速度 122 | ''' 123 | 124 | if not (0 <= pos_to.x <= 8): return False 125 | if not (0 <= pos_to.y <= 9): return False 126 | 127 | fench_from = self._board[pos_from.y][pos_from.x] 128 | if not fench_from : 129 | return False 130 | 131 | _, from_side = fench_to_species(fench_from) 132 | 133 | #move_side 不是None值才会进行走子颜色检查,这样处理某些特殊的存储格式时会处理比较迅速 134 | if self.move_side and (from_side != self.move_side) : 135 | return False 136 | 137 | fench_to = self._board[pos_to.y][pos_to.x] 138 | if not fench_to : 139 | return True 140 | 141 | _, to_side = fench_to_species(fench_to) 142 | 143 | return (from_side != to_side) 144 | 145 | def _move_piece(self, pos_from, pos_to): 146 | 147 | fench = self._board[pos_from.y][pos_from.x] 148 | self._board[pos_to.y][pos_to.x] = fench 149 | self._board[pos_from.y][pos_from.x] = None 150 | 151 | return fench 152 | 153 | def move(self, pos_from, pos_to): 154 | pos_from.y = 9 - pos_from.y 155 | pos_to.y = 9 - pos_to.y 156 | if not self.is_valid_move(pos_from, pos_to): 157 | return None 158 | 159 | board = self.copy() 160 | fench = self.get_fench(pos_to) 161 | self._move_piece(pos_from, pos_to) 162 | 163 | return Move(board, pos_from, pos_to) 164 | 165 | def move_iccs(self,move_str): 166 | move_from, move_to = Move.from_iccs(move_str) 167 | return move(move_from, move_to) 168 | 169 | def move_chinese(self,move_str): 170 | move_from, move_to = Move.from_chinese(self, move_str) 171 | return move(move_from, move_to) 172 | 173 | def next_turn(self) : 174 | if self.move_side == None : 175 | return None 176 | 177 | self.move_side = ChessSide.next_side(self.move_side) 178 | 179 | return self.move_side 180 | 181 | def from_fen(self, fen): 182 | 183 | num_set = set(('1', '2', '3', '4', '5', '6', '7', '8', '9')) 184 | ch_set = set(('k','a','b','n','r','c','p')) 185 | 186 | self.clear() 187 | 188 | if not fen or fen == '': 189 | return 190 | 191 | fen = fen.strip() 192 | 193 | x = 0 194 | y = 9 195 | 196 | for i in range(0, len(fen)): 197 | ch = fen[i] 198 | 199 | if ch == ' ': break 200 | elif ch == '/': 201 | y -= 1 202 | x = 0 203 | if y < 0: break 204 | elif ch in num_set: 205 | x += int(ch) 206 | if x > 8: x = 8 207 | elif ch.lower() in ch_set: 208 | if x <= 8: 209 | self.put_fench(ch, Pos(x, y)) 210 | x += 1 211 | else: 212 | return False 213 | 214 | fens = fen.split() 215 | 216 | self.move_side = None 217 | if (len(fens) >= 2) and (fens[1] == 'b') : 218 | self.move_side = ChessSide.BLACK 219 | else: 220 | self.move_side = ChessSide.RED 221 | 222 | if len(fens) >= 6 : 223 | self.round = int(fens[5]) 224 | else: 225 | self.round = 1 226 | 227 | return True 228 | 229 | def count_x_line_in(self, y, x_from, x_to): 230 | return reduce(lambda count, fench: count+1 if fench else count, self.x_line_in(y, x_from, x_to), 0) 231 | 232 | def count_y_line_in(self, x, y_from, y_to): 233 | return reduce(lambda count, fench: count+1 if fench else count, self.y_line_in(x, y_from, y_to), 0) 234 | 235 | def x_line_in(self, y, x_from, x_to): 236 | step = 1 if x_to > x_from else -1 237 | return [ self._board[y][x] for x in range(x_from+step, x_to, step) ] 238 | 239 | def y_line_in(self, x, y_from, y_to): 240 | step = 1 if y_to > y_from else -1 241 | return [ self._board[y][x] for y in range(y_from+step, y_to, step) ] 242 | 243 | def to_fen(self): 244 | return self.to_short_fen() + ' - - 0 1' 245 | 246 | def to_short_fen(self): 247 | fen = '' 248 | count = 0 249 | for y in range(9, -1, -1): 250 | for x in range(9): 251 | fench = self._board[y][x] 252 | if fench: 253 | if count is not 0: 254 | fen += str(count) 255 | count = 0 256 | fen += fench 257 | else: 258 | count += 1 259 | 260 | if count > 0: 261 | fen += str(count) 262 | count = 0 263 | 264 | if y > 0: fen += '/' 265 | 266 | if self.move_side is ChessSide.BLACK: 267 | fen += ' b' 268 | elif self.move_side is ChessSide.RED : 269 | fen += ' w' 270 | else : 271 | raise CChessException('Move Side Error' + str(self.move_side)) 272 | 273 | return fen 274 | 275 | def dump_board(self): 276 | 277 | board_str = _text_board[:] 278 | 279 | y = 0 280 | for line in self._board: 281 | x = 0 282 | for ch in line: 283 | if ch : 284 | pos = _pos_to_text_board_pos(Pos(x,y)) 285 | new_text=board_str[pos.y][:pos.x] + _fench_to_txt_name(ch) + board_str[pos.y][pos.x+1:] 286 | board_str[pos.y] = new_text 287 | x += 1 288 | y += 1 289 | 290 | return board_str 291 | 292 | def print_board(self): 293 | 294 | board_txt = self.dump_board() 295 | print() 296 | for line in board_txt: 297 | print(line) 298 | print() 299 | 300 | def get_board_arr(self): 301 | return np.asarray(self._board[::-1]) 302 | 303 | #-----------------------------------------------------# 304 | 305 | class ChessBoard(BaseChessBoard): 306 | def __init__(self, fen = None): 307 | super(ChessBoard, self).__init__(fen) 308 | 309 | def put_fench(self, fench, pos): 310 | self._board[pos.y][pos.x] = fench 311 | 312 | def is_valid_move(self, pos_from, pos_to): 313 | if not super(ChessBoard, self).is_valid_move(pos_from, pos_to): 314 | return False 315 | 316 | piece = self.get_piece(pos_from) 317 | if not piece.is_valid_move(pos_to): 318 | return False 319 | return True 320 | 321 | def is_checked_move(self, pos_from, pos_to): 322 | board = self.copy() 323 | board._move_piece(pos_from, pos_to) 324 | return board.is_checked() 325 | 326 | def is_checked(self): 327 | king = self.get_king(self.move_side) 328 | king.create_moves() 329 | if not king : return 0 330 | killers = self.get_side_pieces(ChessSide.next_side(self.move_side)) 331 | ''' 332 | for piece in killers: 333 | if piece.is_valid_move(Pos(king.x, king.y)): 334 | #print piece.x, piece.y, piece.fench 335 | return 1 336 | return 0 337 | ''' 338 | return reduce(lambda count, piece : count+1 if piece.is_valid_move(Pos(king.x, king.y)) else count, killers, 0) 339 | 340 | def is_checkmate(self): 341 | defenders = self.get_side_pieces(self.move_side) 342 | for piece in defenders : 343 | for move_it in piece.create_moves(): 344 | if self.is_valid_move_t(move_it): 345 | if not self.is_checked_move(move_it[0], move_it[1]): 346 | return False 347 | return True 348 | 349 | def get_king(self, side): 350 | limit_y = ((0,1,2), (7,8,9)) 351 | for x in (3,4,5): 352 | for y in limit_y[side]: 353 | fench = self._board[y][x] 354 | if not fench : 355 | continue 356 | if fench.lower() == 'k': 357 | return Piece.create(self, fench, Pos(x,y)) 358 | return None 359 | 360 | def get_side_pieces(self, side): 361 | pieces = [] 362 | for x in range(9): 363 | for y in range(10): 364 | fench = self._board[y][x] 365 | if not fench : 366 | continue 367 | _, p_side = fench_to_species(fench) 368 | if p_side == side : 369 | pieces.append(Piece.create(self, fench, Pos(x,y))) 370 | return pieces 371 | 372 | 373 | #-----------------------------------------------------# 374 | if __name__ == '__main__': 375 | pass 376 | # 377 | # board = ChessBoard(FULL_INIT_FEN) 378 | # board.print_board() 379 | # 380 | # k = board.get_king(ChessSide.RED) 381 | # print (k.x, k.y) == (4,0) 382 | # k = board.get_king(ChessSide.BLACK) 383 | # print (k.x, k.y) == (4,9) 384 | # 385 | # print board.x_line_in(0, 0, 8) 386 | # print board.x_line_in(0, 8, 0) 387 | # #print board.x_line_in(9, 0, 10) 388 | # print board.y_line_in(4, -1, 10) 389 | # print board.y_line_in(4, 10, -1) 390 | # 391 | # print board.count_x_line_in(0, 0, 8) == 7 392 | # print board.count_y_line_in(4,0,9) == 2 393 | # print board.count_y_line_in(4,1,8) == 2 394 | # 395 | # print board.is_checked() 396 | # 397 | # print board.copy().move(Pos(7,2),Pos(4,2)).to_chinese() == u'炮二平五' 398 | # print board.copy().move(Pos(1,2),Pos(1,1)).to_chinese() == u'炮八退一' 399 | # print board.copy().move(Pos(7,2),Pos(7,6)).to_chinese() == u'炮二进四' 400 | # print board.copy().move(Pos(7,7),Pos(4,7)).to_chinese() == u'炮8平5' 401 | # print board.copy().move(Pos(7,7),Pos(7,3)).to_chinese() == u'炮8进4' 402 | # print board.copy().move(Pos(6,3),Pos(6,4)).to_chinese() == u'兵三进一' 403 | # print board.copy().move(Pos(8,0),Pos(8,1)).to_chinese() == u'车一进一' 404 | # print board.copy().move(Pos(0,9),Pos(0,8)).to_chinese() == u'车1进1' 405 | # print board.copy().move(Pos(4,0),Pos(4,1)).to_chinese() == u'帅五进一' 406 | # print board.copy().move(Pos(4,9),Pos(4,8)).to_chinese() == u'将5进1' 407 | # print board.copy().move(Pos(2,0),Pos(4,2)).to_chinese() == u'相七进五' 408 | # print board.copy().move(Pos(5,0),Pos(4,1)).to_chinese() == u'仕四进五' 409 | # print board.copy().move(Pos(7,0),Pos(6,2)).to_chinese() == u'马二进三' 410 | # -------------------------------------------------------------------------------- /cchess/reader_xqf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | Copyright (C) 2014 walker li 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with this program. If not, see . 18 | ''' 19 | 20 | import os 21 | import struct 22 | 23 | from cchess.board import * 24 | from cchess.game import * 25 | 26 | 27 | #-----------------------------------------------------# 28 | result_dict = { 0:"*", 1:"1-0", 2:"0-1", 3:"1/2-1/2", 4:"1/2-1/2" } 29 | 30 | def _decode_pos(man_pos) : 31 | return Pos(int(man_pos / 10), man_pos % 10) 32 | 33 | def _decode_pos2(man_pos) : 34 | return (Pos(int(man_pos[0] / 10), man_pos[0] % 10), Pos(int(man_pos[1] / 10), man_pos[1] % 10)) 35 | 36 | #-----------------------------------------------------# 37 | class XQFKey(object) : 38 | def __init__(self): 39 | pass 40 | 41 | #-----------------------------------------------------# 42 | class XQFBuffDecoder(object) : 43 | def __init__(self, buffer): 44 | self.buffer = buffer 45 | self.index = 0 46 | self.length = len(buffer) 47 | 48 | def __read(self, size): 49 | 50 | start = self.index 51 | stop = self.index + size 52 | 53 | if stop > self.length: 54 | stop = self.length 55 | 56 | self.index = stop 57 | 58 | return self.buffer[start:stop] 59 | 60 | def read_str(self, size, coding = "GB18030"): 61 | buff = self.__read(size) 62 | 63 | try: 64 | ret = buff.decode(coding) 65 | except: 66 | ret = None 67 | 68 | return ret 69 | 70 | def read_bytes(self, size): 71 | return bytearray(self.__read(size)) 72 | 73 | def read_int(self): 74 | bytes = self.read_bytes(4) 75 | return bytes[0] + (bytes[1] << 8) + (bytes[2] << 16) + (bytes[3] << 24) 76 | 77 | #------------------------------------------------- 78 | def __init_decrypt_key(buff_str): 79 | 80 | keys = XQFKey() 81 | 82 | key_buff =bytearray(buff_str) 83 | 84 | # Pascal code here from XQFRW.pas 85 | # KeyMask : dTByte; // 加密掩码 86 | # ProductId : dTDWord; // 产品号(厂商的产品号) 87 | # KeyOrA : dTByte; 88 | # KeyOrB : dTByte; 89 | # KeyOrC : dTByte; 90 | # KeyOrD : dTByte; 91 | # KeysSum : dTByte; // 加密的钥匙和 92 | # KeyXY : dTByte; // 棋子布局位置钥匙 93 | # KeyXYf : dTByte; // 棋谱起点钥匙 94 | # KeyXYt : dTByte; // 棋谱终点钥匙 95 | 96 | HEAD_KeyMask, HEAD_ProductId, \ 97 | HEAD_KeyOrA, HEAD_KeyOrB, HEAD_KeyOrC, HEAD_KeyOrD, \ 98 | HEAD_KeysSum, HEAD_KeyXY, HEAD_KeyXYf, HEAD_KeyXYt = struct.unpack("= 12: 165 | tmpMan[(keys.KeyXY + i + 1) & 0x1F] = man_buff[i] 166 | else : 167 | tmpMan[i] = man_buff[i] 168 | 169 | for i in range(32) : 170 | tmpMan[i] = (tmpMan[i] - keys.KeyXY) & 0xFF 171 | if (tmpMan[i] > 89) : 172 | tmpMan[i] = 0xFF 173 | 174 | return tmpMan 175 | 176 | #-----------------------------------------------------# 177 | def __decode_buff(keys, buff) : 178 | 179 | nPos = 0x400 180 | de_buff =bytearray(buff) 181 | 182 | for i in range(len(buff)) : 183 | KeyByte = keys.F32Keys[(nPos + i) % 32] 184 | de_buff[i] = (de_buff[i] - KeyByte) & 0xFF 185 | 186 | return str(de_buff) 187 | 188 | 189 | #-----------------------------------------------------# 190 | def __read_init_info(buff_decoder, version, keys): 191 | 192 | step_info = buff_decoder.read_bytes(4) 193 | 194 | annote_len = 0 195 | if version <= 0x0A: 196 | #低版本在走子数据后紧跟着注释长度,长度为0则没有注释 197 | annote_len = buff_decoder.read_int() 198 | else: 199 | #高版本通过flag来标记有没有注释,有则紧跟着注释长度和注释字段 200 | step_info[2] &= 0xE0 201 | if (step_info[2] & 0x20) : #有注释 202 | annote_len = buff_decoder.read_int() - keys.KeyRMKSize 203 | 204 | return buff_decoder.read_str(annote_len) if (annote_len > 0) else None 205 | 206 | #-----------------------------------------------------# 207 | def __read_steps(buff_decoder, version, keys, parent, board): 208 | 209 | step_info = buff_decoder.read_bytes(4) 210 | 211 | if len(step_info) == 0: 212 | return 213 | 214 | annote_len = 0 215 | has_next_step = False 216 | has_var_step = False 217 | board_bak = board.copy() 218 | 219 | if version <= 0x0A: 220 | #低版本在走子数据后紧跟着注释长度,长度为0则没有注释 221 | if (step_info[2] & 0xF0) : 222 | has_next_step = True 223 | if (step_info[2] & 0x0F) : 224 | has_var_step = True #有变着 225 | annote_len = buff_decoder.read_int() 226 | 227 | step_info[0] = (step_info[0] - 0x18) & 0xFF; 228 | step_info[1] = (step_info[1] - 0x20) & 0xFF; 229 | 230 | else : 231 | #高版本通过flag来标记有没有注释,有则紧跟着注释长度和注释字段 232 | step_info[2] &= 0xE0 233 | if (step_info[2] & 0x80) : #有后续 234 | has_next_step = True 235 | if (step_info[2] & 0x40) : #有变招 236 | has_var_step = True 237 | if (step_info[2] & 0x20) : #有注释 238 | annote_len = buff_decoder.read_int() - keys.KeyRMKSize 239 | 240 | step_info[0] = (step_info[0] - 0x18 - keys.KeyXYf) & 0xFF 241 | step_info[1] = (step_info[1] - 0x20 - keys.KeyXYt) & 0xFF 242 | 243 | move_from, move_to = _decode_pos2(step_info) 244 | annote = buff_decoder.read_str(annote_len) if annote_len > 0 else None 245 | 246 | fench = board.get_fench(move_from) 247 | 248 | if not fench: 249 | #raise CChessException("bad move at %s %s" % (str(move_from), str(move_to))) 250 | good_move = parent 251 | else: 252 | _, man_side = fench_to_species(fench) 253 | board.move_side = man_side 254 | 255 | if board.is_valid_move(move_from, move_to): 256 | #认为当前走子一方就是合理一方,避免过多走子方检查 257 | curr_move = board.move(move_from, move_to) 258 | curr_move.note = annote 259 | #print curr_move.move_str(), has_next_step, has_var_step 260 | parent.append_next_move(curr_move) 261 | good_move = curr_move 262 | else: 263 | #print "bad move at", move_from, move_to 264 | #board.print_board() 265 | good_move = parent 266 | 267 | if has_next_step : 268 | __read_steps(buff_decoder, version, keys, good_move, board) 269 | 270 | if has_var_step : 271 | #print Move.to_iccs(parent.next_move.move), 'has var' 272 | __read_steps(buff_decoder, version, keys, parent, board_bak) 273 | 274 | #-----------------------------------------------------# 275 | def read_from_xqf(full_file_name, read_annotation = True): 276 | 277 | with open(full_file_name, "rb") as f: 278 | contents = f.read() 279 | 280 | magic, version, crypt_keys, ucBoard,\ 281 | ucUn2, ucRes,\ 282 | ucUn3, ucType,\ 283 | ucUn4, ucTitleLen,szTitle,\ 284 | ucUn5, ucMatchNameLen,szMatchName,\ 285 | ucDateLen, szDate,\ 286 | ucAddrLen, szAddr,\ 287 | ucRedPlayerNameLen, szRedPlayerName,\ 288 | ucBlackPlayerNameLen,szBlackPlayerName,\ 289 | ucTimeRuleLen,szTimeRule,\ 290 | ucRedTimeLen,szRedTime,\ 291 | ucBlackTime,szBlackTime, \ 292 | ucUn6,\ 293 | ucCommenerNameLen,szCommenerName,ucAuthorNameLen,szAuthorName,\ 294 | ucUn7 = struct.unpack("<2sB13s32s3sB12sB15sB63s64sB63sB15sB15sB15sB15sB63sB15sB15s32sB15sB15s528s", contents[:0x400]) 295 | 296 | if magic != "XQ": 297 | return None 298 | 299 | game_info = {} 300 | 301 | game_info["game_source"] = "XQF" 302 | game_info["game_version"] = version 303 | game_info["game_type"] = ucType + 1 304 | 305 | if ucRes <= 4: #It's really some file has value 4 306 | game_info["Result"] = result_dict[ucRes] 307 | else: 308 | print ("Bad Result ", ucRes, full_file_name) 309 | game_info["Result"] = '*' 310 | 311 | if ucRedPlayerNameLen > 0: 312 | try: 313 | game_info["Red"] = szRedPlayerName[:ucRedPlayerNameLen].decode("GB18030") 314 | except : pass 315 | 316 | if ucBlackPlayerNameLen > 0: 317 | try: 318 | game_info["Black"] = szBlackPlayerName[:ucBlackPlayerNameLen].decode("GB18030") 319 | except : pass 320 | 321 | if ucTitleLen > 0: 322 | try: 323 | game_info["Game"] = szTitle[:ucTitleLen].decode("GB18030") 324 | except: pass 325 | 326 | if ucMatchNameLen > 0: 327 | try: 328 | game_info["Event"] = szMatchName[:ucMatchNameLen].decode("GB18030") 329 | except: pass 330 | 331 | path, file_name=os.path.split(full_file_name) 332 | 333 | ''' 334 | if game_info["Result"] == '*' : 335 | if (u"先胜" in file_name) and (u"先和" not in file_name) and (u"先负" not in file_name) : 336 | game_info["Result"] = '1-0' 337 | elif (u"先负" in file_name) and (u"先和" not in file_name) and (u"先胜" not in file_name) : 338 | game_info["Result"] = '0-1' 339 | elif (u"先和" in file_name) and (u"先负" not in file_name) and (u"先胜" not in file_name) : 340 | game_info["Result"] = '1/2-1/2' 341 | ''' 342 | if (version <= 0x0A): 343 | keys = None 344 | chess_mans = __init_chess_board(ucBoard, version) 345 | step_base_buff =XQFBuffDecoder(contents[0x400:]) 346 | else: 347 | keys = __init_decrypt_key(crypt_keys) 348 | chess_mans = __init_chess_board(ucBoard, version, keys) 349 | step_base_buff = XQFBuffDecoder(__decode_buff(keys, contents[0x400:])) 350 | 351 | board = BaseChessBoard() 352 | 353 | chessman_kinds = \ 354 | ( 355 | 'R', 'N', 'B', 'A', 'K', 'A', 'B', 'N', 'R' , \ 356 | 'C', 'C', \ 357 | 'P','P','P','P','P' 358 | ) 359 | 360 | for side in range(2): 361 | for man_index in range(16): 362 | man_pos = chess_mans[side * 16 + man_index] 363 | if man_pos == 0xFF: 364 | continue 365 | pos = _decode_pos(man_pos) 366 | fen_ch = chr(ord(chessman_kinds[man_index]) +side * 32) 367 | board.put_fench(fen_ch, pos) 368 | 369 | game_annotation = __read_init_info(step_base_buff, version, keys) 370 | 371 | game = Game(board, game_annotation) 372 | game.info = game_info 373 | 374 | __read_steps(step_base_buff, version, keys, game, board) 375 | 376 | return game 377 | 378 | #-----------------------------------------------------# 379 | if __name__ == '__main__': 380 | 381 | ''' 382 | game = read_from_xqf(u"test\\FiveGoatsTest.xqf") 383 | game.dump_info() 384 | print 'verified', game.verify_moves() 385 | #moves = game.dump_moves() 386 | #print len(moves) 387 | ''' 388 | game = read_from_xqf(u"test\\EmptyTest.xqf") 389 | game.dump_info() 390 | ''' 391 | game = read_from_xqf(u"test\\BadMoveTest1.xqf") 392 | game.dump_info() 393 | print game.init_fen 394 | print 'verified', game.verify_moves() 395 | 396 | game = read_from_xqf(u"test\\BadMoveTest2.xqf") 397 | game.dump_info() 398 | print game.init_fen 399 | print game.annotation 400 | print 'verified', game.verify_moves() 401 | ''' 402 | 403 | #game = read_from_xqf(u"test\\BadMoveTest3.xqf") 404 | #game = read_from_xqf(u"test\\BadMoveTest4.xqf") 405 | game = read_from_xqf(u"test\\WildHouse.xqf") 406 | game.dump_info() 407 | #moves = game.dump_moves() 408 | #moves = game.dump_std_moves() 409 | #print moves 410 | game.print_init_board() 411 | game.print_chinese_moves(3) 412 | #print len(moves) 413 | #print 'verified', game.verify_moves() 414 | #print 'verified', game.verify_moves() 415 | -------------------------------------------------------------------------------- /cchess/piece.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | Copyright (C) 2014 walker li 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with this program. If not, see . 18 | ''' 19 | 20 | import sys 21 | 22 | from sets import * 23 | from enum import * 24 | 25 | #-----------------------------------------------------# 26 | 27 | h_level_index = \ 28 | ( 29 | (u"九",u"八",u"七",u"六",u"五",u"四",u"三",u"二",u"一"), 30 | (u"1",u"2",u"3",u"4",u"5",u"6",u"7",u"8",u"9") 31 | ) 32 | 33 | v_change_index = \ 34 | ( 35 | (u"错", ""u"一", u"二", u"三", u"四", u"五", u"六", u"七", u"八", u"九"), 36 | (u"误", ""u"1", u"2", u"3", u"4", u"5", u"6", u"7", u"8", u"9") 37 | ) 38 | 39 | #-----------------------------------------------------# 40 | 41 | advisor_pos = ( 42 | ((3, 0), (5, 0), (4, 1), (3, 2), (5, 2)), 43 | ((3, 9), (5, 9), (4, 8), (3, 7), (5, 7)), 44 | ) 45 | 46 | bishop_pos = ( 47 | ((2, 0), (6, 0), (0, 2), (4, 2), (9, 2), (2, 4), (6, 4)), 48 | ((2, 9), (6, 9), (0, 7), (4, 7), (9, 7), (2, 5), (6, 5)), 49 | ) 50 | 51 | #-----------------------------------------------------# 52 | class ChessSide(IntEnum): 53 | RED = 0 54 | BLACK = 1 55 | 56 | @staticmethod 57 | def next_side(side): 58 | return {ChessSide.RED:ChessSide.BLACK, ChessSide.BLACK:ChessSide.RED}[side] 59 | 60 | #-----------------------------------------------------# 61 | class PieceT(IntEnum): 62 | KING = 1 63 | ADVISOR = 2 64 | BISHOP = 3 65 | KNIGHT = 4 66 | ROOK = 5 67 | CANNON = 6 68 | PAWN = 7 69 | 70 | #-----------------------------------------------------# 71 | fench_species_dict = { 72 | 'k': PieceT.KING, 73 | 'a': PieceT.ADVISOR, 74 | 'b': PieceT.BISHOP, 75 | 'n': PieceT.KNIGHT, 76 | 'r': PieceT.ROOK, 77 | 'c': PieceT.CANNON, 78 | 'p': PieceT.PAWN 79 | } 80 | 81 | fench_name_dict = { 82 | 'K': u"帅", 83 | 'k': u"将", 84 | 'A': u"仕", 85 | 'a': u"士", 86 | 'B': u"相", 87 | 'b': u"象", 88 | 'N': u"马", 89 | 'n': u"马", 90 | 'R': u"车", 91 | 'r': u"车", 92 | 'C': u"炮", 93 | 'c': u"炮", 94 | 'P': u"兵", 95 | 'p': u"卒" 96 | } 97 | 98 | 99 | species_fench_dict = { 100 | PieceT.KING: ('K', 'k'), 101 | PieceT.ADVISOR: ('A', 'a'), 102 | PieceT.BISHOP: ('B', 'b'), 103 | PieceT.KNIGHT: ('N', 'n'), 104 | PieceT.ROOK: ('R', 'r'), 105 | PieceT.CANNON: ('C', 'c'), 106 | PieceT.PAWN: ('P', 'p') 107 | } 108 | 109 | #-----------------------------------------------------# 110 | def fench_to_chinese(fench) : 111 | return fench_name_dict[fench] 112 | 113 | def fench_to_species(fen_ch): 114 | return fench_species_dict[fen_ch.lower()], ChessSide.BLACK if fen_ch.islower() else ChessSide.RED 115 | 116 | def species_to_fench(species, side): 117 | return species_fench_dict[species][side] 118 | 119 | #KING, ADVISOR, BISHOP, KNIGHT, ROOK, CANNON, PAWN 120 | 121 | chessman_show_name_dict = { 122 | PieceT.KING: (u"帅", u"将"), 123 | PieceT.ADVISOR: (u"仕", u"士"), 124 | PieceT.BISHOP: (u"相", u"象"), 125 | PieceT.KNIGHT: (u"马", u"碼"), 126 | PieceT.ROOK: (u"车", u"砗"), 127 | PieceT.CANNON: (u"炮", u"砲"), 128 | PieceT.PAWN: (u"兵", u"卒") 129 | } 130 | 131 | def get_show_name(species, side) : 132 | return chessman_show_name_dict[species][side] 133 | 134 | 135 | #-----------------------------------------------------# 136 | class Pos(object): 137 | def __init__(self, x, y): 138 | self.x = x 139 | self.y = y 140 | 141 | def abs_diff(self, other): 142 | return (abs(self.x - other.x), abs(self.y - other.y)) 143 | 144 | def middle(self, other): 145 | return Pos((self.x + other.x) / 2, (self.y + other.y) / 2) 146 | 147 | def __str__(self): 148 | return str(self.x) + ":" + str(self.y) 149 | 150 | def __eq__(self, other): 151 | return (self.x == other.x) and (self.y == other.y) 152 | 153 | def __ne__(self, other): 154 | return (self.x != other.x) or (self.y != other.y) 155 | 156 | def __call__(self): 157 | return (self.x, self.y) 158 | 159 | #-----------------------------------------------------# 160 | 161 | class Piece(object): 162 | 163 | def __init__(self, board, fench, pos): 164 | 165 | self.board = board 166 | self.fench = fench 167 | 168 | species, side = fench_to_species(fench) 169 | 170 | self.species = species 171 | self.side = side 172 | 173 | self.x, self.y = pos() 174 | 175 | def is_valid_pos(self, pos): 176 | return True 177 | 178 | def is_valid_move(self, pos): 179 | return True 180 | 181 | @staticmethod 182 | def create(board, fench, pos): 183 | p_type = fench.lower() 184 | if p_type == 'k': 185 | return King(board, fench, pos) 186 | if p_type == 'a': 187 | return Advisor(board, fench, pos) 188 | if p_type == 'b': 189 | return Bishop(board, fench, pos) 190 | if p_type == 'r': 191 | return Rook(board, fench, pos) 192 | if p_type == 'c': 193 | return Cannon(board, fench, pos) 194 | if p_type == 'n': 195 | return Knight(board, fench, pos) 196 | if p_type == 'p': 197 | return Pawn(board, fench, pos) 198 | 199 | 200 | ''' 201 | def chinese_move_to_std_move(self, move_str): 202 | 203 | if self.species in self.__chinese_move_to_std_move_checks : 204 | new_pos = self.__chinese_move_to_std_move_checks[self.species](move_str) 205 | else : 206 | new_pos = self.__chinese_move_to_std_move_default(move_str) 207 | 208 | if not new_pos : 209 | return None 210 | 211 | if not self.can_move_to(new_pos[0] , new_pos[1]): 212 | return None 213 | 214 | return ((self.x, self.y), new_pos) 215 | 216 | 217 | def __chinese_move_to_std_move_advisor(self, move_str): 218 | 219 | if move_str[0] == u"平": 220 | return None 221 | 222 | new_x = h_level_index[self.side].index(move_str[1]) 223 | 224 | if move_str[0] == u"进" : 225 | diff_y = -1 226 | elif move_str[0] == u"退" : 227 | diff_y = 1 228 | else : 229 | return None 230 | 231 | if self.side == ChessSide.BLACK: 232 | diff_y = - diff_y 233 | 234 | new_y = self.y - diff_y 235 | 236 | return (new_x, new_y) 237 | 238 | def __chinese_move_to_std_move_bishop(self, move_str): 239 | if move_str[0] == u"平": 240 | return None 241 | 242 | new_x = h_level_index[self.side].index(move_str[1]) 243 | 244 | if move_str[0] == u"进" : 245 | diff_y = -2 246 | elif move_str[0] == u"退" : 247 | diff_y = 2 248 | else : 249 | return None 250 | 251 | if self.side == ChessSide.BLACK: 252 | diff_y = - diff_y 253 | 254 | new_y = self.y - diff_y 255 | 256 | return (new_x, new_y) 257 | 258 | def __chinese_move_to_std_move_knight(self, move_str): 259 | if move_str[0] == u"平": 260 | return None 261 | 262 | new_x = h_level_index[self.side].index(move_str[1]) 263 | 264 | diff_x = abs(self.x - new_x) 265 | 266 | if move_str[0] == u"进" : 267 | diff_y = [3, 2, 1][diff_x] 268 | 269 | elif move_str[0] == u"退" : 270 | diff_y = [-3, -2, -1][diff_x] 271 | 272 | else : 273 | return None 274 | 275 | if self.side == ChessSide.RED: 276 | diff_y = -diff_y 277 | 278 | new_y = self.y - diff_y 279 | 280 | return (new_x, new_y) 281 | 282 | def __chinese_move_to_std_move_default(self, move_str): 283 | 284 | if move_str[0] == u"平": 285 | new_x = h_level_index[self.side].index(move_str[1]) 286 | 287 | return (new_x, self.y) 288 | 289 | else : 290 | #王,车,炮,兵的前进和后退 291 | diff = v_change_index[self.side].index(move_str[1]) 292 | 293 | if move_str[0] == u"退": 294 | diff = -diff 295 | elif move_str[0] != u"进": 296 | return None 297 | 298 | if self.side == ChessSide.BLACK: 299 | diff = -diff 300 | 301 | new_y = self.y + diff 302 | 303 | return (self.x, new_y) 304 | ''' 305 | #-----------------------------------------------------# 306 | #王 307 | class King(Piece): 308 | 309 | def is_valid_pos(self, pos): 310 | 311 | if pos.x < 3 or pos.x > 5: 312 | return False 313 | 314 | if (self.side == ChessSide.RED) and pos.y > 2: 315 | return False 316 | 317 | if (self.side == ChessSide.BLACK) and pos.y < 7: 318 | return False 319 | 320 | return True 321 | 322 | def is_valid_move(self, pos): 323 | 324 | #先检查王吃王 325 | k2 = self.board.get_king(ChessSide.next_side(self.side)) 326 | 327 | if ((k2.x,k2.y) == pos()) and self.x == k2.x: 328 | count = self.board.count_y_line_in(self.x, self.y, k2.y) 329 | if count == 0: 330 | return True 331 | 332 | if not self.is_valid_pos(pos) : 333 | return False 334 | 335 | diff = pos.abs_diff(Pos(self.x, self.y)) 336 | 337 | return True if ((diff[0] + diff[1]) == 1) else False 338 | 339 | def create_moves(self): 340 | poss = [Pos(self.x+1,self.y),Pos(self.x-1,self.y),Pos(self.x,self.y+1),Pos(self.x,self.y-1)] 341 | curr_pos = Pos(self.x, self.y) 342 | moves = [(curr_pos, to_pos) for to_pos in poss] 343 | return filter(self.board.is_valid_move_t, moves) 344 | 345 | #-----------------------------------------------------# 346 | #士 347 | class Advisor(Piece): 348 | 349 | def is_valid_pos(self, pos): 350 | return True if pos() in advisor_pos[self.side] else False 351 | 352 | def is_valid_move(self, pos): 353 | 354 | if not self.is_valid_pos(pos) : 355 | return False 356 | 357 | if Pos(self.x, self.y).abs_diff(pos) == (1,1): 358 | return True 359 | 360 | return False 361 | 362 | def create_moves(self): 363 | poss = [Pos(self.x+1,self.y+1),Pos(self.x+1,self.y-1),Pos(self.x-1,self.y+1),Pos(self.x-1,self.y-1)] 364 | curr_pos = Pos(self.x, self.y) 365 | moves = [(curr_pos, to_pos) for to_pos in poss] 366 | return filter(self.board.is_valid_move_t, moves) 367 | 368 | #-----------------------------------------------------# 369 | #象 370 | class Bishop(Piece): 371 | def is_valid_pos(self, pos): 372 | return True if pos() in bishop_pos[self.side] else False 373 | 374 | def is_valid_move(self, pos): 375 | 376 | if Pos(self.x, self.y).abs_diff(pos) != (2,2): 377 | return False 378 | 379 | #塞象眼检查 380 | if self.board.get_fench(Pos(self.x, self.y).middle(pos)) != None : 381 | return False 382 | 383 | return True 384 | 385 | def create_moves(self): 386 | poss = [Pos(self.x+2,self.y+2),Pos(self.x+2,self.y-2),Pos(self.x-2,self.y+2),Pos(self.x-2,self.y-2)] 387 | curr_pos = Pos(self.x, self.y) 388 | moves = [(curr_pos, to_pos) for to_pos in poss] 389 | return filter(self.board.is_valid_move_t, moves) 390 | 391 | #-----------------------------------------------------# 392 | #马 393 | class Knight(Piece): 394 | def is_valid_move(self, pos): 395 | 396 | if (abs(self.x - pos.x) == 2) and (abs(self.y - pos.y) == 1): 397 | 398 | m_x = (self.x + pos.x) / 2 399 | m_y = self.y 400 | 401 | #别马腿检查 402 | if self.board.get_fench(Pos(m_x, m_y)) == None : 403 | return True 404 | 405 | if (abs(self.x - pos.x) == 1) and (abs(self.y - pos.y) == 2): 406 | 407 | m_x = self.x 408 | m_y = (self.y + pos.y) / 2 409 | 410 | #别马腿检查 411 | if self.board.get_fench(Pos(m_x, m_y)) == None : 412 | return True 413 | 414 | return False 415 | 416 | def create_moves(self): 417 | poss = [Pos(self.x+1,self.y+2),Pos(self.x+1,self.y-2), 418 | Pos(self.x-1,self.y+2),Pos(self.x-1,self.y-2), 419 | Pos(self.x+2,self.y+1),Pos(self.x+2,self.y-1), 420 | Pos(self.x-2,self.y+1),Pos(self.x-2,self.y-1), 421 | ] 422 | curr_pos = Pos(self.x, self.y) 423 | moves = [(curr_pos, to_pos) for to_pos in poss] 424 | return filter(self.board.is_valid_move_t, moves) 425 | 426 | #-----------------------------------------------------# 427 | #车 428 | class Rook(Piece): 429 | def is_valid_move(self, pos): 430 | if self.x != pos.x: 431 | #斜向移动是非法的 432 | if self.y != pos.y: 433 | return False 434 | 435 | #水平移动 436 | if self.board.count_x_line_in(self.y, self.x, pos.x) == 0: 437 | return True 438 | 439 | else : 440 | #垂直移动 441 | if self.board.count_y_line_in(self.x, self.y, pos.y) == 0: 442 | return True 443 | 444 | return False 445 | 446 | def create_moves(self): 447 | moves = [] 448 | curr_pos = Pos(self.x, self.y) 449 | for x in range(9): 450 | for y in range(10): 451 | if self.x == x and self.y == y: 452 | continue 453 | moves.append((curr_pos, Pos(x,y))) 454 | return filter(self.board.is_valid_move_t, moves) 455 | 456 | #-----------------------------------------------------# 457 | #炮 458 | class Cannon(Piece): 459 | def is_valid_move(self, pos): 460 | 461 | if self.x != pos.x: 462 | #斜向移动是非法的 463 | if self.y != pos.y: 464 | return False 465 | 466 | #水平移动 467 | count = self.board.count_x_line_in(self.y, self.x, pos.x) 468 | if (count == 0) and (self.board.get_fench(pos) == None): 469 | return True 470 | if (count == 1) and (self.board.get_fench(pos) != None): 471 | return True 472 | else : 473 | #垂直移动 474 | count = self.board.count_y_line_in(self.x, self.y, pos.y) 475 | if (count == 0) and (self.board.get_fench(pos) == None): 476 | return True 477 | if (count == 1) and (self.board.get_fench(pos) != None): 478 | return True 479 | 480 | return False 481 | 482 | def create_moves(self): 483 | moves = [] 484 | curr_pos = Pos(self.x, self.y) 485 | for x in range(9): 486 | for y in range(10): 487 | if self.x == x and self.y == y: 488 | continue 489 | moves.append((curr_pos, Pos(x,y))) 490 | return filter(self.board.is_valid_move_t, moves) 491 | 492 | #-----------------------------------------------------# 493 | #兵/卒 494 | class Pawn(Piece): 495 | def is_valid_pos(self, pos): 496 | 497 | if (self.side == ChessSide.RED) and pos.y < 3: 498 | return False 499 | 500 | if (self.side == ChessSide.BLACK) and pos.y > 6: 501 | return False 502 | 503 | return True 504 | 505 | def is_valid_move(self, pos): 506 | 507 | not_over_river_step = ((0, 1), (0, -1)) 508 | over_river_step = (((-1, 0), (1, 0), (0, 1)),((-1, 0), (1, 0), (0, -1))) 509 | 510 | step = (pos.x - self.x, pos.y - self.y) 511 | 512 | over_river = self.is_over_river() 513 | 514 | if (not over_river) and (step == not_over_river_step[self.side]): 515 | return True 516 | 517 | if over_river and (step in over_river_step[self.side]): 518 | return True 519 | 520 | return False 521 | 522 | def is_over_river(self) : 523 | if (self.side == ChessSide.RED) and (self.y > 4) : 524 | return True 525 | 526 | if (self.side == ChessSide.BLACK) and (self.y < 5) : 527 | return True 528 | 529 | return False 530 | 531 | def create_moves(self): 532 | moves = [] 533 | curr_pos = Pos(self.x, self.y) 534 | for x in range(9): 535 | for y in range(10): 536 | if self.x == x and self.y == y: 537 | continue 538 | moves.append((curr_pos, Pos(x,y))) 539 | return filter(self.board.is_valid_move_t, moves) 540 | 541 | #-----------------------------------------------------# 542 | if __name__ == '__main__': 543 | pass 544 | 545 | -------------------------------------------------------------------------------- /process_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "%reload_ext autoreload \n", 12 | "%autoreload 2\n", 13 | "from cchess import *" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 20, 19 | "metadata": { 20 | "collapsed": true 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "bb = BaseChessBoard(FULL_INIT_FEN)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 21, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "\n", 37 | "0 砗─碼─象─士─将─士─象─碼─砗\n", 38 | " │ │ │ │\│/│ │ │ │\n", 39 | "1 ├─┼─┼─┼─※─┼─┼─┼─┤\n", 40 | " │ │ │ │/│\│ │ │ │\n", 41 | "2 ├─砲─┼─┼─┼─┼─┼─砲─┤\n", 42 | " │ │ │ │ │ │ │ │ │\n", 43 | "3 卒─┼─卒─┼─卒─┼─卒─┼─卒\n", 44 | " │ │ │ │ │ │ │ │ │\n", 45 | "4 ├─┴─┴─┴─┴─┴─┴─┴─┤\n", 46 | " │    │\n", 47 | "5 ├─┬─┬─┬─┬─┬─┬─┬─┤\n", 48 | " │ │ │ │ │ │ │ │ │\n", 49 | "6 兵─┼─兵─┼─兵─┼─兵─┼─兵\n", 50 | " │ │ │ │ │ │ │ │ │\n", 51 | "7 ├─炮─┼─┼─┼─┼─┼─炮─┤\n", 52 | " │ │ │ │\│/│ │ │ │\n", 53 | "8 ├─┼─┼─┼─※─┼─┼─┼─┤\n", 54 | " │ │ │ │/│\│ │ │ │\n", 55 | "9 车─马─相─仕─帅─仕─相─马─车\n", 56 | " 0 1 2 3 4 5 6 7 8\n", 57 | "\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "bb.print_board()" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 22, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "data": { 72 | "text/plain": [ 73 | "" 74 | ] 75 | }, 76 | "execution_count": 22, 77 | "metadata": {}, 78 | "output_type": "execute_result" 79 | } 80 | ], 81 | "source": [ 82 | "bb.move(Pos(7,7),Pos(4,7))" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 23, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stdout", 92 | "output_type": "stream", 93 | "text": [ 94 | "\n", 95 | "0 砗─碼─象─士─将─士─象─碼─砗\n", 96 | " │ │ │ │\│/│ │ │ │\n", 97 | "1 ├─┼─┼─┼─※─┼─┼─┼─┤\n", 98 | " │ │ │ │/│\│ │ │ │\n", 99 | "2 ├─砲─┼─┼─┼─┼─┼─砲─┤\n", 100 | " │ │ │ │ │ │ │ │ │\n", 101 | "3 卒─┼─卒─┼─卒─┼─卒─┼─卒\n", 102 | " │ │ │ │ │ │ │ │ │\n", 103 | "4 ├─┴─┴─┴─┴─┴─┴─┴─┤\n", 104 | " │    │\n", 105 | "5 ├─┬─┬─┬─┬─┬─┬─┬─┤\n", 106 | " │ │ │ │ │ │ │ │ │\n", 107 | "6 兵─┼─兵─┼─兵─┼─兵─┼─兵\n", 108 | " │ │ │ │ │ │ │ │ │\n", 109 | "7 ├─炮─┼─┼─炮─┼─┼─┼─┤\n", 110 | " │ │ │ │\│/│ │ │ │\n", 111 | "8 ├─┼─┼─┼─※─┼─┼─┼─┤\n", 112 | " │ │ │ │/│\│ │ │ │\n", 113 | "9 车─马─相─仕─帅─仕─相─马─车\n", 114 | " 0 1 2 3 4 5 6 7 8\n", 115 | "\n" 116 | ] 117 | } 118 | ], 119 | "source": [ 120 | "bb.print_board()" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 24, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "text/plain": [ 131 | "array([['r', 'n', 'b', 'a', 'k', 'a', 'b', 'n', 'r'],\n", 132 | " [None, None, None, None, None, None, None, None, None],\n", 133 | " [None, 'c', None, None, None, None, None, 'c', None],\n", 134 | " ['p', None, 'p', None, 'p', None, 'p', None, 'p'],\n", 135 | " [None, None, None, None, None, None, None, None, None],\n", 136 | " [None, None, None, None, None, None, None, None, None],\n", 137 | " ['P', None, 'P', None, 'P', None, 'P', None, 'P'],\n", 138 | " [None, 'C', None, None, 'C', None, None, None, None],\n", 139 | " [None, None, None, None, None, None, None, None, None],\n", 140 | " ['R', 'N', 'B', 'A', 'K', 'A', 'B', 'N', 'R']], dtype=object)" 141 | ] 142 | }, 143 | "execution_count": 24, 144 | "metadata": {}, 145 | "output_type": "execute_result" 146 | } 147 | ], 148 | "source": [ 149 | "bb.get_board_arr()" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 69, 155 | "metadata": {}, 156 | "outputs": [ 157 | { 158 | "data": { 159 | "text/plain": [ 160 | "(10, 9)" 161 | ] 162 | }, 163 | "execution_count": 69, 164 | "metadata": {}, 165 | "output_type": "execute_result" 166 | } 167 | ], 168 | "source": [ 169 | "bb.get_board_arr().shape" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 71, 175 | "metadata": { 176 | "collapsed": true 177 | }, 178 | "outputs": [], 179 | "source": [ 180 | "target = np.zeros((10,9))\n", 181 | "target[7,4] = 1" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 72, 187 | "metadata": {}, 188 | "outputs": [ 189 | { 190 | "data": { 191 | "text/plain": [ 192 | "array([[ 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 193 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 194 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 195 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 196 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 197 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 198 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 199 | " [ 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n", 200 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 201 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0.]])" 202 | ] 203 | }, 204 | "execution_count": 72, 205 | "metadata": {}, 206 | "output_type": "execute_result" 207 | } 208 | ], 209 | "source": [ 210 | "target" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 31, 216 | "metadata": {}, 217 | "outputs": [ 218 | { 219 | "data": { 220 | "text/plain": [ 221 | "['n', None, 'c', 'A', 'N', 'k', 'B', 'K', 'a', 'r', 'C', 'P', 'b', 'p', 'R']" 222 | ] 223 | }, 224 | "execution_count": 31, 225 | "metadata": {}, 226 | "output_type": "execute_result" 227 | } 228 | ], 229 | "source": [ 230 | "list(set(bb.get_board_arr().reshape(-1)))" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": { 236 | "collapsed": true 237 | }, 238 | "source": [ 239 | "# read a chess game" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 2, 245 | "metadata": { 246 | "collapsed": true 247 | }, 248 | "outputs": [], 249 | "source": [ 250 | "import xmltodict" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 9, 256 | "metadata": { 257 | "collapsed": true 258 | }, 259 | "outputs": [], 260 | "source": [ 261 | "doc = xmltodict.parse(open('./data/samples/1966年全国象棋个人赛.cbf',encoding='utf-8').read())" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 10, 267 | "metadata": { 268 | "collapsed": true 269 | }, 270 | "outputs": [], 271 | "source": [ 272 | "fen = doc['ChineseChessRecord'][\"Head\"][\"FEN\"]" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 11, 278 | "metadata": { 279 | "collapsed": true 280 | }, 281 | "outputs": [], 282 | "source": [ 283 | "pgnfile = doc['ChineseChessRecord'][\"Head\"][\"From\"]" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 12, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "data": { 293 | "text/plain": [ 294 | "'\\n [Game \"Chinese Chess\"]\\n [Event \"1966年全国象棋个人赛\"]\\n [Site \"郑州\"]\\n [Seesion \"1\"]\\n [Date \"1966-04-12\"]\\n [Round \"第03轮\"]\\n [Red \"胡一鹏\"]\\n [RedTeam \"青海\"]\\n [Black \"杨官璘\"]\\n [BlackTeam \"广东\"]\\n [Result 黑胜]\\n 1.炮二平五 马8进7\\n\\n2.马二进三 车9平8\\n\\n3.车一平二 马2进3\\n\\n4.兵七进一 卒7进1\\n\\n5.车二进六 车1进1\\n\\n6.马八进七 车1平4\\n\\n7.车二平三 炮8退1\\n\\n8.炮八平九 车4进1\\n\\n9.车九平八 炮8平7\\n\\n10.车三进一 车4平7\\n\\n11.车八进七 士6进5\\n\\n12.马七进八 车7平4\\n\\n13.炮五平七 炮7进1\\n\\n14.炮七进四 象3进1\\n\\n15.马八进九 将5平6\\n\\n16.马九进七 炮7平3\\n\\n17.相七进五 车8进6\\n\\n18.车八平九 炮3退2\\n\\n19.车九进二 象7进5\\n\\n20.仕六进五 车8平7\\n\\n21.车九退四 车7平6\\n\\n22.马三进二 卒7进1\\n\\n23.马二进三 将6平5\\n\\n24.马三进一 车6退4\\n\\n25.马一进三 车6退1\\n\\n26.马三退二 卒7平6\\n\\n27.车九平二 士5退6\\n\\n28.相五退七 卒6进1\\n\\n29.炮九平一 卒6进1\\n\\n30.炮一进四 卒6进1\\n\\n31.炮一进三 将5进1\\n\\n32.车二平八 炮3进1\\n\\n33.炮七平九 车4平1\\n\\n34.车八进三 炮3平4\\n\\n35.车八退二 车6进2\\n\\n36.马二退三 车6进2\\n\\n37.相七进五 象5进7\\n\\n38.马三进一 车6退2\\n\\n39.车八平六 炮4进1\\n\\n40.马一进二 车6平8\\n\\n41.马二进三 车8退3\\n\\n42.车六平五 象7退5\\n\\n43.炮一退五 车8平7\\n\\n44.炮一平五 车7进4\\n\\n45.车五平三 将5平4\\n\\n46.车三平八 车1退1\\n\\n47.车八平六 车1进1\\n\\n48.车六平八 炮4平2\\n\\n49.车八平六 炮2平4\\n\\n50.兵九进一 车1平2\\n\\n51.炮五平六 将4平5\\n\\n52.炮六平五 将5平4\\n\\n53.炮九平八 车7平2\\n\\n54.炮八平七 后车进1\\n\\n55.炮五平六 前车进5\\n\\n56.炮六退四 将4平5\\n\\n57.车六平二 炮4进6\\n\\n58.车二平六 炮4平1\\n\\n59.车六平四 前车退3\\n\\n60.兵五进一 前车平4\\n\\n61.相五退七 车4退1\\n\\n62.兵五进一 车4平3\\n\\n63.炮七平五 象5退7\\n\\n64.仕五进六 车3进4\\n\\n65.车四退五 车3平4\\n\\n66.帅五平六 车2进6\\n\\n\\n '" 295 | ] 296 | }, 297 | "execution_count": 12, 298 | "metadata": {}, 299 | "output_type": "execute_result" 300 | } 301 | ], 302 | "source": [ 303 | "open(\"./data/imsa_play/\" + pgnfile,encoding='gbk').read()" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 13, 309 | "metadata": { 310 | "collapsed": true 311 | }, 312 | "outputs": [], 313 | "source": [ 314 | "moves = [i[\"@value\"] for i in doc['ChineseChessRecord']['MoveList'][\"Move\"] if i[\"@value\"] != '00-00']" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 14, 320 | "metadata": {}, 321 | "outputs": [ 322 | { 323 | "data": { 324 | "text/plain": [ 325 | "132" 326 | ] 327 | }, 328 | "execution_count": 14, 329 | "metadata": {}, 330 | "output_type": "execute_result" 331 | } 332 | ], 333 | "source": [ 334 | "len(moves)" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 38, 340 | "metadata": { 341 | "scrolled": true 342 | }, 343 | "outputs": [ 344 | { 345 | "ename": "NameError", 346 | "evalue": "name 'BaseChessBoard' is not defined", 347 | "output_type": "error", 348 | "traceback": [ 349 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 350 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 351 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mbb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mBaseChessBoard\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfen\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mred\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mFalse\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mmoves\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mred\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mred\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mx1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0my1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mx2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0my2\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 352 | "\u001b[0;31mNameError\u001b[0m: name 'BaseChessBoard' is not defined" 353 | ] 354 | } 355 | ], 356 | "source": [ 357 | "bb = BaseChessBoard(fen)\n", 358 | "red = False\n", 359 | "for i in moves:\n", 360 | " red = not red\n", 361 | " x1,y1,x2,y2 = int(i[0]),int(i[1]),int(i[3]),int(i[4])\n", 362 | " print(\"{} {}\".format(i,\"红\" if red else \"黑\"))\n", 363 | " moveresult = bb.move(Pos(x1,y1),Pos(x2,y2))\n", 364 | " assert(moveresult != None)\n", 365 | " bb.print_board()" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 43, 371 | "metadata": {}, 372 | "outputs": [ 373 | { 374 | "data": { 375 | "text/plain": [ 376 | "array([[None, None, 'b', 'a', None, 'k', None, None, None],\n", 377 | " [None, None, None, None, 'a', None, None, None, None],\n", 378 | " [None, None, None, None, 'b', None, None, None, None],\n", 379 | " ['p', None, None, None, None, None, 'p', None, 'p'],\n", 380 | " [None, None, 'P', None, None, None, None, None, None],\n", 381 | " [None, None, None, None, 'C', None, 'P', None, None],\n", 382 | " [None, None, None, None, 'P', None, None, None, 'P'],\n", 383 | " ['N', None, 'R', None, 'K', 'r', None, None, None],\n", 384 | " [None, 'c', 'C', None, None, None, None, None, None],\n", 385 | " ['R', None, 'B', None, None, None, 'B', 'c', None]], dtype=object)" 386 | ] 387 | }, 388 | "execution_count": 43, 389 | "metadata": {}, 390 | "output_type": "execute_result" 391 | } 392 | ], 393 | "source": [ 394 | "bb.get_board_arr()" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": 48, 400 | "metadata": {}, 401 | "outputs": [ 402 | { 403 | "name": "stdout", 404 | "output_type": "stream", 405 | "text": [ 406 | "\n", 407 | "0 ┌─┬─象─士───将─┬─┬─┐\n", 408 | " │ │ │ │\│/│ │ │ │\n", 409 | "1 ├─┼─┼─┼─士─┼─┼─┼─┤\n", 410 | " │ │ │ │/│\│ │ │ │\n", 411 | "2 ├─┼─┼─┼─象─┼─┼─┼─┤\n", 412 | " │ │ │ │ │ │ │ │ │\n", 413 | "3 卒─┼─┼─┼─┼─┼─卒─┼─卒\n", 414 | " │ │ │ │ │ │ │ │ │\n", 415 | "4 ├─┴─兵─┴─┴─┴─┴─┴─┤\n", 416 | " │    │\n", 417 | "5 ├─┬─┬─┬─炮─┬─兵─┬─┤\n", 418 | " │ │ │ │ │ │ │ │ │\n", 419 | "6 ├─┼─┼─┼─兵─┼─┼─┼─兵\n", 420 | " │ │ │ │ │ │ │ │ │\n", 421 | "7 马─┼─车─┼─帅─砗─┼─┼─┤\n", 422 | " │ │ │ │\│/│ │ │ │\n", 423 | "8 ├─砲─炮─┼─※─┼─┼─┼─┤\n", 424 | " │ │ │ │/│\│ │ │ │\n", 425 | "9 车─┴─相─┴───┴─相─砲─┘\n", 426 | " 0 1 2 3 4 5 6 7 8\n", 427 | "\n" 428 | ] 429 | } 430 | ], 431 | "source": [ 432 | "bb.print_board()" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": 44, 438 | "metadata": {}, 439 | "outputs": [ 440 | { 441 | "data": { 442 | "text/plain": [ 443 | "'2ba1k3/4a4/4b4/p5p1p/2P6/4C1P2/4P3P/N1R1Kr3/1cC6/R1B3Bc1 w - - 0 1'" 444 | ] 445 | }, 446 | "execution_count": 44, 447 | "metadata": {}, 448 | "output_type": "execute_result" 449 | } 450 | ], 451 | "source": [ 452 | "bb.to_fen()" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": 45, 458 | "metadata": { 459 | "collapsed": true 460 | }, 461 | "outputs": [], 462 | "source": [ 463 | "cc = BaseChessBoard(bb.to_fen())" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": 47, 469 | "metadata": {}, 470 | "outputs": [ 471 | { 472 | "name": "stdout", 473 | "output_type": "stream", 474 | "text": [ 475 | "\n", 476 | "0 ┌─┬─象─士───将─┬─┬─┐\n", 477 | " │ │ │ │\│/│ │ │ │\n", 478 | "1 ├─┼─┼─┼─士─┼─┼─┼─┤\n", 479 | " │ │ │ │/│\│ │ │ │\n", 480 | "2 ├─┼─┼─┼─象─┼─┼─┼─┤\n", 481 | " │ │ │ │ │ │ │ │ │\n", 482 | "3 卒─┼─┼─┼─┼─┼─卒─┼─卒\n", 483 | " │ │ │ │ │ │ │ │ │\n", 484 | "4 ├─┴─兵─┴─┴─┴─┴─┴─┤\n", 485 | " │    │\n", 486 | "5 ├─┬─┬─┬─炮─┬─兵─┬─┤\n", 487 | " │ │ │ │ │ │ │ │ │\n", 488 | "6 ├─┼─┼─┼─兵─┼─┼─┼─兵\n", 489 | " │ │ │ │ │ │ │ │ │\n", 490 | "7 马─┼─车─┼─帅─砗─┼─┼─┤\n", 491 | " │ │ │ │\│/│ │ │ │\n", 492 | "8 ├─砲─炮─┼─※─┼─┼─┼─┤\n", 493 | " │ │ │ │/│\│ │ │ │\n", 494 | "9 车─┴─相─┴───┴─相─砲─┘\n", 495 | " 0 1 2 3 4 5 6 7 8\n", 496 | "\n" 497 | ] 498 | } 499 | ], 500 | "source": [ 501 | "cc.print_board()" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": 49, 507 | "metadata": { 508 | "collapsed": true 509 | }, 510 | "outputs": [], 511 | "source": [ 512 | "cbfdir = \"./data/imsa-cbf/\"\n", 513 | "allfiles = os.listdir(cbfdir)" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": 54, 519 | "metadata": { 520 | "collapsed": true 521 | }, 522 | "outputs": [], 523 | "source": [ 524 | "allfiles = [os.path.join(cbfdir,i) for i in allfiles]" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": 58, 530 | "metadata": { 531 | "collapsed": true 532 | }, 533 | "outputs": [], 534 | "source": [ 535 | "gap = int(len(allfiles) * 0.9)" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": 61, 541 | "metadata": { 542 | "collapsed": true 543 | }, 544 | "outputs": [], 545 | "source": [ 546 | "import random\n", 547 | "random.shuffle(allfiles)" 548 | ] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": 62, 553 | "metadata": { 554 | "collapsed": true 555 | }, 556 | "outputs": [], 557 | "source": [ 558 | "trainfiles = allfiles[:gap]\n", 559 | "testfiles = allfiles[gap:]" 560 | ] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "execution_count": 73, 565 | "metadata": { 566 | "collapsed": true 567 | }, 568 | "outputs": [], 569 | "source": [ 570 | "import pandas as pd\n", 571 | "trainframe = pd.DataFrame(trainfiles)\n", 572 | "trainframe.to_csv('data/train_list.csv',header=None,index=None,encoding='utf-8')" 573 | ] 574 | }, 575 | { 576 | "cell_type": "code", 577 | "execution_count": 74, 578 | "metadata": { 579 | "collapsed": true 580 | }, 581 | "outputs": [], 582 | "source": [ 583 | "testframe = pd.DataFrame(testfiles)\n", 584 | "testframe.to_csv('data/test_list.csv',header=None,index=None,encoding='utf-8')" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": 68, 590 | "metadata": {}, 591 | "outputs": [ 592 | { 593 | "data": { 594 | "text/plain": [ 595 | "(49982, 5554)" 596 | ] 597 | }, 598 | "execution_count": 68, 599 | "metadata": {}, 600 | "output_type": "execute_result" 601 | } 602 | ], 603 | "source": [ 604 | "len(trainframe),len(testframe)" 605 | ] 606 | }, 607 | { 608 | "cell_type": "markdown", 609 | "metadata": { 610 | "collapsed": true 611 | }, 612 | "source": [ 613 | "# 统计" 614 | ] 615 | }, 616 | { 617 | "cell_type": "code", 618 | "execution_count": 19, 619 | "metadata": { 620 | "collapsed": true 621 | }, 622 | "outputs": [], 623 | "source": [ 624 | "import pandas as pd\n", 625 | "trainframe = pd.read_csv('data/train_list.csv',header=None)" 626 | ] 627 | }, 628 | { 629 | "cell_type": "code", 630 | "execution_count": 32, 631 | "metadata": { 632 | "collapsed": true 633 | }, 634 | "outputs": [], 635 | "source": [ 636 | "files = trainframe[0].values" 637 | ] 638 | }, 639 | { 640 | "cell_type": "code", 641 | "execution_count": 34, 642 | "metadata": { 643 | "collapsed": true 644 | }, 645 | "outputs": [], 646 | "source": [ 647 | "samplefiles = files[:100]" 648 | ] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "execution_count": 35, 653 | "metadata": { 654 | "collapsed": true 655 | }, 656 | "outputs": [], 657 | "source": [ 658 | "total = 0\n", 659 | "for i in samplefiles:\n", 660 | " doc = xmltodict.parse(open(i,encoding='utf-8').read())\n", 661 | " moves = [i[\"@value\"] for i in doc['ChineseChessRecord']['MoveList'][\"Move\"] if i[\"@value\"] != '00-00']\n", 662 | " total += len(moves)" 663 | ] 664 | }, 665 | { 666 | "cell_type": "code", 667 | "execution_count": 36, 668 | "metadata": {}, 669 | "outputs": [ 670 | { 671 | "name": "stdout", 672 | "output_type": "stream", 673 | "text": [ 674 | "79.21\n" 675 | ] 676 | } 677 | ], 678 | "source": [ 679 | "print(total / 100)" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": null, 685 | "metadata": { 686 | "collapsed": true 687 | }, 688 | "outputs": [], 689 | "source": [] 690 | } 691 | ], 692 | "metadata": { 693 | "anaconda-cloud": {}, 694 | "kernelspec": { 695 | "display_name": "Python [default]", 696 | "language": "python", 697 | "name": "python3" 698 | }, 699 | "language_info": { 700 | "codemirror_mode": { 701 | "name": "ipython", 702 | "version": 3 703 | }, 704 | "file_extension": ".py", 705 | "mimetype": "text/x-python", 706 | "name": "python", 707 | "nbconvert_exporter": "python", 708 | "pygments_lexer": "ipython3", 709 | "version": "3.5.2" 710 | } 711 | }, 712 | "nbformat": 4, 713 | "nbformat_minor": 1 714 | } 715 | -------------------------------------------------------------------------------- /chess_value_baseline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "curses is not supported on this machine (please install/reinstall curses for an optimal experience)\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "%reload_ext autoreload \n", 18 | "%autoreload 2\n", 19 | "from cchess import *\n", 20 | "import tensorflow as tf\n", 21 | "import numpy as np\n", 22 | "from matplotlib import pyplot as plt\n", 23 | "import random \n", 24 | "import time\n", 25 | "from utils import Dataset,ProgressBar\n", 26 | "from tflearn.data_flow import DataFlow,DataFlowStatus,FeedDictFlow\n", 27 | "from tflearn.data_utils import Preloader,ImagePreloader\n", 28 | "import scipy\n", 29 | "import pandas as pd\n", 30 | "import xmltodict\n", 31 | "from game_convert import convert_game\n", 32 | "import tflearn" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "# a network predict select and move of Chinese chess, with minimal preprocessing" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 4, 45 | "metadata": { 46 | "collapsed": true 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "GPU_CORE = 0\n", 51 | "BATCH_SIZE = 256\n", 52 | "BEGINING_LR = 0.01\n", 53 | "#TESTIMG_WIDTH = 500\n", 54 | "model_name = '11_8_resnet'\n", 55 | "data_dir = 'data/imsa-cbf/'" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": { 62 | "collapsed": true 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "def get_winner(file,data_dir = 'data/imsa-play/'):\n", 67 | " filename = os.path.join(data_dir,file)\n", 68 | " with open(filename) as fhdl:\n", 69 | " \n", 70 | " " 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 3, 76 | "metadata": { 77 | "collapsed": true 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "class ElePreloader(object):\n", 82 | " def __init__(self,datafile,batch_size=64):\n", 83 | " self.batch_size=batch_size\n", 84 | " content = pd.read_csv(datafile,header=None,index_col=None)\n", 85 | " self.filelist = [i[0] for i in content.get_values()]\n", 86 | " self.pos = 0\n", 87 | " self.feature_list = {\"red\":['A', 'B', 'C', 'K', 'N', 'P', 'R']\n", 88 | " ,\"black\":['a', 'b', 'c', 'k', 'n', 'p', 'r']}\n", 89 | " self.batch_size = batch_size\n", 90 | " self.batch_iter = self.__iter()\n", 91 | " assert(len(self.filelist) > batch_size)\n", 92 | " self.game_iterlist = [None for i in self.filelist]\n", 93 | " \n", 94 | " def __iter(self):\n", 95 | " retx1,rety1,retx2,rety2 = [],[],[],[]\n", 96 | " filelist = []\n", 97 | " while True:\n", 98 | " for i in range(self.batch_size):\n", 99 | " if self.game_iterlist[i] == None:\n", 100 | " if len(filelist) == 0:\n", 101 | " filelist = copy.copy(self.filelist)\n", 102 | " random.shuffle(filelist)\n", 103 | " self.game_iterlist[i] = convert_game(filelist.pop(),feature_list=self.feature_list)\n", 104 | " game_iter = self.game_iterlist[i]\n", 105 | " \n", 106 | " try:\n", 107 | " x1,y1,x2,y2 = game_iter.__next__()\n", 108 | " x1 = np.transpose(x1,[1,2,0])\n", 109 | " x2 = np.transpose(x2,[1,2,0])\n", 110 | " x1 = np.expand_dims(x1,axis=0)\n", 111 | " x2 = np.expand_dims(x2,axis=0)\n", 112 | " retx1.append(x1)\n", 113 | " rety1.append(y1)\n", 114 | " retx2.append(x2)\n", 115 | " rety2.append(y2)\n", 116 | " if len(retx1) >= self.batch_size:\n", 117 | " yield (np.concatenate(retx1,axis=0),np.asarray(rety1)\n", 118 | " ,np.concatenate(retx2,axis=0),np.asarray(rety2))\n", 119 | " retx1,rety1,retx2,rety2 = [],[],[],[]\n", 120 | " except :\n", 121 | " self.game_iterlist[i] = None\n", 122 | "\n", 123 | " def __getitem__(self, id):\n", 124 | " \n", 125 | " x1,y1,x2,y2 = self.batch_iter.__next__()\n", 126 | " return x1,y1,x2,y2\n", 127 | " \n", 128 | " def __len__(self):\n", 129 | " return 10000" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 11, 135 | "metadata": { 136 | "collapsed": true 137 | }, 138 | "outputs": [], 139 | "source": [ 140 | "def res_block(inputx,name,training,block_num=2,filters=256,kernel_size=(3,3)):\n", 141 | " net = inputx\n", 142 | " for i in range(block_num):\n", 143 | " net = tf.layers.conv2d(net,filters=filters,kernel_size=kernel_size,activation=None,name=\"{}_res_conv{}\".format(name,i),padding='same')\n", 144 | " net = tf.layers.batch_normalization(net,training=training,name=\"{}_res_bn{}\".format(name,i))\n", 145 | " if i == block_num - 1:\n", 146 | " net = net + inputx #= tf.concat((inputx,net),axis=-1)\n", 147 | " net = tf.nn.elu(net,name=\"{}_res_elu{}\".format(name,i))\n", 148 | " return net\n", 149 | "\n", 150 | "def conv_block(inputx,name,training,block_num=1,filters=2,kernel_size=(1,1)):\n", 151 | " net = inputx\n", 152 | " for i in range(block_num):\n", 153 | " net = tf.layers.conv2d(net,filters=filters,kernel_size=kernel_size,activation=None,name=\"{}_convblock_conv{}\".format(name,i),padding='same')\n", 154 | " net = tf.layers.batch_normalization(net,training=training,name=\"{}_convblock_bn{}\".format(name,i))\n", 155 | " net = tf.nn.elu(net,name=\"{}_convblock_elu{}\".format(name,i))\n", 156 | " # net [None,10,9,2]\n", 157 | " netshape = net.get_shape().as_list()\n", 158 | " print(\"inside conv block {}\".format(str(netshape)))\n", 159 | " net = tf.reshape(net,shape=(-1,netshape[1] * netshape[2] * netshape[3]))\n", 160 | " net = tf.layers.dense(net,10 * 9,name=\"{}_dense\".format(name))\n", 161 | " net = tf.nn.elu(net,name=\"{}_elu\".format(name))\n", 162 | " return net\n", 163 | "\n", 164 | "def res_net_board(inputx,name,training,filters=256):\n", 165 | " net = inputx\n", 166 | " net = tf.layers.conv2d(net,filters=filters,kernel_size=(3,3),activation=None,name=\"{}_res_convb\".format(name),padding='same')\n", 167 | " net = tf.layers.batch_normalization(net,training=training,name=\"{}_res_bnb\".format(name))\n", 168 | " net = tf.nn.elu(net,name=\"{}_res_elub\".format(name))\n", 169 | " for i in range(NUM_RES_LAYERS):\n", 170 | " net = res_block(net,name=\"{}_layer_{}\".format(name,i + 1),training=training)\n", 171 | " print(net.get_shape().as_list())\n", 172 | " print(\"inside res net {}\".format(str(net.get_shape().as_list())))\n", 173 | " net_unsoftmax = conv_block(net,name=\"{}_conv\".format(name),training=training)\n", 174 | " return net_unsoftmax\n", 175 | "\n", 176 | "def get_scatter(name):\n", 177 | " with tf.variable_scope(\"Test\"):\n", 178 | " ph = tf.placeholder(tf.float32,name=name)\n", 179 | " op = tf.summary.scalar(name,ph)\n", 180 | " return ph,op" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 12, 186 | "metadata": {}, 187 | "outputs": [ 188 | { 189 | "name": "stdout", 190 | "output_type": "stream", 191 | "text": [ 192 | "[None, 10, 9, 256]\n", 193 | "[None, 10, 9, 256]\n", 194 | "[None, 10, 9, 256]\n", 195 | "[None, 10, 9, 256]\n", 196 | "[None, 10, 9, 256]\n", 197 | "[None, 10, 9, 256]\n", 198 | "[None, 10, 9, 256]\n", 199 | "[None, 10, 9, 256]\n", 200 | "[None, 10, 9, 256]\n", 201 | "[None, 10, 9, 256]\n", 202 | "inside res net [None, 10, 9, 256]\n", 203 | "inside conv block [None, 10, 9, 2]\n", 204 | "[None, 10, 9, 256]\n", 205 | "[None, 10, 9, 256]\n", 206 | "[None, 10, 9, 256]\n", 207 | "[None, 10, 9, 256]\n", 208 | "[None, 10, 9, 256]\n", 209 | "[None, 10, 9, 256]\n", 210 | "[None, 10, 9, 256]\n", 211 | "[None, 10, 9, 256]\n", 212 | "[None, 10, 9, 256]\n", 213 | "[None, 10, 9, 256]\n", 214 | "inside res net [None, 10, 9, 256]\n", 215 | "inside conv block [None, 10, 9, 2]\n" 216 | ] 217 | } 218 | ], 219 | "source": [ 220 | "tf.reset_default_graph()\n", 221 | "config = tf.ConfigProto()\n", 222 | "config.gpu_options.allow_growth = True\n", 223 | "config.allow_soft_placement = True\n", 224 | "sess = tf.Session(config=config)\n", 225 | "\n", 226 | "NUM_RES_LAYERS = 10\n", 227 | "\n", 228 | "with tf.device(\"/gpu:{}\".format(GPU_CORE)):\n", 229 | " X1 = tf.placeholder(tf.float32,[None,10,9,14])\n", 230 | " y1 = tf.placeholder(tf.float32,[None,10,9])\n", 231 | " X2 = tf.placeholder(tf.float32,[None,10,9,15])\n", 232 | " y2 = tf.placeholder(tf.float32,[None,10,9])\n", 233 | " \n", 234 | " training = tf.placeholder(tf.bool,name='training_mode')\n", 235 | " learning_rate = tf.placeholder(tf.float32)\n", 236 | " global_step = tf.train.get_or_create_global_step()\n", 237 | " \n", 238 | " net_unsoftmax1 = res_net_board(X1,\"selectnet\",training=training)\n", 239 | " net_unsoftmax2 = res_net_board(X2,\"movenet\",training=training)\n", 240 | " \n", 241 | " target1 = tf.reshape(y1,(-1,10 * 9))\n", 242 | " target2 = tf.reshape(y2,(-1,10 * 9))\n", 243 | " with tf.variable_scope(\"Loss\"):\n", 244 | " loss_select = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=target1,logits=net_unsoftmax1))\n", 245 | " loss_move = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=target2,logits=net_unsoftmax2))\n", 246 | " loss = loss_select + loss_move\n", 247 | " \n", 248 | " loss_select_summary = tf.summary.scalar(\"loss_select\",loss_select)\n", 249 | " loss_move_summary = tf.summary.scalar(\"loss_move\",loss_move)\n", 250 | " loss_summary = tf.summary.scalar(\"total_loss\",loss)\n", 251 | " net_softmax1 = tf.nn.softmax(net_unsoftmax1)\n", 252 | " net_softmax2 = tf.nn.softmax(net_unsoftmax2)\n", 253 | " \n", 254 | " correct_prediction1 = tf.equal(tf.argmax(target1,1), tf.argmax(net_softmax1,1))\n", 255 | " correct_prediction2 = tf.equal(tf.argmax(target2,1), tf.argmax(net_softmax2,1))\n", 256 | " \n", 257 | " with tf.variable_scope(\"Accuracy\"):\n", 258 | " accuracy_select = tf.reduce_mean(tf.cast(correct_prediction1, tf.float32))\n", 259 | " accuracy_move = tf.reduce_mean(tf.cast(correct_prediction2, tf.float32))\n", 260 | " accuracy_total = accuracy_select * accuracy_move\n", 261 | " \n", 262 | " acc_select_summary = tf.summary.scalar(\"accuracy_select\",accuracy_select)\n", 263 | " acc_move_summary = tf.summary.scalar(\"accuracy_move\",accuracy_move)\n", 264 | " acc_summary = tf.summary.scalar(\"acc_summary\",accuracy_total)\n", 265 | " \n", 266 | " summary_op = tf.summary.merge([loss_select_summary,loss_move_summary,loss_summary\n", 267 | " ,acc_select_summary,acc_move_summary,acc_summary])\n", 268 | " \n", 269 | " test_select,test_select_summary = get_scatter(\"test_select_loss\")\n", 270 | " test_move,test_move_summary = get_scatter(\"test_move_loss\")\n", 271 | " test_total,test_total_summary = get_scatter(\"test_total_loss\")\n", 272 | " test_selectacc,test_selectacc_summary = get_scatter(\"test_select_acc\")\n", 273 | " test_moveacc,test_moveacc_summary = get_scatter(\"test_move_acc\")\n", 274 | " test_totalacc,test_totalacc_summary = get_scatter(\"test_total_acc\")\n", 275 | " \n", 276 | " test_summary_op = tf.summary.merge([test_select_summary,test_move_summary,test_total_summary\n", 277 | " ,test_selectacc_summary,test_moveacc_summary,test_totalacc_summary])\n", 278 | " \n", 279 | " update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n", 280 | " with tf.control_dependencies(update_ops):\n", 281 | " optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,momentum=0.9)\n", 282 | " train_op = optimizer.minimize(loss,global_step=global_step)\n", 283 | "\n", 284 | " train_summary_writer = tf.summary.FileWriter(\"./log/{}_train\".format(model_name), sess.graph)" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 13, 290 | "metadata": {}, 291 | "outputs": [ 292 | { 293 | "data": { 294 | "text/plain": [ 295 | "0" 296 | ] 297 | }, 298 | "execution_count": 13, 299 | "metadata": {}, 300 | "output_type": "execute_result" 301 | } 302 | ], 303 | "source": [ 304 | "sess.run(tf.global_variables_initializer())\n", 305 | "tf.train.global_step(sess, global_step)" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 14, 311 | "metadata": { 312 | "collapsed": true 313 | }, 314 | "outputs": [], 315 | "source": [ 316 | "??tflearn.layers.residual_block" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 15, 322 | "metadata": { 323 | "collapsed": true 324 | }, 325 | "outputs": [], 326 | "source": [ 327 | "import os\n", 328 | "if not os.path.exists(\"models/{}\".format(model_name)):\n", 329 | " os.mkdir(\"models/{}\".format(model_name))" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 16, 335 | "metadata": { 336 | "collapsed": true 337 | }, 338 | "outputs": [], 339 | "source": [ 340 | "N_BATCH = 10000\n", 341 | "N_BATCH_TEST = 300" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": 22, 347 | "metadata": {}, 348 | "outputs": [ 349 | { 350 | "name": "stdout", 351 | "output_type": "stream", 352 | "text": [ 353 | "EPOCH 22 STEP 9999 LR 0.001 ACC 75.04 SELACC78.82 MOVACC95.21 LOSS 0.74 100.00 % [==================================================>] 2560000/2560000 \t used:3149s eta:0 sss eta:0 s\n", 354 | "validating epoch 22 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:54s eta:0 sTEST ACC 0.47125524282455444 SELACC 0.5831900835037231 MOVACC 0.8079296946525574 LOSS 2.4170286655426025\n", 355 | "\n", 356 | "EPOCH 23 STEP 9999 LR 0.001 ACC 74.33 SELACC78.43 MOVACC94.77 LOSS 0.76 100.00 % [==================================================>] 2560000/2560000 \t used:5560s eta:0 sss\n", 357 | "validating epoch 23 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:54s eta:0 sTEST ACC 0.4667249023914337 SELACC 0.5788801908493042 MOVACC 0.8060286641120911 LOSS 2.4456872940063477\n", 358 | "\n", 359 | "EPOCH 24 STEP 9999 LR 0.001 ACC 75.16 SELACC79.3 MOVACC94.79 LOSS 0.74 100.00 % [==================================================>] 2560000/2560000 \t used:5559s eta:0 ssss\n", 360 | "validating epoch 24 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:55s eta:0 sTEST ACC 0.47514840960502625 SELACC 0.5864583253860474 MOVACC 0.8101171851158142 LOSS 2.481386661529541\n", 361 | "\n", 362 | "EPOCH 25 STEP 9999 LR 0.001 ACC 77.37 SELACC80.81 MOVACC95.73 LOSS 0.68 100.00 % [==================================================>] 2560000/2560000 \t used:5562s eta:0 sss\n", 363 | "validating epoch 25 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:55s eta:0 sTEST ACC 0.4594808518886566 SELACC 0.574023425579071 MOVACC 0.8003906011581421 LOSS 2.678408622741699\n", 364 | "\n", 365 | "EPOCH 26 STEP 9999 LR 0.001 ACC 77.06 SELACC80.59 MOVACC95.62 LOSS 0.69 100.00 % [==================================================>] 2560000/2560000 \t used:5561s eta:0 sss\n", 366 | "validating epoch 26 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:54s eta:0 sTEST ACC 0.4758685231208801 SELACC 0.5884114503860474 MOVACC 0.80859375 LOSS 2.5563509464263916\n", 367 | "\n", 368 | "EPOCH 27 STEP 9999 LR 0.001 ACC 76.82 SELACC80.47 MOVACC95.47 LOSS 0.7 100.00 % [==================================================>] 2560000/2560000 \t used:5556s eta:0 ssss\n", 369 | "validating epoch 27 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:55s eta:0 sTEST ACC 0.46077102422714233 SELACC 0.5758333206176758 MOVACC 0.8000390529632568 LOSS 2.7744829654693604\n", 370 | "\n", 371 | "EPOCH 28 STEP 9999 LR 0.001 ACC 78.71 SELACC81.58 MOVACC96.48 LOSS 0.63 100.00 % [==================================================>] 2560000/2560000 \t used:5557s eta:0 sss\n", 372 | "validating epoch 28 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:55s eta:0 sTEST ACC 0.46656712889671326 SELACC 0.5788021087646484 MOVACC 0.8061718940734863 LOSS 2.853325128555298\n", 373 | "\n", 374 | "EPOCH 29 STEP 9999 LR 0.001 ACC 78.16 SELACC81.24 MOVACC96.22 LOSS 0.65 100.00 % [==================================================>] 2560000/2560000 \t used:5561s eta:0 sss\n", 375 | "validating epoch 29 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:55s eta:0 sTEST ACC 0.4673483371734619 SELACC 0.5812109112739563 MOVACC 0.8039583563804626 LOSS 2.8683159351348877\n", 376 | "\n", 377 | "EPOCH 30 STEP 9999 LR 0.0001 ACC 80.65 SELACC83.27 MOVACC96.85 LOSS 0.58 100.00 % [==================================================>] 2560000/2560000 \t used:5559s eta:0 sss\n", 378 | "validating epoch 30 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:54s eta:0 sTEST ACC 0.4682302474975586 SELACC 0.5828906297683716 MOVACC 0.8033333420753479 LOSS 2.967599630355835\n", 379 | "\n", 380 | "EPOCH 31 STEP 9999 LR 0.0001 ACC 79.84 SELACC82.5 MOVACC96.77 LOSS 0.61 100.00 % [==================================================>] 2560000/2560000 \t used:5554s eta:0 ssss\n", 381 | "validating epoch 31 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:59s eta:0 sTEST ACC 0.47560566663742065 SELACC 0.5861197710037231 MOVACC 0.8114323019981384 LOSS 2.9342682361602783\n", 382 | "\n", 383 | "EPOCH 32 STEP 9999 LR 0.0001 ACC 80.61 SELACC83.16 MOVACC96.93 LOSS 0.59 100.00 % [==================================================>] 2560000/2560000 \t used:5562s eta:0 sss\n", 384 | "validating epoch 32 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:55s eta:0 sTEST ACC 0.47255992889404297 SELACC 0.5841666460037231 MOVACC 0.8089713454246521 LOSS 3.0034432411193848\n", 385 | "\n", 386 | "EPOCH 33 STEP 8654 LR 0.0001 ACC 80.35 SELACC82.84 MOVACC96.99 LOSS 0.59 86.55 % [===========================================>-------] 2215680/2560000 \t used:4814s eta:748 ss" 387 | ] 388 | }, 389 | { 390 | "ename": "KeyboardInterrupt", 391 | "evalue": "", 392 | "output_type": "error", 393 | "traceback": [ 394 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 395 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 396 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 48\u001b[0m _,step_loss,step_acc_move,step_acc_select,step_acc_total,step_value,step_summary = sess.run(\n\u001b[1;32m 49\u001b[0m [train_op,loss,accuracy_move,accuracy_select,accuracy_total,global_step,summary_op],feed_dict={\n\u001b[0;32m---> 50\u001b[0;31m \u001b[0mX1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mbatch_x1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0my1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mbatch_y1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mX2\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mbatch_x2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0my2\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mbatch_y2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mlearning_rate\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mbatch_lr\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtraining\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 51\u001b[0m })\n\u001b[1;32m 52\u001b[0m \u001b[0mtrain_summary_writer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_summary\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstep_summary\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mstep_value\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 397 | "\u001b[0;32mC:\\Program Files\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 787\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 788\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 789\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 790\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 791\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 398 | "\u001b[0;32mC:\\Program Files\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 995\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 996\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m--> 997\u001b[0;31m feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[1;32m 998\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 999\u001b[0m \u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 399 | "\u001b[0;32mC:\\Program Files\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1130\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 1131\u001b[0m return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[0;32m-> 1132\u001b[0;31m target_list, options, run_metadata)\n\u001b[0m\u001b[1;32m 1133\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 1134\u001b[0m return self._do_call(_prun_fn, self._session, handle, feed_dict,\n", 400 | "\u001b[0;32mC:\\Program Files\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 1137\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 1138\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1139\u001b[0;31m \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1140\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 1141\u001b[0m \u001b[0mmessage\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 401 | "\u001b[0;32mC:\\Program Files\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 1119\u001b[0m return tf_session.TF_Run(session, options,\n\u001b[1;32m 1120\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1121\u001b[0;31m status, run_metadata)\n\u001b[0m\u001b[1;32m 1122\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 1123\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msession\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 402 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 403 | ] 404 | } 405 | ], 406 | "source": [ 407 | "restore = True\n", 408 | "N_EPOCH = 100\n", 409 | "DECAY_EPOCH = 10\n", 410 | "\n", 411 | "class ExpVal:\n", 412 | " def __init__(self,exp_a=0.97):\n", 413 | " self.val = None\n", 414 | " self.exp_a = exp_a\n", 415 | " def update(self,newval):\n", 416 | " if self.val == None:\n", 417 | " self.val = newval\n", 418 | " else:\n", 419 | " self.val = self.exp_a * self.val + (1 - self.exp_a) * newval\n", 420 | " def getval(self):\n", 421 | " return round(self.val,2)\n", 422 | " \n", 423 | "expacc = ExpVal()\n", 424 | "expacc_select = ExpVal()\n", 425 | "expacc_move = ExpVal()\n", 426 | "exploss = ExpVal()\n", 427 | "\n", 428 | "\n", 429 | "begining_learning_rate = 1e-1\n", 430 | "\n", 431 | "pred_image = None\n", 432 | "if restore == False:\n", 433 | " train_epoch = 1\n", 434 | " train_batch = 0\n", 435 | "for one_epoch in range(train_epoch,N_EPOCH):\n", 436 | " train_epoch = one_epoch\n", 437 | " pb = ProgressBar(worksum=N_BATCH * BATCH_SIZE,info=\" epoch {} batch {}\".format(train_epoch,train_batch))\n", 438 | " pb.startjob()\n", 439 | " \n", 440 | " for one_batch in range(N_BATCH):\n", 441 | " if restore == True and one_batch < train_batch:\n", 442 | " pb.auto_display = False\n", 443 | " pb.complete(BATCH_SIZE)\n", 444 | " pb.auto_display = True\n", 445 | " continue\n", 446 | " else:\n", 447 | " restore = False\n", 448 | " train_batch = one_batch\n", 449 | " \n", 450 | " batch_x1,batch_y1,batch_x2,batch_y2 = trainflow.next()['data']\n", 451 | " # learning rate decay strategy\n", 452 | " batch_lr = begining_learning_rate * 10 ** -(one_epoch // DECAY_EPOCH)\n", 453 | " \n", 454 | " _,step_loss,step_acc_move,step_acc_select,step_acc_total,step_value,step_summary = sess.run(\n", 455 | " [train_op,loss,accuracy_move,accuracy_select,accuracy_total,global_step,summary_op],feed_dict={\n", 456 | " X1:batch_x1,y1:batch_y1,X2:batch_x2,y2:batch_y2,learning_rate:batch_lr,training:True\n", 457 | " })\n", 458 | " train_summary_writer.add_summary(step_summary,step_value)\n", 459 | " step_acc_move *= 100\n", 460 | " step_acc_select *= 100\n", 461 | " step_acc_total *= 100\n", 462 | " expacc.update(step_acc_total)\n", 463 | " expacc_select.update(step_acc_select)\n", 464 | " expacc_move.update(step_acc_move)\n", 465 | " exploss.update(step_loss)\n", 466 | "\n", 467 | " \n", 468 | " pb.info = \"EPOCH {} STEP {} LR {} ACC {} SELACC{} MOVACC{} LOSS {} \".format(\n", 469 | " one_epoch,one_batch,batch_lr,expacc.getval()\n", 470 | " ,expacc_select.getval(),expacc_move.getval(),exploss.getval())\n", 471 | " \n", 472 | " pb.complete(BATCH_SIZE)\n", 473 | " print()\n", 474 | " accs = []\n", 475 | " accselects = []\n", 476 | " accmoves = []\n", 477 | " losses = []\n", 478 | " lossselects = []\n", 479 | " lossmoves = []\n", 480 | " pb = ProgressBar(worksum=N_BATCH_TEST * BATCH_SIZE,info=\"validating epoch {} batch {}\".format(train_epoch,train_batch))\n", 481 | " pb.startjob()\n", 482 | " for one_batch in range(N_BATCH_TEST):\n", 483 | " batch_x1,batch_y1,batch_x2,batch_y2 = testflow.next()['data']\n", 484 | " step_loss_move,step_loss_select,step_loss,step_accuracy_move,step_accuracy_select,step_accuracy_total = sess.run(\n", 485 | " [loss_move,loss_select,loss,accuracy_move,accuracy_select,accuracy_total],feed_dict={\n", 486 | " X1:batch_x1,y1:batch_y1,X2:batch_x2,y2:batch_y2,training:False\n", 487 | " })\n", 488 | " accs.append(step_accuracy_total)\n", 489 | " accselects.append(step_accuracy_select)\n", 490 | " accmoves.append(step_accuracy_move)\n", 491 | " losses.append(step_loss)\n", 492 | " lossselects.append(step_loss_select)\n", 493 | " lossmoves.append(step_loss_move)\n", 494 | " \n", 495 | " pb.complete(BATCH_SIZE)\n", 496 | " print(\"TEST ACC {} SELACC {} MOVACC {} LOSS {}\".format(np.average(accs),np.average(accselects)\n", 497 | " ,np.average(accmoves),np.average(losses)))\n", 498 | " #test_select_summary,test_move_summary,test_total_summary\n", 499 | " # ,test_selectacc_summary,test_moveacc_summary,test_totalacc_summary\n", 500 | " test_to_add_to_log = sess.run(test_summary_op,feed_dict={\n", 501 | " test_select:np.average(lossselects),test_move:np.average(lossmoves),test_total:np.average(losses)\n", 502 | " ,test_selectacc:np.average(accselects),test_moveacc:np.average(accmoves),test_totalacc:np.average(accs)\n", 503 | " })\n", 504 | " train_summary_writer.add_summary(test_to_add_to_log,step_value)\n", 505 | " print()\n", 506 | " saver = tf.train.Saver(var_list=tf.global_variables())\n", 507 | " saver.save(sess,\"models/{}/model_{}\".format(model_name,one_epoch))" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": null, 513 | "metadata": { 514 | "collapsed": true 515 | }, 516 | "outputs": [], 517 | "source": [] 518 | } 519 | ], 520 | "metadata": { 521 | "anaconda-cloud": {}, 522 | "kernelspec": { 523 | "display_name": "Python [conda root]", 524 | "language": "python", 525 | "name": "conda-root-py" 526 | }, 527 | "language_info": { 528 | "codemirror_mode": { 529 | "name": "ipython", 530 | "version": 3 531 | }, 532 | "file_extension": ".py", 533 | "mimetype": "text/x-python", 534 | "name": "python", 535 | "nbconvert_exporter": "python", 536 | "pygments_lexer": "ipython3", 537 | "version": "3.5.2" 538 | } 539 | }, 540 | "nbformat": 4, 541 | "nbformat_minor": 1 542 | } 543 | -------------------------------------------------------------------------------- /chess_policy_baseline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "curses is not supported on this machine (please install/reinstall curses for an optimal experience)\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "%reload_ext autoreload \n", 18 | "%autoreload 2\n", 19 | "from cchess import *\n", 20 | "import tensorflow as tf\n", 21 | "import numpy as np\n", 22 | "from matplotlib import pyplot as plt\n", 23 | "import random \n", 24 | "import time\n", 25 | "from utils import Dataset,ProgressBar\n", 26 | "from tflearn.data_flow import DataFlow,DataFlowStatus,FeedDictFlow\n", 27 | "from tflearn.data_utils import Preloader,ImagePreloader\n", 28 | "import scipy\n", 29 | "import pandas as pd\n", 30 | "import xmltodict\n", 31 | "from game_convert import convert_game\n", 32 | "import tflearn" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "# a network predict select and move of Chinese chess, with minimal preprocessing" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 2, 45 | "metadata": { 46 | "collapsed": true 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "GPU_CORE = 0\n", 51 | "BATCH_SIZE = 256\n", 52 | "BEGINING_LR = 0.01\n", 53 | "#TESTIMG_WIDTH = 500\n", 54 | "model_name = '11_4_resnet'\n", 55 | "data_dir = 'data/imsa-cbf/'" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "metadata": { 62 | "collapsed": true 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "class ElePreloader(object):\n", 67 | " def __init__(self,datafile,batch_size=64):\n", 68 | " self.batch_size=batch_size\n", 69 | " content = pd.read_csv(datafile,header=None,index_col=None)\n", 70 | " self.filelist = [i[0] for i in content.get_values()]\n", 71 | " self.pos = 0\n", 72 | " self.feature_list = {\"red\":['A', 'B', 'C', 'K', 'N', 'P', 'R']\n", 73 | " ,\"black\":['a', 'b', 'c', 'k', 'n', 'p', 'r']}\n", 74 | " self.batch_size = batch_size\n", 75 | " self.batch_iter = self.__iter()\n", 76 | " assert(len(self.filelist) > batch_size)\n", 77 | " self.game_iterlist = [None for i in self.filelist]\n", 78 | " \n", 79 | " def __iter(self):\n", 80 | " retx1,rety1,retx2,rety2 = [],[],[],[]\n", 81 | " filelist = []\n", 82 | " while True:\n", 83 | " for i in range(self.batch_size):\n", 84 | " if self.game_iterlist[i] == None:\n", 85 | " if len(filelist) == 0:\n", 86 | " filelist = copy.copy(self.filelist)\n", 87 | " random.shuffle(filelist)\n", 88 | " self.game_iterlist[i] = convert_game(filelist.pop(),feature_list=self.feature_list)\n", 89 | " game_iter = self.game_iterlist[i]\n", 90 | " \n", 91 | " try:\n", 92 | " x1,y1,x2,y2 = game_iter.__next__()\n", 93 | " x1 = np.transpose(x1,[1,2,0])\n", 94 | " x2 = np.transpose(x2,[1,2,0])\n", 95 | " x1 = np.expand_dims(x1,axis=0)\n", 96 | " x2 = np.expand_dims(x2,axis=0)\n", 97 | " retx1.append(x1)\n", 98 | " rety1.append(y1)\n", 99 | " retx2.append(x2)\n", 100 | " rety2.append(y2)\n", 101 | " if len(retx1) >= self.batch_size:\n", 102 | " yield (np.concatenate(retx1,axis=0),np.asarray(rety1)\n", 103 | " ,np.concatenate(retx2,axis=0),np.asarray(rety2))\n", 104 | " retx1,rety1,retx2,rety2 = [],[],[],[]\n", 105 | " except :\n", 106 | " self.game_iterlist[i] = None\n", 107 | "\n", 108 | " def __getitem__(self, id):\n", 109 | " \n", 110 | " x1,y1,x2,y2 = self.batch_iter.__next__()\n", 111 | " return x1,y1,x2,y2\n", 112 | " \n", 113 | " def __len__(self):\n", 114 | " return 10000" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 4, 120 | "metadata": { 121 | "collapsed": true 122 | }, 123 | "outputs": [], 124 | "source": [ 125 | "trainset = ElePreloader(datafile='data/train_list.csv',batch_size=BATCH_SIZE)\n", 126 | "with tf.device(\"/gpu:{}\".format(GPU_CORE)):\n", 127 | " coord = tf.train.Coordinator()\n", 128 | " trainflow = FeedDictFlow({\n", 129 | " 'data':trainset,\n", 130 | " },coord,batch_size=BATCH_SIZE,shuffle=True,continuous=True,num_threads=1)\n", 131 | "trainflow.start()" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 5, 137 | "metadata": { 138 | "collapsed": true 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "testset = ElePreloader(datafile='data/test_list.csv',batch_size=BATCH_SIZE)\n", 143 | "with tf.device(\"/gpu:{}\".format(GPU_CORE)):\n", 144 | " coord = tf.train.Coordinator()\n", 145 | " testflow = FeedDictFlow({\n", 146 | " 'data':testset,\n", 147 | " },coord,batch_size=BATCH_SIZE,shuffle=True,continuous=True,num_threads=1)\n", 148 | "testflow.start()" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 6, 154 | "metadata": { 155 | "collapsed": true 156 | }, 157 | "outputs": [], 158 | "source": [ 159 | "sample_x1,sample_y1,sample_x2,sample_y2 = trainflow.next()['data']" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 7, 165 | "metadata": { 166 | "collapsed": true 167 | }, 168 | "outputs": [], 169 | "source": [ 170 | "sample_x1,sample_y1,sample_x2,sample_y2 = testflow.next()['data']" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 8, 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "data": { 180 | "text/plain": [ 181 | "((256, 10, 9, 14), (256, 10, 9), (256, 10, 9, 15), (256, 10, 9))" 182 | ] 183 | }, 184 | "execution_count": 8, 185 | "metadata": {}, 186 | "output_type": "execute_result" 187 | } 188 | ], 189 | "source": [ 190 | "sample_x1.shape,sample_y1.shape,sample_x2.shape,sample_y2.shape" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 9, 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "data": { 200 | "text/plain": [ 201 | "array([[ 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", 202 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 203 | " [ 0., 1., 0., 0., 0., 0., 0., 1., 0.],\n", 204 | " [ 1., 0., 1., 0., 1., 0., 1., 0., 1.],\n", 205 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 206 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 207 | " [ 1., 0., 2., 0., 1., 0., 1., 0., 1.],\n", 208 | " [ 0., 1., 0., 0., 0., 0., 0., 1., 0.],\n", 209 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 210 | " [ 1., 1., 1., 1., 1., 1., 1., 1., 1.]])" 211 | ] 212 | }, 213 | "execution_count": 9, 214 | "metadata": {}, 215 | "output_type": "execute_result" 216 | } 217 | ], 218 | "source": [ 219 | "np.sum(sample_x2[0],axis=-1)" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 10, 225 | "metadata": {}, 226 | "outputs": [ 227 | { 228 | "data": { 229 | "text/plain": [ 230 | "(10, 9)" 231 | ] 232 | }, 233 | "execution_count": 10, 234 | "metadata": {}, 235 | "output_type": "execute_result" 236 | } 237 | ], 238 | "source": [ 239 | "np.sum(sample_x1[0],axis=-1).shape" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 11, 245 | "metadata": { 246 | "collapsed": true 247 | }, 248 | "outputs": [], 249 | "source": [ 250 | "def res_block(inputx,name,training,block_num=2,filters=256,kernel_size=(3,3)):\n", 251 | " net = inputx\n", 252 | " for i in range(block_num):\n", 253 | " net = tf.layers.conv2d(net,filters=filters,kernel_size=kernel_size,activation=None,name=\"{}_res_conv{}\".format(name,i),padding='same')\n", 254 | " net = tf.layers.batch_normalization(net,training=training,name=\"{}_res_bn{}\".format(name,i))\n", 255 | " if i == block_num - 1:\n", 256 | " net = net + inputx #= tf.concat((inputx,net),axis=-1)\n", 257 | " net = tf.nn.elu(net,name=\"{}_res_elu{}\".format(name,i))\n", 258 | " return net\n", 259 | "\n", 260 | "def conv_block(inputx,name,training,block_num=1,filters=2,kernel_size=(1,1)):\n", 261 | " net = inputx\n", 262 | " for i in range(block_num):\n", 263 | " net = tf.layers.conv2d(net,filters=filters,kernel_size=kernel_size,activation=None,name=\"{}_convblock_conv{}\".format(name,i),padding='same')\n", 264 | " net = tf.layers.batch_normalization(net,training=training,name=\"{}_convblock_bn{}\".format(name,i))\n", 265 | " net = tf.nn.elu(net,name=\"{}_convblock_elu{}\".format(name,i))\n", 266 | " # net [None,10,9,2]\n", 267 | " netshape = net.get_shape().as_list()\n", 268 | " print(\"inside conv block {}\".format(str(netshape)))\n", 269 | " net = tf.reshape(net,shape=(-1,netshape[1] * netshape[2] * netshape[3]))\n", 270 | " net = tf.layers.dense(net,10 * 9,name=\"{}_dense\".format(name))\n", 271 | " net = tf.nn.elu(net,name=\"{}_elu\".format(name))\n", 272 | " return net\n", 273 | "\n", 274 | "def res_net_board(inputx,name,training,filters=256):\n", 275 | " net = inputx\n", 276 | " net = tf.layers.conv2d(net,filters=filters,kernel_size=(3,3),activation=None,name=\"{}_res_convb\".format(name),padding='same')\n", 277 | " net = tf.layers.batch_normalization(net,training=training,name=\"{}_res_bnb\".format(name))\n", 278 | " net = tf.nn.elu(net,name=\"{}_res_elub\".format(name))\n", 279 | " for i in range(NUM_RES_LAYERS):\n", 280 | " net = res_block(net,name=\"{}_layer_{}\".format(name,i + 1),training=training)\n", 281 | " print(net.get_shape().as_list())\n", 282 | " print(\"inside res net {}\".format(str(net.get_shape().as_list())))\n", 283 | " net_unsoftmax = conv_block(net,name=\"{}_conv\".format(name),training=training)\n", 284 | " return net_unsoftmax\n", 285 | "\n", 286 | "def get_scatter(name):\n", 287 | " with tf.variable_scope(\"Test\"):\n", 288 | " ph = tf.placeholder(tf.float32,name=name)\n", 289 | " op = tf.summary.scalar(name,ph)\n", 290 | " return ph,op" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 12, 296 | "metadata": {}, 297 | "outputs": [ 298 | { 299 | "name": "stdout", 300 | "output_type": "stream", 301 | "text": [ 302 | "[None, 10, 9, 256]\n", 303 | "[None, 10, 9, 256]\n", 304 | "[None, 10, 9, 256]\n", 305 | "[None, 10, 9, 256]\n", 306 | "[None, 10, 9, 256]\n", 307 | "[None, 10, 9, 256]\n", 308 | "[None, 10, 9, 256]\n", 309 | "[None, 10, 9, 256]\n", 310 | "[None, 10, 9, 256]\n", 311 | "[None, 10, 9, 256]\n", 312 | "inside res net [None, 10, 9, 256]\n", 313 | "inside conv block [None, 10, 9, 2]\n", 314 | "[None, 10, 9, 256]\n", 315 | "[None, 10, 9, 256]\n", 316 | "[None, 10, 9, 256]\n", 317 | "[None, 10, 9, 256]\n", 318 | "[None, 10, 9, 256]\n", 319 | "[None, 10, 9, 256]\n", 320 | "[None, 10, 9, 256]\n", 321 | "[None, 10, 9, 256]\n", 322 | "[None, 10, 9, 256]\n", 323 | "[None, 10, 9, 256]\n", 324 | "inside res net [None, 10, 9, 256]\n", 325 | "inside conv block [None, 10, 9, 2]\n" 326 | ] 327 | } 328 | ], 329 | "source": [ 330 | "tf.reset_default_graph()\n", 331 | "config = tf.ConfigProto()\n", 332 | "config.gpu_options.allow_growth = True\n", 333 | "config.allow_soft_placement = True\n", 334 | "sess = tf.Session(config=config)\n", 335 | "\n", 336 | "NUM_RES_LAYERS = 10\n", 337 | "\n", 338 | "with tf.device(\"/gpu:{}\".format(GPU_CORE)):\n", 339 | " X1 = tf.placeholder(tf.float32,[None,10,9,14])\n", 340 | " y1 = tf.placeholder(tf.float32,[None,10,9])\n", 341 | " X2 = tf.placeholder(tf.float32,[None,10,9,15])\n", 342 | " y2 = tf.placeholder(tf.float32,[None,10,9])\n", 343 | " \n", 344 | " training = tf.placeholder(tf.bool,name='training_mode')\n", 345 | " learning_rate = tf.placeholder(tf.float32)\n", 346 | " global_step = tf.train.get_or_create_global_step()\n", 347 | " \n", 348 | " net_unsoftmax1 = res_net_board(X1,\"selectnet\",training=training)\n", 349 | " net_unsoftmax2 = res_net_board(X2,\"movenet\",training=training)\n", 350 | " \n", 351 | " target1 = tf.reshape(y1,(-1,10 * 9))\n", 352 | " target2 = tf.reshape(y2,(-1,10 * 9))\n", 353 | " with tf.variable_scope(\"Loss\"):\n", 354 | " loss_select = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=target1,logits=net_unsoftmax1))\n", 355 | " loss_move = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=target2,logits=net_unsoftmax2))\n", 356 | " loss = loss_select + loss_move\n", 357 | " \n", 358 | " loss_select_summary = tf.summary.scalar(\"loss_select\",loss_select)\n", 359 | " loss_move_summary = tf.summary.scalar(\"loss_move\",loss_move)\n", 360 | " loss_summary = tf.summary.scalar(\"total_loss\",loss)\n", 361 | " net_softmax1 = tf.nn.softmax(net_unsoftmax1)\n", 362 | " net_softmax2 = tf.nn.softmax(net_unsoftmax2)\n", 363 | " \n", 364 | " correct_prediction1 = tf.equal(tf.argmax(target1,1), tf.argmax(net_softmax1,1))\n", 365 | " correct_prediction2 = tf.equal(tf.argmax(target2,1), tf.argmax(net_softmax2,1))\n", 366 | " \n", 367 | " with tf.variable_scope(\"Accuracy\"):\n", 368 | " accuracy_select = tf.reduce_mean(tf.cast(correct_prediction1, tf.float32))\n", 369 | " accuracy_move = tf.reduce_mean(tf.cast(correct_prediction2, tf.float32))\n", 370 | " accuracy_total = accuracy_select * accuracy_move\n", 371 | " \n", 372 | " acc_select_summary = tf.summary.scalar(\"accuracy_select\",accuracy_select)\n", 373 | " acc_move_summary = tf.summary.scalar(\"accuracy_move\",accuracy_move)\n", 374 | " acc_summary = tf.summary.scalar(\"acc_summary\",accuracy_total)\n", 375 | " \n", 376 | " summary_op = tf.summary.merge([loss_select_summary,loss_move_summary,loss_summary\n", 377 | " ,acc_select_summary,acc_move_summary,acc_summary])\n", 378 | " \n", 379 | " test_select,test_select_summary = get_scatter(\"test_select_loss\")\n", 380 | " test_move,test_move_summary = get_scatter(\"test_move_loss\")\n", 381 | " test_total,test_total_summary = get_scatter(\"test_total_loss\")\n", 382 | " test_selectacc,test_selectacc_summary = get_scatter(\"test_select_acc\")\n", 383 | " test_moveacc,test_moveacc_summary = get_scatter(\"test_move_acc\")\n", 384 | " test_totalacc,test_totalacc_summary = get_scatter(\"test_total_acc\")\n", 385 | " \n", 386 | " test_summary_op = tf.summary.merge([test_select_summary,test_move_summary,test_total_summary\n", 387 | " ,test_selectacc_summary,test_moveacc_summary,test_totalacc_summary])\n", 388 | " \n", 389 | " update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n", 390 | " with tf.control_dependencies(update_ops):\n", 391 | " optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,momentum=0.9)\n", 392 | " train_op = optimizer.minimize(loss,global_step=global_step)\n", 393 | "\n", 394 | " train_summary_writer = tf.summary.FileWriter(\"./log/{}_train\".format(model_name), sess.graph)" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": 13, 400 | "metadata": {}, 401 | "outputs": [ 402 | { 403 | "data": { 404 | "text/plain": [ 405 | "0" 406 | ] 407 | }, 408 | "execution_count": 13, 409 | "metadata": {}, 410 | "output_type": "execute_result" 411 | } 412 | ], 413 | "source": [ 414 | "sess.run(tf.global_variables_initializer())\n", 415 | "tf.train.global_step(sess, global_step)" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": 14, 421 | "metadata": { 422 | "collapsed": true 423 | }, 424 | "outputs": [], 425 | "source": [ 426 | "??tflearn.layers.residual_block" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 15, 432 | "metadata": { 433 | "collapsed": true 434 | }, 435 | "outputs": [], 436 | "source": [ 437 | "import os\n", 438 | "if not os.path.exists(\"models/{}\".format(model_name)):\n", 439 | " os.mkdir(\"models/{}\".format(model_name))" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": 16, 445 | "metadata": { 446 | "collapsed": true 447 | }, 448 | "outputs": [], 449 | "source": [ 450 | "N_BATCH = 10000\n", 451 | "N_BATCH_TEST = 300" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 22, 457 | "metadata": {}, 458 | "outputs": [ 459 | { 460 | "name": "stdout", 461 | "output_type": "stream", 462 | "text": [ 463 | "EPOCH 22 STEP 9999 LR 0.001 ACC 75.04 SELACC78.82 MOVACC95.21 LOSS 0.74 100.00 % [==================================================>] 2560000/2560000 \t used:3149s eta:0 sss eta:0 s\n", 464 | "validating epoch 22 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:54s eta:0 sTEST ACC 0.47125524282455444 SELACC 0.5831900835037231 MOVACC 0.8079296946525574 LOSS 2.4170286655426025\n", 465 | "\n", 466 | "EPOCH 23 STEP 9999 LR 0.001 ACC 74.33 SELACC78.43 MOVACC94.77 LOSS 0.76 100.00 % [==================================================>] 2560000/2560000 \t used:5560s eta:0 sss\n", 467 | "validating epoch 23 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:54s eta:0 sTEST ACC 0.4667249023914337 SELACC 0.5788801908493042 MOVACC 0.8060286641120911 LOSS 2.4456872940063477\n", 468 | "\n", 469 | "EPOCH 24 STEP 9999 LR 0.001 ACC 75.16 SELACC79.3 MOVACC94.79 LOSS 0.74 100.00 % [==================================================>] 2560000/2560000 \t used:5559s eta:0 ssss\n", 470 | "validating epoch 24 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:55s eta:0 sTEST ACC 0.47514840960502625 SELACC 0.5864583253860474 MOVACC 0.8101171851158142 LOSS 2.481386661529541\n", 471 | "\n", 472 | "EPOCH 25 STEP 9999 LR 0.001 ACC 77.37 SELACC80.81 MOVACC95.73 LOSS 0.68 100.00 % [==================================================>] 2560000/2560000 \t used:5562s eta:0 sss\n", 473 | "validating epoch 25 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:55s eta:0 sTEST ACC 0.4594808518886566 SELACC 0.574023425579071 MOVACC 0.8003906011581421 LOSS 2.678408622741699\n", 474 | "\n", 475 | "EPOCH 26 STEP 9999 LR 0.001 ACC 77.06 SELACC80.59 MOVACC95.62 LOSS 0.69 100.00 % [==================================================>] 2560000/2560000 \t used:5561s eta:0 sss\n", 476 | "validating epoch 26 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:54s eta:0 sTEST ACC 0.4758685231208801 SELACC 0.5884114503860474 MOVACC 0.80859375 LOSS 2.5563509464263916\n", 477 | "\n", 478 | "EPOCH 27 STEP 9999 LR 0.001 ACC 76.82 SELACC80.47 MOVACC95.47 LOSS 0.7 100.00 % [==================================================>] 2560000/2560000 \t used:5556s eta:0 ssss\n", 479 | "validating epoch 27 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:55s eta:0 sTEST ACC 0.46077102422714233 SELACC 0.5758333206176758 MOVACC 0.8000390529632568 LOSS 2.7744829654693604\n", 480 | "\n", 481 | "EPOCH 28 STEP 9999 LR 0.001 ACC 78.71 SELACC81.58 MOVACC96.48 LOSS 0.63 100.00 % [==================================================>] 2560000/2560000 \t used:5557s eta:0 sss\n", 482 | "validating epoch 28 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:55s eta:0 sTEST ACC 0.46656712889671326 SELACC 0.5788021087646484 MOVACC 0.8061718940734863 LOSS 2.853325128555298\n", 483 | "\n", 484 | "EPOCH 29 STEP 9999 LR 0.001 ACC 78.16 SELACC81.24 MOVACC96.22 LOSS 0.65 100.00 % [==================================================>] 2560000/2560000 \t used:5561s eta:0 sss\n", 485 | "validating epoch 29 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:55s eta:0 sTEST ACC 0.4673483371734619 SELACC 0.5812109112739563 MOVACC 0.8039583563804626 LOSS 2.8683159351348877\n", 486 | "\n", 487 | "EPOCH 30 STEP 9999 LR 0.0001 ACC 80.65 SELACC83.27 MOVACC96.85 LOSS 0.58 100.00 % [==================================================>] 2560000/2560000 \t used:5559s eta:0 sss\n", 488 | "validating epoch 30 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:54s eta:0 sTEST ACC 0.4682302474975586 SELACC 0.5828906297683716 MOVACC 0.8033333420753479 LOSS 2.967599630355835\n", 489 | "\n", 490 | "EPOCH 31 STEP 9999 LR 0.0001 ACC 79.84 SELACC82.5 MOVACC96.77 LOSS 0.61 100.00 % [==================================================>] 2560000/2560000 \t used:5554s eta:0 ssss\n", 491 | "validating epoch 31 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:59s eta:0 sTEST ACC 0.47560566663742065 SELACC 0.5861197710037231 MOVACC 0.8114323019981384 LOSS 2.9342682361602783\n", 492 | "\n", 493 | "EPOCH 32 STEP 9999 LR 0.0001 ACC 80.61 SELACC83.16 MOVACC96.93 LOSS 0.59 100.00 % [==================================================>] 2560000/2560000 \t used:5562s eta:0 sss\n", 494 | "validating epoch 32 batch 9999 100.00 % [==================================================>] 76800/76800 \t used:55s eta:0 sTEST ACC 0.47255992889404297 SELACC 0.5841666460037231 MOVACC 0.8089713454246521 LOSS 3.0034432411193848\n", 495 | "\n", 496 | "EPOCH 33 STEP 8654 LR 0.0001 ACC 80.35 SELACC82.84 MOVACC96.99 LOSS 0.59 86.55 % [===========================================>-------] 2215680/2560000 \t used:4814s eta:748 ss" 497 | ] 498 | }, 499 | { 500 | "ename": "KeyboardInterrupt", 501 | "evalue": "", 502 | "output_type": "error", 503 | "traceback": [ 504 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 505 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 506 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 48\u001b[0m _,step_loss,step_acc_move,step_acc_select,step_acc_total,step_value,step_summary = sess.run(\n\u001b[1;32m 49\u001b[0m [train_op,loss,accuracy_move,accuracy_select,accuracy_total,global_step,summary_op],feed_dict={\n\u001b[0;32m---> 50\u001b[0;31m \u001b[0mX1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mbatch_x1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0my1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mbatch_y1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mX2\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mbatch_x2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0my2\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mbatch_y2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mlearning_rate\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mbatch_lr\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtraining\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 51\u001b[0m })\n\u001b[1;32m 52\u001b[0m \u001b[0mtrain_summary_writer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_summary\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstep_summary\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mstep_value\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 507 | "\u001b[0;32mC:\\Program Files\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 787\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 788\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 789\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 790\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 791\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 508 | "\u001b[0;32mC:\\Program Files\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 995\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 996\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m--> 997\u001b[0;31m feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[1;32m 998\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 999\u001b[0m \u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 509 | "\u001b[0;32mC:\\Program Files\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1130\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 1131\u001b[0m return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[0;32m-> 1132\u001b[0;31m target_list, options, run_metadata)\n\u001b[0m\u001b[1;32m 1133\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 1134\u001b[0m return self._do_call(_prun_fn, self._session, handle, feed_dict,\n", 510 | "\u001b[0;32mC:\\Program Files\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 1137\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 1138\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1139\u001b[0;31m \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1140\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 1141\u001b[0m \u001b[0mmessage\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 511 | "\u001b[0;32mC:\\Program Files\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 1119\u001b[0m return tf_session.TF_Run(session, options,\n\u001b[1;32m 1120\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1121\u001b[0;31m status, run_metadata)\n\u001b[0m\u001b[1;32m 1122\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 1123\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msession\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 512 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 513 | ] 514 | } 515 | ], 516 | "source": [ 517 | "restore = True\n", 518 | "N_EPOCH = 100\n", 519 | "DECAY_EPOCH = 10\n", 520 | "\n", 521 | "class ExpVal:\n", 522 | " def __init__(self,exp_a=0.97):\n", 523 | " self.val = None\n", 524 | " self.exp_a = exp_a\n", 525 | " def update(self,newval):\n", 526 | " if self.val == None:\n", 527 | " self.val = newval\n", 528 | " else:\n", 529 | " self.val = self.exp_a * self.val + (1 - self.exp_a) * newval\n", 530 | " def getval(self):\n", 531 | " return round(self.val,2)\n", 532 | " \n", 533 | "expacc = ExpVal()\n", 534 | "expacc_select = ExpVal()\n", 535 | "expacc_move = ExpVal()\n", 536 | "exploss = ExpVal()\n", 537 | "\n", 538 | "\n", 539 | "begining_learning_rate = 1e-1\n", 540 | "\n", 541 | "pred_image = None\n", 542 | "if restore == False:\n", 543 | " train_epoch = 1\n", 544 | " train_batch = 0\n", 545 | "for one_epoch in range(train_epoch,N_EPOCH):\n", 546 | " train_epoch = one_epoch\n", 547 | " pb = ProgressBar(worksum=N_BATCH * BATCH_SIZE,info=\" epoch {} batch {}\".format(train_epoch,train_batch))\n", 548 | " pb.startjob()\n", 549 | " \n", 550 | " for one_batch in range(N_BATCH):\n", 551 | " if restore == True and one_batch < train_batch:\n", 552 | " pb.auto_display = False\n", 553 | " pb.complete(BATCH_SIZE)\n", 554 | " pb.auto_display = True\n", 555 | " continue\n", 556 | " else:\n", 557 | " restore = False\n", 558 | " train_batch = one_batch\n", 559 | " \n", 560 | " batch_x1,batch_y1,batch_x2,batch_y2 = trainflow.next()['data']\n", 561 | " # learning rate decay strategy\n", 562 | " batch_lr = begining_learning_rate * 10 ** -(one_epoch // DECAY_EPOCH)\n", 563 | " \n", 564 | " _,step_loss,step_acc_move,step_acc_select,step_acc_total,step_value,step_summary = sess.run(\n", 565 | " [train_op,loss,accuracy_move,accuracy_select,accuracy_total,global_step,summary_op],feed_dict={\n", 566 | " X1:batch_x1,y1:batch_y1,X2:batch_x2,y2:batch_y2,learning_rate:batch_lr,training:True\n", 567 | " })\n", 568 | " train_summary_writer.add_summary(step_summary,step_value)\n", 569 | " step_acc_move *= 100\n", 570 | " step_acc_select *= 100\n", 571 | " step_acc_total *= 100\n", 572 | " expacc.update(step_acc_total)\n", 573 | " expacc_select.update(step_acc_select)\n", 574 | " expacc_move.update(step_acc_move)\n", 575 | " exploss.update(step_loss)\n", 576 | "\n", 577 | " \n", 578 | " pb.info = \"EPOCH {} STEP {} LR {} ACC {} SELACC{} MOVACC{} LOSS {} \".format(\n", 579 | " one_epoch,one_batch,batch_lr,expacc.getval()\n", 580 | " ,expacc_select.getval(),expacc_move.getval(),exploss.getval())\n", 581 | " \n", 582 | " pb.complete(BATCH_SIZE)\n", 583 | " print()\n", 584 | " accs = []\n", 585 | " accselects = []\n", 586 | " accmoves = []\n", 587 | " losses = []\n", 588 | " lossselects = []\n", 589 | " lossmoves = []\n", 590 | " pb = ProgressBar(worksum=N_BATCH_TEST * BATCH_SIZE,info=\"validating epoch {} batch {}\".format(train_epoch,train_batch))\n", 591 | " pb.startjob()\n", 592 | " for one_batch in range(N_BATCH_TEST):\n", 593 | " batch_x1,batch_y1,batch_x2,batch_y2 = testflow.next()['data']\n", 594 | " step_loss_move,step_loss_select,step_loss,step_accuracy_move,step_accuracy_select,step_accuracy_total = sess.run(\n", 595 | " [loss_move,loss_select,loss,accuracy_move,accuracy_select,accuracy_total],feed_dict={\n", 596 | " X1:batch_x1,y1:batch_y1,X2:batch_x2,y2:batch_y2,training:False\n", 597 | " })\n", 598 | " accs.append(step_accuracy_total)\n", 599 | " accselects.append(step_accuracy_select)\n", 600 | " accmoves.append(step_accuracy_move)\n", 601 | " losses.append(step_loss)\n", 602 | " lossselects.append(step_loss_select)\n", 603 | " lossmoves.append(step_loss_move)\n", 604 | " \n", 605 | " pb.complete(BATCH_SIZE)\n", 606 | " print(\"TEST ACC {} SELACC {} MOVACC {} LOSS {}\".format(np.average(accs),np.average(accselects)\n", 607 | " ,np.average(accmoves),np.average(losses)))\n", 608 | " #test_select_summary,test_move_summary,test_total_summary\n", 609 | " # ,test_selectacc_summary,test_moveacc_summary,test_totalacc_summary\n", 610 | " test_to_add_to_log = sess.run(test_summary_op,feed_dict={\n", 611 | " test_select:np.average(lossselects),test_move:np.average(lossmoves),test_total:np.average(losses)\n", 612 | " ,test_selectacc:np.average(accselects),test_moveacc:np.average(accmoves),test_totalacc:np.average(accs)\n", 613 | " })\n", 614 | " train_summary_writer.add_summary(test_to_add_to_log,step_value)\n", 615 | " print()\n", 616 | " saver = tf.train.Saver(var_list=tf.global_variables())\n", 617 | " saver.save(sess,\"models/{}/model_{}\".format(model_name,one_epoch))" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": null, 623 | "metadata": { 624 | "collapsed": true 625 | }, 626 | "outputs": [], 627 | "source": [] 628 | } 629 | ], 630 | "metadata": { 631 | "anaconda-cloud": {}, 632 | "kernelspec": { 633 | "display_name": "Python [conda root]", 634 | "language": "python", 635 | "name": "conda-root-py" 636 | }, 637 | "language_info": { 638 | "codemirror_mode": { 639 | "name": "ipython", 640 | "version": 3 641 | }, 642 | "file_extension": ".py", 643 | "mimetype": "text/x-python", 644 | "name": "python", 645 | "nbconvert_exporter": "python", 646 | "pygments_lexer": "ipython3", 647 | "version": "3.5.2" 648 | } 649 | }, 650 | "nbformat": 4, 651 | "nbformat_minor": 1 652 | } 653 | --------------------------------------------------------------------------------