├── .gitattributes ├── .gitignore ├── README.md ├── __init__.py ├── ai_train.dockerfile ├── conf └── train_config.yaml ├── cpu_train.sh ├── docker-compose.yml ├── evaluate ├── AICollection.py ├── ChessBoard.py ├── ChessClient.py ├── ChessHelper.py ├── ChessServer.py ├── Hall.py ├── README.md ├── ai_listen_ucloud.bat ├── gobang.py ├── logs │ ├── error.log │ └── info.log ├── page │ ├── chessboard.html │ └── login.html ├── run_ai.sh ├── static │ ├── black.png │ ├── blank.png │ ├── jquery-3.3.1.min.js │ ├── touming.png │ └── white.png ├── unittest │ └── ChessBoardTest.py └── yixin_ai │ ├── engine.exe │ ├── msvcp140.dll │ ├── vcruntime140.dll │ └── yixin.exe ├── game.py ├── game_ai.py ├── human_play_mxnet.py ├── inception-resnet-v2.py ├── logs ├── .gitkeep └── download_model.sh ├── mcts_alphaZero.py ├── mcts_pure.py ├── papers └── thinking-fast-and-slow-with-deep-learning-and-tree-search.pdf ├── play_vs_yixin.py ├── policy_value_loss.json ├── policy_value_net_mxnet.py ├── policy_value_net_mxnet_simple.py ├── requirements.txt ├── run_ai.sh ├── run_server.sh ├── self_play.py ├── sgf_data └── download.sh ├── start_train.sh ├── train_mxnet.py └── utils ├── __init__.py ├── config_loader.py ├── send_email.py └── sgf_dataIter.py /.gitattributes: -------------------------------------------------------------------------------- 1 | pickle_ai_data/dump55.txt filter=lfs diff=lfs merge=lfs -text 2 | pickle_ai_data/dump1.txt filter=lfs diff=lfs merge=lfs -text 3 | pickle_ai_data/dump11.txt filter=lfs diff=lfs merge=lfs -text 4 | pickle_ai_data/dump2.txt filter=lfs diff=lfs merge=lfs -text 5 | pickle_ai_data/dump22.txt filter=lfs diff=lfs merge=lfs -text 6 | pickle_ai_data/dump33.txt filter=lfs diff=lfs merge=lfs -text 7 | pickle_ai_data/dump44.txt filter=lfs diff=lfs merge=lfs -text 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__/ 3 | *.swp 4 | *.swo 5 | *.sgf 6 | *.model 7 | policy_value_loss.json 8 | evaluate/chess_output/ 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## AlphaPig 2 | 使用AlphaZero算法在五子棋上的实现。五子棋比围棋简单,训练起来也稍微能够简单一点,所以选择了五子棋来作为AlphaZero的复现。 3 | 4 | 参考: 5 | 1. [junxiaosong/AlphaZero_Gomoku](https://github.com/junxiaosong/AlphaZero_Gomoku) 6 | 2. [starimpact/AlphaZero_Gomoku](https://github.com/starimpact/AlphaZero_Gomoku) 7 | 3. [yonghenglh6/GobangServer](https://github.com/yonghenglh6/GobangServer) 8 | 9 | 10 | ### 快速启动 11 | 12 | ```(确保本机安装docker/nvidia-docker) 13 | docker-compose up # (默认使用0号显卡,根据run_ai.sh 进行适配性修改) 14 | ``` 15 | 16 | 对 run_ai.sh 进行修改后,修改docker-compose.yml: 17 | 18 | ``` 19 | version: "2.3" 20 | 21 | services: 22 | gobang_server: 23 | image: gospelslave/alphapig:v0.1.11 24 | entrypoint: /bin/bash run_server.sh 25 | privileged: true 26 | environment: 27 | - TZ=Asia/Shanghai 28 | volumes: 29 | - $PWD/run_server.sh:/workspace/run_server.sh # 这里修改后的映射 30 | ports: 31 | - 8888:8888 32 | restart: always 33 | logging: 34 | driver: json-file 35 | options: 36 | max-size: "10M" 37 | max-file: "5" 38 | 39 | gobang_ai: 40 | image: gospelslave/alphapig:v0.1.11 41 | entrypoint: /bin/bash run_ai.sh 42 | privileged: true 43 | environment: 44 | - TZ=Asia/Shanghai 45 | volumes: 46 | - $PWD/run_ai.sh:/workspace/run_ai.sh # 这里是修改后的映射 47 | runtime: nvidia 48 | restart: always 49 | logging: 50 | driver: json-file 51 | options: 52 | max-size: "10M" 53 | max-file: "5" 54 | ``` 55 | 56 | ### 上手入门 57 | + cd AlphaPig/sgf_data/ 58 | 59 | + ``` 60 | sh ./download.sh 61 | ``` 62 | 63 | 将会下载并解压SGF棋谱数据,解压后应该实在sgf_data/目录下,目录结构AlphaPig/sgf_data/*.sgf。 64 | 65 | 也可自行下载SGF棋谱数据,并自行处理。 66 | 67 | + 直接运行根目录下的start_train.sh开始训练 68 | 69 | + 或者进入 train_mxnet.py 修改网络结构等参数,其中conf下的.yaml为训练定义的一些参数,可修改为适合自己的相关参数。 70 | 71 | + SGF格式详解 72 | 73 | ``` 74 | FF[4] SGF格式的版本号,4是最新 75 | SZ[15] 棋盘大小,这是15x15 76 | PW[Pig]白棋棋手名称 77 | WR[2a]白棋棋手段位 78 | PB[stupid]黑棋棋手名称 79 | BR[2c]黑棋棋手段位 80 | DT[2018-07-06]棋谱生成日期 81 | PC[CA]棋局所在位置 82 | KM[6.5]贴目数量 83 | RE[B+Resign]B+是黑胜,W+是白胜,Resign是对方GG的 84 | CA[utf-8]棋局编码 85 | TM[0]限时情况,0为无限时 86 | OT[]读秒规则 87 | ;B[pp];W[dd];B[pc];W[dq] …… 棋谱下棋顺序 88 | ``` 89 | 90 | 91 | 92 | + 如果需要和自己的AI对弈,可以进入evaluate目录,运行 93 | 94 | ``` 95 | python ChessServer.py --port 8888 96 | ``` 97 | 98 | 既可与自己的AI进行对弈,或者与yixin对弈。详细说明请参阅evaluate目录下的ReadMe。 99 | 100 | 对弈例子: 101 | 102 | 103 | 104 | 105 | 106 | + 下载我训练的一些模型(还有很多bug) 107 | 108 | ``` 109 | cd AlphaPig/logs 110 | sh ./download_model.sh 111 | ``` 112 | 113 | ## 致谢 114 | 115 | + 源工程请移步[junxiaosong/AlphaZero_Gomoku](https://github.com/junxiaosong/AlphaZero_Gomoku) ,特别感谢大V的很多issue和指导。 116 | 117 | + 特别感谢格灵深瞳提供的很多训练帮助(课程与训练资源上提供了很大支持),没有格灵深瞳的这些帮助,训练起来毫无头绪。 118 | 119 | + 感谢[Uloud](https://www.ucloud.cn/) 提供的P40 AI-train服务,1256小时/实例的训练,验证了不少想法。而且最后还免单了,中间没少打扰技术支持。特别感谢他们。 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from . import game 3 | from . import mcts_pure 4 | from . import policy_value_net_mxnet 5 | from . import policy_value_net_mxnet_simple 6 | -------------------------------------------------------------------------------- /ai_train.dockerfile: -------------------------------------------------------------------------------- 1 | From gospelslave/alphapig:v0.1.3 2 | 3 | WORKDIR /workspace 4 | 5 | ADD ./ /workspace 6 | 7 | RUN /bin/bash -c "source ~/.bashrc" 8 | 9 | 10 | -------------------------------------------------------------------------------- /conf/train_config.yaml: -------------------------------------------------------------------------------- 1 | # 数据目录 2 | sgf_dir: './sgf_data' 3 | # AI对弈数据目录 4 | ai_data_dir: './pickle_ai_data' 5 | # 棋盘设置 6 | board_width: 15 7 | board_height: 15 8 | n_in_row: 5 9 | # 学习率 10 | learn_rate: 0.0004 11 | # 根据KL散度动态调整学习率 12 | lr_multiplier: 1.0 13 | temp: 1.0 14 | # 每次移动的simulations数 15 | n_playout: 400 16 | # TODO: 蒙特卡洛树模拟选择时更多的依靠先验,估值越精确,C就应该偏向深度(越小) 17 | c_puct: 5 18 | # 数据集最大量(双端队列长度) 19 | buffer_size: 2198800 20 | batch_size: 128 21 | epochs: 8 22 | play_batch_size: 1 23 | # KL散度 24 | kl_targ: 0.02 25 | # 每check_freq次 检测对弈成绩 26 | check_freq: 1000 27 | # 检测成绩用的mcts对手的思考深度 28 | pure_mcts_playout_num: 1000 29 | # 训练多少轮 30 | game_batch_num: 240000 31 | 32 | 33 | # 训练日志 34 | train_logging: 35 | version: 1 36 | formatters: 37 | simpleFormater: 38 | format: '%(asctime)s - %(levelname)s - %(name)s[line:%(lineno)d]: %(message)s' 39 | datefmt: '%Y-%m-%d %H:%M:%S' 40 | handlers: 41 | # 标准输出,只要级别在DEBUG以上就会输出 42 | console: 43 | class: logging.StreamHandler 44 | formatter: simpleFormater 45 | level: DEBUG 46 | stream: ext://sys.stdout 47 | # INFO以上,滚动文件,保留20个,每个最大100MB 48 | info_file_handler: 49 | class : logging.FileHandler 50 | formatter: simpleFormater 51 | level: INFO 52 | filename: ./logs/info.log 53 | # ERROR以上 54 | error_file_handler: 55 | class : logging.FileHandler 56 | formatter: simpleFormater 57 | level: ERROR 58 | filename: ./logs/error.log 59 | root: 60 | level: DEBUG 61 | handlers: [console, info_file_handler, error_file_handler] 62 | -------------------------------------------------------------------------------- /cpu_train.sh: -------------------------------------------------------------------------------- 1 | docker run -it \ 2 | -v /home/ubuntu/uai-sdk/examples/mxnet/train/AlphaPig:/data \ 3 | -v /home/ubuntu/uai-sdk/examples/mxnet/train/AlphaPig/sgf_data:/data/data \ 4 | -v /home/ubuntu/uai-sdk/examples/mxnet/train/AlphaPig/logs:/data/output \ 5 | uhub.service.ucloud.cn/uaishare/cpu_uaitrain_ubuntu-14.04_python-2.7.6_mxnet-1.0.0:v1.0 \ 6 | /bin/bash -c "cd /data && /usr/bin/python /data/train_mxnet.py --model-prefix=siler_Alpha --work_dir=/data --output_dir=/data/output" 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "2.3" 2 | 3 | services: 4 | gobang_server: 5 | image: gospelslave/alphapig:v0.1.11 6 | entrypoint: /bin/bash run_server.sh 7 | privileged: true 8 | environment: 9 | - TZ=Asia/Shanghai 10 | ports: 11 | - 8888:8888 12 | restart: always 13 | logging: 14 | driver: json-file 15 | options: 16 | max-size: "10M" 17 | max-file: "5" 18 | 19 | gobang_ai: 20 | image: gospelslave/alphapig:v0.1.11 21 | entrypoint: /bin/bash run_ai.sh 22 | privileged: true 23 | environment: 24 | - TZ=Asia/Shanghai 25 | runtime: nvidia 26 | restart: always 27 | logging: 28 | driver: json-file 29 | options: 30 | max-size: "10M" 31 | max-file: "5" 32 | -------------------------------------------------------------------------------- /evaluate/AICollection.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | import ChessHelper 4 | import threading 5 | from ChessBoard import ChessBoard 6 | from Hall import GameRoom 7 | from Hall import User 8 | import random 9 | import os 10 | from ChessClient import ChessClient 11 | 12 | 13 | class GameStrategy(object): 14 | def __init__(self): 15 | import gobang 16 | self.searcher = gobang.searcher() 17 | 18 | def play_one_piece(self, user, gameboard): 19 | # user = User() 20 | # gameboard = ChessBoard() 21 | turn = user.game_role 22 | self.searcher.board = [[gameboard.get_piece(m, n) for n in xrange(gameboard.SIZE)] for m in 23 | xrange(gameboard.SIZE)] 24 | # gameboard=ChessBoard() 25 | # gameboard.move_history 26 | score, row, col = self.searcher.search(turn, 2) 27 | # print "score:", score 28 | return (row, col) 29 | 30 | 31 | class GameStrategy_yixin(object): 32 | 33 | def __init__(self): 34 | self.muid = str(random.randint(0, 1000000)) 35 | self.comm_folder = 'yixin_comm/' 36 | if not os.path.exists(self.comm_folder): 37 | os.makedirs(self.comm_folder) 38 | self.chess_state_file = self.comm_folder + 'game_state_' + self.muid 39 | self.action_file = self.comm_folder + 'action_' + self.muid 40 | 41 | def play_one_piece(self, user, gameboard): 42 | # user = User() 43 | # gameboard = ChessBoard() 44 | with open(self.chess_state_file, 'w') as chess_state_file: 45 | for userrole, move_num, row, col in gameboard.move_history: 46 | chess_state_file.write('%d,%d\n' % (row, col)) 47 | if os.path.exists(self.action_file): 48 | os.remove(self.action_file) 49 | os.system("yixin_ai\yixin.exe %s %s" % (self.chess_state_file, self.action_file)) 50 | row, col = random.randint(0, 15), random.randint(0, 15) 51 | with open(self.action_file) as action_file: 52 | line = action_file.readline() 53 | row, col = line.strip().split(',') 54 | row, col = int(row), int(col) 55 | 56 | return (row, col) 57 | 58 | 59 | class GameStrategy_random(object): 60 | def __init__(self): 61 | self._chess_helper_move_set = [] 62 | for i in range(15): 63 | for j in range(15): 64 | self._chess_helper_move_set.append((i, j)) 65 | random.shuffle(self._chess_helper_move_set) 66 | self.try_step = 0 67 | 68 | def play_one_piece(self, user, gameboard): 69 | move = self._chess_helper_move_set[self.try_step] 70 | while gameboard.get_piece(move[0], move[1]) != 0 and self.try_step < 15 * 15: 71 | self.try_step += 1 72 | move = self._chess_helper_move_set[self.try_step] 73 | self.try_step += 1 74 | return move 75 | 76 | 77 | class GameCommunicator(threading.Thread): 78 | def __init__(self, roomid, stragegy, server_url): 79 | threading.Thread.__init__(self) 80 | self.room_id = roomid 81 | self.stragegy = stragegy; 82 | self.server_url = server_url 83 | 84 | def run(self): 85 | client = ChessClient(self.server_url) 86 | client.login_in_guest() 87 | client.join_room(self.room_id) 88 | client.join_game() 89 | exp_interval = 0.5 90 | max_exp_interval = 200 91 | while True: 92 | single_max_time = exp_interval * 100 93 | wait_time = client.wait_game_info_changed(interval=exp_interval, max_time=single_max_time) 94 | if wait_time > single_max_time: 95 | exp_interval *= 2 96 | else: 97 | exp_interval = 0.5 98 | if exp_interval > max_exp_interval: 99 | exp_interval = max_exp_interval 100 | room = client.get_room_info() 101 | user = client.get_user_info() 102 | gameboard = client.get_game_info() 103 | print 'waittime:',wait_time,',room_status:',room.get_status(),',ask_take_back:',room.ask_take_back 104 | if room.get_status() == 1 or room.get_status() == 2: 105 | continue 106 | elif room.get_status() == 3: 107 | if room.ask_take_back != 0 and room.ask_take_back != user.game_role: 108 | client.answer_take_back() 109 | continue 110 | if gameboard.get_current_user() == user.game_role: 111 | one_legal_piece = self.stragegy.play_one_piece(user, gameboard) 112 | action_result = client.put_piece(*one_legal_piece) 113 | client.wait_game_info_changed(interval=exp_interval, max_time=single_max_time) 114 | if action_result['id'] != 0: 115 | print ChessHelper.numToAlp(one_legal_piece[0]), ChessHelper.numToAlp(one_legal_piece[1]) 116 | print action_result['info'] 117 | break 118 | continue 119 | elif room.get_status() == 4: 120 | break 121 | 122 | 123 | class GameListener(object): 124 | def __init__(self, prefix_stategy_map, server_url): 125 | self.client = ChessClient(server_url) 126 | self.client.login_in_guest() 127 | self.prefix_stategy_map = prefix_stategy_map 128 | self.server_url = server_url 129 | self.accupied = set() 130 | 131 | def listen(self): 132 | while True: 133 | all_rooms = self.client.get_all_rooms() 134 | for room in all_rooms: 135 | room_name = room[0] 136 | room_status = room[1] 137 | for prefix in self.prefix_stategy_map: 138 | if room_name.startswith(prefix) and room_status == GameRoom.ROOM_STATUS_ONEWAITING: 139 | if room_name in self.accupied: 140 | continue 141 | print 'Evoke:', room_name 142 | strg = self.prefix_stategy_map[prefix]() 143 | self.accupied.add(room_name) 144 | commu = GameCommunicator(room_name, strg, self.server_url) 145 | commu.start() 146 | break 147 | time.sleep(4) 148 | 149 | 150 | def go_listen(): 151 | import argparse 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument('--server_url', default='http://120.132.59.147:11111') 154 | args = parser.parse_args() 155 | 156 | 157 | if args.server_url.endswith('/'): 158 | args.server_url = args.server_url[:-1] 159 | if not args.server_url.startswith('http://'): 160 | args.server_url = 'http://' + args.server_url 161 | 162 | prefix_stategy_map = {'ai_': lambda: GameStrategy(), 'yixin_': lambda: GameStrategy_yixin(), 163 | 'random_': lambda: GameStrategy_random()} 164 | listen = GameListener(prefix_stategy_map, args.server_url) 165 | listen.listen() 166 | 167 | 168 | if __name__ == "__main__": 169 | go_listen() 170 | -------------------------------------------------------------------------------- /evaluate/ChessBoard.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import cPickle as pickle 4 | 5 | 6 | class ChessBoard(object): 7 | """ChessBoard 8 | 9 | Attributes: 10 | SIZE: The chess board's size. 11 | board: To store the board information. 12 | state: Indicate if the game is over. 13 | current_user: The user who put the next piece. 14 | """ 15 | 16 | STATE_RUNNING = 0 17 | STATE_DONE = 1 18 | STATE_ABORT = 1 19 | 20 | PIECE_STATE_BLANK = 0 21 | PIECE_STATE_FIRST = 1 22 | PIECE_STATE_SECOND = 2 23 | 24 | PAD = 4 25 | 26 | CHECK_DIRECTION = [[[0, 1], [0, -1]], [[1, 0], [-1, 0]], [[1, 1], [-1, -1]], [[1, -1], [-1, 1]]] 27 | 28 | def __init__(self, size=15): 29 | self.SIZE = size 30 | self.board = np.zeros((self.SIZE + ChessBoard.PAD * 2, self.SIZE + ChessBoard.PAD * 2), dtype=np.uint8) 31 | self.state = ChessBoard.STATE_RUNNING 32 | self.current_user = ChessBoard.PIECE_STATE_FIRST 33 | 34 | self.move_num = 0 35 | self.move_history = [] 36 | 37 | self.dump_cache = None 38 | 39 | def changed(func): 40 | def wrapper_func(self,*args, **kwargs): 41 | ret=func(self,*args, **kwargs) 42 | self.dump_cache = None 43 | return ret 44 | 45 | return wrapper_func 46 | 47 | def get_piece(self, row, col): 48 | return self.board[row + ChessBoard.PAD, col + ChessBoard.PAD] 49 | 50 | @changed 51 | def set_piece(self, row, col, user): 52 | self.board[row + ChessBoard.PAD, col + ChessBoard.PAD] = user 53 | 54 | @changed 55 | def put_piece(self, row, col, user): 56 | """Put a piece in the board and check if he wins. 57 | Returns: 58 | 0 successful move. 59 | 1 successful and win move. 60 | -1 move out of range. 61 | -2 piece has been occupied. 62 | -3 game is over 63 | -4 not your turn. 64 | """ 65 | if row < 0 or row >= self.SIZE or col < 0 or col >= self.SIZE: 66 | return -1 67 | if self.get_piece(row, col) != ChessBoard.PIECE_STATE_BLANK: 68 | return -2 69 | if self.state != ChessBoard.STATE_RUNNING: 70 | return -3 71 | if user != self.current_user: 72 | return -4 73 | 74 | self.set_piece(row, col, user) 75 | self.move_num += 1 76 | self.move_history.append((user, self.move_num, row, col,)) 77 | # self.last_move = (row, col) 78 | 79 | # check if win 80 | for dx in xrange(4): 81 | connected_piece_num = 1 82 | for dy in xrange(2): 83 | current_direct = ChessBoard.CHECK_DIRECTION[dx][dy] 84 | c_row = row 85 | c_col = col 86 | 87 | # if else realization 88 | for dz in xrange(4): 89 | c_row += current_direct[0] 90 | c_col += current_direct[1] 91 | if self.get_piece(c_row, c_col) == user: 92 | connected_piece_num += 1 93 | else: 94 | break 95 | 96 | # remove if, but not faster 97 | # p = 1 98 | # for dz in xrange(4): 99 | # c_row += current_direct[0] 100 | # c_col += current_direct[1] 101 | # p = p & (self.board[c_row, c_col] == user) 102 | # connected_piece_num += p 103 | 104 | if connected_piece_num >= 5: 105 | self.state = ChessBoard.STATE_DONE 106 | return 1 107 | 108 | if self.current_user == ChessBoard.PIECE_STATE_SECOND: 109 | self.current_user = ChessBoard.PIECE_STATE_FIRST 110 | else: 111 | self.current_user = ChessBoard.PIECE_STATE_SECOND 112 | 113 | if self.move_num == self.SIZE * self.SIZE: 114 | # self.state = ChessBoard.STATE_DONE 115 | self.state = ChessBoard.STATE_ABORT 116 | 117 | return 0 118 | 119 | def get_winner(self): 120 | return self.current_user if self.state == ChessBoard.STATE_DONE else -1 121 | 122 | def get_state(self): 123 | return self.state 124 | 125 | def get_current_user(self): 126 | return self.current_user 127 | 128 | def get_lastmove(self): 129 | return self.move_history[-1] if len(self.move_history) > 0 else (-1, -1, -1, -1) 130 | 131 | @changed 132 | def take_one_back(self): 133 | if len(self.move_history) > 0: 134 | last_move = self.move_history.pop() 135 | self.set_piece(last_move[-2], last_move[-1], ChessBoard.PIECE_STATE_BLANK) 136 | self.move_num -= 1 137 | 138 | if self.current_user == ChessBoard.PIECE_STATE_SECOND: 139 | self.current_user = ChessBoard.PIECE_STATE_FIRST 140 | else: 141 | self.current_user = ChessBoard.PIECE_STATE_SECOND 142 | 143 | def is_over(self): 144 | return self.state == ChessBoard.STATE_DONE or self.state == ChessBoard.STATE_ABORT 145 | 146 | def dumps(self): 147 | if self.dump_cache is None: 148 | self.dump_cache = pickle.dumps((self.SIZE, self.board, self.state, self.current_user, self.move_history)) 149 | return self.dump_cache 150 | 151 | @changed 152 | def loads(self, chess_str): 153 | self.SIZE, self.board, self.state, self.current_user, self.move_history = pickle.loads(chess_str) 154 | 155 | 156 | @changed 157 | def reset(self): 158 | self.board = np.zeros((self.SIZE + ChessBoard.PAD * 2, self.SIZE + ChessBoard.PAD * 2), dtype=np.uint8) 159 | self.state = ChessBoard.STATE_RUNNING 160 | self.current_user = ChessBoard.PIECE_STATE_FIRST 161 | self.move_num = 0 162 | self.move_history = [] 163 | 164 | @changed 165 | def abort(self): 166 | self.state = ChessBoard.STATE_ABORT 167 | -------------------------------------------------------------------------------- /evaluate/ChessClient.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import requests 3 | import cookielib 4 | # import http.cookiejar 5 | from bs4 import BeautifulSoup 6 | import json 7 | import time 8 | import cPickle as pickle 9 | # import _pickle as pickle 10 | import ChessHelper 11 | from ChessBoard import ChessBoard 12 | from Hall import GameRoom 13 | from Hall import User 14 | import random 15 | import sys 16 | import os 17 | # 方便引入 AlphaPig 18 | abs_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../') 19 | sys.path.insert(0, abs_path) 20 | import AlphaPig as gomoku_zm 21 | 22 | 23 | class ChessClient(): 24 | def __init__(self, server_url): 25 | self.session = requests.Session() 26 | self.session.cookies = cookielib.CookieJar() 27 | agent = 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Maxthon/5.1.2.3000 Chrome/55.0.2883.75 Safari/537.36' 28 | self.headers = { 29 | "Host": server_url, 30 | "Origin": server_url, 31 | "Referer": server_url, 32 | 'User-Agent': agent 33 | } 34 | self.server_url = server_url 35 | self.board = ChessBoard() 36 | self.last_status_signature = "" 37 | 38 | def send_get(self, url): 39 | return self.session.get(self.server_url + url, headers=self.headers) 40 | 41 | def send_post(self, url, data): 42 | return self.session.post(self.server_url + url, data, headers=self.headers) 43 | 44 | def login_in_guest(self): 45 | response = self.send_get('/login?action=login_in_guest') 46 | soup = BeautifulSoup(response.content, "html.parser") 47 | username_span = soup.find('span', attrs={'id': 'username'}) 48 | if username_span: 49 | return username_span.text 50 | else: 51 | return None 52 | 53 | def login(self, username, password): 54 | response = self.send_post('/login?action=login', 55 | data={'username': username, 'password': password}) 56 | soup = BeautifulSoup(response.content, "html.parser") 57 | username_span = soup.find('span', attrs={'id': 'username'}) 58 | if username_span: 59 | return username_span.text 60 | else: 61 | return None 62 | 63 | def logout(self): 64 | self.send_get('/login?action=logout') 65 | 66 | def join_room(self, roomid): 67 | response = self.send_post('/action?action=joinroom', 68 | data={'roomid': roomid}) 69 | action_result = json.loads(response.content) 70 | return action_result 71 | 72 | def join_game(self, cur_role): 73 | #response = self.send_get('/action?action=joingame') 74 | response = self.send_post('/action?action=joingame', 75 | data={'position': str(cur_role)}) 76 | action_result = json.loads(response.content) 77 | return action_result 78 | 79 | def put_piece(self, row, col): 80 | response = self.send_get( 81 | '/action?action=gameaction&actionid=%s&piece_i=%d&piece_j=%d' % ('put_piece', row, col)) 82 | action_result = json.loads(response.content) 83 | return action_result 84 | 85 | def get_room_info(self): 86 | response = self.send_get( 87 | '/action?action=gameaction&actionid=%s' % 'get_room_info') 88 | action_result = json.loads(response.content) 89 | room = pickle.loads(str(action_result['info'])) 90 | return room 91 | 92 | def get_game_info(self): 93 | response = self.send_get( 94 | '/action?action=gameaction&actionid=%s' % 'get_game_info') 95 | action_result = json.loads(response.content) 96 | room = pickle.loads(str(action_result['info'])) 97 | return room 98 | 99 | def get_user_info(self): 100 | response = self.send_get( 101 | '/action?action=gameaction&actionid=%s' % 'get_user_info') 102 | action_result = json.loads(response.content) 103 | room = pickle.loads(str(action_result['info'])) 104 | return room 105 | 106 | def wait_game_info_changed(self, interval=0.5, max_time=100): 107 | wait_time = 0 108 | assert interval > 0, "interval must be positive" 109 | while True: 110 | response = self.send_get( 111 | '/action?action=gameaction&actionid=%s' % ('get_status_signature')) 112 | action_result = json.loads(response.content) 113 | if action_result['id'] == 0: 114 | status_signature = action_result['info'] 115 | if self.last_status_signature != status_signature: 116 | self.last_status_signature = status_signature 117 | break 118 | else: 119 | print("ERROR get_status_signature,", action_result['id'], action_result['info']) 120 | break 121 | time.sleep(interval) 122 | wait_time += interval 123 | if wait_time > max_time: 124 | break 125 | 126 | return wait_time 127 | 128 | def get_all_rooms(self): 129 | response = self.send_get( 130 | '/action?action=get_all_rooms') 131 | action_result = json.loads(response.content) 132 | all_rooms = action_result['info'] 133 | return all_rooms 134 | 135 | def answer_take_back(self, agree=True): 136 | response = self.send_get( 137 | '/action?action=gameaction&actionid=answer_take_back&agree=' + ('true' if agree else 'false')) 138 | action_result = json.loads(response.content) 139 | return action_result 140 | 141 | 142 | class GameStrategy_random(): 143 | def __init__(self): 144 | self._chess_helper_move_set = [] 145 | for i in range(15): 146 | for j in range(15): 147 | self._chess_helper_move_set.append((i, j)) 148 | random.shuffle(self._chess_helper_move_set) 149 | self.try_step = 0 150 | 151 | def play_one_piece(self, user, gameboard): 152 | move = self._chess_helper_move_set[self.try_step] 153 | while gameboard.get_piece(move[0], move[1]) != 0 and self.try_step < 15 * 15: 154 | self.try_step += 1 155 | move = self._chess_helper_move_set[self.try_step] 156 | self.try_step += 1 157 | return move 158 | 159 | 160 | class GameStrategy_MZhang(): 161 | def __init__(self, startplayer=0, complex_='s'): 162 | abs_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../') 163 | # 普通卷积 164 | """ 165 | model_file2 = os.path.join(abs_path, './logs/current_policy_tf_small.model') 166 | model_file1 = os.path.join(abs_path, './logs/best_policy_tf_10999.model') 167 | model_file3 = os.path.join(abs_path, './logs/current_policy_1024.model') 168 | # 残差5层 169 | model_file_res = os.path.join(abs_path, './logs/current_res_5.model') 170 | model_file_res10 = os.path.join(abs_path, './logs/current_res_10.model') 171 | """ 172 | model_file1 = os.path.join(abs_path, "./logs/current_policy.model") 173 | model_file2 = os.path.join(abs_path, "./logs/current_policy.model") 174 | model_file3 = os.path.join(abs_path, "./logs/current_policy.model") 175 | model_file_res = os.path.join(abs_path, "./logs/current_policy.model") 176 | model_file_res10 = os.path.join(abs_path, "./logs/current_policy.model") 177 | 178 | policy_param = None 179 | self.height = 15 180 | self.width = 15 181 | if model_file1 is not None: 182 | print('loading...', model_file1) 183 | try: 184 | policy_param = pickle.load(open(model_file1, 'rb')) 185 | except: 186 | policy_param = pickle.load(open(model_file1, 'rb'), encoding='bytes') 187 | 188 | if complex_ == 's': 189 | policy_value_net = gomoku_zm.policy_value_net_mxnet_simple.PolicyValueNet(self.height, self.width, batch_size=16, model_params=policy_param) 190 | self.mcts_player = gomoku_zm.mcts_alphaZero.MCTSPlayer(policy_value_net.policy_value_fn, c_puct=3, n_playout=80) # n_playout: 160 太久了 改成 80 191 | # 500:50s 200:32s 192 | # 残差网络 5层 193 | elif complex_ == 'r': 194 | if model_file_res is not None: 195 | try: 196 | policy_param_res = pickle.load(open(model_file_res, 'rb')) 197 | except: 198 | policy_param_res = pickle.load(open(model_file_res, 'rb'), encoding='bytes') 199 | policy_value_net_res = gomoku_zm.policy_value_net_mxnet.PolicyValueNet(self.height, self.width, batch_size=16, n_blocks=5, n_filter=128, model_params=policy_param_res) 200 | self.mcts_player = gomoku_zm.mcts_alphaZero.MCTSPlayer(policy_value_net_res.policy_value_fn, c_puct=3, n_playout=1290) 201 | # 10层残差 202 | elif complex_ == 'r10': 203 | if model_file_res10 is not None: 204 | try: 205 | policy_param_res = pickle.load(open(model_file_res10, 'rb')) 206 | except: 207 | policy_param_res = pickle.load(open(model_file_res10, 'rb'), encoding='bytes') 208 | policy_value_net_res = gomoku_zm.policy_value_net_mxnet.PolicyValueNet(self.height, self.width, batch_size=16, n_blocks=10, n_filter=128, model_params=policy_param_res) 209 | self.mcts_player = gomoku_zm.mcts_alphaZero.MCTSPlayer(policy_value_net_res.policy_value_fn, c_puct=3, n_playout=990) 210 | else: 211 | print("*=*" * 20, ": 模型参数错误") 212 | self.board = gomoku_zm.game.Board(width=self.width, height=self.height, n_in_row=5) 213 | self.board.init_board(startplayer) 214 | self.game = gomoku_zm.game.Game(self.board) 215 | p1, p2 = self.board.players 216 | print('players:', p1, p2) 217 | self.mcts_player.set_player_ind(p1) 218 | pass 219 | 220 | def play_one_piece(self, user, gameboard): 221 | print('user:', gameboard.get_current_user()) 222 | print('gameboard:', gameboard.move_history) 223 | lastm = gameboard.get_lastmove() 224 | if lastm[0] != -1: 225 | usr, n, row, col = lastm 226 | mv = (self.height-row-1)*self.height+col 227 | if not self.board.states.has_key(mv): 228 | self.board.do_move(mv) 229 | 230 | print('board:', self.board.states.items()) 231 | move = self.mcts_player.get_action(self.board) 232 | print('***' * 10) 233 | print('move: ', move) 234 | print('\n') 235 | self.board.do_move(move) 236 | # self.game.graphic(self.board, *self.board.players) 237 | outmv = (self.height-move//self.height-1, move%self.width) 238 | 239 | return outmv 240 | 241 | 242 | 243 | def go_play(args): 244 | 245 | client = ChessClient(args.server_url) 246 | client.login_in_guest() 247 | client.join_room(args.room_name) 248 | client.join_game(args.cur_role) 249 | user = client.get_user_info() 250 | print("加入游戏成功,你是:" + ("黑方" if user.game_role == 1 else "白方")) 251 | print('model is: ', args.model) 252 | 253 | if args.ai == 'random': 254 | strategy = GameStrategy_MZhang(user.game_role-1, args.model) 255 | else: 256 | assert False, "No other ai, you can add one or import the AICollection's ai." 257 | 258 | while True: 259 | wait_time = client.wait_game_info_changed() 260 | print('wait_time:', wait_time) 261 | 262 | room = client.get_room_info() 263 | # room=GameRoom() 264 | user = client.get_user_info() 265 | # user=User() 266 | gameboard = client.get_game_info() 267 | # gameboard = ChessBoard() 268 | 269 | print('room.get_status():', room.get_status()) 270 | print('user.game_status():', user.game_status) 271 | print('gameboard.game_status():') 272 | ChessHelper.printBoard(gameboard) 273 | 274 | if room.get_status() == GameRoom.ROOM_STATUS_NOONE or room.get_status() == GameRoom.ROOM_STATUS_ONEWAITING: 275 | print("等待另一个对手加入游戏:") 276 | continue 277 | elif room.get_status() == GameRoom.ROOM_STATUS_PLAYING: 278 | if room.ask_take_back != 0 and room.ask_take_back != user.game_role: 279 | client.answer_take_back() 280 | return 0 281 | # break 282 | if gameboard.get_current_user() == user.game_role: 283 | print("轮到你走:") 284 | current_time = time.time() 285 | one_legal_piece = strategy.play_one_piece(user, gameboard) 286 | action_result = client.put_piece(*one_legal_piece) 287 | print('***' * 20) 288 | print('COST TIME: %s' % (time.time() - current_time)) 289 | if action_result['id'] != 0: 290 | print("走棋失败:") 291 | print(ChessHelper.numToAlp(one_legal_piece[0]), ChessHelper.numToAlp(one_legal_piece[1])) 292 | print(action_result['info']) 293 | 294 | else: 295 | print("轮到对手走....") 296 | continue 297 | elif room.get_status() == GameRoom.ROOM_STATUS_FINISH: 298 | print("游戏已经结束了," + ("黑方" if gameboard.get_winner() == 1 else "白方") + " 赢了") 299 | return 1 300 | # break 301 | 302 | 303 | if __name__ == "__main__": 304 | import argparse 305 | import time 306 | parser = argparse.ArgumentParser() 307 | # yixin_XX_X 格式将会进入和yixin的对战中 308 | temp_room_name = 'yixin_anxingle_' + str(time.time()) 309 | parser.add_argument('--room_name', type=str, default=temp_room_name) 310 | parser.add_argument('--cur_role', default=1) 311 | parser.add_argument('--model', default='s') 312 | print('room_name: ', temp_room_name) 313 | parser.add_argument('--server_url', default='http://127.0.0.1:33512') 314 | parser.add_argument('--ai', default='random') 315 | args = parser.parse_args() 316 | while True: 317 | status = go_play(args) 318 | if status == 1: 319 | time.sleep(5) 320 | else: 321 | print("房间有人!") 322 | time.sleep(10) 323 | -------------------------------------------------------------------------------- /evaluate/ChessHelper.py: -------------------------------------------------------------------------------- 1 | from ChessBoard import ChessBoard 2 | import random 3 | from line_profiler import LineProfiler 4 | 5 | 6 | def numToAlp(num): 7 | return chr(ord('A') + num) 8 | 9 | 10 | def transferSymbol(sym): 11 | if sym == 0: 12 | return "." 13 | if sym == 1: 14 | return "O" 15 | if sym == 2: 16 | return "X" 17 | return "E" 18 | 19 | 20 | def printBoard(chessboard): 21 | for i in range(chessboard.SIZE + 1): 22 | for j in range(chessboard.SIZE + 1): 23 | if i == 0 and j == 0: 24 | print(' ') 25 | elif i == 0: 26 | print(numToAlp(j - 1)) 27 | elif j == 0: 28 | print(numToAlp(i - 1)) 29 | else: 30 | print(transferSymbol(chessboard.get_piece(i - 1, j - 1))) 31 | print() 32 | 33 | def printBoard2Str(chessboard): 34 | info_array=[] 35 | 36 | for i in range(chessboard.SIZE + 1): 37 | info_str = "" 38 | for j in range(chessboard.SIZE + 1): 39 | 40 | if i == 0 and j == 0: 41 | info_str+= ' ' 42 | elif i == 0: 43 | info_str += numToAlp(j - 1) 44 | elif j == 0: 45 | info_str +=numToAlp(i - 1) 46 | else: 47 | info_str +=transferSymbol(chessboard.get_piece(i - 1, j - 1)) 48 | info_str += '\t' 49 | # info_str +='\n' 50 | 51 | info_array.append(info_str) 52 | return info_array 53 | 54 | def playRandomGame(chessboard): 55 | muser = 1 56 | _chess_helper_move_set = [] 57 | for i in range(15): 58 | for j in range(15): 59 | _chess_helper_move_set.append((i, j)) 60 | random.shuffle(_chess_helper_move_set) 61 | for move in _chess_helper_move_set: 62 | if chessboard.is_over(): 63 | # print "No place to put." 64 | return -5 65 | 66 | r_row = move[0] 67 | r_col = move[1] 68 | 69 | return_value = chessboard.put_piece(r_row, r_col, muser) 70 | muser = 2 if muser == 1 else 1 71 | if return_value != 0: 72 | # print ("\n%s win one board. last move is %s %s, return value is %d" % ( 73 | # transferSymbol(chessboard.get_winner()), numToAlp(r_row), numToAlp(r_col), return_value)) 74 | return return_value 75 | 76 | 77 | if __name__ == '__main__': 78 | def linePro(): 79 | lp = LineProfiler() 80 | 81 | # lp_wrapper = lp(cb.put_piece) 82 | # lp_wrapper(7, 7, 1) 83 | 84 | def playMuch(num): 85 | oi = 0 86 | for i in xrange(num): 87 | cb = ChessBoard() 88 | playRandomGame(cb) 89 | oi += cb.move_num 90 | print(oi / num) 91 | 92 | lp_wrapper = lp(playMuch) 93 | lp_wrapper(1000) 94 | lp.print_stats() 95 | 96 | 97 | def playRandom(): 98 | cb = ChessBoard() 99 | return_value = playRandomGame(cb) 100 | printBoard(cb) 101 | print ("\n%s win one board. last move is %s %s, return value is %d" % ( 102 | transferSymbol(cb.get_winner()), numToAlp(cb.get_lastmove()[0]), numToAlp(cb.get_lastmove()[1]), 103 | return_value)) 104 | 105 | 106 | playRandom() 107 | #linePro() 108 | -------------------------------------------------------------------------------- /evaluate/ChessServer.py: -------------------------------------------------------------------------------- 1 | import tornado.ioloop 2 | import tornado.web 3 | import os 4 | import argparse 5 | 6 | from Hall import Hall 7 | from Hall import GameRoom 8 | from Hall import User 9 | 10 | hall = Hall() 11 | 12 | 13 | class BaseHandler(tornado.web.RequestHandler): 14 | def get_current_user(self): 15 | return self.get_secure_cookie("username") 16 | 17 | 18 | class ChessHandler(BaseHandler): 19 | 20 | def get_post(self): 21 | info_ = "" 22 | user = hall.get_user_with_uid(self.current_user) 23 | room = user.game_room 24 | chess_board = room.board if room else None 25 | self.render("page/chessboard.html", username=self.current_user, room=room, 26 | chess_board=chess_board, user=user) 27 | 28 | @tornado.web.authenticated 29 | def get(self): 30 | self.get_post() 31 | 32 | @tornado.web.authenticated 33 | def post(self): 34 | self.get_post() 35 | 36 | 37 | class ActionHandler(BaseHandler): 38 | 39 | def get_post(self): 40 | action = self.get_argument("action", None) 41 | action_result = {"id": -1, "info": "Failure"} 42 | 43 | if action: 44 | if action == "joinroom": 45 | roomid = self.get_argument("roomid", None) 46 | if roomid: 47 | hall.join_room(self.current_user, roomid) 48 | action_result["id"] = 0 49 | action_result["info"] = "Join room success." 50 | else: 51 | action_result["id"] = -1 52 | action_result["info"] = "Not legal room id." 53 | 54 | elif action == "joingame": 55 | user_role = int(self.get_argument("position", -1)) 56 | 57 | if hall.join_game(self.current_user, user_role) == 0: 58 | action_result["id"] = 0 59 | action_result["info"] = "Join game success." 60 | 61 | else: 62 | action_result["id"] = -1 63 | action_result["info"] = "Join game failed, join a room first or have joined game or game is full." 64 | 65 | elif action == "gameaction": 66 | actionid = self.get_argument("actionid", None) 67 | game_action_result = hall.game_action(self.current_user, actionid, self) 68 | if game_action_result.result_id == 0: 69 | action_result["id"] = 0 70 | action_result["info"] = game_action_result.result_info 71 | else: 72 | action_result["id"] = -1 73 | action_result["info"] = "Game action failed:" + str( 74 | game_action_result.result_id) + "," + game_action_result.result_info 75 | elif action == "getboardinfo": 76 | room = hall.get_room_with_user(self.current_user) 77 | # room=GameRoom() 78 | if room: 79 | action_result["id"] = 0 80 | action_result["info"] = room.board.dumps() 81 | else: 82 | action_result["id"] = -1 83 | action_result["info"] = "Not in room, please join one." 84 | elif action == "get_all_rooms": 85 | action_result["id"] = 0 86 | action_result["info"] = [[room_name, hall.id2room[room_name].get_status()] for room_name in 87 | hall.id2room] 88 | elif action == "reset_room": 89 | user = hall.get_user_with_uid(self.current_user) 90 | room = user.game_room 91 | # room=GameRoom() 92 | if room and room.get_status() == GameRoom.ROOM_STATUS_FINISH and user in room.play_users: 93 | room.reset_game() 94 | action_result["id"] = 0 95 | action_result["info"] = "reset success" 96 | else: 97 | action_result["id"] = -1 98 | action_result["info"] = "reset failed." 99 | 100 | else: 101 | action_result["id"] = -1 102 | action_result["info"] = "Not recognition action" + action 103 | else: 104 | action_result["info"] = "Not action arg set" 105 | 106 | # self.write(tornado.escape.json_encode(action_result)) 107 | self.finish(action_result) 108 | 109 | @tornado.web.authenticated 110 | def get(self): 111 | self.get_post() 112 | 113 | @tornado.web.authenticated 114 | def post(self): 115 | self.get_post() 116 | 117 | 118 | class LoginHandler(BaseHandler): 119 | 120 | def get_post(self): 121 | 122 | action = self.get_argument("action", None) 123 | 124 | if action == "login": 125 | if self.current_user is not None: 126 | self.clear_cookie("username") 127 | hall.logout(self.current_user) 128 | 129 | username = self.get_argument("username") 130 | password = self.get_argument("password") 131 | username = hall.login(username, password) 132 | if username: 133 | self.set_secure_cookie("username", username) 134 | self.redirect("/") 135 | else: 136 | self.redirect("/login?status=wrong_password_or_name") 137 | elif action == "login_in_guest": 138 | if self.current_user is not None: 139 | self.clear_cookie("username") 140 | hall.logout(self.current_user) 141 | 142 | username = hall.login_in_guest() 143 | print(username) 144 | if username: 145 | self.set_secure_cookie("username", username) 146 | self.redirect("/") 147 | elif action == "logout": 148 | if self.current_user is not None: 149 | self.clear_cookie("username") 150 | hall.logout(self.current_user) 151 | self.redirect("/login") 152 | else: 153 | self.render('page/login.html') 154 | 155 | def get(self): 156 | self.get_post() 157 | 158 | def post(self): 159 | self.get_post() 160 | 161 | 162 | def main(listen_port): 163 | settings = { 164 | # "template_path": os.path.join(os.path.dirname(__file__), "templates"), 165 | "cookie_secret": "bZJc2sWbQLKos6GkHn/VB9oXwQt8S0R0kRvJ5/xJ89E=", 166 | # "xsrf_cookies": True, 167 | "login_url": "/login", 168 | "static_path": os.path.join(os.path.dirname(__file__), "static"), 169 | } 170 | app = tornado.web.Application([ 171 | (r"/", ChessHandler), 172 | (r"/login", LoginHandler), 173 | (r"/action", ActionHandler), 174 | ], **settings) 175 | app.listen(listen_port) 176 | tornado.ioloop.IOLoop.current().start() 177 | 178 | 179 | if __name__ == "__main__": 180 | parser = argparse.ArgumentParser() 181 | parser.add_argument('--port', default=8888) 182 | args = parser.parse_args() 183 | main(args.port) 184 | -------------------------------------------------------------------------------- /evaluate/Hall.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from ChessBoard import ChessBoard 4 | import cPickle as pickle 5 | # import _pickle as pickle 6 | import random 7 | import string 8 | import os 9 | 10 | 11 | class User(object): 12 | def __init__(self, uid_, hall_): 13 | self.uid = uid_ 14 | self.game_room = None 15 | self.hall = hall_ 16 | self.game_status = User.USER_GAME_STATUS_NOJOIN 17 | self.game_role = -1 18 | 19 | USER_GAME_STATUS_NOJOIN = 0 20 | USER_GAME_STATUS_GAMEJOINED = 1 21 | 22 | def send_message(self): 23 | pass 24 | 25 | def receive_message(self, message): 26 | pass 27 | 28 | def send_game_state(self): 29 | pass 30 | 31 | def action(self, action): 32 | pass 33 | 34 | def join_room(self, room): 35 | if self.game_room is not None: 36 | self.leave_room() 37 | 38 | if room.join_room(self) == GameRoom.ACTION_SUCCESS: 39 | self.game_room = room 40 | else: 41 | return -1 42 | 43 | def join_game(self, user_role): 44 | if self.game_room is None: 45 | return -1 46 | 47 | if self.game_status == User.USER_GAME_STATUS_GAMEJOINED: 48 | return -1 49 | if self.game_room.join_game(self, user_role) == GameRoom.ACTION_SUCCESS: 50 | self.game_status = User.USER_GAME_STATUS_GAMEJOINED 51 | return 0 52 | else: 53 | return -1 54 | 55 | def leave_room(self): 56 | if self.game_room is None: 57 | return -1 58 | self.leave_game() 59 | self.game_room.leave_room(self) 60 | 61 | def leave_game(self): 62 | if self.game_room is None: 63 | return -1 64 | self.game_room.leave_game(self) 65 | self.game_status = User.USER_GAME_STATUS_NOJOIN 66 | self.game_role = -1 67 | 68 | 69 | class ActionResult(object): 70 | def __init__(self, result_id_=0, result_info_=""): 71 | self.result_id = result_id_ 72 | self.result_info = result_info_ 73 | 74 | 75 | class GameRoom(object): 76 | ACTION_SUCCESS = 0 77 | ACTION_FAILURE = -1 78 | 79 | def __init__(self, room_id_): 80 | self.play_users = [] 81 | self.position2users = {} # 1 black role 2 white role 82 | self.room_id = room_id_ 83 | self.users = [] 84 | self.max_player_num = 2 85 | self.max_user_num = 10000 86 | self.board = ChessBoard() 87 | self.status_signature = None 88 | self.set_changed() 89 | self.chess_folder = 'chess_output' 90 | # self.game_status = GameRoom.GAME_STATUS_NOTBEGIN 91 | self.ask_take_back = 0 92 | 93 | def set_changed(self): 94 | self.status_signature = ''.join(random.sample(string.ascii_letters + string.digits, 8)) 95 | 96 | def broadcast_message_to_all(self, message): 97 | self.set_changed() 98 | pass 99 | 100 | def send_message(self, to_user_id, message): 101 | self.set_changed() 102 | pass 103 | 104 | def get_last_move(self): 105 | (userrole, move_num, row, col) = self.board.get_lastmove() 106 | if userrole < 0: 107 | userrole = -1 * self.get_status() - 1 108 | last_move = { 109 | 'role': userrole, 110 | 'move_num': move_num, 111 | 'row': row, 112 | 'col': col, 113 | } 114 | return last_move 115 | 116 | # GAME_STATUS_NOTBEGIN = 0 117 | # GAME_STATUS_RUNING_HEIFANG = 1 118 | # GAME_STATUS_RUNING_BAIFANG = 2 119 | # GAME_STATUS_FINISH = 3 120 | # GAME_STATUS_ASKTAKEBACK_HEIFANG = 4 121 | # GAME_STATUS_ASKTAKEBACK_BAIFANG = 5 122 | # GAME_STATUS_ASKREBEGIN_HEIFANG = 6 123 | # GAME_STATUS_ASKREBEGIN_BAIFANG = 7 124 | 125 | def get_signature(self): 126 | return self.status_signature 127 | 128 | def action(self, user, action_code, action_args): 129 | if action_code == "put_piece": 130 | if self.ask_take_back != 0: 131 | return ActionResult(-1, "Wait take back answer before.") 132 | if not (user.game_role == 1) and not (user.game_role == 2): 133 | return ActionResult(-1, "Not the right time or role to put piece") 134 | piece_i = action_args.get_argument('piece_i', None) 135 | piece_j = action_args.get_argument('piece_j', None) 136 | if piece_i and piece_j: 137 | return_code = self.board.put_piece(int(piece_i), int(piece_j), user.game_role) 138 | if return_code >= 0: 139 | if return_code == 1: 140 | self.finish_game() 141 | self.set_changed() 142 | return ActionResult(0, "put_piece success:" + str(return_code)); 143 | else: 144 | return ActionResult(-4, "put_piece failed, because " + str(return_code)) 145 | else: 146 | return ActionResult(-3, "Not set the piece_i and piece_j") 147 | elif action_code == "getlastmove": 148 | return ActionResult(0, self.get_last_move()) 149 | elif action_code == "get_status_signature": 150 | return ActionResult(0, self.get_signature()) 151 | elif action_code == "get_room_info": 152 | return ActionResult(0, pickle.dumps(self)) 153 | elif action_code == "get_game_info": 154 | return ActionResult(0, pickle.dumps(self.board)) 155 | elif action_code == "get_user_info": 156 | return ActionResult(0, pickle.dumps(user)) 157 | elif action_code == "ask_take_back": 158 | if self.ask_take_back != 0: 159 | return ActionResult(-1, "Not right time to ask take back.") 160 | if user.game_role != 1 and user.game_role != 2: 161 | return ActionResult(-1, "Not right role to ask take back.") 162 | self.ask_take_back = user.game_role 163 | self.set_changed() 164 | return ActionResult(0, "Ask take back success") 165 | elif action_code == "answer_take_back": 166 | 167 | if not (self.ask_take_back == 1 and user.game_role == 2) and not ( 168 | self.ask_take_back == 2 and user.game_role == 1): 169 | return ActionResult(-1, "Not the right time or role to answer take back") 170 | 171 | agree = action_args.get_argument('agree', 'false') 172 | if agree == 'true': 173 | if self.ask_take_back == 1 and user.game_role == 2: 174 | if self.board.get_current_user() == 1: 175 | self.board.take_one_back() 176 | self.board.take_one_back() 177 | 178 | elif self.ask_take_back == 2 and user.game_role == 1: 179 | if self.board.get_current_user() == 2: 180 | self.board.take_one_back() 181 | self.board.take_one_back() 182 | self.ask_take_back = 0 183 | self.set_changed() 184 | return ActionResult(0, "answer take back success") 185 | else: 186 | return ActionResult(-2, "Not recognized game action") 187 | 188 | # def join(self, user): 189 | # if self.join_game(user) == GameRoom.ACTION_SUCCESS: 190 | # return GameRoom.ACTION_SUCCESS 191 | # return self.join_watch(user) 192 | 193 | def join_game(self, user, user_role): 194 | if user not in self.users: 195 | return GameRoom.ACTION_FAILURE 196 | 197 | if len(self.play_users) >= self.max_player_num: 198 | return GameRoom.ACTION_FAILURE 199 | idle_position = [i for i in range(1, self.max_player_num + 1) if i not in self.position2users] 200 | if len(idle_position) == 0: 201 | return GameRoom.ACTION_FAILURE 202 | if user_role is None or user_role == -1: 203 | user_role = idle_position[0] 204 | elif user_role == 0: 205 | random.shuffle(idle_position) 206 | user_role = idle_position[0] 207 | else: 208 | if user_role in self.position2users or user_role > self.max_player_num: 209 | return GameRoom.ACTION_FAILURE 210 | user.game_role = user_role 211 | self.position2users[user_role] = user 212 | self.play_users.append(user) 213 | 214 | self.set_changed() 215 | return GameRoom.ACTION_SUCCESS 216 | 217 | def join_room(self, user): 218 | if len(self.users) >= self.max_user_num: 219 | return GameRoom.ACTION_FAILURE 220 | self.users.append(user) 221 | self.set_changed() 222 | return GameRoom.ACTION_SUCCESS 223 | 224 | def leave_game(self, user): 225 | if user not in self.play_users: 226 | return 227 | 228 | if user in self.play_users: 229 | if self.get_status() == GameRoom.ROOM_STATUS_PLAYING: 230 | self.finish_game(state=-1) 231 | self.play_users.remove(user) 232 | if user.game_role in self.position2users: 233 | del self.position2users[user.game_role] 234 | self.set_changed() 235 | 236 | def leave_room(self, user): 237 | if user not in self.users: 238 | return 239 | self.leave_game(user) 240 | self.users.remove(user) 241 | self.set_changed() 242 | 243 | ROOM_STATUS_FINISH = 4 244 | ROOM_STATUS_NOONE = 1 245 | ROOM_STATUS_ONEWAITING = 2 246 | ROOM_STATUS_PLAYING = 3 247 | ROOM_STATUS_WRONG = -1 248 | ROOM_STATUS_NOTINROOM = 0 249 | 250 | def get_status(self): 251 | if self.board.is_over(): 252 | return GameRoom.ROOM_STATUS_FINISH; 253 | if len(self.play_users) == 0: 254 | return GameRoom.ROOM_STATUS_NOONE; 255 | if len(self.play_users) == 1: 256 | return GameRoom.ROOM_STATUS_ONEWAITING; 257 | if len(self.play_users) == 2: 258 | return GameRoom.ROOM_STATUS_PLAYING; 259 | return GameRoom.ROOM_STATUS_WRONG; 260 | 261 | def reset_game(self): 262 | while len(self.play_users) > 0: 263 | self.play_users[0].leave_game() 264 | self.board.reset() 265 | 266 | def finish_game(self, state=0): 267 | if state == -1: 268 | self.board.abort() 269 | 270 | if not os.path.exists(self.chess_folder): 271 | os.makedirs(self.chess_folder) 272 | import datetime 273 | tm = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 274 | chess_file = self.chess_folder + '/' + self.room_id + '_[' + '|'.join( 275 | [user.uid for user in self.play_users]) + ']_' + tm + '.txt' 276 | with open(chess_file, 'w') as f: 277 | f.write(self.board.dumps()) 278 | 279 | # def get_room_info(self): 280 | # room_info = {'status': 0, 'roomid': -1} 281 | # room_info['status'] = self.get_status() 282 | # room_info['roomid'] = self.room_id 283 | # room_info['room'] = self 284 | # room_info['users_uid'] = [] 285 | # room_info['play_users_uid'] = [] 286 | # room_info['users'] = [] 287 | # room_info['play_users'] = [] 288 | # for user in self.users: 289 | # room_info['users_uid'].append(user.uid) 290 | # room_info['users'].append(user) 291 | # for user in user.game_room.play_users: 292 | # room_info['play_users_uid'].append(user.uid) 293 | # room_info['play_users'].append(user) 294 | 295 | 296 | class Hall(object): 297 | def __init__(self): 298 | self.uid2user = {} 299 | self.id2room = {} 300 | self.MaxUserNum = 10000 301 | self.user_num = 0 302 | 303 | def login(self, username, password): 304 | if self.user_num > self.MaxUserNum: 305 | return None 306 | pass 307 | 308 | def login_in_guest(self): 309 | if self.user_num > self.MaxUserNum: 310 | return None 311 | import random 312 | username = "guest_" 313 | while True: 314 | rand_postfix = random.randint(10000, 1000000) 315 | username = "guest_" + str(rand_postfix) 316 | if username not in self.uid2user: 317 | break 318 | user = User(username, self) 319 | self.uid2user[username] = user 320 | return username 321 | 322 | def get_user_with_uid(self, userid): 323 | if userid not in self.uid2user: 324 | self.uid2user[userid] = User(userid, self) 325 | return self.uid2user[userid] 326 | 327 | def join_room(self, username, roomid): 328 | user = self.get_user_with_uid(username) 329 | if roomid not in self.id2room: 330 | self.id2room[roomid] = GameRoom(roomid) 331 | if user.game_room != self.id2room[roomid]: 332 | user.join_room(self.id2room[roomid]) 333 | 334 | def get_room_info_with_user(self, username): 335 | room_info = {'status': 0, 'roomid': -1} 336 | user = self.get_user_with_uid(username) 337 | if user.game_room: 338 | 339 | room_info['status'] = user.game_room.get_status() 340 | room_info['roomid'] = user.game_room.room_id 341 | room_info['room'] = user.game_room 342 | room_info['users_uid'] = [] 343 | room_info['play_users_uid'] = [] 344 | room_info['users'] = [] 345 | room_info['play_users'] = [] 346 | for user in user.game_room.users: 347 | room_info['users_uid'].append(user.uid) 348 | room_info['users'].append(user) 349 | for user in user.game_room.play_users: 350 | room_info['play_users_uid'].append(user.uid) 351 | room_info['play_users'].append(user) 352 | return room_info 353 | 354 | def get_room_with_user(self, username): 355 | user = self.get_user_with_uid(username) 356 | if user.game_room: 357 | return user.game_room 358 | return None 359 | 360 | def join_game(self, username, user_role): 361 | user = self.get_user_with_uid(username) 362 | return user.join_game(user_role) 363 | 364 | def game_action(self, username, actionid, arg_pack): 365 | user = self.get_user_with_uid(username) 366 | if user.game_room: 367 | return user.game_room.action(user, actionid, arg_pack) 368 | else: 369 | return ActionResult(-1, "Not in any room") 370 | 371 | def logout(self, username): 372 | user = self.get_user_with_uid(username) 373 | if user.game_room: 374 | user.game_room.leave_room(user) 375 | self.uid2user.pop(username) 376 | -------------------------------------------------------------------------------- /evaluate/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # 使用指南 4 | ``` 5 | cd AlphaPig/evaluate 6 | python ChessServer.py # 将会在本机11111端口启动对弈服务,浏览器打开localhost:11111即可 7 | ``` 8 | 9 | # 与AI对弈 10 | 11 | 1. 12 | ``` 13 | # AlphaPig/evaluate 目录下 14 | python ChessClient.py --model XXX.model --cur_role 1/2 --room_name ROOM_NAME --server_url 127.0.0.1:8888 15 | ``` 16 | --cur_role 1是黑手,2是白手; —room_name是对弈的房间号码,server_url 是对弈服务器的地址(默认是本机喽) 17 | 18 | 如果是两个AI对弈,则双方出了--cur_role之外,都敲入相同的参数。如果是AI与人,则谁先进去先有优先选择权。 19 | 20 | # 人人对弈 21 | 22 | 直接打开浏览器,约定房间,没什么好说的。 -------------------------------------------------------------------------------- /evaluate/ai_listen_ucloud.bat: -------------------------------------------------------------------------------- 1 | c:\Python27\python.exe AICollection.py --server_url http://120.132.59.147:11111 -------------------------------------------------------------------------------- /evaluate/logs/error.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/logs/error.log -------------------------------------------------------------------------------- /evaluate/logs/info.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/logs/info.log -------------------------------------------------------------------------------- /evaluate/page/chessboard.html: -------------------------------------------------------------------------------- 1 | 2 | {% from Hall import GameRoom %} 3 | 4 | 5 | 6 | Alpha猪 7 | 8 | 9 | 10 | 11 | 12 |
13 |
14 | {{username}} 15 | , 退出登录. 16 | 房间号: 17 | 18 |
19 | 20 |
21 | {% if room %} 22 | 23 | 34 |
35 | 你已经进入 {{room.room_id}} 房间,本房间共有{{len(room.users)}}位用户: 36 | {% for muser in room.users %} 37 | | {{muser.uid}} 38 | {% end %} 39 |
40 |
41 | 当前黑方为: 42 | {% if 1 in room.position2users %} 43 | {{room.position2users[1].uid}} 44 | {% else%} 45 | 空, 46 | {% if user.game_role<=0 %} 47 | 48 | {% end %} 49 | {% end %} 50 | ,当前白方为: 51 | {% if 2 in room.position2users %} 52 | {{room.position2users[2].uid}} 53 | {% else%} 54 | 空, 55 | {% if user.game_role<=0 %} 56 | 57 | {% end %} 58 | {% end %} 59 | {% if 1 not in room.position2users and 2 not in room.position2users and user.game_role<=0%} 60 | 61 | {% end %} 62 | 63 |
64 |
65 |
66 | 95 |
96 | 功能: 97 | {% if room.get_status()==GameRoom.ROOM_STATUS_PLAYING and room.ask_take_back==0 and (user.game_role==1 or user.game_role==2) %} 98 | 99 | {% end %} 100 | 101 | {% if room.get_status()==GameRoom.ROOM_STATUS_FINISH and (user.game_role==1 or user.game_role==2) %} 102 | 103 | {% end %} 104 | 105 | 106 |
107 |
108 | {% set winner_name="黑方" if chess_board.get_winner()==1 else "白方" %} 109 | {% set current_name="黑方" if chess_board.get_current_user()==1 else "白方" %} 110 | {% if room.get_status()==GameRoom.ROOM_STATUS_NOONE or \ 111 | room.get_status()==GameRoom.ROOM_STATUS_ONEWAITING %} 112 | 等待选手加入游戏. 113 | {% elif room.get_status()==GameRoom.ROOM_STATUS_PLAYING %} 114 | 115 | {% if chess_board.get_current_user()==user.game_role %} 116 | 当前轮到你行动 117 | {% elif user.game_role<=0 %} 118 | 当前轮到{{current_name}}行动 119 | {% else %} 120 | 请等待对方下子 121 | {% end %} 122 | 123 | {% elif room.get_status()==GameRoom.ROOM_STATUS_FINISH %} 124 | 游戏结束,{{winner_name}}获胜。 125 | 130 | 131 | {% end %} 132 | 133 | 134 |
135 | 136 | 137 | 138 | 140 | {% for i in range(chess_board.SIZE) %} 141 | 142 | 143 | {% for j in range(chess_board.SIZE) %} 144 | {% if chess_board.get_piece(i, j)==0 %} 145 | 147 | {% elif chess_board.get_piece(i, j)==1 %} 148 | 153 | {% elif chess_board.get_piece(i , j)==2 %} 154 | 159 | {% else %} 160 | not legal {{chess_board.get_piece(i , j)}} 161 | {% end%} 162 | {% end%} 163 | 164 | {% end%} 165 |
149 | {% if i== chess_board.get_lastmove()[-2] and j== chess_board.get_lastmove()[-1]%} 150 | Last 151 | {% end %} 152 | 155 | {% if i== chess_board.get_lastmove()[-2] and j== chess_board.get_lastmove()[-1]%} 156 | Last 157 | {% end %} 158 |
166 |
167 |
168 | 169 | 291 | 292 | 293 | {% end%} 294 |
295 | 337 | 338 | 339 | -------------------------------------------------------------------------------- /evaluate/page/login.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Alpha猪 6 | 7 | 8 | 9 |
10 |
11 | 16 | 17 |
18 |
19 |
20 | 注意:1~5房间为AI执黑(先手),6~10房间为AI执白(后手) 21 |
22 | 23 |
24 | 25 | 26 | -------------------------------------------------------------------------------- /evaluate/run_ai.sh: -------------------------------------------------------------------------------- 1 | python ChessClient.py --cur_role 1 --model r10 --room_name 1 --server_url http://127.0.0.1:8888 2 | -------------------------------------------------------------------------------- /evaluate/static/black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/static/black.png -------------------------------------------------------------------------------- /evaluate/static/blank.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/static/blank.png -------------------------------------------------------------------------------- /evaluate/static/touming.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/static/touming.png -------------------------------------------------------------------------------- /evaluate/static/white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/static/white.png -------------------------------------------------------------------------------- /evaluate/unittest/ChessBoardTest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from ChessBoard import ChessBoard 3 | import ChessHelper 4 | from ChessHelper import printBoard 5 | 6 | 7 | class ChessBoardTest(unittest.TestCase): 8 | def setUp(self): 9 | self.a = 1 10 | 11 | def test_putpiece1(self): 12 | cb = ChessBoard() 13 | 14 | self.assertEqual(0, cb.put_piece(0, 0, 1)) 15 | self.assertEqual(-2, cb.put_piece(0, 0, 1)) 16 | self.assertEqual(-1, cb.put_piece(-1, 0, 1)) 17 | self.assertEqual(-1, cb.put_piece(0, 16, 1)) 18 | 19 | def test_putpiece2(self): 20 | cb = ChessBoard() 21 | muser = 1 22 | for i in xrange(4): 23 | for j in xrange(15): 24 | return_value = cb.put_piece(i, j, muser) 25 | self.assertEqual(0, return_value) 26 | muser = 2 if muser == 1 else 1 27 | 28 | def test_putpiece2(self): 29 | cb = ChessBoard() 30 | muser = 1 31 | for i in xrange(4): 32 | for j in xrange(15): 33 | return_value = cb.put_piece(i, j, muser) 34 | self.assertEqual(0, return_value) 35 | muser = 2 if muser == 1 else 1 36 | 37 | def test_putpiece3(self): 38 | cb = ChessBoard() 39 | ChessHelper.playRandomGame(cb) 40 | 41 | 42 | if __name__ == '__main__': 43 | unittest.main() 44 | -------------------------------------------------------------------------------- /evaluate/yixin_ai/engine.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/yixin_ai/engine.exe -------------------------------------------------------------------------------- /evaluate/yixin_ai/msvcp140.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/yixin_ai/msvcp140.dll -------------------------------------------------------------------------------- /evaluate/yixin_ai/vcruntime140.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/yixin_ai/vcruntime140.dll -------------------------------------------------------------------------------- /evaluate/yixin_ai/yixin.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/yixin_ai/yixin.exe -------------------------------------------------------------------------------- /game.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: Junxiao Song 4 | """ 5 | 6 | from __future__ import print_function 7 | import numpy as np 8 | import os 9 | from policy_value_net_mxnet import PolicyValueNet # Keras 10 | from mcts_alphaZero import MCTSPlayer 11 | from utils import sgf_dataIter, config_loader 12 | 13 | 14 | import logging 15 | import logging.config 16 | logging.config.dictConfig(config_loader.config_['train_logging']) 17 | _logger = logging.getLogger(__name__) 18 | 19 | current_relative_path = lambda x: os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), x)) 20 | 21 | class Board(object): 22 | """board for the game""" 23 | 24 | def __init__(self, **kwargs): 25 | self.width = int(kwargs.get('width', 8)) 26 | self.height = int(kwargs.get('height', 8)) 27 | # board states stored as a dict, 28 | # key: move as location on the board, 29 | # value: player as pieces type 30 | self.states = {} 31 | # need how many pieces in a row to win 32 | self.n_in_row = int(kwargs.get('n_in_row', 5)) 33 | self.players = [1, 2] # player1 and player2 34 | 35 | def init_board(self, start_player=0): 36 | if self.width < self.n_in_row or self.height < self.n_in_row: 37 | raise Exception('board width and height can not be ' 38 | 'less than {}'.format(self.n_in_row)) 39 | self.current_player = self.players[start_player] # start player 40 | # keep available moves in a list 41 | self.availables = list(range(self.width * self.height)) 42 | self.states = {} 43 | self.history = [] 44 | self.last_move = -1 45 | 46 | def move_to_location(self, move): 47 | """ 48 | 3*3 board's moves like: 49 | 6 7 8 50 | 3 4 5 51 | 0 1 2 52 | and move 5's location is (1,2) 53 | """ 54 | h = move // self.width 55 | w = move % self.width 56 | return [h, w] 57 | 58 | def location_to_move(self, location): 59 | if len(location) != 2: 60 | return -1 61 | h = location[0] 62 | w = location[1] 63 | move = h * self.width + w 64 | if move not in range(self.width * self.height): 65 | return -1 66 | return move 67 | 68 | def current_state(self): 69 | """return the board state from the perspective of the current player. 70 | state shape: (histlen*2+1)*width*height 71 | """ 72 | histlen = 4 73 | statelen = histlen*2+1 74 | square_state = np.zeros((statelen, self.width, self.height)) 75 | if self.states: 76 | histarr = np.array(self.history) 77 | moves_all = histarr[:, 0] 78 | players_all = histarr[:, 1] 79 | real_len = len(moves_all) 80 | for i in range(histlen): 81 | moves = moves_all[:real_len-i] 82 | players = players_all[:real_len-i] 83 | #print(moves, players) 84 | move_curr = moves[players == self.current_player] 85 | move_oppo = moves[players != self.current_player] 86 | square_state[-2*i-3][move_curr // self.width, 87 | move_curr % self.height] = 1.0 88 | square_state[-2*i-2][move_oppo // self.width, 89 | move_oppo % self.height] = 1.0 90 | if real_len-i == 0: 91 | break 92 | if len(self.states) % 2 == 0: 93 | square_state[-1][:, :] = 1.0 # indicate the colour to play 94 | return square_state[:, ::-1, :] 95 | 96 | def current_state_old(self): 97 | """return the board state from the perspective of the current player. 98 | state shape: 4*width*height 99 | """ 100 | 101 | square_state = np.zeros((4, self.width, self.height)) 102 | if self.states: 103 | moves, players = np.array(list(zip(*self.states.items()))) 104 | move_curr = moves[players == self.current_player] 105 | move_oppo = moves[players != self.current_player] 106 | square_state[0][move_curr // self.width, 107 | move_curr % self.height] = 1.0 108 | square_state[1][move_oppo // self.width, 109 | move_oppo % self.height] = 1.0 110 | # indicate the last move location 111 | square_state[2][self.last_move // self.width, 112 | self.last_move % self.height] = 1.0 113 | if len(self.states) % 2 == 0: 114 | square_state[3][:, :] = 1.0 # indicate the colour to play 115 | return square_state[:, ::-1, :] 116 | 117 | def do_move(self, move): 118 | self.states[move] = self.current_player 119 | self.history.append((move, self.current_player)) 120 | self.availables.remove(move) 121 | self.current_player = ( 122 | self.players[0] if self.current_player == self.players[1] 123 | else self.players[1] 124 | ) 125 | self.last_move = move 126 | 127 | def has_a_winner(self): 128 | width = self.width 129 | height = self.height 130 | states = self.states 131 | n = self.n_in_row 132 | 133 | moved = list(set(range(width * height)) - set(self.availables)) 134 | if len(moved) < self.n_in_row + 2: 135 | return False, -1 136 | 137 | for m in moved: 138 | h = m // width 139 | w = m % width 140 | player = states[m] 141 | 142 | if (w in range(width - n + 1) and 143 | len(set(states.get(i, -1) for i in range(m, m + n))) == 1): 144 | return True, player 145 | 146 | if (h in range(height - n + 1) and 147 | len(set(states.get(i, -1) for i in range(m, m + n * width, width))) == 1): 148 | return True, player 149 | 150 | if (w in range(width - n + 1) and h in range(height - n + 1) and 151 | len(set(states.get(i, -1) for i in range(m, m + n * (width + 1), width + 1))) == 1): 152 | return True, player 153 | 154 | if (w in range(n - 1, width) and h in range(height - n + 1) and 155 | len(set(states.get(i, -1) for i in range(m, m + n * (width - 1), width - 1))) == 1): 156 | return True, player 157 | 158 | return False, -1 159 | 160 | def game_end(self): 161 | """Check whether the game is ended or not""" 162 | win, winner = self.has_a_winner() 163 | if win: 164 | return True, winner 165 | elif not len(self.availables): 166 | return True, -1 167 | return False, -1 168 | 169 | def get_current_player(self): 170 | return self.current_player 171 | 172 | 173 | class Game(object): 174 | """game server""" 175 | 176 | def __init__(self, board, **kwargs): 177 | self.board = board 178 | self._boardSize = board.width * board.height 179 | 180 | def graphic(self, board, player1, player2): 181 | """Draw the board and show game info""" 182 | width = board.width 183 | height = board.height 184 | 185 | print("Player", player1, "with X".rjust(3)) 186 | print("Player", player2, "with O".rjust(3)) 187 | print() 188 | for x in range(width): 189 | print("{0:8}".format(x), end='') 190 | print('\r\n') 191 | for i in range(height - 1, -1, -1): 192 | print("{0:4d}".format(i), end='') 193 | for j in range(width): 194 | loc = i * width + j 195 | p = board.states.get(loc, -1) 196 | if p == player1: 197 | print('X'.center(8), end='') 198 | elif p == player2: 199 | print('O'.center(8), end='') 200 | else: 201 | print('_'.center(8), end='') 202 | print('\r\n\r\n') 203 | 204 | def start_play(self, player1, player2, start_player=0, is_shown=1): 205 | """start a game between two players""" 206 | if start_player not in (0, 1): 207 | raise Exception('start_player should be either 0 (player1 first) ' 208 | 'or 1 (player2 first)') 209 | self.board.init_board(start_player) 210 | p1, p2 = self.board.players 211 | player1.set_player_ind(p1) 212 | player2.set_player_ind(p2) 213 | players = {p1: player1, p2: player2} 214 | if is_shown: 215 | self.graphic(self.board, player1.player, player2.player) 216 | while True: 217 | current_player = self.board.get_current_player() 218 | player_in_turn = players[current_player] 219 | move = player_in_turn.get_action(self.board) 220 | self.board.do_move(move) 221 | if is_shown: 222 | self.graphic(self.board, player1.player, player2.player) 223 | end, winner = self.board.game_end() 224 | if end: 225 | if is_shown: 226 | if winner != -1: 227 | print("Game end. Winner is", players[winner]) 228 | else: 229 | print("Game end. Tie") 230 | return winner 231 | 232 | 233 | def start_self_play(self, player, is_shown=0, temp=1e-3, sgf_home=None, file_name=None): 234 | """ start a self-play game using a MCTS player, reuse the search tree, 235 | and store the self-play data: (state, mcts_probs, z) for training 236 | """ 237 | # 获取棋盘数据 238 | X_train = sgf_dataIter.get_data_from_files(file_name, sgf_home) 239 | data_length = len(X_train['seq_num_list']) # 对弈长度(一盘棋盘数据的长度) 240 | self.board.init_board() 241 | p1, p2 = self.board.players 242 | # print('p1: ', p1, ' p2: ', p2) 243 | states, mcts_probs, current_players = [], [], [] 244 | # while True: 245 | for num_index, move in enumerate(X_train['seq_num_list']): 246 | # move_, move_probs = player.get_action(self.board, 247 | # temp=temp, 248 | # return_prob=1) 249 | probs = [0.000001 for _ in range(self._boardSize)] 250 | probs[move] = 0.99999 251 | move_probs = np.asarray(probs) 252 | # print('move: ') 253 | # print(move) 254 | # print('move probs: ') 255 | # print(move_probs) 256 | # print(type(move_probs)) 257 | # print(move_probs.shape) 258 | # store the data 259 | # print('current_state: \n') 260 | # print(self.board.current_state()) 261 | states.append(self.board.current_state()) 262 | mcts_probs.append(move_probs) 263 | current_players.append(self.board.current_player) 264 | # perform a move 265 | try: 266 | self.board.do_move(move) 267 | except Exception as ee: 268 | _logger.error("\033[40;31m %s \033[0m: anxingle file_name: %s, move: %s" %('WARNING', file_name, move) ) 269 | warning, winner, mcts_probs = 1, None, None 270 | return warning, winner, mcts_probs 271 | if is_shown: 272 | self.graphic(self.board, p1, p2) 273 | # go_on= input('go on:') 274 | # if go_on == 1: 275 | # break 276 | # 既然使用现成的棋局文件, end判断当然也需要重新设置 277 | end, warning = 0, 1 278 | if num_index + 1 == data_length: 279 | end = 1 280 | winner = X_train['winner'] 281 | try: 282 | # 这是一个故意的“bug”,目的在于检验是否end。不用也行。 283 | _logger.error('file_name %s has some problem! seq_num_list: %s' % (file_name, X_train['seq_num_list'][num_index+1])) 284 | except Exception as e: 285 | # 倘若进入了这个“bug”, 则不用报告warning 286 | warning = 0 287 | # print('you can ignore: ', e) 288 | # end, winner = self.board.game_end() 289 | if end: 290 | # winner from the perspective of the current player of each state 291 | winners_z = np.zeros(len(current_players)) 292 | if winner != -1: 293 | winners_z[np.array(current_players) == winner] = 1.0 294 | winners_z[np.array(current_players) != winner] = -1.0 295 | # reset MCTS root node 296 | player.reset_player() 297 | if is_shown: 298 | if winner != -1: 299 | print("Game end. Winner is player:", winner) 300 | else: 301 | print("Game end. Tie") 302 | # go_on = input('go on:') 303 | # winner 1:2 304 | return warning, winner, zip(states, mcts_probs, winners_z) 305 | 306 | 307 | if __name__ == '__main__': 308 | model_file = 'current_policy.model' 309 | policy_value_net = PolicyValueNet(15, 15) 310 | mcts_player = MCTSPlayer(policy_value_net.policy_value_fn, 311 | c_puct=3, 312 | n_playout=2, 313 | is_selfplay=1) 314 | board = Board(width=15, height=15, n_in_row=5) 315 | game = Game(board) 316 | sgf_home = current_relative_path('./sgf_data') 317 | file_name = '1000_white_.sgf' 318 | winner, play_data = game.start_self_play(mcts_player, is_shown=1, temp=1.0, sgf_home=sgf_home, file_name=file_name) 319 | 320 | 321 | -------------------------------------------------------------------------------- /game_ai.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: Junxiao Song 4 | """ 5 | 6 | from __future__ import print_function 7 | import numpy as np 8 | from game import Board 9 | import random 10 | 11 | class Game_AI(object): 12 | """game server""" 13 | 14 | def __init__(self, board, **kwargs): 15 | self.board = board 16 | self._boardSize = board.width * board.height 17 | 18 | def graphic(self, board, player1, player2): 19 | """Draw the board and show game info""" 20 | width = board.width 21 | height = board.height 22 | 23 | print("Player", player1, "with X".rjust(3)) 24 | print("Player", player2, "with O".rjust(3)) 25 | print() 26 | for x in range(width): 27 | print("{0:8}".format(x), end='') 28 | print('\r\n') 29 | for i in range(height - 1, -1, -1): 30 | print("{0:4d}".format(i), end='') 31 | for j in range(width): 32 | loc = i * width + j 33 | p = board.states.get(loc, -1) 34 | if p == player1: 35 | print('X'.center(8), end='') 36 | elif p == player2: 37 | print('O'.center(8), end='') 38 | else: 39 | print('_'.center(8), end='') 40 | print('\r\n\r\n') 41 | 42 | def start_play(self, player1, player2, start_player=0, is_shown=1): 43 | """start a game between two players""" 44 | if start_player not in (0, 1): 45 | raise Exception('start_player should be either 0 (player1 first) ' 46 | 'or 1 (player2 first)') 47 | self.board.init_board(start_player) 48 | p1, p2 = self.board.players 49 | player1.set_player_ind(p1) 50 | player2.set_player_ind(p2) 51 | players = {p1: player1, p2: player2} 52 | if is_shown: 53 | self.graphic(self.board, player1.player, player2.player) 54 | while True: 55 | current_player = self.board.get_current_player() 56 | player_in_turn = players[current_player] 57 | move = player_in_turn.get_action(self.board) 58 | self.board.do_move(move) 59 | if is_shown: 60 | self.graphic(self.board, player1.player, player2.player) 61 | end, winner = self.board.game_end() 62 | if end: 63 | if is_shown: 64 | if winner != -1: 65 | print("Game end. Winner is", players[winner]) 66 | else: 67 | print("Game end. Tie") 68 | return winner 69 | 70 | def start_self_play(self, player, is_shown=0, temp=1e-3): 71 | """ start a self-play game using a MCTS player, reuse the search tree, 72 | and store the self-play data: (state, mcts_probs, z) for training 73 | """ 74 | self.board.init_board() 75 | p1, p2 = self.board.players 76 | states, mcts_probs, current_players = [], [], [] 77 | blank_move_list = [0,1,2,3,4,5,6,7,8,15,16,17,18,19,20,21,22,23,30,31,32,33,34,35,36,37,38,45,46,47,48,49,50,51,52,53,60,61,62,63,64,65,66,67,68,75,76,77,78,79,80,81,82,83,90,91,92,93,94,95,96,97,98] 78 | white_move_list = range(0, 103) 79 | if random.random() < 0.09: 80 | while True: 81 | move_blank = random.choice(blank_move_list) 82 | # move_blank = blank_move_list[random.randint(0, len()-1)] 83 | move_white = random.choice(white_move_list) 84 | if move_blank != move_white: 85 | break 86 | # store the data 87 | # 黑子走子概率 88 | probs = [0.000001 for _ in range(self._boardSize)] 89 | probs[move_blank] = 0.99999 90 | move_blank_probs = np.asarray(probs) 91 | 92 | states.append(self.board.current_state()) 93 | mcts_probs.append(move_blank_probs) 94 | current_players.append(self.board.current_player) 95 | # perform a move 96 | self.board.do_move(move_blank) 97 | if is_shown: 98 | self.graphic(self.board, p1, p2) 99 | 100 | # 白子走子概率 101 | probs_ = [0.000001 for _ in range(self._boardSize)] 102 | probs_[move_white] = 0.99999 103 | move_white_probs = np.asarray(probs_) 104 | 105 | states.append(self.board.current_state()) 106 | mcts_probs.append(move_white_probs) 107 | current_players.append(self.board.current_player) 108 | # perform a move 109 | self.board.do_move(move_white) 110 | if is_shown: 111 | self.graphic(self.board, p1, p2) 112 | 113 | while True: 114 | move, move_probs = player.get_action(self.board, 115 | temp=temp, 116 | return_prob=1) 117 | # store the data 118 | states.append(self.board.current_state()) 119 | mcts_probs.append(move_probs) 120 | current_players.append(self.board.current_player) 121 | # perform a move 122 | self.board.do_move(move) 123 | if is_shown: 124 | self.graphic(self.board, p1, p2) 125 | end, winner = self.board.game_end() 126 | if end: 127 | # winner from the perspective of the current player of each state 128 | winners_z = np.zeros(len(current_players)) 129 | if winner != -1: 130 | winners_z[np.array(current_players) == winner] = 1.0 131 | winners_z[np.array(current_players) != winner] = -1.0 132 | # reset MCTS root node 133 | player.reset_player() 134 | if is_shown: 135 | if winner != -1: 136 | print("Game end. Winner is player:", winner) 137 | else: 138 | print("Game end. Tie") 139 | return winner, zip(states, mcts_probs, winners_z) 140 | -------------------------------------------------------------------------------- /human_play_mxnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | human VS AI models 4 | Input your move in the format: 2,3 5 | 6 | @author: Junxiao Song 7 | """ 8 | 9 | from __future__ import print_function 10 | import pickle 11 | from game import Board, Game 12 | from mcts_pure import MCTSPlayer as MCTS_Pure 13 | from mcts_alphaZero import MCTSPlayer 14 | from policy_value_net_mxnet import PolicyValueNet # Keras 15 | 16 | 17 | class Human(object): 18 | """ 19 | human player 20 | """ 21 | 22 | def __init__(self): 23 | self.player = None 24 | 25 | def set_player_ind(self, p): 26 | self.player = p 27 | 28 | def get_action(self, board): 29 | try: 30 | location = input("Your move: ") 31 | if isinstance(location, str): # for python3 32 | location = [int(n, 10) for n in location.split(",")] 33 | move = board.location_to_move(location) 34 | except Exception as e: 35 | move = -1 36 | if move == -1 or move not in board.availables: 37 | print("invalid move") 38 | move = self.get_action(board) 39 | return move 40 | 41 | def __str__(self): 42 | return "Human {}".format(self.player) 43 | 44 | 45 | def run(): 46 | n = 5 47 | width, height = 15, 15 48 | #model_file = 'best_policy.model' 49 | model_file = './logs/current_policy.model' 50 | try: 51 | board = Board(width=width, height=height, n_in_row=n) 52 | game = Game(board) 53 | 54 | # ############### human VS AI ################### 55 | # load the trained policy_value_net in either Theano/Lasagne, PyTorch or TensorFlow 56 | 57 | # best_policy = PolicyValueNet(width, height, model_file = model_file) 58 | # mcts_player = MCTSPlayer(best_policy.policy_value_fn, c_puct=5, n_playout=400) 59 | 60 | # load the provided model (trained in Theano/Lasagne) into a MCTS player written in pure numpy 61 | try: 62 | policy_param = pickle.load(open(model_file, 'rb')) 63 | except: 64 | policy_param = pickle.load(open(model_file, 'rb'), 65 | encoding='bytes') # To support python3 66 | best_policy = PolicyValueNet(board_width=width, board_height=height, batch_size=512, model_params=policy_param) 67 | mcts_player = MCTSPlayer(best_policy.policy_value_fn, 68 | c_puct=5, 69 | n_playout=200) # set larger n_playout for better performance 70 | 71 | # uncomment the following line to play with pure MCTS (it's much weaker even with a larger n_playout) 72 | #mcts_player2 = MCTS_Pure(c_puct=5, n_playout=1000) 73 | # mcts_player2 = MCTSPlayer(best_policy.policy_value_fn, 74 | # c_puct=5, 75 | # n_playout=4000) # set larger n_playout for better performance 76 | 77 | # human player, input your move in the format: 2,3 78 | human = Human() 79 | 80 | # set start_player=0 for human first 81 | game.start_play(human, mcts_player, start_player=1, is_shown=1) 82 | except KeyboardInterrupt: 83 | print('\n\rquit') 84 | 85 | 86 | if __name__ == '__main__': 87 | run() 88 | -------------------------------------------------------------------------------- /inception-resnet-v2.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | """ 19 | Contains the definition of the Inception Resnet V2 architecture. 20 | As described in http://arxiv.org/abs/1602.07261. 21 | Inception-v4, Inception-ResNet and the Impact of Residual Connections 22 | on Learning 23 | Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi 24 | """ 25 | import mxnet as mx 26 | 27 | 28 | def ConvFactory(data, num_filter, kernel, stride=(1, 1), pad=(0, 0), act_type="relu", mirror_attr={}, with_act=True): 29 | conv = mx.symbol.Convolution( 30 | data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad) 31 | bn = mx.symbol.BatchNorm(data=conv) 32 | if with_act: 33 | act = mx.symbol.Activation( 34 | data=bn, act_type=act_type, attr=mirror_attr) 35 | return act 36 | else: 37 | return bn 38 | 39 | 40 | def block35(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}): 41 | tower_conv = ConvFactory(net, 32, (1, 1)) 42 | tower_conv1_0 = ConvFactory(net, 32, (1, 1)) 43 | tower_conv1_1 = ConvFactory(tower_conv1_0, 32, (3, 3), pad=(1, 1)) 44 | tower_conv2_0 = ConvFactory(net, 32, (1, 1)) 45 | tower_conv2_1 = ConvFactory(tower_conv2_0, 48, (3, 3), pad=(1, 1)) 46 | tower_conv2_2 = ConvFactory(tower_conv2_1, 64, (3, 3), pad=(1, 1)) 47 | tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_1, tower_conv2_2]) 48 | tower_out = ConvFactory( 49 | tower_mixed, input_num_channels, (1, 1), with_act=False) 50 | 51 | net += scale * tower_out 52 | if with_act: 53 | act = mx.symbol.Activation( 54 | data=net, act_type=act_type, attr=mirror_attr) 55 | return act 56 | else: 57 | return net 58 | 59 | 60 | def block17(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}): 61 | tower_conv = ConvFactory(net, 192, (1, 1)) 62 | tower_conv1_0 = ConvFactory(net, 129, (1, 1)) 63 | tower_conv1_1 = ConvFactory(tower_conv1_0, 160, (1, 7), pad=(1, 2)) 64 | tower_conv1_2 = ConvFactory(tower_conv1_1, 192, (7, 1), pad=(2, 1)) 65 | tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_2]) 66 | tower_out = ConvFactory( 67 | tower_mixed, input_num_channels, (1, 1), with_act=False) 68 | net += scale * tower_out 69 | if with_act: 70 | act = mx.symbol.Activation( 71 | data=net, act_type=act_type, attr=mirror_attr) 72 | return act 73 | else: 74 | return net 75 | 76 | 77 | def block8(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}): 78 | tower_conv = ConvFactory(net, 192, (1, 1)) 79 | tower_conv1_0 = ConvFactory(net, 192, (1, 1)) 80 | tower_conv1_1 = ConvFactory(tower_conv1_0, 224, (1, 3), pad=(0, 1)) 81 | tower_conv1_2 = ConvFactory(tower_conv1_1, 256, (3, 1), pad=(1, 0)) 82 | tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_2]) 83 | tower_out = ConvFactory( 84 | tower_mixed, input_num_channels, (1, 1), with_act=False) 85 | net += scale * tower_out 86 | if with_act: 87 | act = mx.symbol.Activation( 88 | data=net, act_type=act_type, attr=mirror_attr) 89 | return act 90 | else: 91 | return net 92 | 93 | 94 | def repeat(inputs, repetitions, layer, *args, **kwargs): 95 | outputs = inputs 96 | for i in range(repetitions): 97 | outputs = layer(outputs, *args, **kwargs) 98 | return outputs 99 | 100 | 101 | def get_symbol(num_classes=1000, **kwargs): 102 | data = mx.symbol.Variable(name='data') 103 | conv1a_3_3 = ConvFactory(data=data, num_filter=32, 104 | kernel=(3, 3), stride=(2, 2)) 105 | conv2a_3_3 = ConvFactory(conv1a_3_3, 32, (3, 3)) 106 | conv2b_3_3 = ConvFactory(conv2a_3_3, 64, (3, 3), pad=(1, 1)) 107 | maxpool3a_3_3 = mx.symbol.Pooling( 108 | data=conv2b_3_3, kernel=(3, 3), stride=(2, 2), pool_type='max') 109 | conv3b_1_1 = ConvFactory(maxpool3a_3_3, 80, (1, 1)) 110 | conv4a_3_3 = ConvFactory(conv3b_1_1, 192, (3, 3)) 111 | maxpool5a_3_3 = mx.symbol.Pooling( 112 | data=conv4a_3_3, kernel=(3, 3), stride=(2, 2), pool_type='max') 113 | 114 | tower_conv = ConvFactory(maxpool5a_3_3, 96, (1, 1)) 115 | tower_conv1_0 = ConvFactory(maxpool5a_3_3, 48, (1, 1)) 116 | tower_conv1_1 = ConvFactory(tower_conv1_0, 64, (5, 5), pad=(2, 2)) 117 | 118 | tower_conv2_0 = ConvFactory(maxpool5a_3_3, 64, (1, 1)) 119 | tower_conv2_1 = ConvFactory(tower_conv2_0, 96, (3, 3), pad=(1, 1)) 120 | tower_conv2_2 = ConvFactory(tower_conv2_1, 96, (3, 3), pad=(1, 1)) 121 | 122 | tower_pool3_0 = mx.symbol.Pooling(data=maxpool5a_3_3, kernel=( 123 | 3, 3), stride=(1, 1), pad=(1, 1), pool_type='avg') 124 | tower_conv3_1 = ConvFactory(tower_pool3_0, 64, (1, 1)) 125 | tower_5b_out = mx.symbol.Concat( 126 | *[tower_conv, tower_conv1_1, tower_conv2_2, tower_conv3_1]) 127 | net = repeat(tower_5b_out, 10, block35, scale=0.17, input_num_channels=320) 128 | tower_conv = ConvFactory(net, 384, (3, 3), stride=(2, 2)) 129 | tower_conv1_0 = ConvFactory(net, 256, (1, 1)) 130 | tower_conv1_1 = ConvFactory(tower_conv1_0, 256, (3, 3), pad=(1, 1)) 131 | tower_conv1_2 = ConvFactory(tower_conv1_1, 384, (3, 3), stride=(2, 2)) 132 | tower_pool = mx.symbol.Pooling(net, kernel=( 133 | 3, 3), stride=(2, 2), pool_type='max') 134 | net = mx.symbol.Concat(*[tower_conv, tower_conv1_2, tower_pool]) 135 | net = repeat(net, 20, block17, scale=0.1, input_num_channels=1088) 136 | tower_conv = ConvFactory(net, 256, (1, 1)) 137 | tower_conv0_1 = ConvFactory(tower_conv, 384, (3, 3), stride=(2, 2)) 138 | tower_conv1 = ConvFactory(net, 256, (1, 1)) 139 | tower_conv1_1 = ConvFactory(tower_conv1, 288, (3, 3), stride=(2, 2)) 140 | tower_conv2 = ConvFactory(net, 256, (1, 1)) 141 | tower_conv2_1 = ConvFactory(tower_conv2, 288, (3, 3), pad=(1, 1)) 142 | tower_conv2_2 = ConvFactory(tower_conv2_1, 320, (3, 3), stride=(2, 2)) 143 | tower_pool = mx.symbol.Pooling(net, kernel=( 144 | 3, 3), stride=(2, 2), pool_type='max') 145 | net = mx.symbol.Concat( 146 | *[tower_conv0_1, tower_conv1_1, tower_conv2_2, tower_pool]) 147 | 148 | net = repeat(net, 9, block8, scale=0.2, input_num_channels=2080) 149 | net = block8(net, with_act=False, input_num_channels=2080) 150 | 151 | net = ConvFactory(net, 1536, (1, 1)) 152 | net = mx.symbol.Pooling(net, kernel=( 153 | 1, 1), global_pool=True, stride=(2, 2), pool_type='avg') 154 | net = mx.symbol.Flatten(net) 155 | net = mx.symbol.Dropout(data=net, p=0.2) 156 | net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes) 157 | softmax = mx.symbol.SoftmaxOutput(data=net, name='softmax') 158 | return softmax 159 | -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/logs/.gitkeep -------------------------------------------------------------------------------- /logs/download_model.sh: -------------------------------------------------------------------------------- 1 | wget "http://p28sk9doh.bkt.clouddn.com/current_res_5.model" # for 5 block resnet 2 | wget "http://p28sk9doh.bkt.clouddn.com/current_policy_simple_10999.model" # for 10 block resnet 3 | wget "http://p28sk9doh.bkt.clouddn.com/current_policy_res_10.model" # for simple policy_network_simple.py -------------------------------------------------------------------------------- /mcts_alphaZero.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Monte Carlo Tree Search in AlphaGo Zero style, which uses a policy-value 4 | network to guide the tree search and evaluate the leaf nodes 5 | 6 | @author: Junxiao Song 7 | """ 8 | 9 | import numpy as np 10 | import copy 11 | 12 | 13 | def softmax(x): 14 | probs = np.exp(x - np.max(x)) 15 | probs /= np.sum(probs) 16 | return probs 17 | 18 | 19 | class TreeNode(object): 20 | """A node in the MCTS tree. 21 | 22 | Each node keeps track of its own value Q, prior probability P, and 23 | its visit-count-adjusted prior score u. 24 | """ 25 | 26 | def __init__(self, parent, prior_p): 27 | self._parent = parent 28 | self._children = {} # a map from action to TreeNode 29 | self._n_visits = 0 30 | self._Q = 0 31 | self._u = 0 32 | self._P = prior_p 33 | 34 | def expand(self, action_priors): 35 | """Expand tree by creating new children. 36 | action_priors: a list of tuples of actions and their prior probability 37 | according to the policy function. 38 | """ 39 | for action, prob in action_priors: 40 | if action not in self._children: 41 | self._children[action] = TreeNode(self, prob) 42 | 43 | def select(self, c_puct): 44 | """Select action among children that gives maximum action value Q 45 | plus bonus u(P). 46 | Return: A tuple of (action, next_node) 47 | """ 48 | return max(self._children.items(), 49 | key=lambda act_node: act_node[1].get_value(c_puct)) 50 | 51 | def update(self, leaf_value): 52 | """Update node values from leaf evaluation. 53 | leaf_value: the value of subtree evaluation from the current player's 54 | perspective. 55 | """ 56 | # Count visit. 57 | self._n_visits += 1 58 | # Update Q, a running average of values for all visits. 59 | self._Q += 1.0*(leaf_value - self._Q) / self._n_visits 60 | 61 | def update_recursive(self, leaf_value): 62 | """Like a call to update(), but applied recursively for all ancestors. 63 | """ 64 | # If it is not root, this node's parent should be updated first. 65 | if self._parent: 66 | self._parent.update_recursive(-leaf_value) 67 | self.update(leaf_value) 68 | 69 | def get_value(self, c_puct): 70 | """Calculate and return the value for this node. 71 | It is a combination of leaf evaluations Q, and this node's prior 72 | adjusted for its visit count, u. 73 | c_puct: a number in (0, inf) controlling the relative impact of 74 | value Q, and prior probability P, on this node's score. 75 | 参考Upper Confidence Bounds (UCB)选择公式: 76 | Latex: $score=x_{child}+C*\sqrt{\frac{logN_{parent}}{N_{child}}}$ 77 | """ 78 | self._u = (c_puct * self._P * 79 | np.sqrt(self._parent._n_visits) / (1 + self._n_visits)) 80 | return self._Q + self._u 81 | 82 | def is_leaf(self): 83 | """Check if leaf node (i.e. no nodes below this have been expanded).""" 84 | return self._children == {} 85 | 86 | def is_root(self): 87 | return self._parent is None 88 | 89 | 90 | class MCTS(object): 91 | """An implementation of Monte Carlo Tree Search.""" 92 | 93 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): 94 | """ 95 | policy_value_fn: a function that takes in a board state and outputs 96 | a list of (action, probability) tuples and also a score in [-1, 1] 97 | (i.e. the expected value of the end game score from the current 98 | player's perspective) for the current player. 99 | c_puct: a number in (0, inf) that controls how quickly exploration 100 | converges to the maximum-value policy. A higher value means 101 | relying on the prior more. 102 | """ 103 | self._root = TreeNode(None, 1.0) 104 | self._policy = policy_value_fn 105 | self._c_puct = c_puct 106 | self._n_playout = n_playout 107 | 108 | def _playout(self, state): 109 | """Run a single playout from the root to the leaf, getting a value at 110 | the leaf and propagating it back through its parents. 111 | State is modified in-place, so a copy must be provided. 112 | """ 113 | node = self._root 114 | while(1): 115 | if node.is_leaf(): 116 | break 117 | # Greedily select next move. 118 | action, node = node.select(self._c_puct) 119 | state.do_move(action) 120 | 121 | # Evaluate the leaf using a network which outputs a list of 122 | # (action, probability) tuples p and also a score v in [-1, 1] 123 | # for the current player. 124 | action_probs, leaf_value = self._policy(state) 125 | # Check for end of game. 126 | end, winner = state.game_end() 127 | if not end: 128 | node.expand(action_probs) 129 | else: 130 | # for end state,return the "true" leaf_value 131 | if winner == -1: # tie 132 | leaf_value = 0.0 133 | else: 134 | leaf_value = ( 135 | 1.0 if winner == state.get_current_player() else -1.0 136 | ) 137 | 138 | # Update value and visit count of nodes in this traversal. 139 | node.update_recursive(-leaf_value) 140 | 141 | def get_move_probs(self, state, temp=1e-3): 142 | """Run all playouts sequentially and return the available actions and 143 | their corresponding probabilities. 144 | state: the current game state 145 | temp: temperature parameter in (0, 1] controls the level of exploration 146 | """ 147 | for n in range(self._n_playout): 148 | state_copy = copy.deepcopy(state) 149 | self._playout(state_copy) 150 | 151 | # calc the move probabilities based on visit counts at the root node 152 | act_visits = [(act, node._n_visits) 153 | for act, node in self._root._children.items()] 154 | acts, visits = zip(*act_visits) 155 | act_probs = softmax(1.0/temp * np.log(np.array(visits) + 1e-10)) 156 | 157 | return acts, act_probs 158 | 159 | def update_with_move(self, last_move): 160 | """Step forward in the tree, keeping everything we already know 161 | about the subtree. 162 | """ 163 | if last_move in self._root._children: 164 | self._root = self._root._children[last_move] 165 | self._root._parent = None 166 | else: 167 | self._root = TreeNode(None, 1.0) 168 | 169 | def __str__(self): 170 | return "MCTS" 171 | 172 | 173 | class MCTSPlayer(object): 174 | """AI player based on MCTS""" 175 | 176 | def __init__(self, policy_value_function, 177 | c_puct=5, n_playout=2000, is_selfplay=0): 178 | self.mcts = MCTS(policy_value_function, c_puct, n_playout) 179 | self._is_selfplay = is_selfplay 180 | 181 | def set_player_ind(self, p): 182 | self.player = p 183 | 184 | def reset_player(self): 185 | self.mcts.update_with_move(-1) 186 | 187 | def get_action(self, board, temp=1e-3, return_prob=0): 188 | sensible_moves = board.availables 189 | # the pi vector returned by MCTS as in the alphaGo Zero paper 190 | move_probs = np.zeros(board.width*board.height) 191 | if len(sensible_moves) > 0: 192 | acts, probs = self.mcts.get_move_probs(board, temp) 193 | move_probs[list(acts)] = probs 194 | if self._is_selfplay: 195 | # add Dirichlet Noise for exploration (needed for 196 | # self-play training) 197 | # dirichlet_noise = K * 1.0/num_move_probs 尝试将0.3设置为0.2甚至更小 198 | move = np.random.choice( 199 | acts, 200 | p=0.75*probs + 0.25*np.random.dirichlet(0.3*np.ones(len(probs))) 201 | ) 202 | # update the root node and reuse the search tree 203 | self.mcts.update_with_move(move) 204 | else: 205 | # with the default temp=1e-3, it is almost equivalent 206 | # to choosing the move with the highest prob 207 | move = np.random.choice(acts, p=probs) 208 | # reset the root node 209 | self.mcts.update_with_move(-1) 210 | # location = board.move_to_location(move) 211 | # print("AI move: %d,%d\n" % (location[0], location[1])) 212 | 213 | if return_prob: 214 | return move, move_probs 215 | else: 216 | return move 217 | else: 218 | print("WARNING: the board is full") 219 | 220 | def __str__(self): 221 | return "MCTS {}".format(self.player) 222 | -------------------------------------------------------------------------------- /mcts_pure.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | A pure implementation of the Monte Carlo Tree Search (MCTS) 4 | 5 | @author: Junxiao Song 6 | """ 7 | 8 | import numpy as np 9 | import copy 10 | from operator import itemgetter 11 | 12 | 13 | def rollout_policy_fn(board): 14 | """a coarse, fast version of policy_fn used in the rollout phase.""" 15 | # rollout randomly 16 | action_probs = np.random.rand(len(board.availables)) 17 | return zip(board.availables, action_probs) 18 | 19 | 20 | def policy_value_fn(board): 21 | """a function that takes in a state and outputs a list of (action, probability) 22 | tuples and a score for the state""" 23 | # return uniform probabilities and 0 score for pure MCTS 24 | action_probs = np.ones(len(board.availables))/len(board.availables) 25 | return zip(board.availables, action_probs), 0 26 | 27 | 28 | class TreeNode(object): 29 | """A node in the MCTS tree. Each node keeps track of its own value Q, 30 | prior probability P, and its visit-count-adjusted prior score u. 31 | """ 32 | 33 | def __init__(self, parent, prior_p): 34 | self._parent = parent 35 | self._children = {} # a map from action to TreeNode 36 | self._n_visits = 0 37 | self._Q = 0 38 | self._u = 0 39 | self._P = prior_p 40 | 41 | def expand(self, action_priors): 42 | """Expand tree by creating new children. 43 | action_priors: a list of tuples of actions and their prior probability 44 | according to the policy function. 45 | """ 46 | for action, prob in action_priors: 47 | if action not in self._children: 48 | self._children[action] = TreeNode(self, prob) 49 | 50 | def select(self, c_puct): 51 | """Select action among children that gives maximum action value Q 52 | plus bonus u(P). 53 | Return: A tuple of (action, next_node) 54 | """ 55 | return max(self._children.items(), 56 | key=lambda act_node: act_node[1].get_value(c_puct)) 57 | 58 | def update(self, leaf_value): 59 | """Update node values from leaf evaluation. 60 | leaf_value: the value of subtree evaluation from the current player's 61 | perspective. 62 | """ 63 | # Count visit. 64 | self._n_visits += 1 65 | # Update Q, a running average of values for all visits. 66 | self._Q += 1.0*(leaf_value - self._Q) / self._n_visits 67 | 68 | def update_recursive(self, leaf_value): 69 | """Like a call to update(), but applied recursively for all ancestors. 70 | """ 71 | # If it is not root, this node's parent should be updated first. 72 | if self._parent: 73 | self._parent.update_recursive(-leaf_value) 74 | self.update(leaf_value) 75 | 76 | def get_value(self, c_puct): 77 | """Calculate and return the value for this node. 78 | It is a combination of leaf evaluations Q, and this node's prior 79 | adjusted for its visit count, u. 80 | c_puct: a number in (0, inf) controlling the relative impact of 81 | value Q, and prior probability P, on this node's score. 82 | """ 83 | self._u = (c_puct * self._P * 84 | np.sqrt(self._parent._n_visits) / (1 + self._n_visits)) 85 | return self._Q + self._u 86 | 87 | def is_leaf(self): 88 | """Check if leaf node (i.e. no nodes below this have been expanded). 89 | """ 90 | return self._children == {} 91 | 92 | def is_root(self): 93 | return self._parent is None 94 | 95 | 96 | class MCTS(object): 97 | """A simple implementation of Monte Carlo Tree Search.""" 98 | 99 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): 100 | """ 101 | policy_value_fn: a function that takes in a board state and outputs 102 | a list of (action, probability) tuples and also a score in [-1, 1] 103 | (i.e. the expected value of the end game score from the current 104 | player's perspective) for the current player. 105 | c_puct: a number in (0, inf) that controls how quickly exploration 106 | converges to the maximum-value policy. A higher value means 107 | relying on the prior more. 108 | """ 109 | self._root = TreeNode(None, 1.0) 110 | self._policy = policy_value_fn 111 | self._c_puct = c_puct 112 | self._n_playout = n_playout 113 | 114 | def _playout(self, state): 115 | """Run a single playout from the root to the leaf, getting a value at 116 | the leaf and propagating it back through its parents. 117 | State is modified in-place, so a copy must be provided. 118 | """ 119 | node = self._root 120 | while(1): 121 | if node.is_leaf(): 122 | 123 | break 124 | # Greedily select next move. 125 | action, node = node.select(self._c_puct) 126 | state.do_move(action) 127 | 128 | action_probs, _ = self._policy(state) 129 | # Check for end of game 130 | end, winner = state.game_end() 131 | if not end: 132 | node.expand(action_probs) 133 | # Evaluate the leaf node by random rollout 134 | leaf_value = self._evaluate_rollout(state) 135 | # Update value and visit count of nodes in this traversal. 136 | node.update_recursive(-leaf_value) 137 | 138 | def _evaluate_rollout(self, state, limit=1000): 139 | """Use the rollout policy to play until the end of the game, 140 | returning +1 if the current player wins, -1 if the opponent wins, 141 | and 0 if it is a tie. 142 | """ 143 | player = state.get_current_player() 144 | for i in range(limit): 145 | end, winner = state.game_end() 146 | if end: 147 | break 148 | action_probs = rollout_policy_fn(state) 149 | max_action = max(action_probs, key=itemgetter(1))[0] 150 | state.do_move(max_action) 151 | else: 152 | # If no break from the loop, issue a warning. 153 | print("WARNING: rollout reached move limit") 154 | if winner == -1: # tie 155 | return 0 156 | else: 157 | return 1 if winner == player else -1 158 | 159 | def get_move(self, state): 160 | """Runs all playouts sequentially and returns the most visited action. 161 | state: the current game state 162 | 163 | Return: the selected action 164 | """ 165 | for n in range(self._n_playout): 166 | state_copy = copy.deepcopy(state) 167 | self._playout(state_copy) 168 | return max(self._root._children.items(), 169 | key=lambda act_node: act_node[1]._n_visits)[0] 170 | 171 | def update_with_move(self, last_move): 172 | """Step forward in the tree, keeping everything we already know 173 | about the subtree. 174 | """ 175 | if last_move in self._root._children: 176 | self._root = self._root._children[last_move] 177 | self._root._parent = None 178 | else: 179 | self._root = TreeNode(None, 1.0) 180 | 181 | def __str__(self): 182 | return "MCTS" 183 | 184 | 185 | class MCTSPlayer(object): 186 | """AI player based on MCTS""" 187 | def __init__(self, c_puct=5, n_playout=2000): 188 | self.mcts = MCTS(policy_value_fn, c_puct, n_playout) 189 | 190 | def set_player_ind(self, p): 191 | self.player = p 192 | 193 | def reset_player(self): 194 | self.mcts.update_with_move(-1) 195 | 196 | def get_action(self, board): 197 | sensible_moves = board.availables 198 | if len(sensible_moves) > 0: 199 | move = self.mcts.get_move(board) 200 | self.mcts.update_with_move(-1) 201 | return move 202 | else: 203 | print("WARNING: the board is full") 204 | 205 | def __str__(self): 206 | return "MCTS {}".format(self.player) 207 | -------------------------------------------------------------------------------- /papers/thinking-fast-and-slow-with-deep-learning-and-tree-search.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/papers/thinking-fast-and-slow-with-deep-learning-and-tree-search.pdf -------------------------------------------------------------------------------- /play_vs_yixin.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import sys, time 4 | import random 5 | 6 | 7 | # ---------------------------------------------------------------------- 8 | # chessboard: 棋盘类,简单从字符串加载棋局或者导出字符串,判断输赢等 9 | # ---------------------------------------------------------------------- 10 | class chessboard(object): 11 | 12 | def __init__(self, forbidden=0): 13 | # list内list 14 | self.__board = [[0 for n in xrange(15)] for m in xrange(15)] 15 | self.__forbidden = forbidden 16 | self.__dirs = ((-1, 0), (-1, 1), (0, 1), (1, 1), (1, 0), \ 17 | (1, -1), (0, -1), (-1, -1)) 18 | self.DIRS = self.__dirs 19 | self.won = {} 20 | 21 | # 清空棋盘 22 | def reset(self): 23 | for j in xrange(15): 24 | for i in xrange(15): 25 | self.__board[i][j] = 0 26 | return 0 27 | 28 | # 索引器 29 | def __getitem__(self, row): 30 | return self.__board[row] 31 | 32 | # 将棋盘转换成字符串 33 | def __str__(self): 34 | text = ' A B C D E F G H I J K L M N O\n' 35 | mark = ('. ', 'O ', 'X ') 36 | nrow = 0 37 | for row in self.__board: 38 | line = ''.join([mark[n] for n in row]) 39 | text += chr(ord('A') + nrow) + ' ' + line 40 | nrow += 1 41 | if nrow < 15: text += '\n' 42 | return text 43 | 44 | # 转成字符串 45 | def __repr__(self): 46 | return self.__str__() 47 | 48 | def get(self, row, col): 49 | if row < 0 or row >= 15 or col < 0 or col >= 15: 50 | return 0 51 | return self.__board[row][col] 52 | 53 | def put(self, row, col, x): 54 | if row >= 0 and row < 15 and col >= 0 and col < 15: 55 | self.__board[row][col] = x 56 | return 0 57 | 58 | # 判断输赢,返回0(无输赢),1(白棋赢),2(黑棋赢) 59 | def check(self): 60 | board = self.__board 61 | dirs = ((1, -1), (1, 0), (1, 1), (0, 1)) 62 | for i in xrange(15): 63 | for j in xrange(15): 64 | if board[i][j] == 0: continue 65 | # id 是该位置的棋子(0或X): i行,j列 66 | id = board[i][j] 67 | for d in dirs: 68 | x, y = j, i 69 | count = 0 70 | for k in xrange(5): 71 | if self.get(y, x) != id: break 72 | y += d[0] 73 | x += d[1] 74 | count += 1 75 | if count == 5: 76 | self.won = {} 77 | r, c = i, j 78 | for z in xrange(5): 79 | self.won[(r, c)] = 1 80 | r += d[0] 81 | c += d[1] 82 | return id 83 | return 0 84 | 85 | # 返回数组对象 86 | def board(self): 87 | return self.__board 88 | 89 | # 导出棋局到字符串 90 | def dumps(self): 91 | import StringIO 92 | sio = StringIO.StringIO() 93 | board = self.__board 94 | for i in xrange(15): 95 | for j in xrange(15): 96 | stone = board[i][j] 97 | if stone != 0: 98 | ti = chr(ord('A') + i) 99 | tj = chr(ord('A') + j) 100 | sio.write('%d:%s%s ' % (stone, ti, tj)) 101 | return sio.getvalue() 102 | 103 | # 从字符串加载棋局 104 | def loads(self, text): 105 | self.reset() 106 | board = self.__board 107 | for item in text.strip('\r\n\t ').replace(',', ' ').split(' '): 108 | n = item.strip('\r\n\t ') 109 | if not n: continue 110 | n = n.split(':') 111 | stone = int(n[0]) 112 | i = ord(n[1][0].upper()) - ord('A') 113 | j = ord(n[1][1].upper()) - ord('A') 114 | board[i][j] = stone 115 | return 0 116 | 117 | # 设置终端颜色 118 | def console(self, color): 119 | if sys.platform[:3] == 'win': 120 | try: 121 | import ctypes 122 | except: 123 | return 0 124 | kernel32 = ctypes.windll.LoadLibrary('kernel32.dll') 125 | GetStdHandle = kernel32.GetStdHandle 126 | SetConsoleTextAttribute = kernel32.SetConsoleTextAttribute 127 | GetStdHandle.argtypes = [ctypes.c_uint32] 128 | GetStdHandle.restype = ctypes.c_size_t 129 | SetConsoleTextAttribute.argtypes = [ctypes.c_size_t, ctypes.c_uint16] 130 | SetConsoleTextAttribute.restype = ctypes.c_long 131 | handle = GetStdHandle(0xfffffff5) 132 | if color < 0: color = 7 133 | result = 0 134 | if (color & 1): result |= 4 135 | if (color & 2): result |= 2 136 | if (color & 4): result |= 1 137 | if (color & 8): result |= 8 138 | if (color & 16): result |= 64 139 | if (color & 32): result |= 32 140 | if (color & 64): result |= 16 141 | if (color & 128): result |= 128 142 | SetConsoleTextAttribute(handle, result) 143 | else: 144 | if color >= 0: 145 | foreground = color & 7 146 | background = (color >> 4) & 7 147 | bold = color & 8 148 | sys.stdout.write(" \033[%s3%d;4%dm" % (bold and "01;" or "", foreground, background)) 149 | sys.stdout.flush() 150 | else: 151 | sys.stdout.write(" \033[0m") 152 | sys.stdout.flush() 153 | return 0 154 | 155 | # 彩色输出 156 | def show(self): 157 | print ' A B C D E F G H I J K L M N O' 158 | mark = ('. ', 'O ', 'X ') 159 | nrow = 0 160 | self.check() 161 | color1 = 10 162 | color2 = 13 163 | for row in xrange(15): 164 | print chr(ord('A') + row), 165 | for col in xrange(15): 166 | ch = self.__board[row][col] 167 | if ch == 0: 168 | self.console(-1) 169 | print '.', 170 | elif ch == 1: 171 | if (row, col) in self.won: 172 | self.console(9) 173 | else: 174 | self.console(10) 175 | print 'O', 176 | # self.console(-1) 177 | elif ch == 2: 178 | if (row, col) in self.won: 179 | self.console(9) 180 | else: 181 | self.console(13) 182 | print 'X', 183 | # self.console(-1) 184 | self.console(-1) 185 | print '' 186 | return 0 187 | 188 | 189 | # ---------------------------------------------------------------------- 190 | # evaluation: 棋盘评估类,给当前棋盘打分用 191 | # ---------------------------------------------------------------------- 192 | class evaluation(object): 193 | 194 | def __init__(self): 195 | self.POS = [] 196 | for i in xrange(15): 197 | row = [(7 - max(abs(i - 7), abs(j - 7))) for j in xrange(15)] 198 | self.POS.append(tuple(row)) 199 | self.POS = tuple(self.POS) 200 | self.STWO = 1 # 冲二 201 | self.STHREE = 2 # 冲三 202 | self.SFOUR = 3 # 冲四 203 | self.TWO = 4 # 活二 204 | self.THREE = 5 # 活三 205 | self.FOUR = 6 # 活四 206 | self.FIVE = 7 # 活五 207 | self.DFOUR = 8 # 双四 208 | self.FOURT = 9 # 四三 209 | self.DTHREE = 10 # 双三 210 | self.NOTYPE = 11 211 | self.ANALYSED = 255 # 已经分析过 212 | self.TODO = 0 # 没有分析过 213 | self.result = [0 for i in xrange(30)] # 保存当前直线分析值 214 | self.line = [0 for i in xrange(30)] # 当前直线数据 215 | self.record = [] # 全盘分析结果 [row][col][方向] 216 | for i in xrange(15): 217 | self.record.append([]) 218 | self.record[i] = [] 219 | for j in xrange(15): 220 | self.record[i].append([0, 0, 0, 0]) 221 | self.count = [] # 每种棋局的个数:count[黑棋/白棋][模式] 222 | for i in xrange(3): 223 | data = [0 for i in xrange(20)] 224 | self.count.append(data) 225 | self.reset() 226 | 227 | # 复位数据 228 | def reset(self): 229 | TODO = self.TODO 230 | count = self.count 231 | for i in xrange(15): 232 | line = self.record[i] 233 | for j in xrange(15): 234 | line[j][0] = TODO 235 | line[j][1] = TODO 236 | line[j][2] = TODO 237 | line[j][3] = TODO 238 | for i in xrange(20): 239 | count[0][i] = 0 240 | count[1][i] = 0 241 | count[2][i] = 0 242 | return 0 243 | 244 | # 四个方向(水平,垂直,左斜,右斜)分析评估棋盘,然后根据分析结果打分 245 | def evaluate(self, board, turn): 246 | score = self.__evaluate(board, turn) 247 | count = self.count 248 | if score < -9000: 249 | stone = turn == 1 and 2 or 1 250 | for i in xrange(20): 251 | if count[stone][i] > 0: 252 | score -= i 253 | elif score > 9000: 254 | stone = turn == 1 and 2 or 1 255 | for i in xrange(20): 256 | if count[turn][i] > 0: 257 | score += i 258 | return score 259 | 260 | # 四个方向(水平,垂直,左斜,右斜)分析评估棋盘,然后根据分析结果打分 261 | def __evaluate(self, board, turn): 262 | record, count = self.record, self.count 263 | TODO, ANALYSED = self.TODO, self.ANALYSED 264 | self.reset() 265 | # 四个方向分析 266 | for i in xrange(15): 267 | boardrow = board[i] 268 | recordrow = record[i] 269 | for j in xrange(15): 270 | if boardrow[j] != 0: 271 | if recordrow[j][0] == TODO: # 水平没有分析过? 272 | self.__analysis_horizon(board, i, j) 273 | if recordrow[j][1] == TODO: # 垂直没有分析过? 274 | self.__analysis_vertical(board, i, j) 275 | if recordrow[j][2] == TODO: # 左斜没有分析过? 276 | self.__analysis_left(board, i, j) 277 | if recordrow[j][3] == TODO: # 右斜没有分析过 278 | self.__analysis_right(board, i, j) 279 | 280 | FIVE, FOUR, THREE, TWO = self.FIVE, self.FOUR, self.THREE, self.TWO 281 | SFOUR, STHREE, STWO = self.SFOUR, self.STHREE, self.STWO 282 | check = {} 283 | 284 | # 分别对白棋黑棋计算:FIVE, FOUR, THREE, TWO等出现的次数 285 | for c in (FIVE, FOUR, SFOUR, THREE, STHREE, TWO, STWO): 286 | check[c] = 1 287 | for i in xrange(15): 288 | for j in xrange(15): 289 | stone = board[i][j] 290 | if stone != 0: 291 | for k in xrange(4): 292 | ch = record[i][j][k] 293 | if ch in check: 294 | count[stone][ch] += 1 295 | 296 | # 如果有五连则马上返回分数 297 | BLACK, WHITE = 1, 2 298 | if turn == WHITE: # 当前是白棋 299 | if count[BLACK][FIVE]: 300 | return -9999 301 | if count[WHITE][FIVE]: 302 | return 9999 303 | else: # 当前是黑棋 304 | if count[WHITE][FIVE]: 305 | return -9999 306 | if count[BLACK][FIVE]: 307 | return 9999 308 | 309 | # 如果存在两个冲四,则相当于有一个活四 310 | if count[WHITE][SFOUR] >= 2: 311 | count[WHITE][FOUR] += 1 312 | if count[BLACK][SFOUR] >= 2: 313 | count[BLACK][FOUR] += 1 314 | 315 | # 具体打分 316 | wvalue, bvalue, win = 0, 0, 0 317 | if turn == WHITE: 318 | if count[WHITE][FOUR] > 0: return 9990 319 | if count[WHITE][SFOUR] > 0: return 9980 320 | if count[BLACK][FOUR] > 0: return -9970 321 | if count[BLACK][SFOUR] and count[BLACK][THREE]: 322 | return -9960 323 | if count[WHITE][THREE] and count[BLACK][SFOUR] == 0: 324 | return 9950 325 | if count[BLACK][THREE] > 1 and \ 326 | count[WHITE][SFOUR] == 0 and \ 327 | count[WHITE][THREE] == 0 and \ 328 | count[WHITE][STHREE] == 0: 329 | return -9940 330 | if count[WHITE][THREE] > 1: 331 | wvalue += 2000 332 | elif count[WHITE][THREE]: 333 | wvalue += 200 334 | if count[BLACK][THREE] > 1: 335 | bvalue += 500 336 | elif count[BLACK][THREE]: 337 | bvalue += 100 338 | if count[WHITE][STHREE]: 339 | wvalue += count[WHITE][STHREE] * 10 340 | if count[BLACK][STHREE]: 341 | bvalue += count[BLACK][STHREE] * 10 342 | if count[WHITE][TWO]: 343 | wvalue += count[WHITE][TWO] * 4 344 | if count[BLACK][TWO]: 345 | bvalue += count[BLACK][TWO] * 4 346 | if count[WHITE][STWO]: 347 | wvalue += count[WHITE][STWO] 348 | if count[BLACK][STWO]: 349 | bvalue += count[BLACK][STWO] 350 | else: 351 | if count[BLACK][FOUR] > 0: return 9990 352 | if count[BLACK][SFOUR] > 0: return 9980 353 | if count[WHITE][FOUR] > 0: return -9970 354 | if count[WHITE][SFOUR] and count[WHITE][THREE]: 355 | return -9960 356 | if count[BLACK][THREE] and count[WHITE][SFOUR] == 0: 357 | return 9950 358 | if count[WHITE][THREE] > 1 and \ 359 | count[BLACK][SFOUR] == 0 and \ 360 | count[BLACK][THREE] == 0 and \ 361 | count[BLACK][STHREE] == 0: 362 | return -9940 363 | if count[BLACK][THREE] > 1: 364 | bvalue += 2000 365 | elif count[BLACK][THREE]: 366 | bvalue += 200 367 | if count[WHITE][THREE] > 1: 368 | wvalue += 500 369 | elif count[WHITE][THREE]: 370 | wvalue += 100 371 | if count[BLACK][STHREE]: 372 | bvalue += count[BLACK][STHREE] * 10 373 | if count[WHITE][STHREE]: 374 | wvalue += count[WHITE][STHREE] * 10 375 | if count[BLACK][TWO]: 376 | bvalue += count[BLACK][TWO] * 4 377 | if count[WHITE][TWO]: 378 | wvalue += count[WHITE][TWO] * 4 379 | if count[BLACK][STWO]: 380 | bvalue += count[BLACK][STWO] 381 | if count[WHITE][STWO]: 382 | wvalue += count[WHITE][STWO] 383 | 384 | # 加上位置权值,棋盘最中心点权值是7,往外一格-1,最外圈是0 385 | wc, bc = 0, 0 386 | for i in xrange(15): 387 | for j in xrange(15): 388 | stone = board[i][j] 389 | if stone != 0: 390 | if stone == WHITE: 391 | wc += self.POS[i][j] 392 | else: 393 | bc += self.POS[i][j] 394 | wvalue += wc 395 | bvalue += bc 396 | 397 | if turn == WHITE: 398 | return wvalue - bvalue 399 | 400 | return bvalue - wvalue 401 | 402 | # 分析横向 403 | def __analysis_horizon(self, board, i, j): 404 | line, result, record = self.line, self.result, self.record 405 | TODO = self.TODO 406 | for x in xrange(15): 407 | line[x] = board[i][x] 408 | self.analysis_line(line, result, 15, j) 409 | for x in xrange(15): 410 | if result[x] != TODO: 411 | record[i][x][0] = result[x] 412 | return record[i][j][0] 413 | 414 | # 分析横向 415 | def __analysis_vertical(self, board, i, j): 416 | line, result, record = self.line, self.result, self.record 417 | TODO = self.TODO 418 | for x in xrange(15): 419 | line[x] = board[x][j] 420 | self.analysis_line(line, result, 15, i) 421 | for x in xrange(15): 422 | if result[x] != TODO: 423 | record[x][j][1] = result[x] 424 | return record[i][j][1] 425 | 426 | # 分析左斜 427 | def __analysis_left(self, board, i, j): 428 | line, result, record = self.line, self.result, self.record 429 | TODO = self.TODO 430 | if i < j: 431 | x, y = j - i, 0 432 | else: 433 | x, y = 0, i - j 434 | k = 0 435 | while k < 15: 436 | if x + k > 14 or y + k > 14: 437 | break 438 | line[k] = board[y + k][x + k] 439 | k += 1 440 | self.analysis_line(line, result, k, j - x) 441 | for s in xrange(k): 442 | if result[s] != TODO: 443 | record[y + s][x + s][2] = result[s] 444 | return record[i][j][2] 445 | 446 | # 分析右斜 447 | def __analysis_right(self, board, i, j): 448 | line, result, record = self.line, self.result, self.record 449 | TODO = self.TODO 450 | if 14 - i < j: 451 | x, y, realnum = j - 14 + i, 14, 14 - i 452 | else: 453 | x, y, realnum = 0, i + j, j 454 | k = 0 455 | while k < 15: 456 | if x + k > 14 or y - k < 0: 457 | break 458 | line[k] = board[y - k][x + k] 459 | k += 1 460 | self.analysis_line(line, result, k, j - x) 461 | for s in xrange(k): 462 | if result[s] != TODO: 463 | record[y - s][x + s][3] = result[s] 464 | return record[i][j][3] 465 | 466 | def test(self, board): 467 | self.reset() 468 | record = self.record 469 | TODO = self.TODO 470 | for i in xrange(15): 471 | for j in xrange(15): 472 | if board[i][j] != 0 and 1: 473 | if self.record[i][j][0] == TODO: 474 | self.__analysis_horizon(board, i, j) 475 | pass 476 | if self.record[i][j][1] == TODO: 477 | self.__analysis_vertical(board, i, j) 478 | pass 479 | if self.record[i][j][2] == TODO: 480 | self.__analysis_left(board, i, j) 481 | pass 482 | if self.record[i][j][3] == TODO: 483 | self.__analysis_right(board, i, j) 484 | pass 485 | return 0 486 | 487 | # 分析一条线:五四三二等棋型 488 | def analysis_line(self, line, record, num, pos): 489 | TODO, ANALYSED = self.TODO, self.ANALYSED 490 | THREE, STHREE = self.THREE, self.STHREE 491 | FOUR, SFOUR = self.FOUR, self.SFOUR 492 | while len(line) < 30: line.append(0xf) 493 | while len(record) < 30: record.append(TODO) 494 | for i in xrange(num, 30): 495 | line[i] = 0xf 496 | for i in xrange(num): 497 | record[i] = TODO 498 | if num < 5: 499 | for i in xrange(num): 500 | record[i] = ANALYSED 501 | return 0 502 | stone = line[pos] 503 | inverse = (0, 2, 1)[stone] 504 | num -= 1 505 | xl = pos 506 | xr = pos 507 | while xl > 0: # 探索左边界 508 | if line[xl - 1] != stone: break 509 | xl -= 1 510 | while xr < num: # 探索右边界 511 | if line[xr + 1] != stone: break 512 | xr += 1 513 | left_range = xl 514 | right_range = xr 515 | while left_range > 0: # 探索左边范围(非对方棋子的格子坐标) 516 | if line[left_range - 1] == inverse: break 517 | left_range -= 1 518 | while right_range < num: # 探索右边范围(非对方棋子的格子坐标) 519 | if line[right_range + 1] == inverse: break 520 | right_range += 1 521 | 522 | # 如果该直线范围小于 5,则直接返回 523 | if right_range - left_range < 4: 524 | for k in xrange(left_range, right_range + 1): 525 | record[k] = ANALYSED 526 | return 0 527 | 528 | # 设置已经分析过 529 | for k in xrange(xl, xr + 1): 530 | record[k] = ANALYSED 531 | 532 | srange = xr - xl 533 | 534 | # 如果是 5连 535 | if srange >= 4: 536 | record[pos] = self.FIVE 537 | return self.FIVE 538 | 539 | # 如果是 4连 540 | if srange == 3: 541 | leftfour = False # 是否左边是空格 542 | if xl > 0: 543 | if line[xl - 1] == 0: # 活四 544 | leftfour = True 545 | if xr < num: 546 | if line[xr + 1] == 0: 547 | if leftfour: 548 | record[pos] = self.FOUR # 活四 549 | else: 550 | record[pos] = self.SFOUR # 冲四 551 | else: 552 | if leftfour: 553 | record[pos] = self.SFOUR # 冲四 554 | else: 555 | if leftfour: 556 | record[pos] = self.SFOUR # 冲四 557 | return record[pos] 558 | 559 | # 如果是 3连 560 | if srange == 2: # 三连 561 | left3 = False # 是否左边是空格 562 | if xl > 0: 563 | if line[xl - 1] == 0: # 左边有气 564 | if xl > 1 and line[xl - 2] == stone: 565 | record[xl] = SFOUR 566 | record[xl - 2] = ANALYSED 567 | else: 568 | left3 = True 569 | elif xr == num or line[xr + 1] != 0: 570 | return 0 571 | if xr < num: 572 | if line[xr + 1] == 0: # 右边有气 573 | if xr < num - 1 and line[xr + 2] == stone: 574 | record[xr] = SFOUR # XXX-X 相当于冲四 575 | record[xr + 2] = ANALYSED 576 | elif left3: 577 | record[xr] = THREE 578 | else: 579 | record[xr] = STHREE 580 | elif record[xl] == SFOUR: 581 | return record[xl] 582 | elif left3: 583 | record[pos] = STHREE 584 | else: 585 | if record[xl] == SFOUR: 586 | return record[xl] 587 | if left3: 588 | record[pos] = STHREE 589 | return record[pos] 590 | 591 | # 如果是 2连 592 | if srange == 1: # 两连 593 | left2 = False 594 | if xl > 2: 595 | if line[xl - 1] == 0: # 左边有气 596 | if line[xl - 2] == stone: 597 | if line[xl - 3] == stone: 598 | record[xl - 3] = ANALYSED 599 | record[xl - 2] = ANALYSED 600 | record[xl] = SFOUR 601 | elif line[xl - 3] == 0: 602 | record[xl - 2] = ANALYSED 603 | record[xl] = STHREE 604 | else: 605 | left2 = True 606 | if xr < num: 607 | if line[xr + 1] == 0: # 左边有气 608 | if xr < num - 2 and line[xr + 2] == stone: 609 | if line[xr + 3] == stone: 610 | record[xr + 3] = ANALYSED 611 | record[xr + 2] = ANALYSED 612 | record[xr] = SFOUR 613 | elif line[xr + 3] == 0: 614 | record[xr + 2] = ANALYSED 615 | record[xr] = left2 and THREE or STHREE 616 | else: 617 | if record[xl] == SFOUR: 618 | return record[xl] 619 | if record[xl] == STHREE: 620 | record[xl] = THREE 621 | return record[xl] 622 | if left2: 623 | record[pos] = self.TWO 624 | else: 625 | record[pos] = self.STWO 626 | else: 627 | if record[xl] == SFOUR: 628 | return record[xl] 629 | if left2: 630 | record[pos] = self.STWO 631 | return record[pos] 632 | return 0 633 | 634 | def textrec(self, direction=0): 635 | text = [] 636 | for i in xrange(15): 637 | line = '' 638 | for j in xrange(15): 639 | line += '%x ' % (self.record[i][j][direction] & 0xf) 640 | text.append(line) 641 | return '\n'.join(text) 642 | 643 | 644 | # ---------------------------------------------------------------------- 645 | # DFS: 博弈树搜索 646 | # ---------------------------------------------------------------------- 647 | class searcher(object): 648 | 649 | # 初始化 650 | def __init__(self): 651 | self.evaluator = evaluation() 652 | self.board = [[0 for n in xrange(15)] for i in xrange(15)] 653 | self.gameover = 0 654 | self.overvalue = 0 655 | self.maxdepth = 3 656 | 657 | # 产生当前棋局的走法 658 | def genmove(self, turn): 659 | moves = [] 660 | board = self.board 661 | POSES = self.evaluator.POS 662 | for i in xrange(15): 663 | for j in xrange(15): 664 | if board[i][j] == 0: 665 | score = POSES[i][j] 666 | moves.append((score, i, j)) 667 | moves.sort() 668 | moves.reverse() 669 | return moves 670 | 671 | # 递归搜索:返回最佳分数 672 | def __search(self, turn, depth, alpha=-0x7fffffff, beta=0x7fffffff): 673 | 674 | # 深度为零则评估棋盘并返回 675 | if depth <= 0: 676 | score = self.evaluator.evaluate(self.board, turn) 677 | return score 678 | 679 | # 如果游戏结束则立马返回 680 | score = self.evaluator.evaluate(self.board, turn) 681 | if abs(score) >= 9999 and depth < self.maxdepth: 682 | return score 683 | 684 | # 产生新的走法 685 | moves = self.genmove(turn) 686 | bestmove = None 687 | 688 | # 枚举当前所有走法 689 | for score, row, col in moves: 690 | 691 | # 标记当前走法到棋盘 692 | self.board[row][col] = turn 693 | 694 | # 计算下一回合该谁走 695 | nturn = turn == 1 and 2 or 1 696 | 697 | # 深度优先搜索,返回评分,走的行和走的列 698 | score = - self.__search(nturn, depth - 1, -beta, -alpha) 699 | 700 | # 棋盘上清除当前走法 701 | self.board[row][col] = 0 702 | 703 | # 计算最好分值的走法 704 | # alpha/beta 剪枝 705 | if score > alpha: 706 | alpha = score 707 | bestmove = (row, col) 708 | if alpha >= beta: 709 | break 710 | 711 | # 如果是第一层则记录最好的走法 712 | if depth == self.maxdepth and bestmove: 713 | self.bestmove = bestmove 714 | 715 | # 返回当前最好的分数,和该分数的对应走法 716 | return alpha 717 | 718 | # 具体搜索:传入当前是该谁走(turn=1/2),以及搜索深度(depth) 719 | def search(self, turn, depth=3): 720 | self.maxdepth = depth # 0.7 的概率按照depth搜索 721 | self.bestmove = None 722 | score = self.__search(turn, depth) 723 | if abs(score) > 8000: 724 | self.maxdepth = depth # 0.85 的概率按照depth搜索 725 | score = self.__search(turn, 1) 726 | row, col = self.bestmove 727 | return score, row, col 728 | 729 | 730 | # ---------------------------------------------------------------------- 731 | # psyco speedup 732 | # ---------------------------------------------------------------------- 733 | def psyco_speedup(): 734 | try: 735 | import psyco 736 | psyco.bind(chessboard) 737 | psyco.bind(evaluation) 738 | except: 739 | pass 740 | return 0 741 | 742 | 743 | psyco_speedup() 744 | 745 | 746 | # ---------------------------------------------------------------------- 747 | # main game 748 | # ---------------------------------------------------------------------- 749 | def gamemain(): 750 | b = chessboard() 751 | # 黑手AI 752 | s_blank = searcher() 753 | s_blank.board = b.board() 754 | s = searcher() 755 | s.board = b.board() 756 | 757 | opening = [ 758 | '1:HH 2:GI', 759 | '2:IG 2:GI 1:HH', 760 | '1:IH 2:GI', 761 | '1:HG 2:HI', 762 | '2:HG 2:HI 1:HH', 763 | '1:HH 2:IH 2:GI', 764 | '1:HH 2:IH 2:HI', 765 | '1:HH 2:IH 2:HJ', 766 | '1:HG 2:HH 2:HI', 767 | '1:GH 2:HH 2:HI', 768 | ] 769 | 770 | openid = random.randint(0, len(opening) - 1) 771 | b.loads(opening[0]) 772 | turn = 2 773 | history = [] 774 | undo = False 775 | 776 | # 设置难度 777 | DEPTH = 1 778 | DEPTH_BLACK = 5 779 | 780 | while 1: 781 | print '' 782 | while 1: 783 | print '' % (len(history) + 1) 784 | b.show() 785 | print 'Robot move (u:undo, q:quit): \n', 786 | # 黑手AI自动下 787 | # text = raw_input().strip('\r\n\t ') 788 | score_b, tr, tc = s.search(1, DEPTH_BLACK) 789 | cord_b = '%s%s' % (chr(ord('A') + tr), chr(ord('A') + tc)) 790 | print 'Rotob move to %s (%d) \n' % (cord_b, score_b) 791 | # if len(text) == 2: 792 | # tr = ord(text[0].upper()) - ord('A') 793 | # tc = ord(text[1].upper()) - ord('A') 794 | if tr >= 0 and tc >= 0 and tr < 15 and tc < 15: 795 | if b[tr][tc] == 0: 796 | row, col = tr, tc 797 | break 798 | else: 799 | print 'can not move there' 800 | else: 801 | print 'bad position' 802 | 803 | if undo == True: 804 | undo = False 805 | if len(history) == 0: 806 | print 'no history to undo' 807 | else: 808 | print 'rollback from history ...' 809 | move = history.pop() 810 | b.loads(move) 811 | else: 812 | history.append(b.dumps()) 813 | b[row][col] = 1 814 | b.show() 815 | time.sleep(0.2) 816 | 817 | if b.check() == 1: 818 | b.show() 819 | print b.dumps() 820 | print '' 821 | print 'YOU WIN !!' 822 | return 0 823 | 824 | print 'you should input ...\n' 825 | text = raw_input().strip('\r\n\t ') 826 | if len(text) == 2: 827 | tr = ord(text[0].upper()) - ord('A') 828 | tc = ord(text[1].upper()) - ord('A') 829 | if tr >= 0 and tc >= 0 and tr < 15 and tc < 15: 830 | if b[tr][tc] == 0: 831 | row, col = tr, tc 832 | # break 833 | else: 834 | print 'can not move there' 835 | else: 836 | print 'bad position' 837 | # xtt = input('go on: ') 838 | # score, row, col = s.search(2, DEPTH) 839 | cord = '%s%s' % (chr(ord('A') + row), chr(ord('A') + col)) 840 | print 'robot move to %s ' % cord 841 | # xtt = input('go on: ') 842 | b[row][col] = 2 843 | time.sleep(0.2) 844 | 845 | if b.check() == 2: 846 | b.show() 847 | print b.dumps() 848 | print '' 849 | print 'YOU LOSE.' 850 | return 0 851 | 852 | return 0 853 | 854 | 855 | # ---------------------------------------------------------------------- 856 | # testing case 857 | # ---------------------------------------------------------------------- 858 | if __name__ == '__main__': 859 | start_time = time.time() 860 | gamemain() 861 | print('耗时: ', time.time() - start_time) 862 | 863 | 864 | -------------------------------------------------------------------------------- /policy_value_net_mxnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An implementation of the policyValueNet with Keras 4 | Tested under Keras 2.0.5 with tensorflow-gpu 1.2.1 as backend 5 | 6 | @author: Mingxu Zhang 7 | """ 8 | from __future__ import print_function 9 | import sys 10 | import os 11 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 12 | # sys.path.insert(0, '/home/mingzhang/work/dmlc/python_mxnet/python') 13 | 14 | import mxnet as mx 15 | import numpy as np 16 | import pickle 17 | 18 | 19 | class PolicyValueNet(): 20 | """policy-value network """ 21 | def __init__(self, board_width, board_height, batch_size=512, n_blocks=8, n_filter=128, model_params=None): 22 | self.context = mx.gpu(0) 23 | self.batchsize = batch_size #must same to the TrainPipeline's self.batch_size. 24 | self.channelnum = 9 25 | self.board_width = board_width 26 | self.board_height = board_height 27 | self._n_blocks = n_blocks 28 | self._n_filter = n_filter 29 | self.l2_const = 1e-4 # coef of l2 penalty 30 | self.train_batch = self.create_policy_value_train(self.batchsize) 31 | self.predict_batch = self.create_policy_value_predict(self.batchsize) 32 | self.predict_one = self.create_policy_value_predict(1) 33 | self.num = 0 34 | 35 | if model_params: 36 | self.train_batch.set_params(*model_params) 37 | self.predict_batch.set_params(*model_params) 38 | self.predict_one.set_params(*model_params) 39 | pass 40 | 41 | def conv_act(self, data, num_filter=32, kernel=(3, 3), stride=(1, 1), act='relu', dobn=True, name=''): 42 | # self convolution activation 43 | assert(name!='' and name!=None) 44 | pad = (int(kernel[0]/2), int(kernel[1]/2)) 45 | w = mx.sym.Variable(name+'_weight') 46 | b = mx.sym.Variable(name+'_bias') 47 | conv1 = mx.sym.Convolution(data=data, weight=w, bias=b, num_filter=num_filter, kernel=kernel, pad=pad, name=name) 48 | act1 = conv1 49 | if dobn: 50 | gamma = mx.sym.Variable(name+'_gamma') 51 | beta = mx.sym.Variable(name+'_beta') 52 | mean = mx.sym.Variable(name+'_mean') 53 | var = mx.sym.Variable(name+'_var') 54 | bn = mx.sym.BatchNorm(data=conv1, gamma=gamma, beta=beta, moving_mean=mean, moving_var=var, name=name+'_bn') 55 | act1 = bn 56 | if act is not None and act!='': 57 | #print('....', act) 58 | act1 = mx.sym.Activation(data=act1, act_type=act, name=name+'_act') 59 | 60 | return act1 61 | 62 | def fc_self(self, data, num_hidden, name=''): 63 | assert(name!='' and name!=None) 64 | w = mx.sym.Variable(name+'_weight') 65 | b = mx.sym.Variable(name+'_bias') 66 | fc_1 = mx.sym.FullyConnected(data, weight=w, bias=b, num_hidden=num_hidden, name=name) 67 | 68 | return fc_1 69 | 70 | def create_backbone_resnet(self, input_states): 71 | """ 8层残差网络 """ 72 | 73 | con_net = self.conv_act(input_states, 128, (3, 3), name='res_conv1') 74 | for i in range(1, self._n_blocks+1): 75 | # 残差结构定义 76 | pre_identity = con_net # 保存残差之前部分 77 | con_net = mx.sym.Convolution(con_net, name='convA'+str(i), kernel=(3, 3), pad=(1, 1), num_filter=self._n_filter) 78 | con_net = mx.sym.BatchNorm(con_net, name='bnA'+str(i), fix_gamma=False) 79 | con_net = mx.sym.Activation(con_net, name='actA'+str(i), act_type='relu') 80 | con_net = mx.sym.Convolution(con_net, name='convB'+str(i), kernel=(3, 3), pad=(1, 1), num_filter=self._n_filter) 81 | con_net = mx.sym.BatchNorm(con_net, name='bnB'+str(i), fix_gamma=False) 82 | con_net = con_net + pre_identity # 加上之前的输出,即为残差结构 83 | con_net = mx.sym.Activation(con_net, name='actB'+str(i), act_type='relu') 84 | 85 | # action output 86 | conv3_act = self.conv_act(con_net, 4, (1, 1), name='conv3_1_1') 87 | flatten_1 = mx.sym.Flatten(conv3_act) 88 | flatten_1 = mx.sym.Dropout(flatten_1, p=0.5) 89 | fc_3_1_1 = self.fc_self(flatten_1, self.board_height*self.board_width, name='fc_3_1_1') 90 | action_1 = mx.sym.SoftmaxActivation(fc_3_1_1, name='Act_SILER') 91 | 92 | # value output 93 | conv3_2_1 = self.conv_act(con_net, 2, (1, 1), name='conv3_2_1') 94 | flatten_2 = mx.sym.Flatten(conv3_2_1) 95 | flatten_2 = mx.sym.Dropout(flatten_2, p=0.5) 96 | fc_3_2_1 = self.fc_self(flatten_2, 1, name='fc_3_2_1') 97 | evaluation = mx.sym.Activation(fc_3_2_1, act_type='tanh') 98 | 99 | # mx.viz.plot_network(action_1).view() 100 | # mx.viz.plot_network(evaluation).view() 101 | 102 | return action_1, evaluation 103 | 104 | 105 | def create_backbone(self, input_states): 106 | """ 原始策略价值网络 """ 107 | 108 | conv1 = self.conv_act(input_states, 64, (3, 3), name='conv1') 109 | conv2 = self.conv_act(conv1, 64, (3, 3), name='conv2') 110 | conv3 = self.conv_act(conv2, 128, (3, 3), name='conv3') 111 | conv4 = self.conv_act(conv3, 128, (3, 3), name='conv4') 112 | conv5 = self.conv_act(conv4, 256, (3, 3), name='conv5') 113 | final = self.conv_act(conv5, 256, (3, 3), name='conv_final') 114 | 115 | # action policy layers 116 | conv3_1_1 = self.conv_act(final, 4, (1, 1), name='conv3_1_1') 117 | flatten_1 = mx.sym.Flatten(conv3_1_1) 118 | flatten_1 = mx.sym.Dropout(flatten_1, p=0.5) 119 | fc_3_1_1 = self.fc_self(flatten_1, self.board_height*self.board_width, name='fc_3_1_1') 120 | action_1 = mx.sym.SoftmaxActivation(fc_3_1_1, name='Act_SILER') 121 | # arg_shapes, out_shapes, aux_shapes = action_1.infer_shape() 122 | # print('arg_shape: ', arg_shapes) 123 | # print('out_shape: ', out_shapes) 124 | # print('aux_shape: ', aux_shapes) 125 | # arg_shape: [(512L, 9L, 15L, 15L), (64L, 9L, 3L, 3L), (64L,), (64L,), (64L,), (64L, 64L, 3L, 3L), (64L,), (64L,), (64L,), (128L, 64L, 3L, 3L), (128L,), (128L,), (128L,), (128L, 128L, 3L, 3L), (128L,), (128L,), (128L,), (256L, 128L, 3L, 3L), (256L,), (256L,), (256L,), (256L, 256L, 3L, 3L), (256L,), (256L,), (256L,), (4L, 256L, 1L, 1L), (4L,), (4L,), (4L,), (225L, 900L), (225L,)] 126 | # out_shape: [(512L, 225L)] 127 | # aux_shape: [(64L,), (64L,), (64L,), (64L,), (128L,), (128L,), (128L,), (128L,), (256L,), (256L,), (256L,), (256L,), (4L,), (4L,)] 128 | 129 | # state value layers 130 | conv3_2_1 = self.conv_act(final, 2, (1, 1), name='conv3_2_1') 131 | flatten_2 = mx.sym.Flatten(conv3_2_1) 132 | flatten_2 = mx.sym.Dropout(flatten_2, p=0.5) 133 | fc_3_2_1 = self.fc_self(flatten_2, 1, name='fc_3_2_1') 134 | evaluation = mx.sym.Activation(fc_3_2_1, act_type='tanh') 135 | # args_shapes, out_shapes, aux_shapes = evaluation.infer_shape() 136 | # args_shape: [(5L, 9L, 15L, 15L), (64L, 9L, 3L, 3L), (64L,), (64L,), (64L,), (64L, 64L, 3L, 3L), (64L,), (64L,), (64L,), (128L, 64L, 3L, 3L), (128L,), (128L,), (128L,), (128L, 128L, 3L, 3L), (128L,), (128L,), (128L,), (256L, 128L, 3L, 3L), (256L,), (256L,), (256L,), (256L, 256L, 3L, 3L), (256L,), (256L,), (256L,), (2L, 256L, 1L, 1L), (2L,), (2L,), (2L,), (1L, 450L), (1L,)] 137 | # out_shapes: [(5L, 1L)] 138 | # aux_shapes: [(64L,), (64L,), (64L,), (64L,), (128L,), (128L,), (128L,), (128L,), (256L,), (256L,), (256L,), (256L,), (2L,), (2L,)] 139 | mx.viz.plot_network(action_1).view() 140 | # mx.viz.plot_network(evaluation).view() 141 | go_on = input("go on?:") 142 | if go_on == 0: exit() 143 | 144 | return action_1, evaluation 145 | 146 | 147 | 148 | def create_backbone2(self, input_states): 149 | """create the policy value network """ 150 | 151 | conv1 = self.conv_act(input_states, 32, (3, 3), name='conv1') 152 | conv2 = self.conv_act(conv1, 64, (3, 3), name='conv2') 153 | conv3 = self.conv_act(conv2, 128, (3, 3), name='conv3') 154 | conv4 = self.conv_act(conv3, 128, (3, 3), name='conv4') 155 | conv5 = self.conv_act(conv4, 128, (3, 3), name='conv5') 156 | final = self.conv_act(conv5, 128, (3, 3), name='conv_final') 157 | 158 | # action policy layers 159 | conv3_1_1 = self.conv_act(final, 1024, (1, 1), name='conv3_1_1') 160 | conv3_1_2 = self.conv_act(conv3_1_1, 1, (1, 1), act=None, dobn=False, name='conv3_1_2') 161 | flatten_1 = mx.sym.Flatten(conv3_1_2) 162 | action_1 = mx.sym.SoftmaxActivation(flatten_1) 163 | 164 | # state value layers 165 | conv3_2_1 = self.conv_act(final, 256, (1, 1), name='conv3_2_1') 166 | conv3_2_2 = self.conv_act(conv3_2_1, 1, (1, 1), act=None, dobn=False, name='conv3_2_2') 167 | flatten_2 = mx.sym.Flatten(conv3_2_2) 168 | mean_2 = mx.sym.mean(flatten_2, axis=1, keepdims=True) 169 | evaluation = mx.sym.Activation(mean_2, act_type='tanh') 170 | 171 | return action_1, evaluation 172 | 173 | def create_policy_value_train(self, batch_size): 174 | input_states_shape = (batch_size, self.channelnum, self.board_height, self.board_width) 175 | input_states = mx.sym.Variable(name='input_states', shape=input_states_shape) 176 | # action_1, evaluation = self.create_backbone(input_states) 177 | action_1, evaluation = self.create_backbone_resnet(input_states) 178 | 179 | mcts_probs_shape = (batch_size, self.board_height * self.board_width) 180 | mcts_probs = mx.sym.Variable(name='mcts_probs', shape=mcts_probs_shape) 181 | policy_loss = -mx.sym.sum(mx.sym.log(action_1) * mcts_probs, axis=1) 182 | policy_loss = mx.sym.mean(policy_loss) 183 | 184 | input_labels_shape = (batch_size, 1) 185 | input_labels = mx.sym.Variable(name='input_labels', shape=input_labels_shape) 186 | value_loss = mx.sym.mean(mx.sym.square(input_labels - evaluation)) 187 | 188 | loss = value_loss + policy_loss 189 | loss = mx.sym.MakeLoss(loss) 190 | 191 | entropy = mx.sym.sum(-action_1 * mx.sym.log(action_1), axis=1) 192 | entropy = mx.sym.mean(entropy) 193 | entropy = mx.sym.BlockGrad(entropy) 194 | entropy = mx.sym.MakeLoss(entropy) 195 | policy_value_loss = mx.sym.Group([loss, entropy]) 196 | policy_value_loss.save('policy_value_loss.json') 197 | 198 | pv_train = mx.mod.Module(symbol=policy_value_loss, 199 | data_names=['input_states'], 200 | label_names=['input_labels', 'mcts_probs'], 201 | context=self.context) 202 | pv_train.bind(data_shapes=[('input_states', input_states_shape)], 203 | label_shapes=[('input_labels', input_labels_shape), ('mcts_probs', mcts_probs_shape)], 204 | for_training=True) 205 | pv_train.init_params(initializer=mx.init.Xavier()) 206 | pv_train.init_optimizer(optimizer='adam', 207 | optimizer_params={'learning_rate':0.001, 208 | #'clip_gradient':0.1, 209 | #'momentum':0.9, 210 | 'wd':0.0001}) 211 | 212 | return pv_train 213 | 214 | def create_policy_value_predict(self, batch_size): 215 | input_states_shape = (batch_size, self.channelnum, self.board_height, self.board_width) 216 | input_states = mx.sym.Variable(name='input_states', shape=input_states_shape) 217 | # action_1, evaluation = self.create_backbone(input_states) 218 | action_1, evaluation = self.create_backbone_resnet(input_states) 219 | policy_value_output = mx.sym.Group([action_1, evaluation]) 220 | 221 | pv_predict = mx.mod.Module(symbol=policy_value_output, 222 | data_names=['input_states'], 223 | label_names=None, 224 | context=self.context) 225 | 226 | pv_predict.bind(data_shapes=[('input_states', input_states_shape)], for_training=False) 227 | args, auxs = self.train_batch.get_params() 228 | pv_predict.set_params(args, auxs) 229 | 230 | return pv_predict 231 | 232 | def policy_value(self, state_batch): 233 | states = np.asarray(state_batch) 234 | #print('policy_value:', states.shape) 235 | state_nd = mx.nd.array(states) 236 | self.predict_batch.forward(mx.io.DataBatch([state_nd])) 237 | acts, vals = self.predict_batch.get_outputs() 238 | acts = acts.asnumpy() 239 | vals = vals.asnumpy() 240 | #print(acts[0], vals[0]) 241 | 242 | return acts, vals 243 | 244 | def policy_value2(self, state_batch): 245 | actsall = [] 246 | valsall = [] 247 | for state in state_batch: 248 | state = state.reshape(1, self.channelnum, self.board_height, self.board_width) 249 | #print(state.shape) 250 | state_nd = mx.nd.array(state) 251 | self.predict_one.forward(mx.io.DataBatch([state_nd])) 252 | act, val = self.predict_one.get_outputs() 253 | actsall.append(act[0].asnumpy()) 254 | valsall.append(val[0].asnumpy()) 255 | acts = np.asarray(actsall) 256 | vals = np.asarray(valsall) 257 | #print(acts.shape, vals.shape) 258 | 259 | return acts, vals 260 | 261 | def policy_value_fn(self, board): 262 | """ 263 | input: board 264 | output: a list of (action, probability) tuples for each available action and the score of the board state 265 | """ 266 | legal_positions = board.availables 267 | current_state = board.current_state().reshape(1, self.channelnum, self.board_height, self.board_width) 268 | state_nd = mx.nd.array(current_state) 269 | self.predict_one.forward(mx.io.DataBatch([state_nd])) 270 | acts_probs, values = self.predict_one.get_outputs() 271 | acts_probs = acts_probs.asnumpy() 272 | values = values.asnumpy() 273 | #print(acts_probs[0, :4]) 274 | legal_actprob = acts_probs[0][legal_positions] 275 | act_probs = zip(legal_positions, legal_actprob) 276 | # print(len(legal_positions), legal_actprob.shape, acts_probs.shape) 277 | # if len(legal_positions)==0: 278 | # exit() 279 | 280 | return act_probs, values[0] 281 | 282 | def train_step(self, state_batch, mcts_probs, winner_batch, learning_rate): 283 | # winner_batch: 1.0/-1.0 284 | #print('hello training....') 285 | #print(mcts_probs[0], winner_batch[0]) 286 | self.train_batch._optimizer.lr = learning_rate 287 | state_batch = mx.nd.array(np.asarray(state_batch).reshape(-1, self.channelnum, self.board_height, self.board_width)) 288 | mcts_probs = mx.nd.array(np.asarray(mcts_probs).reshape(-1, self.board_height*self.board_width)) 289 | winner_batch = mx.nd.array(np.asarray(winner_batch).reshape(-1, 1)) 290 | self.train_batch.forward(mx.io.DataBatch([state_batch], [winner_batch, mcts_probs])) 291 | self.train_batch.backward() 292 | self.train_batch.update() 293 | loss, entropy = self.train_batch.get_outputs() 294 | 295 | args, auxs = self.train_batch.get_params() 296 | self.predict_batch.set_params(args, auxs) 297 | self.predict_one.set_params(args, auxs) 298 | 299 | return loss.asnumpy(), entropy.asnumpy() 300 | 301 | def get_policy_param(self): 302 | net_params = self.train_batch.get_params() 303 | return net_params 304 | 305 | def save_model(self, model_file): 306 | """ save model params to file """ 307 | net_params = self.get_policy_param() 308 | print('>>>>>>>>>> saved into', model_file) 309 | pickle.dump(net_params, open(model_file, 'wb'), protocol=2) 310 | -------------------------------------------------------------------------------- /policy_value_net_mxnet_simple.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An implementation of the policyValueNet with Keras 4 | Tested under Keras 2.0.5 with tensorflow-gpu 1.2.1 as backend 5 | 6 | @author: Mingxu Zhang 7 | """ 8 | from __future__ import print_function 9 | import sys 10 | import os 11 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 12 | # sys.path.insert(0, '/home/mingzhang/work/dmlc/python_mxnet/python') 13 | 14 | import mxnet as mx 15 | import numpy as np 16 | import pickle 17 | 18 | 19 | class PolicyValueNet(): 20 | """policy-value network """ 21 | def __init__(self, board_width, board_height, batch_size=512, model_params=None): 22 | self.context = mx.cpu() 23 | self.batchsize = batch_size #must same to the TrainPipeline's self.batch_size. 24 | self.channelnum = 9 25 | self.board_width = board_width 26 | self.board_height = board_height 27 | self.l2_const = 1e-4 # coef of l2 penalty 28 | self.train_batch = self.create_policy_value_train(self.batchsize) 29 | self.predict_batch = self.create_policy_value_predict(self.batchsize) 30 | self.predict_one = self.create_policy_value_predict(1) 31 | self.num = 0 32 | 33 | if model_params: 34 | self.train_batch.set_params(*model_params) 35 | self.predict_batch.set_params(*model_params) 36 | self.predict_one.set_params(*model_params) 37 | pass 38 | 39 | def conv_act(self, data, num_filter=32, kernel=(3, 3), stride=(1, 1), act='relu', dobn=True, name=''): 40 | # self convolution activation 41 | assert(name!='' and name!=None) 42 | pad = (int(kernel[0]/2), int(kernel[1]/2)) 43 | w = mx.sym.Variable(name+'_weight') 44 | b = mx.sym.Variable(name+'_bias') 45 | conv1 = mx.sym.Convolution(data=data, weight=w, bias=b, num_filter=num_filter, kernel=kernel, pad=pad, name=name) 46 | act1 = conv1 47 | if dobn: 48 | gamma = mx.sym.Variable(name+'_gamma') 49 | beta = mx.sym.Variable(name+'_beta') 50 | mean = mx.sym.Variable(name+'_mean') 51 | var = mx.sym.Variable(name+'_var') 52 | bn = mx.sym.BatchNorm(data=conv1, gamma=gamma, beta=beta, moving_mean=mean, moving_var=var, name=name+'_bn') 53 | act1 = bn 54 | if act is not None and act!='': 55 | #print('....', act) 56 | act1 = mx.sym.Activation(data=act1, act_type=act, name=name+'_act') 57 | 58 | return act1 59 | 60 | def fc_self(self, data, num_hidden, name=''): 61 | assert(name!='' and name!=None) 62 | w = mx.sym.Variable(name+'_weight') 63 | b = mx.sym.Variable(name+'_bias') 64 | fc_1 = mx.sym.FullyConnected(data, weight=w, bias=b, num_hidden=num_hidden, name=name) 65 | 66 | return fc_1 67 | 68 | def create_backbone(self, input_states): 69 | """create the policy value network """ 70 | 71 | conv1 = self.conv_act(input_states, 64, (3, 3), name='conv1') 72 | conv2 = self.conv_act(conv1, 64, (3, 3), name='conv2') 73 | conv3 = self.conv_act(conv2, 128, (3, 3), name='conv3') 74 | conv4 = self.conv_act(conv3, 128, (3, 3), name='conv4') 75 | conv5 = self.conv_act(conv4, 256, (3, 3), name='conv5') 76 | final = self.conv_act(conv5, 256, (3, 3), name='conv_final') 77 | 78 | # action policy layers 79 | conv3_1_1 = self.conv_act(final, 4, (1, 1), name='conv3_1_1') 80 | flatten_1 = mx.sym.Flatten(conv3_1_1) 81 | flatten_1 = mx.sym.Dropout(flatten_1, p=0.5) 82 | fc_3_1_1 = self.fc_self(flatten_1, self.board_height*self.board_width, name='fc_3_1_1') 83 | action_1 = mx.sym.SoftmaxActivation(fc_3_1_1) 84 | 85 | # state value layers 86 | conv3_2_1 = self.conv_act(final, 2, (1, 1), name='conv3_2_1') 87 | flatten_2 = mx.sym.Flatten(conv3_2_1) 88 | flatten_2 = mx.sym.Dropout(flatten_2, p=0.5) 89 | fc_3_2_1 = self.fc_self(flatten_2, 1, name='fc_3_2_1') 90 | evaluation = mx.sym.Activation(fc_3_2_1, act_type='tanh') 91 | 92 | return action_1, evaluation 93 | 94 | 95 | 96 | def create_backbone2(self, input_states): 97 | """create the policy value network """ 98 | 99 | conv1 = self.conv_act(input_states, 32, (3, 3), name='conv1') 100 | conv2 = self.conv_act(conv1, 64, (3, 3), name='conv2') 101 | conv3 = self.conv_act(conv2, 128, (3, 3), name='conv3') 102 | conv4 = self.conv_act(conv3, 128, (3, 3), name='conv4') 103 | conv5 = self.conv_act(conv4, 128, (3, 3), name='conv5') 104 | final = self.conv_act(conv5, 128, (3, 3), name='conv_final') 105 | 106 | # action policy layers 107 | conv3_1_1 = self.conv_act(final, 1024, (1, 1), name='conv3_1_1') 108 | conv3_1_2 = self.conv_act(conv3_1_1, 1, (1, 1), act=None, dobn=False, name='conv3_1_2') 109 | flatten_1 = mx.sym.Flatten(conv3_1_2) 110 | action_1 = mx.sym.SoftmaxActivation(flatten_1) 111 | 112 | # state value layers 113 | conv3_2_1 = self.conv_act(final, 256, (1, 1), name='conv3_2_1') 114 | conv3_2_2 = self.conv_act(conv3_2_1, 1, (1, 1), act=None, dobn=False, name='conv3_2_2') 115 | flatten_2 = mx.sym.Flatten(conv3_2_2) 116 | mean_2 = mx.sym.mean(flatten_2, axis=1, keepdims=True) 117 | evaluation = mx.sym.Activation(mean_2, act_type='tanh') 118 | 119 | return action_1, evaluation 120 | 121 | def create_policy_value_train(self, batch_size): 122 | input_states_shape = (batch_size, self.channelnum, self.board_height, self.board_width) 123 | input_states = mx.sym.Variable(name='input_states', shape=input_states_shape) 124 | action_1, evaluation = self.create_backbone(input_states) 125 | 126 | mcts_probs_shape = (batch_size, self.board_height * self.board_width) 127 | mcts_probs = mx.sym.Variable(name='mcts_probs', shape=mcts_probs_shape) 128 | policy_loss = -mx.sym.sum(mx.sym.log(action_1) * mcts_probs, axis=1) 129 | policy_loss = mx.sym.mean(policy_loss) 130 | 131 | input_labels_shape = (batch_size, 1) 132 | input_labels = mx.sym.Variable(name='input_labels', shape=input_labels_shape) 133 | value_loss = mx.sym.mean(mx.sym.square(input_labels - evaluation)) 134 | 135 | loss = value_loss + policy_loss 136 | loss = mx.sym.MakeLoss(loss) 137 | 138 | entropy = mx.sym.sum(-action_1 * mx.sym.log(action_1), axis=1) 139 | entropy = mx.sym.mean(entropy) 140 | entropy = mx.sym.BlockGrad(entropy) 141 | entropy = mx.sym.MakeLoss(entropy) 142 | policy_value_loss = mx.sym.Group([loss, entropy]) 143 | policy_value_loss.save('policy_value_loss.json') 144 | 145 | pv_train = mx.mod.Module(symbol=policy_value_loss, 146 | data_names=['input_states'], 147 | label_names=['input_labels', 'mcts_probs'], 148 | context=self.context) 149 | pv_train.bind(data_shapes=[('input_states', input_states_shape)], 150 | label_shapes=[('input_labels', input_labels_shape), ('mcts_probs', mcts_probs_shape)], 151 | for_training=True) 152 | pv_train.init_params(initializer=mx.init.Xavier()) 153 | pv_train.init_optimizer(optimizer='adam', 154 | optimizer_params={'learning_rate':0.001, 155 | #'clip_gradient':0.1, 156 | #'momentum':0.9, 157 | 'wd':0.0001}) 158 | 159 | return pv_train 160 | 161 | def create_policy_value_predict(self, batch_size): 162 | input_states_shape = (batch_size, self.channelnum, self.board_height, self.board_width) 163 | input_states = mx.sym.Variable(name='input_states', shape=input_states_shape) 164 | action_1, evaluation = self.create_backbone(input_states) 165 | policy_value_output = mx.sym.Group([action_1, evaluation]) 166 | 167 | pv_predict = mx.mod.Module(symbol=policy_value_output, 168 | data_names=['input_states'], 169 | label_names=None, 170 | context=self.context) 171 | 172 | pv_predict.bind(data_shapes=[('input_states', input_states_shape)], for_training=False) 173 | args, auxs = self.train_batch.get_params() 174 | pv_predict.set_params(args, auxs) 175 | 176 | return pv_predict 177 | 178 | def policy_value(self, state_batch): 179 | states = np.asarray(state_batch) 180 | #print('policy_value:', states.shape) 181 | state_nd = mx.nd.array(states) 182 | self.predict_batch.forward(mx.io.DataBatch([state_nd])) 183 | acts, vals = self.predict_batch.get_outputs() 184 | acts = acts.asnumpy() 185 | vals = vals.asnumpy() 186 | #print(acts[0], vals[0]) 187 | 188 | return acts, vals 189 | 190 | def policy_value2(self, state_batch): 191 | actsall = [] 192 | valsall = [] 193 | for state in state_batch: 194 | state = state.reshape(1, self.channelnum, self.board_height, self.board_width) 195 | #print(state.shape) 196 | state_nd = mx.nd.array(state) 197 | self.predict_one.forward(mx.io.DataBatch([state_nd])) 198 | act, val = self.predict_one.get_outputs() 199 | actsall.append(act[0].asnumpy()) 200 | valsall.append(val[0].asnumpy()) 201 | acts = np.asarray(actsall) 202 | vals = np.asarray(valsall) 203 | #print(acts.shape, vals.shape) 204 | 205 | return acts, vals 206 | 207 | def policy_value_fn(self, board): 208 | """ 209 | input: board 210 | output: a list of (action, probability) tuples for each available action and the score of the board state 211 | """ 212 | legal_positions = board.availables 213 | current_state = board.current_state().reshape(1, self.channelnum, self.board_height, self.board_width) 214 | state_nd = mx.nd.array(current_state) 215 | self.predict_one.forward(mx.io.DataBatch([state_nd])) 216 | acts_probs, values = self.predict_one.get_outputs() 217 | acts_probs = acts_probs.asnumpy() 218 | values = values.asnumpy() 219 | #print(acts_probs[0, :4]) 220 | legal_actprob = acts_probs[0][legal_positions] 221 | act_probs = zip(legal_positions, legal_actprob) 222 | # print(len(legal_positions), legal_actprob.shape, acts_probs.shape) 223 | # if len(legal_positions)==0: 224 | # exit() 225 | 226 | return act_probs, values[0] 227 | 228 | def train_step(self, state_batch, mcts_probs, winner_batch, learning_rate): 229 | #print('hello training....') 230 | #print(mcts_probs[0], winner_batch[0]) 231 | self.train_batch._optimizer.lr = learning_rate 232 | state_batch = mx.nd.array(np.asarray(state_batch).reshape(-1, self.channelnum, self.board_height, self.board_width)) 233 | mcts_probs = mx.nd.array(np.asarray(mcts_probs).reshape(-1, self.board_height*self.board_width)) 234 | winner_batch = mx.nd.array(np.asarray(winner_batch).reshape(-1, 1)) 235 | self.train_batch.forward(mx.io.DataBatch([state_batch], [winner_batch, mcts_probs])) 236 | self.train_batch.backward() 237 | self.train_batch.update() 238 | loss, entropy = self.train_batch.get_outputs() 239 | 240 | args, auxs = self.train_batch.get_params() 241 | self.predict_batch.set_params(args, auxs) 242 | self.predict_one.set_params(args, auxs) 243 | 244 | return loss.asnumpy(), entropy.asnumpy() 245 | 246 | def get_policy_param(self): 247 | net_params = self.train_batch.get_params() 248 | return net_params 249 | 250 | def save_model(self, model_file): 251 | """ save model params to file """ 252 | net_params = self.get_policy_param() 253 | print('>>>>>>>>>> saved into', model_file) 254 | pickle.dump(net_params, open(model_file, 'wb'), protocol=2) 255 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | line_profiler==2.1.2 2 | requests==2.18.4 3 | numpy==1.14.2 4 | beautifulsoup4==4.6.0 5 | PyYAML==3.13 6 | psyco 7 | tornado 8 | mxnet==1.6.0 -------------------------------------------------------------------------------- /run_ai.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "start for AI.." && source /workspace/dev.env/bin/activate 4 | export PATH=/usr/local/cuda-10.2/bin:$PATH 5 | export LD_LIBRARY_PATH=/usr/local/cuda-10.2/lib64:$LD_LIBRARY_PATH 6 | sleep 5 7 | export CUDA_VISIBLE_DEVICES=1 # 设置使用的显卡 8 | echo "cuda device: ", $CUDA_VISIBLE_DEVICES 9 | cd /workspace/AlphaPig/evaluate/ && echo "run AI is OK " 10 | nohup python ChessClient.py --cur_role 1 --model r10 --room_name 1 --server_url http://gobang_server:8888 & 11 | nohup python ChessClient.py --cur_role 1 --model r10 --room_name 2 --server_url http://gobang_server:8888 & 12 | nohup python ChessClient.py --cur_role 1 --model r10 --room_name 3 --server_url http://gobang_server:8888 & 13 | nohup python ChessClient.py --cur_role 1 --model r10 --room_name 4 --server_url http://gobang_server:8888 & 14 | nohup python ChessClient.py --cur_role 1 --model r10 --room_name 5 --server_url http://gobang_server:8888 & 15 | nohup python ChessClient.py --cur_role 2 --model r10 --room_name 6 --server_url http://gobang_server:8888 & 16 | nohup python ChessClient.py --cur_role 2 --model r10 --room_name 7 --server_url http://gobang_server:8888 & 17 | nohup python ChessClient.py --cur_role 2 --model r10 --room_name 8 --server_url http://gobang_server:8888 & 18 | nohup python ChessClient.py --cur_role 2 --model r10 --room_name 9 --server_url http://gobang_server:8888 & 19 | python ChessClient.py --cur_role 2 --model r10 --room_name 10 --server_url http://gobang_server:8888 20 | -------------------------------------------------------------------------------- /run_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "start...2" && source /workspace/dev.env/bin/activate && \ 4 | cd /workspace/AlphaPig/evaluate/ && echo "run server is OK " \ 5 | && python ChessServer.py --port 8888 -------------------------------------------------------------------------------- /sgf_data/download.sh: -------------------------------------------------------------------------------- 1 | wget "http://p324ywv2g.bkt.clouddn.com/sgf_data.zip" 2 | unzip sgf_data.zip -------------------------------------------------------------------------------- /start_train.sh: -------------------------------------------------------------------------------- 1 | nohup python -u train_mxnet.py & -------------------------------------------------------------------------------- /train_mxnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An implementation of the training pipeline of AlphaZero for Gomoku 4 | 5 | @author: Junxiao Song 6 | anxingle 7 | """ 8 | 9 | from __future__ import print_function 10 | import pickle 11 | import random 12 | import os 13 | import time 14 | import numpy as np 15 | from optparse import OptionParser 16 | import multiprocessing as mp 17 | from collections import defaultdict, deque 18 | from game import Board, Game 19 | from game_ai import Game_AI 20 | from mcts_pure import MCTSPlayer as MCTS_Pure 21 | from mcts_alphaZero import MCTSPlayer 22 | from utils import config_loader, send_email 23 | 24 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 25 | # from policy_value_net import PolicyValueNet # Theano and Lasagne 26 | # from policy_value_net_pytorch import PolicyValueNet # Pytorch 27 | # from policy_value_net_tensorflow import PolicyValueNet # Tensorflow 28 | # from policy_value_net_keras import PolicyValueNet # Keras 29 | from policy_value_net_mxnet import PolicyValueNet # Mxnet 30 | 31 | import logging 32 | import logging.config 33 | logging.config.dictConfig(config_loader.config_['train_logging']) 34 | _logger = logging.getLogger(__name__) 35 | 36 | current_relative_path = lambda x: os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), x)) 37 | 38 | class TrainPipeline(): 39 | def __init__(self, conf, init_model=None): 40 | # params of the board and the game 41 | self.board_width = conf['board_width'] 42 | self.board_height = conf['board_height'] 43 | self.n_in_row = conf['n_in_row'] 44 | self.board = Board(width=self.board_width, 45 | height=self.board_height, 46 | n_in_row=self.n_in_row) 47 | self.game = Game(self.board) 48 | self.game_ai = Game_AI(self.board) 49 | # training params 50 | self.learn_rate = conf['learn_rate'] 51 | self.lr_multiplier = conf['lr_multiplier'] # adaptively adjust the learning rate based on KL 52 | self.temp = conf['temp'] # the temperature param 53 | self.n_playout = conf['n_playout'] # 500 # num of simulations for each move 54 | self.c_puct = conf['c_puct'] 55 | self.buffer_size = conf['buffer_size'] 56 | self.batch_size = conf['batch_size'] # mini-batch size for training 57 | self.data_buffer = deque(maxlen=self.buffer_size) 58 | self.play_batch_size = conf['play_batch_size'] 59 | self.epochs = conf['epochs'] # num of train_steps for each update 60 | self.kl_targ = conf['kl_targ'] 61 | self.check_freq = conf['check_freq'] 62 | self.game_batch_num =conf['game_batch_num'] 63 | self.best_win_ratio = 0.0 64 | # 多线程相关 65 | self._cpu_count = mp.cpu_count() - 8 66 | # num of simulations used for the pure mcts, which is used as 67 | # the opponent to evaluate the trained policy 68 | self.pure_mcts_playout_num = conf['pure_mcts_playout_num'] 69 | # 训练集文件 70 | self._sgf_home = current_relative_path(conf['sgf_dir']) 71 | _logger.info('path: %s' % self._sgf_home) 72 | self._ai_data_home = current_relative_path(conf['ai_data_dir']) 73 | # 加载人类对弈数据 74 | self._load_training_data(self._sgf_home) 75 | # 加载保存的自对弈数据 76 | # self._load_pickle_data(self._ai_data_home) 77 | if init_model: 78 | # start training from an initial policy-value net 79 | self.policy_value_net = PolicyValueNet(self.board_width, 80 | self.board_height, 81 | self.batch_size, 82 | n_blocks=10, 83 | n_filter=128, 84 | model_params=init_model) 85 | else: 86 | # start training from a new policy-value net 87 | self.policy_value_net = PolicyValueNet(self.board_width, 88 | self.board_height, 89 | self.batch_size, 90 | n_blocks=10, 91 | n_filter=128) 92 | self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, 93 | c_puct=self.c_puct, 94 | n_playout=self.n_playout, 95 | is_selfplay=1) 96 | 97 | def _load_training_data(self, data_dir): 98 | file_list = os.listdir(data_dir) 99 | self._training_data = [item for item in file_list if item.endswith('.sgf') and os.path.isfile(os.path.join(data_dir, item))] 100 | random.shuffle(self._training_data) 101 | self._length_train_data = len(self._training_data) 102 | 103 | """" 104 | def _load_pickle_data(self, data_dir): 105 | file_list = os.listdir(data_dir) 106 | txt_list = [item for item in file_list if item.endswith('.txt') and os.path.isfile(os.path.join(data_dir, item))] 107 | self._ai_history_data = [] 108 | for txt_f in txt_list: 109 | with open(os.path.join(data_dir, txt_f), 'rb') as f_object: 110 | d = pickle.load(f_object) 111 | self._ai_history_data += d 112 | f_object.close() 113 | """ 114 | 115 | def get_equi_data(self, play_data): 116 | """augment the data set by rotation and flipping 117 | play_data: [(state, mcts_prob, winner_z), ..., ...] 118 | """ 119 | extend_data = [] 120 | for state, mcts_porb, winner in play_data: 121 | for i in [1, 2, 3, 4]: 122 | # rotate counterclockwise 123 | equi_state = np.array([np.rot90(s, i) for s in state]) 124 | equi_mcts_prob = np.rot90(np.flipud( 125 | mcts_porb.reshape(self.board_height, self.board_width)), i) 126 | extend_data.append((equi_state, 127 | np.flipud(equi_mcts_prob).flatten(), 128 | winner)) 129 | # flip horizontally 130 | equi_state = np.array([np.fliplr(s) for s in equi_state]) 131 | equi_mcts_prob = np.fliplr(equi_mcts_prob) 132 | extend_data.append((equi_state, 133 | np.flipud(equi_mcts_prob).flatten(), 134 | winner)) 135 | return extend_data 136 | 137 | def collect_selfplay_data(self, n_games=1, training_index=None): 138 | """collect SGF file data for training""" 139 | data_index = training_index % self._length_train_data 140 | if data_index == 0: 141 | random.shuffle(self._training_data) 142 | for i in range(n_games): 143 | warning, winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp, sgf_home=self._sgf_home, file_name=self._training_data[data_index]) 144 | if warning: 145 | _logger.error('\033[0;41m %s \033[0m anxingle_training_index: %s, data_index: %s, file: %s' % ('WARNING', training_index, data_index, self._training_data[data_index])) 146 | else: 147 | _logger.info('winner: %s, file: %s ' % (winner, self._training_data[data_index])) 148 | # print('play_data: ', play_data) 149 | play_data = list(play_data)[:] 150 | self.episode_len = len(play_data) 151 | # augment the data 152 | play_data = self.get_equi_data(play_data) 153 | self.data_buffer.extend(play_data) 154 | _logger.info('game_batch_index: %s, length of data_buffer: %s' % (training_index, len(self.data_buffer))) 155 | 156 | """ 157 | def collect_selfplay_data_pickle(self, n_games=1, training_index=None): 158 | # load AI self play data(auto save for every N game play nums) 159 | data_index = training_index % len(self._ai_history_data) 160 | if data_index == 0: 161 | random.shuffle(self._ai_history_data) 162 | for i in range(n_games): 163 | play_data = self._ai_history_data[data_index] 164 | self.episode_len = len(play_data) 165 | # augment the data 166 | play_data = self.get_equi_data(play_data) 167 | self.data_buffer.extend(play_data) 168 | """ 169 | 170 | def collect_selfplay_data_ai(self, n_games=1, training_index=None): 171 | """collect AI self-play data for training""" 172 | for i in range(n_games): 173 | winner, play_data = self.game_ai.start_self_play(self.mcts_player, 174 | temp=self.temp) 175 | _logger.info('traing_index: %s, winner is: %s' % (training_index, winner)) 176 | play_data = list(play_data)[:] 177 | self.episode_len = len(play_data) 178 | # augment the data 179 | play_data = self.get_equi_data(play_data) 180 | self.data_buffer.extend(play_data) 181 | 182 | # def _multiprocess_collect_selfplay_data(self, q, process_index): 183 | # """ 184 | # TODO: CUDA multiprocessing have bugs! 185 | # winner, play_data = self.game.start_self_play(self.mcts_player, 186 | # temp=self.temp) 187 | # play_data = list(play_data)[:] 188 | # self.episode_len = len(play_data) 189 | # # augment the data 190 | # play_data = self.get_equi_data(play_data) 191 | # q.put(play_data) 192 | 193 | 194 | def policy_update(self): 195 | """update the policy-value net""" 196 | mini_batch = random.sample(self.data_buffer, self.batch_size) 197 | state_batch = [data[0] for data in mini_batch] 198 | mcts_probs_batch = [data[1] for data in mini_batch] 199 | winner_batch = [data[2] for data in mini_batch] 200 | old_probs, old_v = self.policy_value_net.policy_value(state_batch) 201 | learn_rate = self.learn_rate*self.lr_multiplier 202 | for i in range(self.epochs): 203 | loss, entropy = self.policy_value_net.train_step( 204 | state_batch, 205 | mcts_probs_batch, 206 | winner_batch, 207 | learn_rate) 208 | new_probs, new_v = self.policy_value_net.policy_value(state_batch) 209 | kl = np.mean(np.sum(old_probs * ( 210 | np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), 211 | axis=1) 212 | ) 213 | if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly 214 | _logger.info('early stopping. i:%s. epochs: %s' % (i, self.epochs)) 215 | break 216 | # adaptively adjust the learning rate 217 | if kl > self.kl_targ * 2 and self.lr_multiplier > 0.05: 218 | self.lr_multiplier /= 1.5 219 | elif kl < self.kl_targ / 2 and self.lr_multiplier < 20: 220 | self.lr_multiplier *= 1.5 221 | 222 | explained_var_old = (1 - 223 | np.var(np.array(winner_batch) - old_v.flatten()) / 224 | np.var(np.array(winner_batch))) 225 | explained_var_new = (1 - 226 | np.var(np.array(winner_batch) - new_v.flatten()) / 227 | np.var(np.array(winner_batch))) 228 | _logger.info(("kl:{:.4f}," 229 | "lr:{:.1e}," 230 | "loss:{}," 231 | "entropy:{}," 232 | "explained_var_old:{:.3f}," 233 | "explained_var_new:{:.3f}" 234 | ).format(kl, 235 | learn_rate, 236 | loss, 237 | entropy, 238 | explained_var_old, 239 | explained_var_new)) 240 | return loss, entropy 241 | 242 | def policy_evaluate(self, n_games=10): 243 | """ 244 | Evaluate the trained policy by playing against the pure MCTS player 245 | Note: this is only for monitoring the progress of training 246 | """ 247 | current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, 248 | c_puct=self.c_puct, 249 | n_playout=self.n_playout) 250 | pure_mcts_player = MCTS_Pure(c_puct=5, 251 | n_playout=self.pure_mcts_playout_num) 252 | win_cnt = defaultdict(int) 253 | for i in range(n_games): 254 | winner = self.game.start_play(current_mcts_player, 255 | pure_mcts_player, 256 | start_player=i % 2, 257 | is_shown=0) 258 | win_cnt[winner] += 1 259 | win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games 260 | _logger.info("num_playouts:{}, win: {}, lose: {}, tie:{}".format( 261 | self.pure_mcts_playout_num, 262 | win_cnt[1], win_cnt[2], win_cnt[-1])) 263 | return win_ratio 264 | 265 | def run(self): 266 | """run the training pipeline""" 267 | try: 268 | for i in range(self.game_batch_num): 269 | current_time = time.time() 270 | if i < 4000: 271 | self.collect_selfplay_data(1, training_index=i) 272 | else: 273 | self.collect_selfplay_data_ai(1, training_index=i) 274 | _logger.info('collection cost time: %d ' % (time.time() - current_time)) 275 | _logger.info("batch i:{}, episode_len:{}, buffer_len:{}".format( 276 | i+1, self.episode_len, len(self.data_buffer))) 277 | if len(self.data_buffer) > self.batch_size: 278 | batch_time = time.time() 279 | loss, entropy = self.policy_update() 280 | _logger.info('train batch cost time: %d' % (time.time() - batch_time)) 281 | # check the performance of the current model, 282 | # and save the model params 283 | if (i+1) % 50 == 0: 284 | self.policy_value_net.save_model('./logs/current_policy.model') 285 | if (i+1) % self.check_freq == 0: 286 | check_time = time.time() 287 | _logger.info("current self-play batch: {}".format(i+1)) 288 | win_ratio = self.policy_evaluate() 289 | _logger.info('evaluate the network cost time: %s ', int(time.time() - check_time)) 290 | if win_ratio > self.best_win_ratio: 291 | _logger.info("New best policy!!!!!!!!") 292 | self.best_win_ratio = win_ratio 293 | # update the best_policy 294 | self.policy_value_net.save_model('./logs/best_policy_%s.model' % i) 295 | if (self.best_win_ratio >= 0.98 and 296 | self.pure_mcts_playout_num < 8000): 297 | self.pure_mcts_playout_num += 1000 298 | self.best_win_ratio = 0.0 299 | except KeyboardInterrupt: 300 | _logger.info('\n\rquit') 301 | 302 | 303 | if __name__ == '__main__': 304 | try: 305 | start_time = time.time() 306 | # model_file = './logs/current_policy.model' 307 | model_file = None 308 | policy_param = None 309 | conf = config_loader.load_config('./conf/train_config.yaml') 310 | if model_file is not None: 311 | _logger.info('loading...%s' % model_file) 312 | try: 313 | policy_param = pickle.load(open(model_file, 'rb')) 314 | except: 315 | policy_param = pickle.load(open(model_file, 'rb'), 316 | encoding='bytes') # To support python3 317 | training_pipeline = TrainPipeline(conf, policy_param) 318 | _logger.info('enter training!') 319 | # training_pipeline.collect_selfplay_data(1, 1) 320 | training_pipeline.run() 321 | except Exception as e: 322 | _logger.exception(e) 323 | finally: 324 | cost_time = int(time.time() - start_time) 325 | format_time = "耗时: %s 小时 %s 分 %s 秒" % (cost_time/3600, (cost_time%3600)/60, (cost_time%3600)%60 ) 326 | send_email.send_mail('训练结束', format_time, 'XXX') 327 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import config_loader 3 | import sgf_dataIter -------------------------------------------------------------------------------- /utils/config_loader.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | import yaml 4 | import socket 5 | import time 6 | import sys 7 | import logging 8 | 9 | 10 | def load_config(data_path): 11 | f = open(data_path, 'r') 12 | conf = yaml.load(f) 13 | f.close() 14 | return conf 15 | 16 | configure_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../conf/train_config.yaml') 17 | config_ = load_config(configure_path) 18 | 19 | -------------------------------------------------------------------------------- /utils/send_email.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | 3 | import smtplib 4 | from email.mime.text import MIMEText 5 | from email.utils import formataddr 6 | 7 | 8 | def send_mail(message_title, message_text, pass_wd): 9 | my_sender='987683297@qq.com' # 发件人邮箱账号 10 | my_pass = 'oktbbgwyddpybeih' # 发件人邮箱密码 11 | my_pass = pass_wd 12 | my_user='anxingle820@gmail.com' # 收件人邮箱账号,我这边发送给自己 13 | ret=True 14 | try: 15 | # msg=MIMEText('这个AI可以了','plain','utf-8') 16 | msg=MIMEText(message_text,'plain','utf-8') 17 | msg['From']=formataddr(["FromRunoob",my_sender]) # 括号里的对应发件人邮箱昵称、发件人邮箱账号 18 | msg['To']=formataddr(["FK",my_user]) # 括号里的对应收件人邮箱昵称、收件人邮箱账号 19 | msg['Subject'] = message_title # 邮件的主题,也可以说是标题 20 | 21 | server=smtplib.SMTP_SSL("smtp.qq.com", 465) # 发件人邮箱中的SMTP服务器,端口是25 22 | server.login(my_sender, my_pass) # 括号中对应的是发件人邮箱账号、邮箱密码 23 | server.sendmail(my_sender,[my_user,],msg.as_string()) # 括号中对应的是发件人邮箱账号、收件人邮箱账号、发送邮件 24 | server.quit() # 关闭连接 25 | except Exception: # 如果 try 中的语句没有执行,则会执行下面的 ret=False 26 | ret=False 27 | return ret 28 | 29 | -------------------------------------------------------------------------------- /utils/sgf_dataIter.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | import sys 4 | import time 5 | # 可视化棋谱 视情况决定是否引入 6 | # from gobang_board_utils import chessboard, evaluation, searcher, psyco_speedup 7 | # 加速函数 8 | # psyco_speedup() 9 | 10 | LETTER_NUM = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o'] 11 | BIG_LETTER_NUM = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O'] 12 | NUM_LIST = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] 13 | # 棋盘字母位置速查表 14 | seq_lookup = dict(zip(LETTER_NUM, NUM_LIST)) 15 | num2char_lookup = dict(zip(NUM_LIST, BIG_LETTER_NUM)) 16 | 17 | # SGF文件 18 | sgf_home = '/Users/anxingle/Downloads/SGF_Gomoku/sgf/' 19 | 20 | 21 | def get_files_as_list(data_dir): 22 | # 扫描某目录下SGF文件列表 23 | file_list = os.listdir(data_dir) 24 | file_list = [item for item in file_list if item.endswith('.sgf') and os.path.isfile(os.path.join(data_dir, item))] 25 | return file_list 26 | 27 | def content_to_order(sequence): 28 | # 棋谱字母转整型数字 29 | 30 | global seq_lookup # 棋盘字母位置速查表 31 | seq_list = sequence.split(';') 32 | # list:['hh', 'ii', 'hi'....] 33 | seq_list = [item[2:4] for item in seq_list] 34 | # list: [112, 128, ...] 35 | seq_num_list = [seq_lookup[item[0]]*15+seq_lookup[item[1]] for item in seq_list] 36 | return seq_list, seq_num_list 37 | 38 | 39 | def num2char(order_): 40 | global num2char_lookup 41 | Y_axis = num2char_lookup[order_/15] 42 | X_axis = num2char_lookup[order_ % 15] 43 | return '%s%s' % (Y_axis, X_axis) 44 | 45 | def get_data_from_files(file_name, data_dir): 46 | """ 根据文件名读取SGF棋谱内容 """ 47 | 48 | assert file_name.endswith('.sgf'), 'file: %s 不是SGF文件' % file_name 49 | with open(os.path.join(data_dir, file_name)) as f: 50 | p = f.read() 51 | # 棋谱内容 开始/结束 位置 52 | start = file_name.index('_') + 1 53 | end = file_name.index('_.') 54 | 55 | sequence = p[p.index('SZ[15]')+7:-4] 56 | try: 57 | seq_list, seq_num_list = content_to_order(sequence) 58 | except Exception as e: 59 | print('***' * 20) 60 | print(e) 61 | print(file_name) 62 | if file_name[file_name.index('_')+1:file_name.index('_')+6] == 'Blank' or file_name[file_name.index('_')+1:file_name.index('_')+6] == 'blank': 63 | winner = 1 64 | if file_name[file_name.index('_')+1:file_name.index('_')+6] == 'White' or file_name[file_name.index('_')+1:file_name.index('_')+6] == 'white': 65 | winner = 2 66 | return {'winner': winner, 'seq_list': seq_list, 'seq_num_list': seq_num_list, 'file_name':file_name} 67 | 68 | 69 | def read_files(data_dir): 70 | # 迭代读取目录下SGF文件的棋谱内容 71 | 72 | # 扫描获取data_dir目录下所有SGF文件 73 | file_list = get_files_as_list(data_dir) 74 | index = 0 75 | while True: 76 | if index >= len(file_list): yield None 77 | with open(data_dir+file_list[index]) as f: 78 | p = f.read() 79 | # 棋谱内容 开始/结束 位置 80 | start = file_list[index].index('_') + 1 81 | end = file_list[index].index('_.') 82 | 83 | sequence = p[p.index('SZ[15]')+7:-4] 84 | try: 85 | seq_list, seq_num_list = content_to_order(sequence) 86 | except Exception as e: 87 | print('***' * 20) 88 | print(e) 89 | print(file_list[index]) 90 | if sequence[-5] == 'B' or sequence[-5] == 'b': 91 | winner = 1 92 | if sequence[-5] == 'W' or sequence[-5] == 'w': 93 | winner = 2 94 | yield {'winner': winner, 'seq_list': seq_list, 'seq_num_list': seq_num_list, 'index': index, 'file_name':file_list[index]} 95 | index += 1 96 | 97 | # 仅用于走子可视化 98 | def gamemain(seq_list): 99 | b = chessboard() 100 | s = searcher() 101 | s.board = b.board() 102 | 103 | opening = [ 104 | '1:HH 2:II', 105 | # '2:IG 2:GI 1:HH', 106 | # '1:IH 2:GI', 107 | # '1:HG 2:HI', 108 | # '2:HG 2:HI 1:HH', 109 | # '1:HH 2:IH 2:GI', 110 | # '1:HH 2:IH 2:HI', 111 | # '1:HH 2:IH 2:HJ', 112 | # '1:HG 2:HH 2:HI', 113 | # '1:GH 2:HH 2:HI', 114 | ] 115 | 116 | import random 117 | # 开局棋盘局面 118 | # openid = random.randint(0, len(opening) - 1) 119 | # 开局局面为黑方的第一手 120 | def num2char(order_): 121 | global num2char_lookup 122 | Y_axis = num2char_lookup[order_/15] 123 | X_axis = num2char_lookup[order_ % 15] 124 | return '%s%s' % (Y_axis, X_axis) 125 | 126 | 127 | start_open = '1:%s 2:%s' %(num2char(seq_list[0]), num2char(seq_list[1])) 128 | b.loads(start_open) 129 | turn = 2 130 | history = [] 131 | undo = False 132 | 133 | # 设置难度 134 | DEPTH = 1 135 | # 对弈开始,从第黑方第二开始 136 | index = 2 137 | 138 | while True: 139 | print '' 140 | while 1: 141 | print '' % (len(history) + 1) 142 | b.show() 143 | print '该你移动了: (u:悔棋, q:退出):', 144 | # 默认一直继续下去 145 | text = raw_input().strip('\r\n\t ') 146 | text = '%s' % num2char(seq_list[index]) 147 | print 'char: ', num2char(seq_list[index]) 148 | print 'num: ', seq_list[index] 149 | index += 1 150 | if len(text) == 2: 151 | tr = ord(text[0].upper()) - ord('A') 152 | tc = ord(text[1].upper()) - ord('A') 153 | if tr >= 0 and tc >= 0 and tr < 15 and tc < 15: 154 | if b[tr][tc] == 0: 155 | row, col = tr, tc 156 | break 157 | else: 158 | print '已经有棋子在这里了!' 159 | else: 160 | print text 161 | print '不在棋盘内!' 162 | elif text.upper() == 'U': 163 | undo = True 164 | break 165 | elif text.upper() == 'Q': 166 | print b.dumps() 167 | return 0 168 | 169 | if undo == True: 170 | undo = False 171 | if len(history) == 0: 172 | print '棋盘已经清空,无法继续悔棋了!' 173 | else: 174 | print '悔棋中,回退历史棋局 ...' 175 | move = history.pop() 176 | b.loads(move) 177 | else: 178 | history.append(b.dumps()) 179 | b[row][col] = 1 180 | b.show() 181 | 182 | if b.check() == 1: 183 | # b.show() 184 | print b.dumps() 185 | print '' 186 | print 'YOU WIN !!' 187 | return 0 188 | 189 | # print 'AI正在思考 ...' 190 | # time.sleep(0.6) 191 | # xtt = input('go on: ') 192 | # score, row, col = s.search(2, DEPTH) 193 | # AI(白方的输入重新) 194 | text_ai = num2char(seq_list[index]) 195 | index += 1 196 | # 棋盘字符==>数字 197 | row, col = ord(text_ai[0].upper())-65, ord(text_ai[1].upper())-65 198 | cord = '%s%s' % (chr(ord('A') + row), chr(ord('A') + col)) 199 | print 'AI 移动到: %s ' % (cord) 200 | # xtt = input('go on: ') 201 | b[row][col] = 2 202 | xtt = input('go on:') 203 | b.show() 204 | 205 | if b.check() == 2: 206 | # b.show() 207 | print b.dumps() 208 | print '' 209 | print 'YOU LOSE.' 210 | return 0 211 | 212 | return 0 213 | 214 | 215 | if __name__ == '__main__': 216 | data = read_files(sgf_home) 217 | x = None 218 | y = None 219 | for i in range(4800): 220 | y = x 221 | x = data.next() 222 | if x == None: 223 | print('whole loop: ', i) 224 | print('index: ', y['index']) 225 | print('index: ', y['file_name']) 226 | print '\n' 227 | break 228 | else: 229 | pass 230 | 231 | --------------------------------------------------------------------------------