├── .gitattributes
├── .gitignore
├── README.md
├── __init__.py
├── ai_train.dockerfile
├── conf
└── train_config.yaml
├── cpu_train.sh
├── docker-compose.yml
├── evaluate
├── AICollection.py
├── ChessBoard.py
├── ChessClient.py
├── ChessHelper.py
├── ChessServer.py
├── Hall.py
├── README.md
├── ai_listen_ucloud.bat
├── gobang.py
├── logs
│ ├── error.log
│ └── info.log
├── page
│ ├── chessboard.html
│ └── login.html
├── run_ai.sh
├── static
│ ├── black.png
│ ├── blank.png
│ ├── jquery-3.3.1.min.js
│ ├── touming.png
│ └── white.png
├── unittest
│ └── ChessBoardTest.py
└── yixin_ai
│ ├── engine.exe
│ ├── msvcp140.dll
│ ├── vcruntime140.dll
│ └── yixin.exe
├── game.py
├── game_ai.py
├── human_play_mxnet.py
├── inception-resnet-v2.py
├── logs
├── .gitkeep
└── download_model.sh
├── mcts_alphaZero.py
├── mcts_pure.py
├── papers
└── thinking-fast-and-slow-with-deep-learning-and-tree-search.pdf
├── play_vs_yixin.py
├── policy_value_loss.json
├── policy_value_net_mxnet.py
├── policy_value_net_mxnet_simple.py
├── requirements.txt
├── run_ai.sh
├── run_server.sh
├── self_play.py
├── sgf_data
└── download.sh
├── start_train.sh
├── train_mxnet.py
└── utils
├── __init__.py
├── config_loader.py
├── send_email.py
└── sgf_dataIter.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | pickle_ai_data/dump55.txt filter=lfs diff=lfs merge=lfs -text
2 | pickle_ai_data/dump1.txt filter=lfs diff=lfs merge=lfs -text
3 | pickle_ai_data/dump11.txt filter=lfs diff=lfs merge=lfs -text
4 | pickle_ai_data/dump2.txt filter=lfs diff=lfs merge=lfs -text
5 | pickle_ai_data/dump22.txt filter=lfs diff=lfs merge=lfs -text
6 | pickle_ai_data/dump33.txt filter=lfs diff=lfs merge=lfs -text
7 | pickle_ai_data/dump44.txt filter=lfs diff=lfs merge=lfs -text
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | __pycache__/
3 | *.swp
4 | *.swo
5 | *.sgf
6 | *.model
7 | policy_value_loss.json
8 | evaluate/chess_output/
9 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## AlphaPig
2 | 使用AlphaZero算法在五子棋上的实现。五子棋比围棋简单,训练起来也稍微能够简单一点,所以选择了五子棋来作为AlphaZero的复现。
3 |
4 | 参考:
5 | 1. [junxiaosong/AlphaZero_Gomoku](https://github.com/junxiaosong/AlphaZero_Gomoku)
6 | 2. [starimpact/AlphaZero_Gomoku](https://github.com/starimpact/AlphaZero_Gomoku)
7 | 3. [yonghenglh6/GobangServer](https://github.com/yonghenglh6/GobangServer)
8 |
9 |
10 | ### 快速启动
11 |
12 | ```(确保本机安装docker/nvidia-docker)
13 | docker-compose up # (默认使用0号显卡,根据run_ai.sh 进行适配性修改)
14 | ```
15 |
16 | 对 run_ai.sh 进行修改后,修改docker-compose.yml:
17 |
18 | ```
19 | version: "2.3"
20 |
21 | services:
22 | gobang_server:
23 | image: gospelslave/alphapig:v0.1.11
24 | entrypoint: /bin/bash run_server.sh
25 | privileged: true
26 | environment:
27 | - TZ=Asia/Shanghai
28 | volumes:
29 | - $PWD/run_server.sh:/workspace/run_server.sh # 这里修改后的映射
30 | ports:
31 | - 8888:8888
32 | restart: always
33 | logging:
34 | driver: json-file
35 | options:
36 | max-size: "10M"
37 | max-file: "5"
38 |
39 | gobang_ai:
40 | image: gospelslave/alphapig:v0.1.11
41 | entrypoint: /bin/bash run_ai.sh
42 | privileged: true
43 | environment:
44 | - TZ=Asia/Shanghai
45 | volumes:
46 | - $PWD/run_ai.sh:/workspace/run_ai.sh # 这里是修改后的映射
47 | runtime: nvidia
48 | restart: always
49 | logging:
50 | driver: json-file
51 | options:
52 | max-size: "10M"
53 | max-file: "5"
54 | ```
55 |
56 | ### 上手入门
57 | + cd AlphaPig/sgf_data/
58 |
59 | + ```
60 | sh ./download.sh
61 | ```
62 |
63 | 将会下载并解压SGF棋谱数据,解压后应该实在sgf_data/目录下,目录结构AlphaPig/sgf_data/*.sgf。
64 |
65 | 也可自行下载SGF棋谱数据,并自行处理。
66 |
67 | + 直接运行根目录下的start_train.sh开始训练
68 |
69 | + 或者进入 train_mxnet.py 修改网络结构等参数,其中conf下的.yaml为训练定义的一些参数,可修改为适合自己的相关参数。
70 |
71 | + SGF格式详解
72 |
73 | ```
74 | FF[4] SGF格式的版本号,4是最新
75 | SZ[15] 棋盘大小,这是15x15
76 | PW[Pig]白棋棋手名称
77 | WR[2a]白棋棋手段位
78 | PB[stupid]黑棋棋手名称
79 | BR[2c]黑棋棋手段位
80 | DT[2018-07-06]棋谱生成日期
81 | PC[CA]棋局所在位置
82 | KM[6.5]贴目数量
83 | RE[B+Resign]B+是黑胜,W+是白胜,Resign是对方GG的
84 | CA[utf-8]棋局编码
85 | TM[0]限时情况,0为无限时
86 | OT[]读秒规则
87 | ;B[pp];W[dd];B[pc];W[dq] …… 棋谱下棋顺序
88 | ```
89 |
90 |
91 |
92 | + 如果需要和自己的AI对弈,可以进入evaluate目录,运行
93 |
94 | ```
95 | python ChessServer.py --port 8888
96 | ```
97 |
98 | 既可与自己的AI进行对弈,或者与yixin对弈。详细说明请参阅evaluate目录下的ReadMe。
99 |
100 | 对弈例子:
101 |
102 |
103 |
104 |
105 |
106 | + 下载我训练的一些模型(还有很多bug)
107 |
108 | ```
109 | cd AlphaPig/logs
110 | sh ./download_model.sh
111 | ```
112 |
113 | ## 致谢
114 |
115 | + 源工程请移步[junxiaosong/AlphaZero_Gomoku](https://github.com/junxiaosong/AlphaZero_Gomoku) ,特别感谢大V的很多issue和指导。
116 |
117 | + 特别感谢格灵深瞳提供的很多训练帮助(课程与训练资源上提供了很大支持),没有格灵深瞳的这些帮助,训练起来毫无头绪。
118 |
119 | + 感谢[Uloud](https://www.ucloud.cn/) 提供的P40 AI-train服务,1256小时/实例的训练,验证了不少想法。而且最后还免单了,中间没少打扰技术支持。特别感谢他们。
120 |
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from . import game
3 | from . import mcts_pure
4 | from . import policy_value_net_mxnet
5 | from . import policy_value_net_mxnet_simple
6 |
--------------------------------------------------------------------------------
/ai_train.dockerfile:
--------------------------------------------------------------------------------
1 | From gospelslave/alphapig:v0.1.3
2 |
3 | WORKDIR /workspace
4 |
5 | ADD ./ /workspace
6 |
7 | RUN /bin/bash -c "source ~/.bashrc"
8 |
9 |
10 |
--------------------------------------------------------------------------------
/conf/train_config.yaml:
--------------------------------------------------------------------------------
1 | # 数据目录
2 | sgf_dir: './sgf_data'
3 | # AI对弈数据目录
4 | ai_data_dir: './pickle_ai_data'
5 | # 棋盘设置
6 | board_width: 15
7 | board_height: 15
8 | n_in_row: 5
9 | # 学习率
10 | learn_rate: 0.0004
11 | # 根据KL散度动态调整学习率
12 | lr_multiplier: 1.0
13 | temp: 1.0
14 | # 每次移动的simulations数
15 | n_playout: 400
16 | # TODO: 蒙特卡洛树模拟选择时更多的依靠先验,估值越精确,C就应该偏向深度(越小)
17 | c_puct: 5
18 | # 数据集最大量(双端队列长度)
19 | buffer_size: 2198800
20 | batch_size: 128
21 | epochs: 8
22 | play_batch_size: 1
23 | # KL散度
24 | kl_targ: 0.02
25 | # 每check_freq次 检测对弈成绩
26 | check_freq: 1000
27 | # 检测成绩用的mcts对手的思考深度
28 | pure_mcts_playout_num: 1000
29 | # 训练多少轮
30 | game_batch_num: 240000
31 |
32 |
33 | # 训练日志
34 | train_logging:
35 | version: 1
36 | formatters:
37 | simpleFormater:
38 | format: '%(asctime)s - %(levelname)s - %(name)s[line:%(lineno)d]: %(message)s'
39 | datefmt: '%Y-%m-%d %H:%M:%S'
40 | handlers:
41 | # 标准输出,只要级别在DEBUG以上就会输出
42 | console:
43 | class: logging.StreamHandler
44 | formatter: simpleFormater
45 | level: DEBUG
46 | stream: ext://sys.stdout
47 | # INFO以上,滚动文件,保留20个,每个最大100MB
48 | info_file_handler:
49 | class : logging.FileHandler
50 | formatter: simpleFormater
51 | level: INFO
52 | filename: ./logs/info.log
53 | # ERROR以上
54 | error_file_handler:
55 | class : logging.FileHandler
56 | formatter: simpleFormater
57 | level: ERROR
58 | filename: ./logs/error.log
59 | root:
60 | level: DEBUG
61 | handlers: [console, info_file_handler, error_file_handler]
62 |
--------------------------------------------------------------------------------
/cpu_train.sh:
--------------------------------------------------------------------------------
1 | docker run -it \
2 | -v /home/ubuntu/uai-sdk/examples/mxnet/train/AlphaPig:/data \
3 | -v /home/ubuntu/uai-sdk/examples/mxnet/train/AlphaPig/sgf_data:/data/data \
4 | -v /home/ubuntu/uai-sdk/examples/mxnet/train/AlphaPig/logs:/data/output \
5 | uhub.service.ucloud.cn/uaishare/cpu_uaitrain_ubuntu-14.04_python-2.7.6_mxnet-1.0.0:v1.0 \
6 | /bin/bash -c "cd /data && /usr/bin/python /data/train_mxnet.py --model-prefix=siler_Alpha --work_dir=/data --output_dir=/data/output"
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: "2.3"
2 |
3 | services:
4 | gobang_server:
5 | image: gospelslave/alphapig:v0.1.11
6 | entrypoint: /bin/bash run_server.sh
7 | privileged: true
8 | environment:
9 | - TZ=Asia/Shanghai
10 | ports:
11 | - 8888:8888
12 | restart: always
13 | logging:
14 | driver: json-file
15 | options:
16 | max-size: "10M"
17 | max-file: "5"
18 |
19 | gobang_ai:
20 | image: gospelslave/alphapig:v0.1.11
21 | entrypoint: /bin/bash run_ai.sh
22 | privileged: true
23 | environment:
24 | - TZ=Asia/Shanghai
25 | runtime: nvidia
26 | restart: always
27 | logging:
28 | driver: json-file
29 | options:
30 | max-size: "10M"
31 | max-file: "5"
32 |
--------------------------------------------------------------------------------
/evaluate/AICollection.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import time
3 | import ChessHelper
4 | import threading
5 | from ChessBoard import ChessBoard
6 | from Hall import GameRoom
7 | from Hall import User
8 | import random
9 | import os
10 | from ChessClient import ChessClient
11 |
12 |
13 | class GameStrategy(object):
14 | def __init__(self):
15 | import gobang
16 | self.searcher = gobang.searcher()
17 |
18 | def play_one_piece(self, user, gameboard):
19 | # user = User()
20 | # gameboard = ChessBoard()
21 | turn = user.game_role
22 | self.searcher.board = [[gameboard.get_piece(m, n) for n in xrange(gameboard.SIZE)] for m in
23 | xrange(gameboard.SIZE)]
24 | # gameboard=ChessBoard()
25 | # gameboard.move_history
26 | score, row, col = self.searcher.search(turn, 2)
27 | # print "score:", score
28 | return (row, col)
29 |
30 |
31 | class GameStrategy_yixin(object):
32 |
33 | def __init__(self):
34 | self.muid = str(random.randint(0, 1000000))
35 | self.comm_folder = 'yixin_comm/'
36 | if not os.path.exists(self.comm_folder):
37 | os.makedirs(self.comm_folder)
38 | self.chess_state_file = self.comm_folder + 'game_state_' + self.muid
39 | self.action_file = self.comm_folder + 'action_' + self.muid
40 |
41 | def play_one_piece(self, user, gameboard):
42 | # user = User()
43 | # gameboard = ChessBoard()
44 | with open(self.chess_state_file, 'w') as chess_state_file:
45 | for userrole, move_num, row, col in gameboard.move_history:
46 | chess_state_file.write('%d,%d\n' % (row, col))
47 | if os.path.exists(self.action_file):
48 | os.remove(self.action_file)
49 | os.system("yixin_ai\yixin.exe %s %s" % (self.chess_state_file, self.action_file))
50 | row, col = random.randint(0, 15), random.randint(0, 15)
51 | with open(self.action_file) as action_file:
52 | line = action_file.readline()
53 | row, col = line.strip().split(',')
54 | row, col = int(row), int(col)
55 |
56 | return (row, col)
57 |
58 |
59 | class GameStrategy_random(object):
60 | def __init__(self):
61 | self._chess_helper_move_set = []
62 | for i in range(15):
63 | for j in range(15):
64 | self._chess_helper_move_set.append((i, j))
65 | random.shuffle(self._chess_helper_move_set)
66 | self.try_step = 0
67 |
68 | def play_one_piece(self, user, gameboard):
69 | move = self._chess_helper_move_set[self.try_step]
70 | while gameboard.get_piece(move[0], move[1]) != 0 and self.try_step < 15 * 15:
71 | self.try_step += 1
72 | move = self._chess_helper_move_set[self.try_step]
73 | self.try_step += 1
74 | return move
75 |
76 |
77 | class GameCommunicator(threading.Thread):
78 | def __init__(self, roomid, stragegy, server_url):
79 | threading.Thread.__init__(self)
80 | self.room_id = roomid
81 | self.stragegy = stragegy;
82 | self.server_url = server_url
83 |
84 | def run(self):
85 | client = ChessClient(self.server_url)
86 | client.login_in_guest()
87 | client.join_room(self.room_id)
88 | client.join_game()
89 | exp_interval = 0.5
90 | max_exp_interval = 200
91 | while True:
92 | single_max_time = exp_interval * 100
93 | wait_time = client.wait_game_info_changed(interval=exp_interval, max_time=single_max_time)
94 | if wait_time > single_max_time:
95 | exp_interval *= 2
96 | else:
97 | exp_interval = 0.5
98 | if exp_interval > max_exp_interval:
99 | exp_interval = max_exp_interval
100 | room = client.get_room_info()
101 | user = client.get_user_info()
102 | gameboard = client.get_game_info()
103 | print 'waittime:',wait_time,',room_status:',room.get_status(),',ask_take_back:',room.ask_take_back
104 | if room.get_status() == 1 or room.get_status() == 2:
105 | continue
106 | elif room.get_status() == 3:
107 | if room.ask_take_back != 0 and room.ask_take_back != user.game_role:
108 | client.answer_take_back()
109 | continue
110 | if gameboard.get_current_user() == user.game_role:
111 | one_legal_piece = self.stragegy.play_one_piece(user, gameboard)
112 | action_result = client.put_piece(*one_legal_piece)
113 | client.wait_game_info_changed(interval=exp_interval, max_time=single_max_time)
114 | if action_result['id'] != 0:
115 | print ChessHelper.numToAlp(one_legal_piece[0]), ChessHelper.numToAlp(one_legal_piece[1])
116 | print action_result['info']
117 | break
118 | continue
119 | elif room.get_status() == 4:
120 | break
121 |
122 |
123 | class GameListener(object):
124 | def __init__(self, prefix_stategy_map, server_url):
125 | self.client = ChessClient(server_url)
126 | self.client.login_in_guest()
127 | self.prefix_stategy_map = prefix_stategy_map
128 | self.server_url = server_url
129 | self.accupied = set()
130 |
131 | def listen(self):
132 | while True:
133 | all_rooms = self.client.get_all_rooms()
134 | for room in all_rooms:
135 | room_name = room[0]
136 | room_status = room[1]
137 | for prefix in self.prefix_stategy_map:
138 | if room_name.startswith(prefix) and room_status == GameRoom.ROOM_STATUS_ONEWAITING:
139 | if room_name in self.accupied:
140 | continue
141 | print 'Evoke:', room_name
142 | strg = self.prefix_stategy_map[prefix]()
143 | self.accupied.add(room_name)
144 | commu = GameCommunicator(room_name, strg, self.server_url)
145 | commu.start()
146 | break
147 | time.sleep(4)
148 |
149 |
150 | def go_listen():
151 | import argparse
152 | parser = argparse.ArgumentParser()
153 | parser.add_argument('--server_url', default='http://120.132.59.147:11111')
154 | args = parser.parse_args()
155 |
156 |
157 | if args.server_url.endswith('/'):
158 | args.server_url = args.server_url[:-1]
159 | if not args.server_url.startswith('http://'):
160 | args.server_url = 'http://' + args.server_url
161 |
162 | prefix_stategy_map = {'ai_': lambda: GameStrategy(), 'yixin_': lambda: GameStrategy_yixin(),
163 | 'random_': lambda: GameStrategy_random()}
164 | listen = GameListener(prefix_stategy_map, args.server_url)
165 | listen.listen()
166 |
167 |
168 | if __name__ == "__main__":
169 | go_listen()
170 |
--------------------------------------------------------------------------------
/evaluate/ChessBoard.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 | import cPickle as pickle
4 |
5 |
6 | class ChessBoard(object):
7 | """ChessBoard
8 |
9 | Attributes:
10 | SIZE: The chess board's size.
11 | board: To store the board information.
12 | state: Indicate if the game is over.
13 | current_user: The user who put the next piece.
14 | """
15 |
16 | STATE_RUNNING = 0
17 | STATE_DONE = 1
18 | STATE_ABORT = 1
19 |
20 | PIECE_STATE_BLANK = 0
21 | PIECE_STATE_FIRST = 1
22 | PIECE_STATE_SECOND = 2
23 |
24 | PAD = 4
25 |
26 | CHECK_DIRECTION = [[[0, 1], [0, -1]], [[1, 0], [-1, 0]], [[1, 1], [-1, -1]], [[1, -1], [-1, 1]]]
27 |
28 | def __init__(self, size=15):
29 | self.SIZE = size
30 | self.board = np.zeros((self.SIZE + ChessBoard.PAD * 2, self.SIZE + ChessBoard.PAD * 2), dtype=np.uint8)
31 | self.state = ChessBoard.STATE_RUNNING
32 | self.current_user = ChessBoard.PIECE_STATE_FIRST
33 |
34 | self.move_num = 0
35 | self.move_history = []
36 |
37 | self.dump_cache = None
38 |
39 | def changed(func):
40 | def wrapper_func(self,*args, **kwargs):
41 | ret=func(self,*args, **kwargs)
42 | self.dump_cache = None
43 | return ret
44 |
45 | return wrapper_func
46 |
47 | def get_piece(self, row, col):
48 | return self.board[row + ChessBoard.PAD, col + ChessBoard.PAD]
49 |
50 | @changed
51 | def set_piece(self, row, col, user):
52 | self.board[row + ChessBoard.PAD, col + ChessBoard.PAD] = user
53 |
54 | @changed
55 | def put_piece(self, row, col, user):
56 | """Put a piece in the board and check if he wins.
57 | Returns:
58 | 0 successful move.
59 | 1 successful and win move.
60 | -1 move out of range.
61 | -2 piece has been occupied.
62 | -3 game is over
63 | -4 not your turn.
64 | """
65 | if row < 0 or row >= self.SIZE or col < 0 or col >= self.SIZE:
66 | return -1
67 | if self.get_piece(row, col) != ChessBoard.PIECE_STATE_BLANK:
68 | return -2
69 | if self.state != ChessBoard.STATE_RUNNING:
70 | return -3
71 | if user != self.current_user:
72 | return -4
73 |
74 | self.set_piece(row, col, user)
75 | self.move_num += 1
76 | self.move_history.append((user, self.move_num, row, col,))
77 | # self.last_move = (row, col)
78 |
79 | # check if win
80 | for dx in xrange(4):
81 | connected_piece_num = 1
82 | for dy in xrange(2):
83 | current_direct = ChessBoard.CHECK_DIRECTION[dx][dy]
84 | c_row = row
85 | c_col = col
86 |
87 | # if else realization
88 | for dz in xrange(4):
89 | c_row += current_direct[0]
90 | c_col += current_direct[1]
91 | if self.get_piece(c_row, c_col) == user:
92 | connected_piece_num += 1
93 | else:
94 | break
95 |
96 | # remove if, but not faster
97 | # p = 1
98 | # for dz in xrange(4):
99 | # c_row += current_direct[0]
100 | # c_col += current_direct[1]
101 | # p = p & (self.board[c_row, c_col] == user)
102 | # connected_piece_num += p
103 |
104 | if connected_piece_num >= 5:
105 | self.state = ChessBoard.STATE_DONE
106 | return 1
107 |
108 | if self.current_user == ChessBoard.PIECE_STATE_SECOND:
109 | self.current_user = ChessBoard.PIECE_STATE_FIRST
110 | else:
111 | self.current_user = ChessBoard.PIECE_STATE_SECOND
112 |
113 | if self.move_num == self.SIZE * self.SIZE:
114 | # self.state = ChessBoard.STATE_DONE
115 | self.state = ChessBoard.STATE_ABORT
116 |
117 | return 0
118 |
119 | def get_winner(self):
120 | return self.current_user if self.state == ChessBoard.STATE_DONE else -1
121 |
122 | def get_state(self):
123 | return self.state
124 |
125 | def get_current_user(self):
126 | return self.current_user
127 |
128 | def get_lastmove(self):
129 | return self.move_history[-1] if len(self.move_history) > 0 else (-1, -1, -1, -1)
130 |
131 | @changed
132 | def take_one_back(self):
133 | if len(self.move_history) > 0:
134 | last_move = self.move_history.pop()
135 | self.set_piece(last_move[-2], last_move[-1], ChessBoard.PIECE_STATE_BLANK)
136 | self.move_num -= 1
137 |
138 | if self.current_user == ChessBoard.PIECE_STATE_SECOND:
139 | self.current_user = ChessBoard.PIECE_STATE_FIRST
140 | else:
141 | self.current_user = ChessBoard.PIECE_STATE_SECOND
142 |
143 | def is_over(self):
144 | return self.state == ChessBoard.STATE_DONE or self.state == ChessBoard.STATE_ABORT
145 |
146 | def dumps(self):
147 | if self.dump_cache is None:
148 | self.dump_cache = pickle.dumps((self.SIZE, self.board, self.state, self.current_user, self.move_history))
149 | return self.dump_cache
150 |
151 | @changed
152 | def loads(self, chess_str):
153 | self.SIZE, self.board, self.state, self.current_user, self.move_history = pickle.loads(chess_str)
154 |
155 |
156 | @changed
157 | def reset(self):
158 | self.board = np.zeros((self.SIZE + ChessBoard.PAD * 2, self.SIZE + ChessBoard.PAD * 2), dtype=np.uint8)
159 | self.state = ChessBoard.STATE_RUNNING
160 | self.current_user = ChessBoard.PIECE_STATE_FIRST
161 | self.move_num = 0
162 | self.move_history = []
163 |
164 | @changed
165 | def abort(self):
166 | self.state = ChessBoard.STATE_ABORT
167 |
--------------------------------------------------------------------------------
/evaluate/ChessClient.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import requests
3 | import cookielib
4 | # import http.cookiejar
5 | from bs4 import BeautifulSoup
6 | import json
7 | import time
8 | import cPickle as pickle
9 | # import _pickle as pickle
10 | import ChessHelper
11 | from ChessBoard import ChessBoard
12 | from Hall import GameRoom
13 | from Hall import User
14 | import random
15 | import sys
16 | import os
17 | # 方便引入 AlphaPig
18 | abs_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')
19 | sys.path.insert(0, abs_path)
20 | import AlphaPig as gomoku_zm
21 |
22 |
23 | class ChessClient():
24 | def __init__(self, server_url):
25 | self.session = requests.Session()
26 | self.session.cookies = cookielib.CookieJar()
27 | agent = 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Maxthon/5.1.2.3000 Chrome/55.0.2883.75 Safari/537.36'
28 | self.headers = {
29 | "Host": server_url,
30 | "Origin": server_url,
31 | "Referer": server_url,
32 | 'User-Agent': agent
33 | }
34 | self.server_url = server_url
35 | self.board = ChessBoard()
36 | self.last_status_signature = ""
37 |
38 | def send_get(self, url):
39 | return self.session.get(self.server_url + url, headers=self.headers)
40 |
41 | def send_post(self, url, data):
42 | return self.session.post(self.server_url + url, data, headers=self.headers)
43 |
44 | def login_in_guest(self):
45 | response = self.send_get('/login?action=login_in_guest')
46 | soup = BeautifulSoup(response.content, "html.parser")
47 | username_span = soup.find('span', attrs={'id': 'username'})
48 | if username_span:
49 | return username_span.text
50 | else:
51 | return None
52 |
53 | def login(self, username, password):
54 | response = self.send_post('/login?action=login',
55 | data={'username': username, 'password': password})
56 | soup = BeautifulSoup(response.content, "html.parser")
57 | username_span = soup.find('span', attrs={'id': 'username'})
58 | if username_span:
59 | return username_span.text
60 | else:
61 | return None
62 |
63 | def logout(self):
64 | self.send_get('/login?action=logout')
65 |
66 | def join_room(self, roomid):
67 | response = self.send_post('/action?action=joinroom',
68 | data={'roomid': roomid})
69 | action_result = json.loads(response.content)
70 | return action_result
71 |
72 | def join_game(self, cur_role):
73 | #response = self.send_get('/action?action=joingame')
74 | response = self.send_post('/action?action=joingame',
75 | data={'position': str(cur_role)})
76 | action_result = json.loads(response.content)
77 | return action_result
78 |
79 | def put_piece(self, row, col):
80 | response = self.send_get(
81 | '/action?action=gameaction&actionid=%s&piece_i=%d&piece_j=%d' % ('put_piece', row, col))
82 | action_result = json.loads(response.content)
83 | return action_result
84 |
85 | def get_room_info(self):
86 | response = self.send_get(
87 | '/action?action=gameaction&actionid=%s' % 'get_room_info')
88 | action_result = json.loads(response.content)
89 | room = pickle.loads(str(action_result['info']))
90 | return room
91 |
92 | def get_game_info(self):
93 | response = self.send_get(
94 | '/action?action=gameaction&actionid=%s' % 'get_game_info')
95 | action_result = json.loads(response.content)
96 | room = pickle.loads(str(action_result['info']))
97 | return room
98 |
99 | def get_user_info(self):
100 | response = self.send_get(
101 | '/action?action=gameaction&actionid=%s' % 'get_user_info')
102 | action_result = json.loads(response.content)
103 | room = pickle.loads(str(action_result['info']))
104 | return room
105 |
106 | def wait_game_info_changed(self, interval=0.5, max_time=100):
107 | wait_time = 0
108 | assert interval > 0, "interval must be positive"
109 | while True:
110 | response = self.send_get(
111 | '/action?action=gameaction&actionid=%s' % ('get_status_signature'))
112 | action_result = json.loads(response.content)
113 | if action_result['id'] == 0:
114 | status_signature = action_result['info']
115 | if self.last_status_signature != status_signature:
116 | self.last_status_signature = status_signature
117 | break
118 | else:
119 | print("ERROR get_status_signature,", action_result['id'], action_result['info'])
120 | break
121 | time.sleep(interval)
122 | wait_time += interval
123 | if wait_time > max_time:
124 | break
125 |
126 | return wait_time
127 |
128 | def get_all_rooms(self):
129 | response = self.send_get(
130 | '/action?action=get_all_rooms')
131 | action_result = json.loads(response.content)
132 | all_rooms = action_result['info']
133 | return all_rooms
134 |
135 | def answer_take_back(self, agree=True):
136 | response = self.send_get(
137 | '/action?action=gameaction&actionid=answer_take_back&agree=' + ('true' if agree else 'false'))
138 | action_result = json.loads(response.content)
139 | return action_result
140 |
141 |
142 | class GameStrategy_random():
143 | def __init__(self):
144 | self._chess_helper_move_set = []
145 | for i in range(15):
146 | for j in range(15):
147 | self._chess_helper_move_set.append((i, j))
148 | random.shuffle(self._chess_helper_move_set)
149 | self.try_step = 0
150 |
151 | def play_one_piece(self, user, gameboard):
152 | move = self._chess_helper_move_set[self.try_step]
153 | while gameboard.get_piece(move[0], move[1]) != 0 and self.try_step < 15 * 15:
154 | self.try_step += 1
155 | move = self._chess_helper_move_set[self.try_step]
156 | self.try_step += 1
157 | return move
158 |
159 |
160 | class GameStrategy_MZhang():
161 | def __init__(self, startplayer=0, complex_='s'):
162 | abs_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../')
163 | # 普通卷积
164 | """
165 | model_file2 = os.path.join(abs_path, './logs/current_policy_tf_small.model')
166 | model_file1 = os.path.join(abs_path, './logs/best_policy_tf_10999.model')
167 | model_file3 = os.path.join(abs_path, './logs/current_policy_1024.model')
168 | # 残差5层
169 | model_file_res = os.path.join(abs_path, './logs/current_res_5.model')
170 | model_file_res10 = os.path.join(abs_path, './logs/current_res_10.model')
171 | """
172 | model_file1 = os.path.join(abs_path, "./logs/current_policy.model")
173 | model_file2 = os.path.join(abs_path, "./logs/current_policy.model")
174 | model_file3 = os.path.join(abs_path, "./logs/current_policy.model")
175 | model_file_res = os.path.join(abs_path, "./logs/current_policy.model")
176 | model_file_res10 = os.path.join(abs_path, "./logs/current_policy.model")
177 |
178 | policy_param = None
179 | self.height = 15
180 | self.width = 15
181 | if model_file1 is not None:
182 | print('loading...', model_file1)
183 | try:
184 | policy_param = pickle.load(open(model_file1, 'rb'))
185 | except:
186 | policy_param = pickle.load(open(model_file1, 'rb'), encoding='bytes')
187 |
188 | if complex_ == 's':
189 | policy_value_net = gomoku_zm.policy_value_net_mxnet_simple.PolicyValueNet(self.height, self.width, batch_size=16, model_params=policy_param)
190 | self.mcts_player = gomoku_zm.mcts_alphaZero.MCTSPlayer(policy_value_net.policy_value_fn, c_puct=3, n_playout=80) # n_playout: 160 太久了 改成 80
191 | # 500:50s 200:32s
192 | # 残差网络 5层
193 | elif complex_ == 'r':
194 | if model_file_res is not None:
195 | try:
196 | policy_param_res = pickle.load(open(model_file_res, 'rb'))
197 | except:
198 | policy_param_res = pickle.load(open(model_file_res, 'rb'), encoding='bytes')
199 | policy_value_net_res = gomoku_zm.policy_value_net_mxnet.PolicyValueNet(self.height, self.width, batch_size=16, n_blocks=5, n_filter=128, model_params=policy_param_res)
200 | self.mcts_player = gomoku_zm.mcts_alphaZero.MCTSPlayer(policy_value_net_res.policy_value_fn, c_puct=3, n_playout=1290)
201 | # 10层残差
202 | elif complex_ == 'r10':
203 | if model_file_res10 is not None:
204 | try:
205 | policy_param_res = pickle.load(open(model_file_res10, 'rb'))
206 | except:
207 | policy_param_res = pickle.load(open(model_file_res10, 'rb'), encoding='bytes')
208 | policy_value_net_res = gomoku_zm.policy_value_net_mxnet.PolicyValueNet(self.height, self.width, batch_size=16, n_blocks=10, n_filter=128, model_params=policy_param_res)
209 | self.mcts_player = gomoku_zm.mcts_alphaZero.MCTSPlayer(policy_value_net_res.policy_value_fn, c_puct=3, n_playout=990)
210 | else:
211 | print("*=*" * 20, ": 模型参数错误")
212 | self.board = gomoku_zm.game.Board(width=self.width, height=self.height, n_in_row=5)
213 | self.board.init_board(startplayer)
214 | self.game = gomoku_zm.game.Game(self.board)
215 | p1, p2 = self.board.players
216 | print('players:', p1, p2)
217 | self.mcts_player.set_player_ind(p1)
218 | pass
219 |
220 | def play_one_piece(self, user, gameboard):
221 | print('user:', gameboard.get_current_user())
222 | print('gameboard:', gameboard.move_history)
223 | lastm = gameboard.get_lastmove()
224 | if lastm[0] != -1:
225 | usr, n, row, col = lastm
226 | mv = (self.height-row-1)*self.height+col
227 | if not self.board.states.has_key(mv):
228 | self.board.do_move(mv)
229 |
230 | print('board:', self.board.states.items())
231 | move = self.mcts_player.get_action(self.board)
232 | print('***' * 10)
233 | print('move: ', move)
234 | print('\n')
235 | self.board.do_move(move)
236 | # self.game.graphic(self.board, *self.board.players)
237 | outmv = (self.height-move//self.height-1, move%self.width)
238 |
239 | return outmv
240 |
241 |
242 |
243 | def go_play(args):
244 |
245 | client = ChessClient(args.server_url)
246 | client.login_in_guest()
247 | client.join_room(args.room_name)
248 | client.join_game(args.cur_role)
249 | user = client.get_user_info()
250 | print("加入游戏成功,你是:" + ("黑方" if user.game_role == 1 else "白方"))
251 | print('model is: ', args.model)
252 |
253 | if args.ai == 'random':
254 | strategy = GameStrategy_MZhang(user.game_role-1, args.model)
255 | else:
256 | assert False, "No other ai, you can add one or import the AICollection's ai."
257 |
258 | while True:
259 | wait_time = client.wait_game_info_changed()
260 | print('wait_time:', wait_time)
261 |
262 | room = client.get_room_info()
263 | # room=GameRoom()
264 | user = client.get_user_info()
265 | # user=User()
266 | gameboard = client.get_game_info()
267 | # gameboard = ChessBoard()
268 |
269 | print('room.get_status():', room.get_status())
270 | print('user.game_status():', user.game_status)
271 | print('gameboard.game_status():')
272 | ChessHelper.printBoard(gameboard)
273 |
274 | if room.get_status() == GameRoom.ROOM_STATUS_NOONE or room.get_status() == GameRoom.ROOM_STATUS_ONEWAITING:
275 | print("等待另一个对手加入游戏:")
276 | continue
277 | elif room.get_status() == GameRoom.ROOM_STATUS_PLAYING:
278 | if room.ask_take_back != 0 and room.ask_take_back != user.game_role:
279 | client.answer_take_back()
280 | return 0
281 | # break
282 | if gameboard.get_current_user() == user.game_role:
283 | print("轮到你走:")
284 | current_time = time.time()
285 | one_legal_piece = strategy.play_one_piece(user, gameboard)
286 | action_result = client.put_piece(*one_legal_piece)
287 | print('***' * 20)
288 | print('COST TIME: %s' % (time.time() - current_time))
289 | if action_result['id'] != 0:
290 | print("走棋失败:")
291 | print(ChessHelper.numToAlp(one_legal_piece[0]), ChessHelper.numToAlp(one_legal_piece[1]))
292 | print(action_result['info'])
293 |
294 | else:
295 | print("轮到对手走....")
296 | continue
297 | elif room.get_status() == GameRoom.ROOM_STATUS_FINISH:
298 | print("游戏已经结束了," + ("黑方" if gameboard.get_winner() == 1 else "白方") + " 赢了")
299 | return 1
300 | # break
301 |
302 |
303 | if __name__ == "__main__":
304 | import argparse
305 | import time
306 | parser = argparse.ArgumentParser()
307 | # yixin_XX_X 格式将会进入和yixin的对战中
308 | temp_room_name = 'yixin_anxingle_' + str(time.time())
309 | parser.add_argument('--room_name', type=str, default=temp_room_name)
310 | parser.add_argument('--cur_role', default=1)
311 | parser.add_argument('--model', default='s')
312 | print('room_name: ', temp_room_name)
313 | parser.add_argument('--server_url', default='http://127.0.0.1:33512')
314 | parser.add_argument('--ai', default='random')
315 | args = parser.parse_args()
316 | while True:
317 | status = go_play(args)
318 | if status == 1:
319 | time.sleep(5)
320 | else:
321 | print("房间有人!")
322 | time.sleep(10)
323 |
--------------------------------------------------------------------------------
/evaluate/ChessHelper.py:
--------------------------------------------------------------------------------
1 | from ChessBoard import ChessBoard
2 | import random
3 | from line_profiler import LineProfiler
4 |
5 |
6 | def numToAlp(num):
7 | return chr(ord('A') + num)
8 |
9 |
10 | def transferSymbol(sym):
11 | if sym == 0:
12 | return "."
13 | if sym == 1:
14 | return "O"
15 | if sym == 2:
16 | return "X"
17 | return "E"
18 |
19 |
20 | def printBoard(chessboard):
21 | for i in range(chessboard.SIZE + 1):
22 | for j in range(chessboard.SIZE + 1):
23 | if i == 0 and j == 0:
24 | print(' ')
25 | elif i == 0:
26 | print(numToAlp(j - 1))
27 | elif j == 0:
28 | print(numToAlp(i - 1))
29 | else:
30 | print(transferSymbol(chessboard.get_piece(i - 1, j - 1)))
31 | print()
32 |
33 | def printBoard2Str(chessboard):
34 | info_array=[]
35 |
36 | for i in range(chessboard.SIZE + 1):
37 | info_str = ""
38 | for j in range(chessboard.SIZE + 1):
39 |
40 | if i == 0 and j == 0:
41 | info_str+= ' '
42 | elif i == 0:
43 | info_str += numToAlp(j - 1)
44 | elif j == 0:
45 | info_str +=numToAlp(i - 1)
46 | else:
47 | info_str +=transferSymbol(chessboard.get_piece(i - 1, j - 1))
48 | info_str += '\t'
49 | # info_str +='\n'
50 |
51 | info_array.append(info_str)
52 | return info_array
53 |
54 | def playRandomGame(chessboard):
55 | muser = 1
56 | _chess_helper_move_set = []
57 | for i in range(15):
58 | for j in range(15):
59 | _chess_helper_move_set.append((i, j))
60 | random.shuffle(_chess_helper_move_set)
61 | for move in _chess_helper_move_set:
62 | if chessboard.is_over():
63 | # print "No place to put."
64 | return -5
65 |
66 | r_row = move[0]
67 | r_col = move[1]
68 |
69 | return_value = chessboard.put_piece(r_row, r_col, muser)
70 | muser = 2 if muser == 1 else 1
71 | if return_value != 0:
72 | # print ("\n%s win one board. last move is %s %s, return value is %d" % (
73 | # transferSymbol(chessboard.get_winner()), numToAlp(r_row), numToAlp(r_col), return_value))
74 | return return_value
75 |
76 |
77 | if __name__ == '__main__':
78 | def linePro():
79 | lp = LineProfiler()
80 |
81 | # lp_wrapper = lp(cb.put_piece)
82 | # lp_wrapper(7, 7, 1)
83 |
84 | def playMuch(num):
85 | oi = 0
86 | for i in xrange(num):
87 | cb = ChessBoard()
88 | playRandomGame(cb)
89 | oi += cb.move_num
90 | print(oi / num)
91 |
92 | lp_wrapper = lp(playMuch)
93 | lp_wrapper(1000)
94 | lp.print_stats()
95 |
96 |
97 | def playRandom():
98 | cb = ChessBoard()
99 | return_value = playRandomGame(cb)
100 | printBoard(cb)
101 | print ("\n%s win one board. last move is %s %s, return value is %d" % (
102 | transferSymbol(cb.get_winner()), numToAlp(cb.get_lastmove()[0]), numToAlp(cb.get_lastmove()[1]),
103 | return_value))
104 |
105 |
106 | playRandom()
107 | #linePro()
108 |
--------------------------------------------------------------------------------
/evaluate/ChessServer.py:
--------------------------------------------------------------------------------
1 | import tornado.ioloop
2 | import tornado.web
3 | import os
4 | import argparse
5 |
6 | from Hall import Hall
7 | from Hall import GameRoom
8 | from Hall import User
9 |
10 | hall = Hall()
11 |
12 |
13 | class BaseHandler(tornado.web.RequestHandler):
14 | def get_current_user(self):
15 | return self.get_secure_cookie("username")
16 |
17 |
18 | class ChessHandler(BaseHandler):
19 |
20 | def get_post(self):
21 | info_ = ""
22 | user = hall.get_user_with_uid(self.current_user)
23 | room = user.game_room
24 | chess_board = room.board if room else None
25 | self.render("page/chessboard.html", username=self.current_user, room=room,
26 | chess_board=chess_board, user=user)
27 |
28 | @tornado.web.authenticated
29 | def get(self):
30 | self.get_post()
31 |
32 | @tornado.web.authenticated
33 | def post(self):
34 | self.get_post()
35 |
36 |
37 | class ActionHandler(BaseHandler):
38 |
39 | def get_post(self):
40 | action = self.get_argument("action", None)
41 | action_result = {"id": -1, "info": "Failure"}
42 |
43 | if action:
44 | if action == "joinroom":
45 | roomid = self.get_argument("roomid", None)
46 | if roomid:
47 | hall.join_room(self.current_user, roomid)
48 | action_result["id"] = 0
49 | action_result["info"] = "Join room success."
50 | else:
51 | action_result["id"] = -1
52 | action_result["info"] = "Not legal room id."
53 |
54 | elif action == "joingame":
55 | user_role = int(self.get_argument("position", -1))
56 |
57 | if hall.join_game(self.current_user, user_role) == 0:
58 | action_result["id"] = 0
59 | action_result["info"] = "Join game success."
60 |
61 | else:
62 | action_result["id"] = -1
63 | action_result["info"] = "Join game failed, join a room first or have joined game or game is full."
64 |
65 | elif action == "gameaction":
66 | actionid = self.get_argument("actionid", None)
67 | game_action_result = hall.game_action(self.current_user, actionid, self)
68 | if game_action_result.result_id == 0:
69 | action_result["id"] = 0
70 | action_result["info"] = game_action_result.result_info
71 | else:
72 | action_result["id"] = -1
73 | action_result["info"] = "Game action failed:" + str(
74 | game_action_result.result_id) + "," + game_action_result.result_info
75 | elif action == "getboardinfo":
76 | room = hall.get_room_with_user(self.current_user)
77 | # room=GameRoom()
78 | if room:
79 | action_result["id"] = 0
80 | action_result["info"] = room.board.dumps()
81 | else:
82 | action_result["id"] = -1
83 | action_result["info"] = "Not in room, please join one."
84 | elif action == "get_all_rooms":
85 | action_result["id"] = 0
86 | action_result["info"] = [[room_name, hall.id2room[room_name].get_status()] for room_name in
87 | hall.id2room]
88 | elif action == "reset_room":
89 | user = hall.get_user_with_uid(self.current_user)
90 | room = user.game_room
91 | # room=GameRoom()
92 | if room and room.get_status() == GameRoom.ROOM_STATUS_FINISH and user in room.play_users:
93 | room.reset_game()
94 | action_result["id"] = 0
95 | action_result["info"] = "reset success"
96 | else:
97 | action_result["id"] = -1
98 | action_result["info"] = "reset failed."
99 |
100 | else:
101 | action_result["id"] = -1
102 | action_result["info"] = "Not recognition action" + action
103 | else:
104 | action_result["info"] = "Not action arg set"
105 |
106 | # self.write(tornado.escape.json_encode(action_result))
107 | self.finish(action_result)
108 |
109 | @tornado.web.authenticated
110 | def get(self):
111 | self.get_post()
112 |
113 | @tornado.web.authenticated
114 | def post(self):
115 | self.get_post()
116 |
117 |
118 | class LoginHandler(BaseHandler):
119 |
120 | def get_post(self):
121 |
122 | action = self.get_argument("action", None)
123 |
124 | if action == "login":
125 | if self.current_user is not None:
126 | self.clear_cookie("username")
127 | hall.logout(self.current_user)
128 |
129 | username = self.get_argument("username")
130 | password = self.get_argument("password")
131 | username = hall.login(username, password)
132 | if username:
133 | self.set_secure_cookie("username", username)
134 | self.redirect("/")
135 | else:
136 | self.redirect("/login?status=wrong_password_or_name")
137 | elif action == "login_in_guest":
138 | if self.current_user is not None:
139 | self.clear_cookie("username")
140 | hall.logout(self.current_user)
141 |
142 | username = hall.login_in_guest()
143 | print(username)
144 | if username:
145 | self.set_secure_cookie("username", username)
146 | self.redirect("/")
147 | elif action == "logout":
148 | if self.current_user is not None:
149 | self.clear_cookie("username")
150 | hall.logout(self.current_user)
151 | self.redirect("/login")
152 | else:
153 | self.render('page/login.html')
154 |
155 | def get(self):
156 | self.get_post()
157 |
158 | def post(self):
159 | self.get_post()
160 |
161 |
162 | def main(listen_port):
163 | settings = {
164 | # "template_path": os.path.join(os.path.dirname(__file__), "templates"),
165 | "cookie_secret": "bZJc2sWbQLKos6GkHn/VB9oXwQt8S0R0kRvJ5/xJ89E=",
166 | # "xsrf_cookies": True,
167 | "login_url": "/login",
168 | "static_path": os.path.join(os.path.dirname(__file__), "static"),
169 | }
170 | app = tornado.web.Application([
171 | (r"/", ChessHandler),
172 | (r"/login", LoginHandler),
173 | (r"/action", ActionHandler),
174 | ], **settings)
175 | app.listen(listen_port)
176 | tornado.ioloop.IOLoop.current().start()
177 |
178 |
179 | if __name__ == "__main__":
180 | parser = argparse.ArgumentParser()
181 | parser.add_argument('--port', default=8888)
182 | args = parser.parse_args()
183 | main(args.port)
184 |
--------------------------------------------------------------------------------
/evaluate/Hall.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from ChessBoard import ChessBoard
4 | import cPickle as pickle
5 | # import _pickle as pickle
6 | import random
7 | import string
8 | import os
9 |
10 |
11 | class User(object):
12 | def __init__(self, uid_, hall_):
13 | self.uid = uid_
14 | self.game_room = None
15 | self.hall = hall_
16 | self.game_status = User.USER_GAME_STATUS_NOJOIN
17 | self.game_role = -1
18 |
19 | USER_GAME_STATUS_NOJOIN = 0
20 | USER_GAME_STATUS_GAMEJOINED = 1
21 |
22 | def send_message(self):
23 | pass
24 |
25 | def receive_message(self, message):
26 | pass
27 |
28 | def send_game_state(self):
29 | pass
30 |
31 | def action(self, action):
32 | pass
33 |
34 | def join_room(self, room):
35 | if self.game_room is not None:
36 | self.leave_room()
37 |
38 | if room.join_room(self) == GameRoom.ACTION_SUCCESS:
39 | self.game_room = room
40 | else:
41 | return -1
42 |
43 | def join_game(self, user_role):
44 | if self.game_room is None:
45 | return -1
46 |
47 | if self.game_status == User.USER_GAME_STATUS_GAMEJOINED:
48 | return -1
49 | if self.game_room.join_game(self, user_role) == GameRoom.ACTION_SUCCESS:
50 | self.game_status = User.USER_GAME_STATUS_GAMEJOINED
51 | return 0
52 | else:
53 | return -1
54 |
55 | def leave_room(self):
56 | if self.game_room is None:
57 | return -1
58 | self.leave_game()
59 | self.game_room.leave_room(self)
60 |
61 | def leave_game(self):
62 | if self.game_room is None:
63 | return -1
64 | self.game_room.leave_game(self)
65 | self.game_status = User.USER_GAME_STATUS_NOJOIN
66 | self.game_role = -1
67 |
68 |
69 | class ActionResult(object):
70 | def __init__(self, result_id_=0, result_info_=""):
71 | self.result_id = result_id_
72 | self.result_info = result_info_
73 |
74 |
75 | class GameRoom(object):
76 | ACTION_SUCCESS = 0
77 | ACTION_FAILURE = -1
78 |
79 | def __init__(self, room_id_):
80 | self.play_users = []
81 | self.position2users = {} # 1 black role 2 white role
82 | self.room_id = room_id_
83 | self.users = []
84 | self.max_player_num = 2
85 | self.max_user_num = 10000
86 | self.board = ChessBoard()
87 | self.status_signature = None
88 | self.set_changed()
89 | self.chess_folder = 'chess_output'
90 | # self.game_status = GameRoom.GAME_STATUS_NOTBEGIN
91 | self.ask_take_back = 0
92 |
93 | def set_changed(self):
94 | self.status_signature = ''.join(random.sample(string.ascii_letters + string.digits, 8))
95 |
96 | def broadcast_message_to_all(self, message):
97 | self.set_changed()
98 | pass
99 |
100 | def send_message(self, to_user_id, message):
101 | self.set_changed()
102 | pass
103 |
104 | def get_last_move(self):
105 | (userrole, move_num, row, col) = self.board.get_lastmove()
106 | if userrole < 0:
107 | userrole = -1 * self.get_status() - 1
108 | last_move = {
109 | 'role': userrole,
110 | 'move_num': move_num,
111 | 'row': row,
112 | 'col': col,
113 | }
114 | return last_move
115 |
116 | # GAME_STATUS_NOTBEGIN = 0
117 | # GAME_STATUS_RUNING_HEIFANG = 1
118 | # GAME_STATUS_RUNING_BAIFANG = 2
119 | # GAME_STATUS_FINISH = 3
120 | # GAME_STATUS_ASKTAKEBACK_HEIFANG = 4
121 | # GAME_STATUS_ASKTAKEBACK_BAIFANG = 5
122 | # GAME_STATUS_ASKREBEGIN_HEIFANG = 6
123 | # GAME_STATUS_ASKREBEGIN_BAIFANG = 7
124 |
125 | def get_signature(self):
126 | return self.status_signature
127 |
128 | def action(self, user, action_code, action_args):
129 | if action_code == "put_piece":
130 | if self.ask_take_back != 0:
131 | return ActionResult(-1, "Wait take back answer before.")
132 | if not (user.game_role == 1) and not (user.game_role == 2):
133 | return ActionResult(-1, "Not the right time or role to put piece")
134 | piece_i = action_args.get_argument('piece_i', None)
135 | piece_j = action_args.get_argument('piece_j', None)
136 | if piece_i and piece_j:
137 | return_code = self.board.put_piece(int(piece_i), int(piece_j), user.game_role)
138 | if return_code >= 0:
139 | if return_code == 1:
140 | self.finish_game()
141 | self.set_changed()
142 | return ActionResult(0, "put_piece success:" + str(return_code));
143 | else:
144 | return ActionResult(-4, "put_piece failed, because " + str(return_code))
145 | else:
146 | return ActionResult(-3, "Not set the piece_i and piece_j")
147 | elif action_code == "getlastmove":
148 | return ActionResult(0, self.get_last_move())
149 | elif action_code == "get_status_signature":
150 | return ActionResult(0, self.get_signature())
151 | elif action_code == "get_room_info":
152 | return ActionResult(0, pickle.dumps(self))
153 | elif action_code == "get_game_info":
154 | return ActionResult(0, pickle.dumps(self.board))
155 | elif action_code == "get_user_info":
156 | return ActionResult(0, pickle.dumps(user))
157 | elif action_code == "ask_take_back":
158 | if self.ask_take_back != 0:
159 | return ActionResult(-1, "Not right time to ask take back.")
160 | if user.game_role != 1 and user.game_role != 2:
161 | return ActionResult(-1, "Not right role to ask take back.")
162 | self.ask_take_back = user.game_role
163 | self.set_changed()
164 | return ActionResult(0, "Ask take back success")
165 | elif action_code == "answer_take_back":
166 |
167 | if not (self.ask_take_back == 1 and user.game_role == 2) and not (
168 | self.ask_take_back == 2 and user.game_role == 1):
169 | return ActionResult(-1, "Not the right time or role to answer take back")
170 |
171 | agree = action_args.get_argument('agree', 'false')
172 | if agree == 'true':
173 | if self.ask_take_back == 1 and user.game_role == 2:
174 | if self.board.get_current_user() == 1:
175 | self.board.take_one_back()
176 | self.board.take_one_back()
177 |
178 | elif self.ask_take_back == 2 and user.game_role == 1:
179 | if self.board.get_current_user() == 2:
180 | self.board.take_one_back()
181 | self.board.take_one_back()
182 | self.ask_take_back = 0
183 | self.set_changed()
184 | return ActionResult(0, "answer take back success")
185 | else:
186 | return ActionResult(-2, "Not recognized game action")
187 |
188 | # def join(self, user):
189 | # if self.join_game(user) == GameRoom.ACTION_SUCCESS:
190 | # return GameRoom.ACTION_SUCCESS
191 | # return self.join_watch(user)
192 |
193 | def join_game(self, user, user_role):
194 | if user not in self.users:
195 | return GameRoom.ACTION_FAILURE
196 |
197 | if len(self.play_users) >= self.max_player_num:
198 | return GameRoom.ACTION_FAILURE
199 | idle_position = [i for i in range(1, self.max_player_num + 1) if i not in self.position2users]
200 | if len(idle_position) == 0:
201 | return GameRoom.ACTION_FAILURE
202 | if user_role is None or user_role == -1:
203 | user_role = idle_position[0]
204 | elif user_role == 0:
205 | random.shuffle(idle_position)
206 | user_role = idle_position[0]
207 | else:
208 | if user_role in self.position2users or user_role > self.max_player_num:
209 | return GameRoom.ACTION_FAILURE
210 | user.game_role = user_role
211 | self.position2users[user_role] = user
212 | self.play_users.append(user)
213 |
214 | self.set_changed()
215 | return GameRoom.ACTION_SUCCESS
216 |
217 | def join_room(self, user):
218 | if len(self.users) >= self.max_user_num:
219 | return GameRoom.ACTION_FAILURE
220 | self.users.append(user)
221 | self.set_changed()
222 | return GameRoom.ACTION_SUCCESS
223 |
224 | def leave_game(self, user):
225 | if user not in self.play_users:
226 | return
227 |
228 | if user in self.play_users:
229 | if self.get_status() == GameRoom.ROOM_STATUS_PLAYING:
230 | self.finish_game(state=-1)
231 | self.play_users.remove(user)
232 | if user.game_role in self.position2users:
233 | del self.position2users[user.game_role]
234 | self.set_changed()
235 |
236 | def leave_room(self, user):
237 | if user not in self.users:
238 | return
239 | self.leave_game(user)
240 | self.users.remove(user)
241 | self.set_changed()
242 |
243 | ROOM_STATUS_FINISH = 4
244 | ROOM_STATUS_NOONE = 1
245 | ROOM_STATUS_ONEWAITING = 2
246 | ROOM_STATUS_PLAYING = 3
247 | ROOM_STATUS_WRONG = -1
248 | ROOM_STATUS_NOTINROOM = 0
249 |
250 | def get_status(self):
251 | if self.board.is_over():
252 | return GameRoom.ROOM_STATUS_FINISH;
253 | if len(self.play_users) == 0:
254 | return GameRoom.ROOM_STATUS_NOONE;
255 | if len(self.play_users) == 1:
256 | return GameRoom.ROOM_STATUS_ONEWAITING;
257 | if len(self.play_users) == 2:
258 | return GameRoom.ROOM_STATUS_PLAYING;
259 | return GameRoom.ROOM_STATUS_WRONG;
260 |
261 | def reset_game(self):
262 | while len(self.play_users) > 0:
263 | self.play_users[0].leave_game()
264 | self.board.reset()
265 |
266 | def finish_game(self, state=0):
267 | if state == -1:
268 | self.board.abort()
269 |
270 | if not os.path.exists(self.chess_folder):
271 | os.makedirs(self.chess_folder)
272 | import datetime
273 | tm = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
274 | chess_file = self.chess_folder + '/' + self.room_id + '_[' + '|'.join(
275 | [user.uid for user in self.play_users]) + ']_' + tm + '.txt'
276 | with open(chess_file, 'w') as f:
277 | f.write(self.board.dumps())
278 |
279 | # def get_room_info(self):
280 | # room_info = {'status': 0, 'roomid': -1}
281 | # room_info['status'] = self.get_status()
282 | # room_info['roomid'] = self.room_id
283 | # room_info['room'] = self
284 | # room_info['users_uid'] = []
285 | # room_info['play_users_uid'] = []
286 | # room_info['users'] = []
287 | # room_info['play_users'] = []
288 | # for user in self.users:
289 | # room_info['users_uid'].append(user.uid)
290 | # room_info['users'].append(user)
291 | # for user in user.game_room.play_users:
292 | # room_info['play_users_uid'].append(user.uid)
293 | # room_info['play_users'].append(user)
294 |
295 |
296 | class Hall(object):
297 | def __init__(self):
298 | self.uid2user = {}
299 | self.id2room = {}
300 | self.MaxUserNum = 10000
301 | self.user_num = 0
302 |
303 | def login(self, username, password):
304 | if self.user_num > self.MaxUserNum:
305 | return None
306 | pass
307 |
308 | def login_in_guest(self):
309 | if self.user_num > self.MaxUserNum:
310 | return None
311 | import random
312 | username = "guest_"
313 | while True:
314 | rand_postfix = random.randint(10000, 1000000)
315 | username = "guest_" + str(rand_postfix)
316 | if username not in self.uid2user:
317 | break
318 | user = User(username, self)
319 | self.uid2user[username] = user
320 | return username
321 |
322 | def get_user_with_uid(self, userid):
323 | if userid not in self.uid2user:
324 | self.uid2user[userid] = User(userid, self)
325 | return self.uid2user[userid]
326 |
327 | def join_room(self, username, roomid):
328 | user = self.get_user_with_uid(username)
329 | if roomid not in self.id2room:
330 | self.id2room[roomid] = GameRoom(roomid)
331 | if user.game_room != self.id2room[roomid]:
332 | user.join_room(self.id2room[roomid])
333 |
334 | def get_room_info_with_user(self, username):
335 | room_info = {'status': 0, 'roomid': -1}
336 | user = self.get_user_with_uid(username)
337 | if user.game_room:
338 |
339 | room_info['status'] = user.game_room.get_status()
340 | room_info['roomid'] = user.game_room.room_id
341 | room_info['room'] = user.game_room
342 | room_info['users_uid'] = []
343 | room_info['play_users_uid'] = []
344 | room_info['users'] = []
345 | room_info['play_users'] = []
346 | for user in user.game_room.users:
347 | room_info['users_uid'].append(user.uid)
348 | room_info['users'].append(user)
349 | for user in user.game_room.play_users:
350 | room_info['play_users_uid'].append(user.uid)
351 | room_info['play_users'].append(user)
352 | return room_info
353 |
354 | def get_room_with_user(self, username):
355 | user = self.get_user_with_uid(username)
356 | if user.game_room:
357 | return user.game_room
358 | return None
359 |
360 | def join_game(self, username, user_role):
361 | user = self.get_user_with_uid(username)
362 | return user.join_game(user_role)
363 |
364 | def game_action(self, username, actionid, arg_pack):
365 | user = self.get_user_with_uid(username)
366 | if user.game_room:
367 | return user.game_room.action(user, actionid, arg_pack)
368 | else:
369 | return ActionResult(-1, "Not in any room")
370 |
371 | def logout(self, username):
372 | user = self.get_user_with_uid(username)
373 | if user.game_room:
374 | user.game_room.leave_room(user)
375 | self.uid2user.pop(username)
376 |
--------------------------------------------------------------------------------
/evaluate/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # 使用指南
4 | ```
5 | cd AlphaPig/evaluate
6 | python ChessServer.py # 将会在本机11111端口启动对弈服务,浏览器打开localhost:11111即可
7 | ```
8 |
9 | # 与AI对弈
10 |
11 | 1.
12 | ```
13 | # AlphaPig/evaluate 目录下
14 | python ChessClient.py --model XXX.model --cur_role 1/2 --room_name ROOM_NAME --server_url 127.0.0.1:8888
15 | ```
16 | --cur_role 1是黑手,2是白手; —room_name是对弈的房间号码,server_url 是对弈服务器的地址(默认是本机喽)
17 |
18 | 如果是两个AI对弈,则双方出了--cur_role之外,都敲入相同的参数。如果是AI与人,则谁先进去先有优先选择权。
19 |
20 | # 人人对弈
21 |
22 | 直接打开浏览器,约定房间,没什么好说的。
--------------------------------------------------------------------------------
/evaluate/ai_listen_ucloud.bat:
--------------------------------------------------------------------------------
1 | c:\Python27\python.exe AICollection.py --server_url http://120.132.59.147:11111
--------------------------------------------------------------------------------
/evaluate/logs/error.log:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/logs/error.log
--------------------------------------------------------------------------------
/evaluate/logs/info.log:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/logs/info.log
--------------------------------------------------------------------------------
/evaluate/page/chessboard.html:
--------------------------------------------------------------------------------
1 |
2 | {% from Hall import GameRoom %}
3 |
4 |
5 |
6 | Alpha猪
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
{{username}}
15 | ,
退出登录.
16 | 房间号:
17 |
18 |
19 |
20 |
21 | {% if room %}
22 |
23 |
34 |
35 | 你已经进入 {{room.room_id}} 房间,本房间共有{{len(room.users)}}位用户:
36 | {% for muser in room.users %}
37 | | {{muser.uid}}
38 | {% end %}
39 |
40 |
41 | 当前黑方为:
42 | {% if 1 in room.position2users %}
43 | {{room.position2users[1].uid}}
44 | {% else%}
45 | 空,
46 | {% if user.game_role<=0 %}
47 |
48 | {% end %}
49 | {% end %}
50 | ,当前白方为:
51 | {% if 2 in room.position2users %}
52 | {{room.position2users[2].uid}}
53 | {% else%}
54 | 空,
55 | {% if user.game_role<=0 %}
56 |
57 | {% end %}
58 | {% end %}
59 | {% if 1 not in room.position2users and 2 not in room.position2users and user.game_role<=0%}
60 |
61 | {% end %}
62 |
63 |
64 |
65 |
66 |
95 |
96 | 功能:
97 | {% if room.get_status()==GameRoom.ROOM_STATUS_PLAYING and room.ask_take_back==0 and (user.game_role==1 or user.game_role==2) %}
98 |
99 | {% end %}
100 |
101 | {% if room.get_status()==GameRoom.ROOM_STATUS_FINISH and (user.game_role==1 or user.game_role==2) %}
102 |
103 | {% end %}
104 |
105 |
106 |
107 |
108 | {% set winner_name="黑方" if chess_board.get_winner()==1 else "白方" %}
109 | {% set current_name="黑方" if chess_board.get_current_user()==1 else "白方" %}
110 | {% if room.get_status()==GameRoom.ROOM_STATUS_NOONE or \
111 | room.get_status()==GameRoom.ROOM_STATUS_ONEWAITING %}
112 | 等待选手加入游戏.
113 | {% elif room.get_status()==GameRoom.ROOM_STATUS_PLAYING %}
114 |
115 | {% if chess_board.get_current_user()==user.game_role %}
116 | 当前轮到你行动
117 | {% elif user.game_role<=0 %}
118 | 当前轮到{{current_name}}行动
119 | {% else %}
120 | 请等待对方下子
121 | {% end %}
122 |
123 | {% elif room.get_status()==GameRoom.ROOM_STATUS_FINISH %}
124 | 游戏结束,{{winner_name}}获胜。
125 |
130 |
131 | {% end %}
132 |
133 |
134 |
135 |
136 |
137 |
138 |
140 | {% for i in range(chess_board.SIZE) %}
141 |
142 |
143 | {% for j in range(chess_board.SIZE) %}
144 | {% if chess_board.get_piece(i, j)==0 %}
145 | |
147 | {% elif chess_board.get_piece(i, j)==1 %}
148 |
149 | {% if i== chess_board.get_lastmove()[-2] and j== chess_board.get_lastmove()[-1]%}
150 | Last
151 | {% end %}
152 | |
153 | {% elif chess_board.get_piece(i , j)==2 %}
154 |
155 | {% if i== chess_board.get_lastmove()[-2] and j== chess_board.get_lastmove()[-1]%}
156 | Last
157 | {% end %}
158 | |
159 | {% else %}
160 | not legal {{chess_board.get_piece(i , j)}}
161 | {% end%}
162 | {% end%}
163 |
164 | {% end%}
165 |
166 |
167 |
168 |
169 |
291 |
292 |
293 | {% end%}
294 |
295 |
337 |
338 |
339 |
--------------------------------------------------------------------------------
/evaluate/page/login.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Alpha猪
6 |
7 |
8 |
9 |
10 |
18 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/evaluate/run_ai.sh:
--------------------------------------------------------------------------------
1 | python ChessClient.py --cur_role 1 --model r10 --room_name 1 --server_url http://127.0.0.1:8888
2 |
--------------------------------------------------------------------------------
/evaluate/static/black.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/static/black.png
--------------------------------------------------------------------------------
/evaluate/static/blank.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/static/blank.png
--------------------------------------------------------------------------------
/evaluate/static/touming.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/static/touming.png
--------------------------------------------------------------------------------
/evaluate/static/white.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/static/white.png
--------------------------------------------------------------------------------
/evaluate/unittest/ChessBoardTest.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from ChessBoard import ChessBoard
3 | import ChessHelper
4 | from ChessHelper import printBoard
5 |
6 |
7 | class ChessBoardTest(unittest.TestCase):
8 | def setUp(self):
9 | self.a = 1
10 |
11 | def test_putpiece1(self):
12 | cb = ChessBoard()
13 |
14 | self.assertEqual(0, cb.put_piece(0, 0, 1))
15 | self.assertEqual(-2, cb.put_piece(0, 0, 1))
16 | self.assertEqual(-1, cb.put_piece(-1, 0, 1))
17 | self.assertEqual(-1, cb.put_piece(0, 16, 1))
18 |
19 | def test_putpiece2(self):
20 | cb = ChessBoard()
21 | muser = 1
22 | for i in xrange(4):
23 | for j in xrange(15):
24 | return_value = cb.put_piece(i, j, muser)
25 | self.assertEqual(0, return_value)
26 | muser = 2 if muser == 1 else 1
27 |
28 | def test_putpiece2(self):
29 | cb = ChessBoard()
30 | muser = 1
31 | for i in xrange(4):
32 | for j in xrange(15):
33 | return_value = cb.put_piece(i, j, muser)
34 | self.assertEqual(0, return_value)
35 | muser = 2 if muser == 1 else 1
36 |
37 | def test_putpiece3(self):
38 | cb = ChessBoard()
39 | ChessHelper.playRandomGame(cb)
40 |
41 |
42 | if __name__ == '__main__':
43 | unittest.main()
44 |
--------------------------------------------------------------------------------
/evaluate/yixin_ai/engine.exe:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/yixin_ai/engine.exe
--------------------------------------------------------------------------------
/evaluate/yixin_ai/msvcp140.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/yixin_ai/msvcp140.dll
--------------------------------------------------------------------------------
/evaluate/yixin_ai/vcruntime140.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/yixin_ai/vcruntime140.dll
--------------------------------------------------------------------------------
/evaluate/yixin_ai/yixin.exe:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/evaluate/yixin_ai/yixin.exe
--------------------------------------------------------------------------------
/game.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: Junxiao Song
4 | """
5 |
6 | from __future__ import print_function
7 | import numpy as np
8 | import os
9 | from policy_value_net_mxnet import PolicyValueNet # Keras
10 | from mcts_alphaZero import MCTSPlayer
11 | from utils import sgf_dataIter, config_loader
12 |
13 |
14 | import logging
15 | import logging.config
16 | logging.config.dictConfig(config_loader.config_['train_logging'])
17 | _logger = logging.getLogger(__name__)
18 |
19 | current_relative_path = lambda x: os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), x))
20 |
21 | class Board(object):
22 | """board for the game"""
23 |
24 | def __init__(self, **kwargs):
25 | self.width = int(kwargs.get('width', 8))
26 | self.height = int(kwargs.get('height', 8))
27 | # board states stored as a dict,
28 | # key: move as location on the board,
29 | # value: player as pieces type
30 | self.states = {}
31 | # need how many pieces in a row to win
32 | self.n_in_row = int(kwargs.get('n_in_row', 5))
33 | self.players = [1, 2] # player1 and player2
34 |
35 | def init_board(self, start_player=0):
36 | if self.width < self.n_in_row or self.height < self.n_in_row:
37 | raise Exception('board width and height can not be '
38 | 'less than {}'.format(self.n_in_row))
39 | self.current_player = self.players[start_player] # start player
40 | # keep available moves in a list
41 | self.availables = list(range(self.width * self.height))
42 | self.states = {}
43 | self.history = []
44 | self.last_move = -1
45 |
46 | def move_to_location(self, move):
47 | """
48 | 3*3 board's moves like:
49 | 6 7 8
50 | 3 4 5
51 | 0 1 2
52 | and move 5's location is (1,2)
53 | """
54 | h = move // self.width
55 | w = move % self.width
56 | return [h, w]
57 |
58 | def location_to_move(self, location):
59 | if len(location) != 2:
60 | return -1
61 | h = location[0]
62 | w = location[1]
63 | move = h * self.width + w
64 | if move not in range(self.width * self.height):
65 | return -1
66 | return move
67 |
68 | def current_state(self):
69 | """return the board state from the perspective of the current player.
70 | state shape: (histlen*2+1)*width*height
71 | """
72 | histlen = 4
73 | statelen = histlen*2+1
74 | square_state = np.zeros((statelen, self.width, self.height))
75 | if self.states:
76 | histarr = np.array(self.history)
77 | moves_all = histarr[:, 0]
78 | players_all = histarr[:, 1]
79 | real_len = len(moves_all)
80 | for i in range(histlen):
81 | moves = moves_all[:real_len-i]
82 | players = players_all[:real_len-i]
83 | #print(moves, players)
84 | move_curr = moves[players == self.current_player]
85 | move_oppo = moves[players != self.current_player]
86 | square_state[-2*i-3][move_curr // self.width,
87 | move_curr % self.height] = 1.0
88 | square_state[-2*i-2][move_oppo // self.width,
89 | move_oppo % self.height] = 1.0
90 | if real_len-i == 0:
91 | break
92 | if len(self.states) % 2 == 0:
93 | square_state[-1][:, :] = 1.0 # indicate the colour to play
94 | return square_state[:, ::-1, :]
95 |
96 | def current_state_old(self):
97 | """return the board state from the perspective of the current player.
98 | state shape: 4*width*height
99 | """
100 |
101 | square_state = np.zeros((4, self.width, self.height))
102 | if self.states:
103 | moves, players = np.array(list(zip(*self.states.items())))
104 | move_curr = moves[players == self.current_player]
105 | move_oppo = moves[players != self.current_player]
106 | square_state[0][move_curr // self.width,
107 | move_curr % self.height] = 1.0
108 | square_state[1][move_oppo // self.width,
109 | move_oppo % self.height] = 1.0
110 | # indicate the last move location
111 | square_state[2][self.last_move // self.width,
112 | self.last_move % self.height] = 1.0
113 | if len(self.states) % 2 == 0:
114 | square_state[3][:, :] = 1.0 # indicate the colour to play
115 | return square_state[:, ::-1, :]
116 |
117 | def do_move(self, move):
118 | self.states[move] = self.current_player
119 | self.history.append((move, self.current_player))
120 | self.availables.remove(move)
121 | self.current_player = (
122 | self.players[0] if self.current_player == self.players[1]
123 | else self.players[1]
124 | )
125 | self.last_move = move
126 |
127 | def has_a_winner(self):
128 | width = self.width
129 | height = self.height
130 | states = self.states
131 | n = self.n_in_row
132 |
133 | moved = list(set(range(width * height)) - set(self.availables))
134 | if len(moved) < self.n_in_row + 2:
135 | return False, -1
136 |
137 | for m in moved:
138 | h = m // width
139 | w = m % width
140 | player = states[m]
141 |
142 | if (w in range(width - n + 1) and
143 | len(set(states.get(i, -1) for i in range(m, m + n))) == 1):
144 | return True, player
145 |
146 | if (h in range(height - n + 1) and
147 | len(set(states.get(i, -1) for i in range(m, m + n * width, width))) == 1):
148 | return True, player
149 |
150 | if (w in range(width - n + 1) and h in range(height - n + 1) and
151 | len(set(states.get(i, -1) for i in range(m, m + n * (width + 1), width + 1))) == 1):
152 | return True, player
153 |
154 | if (w in range(n - 1, width) and h in range(height - n + 1) and
155 | len(set(states.get(i, -1) for i in range(m, m + n * (width - 1), width - 1))) == 1):
156 | return True, player
157 |
158 | return False, -1
159 |
160 | def game_end(self):
161 | """Check whether the game is ended or not"""
162 | win, winner = self.has_a_winner()
163 | if win:
164 | return True, winner
165 | elif not len(self.availables):
166 | return True, -1
167 | return False, -1
168 |
169 | def get_current_player(self):
170 | return self.current_player
171 |
172 |
173 | class Game(object):
174 | """game server"""
175 |
176 | def __init__(self, board, **kwargs):
177 | self.board = board
178 | self._boardSize = board.width * board.height
179 |
180 | def graphic(self, board, player1, player2):
181 | """Draw the board and show game info"""
182 | width = board.width
183 | height = board.height
184 |
185 | print("Player", player1, "with X".rjust(3))
186 | print("Player", player2, "with O".rjust(3))
187 | print()
188 | for x in range(width):
189 | print("{0:8}".format(x), end='')
190 | print('\r\n')
191 | for i in range(height - 1, -1, -1):
192 | print("{0:4d}".format(i), end='')
193 | for j in range(width):
194 | loc = i * width + j
195 | p = board.states.get(loc, -1)
196 | if p == player1:
197 | print('X'.center(8), end='')
198 | elif p == player2:
199 | print('O'.center(8), end='')
200 | else:
201 | print('_'.center(8), end='')
202 | print('\r\n\r\n')
203 |
204 | def start_play(self, player1, player2, start_player=0, is_shown=1):
205 | """start a game between two players"""
206 | if start_player not in (0, 1):
207 | raise Exception('start_player should be either 0 (player1 first) '
208 | 'or 1 (player2 first)')
209 | self.board.init_board(start_player)
210 | p1, p2 = self.board.players
211 | player1.set_player_ind(p1)
212 | player2.set_player_ind(p2)
213 | players = {p1: player1, p2: player2}
214 | if is_shown:
215 | self.graphic(self.board, player1.player, player2.player)
216 | while True:
217 | current_player = self.board.get_current_player()
218 | player_in_turn = players[current_player]
219 | move = player_in_turn.get_action(self.board)
220 | self.board.do_move(move)
221 | if is_shown:
222 | self.graphic(self.board, player1.player, player2.player)
223 | end, winner = self.board.game_end()
224 | if end:
225 | if is_shown:
226 | if winner != -1:
227 | print("Game end. Winner is", players[winner])
228 | else:
229 | print("Game end. Tie")
230 | return winner
231 |
232 |
233 | def start_self_play(self, player, is_shown=0, temp=1e-3, sgf_home=None, file_name=None):
234 | """ start a self-play game using a MCTS player, reuse the search tree,
235 | and store the self-play data: (state, mcts_probs, z) for training
236 | """
237 | # 获取棋盘数据
238 | X_train = sgf_dataIter.get_data_from_files(file_name, sgf_home)
239 | data_length = len(X_train['seq_num_list']) # 对弈长度(一盘棋盘数据的长度)
240 | self.board.init_board()
241 | p1, p2 = self.board.players
242 | # print('p1: ', p1, ' p2: ', p2)
243 | states, mcts_probs, current_players = [], [], []
244 | # while True:
245 | for num_index, move in enumerate(X_train['seq_num_list']):
246 | # move_, move_probs = player.get_action(self.board,
247 | # temp=temp,
248 | # return_prob=1)
249 | probs = [0.000001 for _ in range(self._boardSize)]
250 | probs[move] = 0.99999
251 | move_probs = np.asarray(probs)
252 | # print('move: ')
253 | # print(move)
254 | # print('move probs: ')
255 | # print(move_probs)
256 | # print(type(move_probs))
257 | # print(move_probs.shape)
258 | # store the data
259 | # print('current_state: \n')
260 | # print(self.board.current_state())
261 | states.append(self.board.current_state())
262 | mcts_probs.append(move_probs)
263 | current_players.append(self.board.current_player)
264 | # perform a move
265 | try:
266 | self.board.do_move(move)
267 | except Exception as ee:
268 | _logger.error("\033[40;31m %s \033[0m: anxingle file_name: %s, move: %s" %('WARNING', file_name, move) )
269 | warning, winner, mcts_probs = 1, None, None
270 | return warning, winner, mcts_probs
271 | if is_shown:
272 | self.graphic(self.board, p1, p2)
273 | # go_on= input('go on:')
274 | # if go_on == 1:
275 | # break
276 | # 既然使用现成的棋局文件, end判断当然也需要重新设置
277 | end, warning = 0, 1
278 | if num_index + 1 == data_length:
279 | end = 1
280 | winner = X_train['winner']
281 | try:
282 | # 这是一个故意的“bug”,目的在于检验是否end。不用也行。
283 | _logger.error('file_name %s has some problem! seq_num_list: %s' % (file_name, X_train['seq_num_list'][num_index+1]))
284 | except Exception as e:
285 | # 倘若进入了这个“bug”, 则不用报告warning
286 | warning = 0
287 | # print('you can ignore: ', e)
288 | # end, winner = self.board.game_end()
289 | if end:
290 | # winner from the perspective of the current player of each state
291 | winners_z = np.zeros(len(current_players))
292 | if winner != -1:
293 | winners_z[np.array(current_players) == winner] = 1.0
294 | winners_z[np.array(current_players) != winner] = -1.0
295 | # reset MCTS root node
296 | player.reset_player()
297 | if is_shown:
298 | if winner != -1:
299 | print("Game end. Winner is player:", winner)
300 | else:
301 | print("Game end. Tie")
302 | # go_on = input('go on:')
303 | # winner 1:2
304 | return warning, winner, zip(states, mcts_probs, winners_z)
305 |
306 |
307 | if __name__ == '__main__':
308 | model_file = 'current_policy.model'
309 | policy_value_net = PolicyValueNet(15, 15)
310 | mcts_player = MCTSPlayer(policy_value_net.policy_value_fn,
311 | c_puct=3,
312 | n_playout=2,
313 | is_selfplay=1)
314 | board = Board(width=15, height=15, n_in_row=5)
315 | game = Game(board)
316 | sgf_home = current_relative_path('./sgf_data')
317 | file_name = '1000_white_.sgf'
318 | winner, play_data = game.start_self_play(mcts_player, is_shown=1, temp=1.0, sgf_home=sgf_home, file_name=file_name)
319 |
320 |
321 |
--------------------------------------------------------------------------------
/game_ai.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: Junxiao Song
4 | """
5 |
6 | from __future__ import print_function
7 | import numpy as np
8 | from game import Board
9 | import random
10 |
11 | class Game_AI(object):
12 | """game server"""
13 |
14 | def __init__(self, board, **kwargs):
15 | self.board = board
16 | self._boardSize = board.width * board.height
17 |
18 | def graphic(self, board, player1, player2):
19 | """Draw the board and show game info"""
20 | width = board.width
21 | height = board.height
22 |
23 | print("Player", player1, "with X".rjust(3))
24 | print("Player", player2, "with O".rjust(3))
25 | print()
26 | for x in range(width):
27 | print("{0:8}".format(x), end='')
28 | print('\r\n')
29 | for i in range(height - 1, -1, -1):
30 | print("{0:4d}".format(i), end='')
31 | for j in range(width):
32 | loc = i * width + j
33 | p = board.states.get(loc, -1)
34 | if p == player1:
35 | print('X'.center(8), end='')
36 | elif p == player2:
37 | print('O'.center(8), end='')
38 | else:
39 | print('_'.center(8), end='')
40 | print('\r\n\r\n')
41 |
42 | def start_play(self, player1, player2, start_player=0, is_shown=1):
43 | """start a game between two players"""
44 | if start_player not in (0, 1):
45 | raise Exception('start_player should be either 0 (player1 first) '
46 | 'or 1 (player2 first)')
47 | self.board.init_board(start_player)
48 | p1, p2 = self.board.players
49 | player1.set_player_ind(p1)
50 | player2.set_player_ind(p2)
51 | players = {p1: player1, p2: player2}
52 | if is_shown:
53 | self.graphic(self.board, player1.player, player2.player)
54 | while True:
55 | current_player = self.board.get_current_player()
56 | player_in_turn = players[current_player]
57 | move = player_in_turn.get_action(self.board)
58 | self.board.do_move(move)
59 | if is_shown:
60 | self.graphic(self.board, player1.player, player2.player)
61 | end, winner = self.board.game_end()
62 | if end:
63 | if is_shown:
64 | if winner != -1:
65 | print("Game end. Winner is", players[winner])
66 | else:
67 | print("Game end. Tie")
68 | return winner
69 |
70 | def start_self_play(self, player, is_shown=0, temp=1e-3):
71 | """ start a self-play game using a MCTS player, reuse the search tree,
72 | and store the self-play data: (state, mcts_probs, z) for training
73 | """
74 | self.board.init_board()
75 | p1, p2 = self.board.players
76 | states, mcts_probs, current_players = [], [], []
77 | blank_move_list = [0,1,2,3,4,5,6,7,8,15,16,17,18,19,20,21,22,23,30,31,32,33,34,35,36,37,38,45,46,47,48,49,50,51,52,53,60,61,62,63,64,65,66,67,68,75,76,77,78,79,80,81,82,83,90,91,92,93,94,95,96,97,98]
78 | white_move_list = range(0, 103)
79 | if random.random() < 0.09:
80 | while True:
81 | move_blank = random.choice(blank_move_list)
82 | # move_blank = blank_move_list[random.randint(0, len()-1)]
83 | move_white = random.choice(white_move_list)
84 | if move_blank != move_white:
85 | break
86 | # store the data
87 | # 黑子走子概率
88 | probs = [0.000001 for _ in range(self._boardSize)]
89 | probs[move_blank] = 0.99999
90 | move_blank_probs = np.asarray(probs)
91 |
92 | states.append(self.board.current_state())
93 | mcts_probs.append(move_blank_probs)
94 | current_players.append(self.board.current_player)
95 | # perform a move
96 | self.board.do_move(move_blank)
97 | if is_shown:
98 | self.graphic(self.board, p1, p2)
99 |
100 | # 白子走子概率
101 | probs_ = [0.000001 for _ in range(self._boardSize)]
102 | probs_[move_white] = 0.99999
103 | move_white_probs = np.asarray(probs_)
104 |
105 | states.append(self.board.current_state())
106 | mcts_probs.append(move_white_probs)
107 | current_players.append(self.board.current_player)
108 | # perform a move
109 | self.board.do_move(move_white)
110 | if is_shown:
111 | self.graphic(self.board, p1, p2)
112 |
113 | while True:
114 | move, move_probs = player.get_action(self.board,
115 | temp=temp,
116 | return_prob=1)
117 | # store the data
118 | states.append(self.board.current_state())
119 | mcts_probs.append(move_probs)
120 | current_players.append(self.board.current_player)
121 | # perform a move
122 | self.board.do_move(move)
123 | if is_shown:
124 | self.graphic(self.board, p1, p2)
125 | end, winner = self.board.game_end()
126 | if end:
127 | # winner from the perspective of the current player of each state
128 | winners_z = np.zeros(len(current_players))
129 | if winner != -1:
130 | winners_z[np.array(current_players) == winner] = 1.0
131 | winners_z[np.array(current_players) != winner] = -1.0
132 | # reset MCTS root node
133 | player.reset_player()
134 | if is_shown:
135 | if winner != -1:
136 | print("Game end. Winner is player:", winner)
137 | else:
138 | print("Game end. Tie")
139 | return winner, zip(states, mcts_probs, winners_z)
140 |
--------------------------------------------------------------------------------
/human_play_mxnet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | human VS AI models
4 | Input your move in the format: 2,3
5 |
6 | @author: Junxiao Song
7 | """
8 |
9 | from __future__ import print_function
10 | import pickle
11 | from game import Board, Game
12 | from mcts_pure import MCTSPlayer as MCTS_Pure
13 | from mcts_alphaZero import MCTSPlayer
14 | from policy_value_net_mxnet import PolicyValueNet # Keras
15 |
16 |
17 | class Human(object):
18 | """
19 | human player
20 | """
21 |
22 | def __init__(self):
23 | self.player = None
24 |
25 | def set_player_ind(self, p):
26 | self.player = p
27 |
28 | def get_action(self, board):
29 | try:
30 | location = input("Your move: ")
31 | if isinstance(location, str): # for python3
32 | location = [int(n, 10) for n in location.split(",")]
33 | move = board.location_to_move(location)
34 | except Exception as e:
35 | move = -1
36 | if move == -1 or move not in board.availables:
37 | print("invalid move")
38 | move = self.get_action(board)
39 | return move
40 |
41 | def __str__(self):
42 | return "Human {}".format(self.player)
43 |
44 |
45 | def run():
46 | n = 5
47 | width, height = 15, 15
48 | #model_file = 'best_policy.model'
49 | model_file = './logs/current_policy.model'
50 | try:
51 | board = Board(width=width, height=height, n_in_row=n)
52 | game = Game(board)
53 |
54 | # ############### human VS AI ###################
55 | # load the trained policy_value_net in either Theano/Lasagne, PyTorch or TensorFlow
56 |
57 | # best_policy = PolicyValueNet(width, height, model_file = model_file)
58 | # mcts_player = MCTSPlayer(best_policy.policy_value_fn, c_puct=5, n_playout=400)
59 |
60 | # load the provided model (trained in Theano/Lasagne) into a MCTS player written in pure numpy
61 | try:
62 | policy_param = pickle.load(open(model_file, 'rb'))
63 | except:
64 | policy_param = pickle.load(open(model_file, 'rb'),
65 | encoding='bytes') # To support python3
66 | best_policy = PolicyValueNet(board_width=width, board_height=height, batch_size=512, model_params=policy_param)
67 | mcts_player = MCTSPlayer(best_policy.policy_value_fn,
68 | c_puct=5,
69 | n_playout=200) # set larger n_playout for better performance
70 |
71 | # uncomment the following line to play with pure MCTS (it's much weaker even with a larger n_playout)
72 | #mcts_player2 = MCTS_Pure(c_puct=5, n_playout=1000)
73 | # mcts_player2 = MCTSPlayer(best_policy.policy_value_fn,
74 | # c_puct=5,
75 | # n_playout=4000) # set larger n_playout for better performance
76 |
77 | # human player, input your move in the format: 2,3
78 | human = Human()
79 |
80 | # set start_player=0 for human first
81 | game.start_play(human, mcts_player, start_player=1, is_shown=1)
82 | except KeyboardInterrupt:
83 | print('\n\rquit')
84 |
85 |
86 | if __name__ == '__main__':
87 | run()
88 |
--------------------------------------------------------------------------------
/inception-resnet-v2.py:
--------------------------------------------------------------------------------
1 | # Licensed to the Apache Software Foundation (ASF) under one
2 | # or more contributor license agreements. See the NOTICE file
3 | # distributed with this work for additional information
4 | # regarding copyright ownership. The ASF licenses this file
5 | # to you under the Apache License, Version 2.0 (the
6 | # "License"); you may not use this file except in compliance
7 | # with the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing,
12 | # software distributed under the License is distributed on an
13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 | # KIND, either express or implied. See the License for the
15 | # specific language governing permissions and limitations
16 | # under the License.
17 |
18 | """
19 | Contains the definition of the Inception Resnet V2 architecture.
20 | As described in http://arxiv.org/abs/1602.07261.
21 | Inception-v4, Inception-ResNet and the Impact of Residual Connections
22 | on Learning
23 | Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi
24 | """
25 | import mxnet as mx
26 |
27 |
28 | def ConvFactory(data, num_filter, kernel, stride=(1, 1), pad=(0, 0), act_type="relu", mirror_attr={}, with_act=True):
29 | conv = mx.symbol.Convolution(
30 | data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)
31 | bn = mx.symbol.BatchNorm(data=conv)
32 | if with_act:
33 | act = mx.symbol.Activation(
34 | data=bn, act_type=act_type, attr=mirror_attr)
35 | return act
36 | else:
37 | return bn
38 |
39 |
40 | def block35(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}):
41 | tower_conv = ConvFactory(net, 32, (1, 1))
42 | tower_conv1_0 = ConvFactory(net, 32, (1, 1))
43 | tower_conv1_1 = ConvFactory(tower_conv1_0, 32, (3, 3), pad=(1, 1))
44 | tower_conv2_0 = ConvFactory(net, 32, (1, 1))
45 | tower_conv2_1 = ConvFactory(tower_conv2_0, 48, (3, 3), pad=(1, 1))
46 | tower_conv2_2 = ConvFactory(tower_conv2_1, 64, (3, 3), pad=(1, 1))
47 | tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_1, tower_conv2_2])
48 | tower_out = ConvFactory(
49 | tower_mixed, input_num_channels, (1, 1), with_act=False)
50 |
51 | net += scale * tower_out
52 | if with_act:
53 | act = mx.symbol.Activation(
54 | data=net, act_type=act_type, attr=mirror_attr)
55 | return act
56 | else:
57 | return net
58 |
59 |
60 | def block17(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}):
61 | tower_conv = ConvFactory(net, 192, (1, 1))
62 | tower_conv1_0 = ConvFactory(net, 129, (1, 1))
63 | tower_conv1_1 = ConvFactory(tower_conv1_0, 160, (1, 7), pad=(1, 2))
64 | tower_conv1_2 = ConvFactory(tower_conv1_1, 192, (7, 1), pad=(2, 1))
65 | tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_2])
66 | tower_out = ConvFactory(
67 | tower_mixed, input_num_channels, (1, 1), with_act=False)
68 | net += scale * tower_out
69 | if with_act:
70 | act = mx.symbol.Activation(
71 | data=net, act_type=act_type, attr=mirror_attr)
72 | return act
73 | else:
74 | return net
75 |
76 |
77 | def block8(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}):
78 | tower_conv = ConvFactory(net, 192, (1, 1))
79 | tower_conv1_0 = ConvFactory(net, 192, (1, 1))
80 | tower_conv1_1 = ConvFactory(tower_conv1_0, 224, (1, 3), pad=(0, 1))
81 | tower_conv1_2 = ConvFactory(tower_conv1_1, 256, (3, 1), pad=(1, 0))
82 | tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_2])
83 | tower_out = ConvFactory(
84 | tower_mixed, input_num_channels, (1, 1), with_act=False)
85 | net += scale * tower_out
86 | if with_act:
87 | act = mx.symbol.Activation(
88 | data=net, act_type=act_type, attr=mirror_attr)
89 | return act
90 | else:
91 | return net
92 |
93 |
94 | def repeat(inputs, repetitions, layer, *args, **kwargs):
95 | outputs = inputs
96 | for i in range(repetitions):
97 | outputs = layer(outputs, *args, **kwargs)
98 | return outputs
99 |
100 |
101 | def get_symbol(num_classes=1000, **kwargs):
102 | data = mx.symbol.Variable(name='data')
103 | conv1a_3_3 = ConvFactory(data=data, num_filter=32,
104 | kernel=(3, 3), stride=(2, 2))
105 | conv2a_3_3 = ConvFactory(conv1a_3_3, 32, (3, 3))
106 | conv2b_3_3 = ConvFactory(conv2a_3_3, 64, (3, 3), pad=(1, 1))
107 | maxpool3a_3_3 = mx.symbol.Pooling(
108 | data=conv2b_3_3, kernel=(3, 3), stride=(2, 2), pool_type='max')
109 | conv3b_1_1 = ConvFactory(maxpool3a_3_3, 80, (1, 1))
110 | conv4a_3_3 = ConvFactory(conv3b_1_1, 192, (3, 3))
111 | maxpool5a_3_3 = mx.symbol.Pooling(
112 | data=conv4a_3_3, kernel=(3, 3), stride=(2, 2), pool_type='max')
113 |
114 | tower_conv = ConvFactory(maxpool5a_3_3, 96, (1, 1))
115 | tower_conv1_0 = ConvFactory(maxpool5a_3_3, 48, (1, 1))
116 | tower_conv1_1 = ConvFactory(tower_conv1_0, 64, (5, 5), pad=(2, 2))
117 |
118 | tower_conv2_0 = ConvFactory(maxpool5a_3_3, 64, (1, 1))
119 | tower_conv2_1 = ConvFactory(tower_conv2_0, 96, (3, 3), pad=(1, 1))
120 | tower_conv2_2 = ConvFactory(tower_conv2_1, 96, (3, 3), pad=(1, 1))
121 |
122 | tower_pool3_0 = mx.symbol.Pooling(data=maxpool5a_3_3, kernel=(
123 | 3, 3), stride=(1, 1), pad=(1, 1), pool_type='avg')
124 | tower_conv3_1 = ConvFactory(tower_pool3_0, 64, (1, 1))
125 | tower_5b_out = mx.symbol.Concat(
126 | *[tower_conv, tower_conv1_1, tower_conv2_2, tower_conv3_1])
127 | net = repeat(tower_5b_out, 10, block35, scale=0.17, input_num_channels=320)
128 | tower_conv = ConvFactory(net, 384, (3, 3), stride=(2, 2))
129 | tower_conv1_0 = ConvFactory(net, 256, (1, 1))
130 | tower_conv1_1 = ConvFactory(tower_conv1_0, 256, (3, 3), pad=(1, 1))
131 | tower_conv1_2 = ConvFactory(tower_conv1_1, 384, (3, 3), stride=(2, 2))
132 | tower_pool = mx.symbol.Pooling(net, kernel=(
133 | 3, 3), stride=(2, 2), pool_type='max')
134 | net = mx.symbol.Concat(*[tower_conv, tower_conv1_2, tower_pool])
135 | net = repeat(net, 20, block17, scale=0.1, input_num_channels=1088)
136 | tower_conv = ConvFactory(net, 256, (1, 1))
137 | tower_conv0_1 = ConvFactory(tower_conv, 384, (3, 3), stride=(2, 2))
138 | tower_conv1 = ConvFactory(net, 256, (1, 1))
139 | tower_conv1_1 = ConvFactory(tower_conv1, 288, (3, 3), stride=(2, 2))
140 | tower_conv2 = ConvFactory(net, 256, (1, 1))
141 | tower_conv2_1 = ConvFactory(tower_conv2, 288, (3, 3), pad=(1, 1))
142 | tower_conv2_2 = ConvFactory(tower_conv2_1, 320, (3, 3), stride=(2, 2))
143 | tower_pool = mx.symbol.Pooling(net, kernel=(
144 | 3, 3), stride=(2, 2), pool_type='max')
145 | net = mx.symbol.Concat(
146 | *[tower_conv0_1, tower_conv1_1, tower_conv2_2, tower_pool])
147 |
148 | net = repeat(net, 9, block8, scale=0.2, input_num_channels=2080)
149 | net = block8(net, with_act=False, input_num_channels=2080)
150 |
151 | net = ConvFactory(net, 1536, (1, 1))
152 | net = mx.symbol.Pooling(net, kernel=(
153 | 1, 1), global_pool=True, stride=(2, 2), pool_type='avg')
154 | net = mx.symbol.Flatten(net)
155 | net = mx.symbol.Dropout(data=net, p=0.2)
156 | net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes)
157 | softmax = mx.symbol.SoftmaxOutput(data=net, name='softmax')
158 | return softmax
159 |
--------------------------------------------------------------------------------
/logs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/logs/.gitkeep
--------------------------------------------------------------------------------
/logs/download_model.sh:
--------------------------------------------------------------------------------
1 | wget "http://p28sk9doh.bkt.clouddn.com/current_res_5.model" # for 5 block resnet
2 | wget "http://p28sk9doh.bkt.clouddn.com/current_policy_simple_10999.model" # for 10 block resnet
3 | wget "http://p28sk9doh.bkt.clouddn.com/current_policy_res_10.model" # for simple policy_network_simple.py
--------------------------------------------------------------------------------
/mcts_alphaZero.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Monte Carlo Tree Search in AlphaGo Zero style, which uses a policy-value
4 | network to guide the tree search and evaluate the leaf nodes
5 |
6 | @author: Junxiao Song
7 | """
8 |
9 | import numpy as np
10 | import copy
11 |
12 |
13 | def softmax(x):
14 | probs = np.exp(x - np.max(x))
15 | probs /= np.sum(probs)
16 | return probs
17 |
18 |
19 | class TreeNode(object):
20 | """A node in the MCTS tree.
21 |
22 | Each node keeps track of its own value Q, prior probability P, and
23 | its visit-count-adjusted prior score u.
24 | """
25 |
26 | def __init__(self, parent, prior_p):
27 | self._parent = parent
28 | self._children = {} # a map from action to TreeNode
29 | self._n_visits = 0
30 | self._Q = 0
31 | self._u = 0
32 | self._P = prior_p
33 |
34 | def expand(self, action_priors):
35 | """Expand tree by creating new children.
36 | action_priors: a list of tuples of actions and their prior probability
37 | according to the policy function.
38 | """
39 | for action, prob in action_priors:
40 | if action not in self._children:
41 | self._children[action] = TreeNode(self, prob)
42 |
43 | def select(self, c_puct):
44 | """Select action among children that gives maximum action value Q
45 | plus bonus u(P).
46 | Return: A tuple of (action, next_node)
47 | """
48 | return max(self._children.items(),
49 | key=lambda act_node: act_node[1].get_value(c_puct))
50 |
51 | def update(self, leaf_value):
52 | """Update node values from leaf evaluation.
53 | leaf_value: the value of subtree evaluation from the current player's
54 | perspective.
55 | """
56 | # Count visit.
57 | self._n_visits += 1
58 | # Update Q, a running average of values for all visits.
59 | self._Q += 1.0*(leaf_value - self._Q) / self._n_visits
60 |
61 | def update_recursive(self, leaf_value):
62 | """Like a call to update(), but applied recursively for all ancestors.
63 | """
64 | # If it is not root, this node's parent should be updated first.
65 | if self._parent:
66 | self._parent.update_recursive(-leaf_value)
67 | self.update(leaf_value)
68 |
69 | def get_value(self, c_puct):
70 | """Calculate and return the value for this node.
71 | It is a combination of leaf evaluations Q, and this node's prior
72 | adjusted for its visit count, u.
73 | c_puct: a number in (0, inf) controlling the relative impact of
74 | value Q, and prior probability P, on this node's score.
75 | 参考Upper Confidence Bounds (UCB)选择公式:
76 | Latex: $score=x_{child}+C*\sqrt{\frac{logN_{parent}}{N_{child}}}$
77 | """
78 | self._u = (c_puct * self._P *
79 | np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
80 | return self._Q + self._u
81 |
82 | def is_leaf(self):
83 | """Check if leaf node (i.e. no nodes below this have been expanded)."""
84 | return self._children == {}
85 |
86 | def is_root(self):
87 | return self._parent is None
88 |
89 |
90 | class MCTS(object):
91 | """An implementation of Monte Carlo Tree Search."""
92 |
93 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
94 | """
95 | policy_value_fn: a function that takes in a board state and outputs
96 | a list of (action, probability) tuples and also a score in [-1, 1]
97 | (i.e. the expected value of the end game score from the current
98 | player's perspective) for the current player.
99 | c_puct: a number in (0, inf) that controls how quickly exploration
100 | converges to the maximum-value policy. A higher value means
101 | relying on the prior more.
102 | """
103 | self._root = TreeNode(None, 1.0)
104 | self._policy = policy_value_fn
105 | self._c_puct = c_puct
106 | self._n_playout = n_playout
107 |
108 | def _playout(self, state):
109 | """Run a single playout from the root to the leaf, getting a value at
110 | the leaf and propagating it back through its parents.
111 | State is modified in-place, so a copy must be provided.
112 | """
113 | node = self._root
114 | while(1):
115 | if node.is_leaf():
116 | break
117 | # Greedily select next move.
118 | action, node = node.select(self._c_puct)
119 | state.do_move(action)
120 |
121 | # Evaluate the leaf using a network which outputs a list of
122 | # (action, probability) tuples p and also a score v in [-1, 1]
123 | # for the current player.
124 | action_probs, leaf_value = self._policy(state)
125 | # Check for end of game.
126 | end, winner = state.game_end()
127 | if not end:
128 | node.expand(action_probs)
129 | else:
130 | # for end state,return the "true" leaf_value
131 | if winner == -1: # tie
132 | leaf_value = 0.0
133 | else:
134 | leaf_value = (
135 | 1.0 if winner == state.get_current_player() else -1.0
136 | )
137 |
138 | # Update value and visit count of nodes in this traversal.
139 | node.update_recursive(-leaf_value)
140 |
141 | def get_move_probs(self, state, temp=1e-3):
142 | """Run all playouts sequentially and return the available actions and
143 | their corresponding probabilities.
144 | state: the current game state
145 | temp: temperature parameter in (0, 1] controls the level of exploration
146 | """
147 | for n in range(self._n_playout):
148 | state_copy = copy.deepcopy(state)
149 | self._playout(state_copy)
150 |
151 | # calc the move probabilities based on visit counts at the root node
152 | act_visits = [(act, node._n_visits)
153 | for act, node in self._root._children.items()]
154 | acts, visits = zip(*act_visits)
155 | act_probs = softmax(1.0/temp * np.log(np.array(visits) + 1e-10))
156 |
157 | return acts, act_probs
158 |
159 | def update_with_move(self, last_move):
160 | """Step forward in the tree, keeping everything we already know
161 | about the subtree.
162 | """
163 | if last_move in self._root._children:
164 | self._root = self._root._children[last_move]
165 | self._root._parent = None
166 | else:
167 | self._root = TreeNode(None, 1.0)
168 |
169 | def __str__(self):
170 | return "MCTS"
171 |
172 |
173 | class MCTSPlayer(object):
174 | """AI player based on MCTS"""
175 |
176 | def __init__(self, policy_value_function,
177 | c_puct=5, n_playout=2000, is_selfplay=0):
178 | self.mcts = MCTS(policy_value_function, c_puct, n_playout)
179 | self._is_selfplay = is_selfplay
180 |
181 | def set_player_ind(self, p):
182 | self.player = p
183 |
184 | def reset_player(self):
185 | self.mcts.update_with_move(-1)
186 |
187 | def get_action(self, board, temp=1e-3, return_prob=0):
188 | sensible_moves = board.availables
189 | # the pi vector returned by MCTS as in the alphaGo Zero paper
190 | move_probs = np.zeros(board.width*board.height)
191 | if len(sensible_moves) > 0:
192 | acts, probs = self.mcts.get_move_probs(board, temp)
193 | move_probs[list(acts)] = probs
194 | if self._is_selfplay:
195 | # add Dirichlet Noise for exploration (needed for
196 | # self-play training)
197 | # dirichlet_noise = K * 1.0/num_move_probs 尝试将0.3设置为0.2甚至更小
198 | move = np.random.choice(
199 | acts,
200 | p=0.75*probs + 0.25*np.random.dirichlet(0.3*np.ones(len(probs)))
201 | )
202 | # update the root node and reuse the search tree
203 | self.mcts.update_with_move(move)
204 | else:
205 | # with the default temp=1e-3, it is almost equivalent
206 | # to choosing the move with the highest prob
207 | move = np.random.choice(acts, p=probs)
208 | # reset the root node
209 | self.mcts.update_with_move(-1)
210 | # location = board.move_to_location(move)
211 | # print("AI move: %d,%d\n" % (location[0], location[1]))
212 |
213 | if return_prob:
214 | return move, move_probs
215 | else:
216 | return move
217 | else:
218 | print("WARNING: the board is full")
219 |
220 | def __str__(self):
221 | return "MCTS {}".format(self.player)
222 |
--------------------------------------------------------------------------------
/mcts_pure.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | A pure implementation of the Monte Carlo Tree Search (MCTS)
4 |
5 | @author: Junxiao Song
6 | """
7 |
8 | import numpy as np
9 | import copy
10 | from operator import itemgetter
11 |
12 |
13 | def rollout_policy_fn(board):
14 | """a coarse, fast version of policy_fn used in the rollout phase."""
15 | # rollout randomly
16 | action_probs = np.random.rand(len(board.availables))
17 | return zip(board.availables, action_probs)
18 |
19 |
20 | def policy_value_fn(board):
21 | """a function that takes in a state and outputs a list of (action, probability)
22 | tuples and a score for the state"""
23 | # return uniform probabilities and 0 score for pure MCTS
24 | action_probs = np.ones(len(board.availables))/len(board.availables)
25 | return zip(board.availables, action_probs), 0
26 |
27 |
28 | class TreeNode(object):
29 | """A node in the MCTS tree. Each node keeps track of its own value Q,
30 | prior probability P, and its visit-count-adjusted prior score u.
31 | """
32 |
33 | def __init__(self, parent, prior_p):
34 | self._parent = parent
35 | self._children = {} # a map from action to TreeNode
36 | self._n_visits = 0
37 | self._Q = 0
38 | self._u = 0
39 | self._P = prior_p
40 |
41 | def expand(self, action_priors):
42 | """Expand tree by creating new children.
43 | action_priors: a list of tuples of actions and their prior probability
44 | according to the policy function.
45 | """
46 | for action, prob in action_priors:
47 | if action not in self._children:
48 | self._children[action] = TreeNode(self, prob)
49 |
50 | def select(self, c_puct):
51 | """Select action among children that gives maximum action value Q
52 | plus bonus u(P).
53 | Return: A tuple of (action, next_node)
54 | """
55 | return max(self._children.items(),
56 | key=lambda act_node: act_node[1].get_value(c_puct))
57 |
58 | def update(self, leaf_value):
59 | """Update node values from leaf evaluation.
60 | leaf_value: the value of subtree evaluation from the current player's
61 | perspective.
62 | """
63 | # Count visit.
64 | self._n_visits += 1
65 | # Update Q, a running average of values for all visits.
66 | self._Q += 1.0*(leaf_value - self._Q) / self._n_visits
67 |
68 | def update_recursive(self, leaf_value):
69 | """Like a call to update(), but applied recursively for all ancestors.
70 | """
71 | # If it is not root, this node's parent should be updated first.
72 | if self._parent:
73 | self._parent.update_recursive(-leaf_value)
74 | self.update(leaf_value)
75 |
76 | def get_value(self, c_puct):
77 | """Calculate and return the value for this node.
78 | It is a combination of leaf evaluations Q, and this node's prior
79 | adjusted for its visit count, u.
80 | c_puct: a number in (0, inf) controlling the relative impact of
81 | value Q, and prior probability P, on this node's score.
82 | """
83 | self._u = (c_puct * self._P *
84 | np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
85 | return self._Q + self._u
86 |
87 | def is_leaf(self):
88 | """Check if leaf node (i.e. no nodes below this have been expanded).
89 | """
90 | return self._children == {}
91 |
92 | def is_root(self):
93 | return self._parent is None
94 |
95 |
96 | class MCTS(object):
97 | """A simple implementation of Monte Carlo Tree Search."""
98 |
99 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
100 | """
101 | policy_value_fn: a function that takes in a board state and outputs
102 | a list of (action, probability) tuples and also a score in [-1, 1]
103 | (i.e. the expected value of the end game score from the current
104 | player's perspective) for the current player.
105 | c_puct: a number in (0, inf) that controls how quickly exploration
106 | converges to the maximum-value policy. A higher value means
107 | relying on the prior more.
108 | """
109 | self._root = TreeNode(None, 1.0)
110 | self._policy = policy_value_fn
111 | self._c_puct = c_puct
112 | self._n_playout = n_playout
113 |
114 | def _playout(self, state):
115 | """Run a single playout from the root to the leaf, getting a value at
116 | the leaf and propagating it back through its parents.
117 | State is modified in-place, so a copy must be provided.
118 | """
119 | node = self._root
120 | while(1):
121 | if node.is_leaf():
122 |
123 | break
124 | # Greedily select next move.
125 | action, node = node.select(self._c_puct)
126 | state.do_move(action)
127 |
128 | action_probs, _ = self._policy(state)
129 | # Check for end of game
130 | end, winner = state.game_end()
131 | if not end:
132 | node.expand(action_probs)
133 | # Evaluate the leaf node by random rollout
134 | leaf_value = self._evaluate_rollout(state)
135 | # Update value and visit count of nodes in this traversal.
136 | node.update_recursive(-leaf_value)
137 |
138 | def _evaluate_rollout(self, state, limit=1000):
139 | """Use the rollout policy to play until the end of the game,
140 | returning +1 if the current player wins, -1 if the opponent wins,
141 | and 0 if it is a tie.
142 | """
143 | player = state.get_current_player()
144 | for i in range(limit):
145 | end, winner = state.game_end()
146 | if end:
147 | break
148 | action_probs = rollout_policy_fn(state)
149 | max_action = max(action_probs, key=itemgetter(1))[0]
150 | state.do_move(max_action)
151 | else:
152 | # If no break from the loop, issue a warning.
153 | print("WARNING: rollout reached move limit")
154 | if winner == -1: # tie
155 | return 0
156 | else:
157 | return 1 if winner == player else -1
158 |
159 | def get_move(self, state):
160 | """Runs all playouts sequentially and returns the most visited action.
161 | state: the current game state
162 |
163 | Return: the selected action
164 | """
165 | for n in range(self._n_playout):
166 | state_copy = copy.deepcopy(state)
167 | self._playout(state_copy)
168 | return max(self._root._children.items(),
169 | key=lambda act_node: act_node[1]._n_visits)[0]
170 |
171 | def update_with_move(self, last_move):
172 | """Step forward in the tree, keeping everything we already know
173 | about the subtree.
174 | """
175 | if last_move in self._root._children:
176 | self._root = self._root._children[last_move]
177 | self._root._parent = None
178 | else:
179 | self._root = TreeNode(None, 1.0)
180 |
181 | def __str__(self):
182 | return "MCTS"
183 |
184 |
185 | class MCTSPlayer(object):
186 | """AI player based on MCTS"""
187 | def __init__(self, c_puct=5, n_playout=2000):
188 | self.mcts = MCTS(policy_value_fn, c_puct, n_playout)
189 |
190 | def set_player_ind(self, p):
191 | self.player = p
192 |
193 | def reset_player(self):
194 | self.mcts.update_with_move(-1)
195 |
196 | def get_action(self, board):
197 | sensible_moves = board.availables
198 | if len(sensible_moves) > 0:
199 | move = self.mcts.get_move(board)
200 | self.mcts.update_with_move(-1)
201 | return move
202 | else:
203 | print("WARNING: the board is full")
204 |
205 | def __str__(self):
206 | return "MCTS {}".format(self.player)
207 |
--------------------------------------------------------------------------------
/papers/thinking-fast-and-slow-with-deep-learning-and-tree-search.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anxingle/AlphaPig/16257003c6fa33d3583b8d43919f4ec944f8b5bb/papers/thinking-fast-and-slow-with-deep-learning-and-tree-search.pdf
--------------------------------------------------------------------------------
/play_vs_yixin.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import sys, time
4 | import random
5 |
6 |
7 | # ----------------------------------------------------------------------
8 | # chessboard: 棋盘类,简单从字符串加载棋局或者导出字符串,判断输赢等
9 | # ----------------------------------------------------------------------
10 | class chessboard(object):
11 |
12 | def __init__(self, forbidden=0):
13 | # list内list
14 | self.__board = [[0 for n in xrange(15)] for m in xrange(15)]
15 | self.__forbidden = forbidden
16 | self.__dirs = ((-1, 0), (-1, 1), (0, 1), (1, 1), (1, 0), \
17 | (1, -1), (0, -1), (-1, -1))
18 | self.DIRS = self.__dirs
19 | self.won = {}
20 |
21 | # 清空棋盘
22 | def reset(self):
23 | for j in xrange(15):
24 | for i in xrange(15):
25 | self.__board[i][j] = 0
26 | return 0
27 |
28 | # 索引器
29 | def __getitem__(self, row):
30 | return self.__board[row]
31 |
32 | # 将棋盘转换成字符串
33 | def __str__(self):
34 | text = ' A B C D E F G H I J K L M N O\n'
35 | mark = ('. ', 'O ', 'X ')
36 | nrow = 0
37 | for row in self.__board:
38 | line = ''.join([mark[n] for n in row])
39 | text += chr(ord('A') + nrow) + ' ' + line
40 | nrow += 1
41 | if nrow < 15: text += '\n'
42 | return text
43 |
44 | # 转成字符串
45 | def __repr__(self):
46 | return self.__str__()
47 |
48 | def get(self, row, col):
49 | if row < 0 or row >= 15 or col < 0 or col >= 15:
50 | return 0
51 | return self.__board[row][col]
52 |
53 | def put(self, row, col, x):
54 | if row >= 0 and row < 15 and col >= 0 and col < 15:
55 | self.__board[row][col] = x
56 | return 0
57 |
58 | # 判断输赢,返回0(无输赢),1(白棋赢),2(黑棋赢)
59 | def check(self):
60 | board = self.__board
61 | dirs = ((1, -1), (1, 0), (1, 1), (0, 1))
62 | for i in xrange(15):
63 | for j in xrange(15):
64 | if board[i][j] == 0: continue
65 | # id 是该位置的棋子(0或X): i行,j列
66 | id = board[i][j]
67 | for d in dirs:
68 | x, y = j, i
69 | count = 0
70 | for k in xrange(5):
71 | if self.get(y, x) != id: break
72 | y += d[0]
73 | x += d[1]
74 | count += 1
75 | if count == 5:
76 | self.won = {}
77 | r, c = i, j
78 | for z in xrange(5):
79 | self.won[(r, c)] = 1
80 | r += d[0]
81 | c += d[1]
82 | return id
83 | return 0
84 |
85 | # 返回数组对象
86 | def board(self):
87 | return self.__board
88 |
89 | # 导出棋局到字符串
90 | def dumps(self):
91 | import StringIO
92 | sio = StringIO.StringIO()
93 | board = self.__board
94 | for i in xrange(15):
95 | for j in xrange(15):
96 | stone = board[i][j]
97 | if stone != 0:
98 | ti = chr(ord('A') + i)
99 | tj = chr(ord('A') + j)
100 | sio.write('%d:%s%s ' % (stone, ti, tj))
101 | return sio.getvalue()
102 |
103 | # 从字符串加载棋局
104 | def loads(self, text):
105 | self.reset()
106 | board = self.__board
107 | for item in text.strip('\r\n\t ').replace(',', ' ').split(' '):
108 | n = item.strip('\r\n\t ')
109 | if not n: continue
110 | n = n.split(':')
111 | stone = int(n[0])
112 | i = ord(n[1][0].upper()) - ord('A')
113 | j = ord(n[1][1].upper()) - ord('A')
114 | board[i][j] = stone
115 | return 0
116 |
117 | # 设置终端颜色
118 | def console(self, color):
119 | if sys.platform[:3] == 'win':
120 | try:
121 | import ctypes
122 | except:
123 | return 0
124 | kernel32 = ctypes.windll.LoadLibrary('kernel32.dll')
125 | GetStdHandle = kernel32.GetStdHandle
126 | SetConsoleTextAttribute = kernel32.SetConsoleTextAttribute
127 | GetStdHandle.argtypes = [ctypes.c_uint32]
128 | GetStdHandle.restype = ctypes.c_size_t
129 | SetConsoleTextAttribute.argtypes = [ctypes.c_size_t, ctypes.c_uint16]
130 | SetConsoleTextAttribute.restype = ctypes.c_long
131 | handle = GetStdHandle(0xfffffff5)
132 | if color < 0: color = 7
133 | result = 0
134 | if (color & 1): result |= 4
135 | if (color & 2): result |= 2
136 | if (color & 4): result |= 1
137 | if (color & 8): result |= 8
138 | if (color & 16): result |= 64
139 | if (color & 32): result |= 32
140 | if (color & 64): result |= 16
141 | if (color & 128): result |= 128
142 | SetConsoleTextAttribute(handle, result)
143 | else:
144 | if color >= 0:
145 | foreground = color & 7
146 | background = (color >> 4) & 7
147 | bold = color & 8
148 | sys.stdout.write(" \033[%s3%d;4%dm" % (bold and "01;" or "", foreground, background))
149 | sys.stdout.flush()
150 | else:
151 | sys.stdout.write(" \033[0m")
152 | sys.stdout.flush()
153 | return 0
154 |
155 | # 彩色输出
156 | def show(self):
157 | print ' A B C D E F G H I J K L M N O'
158 | mark = ('. ', 'O ', 'X ')
159 | nrow = 0
160 | self.check()
161 | color1 = 10
162 | color2 = 13
163 | for row in xrange(15):
164 | print chr(ord('A') + row),
165 | for col in xrange(15):
166 | ch = self.__board[row][col]
167 | if ch == 0:
168 | self.console(-1)
169 | print '.',
170 | elif ch == 1:
171 | if (row, col) in self.won:
172 | self.console(9)
173 | else:
174 | self.console(10)
175 | print 'O',
176 | # self.console(-1)
177 | elif ch == 2:
178 | if (row, col) in self.won:
179 | self.console(9)
180 | else:
181 | self.console(13)
182 | print 'X',
183 | # self.console(-1)
184 | self.console(-1)
185 | print ''
186 | return 0
187 |
188 |
189 | # ----------------------------------------------------------------------
190 | # evaluation: 棋盘评估类,给当前棋盘打分用
191 | # ----------------------------------------------------------------------
192 | class evaluation(object):
193 |
194 | def __init__(self):
195 | self.POS = []
196 | for i in xrange(15):
197 | row = [(7 - max(abs(i - 7), abs(j - 7))) for j in xrange(15)]
198 | self.POS.append(tuple(row))
199 | self.POS = tuple(self.POS)
200 | self.STWO = 1 # 冲二
201 | self.STHREE = 2 # 冲三
202 | self.SFOUR = 3 # 冲四
203 | self.TWO = 4 # 活二
204 | self.THREE = 5 # 活三
205 | self.FOUR = 6 # 活四
206 | self.FIVE = 7 # 活五
207 | self.DFOUR = 8 # 双四
208 | self.FOURT = 9 # 四三
209 | self.DTHREE = 10 # 双三
210 | self.NOTYPE = 11
211 | self.ANALYSED = 255 # 已经分析过
212 | self.TODO = 0 # 没有分析过
213 | self.result = [0 for i in xrange(30)] # 保存当前直线分析值
214 | self.line = [0 for i in xrange(30)] # 当前直线数据
215 | self.record = [] # 全盘分析结果 [row][col][方向]
216 | for i in xrange(15):
217 | self.record.append([])
218 | self.record[i] = []
219 | for j in xrange(15):
220 | self.record[i].append([0, 0, 0, 0])
221 | self.count = [] # 每种棋局的个数:count[黑棋/白棋][模式]
222 | for i in xrange(3):
223 | data = [0 for i in xrange(20)]
224 | self.count.append(data)
225 | self.reset()
226 |
227 | # 复位数据
228 | def reset(self):
229 | TODO = self.TODO
230 | count = self.count
231 | for i in xrange(15):
232 | line = self.record[i]
233 | for j in xrange(15):
234 | line[j][0] = TODO
235 | line[j][1] = TODO
236 | line[j][2] = TODO
237 | line[j][3] = TODO
238 | for i in xrange(20):
239 | count[0][i] = 0
240 | count[1][i] = 0
241 | count[2][i] = 0
242 | return 0
243 |
244 | # 四个方向(水平,垂直,左斜,右斜)分析评估棋盘,然后根据分析结果打分
245 | def evaluate(self, board, turn):
246 | score = self.__evaluate(board, turn)
247 | count = self.count
248 | if score < -9000:
249 | stone = turn == 1 and 2 or 1
250 | for i in xrange(20):
251 | if count[stone][i] > 0:
252 | score -= i
253 | elif score > 9000:
254 | stone = turn == 1 and 2 or 1
255 | for i in xrange(20):
256 | if count[turn][i] > 0:
257 | score += i
258 | return score
259 |
260 | # 四个方向(水平,垂直,左斜,右斜)分析评估棋盘,然后根据分析结果打分
261 | def __evaluate(self, board, turn):
262 | record, count = self.record, self.count
263 | TODO, ANALYSED = self.TODO, self.ANALYSED
264 | self.reset()
265 | # 四个方向分析
266 | for i in xrange(15):
267 | boardrow = board[i]
268 | recordrow = record[i]
269 | for j in xrange(15):
270 | if boardrow[j] != 0:
271 | if recordrow[j][0] == TODO: # 水平没有分析过?
272 | self.__analysis_horizon(board, i, j)
273 | if recordrow[j][1] == TODO: # 垂直没有分析过?
274 | self.__analysis_vertical(board, i, j)
275 | if recordrow[j][2] == TODO: # 左斜没有分析过?
276 | self.__analysis_left(board, i, j)
277 | if recordrow[j][3] == TODO: # 右斜没有分析过
278 | self.__analysis_right(board, i, j)
279 |
280 | FIVE, FOUR, THREE, TWO = self.FIVE, self.FOUR, self.THREE, self.TWO
281 | SFOUR, STHREE, STWO = self.SFOUR, self.STHREE, self.STWO
282 | check = {}
283 |
284 | # 分别对白棋黑棋计算:FIVE, FOUR, THREE, TWO等出现的次数
285 | for c in (FIVE, FOUR, SFOUR, THREE, STHREE, TWO, STWO):
286 | check[c] = 1
287 | for i in xrange(15):
288 | for j in xrange(15):
289 | stone = board[i][j]
290 | if stone != 0:
291 | for k in xrange(4):
292 | ch = record[i][j][k]
293 | if ch in check:
294 | count[stone][ch] += 1
295 |
296 | # 如果有五连则马上返回分数
297 | BLACK, WHITE = 1, 2
298 | if turn == WHITE: # 当前是白棋
299 | if count[BLACK][FIVE]:
300 | return -9999
301 | if count[WHITE][FIVE]:
302 | return 9999
303 | else: # 当前是黑棋
304 | if count[WHITE][FIVE]:
305 | return -9999
306 | if count[BLACK][FIVE]:
307 | return 9999
308 |
309 | # 如果存在两个冲四,则相当于有一个活四
310 | if count[WHITE][SFOUR] >= 2:
311 | count[WHITE][FOUR] += 1
312 | if count[BLACK][SFOUR] >= 2:
313 | count[BLACK][FOUR] += 1
314 |
315 | # 具体打分
316 | wvalue, bvalue, win = 0, 0, 0
317 | if turn == WHITE:
318 | if count[WHITE][FOUR] > 0: return 9990
319 | if count[WHITE][SFOUR] > 0: return 9980
320 | if count[BLACK][FOUR] > 0: return -9970
321 | if count[BLACK][SFOUR] and count[BLACK][THREE]:
322 | return -9960
323 | if count[WHITE][THREE] and count[BLACK][SFOUR] == 0:
324 | return 9950
325 | if count[BLACK][THREE] > 1 and \
326 | count[WHITE][SFOUR] == 0 and \
327 | count[WHITE][THREE] == 0 and \
328 | count[WHITE][STHREE] == 0:
329 | return -9940
330 | if count[WHITE][THREE] > 1:
331 | wvalue += 2000
332 | elif count[WHITE][THREE]:
333 | wvalue += 200
334 | if count[BLACK][THREE] > 1:
335 | bvalue += 500
336 | elif count[BLACK][THREE]:
337 | bvalue += 100
338 | if count[WHITE][STHREE]:
339 | wvalue += count[WHITE][STHREE] * 10
340 | if count[BLACK][STHREE]:
341 | bvalue += count[BLACK][STHREE] * 10
342 | if count[WHITE][TWO]:
343 | wvalue += count[WHITE][TWO] * 4
344 | if count[BLACK][TWO]:
345 | bvalue += count[BLACK][TWO] * 4
346 | if count[WHITE][STWO]:
347 | wvalue += count[WHITE][STWO]
348 | if count[BLACK][STWO]:
349 | bvalue += count[BLACK][STWO]
350 | else:
351 | if count[BLACK][FOUR] > 0: return 9990
352 | if count[BLACK][SFOUR] > 0: return 9980
353 | if count[WHITE][FOUR] > 0: return -9970
354 | if count[WHITE][SFOUR] and count[WHITE][THREE]:
355 | return -9960
356 | if count[BLACK][THREE] and count[WHITE][SFOUR] == 0:
357 | return 9950
358 | if count[WHITE][THREE] > 1 and \
359 | count[BLACK][SFOUR] == 0 and \
360 | count[BLACK][THREE] == 0 and \
361 | count[BLACK][STHREE] == 0:
362 | return -9940
363 | if count[BLACK][THREE] > 1:
364 | bvalue += 2000
365 | elif count[BLACK][THREE]:
366 | bvalue += 200
367 | if count[WHITE][THREE] > 1:
368 | wvalue += 500
369 | elif count[WHITE][THREE]:
370 | wvalue += 100
371 | if count[BLACK][STHREE]:
372 | bvalue += count[BLACK][STHREE] * 10
373 | if count[WHITE][STHREE]:
374 | wvalue += count[WHITE][STHREE] * 10
375 | if count[BLACK][TWO]:
376 | bvalue += count[BLACK][TWO] * 4
377 | if count[WHITE][TWO]:
378 | wvalue += count[WHITE][TWO] * 4
379 | if count[BLACK][STWO]:
380 | bvalue += count[BLACK][STWO]
381 | if count[WHITE][STWO]:
382 | wvalue += count[WHITE][STWO]
383 |
384 | # 加上位置权值,棋盘最中心点权值是7,往外一格-1,最外圈是0
385 | wc, bc = 0, 0
386 | for i in xrange(15):
387 | for j in xrange(15):
388 | stone = board[i][j]
389 | if stone != 0:
390 | if stone == WHITE:
391 | wc += self.POS[i][j]
392 | else:
393 | bc += self.POS[i][j]
394 | wvalue += wc
395 | bvalue += bc
396 |
397 | if turn == WHITE:
398 | return wvalue - bvalue
399 |
400 | return bvalue - wvalue
401 |
402 | # 分析横向
403 | def __analysis_horizon(self, board, i, j):
404 | line, result, record = self.line, self.result, self.record
405 | TODO = self.TODO
406 | for x in xrange(15):
407 | line[x] = board[i][x]
408 | self.analysis_line(line, result, 15, j)
409 | for x in xrange(15):
410 | if result[x] != TODO:
411 | record[i][x][0] = result[x]
412 | return record[i][j][0]
413 |
414 | # 分析横向
415 | def __analysis_vertical(self, board, i, j):
416 | line, result, record = self.line, self.result, self.record
417 | TODO = self.TODO
418 | for x in xrange(15):
419 | line[x] = board[x][j]
420 | self.analysis_line(line, result, 15, i)
421 | for x in xrange(15):
422 | if result[x] != TODO:
423 | record[x][j][1] = result[x]
424 | return record[i][j][1]
425 |
426 | # 分析左斜
427 | def __analysis_left(self, board, i, j):
428 | line, result, record = self.line, self.result, self.record
429 | TODO = self.TODO
430 | if i < j:
431 | x, y = j - i, 0
432 | else:
433 | x, y = 0, i - j
434 | k = 0
435 | while k < 15:
436 | if x + k > 14 or y + k > 14:
437 | break
438 | line[k] = board[y + k][x + k]
439 | k += 1
440 | self.analysis_line(line, result, k, j - x)
441 | for s in xrange(k):
442 | if result[s] != TODO:
443 | record[y + s][x + s][2] = result[s]
444 | return record[i][j][2]
445 |
446 | # 分析右斜
447 | def __analysis_right(self, board, i, j):
448 | line, result, record = self.line, self.result, self.record
449 | TODO = self.TODO
450 | if 14 - i < j:
451 | x, y, realnum = j - 14 + i, 14, 14 - i
452 | else:
453 | x, y, realnum = 0, i + j, j
454 | k = 0
455 | while k < 15:
456 | if x + k > 14 or y - k < 0:
457 | break
458 | line[k] = board[y - k][x + k]
459 | k += 1
460 | self.analysis_line(line, result, k, j - x)
461 | for s in xrange(k):
462 | if result[s] != TODO:
463 | record[y - s][x + s][3] = result[s]
464 | return record[i][j][3]
465 |
466 | def test(self, board):
467 | self.reset()
468 | record = self.record
469 | TODO = self.TODO
470 | for i in xrange(15):
471 | for j in xrange(15):
472 | if board[i][j] != 0 and 1:
473 | if self.record[i][j][0] == TODO:
474 | self.__analysis_horizon(board, i, j)
475 | pass
476 | if self.record[i][j][1] == TODO:
477 | self.__analysis_vertical(board, i, j)
478 | pass
479 | if self.record[i][j][2] == TODO:
480 | self.__analysis_left(board, i, j)
481 | pass
482 | if self.record[i][j][3] == TODO:
483 | self.__analysis_right(board, i, j)
484 | pass
485 | return 0
486 |
487 | # 分析一条线:五四三二等棋型
488 | def analysis_line(self, line, record, num, pos):
489 | TODO, ANALYSED = self.TODO, self.ANALYSED
490 | THREE, STHREE = self.THREE, self.STHREE
491 | FOUR, SFOUR = self.FOUR, self.SFOUR
492 | while len(line) < 30: line.append(0xf)
493 | while len(record) < 30: record.append(TODO)
494 | for i in xrange(num, 30):
495 | line[i] = 0xf
496 | for i in xrange(num):
497 | record[i] = TODO
498 | if num < 5:
499 | for i in xrange(num):
500 | record[i] = ANALYSED
501 | return 0
502 | stone = line[pos]
503 | inverse = (0, 2, 1)[stone]
504 | num -= 1
505 | xl = pos
506 | xr = pos
507 | while xl > 0: # 探索左边界
508 | if line[xl - 1] != stone: break
509 | xl -= 1
510 | while xr < num: # 探索右边界
511 | if line[xr + 1] != stone: break
512 | xr += 1
513 | left_range = xl
514 | right_range = xr
515 | while left_range > 0: # 探索左边范围(非对方棋子的格子坐标)
516 | if line[left_range - 1] == inverse: break
517 | left_range -= 1
518 | while right_range < num: # 探索右边范围(非对方棋子的格子坐标)
519 | if line[right_range + 1] == inverse: break
520 | right_range += 1
521 |
522 | # 如果该直线范围小于 5,则直接返回
523 | if right_range - left_range < 4:
524 | for k in xrange(left_range, right_range + 1):
525 | record[k] = ANALYSED
526 | return 0
527 |
528 | # 设置已经分析过
529 | for k in xrange(xl, xr + 1):
530 | record[k] = ANALYSED
531 |
532 | srange = xr - xl
533 |
534 | # 如果是 5连
535 | if srange >= 4:
536 | record[pos] = self.FIVE
537 | return self.FIVE
538 |
539 | # 如果是 4连
540 | if srange == 3:
541 | leftfour = False # 是否左边是空格
542 | if xl > 0:
543 | if line[xl - 1] == 0: # 活四
544 | leftfour = True
545 | if xr < num:
546 | if line[xr + 1] == 0:
547 | if leftfour:
548 | record[pos] = self.FOUR # 活四
549 | else:
550 | record[pos] = self.SFOUR # 冲四
551 | else:
552 | if leftfour:
553 | record[pos] = self.SFOUR # 冲四
554 | else:
555 | if leftfour:
556 | record[pos] = self.SFOUR # 冲四
557 | return record[pos]
558 |
559 | # 如果是 3连
560 | if srange == 2: # 三连
561 | left3 = False # 是否左边是空格
562 | if xl > 0:
563 | if line[xl - 1] == 0: # 左边有气
564 | if xl > 1 and line[xl - 2] == stone:
565 | record[xl] = SFOUR
566 | record[xl - 2] = ANALYSED
567 | else:
568 | left3 = True
569 | elif xr == num or line[xr + 1] != 0:
570 | return 0
571 | if xr < num:
572 | if line[xr + 1] == 0: # 右边有气
573 | if xr < num - 1 and line[xr + 2] == stone:
574 | record[xr] = SFOUR # XXX-X 相当于冲四
575 | record[xr + 2] = ANALYSED
576 | elif left3:
577 | record[xr] = THREE
578 | else:
579 | record[xr] = STHREE
580 | elif record[xl] == SFOUR:
581 | return record[xl]
582 | elif left3:
583 | record[pos] = STHREE
584 | else:
585 | if record[xl] == SFOUR:
586 | return record[xl]
587 | if left3:
588 | record[pos] = STHREE
589 | return record[pos]
590 |
591 | # 如果是 2连
592 | if srange == 1: # 两连
593 | left2 = False
594 | if xl > 2:
595 | if line[xl - 1] == 0: # 左边有气
596 | if line[xl - 2] == stone:
597 | if line[xl - 3] == stone:
598 | record[xl - 3] = ANALYSED
599 | record[xl - 2] = ANALYSED
600 | record[xl] = SFOUR
601 | elif line[xl - 3] == 0:
602 | record[xl - 2] = ANALYSED
603 | record[xl] = STHREE
604 | else:
605 | left2 = True
606 | if xr < num:
607 | if line[xr + 1] == 0: # 左边有气
608 | if xr < num - 2 and line[xr + 2] == stone:
609 | if line[xr + 3] == stone:
610 | record[xr + 3] = ANALYSED
611 | record[xr + 2] = ANALYSED
612 | record[xr] = SFOUR
613 | elif line[xr + 3] == 0:
614 | record[xr + 2] = ANALYSED
615 | record[xr] = left2 and THREE or STHREE
616 | else:
617 | if record[xl] == SFOUR:
618 | return record[xl]
619 | if record[xl] == STHREE:
620 | record[xl] = THREE
621 | return record[xl]
622 | if left2:
623 | record[pos] = self.TWO
624 | else:
625 | record[pos] = self.STWO
626 | else:
627 | if record[xl] == SFOUR:
628 | return record[xl]
629 | if left2:
630 | record[pos] = self.STWO
631 | return record[pos]
632 | return 0
633 |
634 | def textrec(self, direction=0):
635 | text = []
636 | for i in xrange(15):
637 | line = ''
638 | for j in xrange(15):
639 | line += '%x ' % (self.record[i][j][direction] & 0xf)
640 | text.append(line)
641 | return '\n'.join(text)
642 |
643 |
644 | # ----------------------------------------------------------------------
645 | # DFS: 博弈树搜索
646 | # ----------------------------------------------------------------------
647 | class searcher(object):
648 |
649 | # 初始化
650 | def __init__(self):
651 | self.evaluator = evaluation()
652 | self.board = [[0 for n in xrange(15)] for i in xrange(15)]
653 | self.gameover = 0
654 | self.overvalue = 0
655 | self.maxdepth = 3
656 |
657 | # 产生当前棋局的走法
658 | def genmove(self, turn):
659 | moves = []
660 | board = self.board
661 | POSES = self.evaluator.POS
662 | for i in xrange(15):
663 | for j in xrange(15):
664 | if board[i][j] == 0:
665 | score = POSES[i][j]
666 | moves.append((score, i, j))
667 | moves.sort()
668 | moves.reverse()
669 | return moves
670 |
671 | # 递归搜索:返回最佳分数
672 | def __search(self, turn, depth, alpha=-0x7fffffff, beta=0x7fffffff):
673 |
674 | # 深度为零则评估棋盘并返回
675 | if depth <= 0:
676 | score = self.evaluator.evaluate(self.board, turn)
677 | return score
678 |
679 | # 如果游戏结束则立马返回
680 | score = self.evaluator.evaluate(self.board, turn)
681 | if abs(score) >= 9999 and depth < self.maxdepth:
682 | return score
683 |
684 | # 产生新的走法
685 | moves = self.genmove(turn)
686 | bestmove = None
687 |
688 | # 枚举当前所有走法
689 | for score, row, col in moves:
690 |
691 | # 标记当前走法到棋盘
692 | self.board[row][col] = turn
693 |
694 | # 计算下一回合该谁走
695 | nturn = turn == 1 and 2 or 1
696 |
697 | # 深度优先搜索,返回评分,走的行和走的列
698 | score = - self.__search(nturn, depth - 1, -beta, -alpha)
699 |
700 | # 棋盘上清除当前走法
701 | self.board[row][col] = 0
702 |
703 | # 计算最好分值的走法
704 | # alpha/beta 剪枝
705 | if score > alpha:
706 | alpha = score
707 | bestmove = (row, col)
708 | if alpha >= beta:
709 | break
710 |
711 | # 如果是第一层则记录最好的走法
712 | if depth == self.maxdepth and bestmove:
713 | self.bestmove = bestmove
714 |
715 | # 返回当前最好的分数,和该分数的对应走法
716 | return alpha
717 |
718 | # 具体搜索:传入当前是该谁走(turn=1/2),以及搜索深度(depth)
719 | def search(self, turn, depth=3):
720 | self.maxdepth = depth # 0.7 的概率按照depth搜索
721 | self.bestmove = None
722 | score = self.__search(turn, depth)
723 | if abs(score) > 8000:
724 | self.maxdepth = depth # 0.85 的概率按照depth搜索
725 | score = self.__search(turn, 1)
726 | row, col = self.bestmove
727 | return score, row, col
728 |
729 |
730 | # ----------------------------------------------------------------------
731 | # psyco speedup
732 | # ----------------------------------------------------------------------
733 | def psyco_speedup():
734 | try:
735 | import psyco
736 | psyco.bind(chessboard)
737 | psyco.bind(evaluation)
738 | except:
739 | pass
740 | return 0
741 |
742 |
743 | psyco_speedup()
744 |
745 |
746 | # ----------------------------------------------------------------------
747 | # main game
748 | # ----------------------------------------------------------------------
749 | def gamemain():
750 | b = chessboard()
751 | # 黑手AI
752 | s_blank = searcher()
753 | s_blank.board = b.board()
754 | s = searcher()
755 | s.board = b.board()
756 |
757 | opening = [
758 | '1:HH 2:GI',
759 | '2:IG 2:GI 1:HH',
760 | '1:IH 2:GI',
761 | '1:HG 2:HI',
762 | '2:HG 2:HI 1:HH',
763 | '1:HH 2:IH 2:GI',
764 | '1:HH 2:IH 2:HI',
765 | '1:HH 2:IH 2:HJ',
766 | '1:HG 2:HH 2:HI',
767 | '1:GH 2:HH 2:HI',
768 | ]
769 |
770 | openid = random.randint(0, len(opening) - 1)
771 | b.loads(opening[0])
772 | turn = 2
773 | history = []
774 | undo = False
775 |
776 | # 设置难度
777 | DEPTH = 1
778 | DEPTH_BLACK = 5
779 |
780 | while 1:
781 | print ''
782 | while 1:
783 | print '' % (len(history) + 1)
784 | b.show()
785 | print 'Robot move (u:undo, q:quit): \n',
786 | # 黑手AI自动下
787 | # text = raw_input().strip('\r\n\t ')
788 | score_b, tr, tc = s.search(1, DEPTH_BLACK)
789 | cord_b = '%s%s' % (chr(ord('A') + tr), chr(ord('A') + tc))
790 | print 'Rotob move to %s (%d) \n' % (cord_b, score_b)
791 | # if len(text) == 2:
792 | # tr = ord(text[0].upper()) - ord('A')
793 | # tc = ord(text[1].upper()) - ord('A')
794 | if tr >= 0 and tc >= 0 and tr < 15 and tc < 15:
795 | if b[tr][tc] == 0:
796 | row, col = tr, tc
797 | break
798 | else:
799 | print 'can not move there'
800 | else:
801 | print 'bad position'
802 |
803 | if undo == True:
804 | undo = False
805 | if len(history) == 0:
806 | print 'no history to undo'
807 | else:
808 | print 'rollback from history ...'
809 | move = history.pop()
810 | b.loads(move)
811 | else:
812 | history.append(b.dumps())
813 | b[row][col] = 1
814 | b.show()
815 | time.sleep(0.2)
816 |
817 | if b.check() == 1:
818 | b.show()
819 | print b.dumps()
820 | print ''
821 | print 'YOU WIN !!'
822 | return 0
823 |
824 | print 'you should input ...\n'
825 | text = raw_input().strip('\r\n\t ')
826 | if len(text) == 2:
827 | tr = ord(text[0].upper()) - ord('A')
828 | tc = ord(text[1].upper()) - ord('A')
829 | if tr >= 0 and tc >= 0 and tr < 15 and tc < 15:
830 | if b[tr][tc] == 0:
831 | row, col = tr, tc
832 | # break
833 | else:
834 | print 'can not move there'
835 | else:
836 | print 'bad position'
837 | # xtt = input('go on: ')
838 | # score, row, col = s.search(2, DEPTH)
839 | cord = '%s%s' % (chr(ord('A') + row), chr(ord('A') + col))
840 | print 'robot move to %s ' % cord
841 | # xtt = input('go on: ')
842 | b[row][col] = 2
843 | time.sleep(0.2)
844 |
845 | if b.check() == 2:
846 | b.show()
847 | print b.dumps()
848 | print ''
849 | print 'YOU LOSE.'
850 | return 0
851 |
852 | return 0
853 |
854 |
855 | # ----------------------------------------------------------------------
856 | # testing case
857 | # ----------------------------------------------------------------------
858 | if __name__ == '__main__':
859 | start_time = time.time()
860 | gamemain()
861 | print('耗时: ', time.time() - start_time)
862 |
863 |
864 |
--------------------------------------------------------------------------------
/policy_value_net_mxnet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | An implementation of the policyValueNet with Keras
4 | Tested under Keras 2.0.5 with tensorflow-gpu 1.2.1 as backend
5 |
6 | @author: Mingxu Zhang
7 | """
8 | from __future__ import print_function
9 | import sys
10 | import os
11 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
12 | # sys.path.insert(0, '/home/mingzhang/work/dmlc/python_mxnet/python')
13 |
14 | import mxnet as mx
15 | import numpy as np
16 | import pickle
17 |
18 |
19 | class PolicyValueNet():
20 | """policy-value network """
21 | def __init__(self, board_width, board_height, batch_size=512, n_blocks=8, n_filter=128, model_params=None):
22 | self.context = mx.gpu(0)
23 | self.batchsize = batch_size #must same to the TrainPipeline's self.batch_size.
24 | self.channelnum = 9
25 | self.board_width = board_width
26 | self.board_height = board_height
27 | self._n_blocks = n_blocks
28 | self._n_filter = n_filter
29 | self.l2_const = 1e-4 # coef of l2 penalty
30 | self.train_batch = self.create_policy_value_train(self.batchsize)
31 | self.predict_batch = self.create_policy_value_predict(self.batchsize)
32 | self.predict_one = self.create_policy_value_predict(1)
33 | self.num = 0
34 |
35 | if model_params:
36 | self.train_batch.set_params(*model_params)
37 | self.predict_batch.set_params(*model_params)
38 | self.predict_one.set_params(*model_params)
39 | pass
40 |
41 | def conv_act(self, data, num_filter=32, kernel=(3, 3), stride=(1, 1), act='relu', dobn=True, name=''):
42 | # self convolution activation
43 | assert(name!='' and name!=None)
44 | pad = (int(kernel[0]/2), int(kernel[1]/2))
45 | w = mx.sym.Variable(name+'_weight')
46 | b = mx.sym.Variable(name+'_bias')
47 | conv1 = mx.sym.Convolution(data=data, weight=w, bias=b, num_filter=num_filter, kernel=kernel, pad=pad, name=name)
48 | act1 = conv1
49 | if dobn:
50 | gamma = mx.sym.Variable(name+'_gamma')
51 | beta = mx.sym.Variable(name+'_beta')
52 | mean = mx.sym.Variable(name+'_mean')
53 | var = mx.sym.Variable(name+'_var')
54 | bn = mx.sym.BatchNorm(data=conv1, gamma=gamma, beta=beta, moving_mean=mean, moving_var=var, name=name+'_bn')
55 | act1 = bn
56 | if act is not None and act!='':
57 | #print('....', act)
58 | act1 = mx.sym.Activation(data=act1, act_type=act, name=name+'_act')
59 |
60 | return act1
61 |
62 | def fc_self(self, data, num_hidden, name=''):
63 | assert(name!='' and name!=None)
64 | w = mx.sym.Variable(name+'_weight')
65 | b = mx.sym.Variable(name+'_bias')
66 | fc_1 = mx.sym.FullyConnected(data, weight=w, bias=b, num_hidden=num_hidden, name=name)
67 |
68 | return fc_1
69 |
70 | def create_backbone_resnet(self, input_states):
71 | """ 8层残差网络 """
72 |
73 | con_net = self.conv_act(input_states, 128, (3, 3), name='res_conv1')
74 | for i in range(1, self._n_blocks+1):
75 | # 残差结构定义
76 | pre_identity = con_net # 保存残差之前部分
77 | con_net = mx.sym.Convolution(con_net, name='convA'+str(i), kernel=(3, 3), pad=(1, 1), num_filter=self._n_filter)
78 | con_net = mx.sym.BatchNorm(con_net, name='bnA'+str(i), fix_gamma=False)
79 | con_net = mx.sym.Activation(con_net, name='actA'+str(i), act_type='relu')
80 | con_net = mx.sym.Convolution(con_net, name='convB'+str(i), kernel=(3, 3), pad=(1, 1), num_filter=self._n_filter)
81 | con_net = mx.sym.BatchNorm(con_net, name='bnB'+str(i), fix_gamma=False)
82 | con_net = con_net + pre_identity # 加上之前的输出,即为残差结构
83 | con_net = mx.sym.Activation(con_net, name='actB'+str(i), act_type='relu')
84 |
85 | # action output
86 | conv3_act = self.conv_act(con_net, 4, (1, 1), name='conv3_1_1')
87 | flatten_1 = mx.sym.Flatten(conv3_act)
88 | flatten_1 = mx.sym.Dropout(flatten_1, p=0.5)
89 | fc_3_1_1 = self.fc_self(flatten_1, self.board_height*self.board_width, name='fc_3_1_1')
90 | action_1 = mx.sym.SoftmaxActivation(fc_3_1_1, name='Act_SILER')
91 |
92 | # value output
93 | conv3_2_1 = self.conv_act(con_net, 2, (1, 1), name='conv3_2_1')
94 | flatten_2 = mx.sym.Flatten(conv3_2_1)
95 | flatten_2 = mx.sym.Dropout(flatten_2, p=0.5)
96 | fc_3_2_1 = self.fc_self(flatten_2, 1, name='fc_3_2_1')
97 | evaluation = mx.sym.Activation(fc_3_2_1, act_type='tanh')
98 |
99 | # mx.viz.plot_network(action_1).view()
100 | # mx.viz.plot_network(evaluation).view()
101 |
102 | return action_1, evaluation
103 |
104 |
105 | def create_backbone(self, input_states):
106 | """ 原始策略价值网络 """
107 |
108 | conv1 = self.conv_act(input_states, 64, (3, 3), name='conv1')
109 | conv2 = self.conv_act(conv1, 64, (3, 3), name='conv2')
110 | conv3 = self.conv_act(conv2, 128, (3, 3), name='conv3')
111 | conv4 = self.conv_act(conv3, 128, (3, 3), name='conv4')
112 | conv5 = self.conv_act(conv4, 256, (3, 3), name='conv5')
113 | final = self.conv_act(conv5, 256, (3, 3), name='conv_final')
114 |
115 | # action policy layers
116 | conv3_1_1 = self.conv_act(final, 4, (1, 1), name='conv3_1_1')
117 | flatten_1 = mx.sym.Flatten(conv3_1_1)
118 | flatten_1 = mx.sym.Dropout(flatten_1, p=0.5)
119 | fc_3_1_1 = self.fc_self(flatten_1, self.board_height*self.board_width, name='fc_3_1_1')
120 | action_1 = mx.sym.SoftmaxActivation(fc_3_1_1, name='Act_SILER')
121 | # arg_shapes, out_shapes, aux_shapes = action_1.infer_shape()
122 | # print('arg_shape: ', arg_shapes)
123 | # print('out_shape: ', out_shapes)
124 | # print('aux_shape: ', aux_shapes)
125 | # arg_shape: [(512L, 9L, 15L, 15L), (64L, 9L, 3L, 3L), (64L,), (64L,), (64L,), (64L, 64L, 3L, 3L), (64L,), (64L,), (64L,), (128L, 64L, 3L, 3L), (128L,), (128L,), (128L,), (128L, 128L, 3L, 3L), (128L,), (128L,), (128L,), (256L, 128L, 3L, 3L), (256L,), (256L,), (256L,), (256L, 256L, 3L, 3L), (256L,), (256L,), (256L,), (4L, 256L, 1L, 1L), (4L,), (4L,), (4L,), (225L, 900L), (225L,)]
126 | # out_shape: [(512L, 225L)]
127 | # aux_shape: [(64L,), (64L,), (64L,), (64L,), (128L,), (128L,), (128L,), (128L,), (256L,), (256L,), (256L,), (256L,), (4L,), (4L,)]
128 |
129 | # state value layers
130 | conv3_2_1 = self.conv_act(final, 2, (1, 1), name='conv3_2_1')
131 | flatten_2 = mx.sym.Flatten(conv3_2_1)
132 | flatten_2 = mx.sym.Dropout(flatten_2, p=0.5)
133 | fc_3_2_1 = self.fc_self(flatten_2, 1, name='fc_3_2_1')
134 | evaluation = mx.sym.Activation(fc_3_2_1, act_type='tanh')
135 | # args_shapes, out_shapes, aux_shapes = evaluation.infer_shape()
136 | # args_shape: [(5L, 9L, 15L, 15L), (64L, 9L, 3L, 3L), (64L,), (64L,), (64L,), (64L, 64L, 3L, 3L), (64L,), (64L,), (64L,), (128L, 64L, 3L, 3L), (128L,), (128L,), (128L,), (128L, 128L, 3L, 3L), (128L,), (128L,), (128L,), (256L, 128L, 3L, 3L), (256L,), (256L,), (256L,), (256L, 256L, 3L, 3L), (256L,), (256L,), (256L,), (2L, 256L, 1L, 1L), (2L,), (2L,), (2L,), (1L, 450L), (1L,)]
137 | # out_shapes: [(5L, 1L)]
138 | # aux_shapes: [(64L,), (64L,), (64L,), (64L,), (128L,), (128L,), (128L,), (128L,), (256L,), (256L,), (256L,), (256L,), (2L,), (2L,)]
139 | mx.viz.plot_network(action_1).view()
140 | # mx.viz.plot_network(evaluation).view()
141 | go_on = input("go on?:")
142 | if go_on == 0: exit()
143 |
144 | return action_1, evaluation
145 |
146 |
147 |
148 | def create_backbone2(self, input_states):
149 | """create the policy value network """
150 |
151 | conv1 = self.conv_act(input_states, 32, (3, 3), name='conv1')
152 | conv2 = self.conv_act(conv1, 64, (3, 3), name='conv2')
153 | conv3 = self.conv_act(conv2, 128, (3, 3), name='conv3')
154 | conv4 = self.conv_act(conv3, 128, (3, 3), name='conv4')
155 | conv5 = self.conv_act(conv4, 128, (3, 3), name='conv5')
156 | final = self.conv_act(conv5, 128, (3, 3), name='conv_final')
157 |
158 | # action policy layers
159 | conv3_1_1 = self.conv_act(final, 1024, (1, 1), name='conv3_1_1')
160 | conv3_1_2 = self.conv_act(conv3_1_1, 1, (1, 1), act=None, dobn=False, name='conv3_1_2')
161 | flatten_1 = mx.sym.Flatten(conv3_1_2)
162 | action_1 = mx.sym.SoftmaxActivation(flatten_1)
163 |
164 | # state value layers
165 | conv3_2_1 = self.conv_act(final, 256, (1, 1), name='conv3_2_1')
166 | conv3_2_2 = self.conv_act(conv3_2_1, 1, (1, 1), act=None, dobn=False, name='conv3_2_2')
167 | flatten_2 = mx.sym.Flatten(conv3_2_2)
168 | mean_2 = mx.sym.mean(flatten_2, axis=1, keepdims=True)
169 | evaluation = mx.sym.Activation(mean_2, act_type='tanh')
170 |
171 | return action_1, evaluation
172 |
173 | def create_policy_value_train(self, batch_size):
174 | input_states_shape = (batch_size, self.channelnum, self.board_height, self.board_width)
175 | input_states = mx.sym.Variable(name='input_states', shape=input_states_shape)
176 | # action_1, evaluation = self.create_backbone(input_states)
177 | action_1, evaluation = self.create_backbone_resnet(input_states)
178 |
179 | mcts_probs_shape = (batch_size, self.board_height * self.board_width)
180 | mcts_probs = mx.sym.Variable(name='mcts_probs', shape=mcts_probs_shape)
181 | policy_loss = -mx.sym.sum(mx.sym.log(action_1) * mcts_probs, axis=1)
182 | policy_loss = mx.sym.mean(policy_loss)
183 |
184 | input_labels_shape = (batch_size, 1)
185 | input_labels = mx.sym.Variable(name='input_labels', shape=input_labels_shape)
186 | value_loss = mx.sym.mean(mx.sym.square(input_labels - evaluation))
187 |
188 | loss = value_loss + policy_loss
189 | loss = mx.sym.MakeLoss(loss)
190 |
191 | entropy = mx.sym.sum(-action_1 * mx.sym.log(action_1), axis=1)
192 | entropy = mx.sym.mean(entropy)
193 | entropy = mx.sym.BlockGrad(entropy)
194 | entropy = mx.sym.MakeLoss(entropy)
195 | policy_value_loss = mx.sym.Group([loss, entropy])
196 | policy_value_loss.save('policy_value_loss.json')
197 |
198 | pv_train = mx.mod.Module(symbol=policy_value_loss,
199 | data_names=['input_states'],
200 | label_names=['input_labels', 'mcts_probs'],
201 | context=self.context)
202 | pv_train.bind(data_shapes=[('input_states', input_states_shape)],
203 | label_shapes=[('input_labels', input_labels_shape), ('mcts_probs', mcts_probs_shape)],
204 | for_training=True)
205 | pv_train.init_params(initializer=mx.init.Xavier())
206 | pv_train.init_optimizer(optimizer='adam',
207 | optimizer_params={'learning_rate':0.001,
208 | #'clip_gradient':0.1,
209 | #'momentum':0.9,
210 | 'wd':0.0001})
211 |
212 | return pv_train
213 |
214 | def create_policy_value_predict(self, batch_size):
215 | input_states_shape = (batch_size, self.channelnum, self.board_height, self.board_width)
216 | input_states = mx.sym.Variable(name='input_states', shape=input_states_shape)
217 | # action_1, evaluation = self.create_backbone(input_states)
218 | action_1, evaluation = self.create_backbone_resnet(input_states)
219 | policy_value_output = mx.sym.Group([action_1, evaluation])
220 |
221 | pv_predict = mx.mod.Module(symbol=policy_value_output,
222 | data_names=['input_states'],
223 | label_names=None,
224 | context=self.context)
225 |
226 | pv_predict.bind(data_shapes=[('input_states', input_states_shape)], for_training=False)
227 | args, auxs = self.train_batch.get_params()
228 | pv_predict.set_params(args, auxs)
229 |
230 | return pv_predict
231 |
232 | def policy_value(self, state_batch):
233 | states = np.asarray(state_batch)
234 | #print('policy_value:', states.shape)
235 | state_nd = mx.nd.array(states)
236 | self.predict_batch.forward(mx.io.DataBatch([state_nd]))
237 | acts, vals = self.predict_batch.get_outputs()
238 | acts = acts.asnumpy()
239 | vals = vals.asnumpy()
240 | #print(acts[0], vals[0])
241 |
242 | return acts, vals
243 |
244 | def policy_value2(self, state_batch):
245 | actsall = []
246 | valsall = []
247 | for state in state_batch:
248 | state = state.reshape(1, self.channelnum, self.board_height, self.board_width)
249 | #print(state.shape)
250 | state_nd = mx.nd.array(state)
251 | self.predict_one.forward(mx.io.DataBatch([state_nd]))
252 | act, val = self.predict_one.get_outputs()
253 | actsall.append(act[0].asnumpy())
254 | valsall.append(val[0].asnumpy())
255 | acts = np.asarray(actsall)
256 | vals = np.asarray(valsall)
257 | #print(acts.shape, vals.shape)
258 |
259 | return acts, vals
260 |
261 | def policy_value_fn(self, board):
262 | """
263 | input: board
264 | output: a list of (action, probability) tuples for each available action and the score of the board state
265 | """
266 | legal_positions = board.availables
267 | current_state = board.current_state().reshape(1, self.channelnum, self.board_height, self.board_width)
268 | state_nd = mx.nd.array(current_state)
269 | self.predict_one.forward(mx.io.DataBatch([state_nd]))
270 | acts_probs, values = self.predict_one.get_outputs()
271 | acts_probs = acts_probs.asnumpy()
272 | values = values.asnumpy()
273 | #print(acts_probs[0, :4])
274 | legal_actprob = acts_probs[0][legal_positions]
275 | act_probs = zip(legal_positions, legal_actprob)
276 | # print(len(legal_positions), legal_actprob.shape, acts_probs.shape)
277 | # if len(legal_positions)==0:
278 | # exit()
279 |
280 | return act_probs, values[0]
281 |
282 | def train_step(self, state_batch, mcts_probs, winner_batch, learning_rate):
283 | # winner_batch: 1.0/-1.0
284 | #print('hello training....')
285 | #print(mcts_probs[0], winner_batch[0])
286 | self.train_batch._optimizer.lr = learning_rate
287 | state_batch = mx.nd.array(np.asarray(state_batch).reshape(-1, self.channelnum, self.board_height, self.board_width))
288 | mcts_probs = mx.nd.array(np.asarray(mcts_probs).reshape(-1, self.board_height*self.board_width))
289 | winner_batch = mx.nd.array(np.asarray(winner_batch).reshape(-1, 1))
290 | self.train_batch.forward(mx.io.DataBatch([state_batch], [winner_batch, mcts_probs]))
291 | self.train_batch.backward()
292 | self.train_batch.update()
293 | loss, entropy = self.train_batch.get_outputs()
294 |
295 | args, auxs = self.train_batch.get_params()
296 | self.predict_batch.set_params(args, auxs)
297 | self.predict_one.set_params(args, auxs)
298 |
299 | return loss.asnumpy(), entropy.asnumpy()
300 |
301 | def get_policy_param(self):
302 | net_params = self.train_batch.get_params()
303 | return net_params
304 |
305 | def save_model(self, model_file):
306 | """ save model params to file """
307 | net_params = self.get_policy_param()
308 | print('>>>>>>>>>> saved into', model_file)
309 | pickle.dump(net_params, open(model_file, 'wb'), protocol=2)
310 |
--------------------------------------------------------------------------------
/policy_value_net_mxnet_simple.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | An implementation of the policyValueNet with Keras
4 | Tested under Keras 2.0.5 with tensorflow-gpu 1.2.1 as backend
5 |
6 | @author: Mingxu Zhang
7 | """
8 | from __future__ import print_function
9 | import sys
10 | import os
11 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
12 | # sys.path.insert(0, '/home/mingzhang/work/dmlc/python_mxnet/python')
13 |
14 | import mxnet as mx
15 | import numpy as np
16 | import pickle
17 |
18 |
19 | class PolicyValueNet():
20 | """policy-value network """
21 | def __init__(self, board_width, board_height, batch_size=512, model_params=None):
22 | self.context = mx.cpu()
23 | self.batchsize = batch_size #must same to the TrainPipeline's self.batch_size.
24 | self.channelnum = 9
25 | self.board_width = board_width
26 | self.board_height = board_height
27 | self.l2_const = 1e-4 # coef of l2 penalty
28 | self.train_batch = self.create_policy_value_train(self.batchsize)
29 | self.predict_batch = self.create_policy_value_predict(self.batchsize)
30 | self.predict_one = self.create_policy_value_predict(1)
31 | self.num = 0
32 |
33 | if model_params:
34 | self.train_batch.set_params(*model_params)
35 | self.predict_batch.set_params(*model_params)
36 | self.predict_one.set_params(*model_params)
37 | pass
38 |
39 | def conv_act(self, data, num_filter=32, kernel=(3, 3), stride=(1, 1), act='relu', dobn=True, name=''):
40 | # self convolution activation
41 | assert(name!='' and name!=None)
42 | pad = (int(kernel[0]/2), int(kernel[1]/2))
43 | w = mx.sym.Variable(name+'_weight')
44 | b = mx.sym.Variable(name+'_bias')
45 | conv1 = mx.sym.Convolution(data=data, weight=w, bias=b, num_filter=num_filter, kernel=kernel, pad=pad, name=name)
46 | act1 = conv1
47 | if dobn:
48 | gamma = mx.sym.Variable(name+'_gamma')
49 | beta = mx.sym.Variable(name+'_beta')
50 | mean = mx.sym.Variable(name+'_mean')
51 | var = mx.sym.Variable(name+'_var')
52 | bn = mx.sym.BatchNorm(data=conv1, gamma=gamma, beta=beta, moving_mean=mean, moving_var=var, name=name+'_bn')
53 | act1 = bn
54 | if act is not None and act!='':
55 | #print('....', act)
56 | act1 = mx.sym.Activation(data=act1, act_type=act, name=name+'_act')
57 |
58 | return act1
59 |
60 | def fc_self(self, data, num_hidden, name=''):
61 | assert(name!='' and name!=None)
62 | w = mx.sym.Variable(name+'_weight')
63 | b = mx.sym.Variable(name+'_bias')
64 | fc_1 = mx.sym.FullyConnected(data, weight=w, bias=b, num_hidden=num_hidden, name=name)
65 |
66 | return fc_1
67 |
68 | def create_backbone(self, input_states):
69 | """create the policy value network """
70 |
71 | conv1 = self.conv_act(input_states, 64, (3, 3), name='conv1')
72 | conv2 = self.conv_act(conv1, 64, (3, 3), name='conv2')
73 | conv3 = self.conv_act(conv2, 128, (3, 3), name='conv3')
74 | conv4 = self.conv_act(conv3, 128, (3, 3), name='conv4')
75 | conv5 = self.conv_act(conv4, 256, (3, 3), name='conv5')
76 | final = self.conv_act(conv5, 256, (3, 3), name='conv_final')
77 |
78 | # action policy layers
79 | conv3_1_1 = self.conv_act(final, 4, (1, 1), name='conv3_1_1')
80 | flatten_1 = mx.sym.Flatten(conv3_1_1)
81 | flatten_1 = mx.sym.Dropout(flatten_1, p=0.5)
82 | fc_3_1_1 = self.fc_self(flatten_1, self.board_height*self.board_width, name='fc_3_1_1')
83 | action_1 = mx.sym.SoftmaxActivation(fc_3_1_1)
84 |
85 | # state value layers
86 | conv3_2_1 = self.conv_act(final, 2, (1, 1), name='conv3_2_1')
87 | flatten_2 = mx.sym.Flatten(conv3_2_1)
88 | flatten_2 = mx.sym.Dropout(flatten_2, p=0.5)
89 | fc_3_2_1 = self.fc_self(flatten_2, 1, name='fc_3_2_1')
90 | evaluation = mx.sym.Activation(fc_3_2_1, act_type='tanh')
91 |
92 | return action_1, evaluation
93 |
94 |
95 |
96 | def create_backbone2(self, input_states):
97 | """create the policy value network """
98 |
99 | conv1 = self.conv_act(input_states, 32, (3, 3), name='conv1')
100 | conv2 = self.conv_act(conv1, 64, (3, 3), name='conv2')
101 | conv3 = self.conv_act(conv2, 128, (3, 3), name='conv3')
102 | conv4 = self.conv_act(conv3, 128, (3, 3), name='conv4')
103 | conv5 = self.conv_act(conv4, 128, (3, 3), name='conv5')
104 | final = self.conv_act(conv5, 128, (3, 3), name='conv_final')
105 |
106 | # action policy layers
107 | conv3_1_1 = self.conv_act(final, 1024, (1, 1), name='conv3_1_1')
108 | conv3_1_2 = self.conv_act(conv3_1_1, 1, (1, 1), act=None, dobn=False, name='conv3_1_2')
109 | flatten_1 = mx.sym.Flatten(conv3_1_2)
110 | action_1 = mx.sym.SoftmaxActivation(flatten_1)
111 |
112 | # state value layers
113 | conv3_2_1 = self.conv_act(final, 256, (1, 1), name='conv3_2_1')
114 | conv3_2_2 = self.conv_act(conv3_2_1, 1, (1, 1), act=None, dobn=False, name='conv3_2_2')
115 | flatten_2 = mx.sym.Flatten(conv3_2_2)
116 | mean_2 = mx.sym.mean(flatten_2, axis=1, keepdims=True)
117 | evaluation = mx.sym.Activation(mean_2, act_type='tanh')
118 |
119 | return action_1, evaluation
120 |
121 | def create_policy_value_train(self, batch_size):
122 | input_states_shape = (batch_size, self.channelnum, self.board_height, self.board_width)
123 | input_states = mx.sym.Variable(name='input_states', shape=input_states_shape)
124 | action_1, evaluation = self.create_backbone(input_states)
125 |
126 | mcts_probs_shape = (batch_size, self.board_height * self.board_width)
127 | mcts_probs = mx.sym.Variable(name='mcts_probs', shape=mcts_probs_shape)
128 | policy_loss = -mx.sym.sum(mx.sym.log(action_1) * mcts_probs, axis=1)
129 | policy_loss = mx.sym.mean(policy_loss)
130 |
131 | input_labels_shape = (batch_size, 1)
132 | input_labels = mx.sym.Variable(name='input_labels', shape=input_labels_shape)
133 | value_loss = mx.sym.mean(mx.sym.square(input_labels - evaluation))
134 |
135 | loss = value_loss + policy_loss
136 | loss = mx.sym.MakeLoss(loss)
137 |
138 | entropy = mx.sym.sum(-action_1 * mx.sym.log(action_1), axis=1)
139 | entropy = mx.sym.mean(entropy)
140 | entropy = mx.sym.BlockGrad(entropy)
141 | entropy = mx.sym.MakeLoss(entropy)
142 | policy_value_loss = mx.sym.Group([loss, entropy])
143 | policy_value_loss.save('policy_value_loss.json')
144 |
145 | pv_train = mx.mod.Module(symbol=policy_value_loss,
146 | data_names=['input_states'],
147 | label_names=['input_labels', 'mcts_probs'],
148 | context=self.context)
149 | pv_train.bind(data_shapes=[('input_states', input_states_shape)],
150 | label_shapes=[('input_labels', input_labels_shape), ('mcts_probs', mcts_probs_shape)],
151 | for_training=True)
152 | pv_train.init_params(initializer=mx.init.Xavier())
153 | pv_train.init_optimizer(optimizer='adam',
154 | optimizer_params={'learning_rate':0.001,
155 | #'clip_gradient':0.1,
156 | #'momentum':0.9,
157 | 'wd':0.0001})
158 |
159 | return pv_train
160 |
161 | def create_policy_value_predict(self, batch_size):
162 | input_states_shape = (batch_size, self.channelnum, self.board_height, self.board_width)
163 | input_states = mx.sym.Variable(name='input_states', shape=input_states_shape)
164 | action_1, evaluation = self.create_backbone(input_states)
165 | policy_value_output = mx.sym.Group([action_1, evaluation])
166 |
167 | pv_predict = mx.mod.Module(symbol=policy_value_output,
168 | data_names=['input_states'],
169 | label_names=None,
170 | context=self.context)
171 |
172 | pv_predict.bind(data_shapes=[('input_states', input_states_shape)], for_training=False)
173 | args, auxs = self.train_batch.get_params()
174 | pv_predict.set_params(args, auxs)
175 |
176 | return pv_predict
177 |
178 | def policy_value(self, state_batch):
179 | states = np.asarray(state_batch)
180 | #print('policy_value:', states.shape)
181 | state_nd = mx.nd.array(states)
182 | self.predict_batch.forward(mx.io.DataBatch([state_nd]))
183 | acts, vals = self.predict_batch.get_outputs()
184 | acts = acts.asnumpy()
185 | vals = vals.asnumpy()
186 | #print(acts[0], vals[0])
187 |
188 | return acts, vals
189 |
190 | def policy_value2(self, state_batch):
191 | actsall = []
192 | valsall = []
193 | for state in state_batch:
194 | state = state.reshape(1, self.channelnum, self.board_height, self.board_width)
195 | #print(state.shape)
196 | state_nd = mx.nd.array(state)
197 | self.predict_one.forward(mx.io.DataBatch([state_nd]))
198 | act, val = self.predict_one.get_outputs()
199 | actsall.append(act[0].asnumpy())
200 | valsall.append(val[0].asnumpy())
201 | acts = np.asarray(actsall)
202 | vals = np.asarray(valsall)
203 | #print(acts.shape, vals.shape)
204 |
205 | return acts, vals
206 |
207 | def policy_value_fn(self, board):
208 | """
209 | input: board
210 | output: a list of (action, probability) tuples for each available action and the score of the board state
211 | """
212 | legal_positions = board.availables
213 | current_state = board.current_state().reshape(1, self.channelnum, self.board_height, self.board_width)
214 | state_nd = mx.nd.array(current_state)
215 | self.predict_one.forward(mx.io.DataBatch([state_nd]))
216 | acts_probs, values = self.predict_one.get_outputs()
217 | acts_probs = acts_probs.asnumpy()
218 | values = values.asnumpy()
219 | #print(acts_probs[0, :4])
220 | legal_actprob = acts_probs[0][legal_positions]
221 | act_probs = zip(legal_positions, legal_actprob)
222 | # print(len(legal_positions), legal_actprob.shape, acts_probs.shape)
223 | # if len(legal_positions)==0:
224 | # exit()
225 |
226 | return act_probs, values[0]
227 |
228 | def train_step(self, state_batch, mcts_probs, winner_batch, learning_rate):
229 | #print('hello training....')
230 | #print(mcts_probs[0], winner_batch[0])
231 | self.train_batch._optimizer.lr = learning_rate
232 | state_batch = mx.nd.array(np.asarray(state_batch).reshape(-1, self.channelnum, self.board_height, self.board_width))
233 | mcts_probs = mx.nd.array(np.asarray(mcts_probs).reshape(-1, self.board_height*self.board_width))
234 | winner_batch = mx.nd.array(np.asarray(winner_batch).reshape(-1, 1))
235 | self.train_batch.forward(mx.io.DataBatch([state_batch], [winner_batch, mcts_probs]))
236 | self.train_batch.backward()
237 | self.train_batch.update()
238 | loss, entropy = self.train_batch.get_outputs()
239 |
240 | args, auxs = self.train_batch.get_params()
241 | self.predict_batch.set_params(args, auxs)
242 | self.predict_one.set_params(args, auxs)
243 |
244 | return loss.asnumpy(), entropy.asnumpy()
245 |
246 | def get_policy_param(self):
247 | net_params = self.train_batch.get_params()
248 | return net_params
249 |
250 | def save_model(self, model_file):
251 | """ save model params to file """
252 | net_params = self.get_policy_param()
253 | print('>>>>>>>>>> saved into', model_file)
254 | pickle.dump(net_params, open(model_file, 'wb'), protocol=2)
255 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | line_profiler==2.1.2
2 | requests==2.18.4
3 | numpy==1.14.2
4 | beautifulsoup4==4.6.0
5 | PyYAML==3.13
6 | psyco
7 | tornado
8 | mxnet==1.6.0
--------------------------------------------------------------------------------
/run_ai.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "start for AI.." && source /workspace/dev.env/bin/activate
4 | export PATH=/usr/local/cuda-10.2/bin:$PATH
5 | export LD_LIBRARY_PATH=/usr/local/cuda-10.2/lib64:$LD_LIBRARY_PATH
6 | sleep 5
7 | export CUDA_VISIBLE_DEVICES=1 # 设置使用的显卡
8 | echo "cuda device: ", $CUDA_VISIBLE_DEVICES
9 | cd /workspace/AlphaPig/evaluate/ && echo "run AI is OK "
10 | nohup python ChessClient.py --cur_role 1 --model r10 --room_name 1 --server_url http://gobang_server:8888 &
11 | nohup python ChessClient.py --cur_role 1 --model r10 --room_name 2 --server_url http://gobang_server:8888 &
12 | nohup python ChessClient.py --cur_role 1 --model r10 --room_name 3 --server_url http://gobang_server:8888 &
13 | nohup python ChessClient.py --cur_role 1 --model r10 --room_name 4 --server_url http://gobang_server:8888 &
14 | nohup python ChessClient.py --cur_role 1 --model r10 --room_name 5 --server_url http://gobang_server:8888 &
15 | nohup python ChessClient.py --cur_role 2 --model r10 --room_name 6 --server_url http://gobang_server:8888 &
16 | nohup python ChessClient.py --cur_role 2 --model r10 --room_name 7 --server_url http://gobang_server:8888 &
17 | nohup python ChessClient.py --cur_role 2 --model r10 --room_name 8 --server_url http://gobang_server:8888 &
18 | nohup python ChessClient.py --cur_role 2 --model r10 --room_name 9 --server_url http://gobang_server:8888 &
19 | python ChessClient.py --cur_role 2 --model r10 --room_name 10 --server_url http://gobang_server:8888
20 |
--------------------------------------------------------------------------------
/run_server.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "start...2" && source /workspace/dev.env/bin/activate && \
4 | cd /workspace/AlphaPig/evaluate/ && echo "run server is OK " \
5 | && python ChessServer.py --port 8888
--------------------------------------------------------------------------------
/sgf_data/download.sh:
--------------------------------------------------------------------------------
1 | wget "http://p324ywv2g.bkt.clouddn.com/sgf_data.zip"
2 | unzip sgf_data.zip
--------------------------------------------------------------------------------
/start_train.sh:
--------------------------------------------------------------------------------
1 | nohup python -u train_mxnet.py &
--------------------------------------------------------------------------------
/train_mxnet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | An implementation of the training pipeline of AlphaZero for Gomoku
4 |
5 | @author: Junxiao Song
6 | anxingle
7 | """
8 |
9 | from __future__ import print_function
10 | import pickle
11 | import random
12 | import os
13 | import time
14 | import numpy as np
15 | from optparse import OptionParser
16 | import multiprocessing as mp
17 | from collections import defaultdict, deque
18 | from game import Board, Game
19 | from game_ai import Game_AI
20 | from mcts_pure import MCTSPlayer as MCTS_Pure
21 | from mcts_alphaZero import MCTSPlayer
22 | from utils import config_loader, send_email
23 |
24 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
25 | # from policy_value_net import PolicyValueNet # Theano and Lasagne
26 | # from policy_value_net_pytorch import PolicyValueNet # Pytorch
27 | # from policy_value_net_tensorflow import PolicyValueNet # Tensorflow
28 | # from policy_value_net_keras import PolicyValueNet # Keras
29 | from policy_value_net_mxnet import PolicyValueNet # Mxnet
30 |
31 | import logging
32 | import logging.config
33 | logging.config.dictConfig(config_loader.config_['train_logging'])
34 | _logger = logging.getLogger(__name__)
35 |
36 | current_relative_path = lambda x: os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), x))
37 |
38 | class TrainPipeline():
39 | def __init__(self, conf, init_model=None):
40 | # params of the board and the game
41 | self.board_width = conf['board_width']
42 | self.board_height = conf['board_height']
43 | self.n_in_row = conf['n_in_row']
44 | self.board = Board(width=self.board_width,
45 | height=self.board_height,
46 | n_in_row=self.n_in_row)
47 | self.game = Game(self.board)
48 | self.game_ai = Game_AI(self.board)
49 | # training params
50 | self.learn_rate = conf['learn_rate']
51 | self.lr_multiplier = conf['lr_multiplier'] # adaptively adjust the learning rate based on KL
52 | self.temp = conf['temp'] # the temperature param
53 | self.n_playout = conf['n_playout'] # 500 # num of simulations for each move
54 | self.c_puct = conf['c_puct']
55 | self.buffer_size = conf['buffer_size']
56 | self.batch_size = conf['batch_size'] # mini-batch size for training
57 | self.data_buffer = deque(maxlen=self.buffer_size)
58 | self.play_batch_size = conf['play_batch_size']
59 | self.epochs = conf['epochs'] # num of train_steps for each update
60 | self.kl_targ = conf['kl_targ']
61 | self.check_freq = conf['check_freq']
62 | self.game_batch_num =conf['game_batch_num']
63 | self.best_win_ratio = 0.0
64 | # 多线程相关
65 | self._cpu_count = mp.cpu_count() - 8
66 | # num of simulations used for the pure mcts, which is used as
67 | # the opponent to evaluate the trained policy
68 | self.pure_mcts_playout_num = conf['pure_mcts_playout_num']
69 | # 训练集文件
70 | self._sgf_home = current_relative_path(conf['sgf_dir'])
71 | _logger.info('path: %s' % self._sgf_home)
72 | self._ai_data_home = current_relative_path(conf['ai_data_dir'])
73 | # 加载人类对弈数据
74 | self._load_training_data(self._sgf_home)
75 | # 加载保存的自对弈数据
76 | # self._load_pickle_data(self._ai_data_home)
77 | if init_model:
78 | # start training from an initial policy-value net
79 | self.policy_value_net = PolicyValueNet(self.board_width,
80 | self.board_height,
81 | self.batch_size,
82 | n_blocks=10,
83 | n_filter=128,
84 | model_params=init_model)
85 | else:
86 | # start training from a new policy-value net
87 | self.policy_value_net = PolicyValueNet(self.board_width,
88 | self.board_height,
89 | self.batch_size,
90 | n_blocks=10,
91 | n_filter=128)
92 | self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
93 | c_puct=self.c_puct,
94 | n_playout=self.n_playout,
95 | is_selfplay=1)
96 |
97 | def _load_training_data(self, data_dir):
98 | file_list = os.listdir(data_dir)
99 | self._training_data = [item for item in file_list if item.endswith('.sgf') and os.path.isfile(os.path.join(data_dir, item))]
100 | random.shuffle(self._training_data)
101 | self._length_train_data = len(self._training_data)
102 |
103 | """"
104 | def _load_pickle_data(self, data_dir):
105 | file_list = os.listdir(data_dir)
106 | txt_list = [item for item in file_list if item.endswith('.txt') and os.path.isfile(os.path.join(data_dir, item))]
107 | self._ai_history_data = []
108 | for txt_f in txt_list:
109 | with open(os.path.join(data_dir, txt_f), 'rb') as f_object:
110 | d = pickle.load(f_object)
111 | self._ai_history_data += d
112 | f_object.close()
113 | """
114 |
115 | def get_equi_data(self, play_data):
116 | """augment the data set by rotation and flipping
117 | play_data: [(state, mcts_prob, winner_z), ..., ...]
118 | """
119 | extend_data = []
120 | for state, mcts_porb, winner in play_data:
121 | for i in [1, 2, 3, 4]:
122 | # rotate counterclockwise
123 | equi_state = np.array([np.rot90(s, i) for s in state])
124 | equi_mcts_prob = np.rot90(np.flipud(
125 | mcts_porb.reshape(self.board_height, self.board_width)), i)
126 | extend_data.append((equi_state,
127 | np.flipud(equi_mcts_prob).flatten(),
128 | winner))
129 | # flip horizontally
130 | equi_state = np.array([np.fliplr(s) for s in equi_state])
131 | equi_mcts_prob = np.fliplr(equi_mcts_prob)
132 | extend_data.append((equi_state,
133 | np.flipud(equi_mcts_prob).flatten(),
134 | winner))
135 | return extend_data
136 |
137 | def collect_selfplay_data(self, n_games=1, training_index=None):
138 | """collect SGF file data for training"""
139 | data_index = training_index % self._length_train_data
140 | if data_index == 0:
141 | random.shuffle(self._training_data)
142 | for i in range(n_games):
143 | warning, winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp, sgf_home=self._sgf_home, file_name=self._training_data[data_index])
144 | if warning:
145 | _logger.error('\033[0;41m %s \033[0m anxingle_training_index: %s, data_index: %s, file: %s' % ('WARNING', training_index, data_index, self._training_data[data_index]))
146 | else:
147 | _logger.info('winner: %s, file: %s ' % (winner, self._training_data[data_index]))
148 | # print('play_data: ', play_data)
149 | play_data = list(play_data)[:]
150 | self.episode_len = len(play_data)
151 | # augment the data
152 | play_data = self.get_equi_data(play_data)
153 | self.data_buffer.extend(play_data)
154 | _logger.info('game_batch_index: %s, length of data_buffer: %s' % (training_index, len(self.data_buffer)))
155 |
156 | """
157 | def collect_selfplay_data_pickle(self, n_games=1, training_index=None):
158 | # load AI self play data(auto save for every N game play nums)
159 | data_index = training_index % len(self._ai_history_data)
160 | if data_index == 0:
161 | random.shuffle(self._ai_history_data)
162 | for i in range(n_games):
163 | play_data = self._ai_history_data[data_index]
164 | self.episode_len = len(play_data)
165 | # augment the data
166 | play_data = self.get_equi_data(play_data)
167 | self.data_buffer.extend(play_data)
168 | """
169 |
170 | def collect_selfplay_data_ai(self, n_games=1, training_index=None):
171 | """collect AI self-play data for training"""
172 | for i in range(n_games):
173 | winner, play_data = self.game_ai.start_self_play(self.mcts_player,
174 | temp=self.temp)
175 | _logger.info('traing_index: %s, winner is: %s' % (training_index, winner))
176 | play_data = list(play_data)[:]
177 | self.episode_len = len(play_data)
178 | # augment the data
179 | play_data = self.get_equi_data(play_data)
180 | self.data_buffer.extend(play_data)
181 |
182 | # def _multiprocess_collect_selfplay_data(self, q, process_index):
183 | # """
184 | # TODO: CUDA multiprocessing have bugs!
185 | # winner, play_data = self.game.start_self_play(self.mcts_player,
186 | # temp=self.temp)
187 | # play_data = list(play_data)[:]
188 | # self.episode_len = len(play_data)
189 | # # augment the data
190 | # play_data = self.get_equi_data(play_data)
191 | # q.put(play_data)
192 |
193 |
194 | def policy_update(self):
195 | """update the policy-value net"""
196 | mini_batch = random.sample(self.data_buffer, self.batch_size)
197 | state_batch = [data[0] for data in mini_batch]
198 | mcts_probs_batch = [data[1] for data in mini_batch]
199 | winner_batch = [data[2] for data in mini_batch]
200 | old_probs, old_v = self.policy_value_net.policy_value(state_batch)
201 | learn_rate = self.learn_rate*self.lr_multiplier
202 | for i in range(self.epochs):
203 | loss, entropy = self.policy_value_net.train_step(
204 | state_batch,
205 | mcts_probs_batch,
206 | winner_batch,
207 | learn_rate)
208 | new_probs, new_v = self.policy_value_net.policy_value(state_batch)
209 | kl = np.mean(np.sum(old_probs * (
210 | np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
211 | axis=1)
212 | )
213 | if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly
214 | _logger.info('early stopping. i:%s. epochs: %s' % (i, self.epochs))
215 | break
216 | # adaptively adjust the learning rate
217 | if kl > self.kl_targ * 2 and self.lr_multiplier > 0.05:
218 | self.lr_multiplier /= 1.5
219 | elif kl < self.kl_targ / 2 and self.lr_multiplier < 20:
220 | self.lr_multiplier *= 1.5
221 |
222 | explained_var_old = (1 -
223 | np.var(np.array(winner_batch) - old_v.flatten()) /
224 | np.var(np.array(winner_batch)))
225 | explained_var_new = (1 -
226 | np.var(np.array(winner_batch) - new_v.flatten()) /
227 | np.var(np.array(winner_batch)))
228 | _logger.info(("kl:{:.4f},"
229 | "lr:{:.1e},"
230 | "loss:{},"
231 | "entropy:{},"
232 | "explained_var_old:{:.3f},"
233 | "explained_var_new:{:.3f}"
234 | ).format(kl,
235 | learn_rate,
236 | loss,
237 | entropy,
238 | explained_var_old,
239 | explained_var_new))
240 | return loss, entropy
241 |
242 | def policy_evaluate(self, n_games=10):
243 | """
244 | Evaluate the trained policy by playing against the pure MCTS player
245 | Note: this is only for monitoring the progress of training
246 | """
247 | current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
248 | c_puct=self.c_puct,
249 | n_playout=self.n_playout)
250 | pure_mcts_player = MCTS_Pure(c_puct=5,
251 | n_playout=self.pure_mcts_playout_num)
252 | win_cnt = defaultdict(int)
253 | for i in range(n_games):
254 | winner = self.game.start_play(current_mcts_player,
255 | pure_mcts_player,
256 | start_player=i % 2,
257 | is_shown=0)
258 | win_cnt[winner] += 1
259 | win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games
260 | _logger.info("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
261 | self.pure_mcts_playout_num,
262 | win_cnt[1], win_cnt[2], win_cnt[-1]))
263 | return win_ratio
264 |
265 | def run(self):
266 | """run the training pipeline"""
267 | try:
268 | for i in range(self.game_batch_num):
269 | current_time = time.time()
270 | if i < 4000:
271 | self.collect_selfplay_data(1, training_index=i)
272 | else:
273 | self.collect_selfplay_data_ai(1, training_index=i)
274 | _logger.info('collection cost time: %d ' % (time.time() - current_time))
275 | _logger.info("batch i:{}, episode_len:{}, buffer_len:{}".format(
276 | i+1, self.episode_len, len(self.data_buffer)))
277 | if len(self.data_buffer) > self.batch_size:
278 | batch_time = time.time()
279 | loss, entropy = self.policy_update()
280 | _logger.info('train batch cost time: %d' % (time.time() - batch_time))
281 | # check the performance of the current model,
282 | # and save the model params
283 | if (i+1) % 50 == 0:
284 | self.policy_value_net.save_model('./logs/current_policy.model')
285 | if (i+1) % self.check_freq == 0:
286 | check_time = time.time()
287 | _logger.info("current self-play batch: {}".format(i+1))
288 | win_ratio = self.policy_evaluate()
289 | _logger.info('evaluate the network cost time: %s ', int(time.time() - check_time))
290 | if win_ratio > self.best_win_ratio:
291 | _logger.info("New best policy!!!!!!!!")
292 | self.best_win_ratio = win_ratio
293 | # update the best_policy
294 | self.policy_value_net.save_model('./logs/best_policy_%s.model' % i)
295 | if (self.best_win_ratio >= 0.98 and
296 | self.pure_mcts_playout_num < 8000):
297 | self.pure_mcts_playout_num += 1000
298 | self.best_win_ratio = 0.0
299 | except KeyboardInterrupt:
300 | _logger.info('\n\rquit')
301 |
302 |
303 | if __name__ == '__main__':
304 | try:
305 | start_time = time.time()
306 | # model_file = './logs/current_policy.model'
307 | model_file = None
308 | policy_param = None
309 | conf = config_loader.load_config('./conf/train_config.yaml')
310 | if model_file is not None:
311 | _logger.info('loading...%s' % model_file)
312 | try:
313 | policy_param = pickle.load(open(model_file, 'rb'))
314 | except:
315 | policy_param = pickle.load(open(model_file, 'rb'),
316 | encoding='bytes') # To support python3
317 | training_pipeline = TrainPipeline(conf, policy_param)
318 | _logger.info('enter training!')
319 | # training_pipeline.collect_selfplay_data(1, 1)
320 | training_pipeline.run()
321 | except Exception as e:
322 | _logger.exception(e)
323 | finally:
324 | cost_time = int(time.time() - start_time)
325 | format_time = "耗时: %s 小时 %s 分 %s 秒" % (cost_time/3600, (cost_time%3600)/60, (cost_time%3600)%60 )
326 | send_email.send_mail('训练结束', format_time, 'XXX')
327 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import config_loader
3 | import sgf_dataIter
--------------------------------------------------------------------------------
/utils/config_loader.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import os
3 | import yaml
4 | import socket
5 | import time
6 | import sys
7 | import logging
8 |
9 |
10 | def load_config(data_path):
11 | f = open(data_path, 'r')
12 | conf = yaml.load(f)
13 | f.close()
14 | return conf
15 |
16 | configure_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../conf/train_config.yaml')
17 | config_ = load_config(configure_path)
18 |
19 |
--------------------------------------------------------------------------------
/utils/send_email.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 |
3 | import smtplib
4 | from email.mime.text import MIMEText
5 | from email.utils import formataddr
6 |
7 |
8 | def send_mail(message_title, message_text, pass_wd):
9 | my_sender='987683297@qq.com' # 发件人邮箱账号
10 | my_pass = 'oktbbgwyddpybeih' # 发件人邮箱密码
11 | my_pass = pass_wd
12 | my_user='anxingle820@gmail.com' # 收件人邮箱账号,我这边发送给自己
13 | ret=True
14 | try:
15 | # msg=MIMEText('这个AI可以了','plain','utf-8')
16 | msg=MIMEText(message_text,'plain','utf-8')
17 | msg['From']=formataddr(["FromRunoob",my_sender]) # 括号里的对应发件人邮箱昵称、发件人邮箱账号
18 | msg['To']=formataddr(["FK",my_user]) # 括号里的对应收件人邮箱昵称、收件人邮箱账号
19 | msg['Subject'] = message_title # 邮件的主题,也可以说是标题
20 |
21 | server=smtplib.SMTP_SSL("smtp.qq.com", 465) # 发件人邮箱中的SMTP服务器,端口是25
22 | server.login(my_sender, my_pass) # 括号中对应的是发件人邮箱账号、邮箱密码
23 | server.sendmail(my_sender,[my_user,],msg.as_string()) # 括号中对应的是发件人邮箱账号、收件人邮箱账号、发送邮件
24 | server.quit() # 关闭连接
25 | except Exception: # 如果 try 中的语句没有执行,则会执行下面的 ret=False
26 | ret=False
27 | return ret
28 |
29 |
--------------------------------------------------------------------------------
/utils/sgf_dataIter.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import os
3 | import sys
4 | import time
5 | # 可视化棋谱 视情况决定是否引入
6 | # from gobang_board_utils import chessboard, evaluation, searcher, psyco_speedup
7 | # 加速函数
8 | # psyco_speedup()
9 |
10 | LETTER_NUM = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o']
11 | BIG_LETTER_NUM = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O']
12 | NUM_LIST = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
13 | # 棋盘字母位置速查表
14 | seq_lookup = dict(zip(LETTER_NUM, NUM_LIST))
15 | num2char_lookup = dict(zip(NUM_LIST, BIG_LETTER_NUM))
16 |
17 | # SGF文件
18 | sgf_home = '/Users/anxingle/Downloads/SGF_Gomoku/sgf/'
19 |
20 |
21 | def get_files_as_list(data_dir):
22 | # 扫描某目录下SGF文件列表
23 | file_list = os.listdir(data_dir)
24 | file_list = [item for item in file_list if item.endswith('.sgf') and os.path.isfile(os.path.join(data_dir, item))]
25 | return file_list
26 |
27 | def content_to_order(sequence):
28 | # 棋谱字母转整型数字
29 |
30 | global seq_lookup # 棋盘字母位置速查表
31 | seq_list = sequence.split(';')
32 | # list:['hh', 'ii', 'hi'....]
33 | seq_list = [item[2:4] for item in seq_list]
34 | # list: [112, 128, ...]
35 | seq_num_list = [seq_lookup[item[0]]*15+seq_lookup[item[1]] for item in seq_list]
36 | return seq_list, seq_num_list
37 |
38 |
39 | def num2char(order_):
40 | global num2char_lookup
41 | Y_axis = num2char_lookup[order_/15]
42 | X_axis = num2char_lookup[order_ % 15]
43 | return '%s%s' % (Y_axis, X_axis)
44 |
45 | def get_data_from_files(file_name, data_dir):
46 | """ 根据文件名读取SGF棋谱内容 """
47 |
48 | assert file_name.endswith('.sgf'), 'file: %s 不是SGF文件' % file_name
49 | with open(os.path.join(data_dir, file_name)) as f:
50 | p = f.read()
51 | # 棋谱内容 开始/结束 位置
52 | start = file_name.index('_') + 1
53 | end = file_name.index('_.')
54 |
55 | sequence = p[p.index('SZ[15]')+7:-4]
56 | try:
57 | seq_list, seq_num_list = content_to_order(sequence)
58 | except Exception as e:
59 | print('***' * 20)
60 | print(e)
61 | print(file_name)
62 | if file_name[file_name.index('_')+1:file_name.index('_')+6] == 'Blank' or file_name[file_name.index('_')+1:file_name.index('_')+6] == 'blank':
63 | winner = 1
64 | if file_name[file_name.index('_')+1:file_name.index('_')+6] == 'White' or file_name[file_name.index('_')+1:file_name.index('_')+6] == 'white':
65 | winner = 2
66 | return {'winner': winner, 'seq_list': seq_list, 'seq_num_list': seq_num_list, 'file_name':file_name}
67 |
68 |
69 | def read_files(data_dir):
70 | # 迭代读取目录下SGF文件的棋谱内容
71 |
72 | # 扫描获取data_dir目录下所有SGF文件
73 | file_list = get_files_as_list(data_dir)
74 | index = 0
75 | while True:
76 | if index >= len(file_list): yield None
77 | with open(data_dir+file_list[index]) as f:
78 | p = f.read()
79 | # 棋谱内容 开始/结束 位置
80 | start = file_list[index].index('_') + 1
81 | end = file_list[index].index('_.')
82 |
83 | sequence = p[p.index('SZ[15]')+7:-4]
84 | try:
85 | seq_list, seq_num_list = content_to_order(sequence)
86 | except Exception as e:
87 | print('***' * 20)
88 | print(e)
89 | print(file_list[index])
90 | if sequence[-5] == 'B' or sequence[-5] == 'b':
91 | winner = 1
92 | if sequence[-5] == 'W' or sequence[-5] == 'w':
93 | winner = 2
94 | yield {'winner': winner, 'seq_list': seq_list, 'seq_num_list': seq_num_list, 'index': index, 'file_name':file_list[index]}
95 | index += 1
96 |
97 | # 仅用于走子可视化
98 | def gamemain(seq_list):
99 | b = chessboard()
100 | s = searcher()
101 | s.board = b.board()
102 |
103 | opening = [
104 | '1:HH 2:II',
105 | # '2:IG 2:GI 1:HH',
106 | # '1:IH 2:GI',
107 | # '1:HG 2:HI',
108 | # '2:HG 2:HI 1:HH',
109 | # '1:HH 2:IH 2:GI',
110 | # '1:HH 2:IH 2:HI',
111 | # '1:HH 2:IH 2:HJ',
112 | # '1:HG 2:HH 2:HI',
113 | # '1:GH 2:HH 2:HI',
114 | ]
115 |
116 | import random
117 | # 开局棋盘局面
118 | # openid = random.randint(0, len(opening) - 1)
119 | # 开局局面为黑方的第一手
120 | def num2char(order_):
121 | global num2char_lookup
122 | Y_axis = num2char_lookup[order_/15]
123 | X_axis = num2char_lookup[order_ % 15]
124 | return '%s%s' % (Y_axis, X_axis)
125 |
126 |
127 | start_open = '1:%s 2:%s' %(num2char(seq_list[0]), num2char(seq_list[1]))
128 | b.loads(start_open)
129 | turn = 2
130 | history = []
131 | undo = False
132 |
133 | # 设置难度
134 | DEPTH = 1
135 | # 对弈开始,从第黑方第二开始
136 | index = 2
137 |
138 | while True:
139 | print ''
140 | while 1:
141 | print '' % (len(history) + 1)
142 | b.show()
143 | print '该你移动了: (u:悔棋, q:退出):',
144 | # 默认一直继续下去
145 | text = raw_input().strip('\r\n\t ')
146 | text = '%s' % num2char(seq_list[index])
147 | print 'char: ', num2char(seq_list[index])
148 | print 'num: ', seq_list[index]
149 | index += 1
150 | if len(text) == 2:
151 | tr = ord(text[0].upper()) - ord('A')
152 | tc = ord(text[1].upper()) - ord('A')
153 | if tr >= 0 and tc >= 0 and tr < 15 and tc < 15:
154 | if b[tr][tc] == 0:
155 | row, col = tr, tc
156 | break
157 | else:
158 | print '已经有棋子在这里了!'
159 | else:
160 | print text
161 | print '不在棋盘内!'
162 | elif text.upper() == 'U':
163 | undo = True
164 | break
165 | elif text.upper() == 'Q':
166 | print b.dumps()
167 | return 0
168 |
169 | if undo == True:
170 | undo = False
171 | if len(history) == 0:
172 | print '棋盘已经清空,无法继续悔棋了!'
173 | else:
174 | print '悔棋中,回退历史棋局 ...'
175 | move = history.pop()
176 | b.loads(move)
177 | else:
178 | history.append(b.dumps())
179 | b[row][col] = 1
180 | b.show()
181 |
182 | if b.check() == 1:
183 | # b.show()
184 | print b.dumps()
185 | print ''
186 | print 'YOU WIN !!'
187 | return 0
188 |
189 | # print 'AI正在思考 ...'
190 | # time.sleep(0.6)
191 | # xtt = input('go on: ')
192 | # score, row, col = s.search(2, DEPTH)
193 | # AI(白方的输入重新)
194 | text_ai = num2char(seq_list[index])
195 | index += 1
196 | # 棋盘字符==>数字
197 | row, col = ord(text_ai[0].upper())-65, ord(text_ai[1].upper())-65
198 | cord = '%s%s' % (chr(ord('A') + row), chr(ord('A') + col))
199 | print 'AI 移动到: %s ' % (cord)
200 | # xtt = input('go on: ')
201 | b[row][col] = 2
202 | xtt = input('go on:')
203 | b.show()
204 |
205 | if b.check() == 2:
206 | # b.show()
207 | print b.dumps()
208 | print ''
209 | print 'YOU LOSE.'
210 | return 0
211 |
212 | return 0
213 |
214 |
215 | if __name__ == '__main__':
216 | data = read_files(sgf_home)
217 | x = None
218 | y = None
219 | for i in range(4800):
220 | y = x
221 | x = data.next()
222 | if x == None:
223 | print('whole loop: ', i)
224 | print('index: ', y['index'])
225 | print('index: ', y['file_name'])
226 | print '\n'
227 | break
228 | else:
229 | pass
230 |
231 |
--------------------------------------------------------------------------------