├── .bazelrc
├── .gitignore
├── .travis.yml
├── BUILD
├── LICENSE.TXT
├── PhoenixGo.sln
├── README.md
├── ThirdParty.props
├── WORKSPACE
├── common
├── BUILD
├── errordef.h
├── go_comm.cc
├── go_comm.h
├── go_state.cc
├── go_state.h
├── str_utils.cc
├── str_utils.h
├── task_queue.h
├── thread_conductor.cc
├── thread_conductor.h
├── timer.cc
├── timer.h
├── wait_group.cc
└── wait_group.h
├── configure
├── configure.py
├── dist
├── BUILD
├── async_dist_zero_model_client.cc
├── async_dist_zero_model_client.h
├── async_rpc_queue.h
├── dist_config.proto
├── dist_zero_model.proto
├── dist_zero_model_client.cc
├── dist_zero_model_client.h
├── dist_zero_model_server.cc
├── leaky_bucket.cc
└── leaky_bucket.h
├── docs
├── FAQ.md
├── benchmark-gtx1060-75w.md
├── benchmark-teslaV100.md
├── go-review-partner.md
├── mcts-main-help.md
├── minimalist-bazel-configure.md
├── path-errors.md
└── tested-versions.md
├── etc
├── mcts_1gpu.conf
├── mcts_1gpu_grp.conf
├── mcts_1gpu_notensorrt.conf
├── mcts_1gpu_notensorrt_grp.conf
├── mcts_2gpu.conf
├── mcts_2gpu_grp.conf
├── mcts_2gpu_notensorrt.conf
├── mcts_2gpu_notensorrt_grp.conf
├── mcts_3gpu.conf
├── mcts_3gpu_notensorrt.conf
├── mcts_4gpu.conf
├── mcts_4gpu_notensorrt.conf
├── mcts_5gpu.conf
├── mcts_5gpu_notensorrt.conf
├── mcts_6gpu.conf
├── mcts_6gpu_notensorrt.conf
├── mcts_7gpu.conf
├── mcts_7gpu_notensorrt.conf
├── mcts_8gpu.conf
├── mcts_8gpu_notensorrt.conf
├── mcts_async_dist.conf
├── mcts_cpu.conf
├── mcts_cpu_grp.conf
└── mcts_dist.conf
├── images
└── logo.jpg
├── mcts
├── BUILD
├── byo_yomi_timer.cc
├── byo_yomi_timer.h
├── debug_tool.cc
├── mcts_config.cc
├── mcts_config.h
├── mcts_config.proto
├── mcts_debugger.cc
├── mcts_debugger.h
├── mcts_engine.cc
├── mcts_engine.h
├── mcts_main.cc
├── mcts_monitor.cc
└── mcts_monitor.h
├── mcts_main.filters
├── mcts_main.vcxproj
├── model
├── BUILD
├── build_tensorrt_model.cc
├── checkpoint_state.proto
├── checkpoint_utils.cc
├── checkpoint_utils.h
├── model_config.proto
├── trt_zero_model.cc
├── trt_zero_model.h
├── zero_model.cc
├── zero_model.h
└── zero_model_base.h
├── rules.bzl
├── scripts
├── build_tensorrt_model.sh
├── get_global_step.py
├── graph_transform.py
├── start.sh
├── start_cpu.bat
└── start_gpu.bat
├── third_party
├── glog
│ ├── BUILD
│ └── glog.patch
└── tensorflow
│ ├── .bazelrc
│ ├── BUILD
│ └── tensorflow.patch
└── tools
└── .keep
/.bazelrc:
--------------------------------------------------------------------------------
1 | import %workspace%/third_party/tensorflow/.bazelrc
2 |
3 | build --config=opt
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | log
2 | conf
3 | ckpt
4 | bazel-*
5 | compile_commands.json
6 | tools/actions
7 | third_party/bazel
8 | *.pyc
9 | .tf_configure.bazelrc
10 | tools/python_bin_path.sh
11 | *.pb.h
12 | *.pb.cc
13 | *.user
14 | *.VC.db
15 | *.VC.VC.opendb
16 | .vs
17 | x64
18 | x86
19 | Release
20 | Debug
21 | .clang
22 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | dist: trusty
2 | sudo: required
3 |
4 | before_install:
5 | - wget https://github.com/bazelbuild/bazel/releases/download/0.11.1/bazel-0.11.1-installer-linux-x86_64.sh
6 | - chmod +x bazel-0.11.1-installer-linux-x86_64.sh
7 | - ./bazel-0.11.1-installer-linux-x86_64.sh --user
8 | - export PATH="$PATH:$HOME/bin"
9 |
10 | script:
11 | - bazel build //mcts:mcts_main
12 | - bazel build //dist:dist_zero_model_server
13 |
--------------------------------------------------------------------------------
/BUILD:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent/PhoenixGo/fbf67f9aec42531bff9569c44b85eb4c3f37b7be/BUILD
--------------------------------------------------------------------------------
/PhoenixGo.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio 14
4 | VisualStudioVersion = 14.0.25420.1
5 | MinimumVisualStudioVersion = 10.0.40219.1
6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "mcts_main", "mcts_main.vcxproj", "{ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}"
7 | EndProject
8 | Global
9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
10 | Debug|x64 = Debug|x64
11 | Debug|x86 = Debug|x86
12 | Release|x64 = Release|x64
13 | Release|x86 = Release|x86
14 | EndGlobalSection
15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
16 | {ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}.Debug|x64.ActiveCfg = Debug|x64
17 | {ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}.Debug|x64.Build.0 = Debug|x64
18 | {ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}.Debug|x86.ActiveCfg = Debug|Win32
19 | {ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}.Debug|x86.Build.0 = Debug|Win32
20 | {ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}.Release|x64.ActiveCfg = Release|x64
21 | {ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}.Release|x64.Build.0 = Release|x64
22 | {ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}.Release|x86.ActiveCfg = Release|Win32
23 | {ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}.Release|x86.Build.0 = Release|Win32
24 | EndGlobalSection
25 | GlobalSection(SolutionProperties) = preSolution
26 | HideSolutionNode = FALSE
27 | EndGlobalSection
28 | EndGlobal
29 |
--------------------------------------------------------------------------------
/ThirdParty.props:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | C:\Users\jchzhang\tensorflow
6 | C:\Users\jchzhang\tensorflow\tensorflow\contrib\cmake\build
7 | C:\Users\jchzhang\glog-0.3.5
8 | C:\Users\jchzhang\glog-0.3.5
9 | C:\Users\jchzhang\gflags-2.2.1
10 | C:\Users\jchzhang\gflags-2.2.1
11 | C:\Users\jchzhang\boost_1_66_0
12 | C:\Users\jchzhang\boost_1_66_0\stage\lib
13 |
14 |
15 |
16 |
17 |
18 | $(tensorflow_SourcePath)
19 |
20 |
21 | $(tensorflow_BuildPath)
22 |
23 |
24 | $(glog_SourcePath)
25 |
26 |
27 | $(glog_BuildPath)
28 |
29 |
30 | $(gflags_SourcePath)
31 |
32 |
33 | $(gflags_BuildPath)
34 |
35 |
36 | $(boost_IncludePath)
37 |
38 |
39 | $(boost_LibPath)
40 |
41 |
42 |
--------------------------------------------------------------------------------
/WORKSPACE:
--------------------------------------------------------------------------------
1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
2 |
3 | http_archive(
4 | name = "com_github_google_glog",
5 | urls = ["https://github.com/google/glog/archive/55cc27b6eca3d7906fc1a920ca95df7717deb4e7.tar.gz"],
6 | sha256 = "4966f4233e4dcac53c7c7dd26054d17c045447e183bf9df7081575d2b888b95b",
7 | strip_prefix = "glog-55cc27b6eca3d7906fc1a920ca95df7717deb4e7",
8 | patches = ["//third_party/glog:glog.patch"],
9 | )
10 |
11 | http_archive(
12 | name = "org_tensorflow",
13 | urls = ["https://github.com/tensorflow/tensorflow/archive/v1.13.1.tar.gz"],
14 | sha256 = "7cd19978e6bc7edc2c847bce19f95515a742b34ea5e28e4389dade35348f58ed",
15 | strip_prefix = "tensorflow-1.13.1",
16 | patches = ["//third_party/tensorflow:tensorflow.patch"],
17 | )
18 |
19 | http_archive(
20 | name = "io_bazel_rules_closure",
21 | sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae",
22 | strip_prefix = "rules_closure-dbb96841cc0a5fb2664c37822803b06dab20c7d1",
23 | urls = [
24 | "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz",
25 | "https://github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", # 2018-04-13
26 | ],
27 | )
28 |
29 | load('@org_tensorflow//tensorflow:workspace.bzl', 'tf_workspace')
30 | tf_workspace(path_prefix = "", tf_repo_name = "org_tensorflow")
31 |
32 | http_archive(
33 | name = "com_github_nelhage_rules_boost",
34 | urls = ["https://github.com/nelhage/rules_boost/archive/6d6fd834281cb8f8e758dd9ad76df86304bf1869.tar.gz"],
35 | sha256 = "9adb4899e40fc10871bab1ff2e8feee950c194eec9940490f65a2761bbe6941d",
36 | strip_prefix = "rules_boost-6d6fd834281cb8f8e758dd9ad76df86304bf1869",
37 | )
38 |
39 | load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
40 | boost_deps()
41 |
--------------------------------------------------------------------------------
/common/BUILD:
--------------------------------------------------------------------------------
1 | cc_library(
2 | name = "go_comm",
3 | srcs = ["go_comm.cc"],
4 | hdrs = ["go_comm.h"],
5 | visibility = ["//visibility:public"],
6 | )
7 |
8 | cc_library(
9 | name = "go_state",
10 | srcs = ["go_state.cc"],
11 | hdrs = ["go_state.h"],
12 | deps = [":go_comm"],
13 | visibility = ["//visibility:public"],
14 | )
15 |
16 | cc_library(
17 | name = "errordef",
18 | hdrs = ["errordef.h"],
19 | visibility = ["//visibility:public"],
20 | )
21 |
22 | cc_library(
23 | name = "task_queue",
24 | hdrs = ["task_queue.h"],
25 | visibility = ["//visibility:public"],
26 | )
27 |
28 | cc_library(
29 | name = "wait_group",
30 | srcs = ["wait_group.cc"],
31 | hdrs = ["wait_group.h"],
32 | visibility = ["//visibility:public"],
33 | )
34 |
35 | cc_library(
36 | name = "thread_conductor",
37 | srcs = ["thread_conductor.cc"],
38 | hdrs = ["thread_conductor.h"],
39 | deps = [":wait_group"],
40 | visibility = ["//visibility:public"],
41 | )
42 |
43 | cc_library(
44 | name = "timer",
45 | srcs = ["timer.cc"],
46 | hdrs = ["timer.h"],
47 | visibility = ["//visibility:public"],
48 | )
49 |
50 | cc_library(
51 | name = "str_utils",
52 | srcs = ["str_utils.cc"],
53 | hdrs = ["str_utils.h"],
54 | visibility = ["//visibility:public"],
55 | )
56 |
--------------------------------------------------------------------------------
/common/errordef.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | enum {
21 | ERR_INVALID_INPUT = -1,
22 | ERR_FORWARD_TIMEOUT = -2,
23 |
24 | ERR_READ_CHECKPOINT = -1000,
25 | ERR_CREATE_SESSION = -1001,
26 | ERR_CREATE_GRAPH = -1002,
27 | ERR_RESTORE_VAR = -1003,
28 | ERR_SESSION_RUN = -1005,
29 |
30 | // tensorrt error
31 | ERR_READ_TRT_MODEL = -2000,
32 | ERR_LOAD_TRT_ENGINE = -2001,
33 | ERR_CUDA_MALLOC = -2002,
34 | ERR_CUDA_FREE = -2003,
35 | ERR_CUDA_MEMCPY = -2004,
36 |
37 | // rpc error
38 | ERR_GLOBAL_STEP_CONFLICT = -3000,
39 | ERR_EMPTY_RESP = -3001,
40 | };
41 |
--------------------------------------------------------------------------------
/common/go_comm.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include "go_comm.h"
19 |
20 | #include
21 | #include
22 | // #include
23 |
24 | #define x first
25 | #define y second
26 |
27 | using namespace std;
28 | using namespace GoComm;
29 |
30 | GoHashValuePair g_hash_weight[BORDER_SIZE][BORDER_SIZE];
31 |
32 | vector g_neighbour_cache_by_coord[BORDER_SIZE][BORDER_SIZE];
33 | GoSize g_neighbour_size[GOBOARD_SIZE];
34 | GoCoordId g_neighbour_cache_by_id[GOBOARD_SIZE][5];
35 | GoCoordId g_log2_table[67];
36 | uint64_t g_zobrist_board_hash_weight[4][GOBOARD_SIZE];
37 | uint64_t g_zobrist_ko_hash_weight[GOBOARD_SIZE];
38 | uint64_t g_zobrist_player_hash_weight[4];
39 |
40 | namespace GoFunction {
41 |
42 | bool InBoard(const GoCoordId id) {
43 | return 0 <= id && id < GOBOARD_SIZE;
44 | }
45 |
46 | bool InBoard(const GoCoordId x, const GoCoordId y) {
47 | return 0 <= x && x < BORDER_SIZE
48 | && 0 <= y && y < BORDER_SIZE;
49 | }
50 |
51 | bool IsPass(const GoCoordId id) {
52 | return COORD_PASS == id;
53 | }
54 |
55 | bool IsPass(const GoCoordId x, const GoCoordId y) {
56 | return COORD_PASS == CoordToId(x, y);
57 | }
58 |
59 | bool IsUnset(const GoCoordId id) {
60 | return COORD_UNSET == id;
61 | }
62 |
63 | bool IsUnset(const GoCoordId x, const GoCoordId y) {
64 | return COORD_UNSET == CoordToId(x, y);
65 | }
66 |
67 | bool IsResign(const GoCoordId id) {
68 | return COORD_RESIGN == id;
69 | }
70 |
71 | bool IsResign(const GoCoordId x, const GoCoordId y) {
72 | return COORD_RESIGN == CoordToId(x, y);
73 | }
74 |
75 |
76 | void IdToCoord(const GoCoordId id, GoCoordId &x, GoCoordId &y) {
77 | if (COORD_PASS == id) {
78 | x = y = COORD_PASS;
79 | } else if (COORD_RESIGN == id) {
80 | x = y = COORD_RESIGN;
81 | } else if (!InBoard(id)) {
82 | x = y = COORD_UNSET;
83 | } else {
84 | x = id / BORDER_SIZE;
85 | y = id % BORDER_SIZE;
86 | }
87 | }
88 |
89 | GoCoordId CoordToId(const GoCoordId x, const GoCoordId y) {
90 | if (COORD_PASS == x && COORD_PASS == y) {
91 | return COORD_PASS;
92 | }
93 | if (COORD_RESIGN == x && COORD_RESIGN == y) {
94 | return COORD_RESIGN;
95 | }
96 | if (!InBoard(x, y)) {
97 | return COORD_UNSET;
98 | }
99 | return x * BORDER_SIZE + y;
100 | }
101 |
102 | void StrToCoord(const string &str, GoCoordId &x, GoCoordId &y) {
103 | // CHECK_EQ(str.length(), 2) << "string[" << str << "] length not equal to 2";
104 | x = str[0] - 'a';
105 | y = str[1] - 'a';
106 | if (str == "zz") {
107 | x = y = COORD_PASS;
108 | } else if (!InBoard(x, y)) {
109 | x = y = COORD_UNSET;
110 | }
111 | }
112 |
113 | string CoordToStr(const GoCoordId x, const GoCoordId y) {
114 | char buffer[3];
115 | if (!InBoard(x, y)) {
116 | buffer[0] = buffer[1] = 'z';
117 | } else {
118 | buffer[0] = x + 'a';
119 | buffer[1] = y + 'a';
120 | }
121 | return string(buffer, 2);
122 | }
123 |
124 | std::string IdToStr(const GoCoordId id) {
125 | GoCoordId x, y;
126 | IdToCoord(id, x, y);
127 | return CoordToStr(x, y);
128 | }
129 |
130 | GoCoordId StrToId(const std::string &str) {
131 | GoCoordId x, y;
132 | StrToCoord(str, x, y);
133 | return CoordToId(x, y);
134 | }
135 |
136 | once_flag CreateGlobalVariables_once;
137 | void CreateGlobalVariables() {
138 | call_once(
139 | CreateGlobalVariables_once,
140 | []() {
141 | CreateNeighbourCache();
142 | CreateHashWeights();
143 | CreateQuickLog2Table();
144 | CreateZobristHash();
145 | }
146 | );
147 | }
148 |
149 |
150 | void CreateHashWeights() {
151 | g_hash_weight[0][0] = GoHashValuePair(1, 1);
152 | for (GoCoordId i = 1; i < GOBOARD_SIZE; ++i) {
153 | g_hash_weight[i / BORDER_SIZE][i % BORDER_SIZE] =
154 | GoHashValuePair(g_hash_weight[(i - 1) / BORDER_SIZE][(i - 1) % BORDER_SIZE].x * g_hash_unit.x,
155 | g_hash_weight[(i - 1) / BORDER_SIZE][(i - 1) % BORDER_SIZE].y * g_hash_unit.y);
156 | }
157 | }
158 |
159 | void CreateNeighbourCache() {
160 | for (GoCoordId x = 0; x < BORDER_SIZE; ++x) {
161 | for (GoCoordId y = 0; y < BORDER_SIZE; ++y) {
162 | GoCoordId id = CoordToId(x, y);
163 |
164 | g_neighbour_cache_by_coord[x][y].clear();
165 | for (int i = 0; i <= DELTA_SIZE; ++i) {
166 | g_neighbour_cache_by_id[id][i] = COORD_UNSET;
167 | }
168 | for (int i = 0; i < DELTA_SIZE; ++i) {
169 | GoCoordId nx = x + DELTA_X[i];
170 | GoCoordId ny = y + DELTA_Y[i];
171 |
172 | if (InBoard(nx, ny)) {
173 | g_neighbour_cache_by_coord[x][y].push_back(GoPosition(nx, ny));
174 | }
175 | }
176 | g_neighbour_size[id] = g_neighbour_cache_by_coord[x][y].size();
177 | for (GoSize i = 0; i < g_neighbour_cache_by_coord[x][y].size(); ++i) {
178 | g_neighbour_cache_by_id[id][i] = CoordToId(g_neighbour_cache_by_coord[x][y][i].x,
179 | g_neighbour_cache_by_coord[x][y][i].y);
180 | }
181 | }
182 | }
183 | // cerr << hex << int(g_neighbour_cache_by_coord) << endl;
184 | }
185 |
186 | void CreateQuickLog2Table() {
187 | memset(g_log2_table, -1, sizeof(g_log2_table));
188 | int tmp = 1;
189 |
190 | for (GoCoordId i = 0; i < 64; ++i) {
191 | g_log2_table[tmp] = i;
192 | tmp *= 2;
193 | tmp %= 67;
194 | }
195 | }
196 |
197 | #if defined(_WIN32) || defined(_WIN64)
198 | static int rand_r(unsigned int *seed)
199 | {
200 | unsigned int next = *seed;
201 | int result;
202 |
203 | next *= 1103515245;
204 | next += 12345;
205 | result = (unsigned int)(next / 65536) % 2048;
206 |
207 | next *= 1103515245;
208 | next += 12345;
209 | result <<= 10;
210 | result ^= (unsigned int)(next / 65536) % 1024;
211 |
212 | next *= 1103515245;
213 | next += 12345;
214 | result <<= 10;
215 | result ^= (unsigned int)(next / 65536) % 1024;
216 |
217 | *seed = next;
218 |
219 | return result;
220 | }
221 | #endif
222 |
223 | void CreateZobristHash() {
224 | uint32_t seed = 0xdeadbeaf;
225 |
226 | for (int i = 0; i < 4; ++i) {
227 | g_zobrist_player_hash_weight[i] = (uint64_t) rand_r(&seed) << 32 | rand_r(&seed);
228 | for (int j = 0; j < GOBOARD_SIZE; ++j) {
229 | g_zobrist_board_hash_weight[i][j] = (uint64_t) rand_r(&seed) << 32 | rand_r(&seed);
230 | }
231 | }
232 |
233 | for (int i = 0; i < GOBOARD_SIZE; ++i) {
234 | g_zobrist_ko_hash_weight[i] = (uint64_t) rand_r(&seed) << 32 | rand_r(&seed);
235 | }
236 | }
237 |
238 | } // namespace GoFunction
239 |
240 | #undef y
241 | #undef x
242 |
--------------------------------------------------------------------------------
/common/go_comm.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include
21 | #include
22 | #include
23 |
24 | // Return code of functions should be "int"
25 | typedef uint8_t GoStoneColor; // Stone color
26 | typedef int16_t GoCoordId; // Stone IDs or coordinates
27 | typedef int16_t GoBlockId; // Block IDs
28 | typedef int16_t GoSize; // Counts of visit times, used blocks, ..
29 |
30 | namespace GoComm {
31 |
32 | const GoCoordId BORDER_SIZE = 19;
33 | const GoCoordId GOBOARD_SIZE = BORDER_SIZE * BORDER_SIZE;
34 | const GoCoordId COORD_UNSET = -2;
35 | const GoCoordId COORD_PASS = -1;
36 | const GoCoordId COORD_RESIGN = -3;
37 |
38 | const GoSize SIZE_NONE = 0;
39 |
40 | const GoBlockId MAX_BLOCK_SIZE = 1 << 8;
41 | const GoBlockId BLOCK_UNSET = -1;
42 |
43 | const GoStoneColor EMPTY = 0;
44 | const GoStoneColor BLACK = 1;
45 | const GoStoneColor WHITE = 2;
46 | const GoStoneColor WALL = 3;
47 | const GoStoneColor COLOR_UNKNOWN = -1;
48 | const char *const COLOR_STRING[] = { "Empty", "Black", "White", "Wall" };
49 |
50 | const GoCoordId DELTA_X[] = { 0, 1, 0, -1 };
51 | const GoCoordId DELTA_Y[] = { -1, 0, 1, 0 };
52 | const GoSize DELTA_SIZE = sizeof(DELTA_X) / sizeof(*DELTA_X);
53 |
54 | const GoSize UINT64_BITS = sizeof(uint64_t) * 8;
55 | const GoSize LIBERTY_STATE_SIZE = (GOBOARD_SIZE + UINT64_BITS - 1) / UINT64_BITS;
56 | const GoSize BOARD_STATE_SIZE = (GOBOARD_SIZE + UINT64_BITS - 1) / UINT64_BITS;
57 |
58 | } // namespace GoComm
59 |
60 | namespace GoFeature {
61 |
62 | const int SIZE_HISTORYEACHSIDE = 16;
63 | const int SIZE_PLAYERCOLOR = 1;
64 |
65 | const int STARTPOS_HISTORYEACHSIDE = 0;
66 | const int STARTPOS_PLAYERCOLOR = STARTPOS_HISTORYEACHSIDE + SIZE_HISTORYEACHSIDE;
67 |
68 | const int FEATURE_COUNT = STARTPOS_PLAYERCOLOR + SIZE_PLAYERCOLOR;
69 |
70 | } // namespace GoFeature
71 |
72 |
73 | namespace GoFunction {
74 |
75 | extern bool InBoard(const GoCoordId id);
76 |
77 | extern bool InBoard(const GoCoordId x, const GoCoordId y);
78 |
79 | extern bool IsPass(const GoCoordId id);
80 |
81 | extern bool IsPass(const GoCoordId x, const GoCoordId y);
82 |
83 | extern bool IsUnset(const GoCoordId id);
84 |
85 | extern bool IsUnset(const GoCoordId x, const GoCoordId y);
86 |
87 | extern bool IsResign(const GoCoordId id);
88 |
89 | extern bool IsResign(const GoCoordId x, const GoCoordId y);
90 |
91 |
92 | extern void IdToCoord(const GoCoordId id, GoCoordId &x, GoCoordId &y);
93 |
94 | extern GoCoordId CoordToId(const GoCoordId x, const GoCoordId y);
95 |
96 | extern void StrToCoord(const std::string &str, GoCoordId &x, GoCoordId &y);
97 |
98 | extern std::string CoordToStr(const GoCoordId x, const GoCoordId y);
99 |
100 | extern std::string IdToStr(const GoCoordId id);
101 |
102 | extern GoCoordId StrToId(const std::string &str);
103 |
104 |
105 | extern void CreateGlobalVariables();
106 |
107 | extern void CreateHashWeights();
108 |
109 | extern void CreateNeighbourCache();
110 |
111 | extern void CreateQuickLog2Table();
112 |
113 | extern void CreateZobristHash();
114 |
115 | } // namespace GoFunction
116 |
117 |
118 | typedef std::pair GoPosition;
119 | typedef std::pair GoHashValuePair;
120 |
121 | extern GoHashValuePair g_hash_weight[GoComm::BORDER_SIZE][GoComm::BORDER_SIZE];
122 | const GoHashValuePair g_hash_unit(3, 7);
123 | extern uint64_t g_zobrist_board_hash_weight[4][GoComm::GOBOARD_SIZE];
124 | extern uint64_t g_zobrist_ko_hash_weight[GoComm::GOBOARD_SIZE];
125 | extern uint64_t g_zobrist_player_hash_weight[4];
126 |
127 | extern std::vector g_neighbour_cache_by_coord[GoComm::BORDER_SIZE][GoComm::BORDER_SIZE];
128 | extern GoSize g_neighbour_size[GoComm::GOBOARD_SIZE];
129 | extern GoCoordId g_neighbour_cache_by_id[GoComm::GOBOARD_SIZE][GoComm::DELTA_SIZE + 1];
130 | extern GoCoordId g_log2_table[67];
131 |
132 | #define FOR_NEI(id, nb) for (GoCoordId *nb = g_neighbour_cache_by_id[(id)]; \
133 | GoComm::COORD_UNSET != *nb; ++nb)
134 | #define FOR_EACHCOORD(id) for (GoCoordId id = 0; id < GoComm::GOBOARD_SIZE; ++id)
135 | #define FOR_EACHBLOCK(id) for (GoBlockId id = 0; id < GoComm::MAX_BLOCK_SIZE; ++id)
136 |
--------------------------------------------------------------------------------
/common/str_utils.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include "str_utils.h"
19 |
20 | std::vector SplitStr(const std::string &str, char delim)
21 | {
22 | std::vector ret = {""};
23 | for (char c: str) {
24 | if (c == delim) {
25 | ret.emplace_back();
26 | } else {
27 | ret.back() += c;
28 | }
29 | }
30 | return ret;
31 | }
32 |
--------------------------------------------------------------------------------
/common/str_utils.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include
21 | #include
22 |
23 | std::vector SplitStr(const std::string &str, char delim);
24 |
--------------------------------------------------------------------------------
/common/task_queue.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include
21 | #include
22 | #include
23 | #include
24 | #include
25 |
26 | template
27 | class TaskQueue
28 | {
29 | public:
30 | TaskQueue(int capacity = 0)
31 | : m_capacity(capacity), m_size(0), m_is_close(false)
32 | {
33 | }
34 |
35 | template
36 | void Push(U &&elem)
37 | {
38 | std::unique_lock lock(m_mutex);
39 | if (m_capacity > 0) {
40 | m_push_cond.wait(lock, [this]{ return m_queue.size() < m_capacity; });
41 | }
42 | m_queue.push_back(std::forward(elem));
43 | m_size = m_queue.size();
44 | lock.unlock();
45 | m_pop_cond.notify_one();
46 | }
47 |
48 | template
49 | void PushFront(U &&elem)
50 | {
51 | std::unique_lock lock(m_mutex);
52 | m_queue.push_front(std::forward(elem));
53 | m_size = m_queue.size();
54 | lock.unlock();
55 | m_pop_cond.notify_one();
56 | }
57 |
58 | bool Pop(T &elem, int64_t timeout_us = -1)
59 | {
60 | std::unique_lock lock(m_mutex);
61 | if (timeout_us < 0) {
62 | m_pop_cond.wait(lock, [this]{ return !m_queue.empty() || m_is_close; });
63 | } else {
64 | if (!m_pop_cond.wait_for(lock, std::chrono::microseconds(timeout_us),
65 | [this]{ return !m_queue.empty() || m_is_close; })) {
66 | return false;
67 | }
68 | }
69 | if (m_queue.empty()) {
70 | return false;
71 | }
72 | elem = std::move(m_queue.front());
73 | m_queue.pop_front();
74 | m_size = m_queue.size();
75 | lock.unlock();
76 | if (m_capacity > 0) {
77 | m_push_cond.notify_one();
78 | }
79 | return true;
80 | }
81 |
82 | void Close()
83 | {
84 | {
85 | std::lock_guard lock(m_mutex);
86 | m_is_close = true;
87 | }
88 | m_push_cond.notify_all();
89 | m_pop_cond.notify_all();
90 | }
91 |
92 | bool IsClose() const
93 | {
94 | return m_is_close;
95 | }
96 |
97 | int Size() const
98 | {
99 | return m_size;
100 | }
101 |
102 | private:
103 | std::deque m_queue;
104 | int m_capacity;
105 | std::atomic m_size;
106 | std::mutex m_mutex;
107 | std::condition_variable m_push_cond;
108 | std::condition_variable m_pop_cond;
109 | std::atomic m_is_close;
110 | };
111 |
--------------------------------------------------------------------------------
/common/thread_conductor.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include "thread_conductor.h"
19 |
20 | #include
21 |
22 | ThreadConductor::ThreadConductor() : m_state(k_pause)
23 | {
24 | }
25 |
26 | ThreadConductor::~ThreadConductor()
27 | {
28 | }
29 |
30 | void ThreadConductor::Pause()
31 | {
32 | m_resume_wg.Wait();
33 | std::lock_guard lock(m_mutex);
34 | m_state = k_pause;
35 | m_cond.notify_all();
36 | }
37 |
38 | void ThreadConductor::Resume(int num_threads)
39 | {
40 | {
41 | std::lock_guard lock(m_mutex);
42 | if (m_state == k_running) return;
43 | m_resume_wg.Add(num_threads);
44 | m_pause_wg.Add(num_threads);
45 | m_state = k_running;
46 | }
47 | m_cond.notify_all();
48 | m_resume_wg.Wait();
49 | }
50 |
51 | void ThreadConductor::Wait()
52 | {
53 | {
54 | std::unique_lock lock(m_mutex);
55 | m_cond.wait(lock, [this]{ return m_state != k_pause; });
56 | if (m_state == k_terminate) return;
57 | }
58 | m_resume_wg.Done();
59 | }
60 |
61 | void ThreadConductor::AckPause()
62 | {
63 | m_pause_wg.Done();
64 | }
65 |
66 | bool ThreadConductor::Join(int64_t timeout_us)
67 | {
68 | return m_pause_wg.Wait(timeout_us);
69 | }
70 |
71 | void ThreadConductor::Sleep(int64_t duration_us)
72 | {
73 | std::unique_lock lock(m_mutex);
74 | m_cond.wait_for(lock, std::chrono::microseconds(duration_us), [this]{ return m_state == k_pause; });
75 | }
76 |
77 | bool ThreadConductor::IsRunning()
78 | {
79 | return m_state == k_running;
80 | }
81 |
82 | void ThreadConductor::Terminate()
83 | {
84 | Pause();
85 | Join();
86 | {
87 | std::lock_guard lock(m_mutex);
88 | m_state = k_terminate;
89 | }
90 | m_cond.notify_all();
91 | }
92 |
93 | bool ThreadConductor::IsTerminate()
94 | {
95 | return m_state == k_terminate;
96 | }
97 |
--------------------------------------------------------------------------------
/common/thread_conductor.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include
21 | #include
22 | #include
23 |
24 | #include "wait_group.h"
25 |
26 |
27 | class ThreadConductor
28 | {
29 | public:
30 | ThreadConductor();
31 | ~ThreadConductor();
32 |
33 | void Pause();
34 | void Resume(int num_threads);
35 | void Wait();
36 | void AckPause();
37 | bool Join(int64_t timeout_us = -1);
38 | void Sleep(int64_t duration_us);
39 | bool IsRunning();
40 | void Terminate();
41 | bool IsTerminate();
42 |
43 | private:
44 | std::atomic m_state;
45 | std::mutex m_mutex;
46 | std::condition_variable m_cond;
47 |
48 | WaitGroup m_resume_wg;
49 | WaitGroup m_pause_wg;
50 |
51 | private:
52 | static const int k_pause = 0;
53 | static const int k_running = 1;
54 | static const int k_terminate = 2;
55 | };
56 |
--------------------------------------------------------------------------------
/common/timer.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include "timer.h"
19 |
20 | using namespace std::chrono;
21 |
22 | Timer::Timer()
23 | : m_start(clock::now())
24 | {
25 | }
26 |
27 | void Timer::Reset()
28 | {
29 | m_start = clock::now();
30 | }
31 |
32 | int64_t Timer::sec() const
33 | {
34 | return duration_cast(clock::now() - m_start).count();
35 | }
36 |
37 | int64_t Timer::ms() const
38 | {
39 | return duration_cast(clock::now() - m_start).count();
40 | }
41 |
42 | int64_t Timer::us() const
43 | {
44 | return duration_cast(clock::now() - m_start).count();
45 | }
46 |
47 | float Timer::fsec() const
48 | {
49 | return std::chrono::duration(clock::now() - m_start).count();
50 | }
51 |
52 | float Timer::fms() const
53 | {
54 | return std::chrono::duration(clock::now() - m_start).count();
55 | }
56 |
57 | float Timer::fus() const
58 | {
59 | return std::chrono::duration(clock::now() - m_start).count();
60 | }
61 |
--------------------------------------------------------------------------------
/common/timer.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include
21 |
22 | class Timer
23 | {
24 | public:
25 | Timer();
26 | void Reset();
27 | int64_t sec() const;
28 | int64_t ms() const;
29 | int64_t us() const;
30 | float fsec() const;
31 | float fms() const;
32 | float fus() const;
33 |
34 | private:
35 | typedef std::chrono::system_clock clock;
36 | clock::time_point m_start;
37 | };
38 |
--------------------------------------------------------------------------------
/common/wait_group.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include "wait_group.h"
19 |
20 | void WaitGroup::Add(int v)
21 | {
22 | bool notify;
23 | {
24 | std::lock_guard lock(m_mutex);
25 | m_counter += v;
26 | if (m_counter < 0) {
27 | throw std::runtime_error("WaitGroup::Add(): m_counter < 0");
28 | }
29 | notify = (m_counter == 0);
30 | }
31 | if (notify) {
32 | m_cond.notify_all();
33 | }
34 | }
35 |
36 | void WaitGroup::Done()
37 | {
38 | Add(-1);
39 | }
40 |
41 | bool WaitGroup::Wait(int64_t timeout_us)
42 | {
43 | std::unique_lock lock(m_mutex);
44 | if (timeout_us < 0) {
45 | m_cond.wait(lock, [this]{ return m_counter == 0; });
46 | return true;
47 | } else {
48 | return m_cond.wait_for(lock, std::chrono::microseconds(timeout_us), [this]{ return m_counter == 0; });
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/common/wait_group.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include
21 | #include
22 |
23 | class WaitGroup
24 | {
25 | public:
26 | void Add(int v = 1);
27 | void Done();
28 | bool Wait(int64_t timeout_us = -1);
29 |
30 | private:
31 | int m_counter = 0;
32 | std::mutex m_mutex;
33 | std::condition_variable m_cond;
34 | };
35 |
--------------------------------------------------------------------------------
/configure:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -e
4 | set -o pipefail
5 |
6 | if [ -z "$PYTHON_BIN_PATH" ]; then
7 | PYTHON_BIN_PATH=$(which python || which python3 || true)
8 | fi
9 |
10 | # Set all env variables
11 | CONFIGURE_DIR=$(dirname "$0")
12 | "$PYTHON_BIN_PATH" "${CONFIGURE_DIR}/configure.py" "$@"
13 |
14 | echo "Configuration finished"
15 |
16 |
--------------------------------------------------------------------------------
/dist/BUILD:
--------------------------------------------------------------------------------
1 | load("//:rules.bzl", "cc_proto_library", "tf_cc_binary")
2 |
3 | tf_cc_binary(
4 | name = "dist_zero_model_server",
5 | srcs = [
6 | "dist_zero_model_server.cc",
7 | ],
8 | deps = [
9 | ":dist_zero_model_cc_proto",
10 | "//common:timer",
11 | "//model:zero_model",
12 | "//model:trt_zero_model",
13 | "@com_github_google_glog//:glog",
14 | ],
15 | )
16 |
17 | cc_library(
18 | name = "dist_zero_model_client",
19 | srcs = ["dist_zero_model_client.cc"],
20 | hdrs = ["dist_zero_model_client.h"],
21 | deps = [
22 | ":dist_config_cc_proto",
23 | ":dist_zero_model_cc_proto",
24 | ":leaky_bucket",
25 | "//model:zero_model_base",
26 | "@com_github_google_glog//:glog",
27 | ],
28 | visibility = ["//visibility:public"],
29 | )
30 |
31 | cc_library(
32 | name = "async_dist_zero_model_client",
33 | srcs = ["async_dist_zero_model_client.cc"],
34 | hdrs = ["async_dist_zero_model_client.h"],
35 | deps = [
36 | ":dist_config_cc_proto",
37 | ":dist_zero_model_cc_proto",
38 | ":async_rpc_queue",
39 | ":leaky_bucket",
40 | "//model:zero_model_base",
41 | "@com_github_google_glog//:glog",
42 | ],
43 | visibility = ["//visibility:public"],
44 | )
45 |
46 | cc_proto_library(
47 | name = "dist_config_cc_proto",
48 | srcs = ["dist_config.proto"],
49 | visibility = ["//visibility:public"],
50 | )
51 |
52 | cc_proto_library(
53 | name = "dist_zero_model_cc_proto",
54 | srcs = ["dist_zero_model.proto"],
55 | deps = ["//model:model_config_cc_proto"],
56 | use_grpc_plugin = True,
57 | )
58 |
59 | cc_library(
60 | name = "async_rpc_queue",
61 | hdrs = ["async_rpc_queue.h"],
62 | deps = ["@grpc//:grpc++_unsecure"],
63 | )
64 |
65 | cc_library(
66 | name = "leaky_bucket",
67 | srcs = ["leaky_bucket.cc"],
68 | hdrs = ["leaky_bucket.h"],
69 | )
70 |
--------------------------------------------------------------------------------
/dist/async_dist_zero_model_client.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include "async_dist_zero_model_client.h"
19 |
20 | #include
21 |
22 | #include
23 |
24 | AsyncDistZeroModelClient::AsyncDistZeroModelClient(const std::vector &svr_addrs,
25 | const DistConfig &dist_config)
26 | : m_config(dist_config),
27 | m_svr_addrs(svr_addrs),
28 | m_bucket_mutexes(new std::mutex[svr_addrs.size()])
29 | {
30 | CHECK(!m_svr_addrs.empty());
31 | for (size_t i = 0; i < m_svr_addrs.size(); ++i) {
32 | m_stubs.emplace_back(DistZeroModel::NewStub(
33 | grpc::CreateChannel(m_svr_addrs[i], grpc::InsecureChannelCredentials())));
34 | m_avail_stubs.push(i);
35 | m_leaky_buckets.emplace_back(m_config.leaky_bucket_size(), m_config.leaky_bucket_refill_period_ms());
36 | }
37 | m_forward_rpc_complete_thread = std::thread(&AsyncRpcQueue::Complete, &m_forward_rpc_queue, -1);
38 | }
39 |
40 | AsyncDistZeroModelClient::~AsyncDistZeroModelClient()
41 | {
42 | m_forward_rpc_queue.Shutdown();
43 | LOG(INFO) << "~AsyncDistZeroModelClient waiting async rpc complete thread stop";
44 | m_forward_rpc_complete_thread.join();
45 | LOG(INFO) << "~AsyncDistZeroModelClient waiting all stubs released";
46 | std::unique_lock lock(m_mutex);
47 | m_cond.wait(lock, [this]{ return m_avail_stubs.size() == m_stubs.size(); });
48 | LOG(INFO) << "~AsyncDistZeroModelClient succ";
49 | }
50 |
51 | int AsyncDistZeroModelClient::Init(const ModelConfig &model_config)
52 | {
53 | AsyncRpcQueue queue;
54 | InitReq req;
55 | req.mutable_model_config()->CopyFrom(model_config);
56 | int ret = 0;
57 | for (size_t i = 0; i < m_stubs.size(); ++i) {
58 | queue.Call(
59 | BindAsyncRpcFunc(*m_stubs[i], AsyncInit), req,
60 | [this, i, &ret](grpc::Status &status, InitResp &resp) {
61 | if (!status.ok()) {
62 | LOG(ERROR) << "DistZeroModel::Init error, " << m_svr_addrs[i] << ", ret "
63 | << status.error_code() << ": " << status.error_message();
64 | ret = status.error_code();
65 | }
66 | }
67 | );
68 | }
69 | queue.Complete(m_stubs.size());
70 | return ret;
71 | }
72 |
73 | int AsyncDistZeroModelClient::GetGlobalStep(int &global_step)
74 | {
75 | AsyncRpcQueue queue;
76 | GetGlobalStepReq req;
77 | int ret = 0;
78 | std::vector global_steps(m_stubs.size());
79 | for (size_t i = 0; i < m_stubs.size(); ++i) {
80 | queue.Call(
81 | BindAsyncRpcFunc(*m_stubs[i], AsyncGetGlobalStep), req,
82 | [this, i, &ret, &global_steps](grpc::Status &status, GetGlobalStepResp &resp) {
83 | if (status.ok()) {
84 | global_steps[i] = resp.global_step();
85 | } else {
86 | LOG(ERROR) << "DistZeroModel::GetGlobalStep error, " << m_svr_addrs[i] << ", ret "
87 | << status.error_code() << ": " << status.error_message();
88 | ret = status.error_code();
89 | }
90 | }
91 | );
92 | }
93 | queue.Complete(m_stubs.size());
94 | if (ret == 0) {
95 | for (size_t i = 1; i < m_stubs.size(); ++i) {
96 | if (global_steps[i] != global_steps[0]) {
97 | LOG(ERROR) << "Recived different global_step, "
98 | << global_steps[i] << "(" << m_svr_addrs[i] << ")" << " vs "
99 | << global_steps[0] << "(" << m_svr_addrs[0] << ")";
100 | ret = ERR_GLOBAL_STEP_CONFLICT;
101 | }
102 | }
103 | }
104 | if (ret == 0) {
105 | global_step = global_steps[0];
106 | }
107 | return ret;
108 | }
109 |
110 | void AsyncDistZeroModelClient::Forward(const std::vector> &inputs, callback_t callback)
111 | {
112 | int stub_id = GetStub();
113 | ForwardReq req;
114 | for (const auto &features: inputs) {
115 | if (features.size() != INPUT_DIM) {
116 | LOG(ERROR) << "Error input dim not match, need " << INPUT_DIM << ", got " << features.size();
117 | callback(ERR_INVALID_INPUT, {}, {});
118 | return;
119 | }
120 | std::string encode_features((INPUT_DIM + 7) / 8, 0);
121 | for (int i = 0; i < INPUT_DIM; ++i) {
122 | encode_features[i / 8] |= (unsigned char)features[i] << (i % 8);
123 | }
124 | req.add_inputs(encode_features);
125 | }
126 | m_forward_rpc_queue.Call(
127 | BindAsyncRpcFunc(*m_stubs[stub_id], AsyncForward), req,
128 | [this, stub_id, callback](grpc::Status &status, ForwardResp &resp) {
129 | if (status.ok() && resp.outputs_size() == 0) {
130 | status = grpc::Status(grpc::StatusCode(ERR_EMPTY_RESP), "receive empty response");
131 | }
132 |
133 | bool release_now = true;
134 | if (!status.ok()) {
135 | m_stubs[stub_id] = DistZeroModel::NewStub(
136 | grpc::CreateChannel(m_svr_addrs[stub_id], grpc::InsecureChannelCredentials()));
137 | if (m_config.enable_leaky_bucket()) {
138 | std::lock_guard lock(m_bucket_mutexes[stub_id]);
139 | m_leaky_buckets[stub_id].ConsumeToken();
140 | release_now = !m_leaky_buckets[stub_id].Empty();
141 | }
142 | }
143 | if (release_now) {
144 | ReleaseStub(stub_id);
145 | } else {
146 | DisableStub(stub_id);
147 | }
148 |
149 | if (status.ok()) {
150 | std::vector> policy;
151 | std::vector value;
152 | for (auto &output: resp.outputs()) {
153 | policy.emplace_back(output.policy().begin(), output.policy().end());
154 | value.push_back(output.value());
155 | }
156 | callback(0, std::move(policy), std::move(value));
157 | } else if (status.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
158 | LOG(ERROR) << "DistZeroModel::Forward timeout, " << m_svr_addrs[stub_id];
159 | callback(ERR_FORWARD_TIMEOUT, {}, {});
160 | } else {
161 | LOG(ERROR) << "DistZeroModel::Forward error, " << m_svr_addrs[stub_id] << " "
162 | << status.error_code() << ": " << status.error_message();
163 | callback(status.error_code(), {}, {});
164 | }
165 | },
166 | m_config.timeout_ms()
167 | );
168 | }
169 |
170 | int AsyncDistZeroModelClient::Forward(const std::vector> &inputs,
171 | std::vector> &policy, std::vector &value)
172 | {
173 | std::promise>, std::vector>> promise;
174 | Forward(inputs, [&promise](int ret, std::vector> policy, std::vector value) {
175 | promise.set_value(std::make_tuple(ret, std::move(policy), std::move(value)));
176 | });
177 | int ret;
178 | std::tie(ret, policy, value) = promise.get_future().get();
179 | return ret;
180 | }
181 |
182 | int AsyncDistZeroModelClient::RpcQueueSize()
183 | {
184 | return m_forward_rpc_queue.Size();
185 | }
186 |
187 | void AsyncDistZeroModelClient::Wait()
188 | {
189 | std::unique_lock lock(m_mutex);
190 | m_cond.wait(lock, [this]{ return !m_avail_stubs.empty(); });
191 | }
192 |
193 | int AsyncDistZeroModelClient::GetStub()
194 | {
195 | std::unique_lock lock(m_mutex);
196 | m_cond.wait(lock, [this]{ return !m_avail_stubs.empty(); });
197 | int stub_id = m_avail_stubs.top();
198 | m_avail_stubs.pop();
199 | return stub_id;
200 | }
201 |
202 | void AsyncDistZeroModelClient::ReleaseStub(int stub_id)
203 | {
204 | {
205 | std::lock_guard lock(m_mutex);
206 | m_avail_stubs.push(stub_id);
207 | }
208 | m_cond.notify_one();
209 | }
210 |
211 | void AsyncDistZeroModelClient::DisableStub(int stub_id)
212 | {
213 | LOG(ERROR) << "disable DistZeroModel " << m_svr_addrs[stub_id];
214 | std::thread([this, stub_id]() {
215 | {
216 | std::lock_guard lock(m_bucket_mutexes[stub_id]);
217 | m_leaky_buckets[stub_id].WaitRefill();
218 | }
219 | ReleaseStub(stub_id);
220 | LOG(INFO) << "reenable DistZeroModel " << m_svr_addrs[stub_id];
221 | }).detach();
222 | }
223 |
--------------------------------------------------------------------------------
/dist/async_dist_zero_model_client.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include
21 | #include
22 | #include
23 | #include
24 | #include
25 |
26 | #include "model/zero_model_base.h"
27 |
28 | #include "dist/async_rpc_queue.h"
29 | #include "dist/leaky_bucket.h"
30 | #include "dist/dist_config.pb.h"
31 | #include "dist/dist_zero_model.grpc.pb.h"
32 |
33 | class AsyncDistZeroModelClient final : public ZeroModelBase
34 | {
35 | public:
36 | AsyncDistZeroModelClient(const std::vector &svr_addrs, const DistConfig &dist_config);
37 |
38 | ~AsyncDistZeroModelClient() override;
39 |
40 | int Init(const ModelConfig &model_config) override;
41 |
42 | int Forward(const std::vector>& inputs,
43 | std::vector> &policy, std::vector &value) override;
44 |
45 | void Forward(const std::vector> &inputs, callback_t callback) override;
46 |
47 | int GetGlobalStep(int &global_step) override;
48 |
49 | int RpcQueueSize() override;
50 |
51 | void Wait() override;
52 |
53 | private:
54 | int GetStub();
55 |
56 | void ReleaseStub(int stub_id);
57 |
58 | void DisableStub(int stub_id);
59 |
60 | private:
61 | DistConfig m_config;
62 | std::vector m_svr_addrs;
63 | std::vector> m_stubs;
64 | AsyncRpcQueue m_forward_rpc_queue;
65 | std::thread m_forward_rpc_complete_thread;
66 |
67 | std::stack m_avail_stubs;
68 | std::mutex m_mutex;
69 | std::condition_variable m_cond;
70 |
71 | std::vector m_leaky_buckets;
72 | std::unique_ptr m_bucket_mutexes;
73 | };
74 |
--------------------------------------------------------------------------------
/dist/async_rpc_queue.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include
19 | #include
20 |
21 | #include
22 |
23 | template
24 | using AsyncRpcFunc = std::function>
25 | (grpc::ClientContext*, const Req&, grpc::CompletionQueue*)>;
26 |
27 | #define BindAsyncRpcFunc(stub, func) \
28 | std::bind(&std::remove_reference::type::func, stub, \
29 | std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)
30 |
31 | template
32 | using AsyncRpcCallback = std::function;
33 |
34 | struct AsyncClientCallBase
35 | {
36 | virtual ~AsyncClientCallBase() {}
37 | virtual void Complete() = 0;
38 | };
39 |
40 | template
41 | struct AsyncClientCall: public AsyncClientCallBase
42 | {
43 | Resp resp;
44 | grpc::ClientContext context;
45 | grpc::Status status;
46 | std::unique_ptr> response_reader;
47 | AsyncRpcCallback callback;
48 |
49 | void Complete() override { callback(status, resp); }
50 | };
51 |
52 | class AsyncRpcQueue
53 | {
54 |
55 | public:
56 | AsyncRpcQueue()
57 | : m_size(0), m_is_shutdown(false)
58 | {
59 | }
60 |
61 | template
62 | void Call(AsyncRpcFunc fn, const Req &req, AsyncRpcCallback callback, int timeout_ms = -1)
63 | {
64 | auto call = new AsyncClientCall;
65 | call->callback = callback;
66 | if (timeout_ms > 0) {
67 | call->context.set_deadline(std::chrono::system_clock::now() + std::chrono::milliseconds(timeout_ms));
68 | }
69 | call->response_reader = fn(&call->context, req, &m_cq);
70 | call->response_reader->Finish(&call->resp, &call->status, (void*)call);
71 |
72 | ++m_size;
73 | }
74 |
75 | void Complete(int n = -1)
76 | {
77 | void *got_tag;
78 | bool ok = false;
79 | for (int i = 0; (n < 0 || i < n) && !m_is_shutdown && m_cq.Next(&got_tag, &ok); ++i) {
80 | std::unique_ptr call(static_cast(got_tag));
81 | --m_size;
82 | call->Complete();
83 | }
84 | }
85 |
86 | void Shutdown()
87 | {
88 | m_is_shutdown = true;
89 | m_cq.Shutdown();
90 | }
91 |
92 | int Size()
93 | {
94 | return m_size;
95 | }
96 |
97 | private:
98 | grpc::CompletionQueue m_cq;
99 | std::atomic m_size;
100 | std::atomic m_is_shutdown;
101 | };
102 |
--------------------------------------------------------------------------------
/dist/dist_config.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | message DistConfig {
4 | int32 timeout_ms = 1;
5 | bool enable_leaky_bucket = 2;
6 | int32 leaky_bucket_size = 3;
7 | int32 leaky_bucket_refill_period_ms = 4;
8 | }
9 |
--------------------------------------------------------------------------------
/dist/dist_zero_model.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | import "model/model_config.proto";
4 |
5 | message ModelOutput {
6 | repeated float policy = 1;
7 | float value = 2;
8 | }
9 |
10 | message InitReq { ModelConfig model_config = 1; }
11 | message InitResp {}
12 |
13 | message GetGlobalStepReq {}
14 | message GetGlobalStepResp { int32 global_step = 1; }
15 |
16 | message ForwardReq { repeated bytes inputs = 1; }
17 | message ForwardResp { repeated ModelOutput outputs = 1; }
18 |
19 | service DistZeroModel {
20 | rpc Init (InitReq) returns (InitResp) {}
21 | rpc GetGlobalStep (GetGlobalStepReq) returns (GetGlobalStepResp) {}
22 | rpc Forward (ForwardReq) returns (ForwardResp) {}
23 | }
24 |
--------------------------------------------------------------------------------
/dist/dist_zero_model_client.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include "dist_zero_model_client.h"
19 |
20 | #include
21 |
22 | #include
23 | #include
24 |
25 | DistZeroModelClient::DistZeroModelClient(const std::string &server_address, const DistConfig &dist_config)
26 | : m_config(dist_config),
27 | m_server_address(server_address),
28 | m_stub(DistZeroModel::NewStub(grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()))),
29 | m_leaky_bucket(dist_config.leaky_bucket_size(), dist_config.leaky_bucket_refill_period_ms())
30 | {
31 | }
32 |
33 | int DistZeroModelClient::Init(const ModelConfig &model_config)
34 | {
35 | InitReq req;
36 | InitResp resp;
37 |
38 | req.mutable_model_config()->CopyFrom(model_config);
39 |
40 | grpc::ClientContext context;
41 | grpc::Status status = m_stub->Init(&context, req, &resp);
42 |
43 | if (status.ok()) {
44 | return 0;
45 | } else {
46 | LOG(ERROR) << "DistZeroModel::Init error, " << m_server_address << ", ret "
47 | << status.error_code() << ": " << status.error_message();
48 | return status.error_code();
49 | }
50 | }
51 |
52 | int DistZeroModelClient::GetGlobalStep(int &global_step)
53 | {
54 | GetGlobalStepReq req;
55 | GetGlobalStepResp resp;
56 |
57 | grpc::ClientContext context;
58 | grpc::Status status = m_stub->GetGlobalStep(&context, req, &resp);
59 |
60 | if (status.ok()) {
61 | global_step = resp.global_step();
62 | return 0;
63 | } else {
64 | LOG(ERROR) << "DistZeroModel::GetGlobalStep error, " << m_server_address << ", ret "
65 | << status.error_code() << ": " << status.error_message();
66 | return status.error_code();
67 | }
68 | }
69 |
70 | int DistZeroModelClient::Forward(const std::vector>& inputs,
71 | std::vector> &policy, std::vector &value)
72 | {
73 | ForwardReq req;
74 | ForwardResp resp;
75 |
76 | for (const auto &features: inputs) {
77 | if (features.size() != INPUT_DIM) {
78 | LOG(ERROR) << "Error input dim not match, need " << INPUT_DIM << ", got " << features.size();
79 | return ERR_INVALID_INPUT;
80 | }
81 | std::string encode_features((INPUT_DIM + 7) / 8, 0);
82 | for (int i = 0; i < INPUT_DIM; ++i) {
83 | encode_features[i / 8] |= (unsigned char)features[i] << (i % 8);
84 | }
85 | req.add_inputs(encode_features);
86 | }
87 |
88 | grpc::ClientContext context;
89 | if (m_config.timeout_ms() > 0) {
90 | context.set_deadline(std::chrono::system_clock::now() + std::chrono::milliseconds(m_config.timeout_ms()));
91 | }
92 | grpc::Status status = m_stub->Forward(&context, req, &resp);
93 |
94 | if (!status.ok()) {
95 | m_stub = DistZeroModel::NewStub(grpc::CreateChannel(m_server_address, grpc::InsecureChannelCredentials()));
96 | if (m_config.enable_leaky_bucket()) {
97 | m_leaky_bucket.ConsumeToken();
98 | }
99 | }
100 |
101 | if (status.ok()) {
102 | policy.clear();
103 | value.clear();
104 | for (auto &output: resp.outputs()) {
105 | policy.emplace_back(output.policy().begin(), output.policy().end());
106 | value.push_back(output.value());
107 | }
108 | return 0;
109 | } else if (status.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
110 | LOG(ERROR) << "DistZeroModel::Forward timeout, " << m_server_address;
111 | return ERR_FORWARD_TIMEOUT;
112 | } else {
113 | LOG(ERROR) << "DistZeroModel::Forward error, " << m_server_address << " "
114 | << status.error_code() << ": " << status.error_message();
115 | return status.error_code();
116 | }
117 | }
118 |
119 | void DistZeroModelClient::Wait()
120 | {
121 | if (m_config.enable_leaky_bucket() && m_leaky_bucket.Empty()) {
122 | LOG(ERROR) << "disable DistZeroModel " << m_server_address;
123 | m_leaky_bucket.WaitRefill();
124 | LOG(INFO) << "reenable DistZeroModel " << m_server_address;
125 | }
126 | }
127 |
--------------------------------------------------------------------------------
/dist/dist_zero_model_client.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include
21 |
22 | #include "model/zero_model_base.h"
23 |
24 | #include "dist/leaky_bucket.h"
25 | #include "dist/dist_config.pb.h"
26 | #include "dist/dist_zero_model.grpc.pb.h"
27 |
28 | class DistZeroModelClient final : public ZeroModelBase
29 | {
30 | public:
31 | DistZeroModelClient(const std::string &server_adress, const DistConfig &dist_config);
32 |
33 | int Init(const ModelConfig &model_config) override;
34 |
35 | int Forward(const std::vector>& inputs,
36 | std::vector> &policy, std::vector &value) override;
37 |
38 | int GetGlobalStep(int &global_step) override;
39 |
40 | void Wait() override;
41 |
42 | private:
43 | DistConfig m_config;
44 | std::string m_server_address;
45 | std::unique_ptr m_stub;
46 | LeakyBucket m_leaky_bucket;
47 | };
48 |
--------------------------------------------------------------------------------
/dist/dist_zero_model_server.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include
19 | #include
20 | #include
21 |
22 | #include "common/timer.h"
23 | #include "model/zero_model.h"
24 | #include "model/trt_zero_model.h"
25 |
26 | #include "dist/dist_zero_model.grpc.pb.h"
27 |
28 | DEFINE_string(server_address, "", "Server address.");
29 | DEFINE_int32(gpu, 0, "Use which gpu.");
30 |
31 | class DistZeroModelServiceImpl final : public DistZeroModel::Service
32 | {
33 | public:
34 | grpc::Status Init(grpc::ServerContext *context, const InitReq *req, InitResp *resp) override
35 | {
36 | std::lock_guard lock(m_mutex);
37 |
38 | LOG(INFO) << "Init with config: " << req->model_config().DebugString();
39 |
40 | if (req->model_config().enable_mkl()) {
41 | ZeroModel::SetMKLEnv(req->model_config());
42 | }
43 |
44 | m_model.reset(new ZeroModel(FLAGS_gpu));
45 | if (req->model_config().enable_tensorrt()) {
46 | m_model.reset(new TrtZeroModel(FLAGS_gpu));
47 | }
48 |
49 | int ret = m_model->Init(req->model_config());
50 | if (ret == 0) {
51 | LOG(INFO) << "Init model succ";
52 | return grpc::Status::OK;
53 | } else {
54 | LOG(ERROR) << "Init model error: " << ret;
55 | return grpc::Status(grpc::StatusCode(ret), "Init model error");
56 | }
57 | }
58 |
59 | grpc::Status GetGlobalStep(grpc::ServerContext *context,
60 | const GetGlobalStepReq *req, GetGlobalStepResp *resp) override
61 | {
62 | std::lock_guard lock(m_mutex);
63 |
64 | if (m_model == nullptr) {
65 | return grpc::Status(grpc::StatusCode(-1), "DistZeroModel hasn't init");
66 | }
67 |
68 | int global_step;
69 | int ret = m_model->GetGlobalStep(global_step);
70 |
71 | if (ret == 0) {
72 | LOG(INFO) << "Get global_step=" << global_step;
73 | resp->set_global_step(global_step);
74 | return grpc::Status::OK;
75 | } else {
76 | LOG(INFO) << "Get global_step error: " << ret;
77 | return grpc::Status(grpc::StatusCode(ret), "Get global_step error");
78 | }
79 | }
80 |
81 | grpc::Status Forward(grpc::ServerContext *context, const ForwardReq *req, ForwardResp *resp) override
82 | {
83 | Timer timer;
84 | std::lock_guard lock(m_mutex);
85 |
86 | if (m_model == nullptr) {
87 | return grpc::Status(grpc::StatusCode(-1), "DistZeroModel hasn't init");
88 | }
89 |
90 | std::vector> inputs;
91 | for (const auto &encode_features: req->inputs()) {
92 | if ((int)encode_features.size() * 8 < m_model->INPUT_DIM) {
93 | LOG(ERROR) << "Error input features need " << m_model->INPUT_DIM << " bits, recv only "
94 | << encode_features.size() * 8;
95 | return grpc::Status(grpc::StatusCode(ERR_INVALID_INPUT), "Forward error");
96 | }
97 | std::vector features(m_model->INPUT_DIM);
98 | for (int i = 0; i < m_model->INPUT_DIM; ++i) {
99 | features[i] = (unsigned char)encode_features[i / 8] >> (i % 8) & 1;
100 | }
101 | inputs.push_back(std::move(features));
102 | }
103 |
104 | std::vector> policy;
105 | std::vector value;
106 | int ret = m_model->Forward(inputs, policy, value);
107 |
108 | if (ret == 0) {
109 | for (size_t i = 0; i < policy.size(); ++i) {
110 | auto *output = resp->add_outputs();
111 | for (const auto &p: policy[i]) {
112 | output->add_policy(p);
113 | }
114 | output->set_value(value[i]);
115 | }
116 | if (resp->outputs_size() == 0) LOG(ERROR) << "input batch size " << inputs.size() << ", output 0!!!";
117 | LOG_EVERY_N(INFO, 1000) << "Forward succ.";
118 | return grpc::Status::OK;
119 | } else {
120 | LOG(ERROR) << "Forward error: " << ret;
121 | return grpc::Status(grpc::StatusCode(ret), "Forward error");
122 | }
123 | }
124 |
125 | private:
126 | std::unique_ptr m_model;
127 | std::mutex m_mutex;
128 | };
129 |
130 | int main(int argc, char *argv[])
131 | {
132 | google::ParseCommandLineFlags(&argc, &argv, true);
133 | google::InitGoogleLogging(argv[0]);
134 | google::InstallFailureSignalHandler();
135 |
136 | DistZeroModelServiceImpl service;
137 |
138 | grpc::ServerBuilder builder;
139 | builder.AddListeningPort(FLAGS_server_address, grpc::InsecureServerCredentials());
140 | builder.RegisterService(&service);
141 | std::unique_ptr server(builder.BuildAndStart());
142 | LOG(INFO) << "Server listening on " << FLAGS_server_address;
143 | server->Wait();
144 | }
145 |
--------------------------------------------------------------------------------
/dist/leaky_bucket.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include
19 |
20 | #include "leaky_bucket.h"
21 |
22 | LeakyBucket::LeakyBucket(int bucket_size, int refill_period_ms)
23 | : m_bucket_size(bucket_size), m_tokens(bucket_size),
24 | m_refill_period(std::chrono::milliseconds(refill_period_ms)),
25 | m_last_refill(clock::now())
26 | {
27 | }
28 |
29 | void LeakyBucket::ConsumeToken()
30 | {
31 | auto now = clock::now();
32 | if (now - m_last_refill > m_refill_period) {
33 | m_last_refill = now;
34 | m_tokens = m_bucket_size;
35 | }
36 | --m_tokens;
37 | }
38 |
39 | bool LeakyBucket::Empty()
40 | {
41 | return m_tokens <= 0;
42 | }
43 |
44 | void LeakyBucket::WaitRefill()
45 | {
46 | if (m_tokens <= 0) {
47 | std::this_thread::sleep_until(m_last_refill + m_refill_period);
48 | m_last_refill = clock::now();
49 | m_tokens = m_bucket_size;
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/dist/leaky_bucket.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include
21 |
22 | class LeakyBucket {
23 | public:
24 | LeakyBucket(int bucket_size, int refill_period_ms);
25 | void ConsumeToken();
26 | bool Empty();
27 | void WaitRefill();
28 |
29 | private:
30 | int m_bucket_size;
31 | int m_tokens;
32 | typedef std::chrono::system_clock clock;
33 | clock::duration m_refill_period;
34 | clock::time_point m_last_refill;
35 | };
36 |
--------------------------------------------------------------------------------
/docs/benchmark-teslaV100.md:
--------------------------------------------------------------------------------
1 | # Benchmark setup :
2 |
3 | ## setup
4 | - hardware : google cloud machine with Tesla V100-SXM2-16GB, 4 and
5 | 12vcpu (skylake server or later, support avx512, google cloud platform
6 | currently does not allow more than 12 vcpu per die, so maximum for 1
7 | GPU is 12 vcpu = 6 physical cores / 12 cpu threads),
8 | 16gb system ram, 40 gb hdd
9 | - software : ubuntu 18.04 LTS, nvidia 410, cuda 10.0, cudnn 7.4.2,
10 | no tensorrt, bazel 0.11.1
11 | - engine settings: limited time 60 seconds per move, all other time
12 | management settings disabled in config file, all the rest is default
13 | settings
14 |
15 | ## methodology :
16 | - most moves come from the same game played using
17 | [gtp2ogs](https://github.com/online-go/gtp2ogs), for few moves moves,
18 | copy paste stderr output
19 | - tensorRT is not used with the V100 here, because it would need to
20 | build our own tensor model, which was not done here, see
21 | [FAQ question](#a13-i-have-a-nvidia-rtx-card-turing-or-tesla-v100titan-v-volta-is-it-compatible-)
22 | for details
23 |
24 | ## credits :
25 | - credit for doing this tests go to
26 | [wonderingabout](https://github.com/wonderingabout)
27 | - credit for providing the hardware goes to google cloud
28 | platform
29 |
30 | # BATCH SIZE 4
31 |
32 | - batch size 4
33 | - tensorrt : OFF
34 | - 8 threads
35 | - children : 64
36 | - 400M tree size
37 | - unlimited sims
38 | - 60s per move
39 |
40 | ### 4 vcpu (2 physical cores/ 4 cpu threads) :
41 |
42 | `
43 | stderr: 4th move(w): pp, winrate=56.125683%, N=20226, Q=0.122514, p=0.728064, v=0.114009, cost 60014.109375ms, sims=22976, height=46, avg_height=12.719517, global_step=639200
44 | `
45 |
46 | ### 12 vcpu (6 physical cores/ 12 cpu threads) :
47 |
48 | `
49 | stderr: 2th move(w): pd, winrate=56.061207%, N=6544, Q=0.121224, p=0.212260, v=0.106201, cost 60021.523438ms, sims=23096, height=30, avg_height=9.461098, global_step=639200
50 | `
51 |
52 | # BATCH SIZE 8
53 |
54 | - batch size 8
55 | - tensorrt : OFF
56 | - 16 threads
57 | - children : 96
58 | - 2000M tree size
59 | - unlimited sims
60 | - 60s per move
61 |
62 | ### 4 vcpu (2 physical cores/ 4 cpu threads) :
63 |
64 | `
65 | stderr: 8th move(w): nq, winrate=56.569016%, N=27010, Q=0.131380, p=0.591054, v=0.117058, cost 60036.335938ms, sims=32152, height=59, avg_height=14.057215, global_step=639200
66 | `
67 |
68 | ### 12 vcpu (6 physical cores/ 12 cpu threads) :
69 |
70 | `
71 | stderr: 4th move(w): pp, winrate=56.120705%, N=28938, Q=0.122414, p=0.715722, v=0.114932, cost 60016.148438ms, sims=32728, height=54, avg_height=13.301727, global_step=639200
72 | `
73 |
74 | # BATCH SIZE 16
75 |
76 | - batch size 16
77 | - tensorrt : OFF
78 | - 32 threads
79 | - children : 128
80 | - 2000M tree size
81 | - unlimited sims
82 | - 60s per move
83 |
84 | ### 4 vcpu (2 physical cores/ 4 cpu threads) :
85 |
86 | `
87 | stderr: 2th move(w): dp, winrate=56.057503%, N=15696, Q=0.121150, p=0.207913, v=0.105841, cost 60048.324219ms, sims=53568, height=34, avg_height=9.826570, global_step=639200
88 | `
89 |
90 | ### 12 vcpu (6 physical cores/ 12 cpu threads) :
91 |
92 | `
93 | stderr: 6th move(w): qn, winrate=56.170525%, N=29628, Q=0.123410, p=0.306212, v=0.111480, cost 60020.058594ms, sims=53968, height=71, avg_height=14.943110, global_step=639200
94 | `
95 |
96 | # BATCH SIZE 32
97 |
98 | - batch size 32
99 | - tensorrt : OFF
100 | - 64 threads
101 | - children : 128
102 | - 2000M tree size
103 | - unlimited sims
104 | - 60s per move
105 |
106 | ### 4 vcpu (2 physical cores/ 4 cpu threads) :
107 |
108 | `
109 | stderr: 10th move(w): cp, winrate=65.475777%, N=58821, Q=0.309516, p=0.886150, v=0.196629, cost 60111.078125ms, sims=59444, height=53, avg_height=10.118464, global_step=639200
110 | `
111 |
112 | ### 12 vcpu (6 physical cores/ 12 cpu threads) :
113 |
114 | `
115 | stderr: 8th move(w): qf, winrate=56.714546%, N=55717, Q=0.134291, p=0.613282, v=0.114160, cost 60048.957031ms, sims=62368, height=64, avg_height=13.618808, global_step=639200
116 | `
117 |
118 | # BATCH SIZE 64
119 |
120 | - batch size 64
121 | - tensorrt : OFF
122 | - 32 threads
123 | - children : 128
124 | - 2000M tree size
125 | - unlimited sims
126 | - 60s per move
127 |
128 | ### 4 vcpu (2 physical cores/ 4 cpu threads) :
129 |
130 | `
131 | stderr: 12th move(w): bo, winrate=65.815079%, N=64431, Q=0.316302, p=0.884180, v=0.226788, cost 60263.683594ms, sims=64960, height=54, avg_height=12.711725, global_step=639200
132 | `
133 |
134 | ### 12 vcpu (6 physical cores/ 12 cpu threads) :
135 |
136 | `
137 | stderr: 10th move(w): pc, winrate=65.360603%, N=67943, Q=0.307212, p=0.887373, v=0.175454, cost 60165.914062ms, sims=69031, height=63, avg_height=7.840148, global_step=639200
138 | `
139 |
140 | # BATCH SIZE 128
141 |
142 | - batch size 128
143 | - tensorrt : OFF
144 | - 256 threads
145 | - children : 128
146 | - 2000M tree size
147 | - unlimited sims
148 | - 60s per move
149 |
150 | ### 4 vcpu (2 physical cores/ 4 cpu threads) :
151 |
152 | `
153 | stderr: 16th move(w): rf, winrate=65.895859%, N=66225, Q=0.317917, p=0.937327, v=0.202470, cost 60232.250000ms, sims=67328, height=44, avg_height=10.560142, global_step=639200
154 | `
155 |
156 | ### 12 vcpu (6 physical cores/ 12 cpu threads) :
157 |
158 | `
159 | stderr: 12th move(w): ob, winrate=65.983253%, N=70697, Q=0.319665, p=0.920173, v=0.223816, cost 60312.035156ms, sims=71664, height=49, avg_height=6.786881, global_step=639200
160 | `
161 |
162 | # CONCLUSIONS for Tesla V100 :
163 |
164 | - all the conclusions below are without tensorRT optimization, which
165 | is known to bring 15-30% extra computation performance depending on
166 | hardware and settings :
167 |
168 | ```
169 | -> for batch size 4 to 8 , gain = +43%
170 | -> for batch size 8 to 16 , gain = +60%
171 | -> for batch size 16 to 32 , gain = +14%
172 | -> for batch size 32 to 64 , gain = +10%
173 | -> for batch size 64 to 128 , gain = +3.7%
174 | ```
175 |
176 | - batch sizes 8 and 16 significant great increases speed on Tesla
177 | V100 with 6 cores / 12 cpu threads or less
178 | - batch sizes higher 16 to 64 bring significant small increase speed
179 | on Tesla V100 with 6 cores / 12 cpu threads or less, but considering
180 | the loss of computing accuracy, this is not an efficient choice
181 | - therefore, the most efficient batch size seems to be 16, providing
182 | **an average 900 simulations per second on Tesla V100**, and a 135%
183 | speed increase as compared to batch size 4
184 | - batch size higher than 64 do not bring significant speed increases on
185 | Tesla V100 with 6 cores / 12 cpu threads or less
186 |
187 | - on the CPU side, as of February 2019, PhoenixGo engine does not
188 | significantly benefit from a number of cpu threads higher than 2
189 | cpu cores/ 4 cpu threads, even on Tesla V100
190 |
191 | For comparison, you can refer to
192 | [gtx-1060-75w-benchmark](benchmark-gtx1060-75w.md)
193 |
--------------------------------------------------------------------------------
/docs/go-review-partner.md:
--------------------------------------------------------------------------------
1 | GoReviewPartner github page is [here](https://github.com/pnprog/goreviewpartner)
2 |
3 | Support is still in testing, but so far these are the recommended
4 | settings to do use PhoenixGo with GoReviewPartner
5 |
6 | 3 config files are provided to you using optimized settings for grp
7 | (GoReviewPartner) :
8 | - GPU with tensorRT (linux only) :
9 | [mcts_1gpu_grp.conf](/etc/mcts_1gpu_grp.conf)
10 | - GPU without tensorRT (linux, windows) :
11 | [mcts_1gpu_notensorrt_grp.conf](/etc/mcts_1gpu_notensorrt_grp.conf)
12 | - CPU-only (linux, mac, windows), much slower :
13 | [mcts_cpu_grp.conf](/etc/mcts_cpu_grp.conf)
14 |
15 | note : it is also possible with multiple GPU, only the most common
16 | examples were shown here
17 |
18 | if you want to do the changes manually, you need to :
19 |
20 | in phoenixgo .conf config file :
21 |
22 | - disable all time settings in config file (set to `0`)
23 | - set time to unlimited (set timeout to 0 ms)
24 | - use simulations per move to have fixed computation per move (playouts),
25 | because it is not needed to play fast since this is an analysis,
26 | unlimited time
27 | - for the same reason, set `enable background search` to 0 (disable
28 | pondering), because this is not a live game, time settings are not
29 | needed and can cause conflicts
30 | - add debug width and height with `debugger` in config file, see
31 | [FAQ question](/docs/FAQ.md/#a2-where-is-the-pv-analysis-)
32 |
33 | in grp (GoReviewPartner) settings :
34 | - it is easier to use one of the pre-made grp profile in config.ini
35 | (slightly modify them if needed)
36 |
37 | - Don't forget to add paths in the config file as explained in
38 | [FAQ question](/docs/FAQ.md/#a5-ckptzerockpt-20b-v1fp32plan-error-no-such-file-or-directory)
39 |
40 | - If you're on windows, you need to also pay attention to the syntax
41 | too, for example on windows it is `--logtostderr --v 1` not
42 | `--logtostderr --v=1` , see
43 | [FAQ question](/docs/FAQ.md/#a4-syntax-error-windows) for details and
44 | other differences
45 |
46 | - and run the mcts_main (not start.sh) with the needed parameters,
47 | see an example
48 | [here](https://github.com/wonderingabout/goreviewpartner/blob/config-profiles-phoenixgo/config.ini#L100-L116)
49 |
50 | also, see [#86](https://github.com/Tencent/PhoenixGo/issues/86) and
51 | [#99](https://github.com/pnprog/goreviewpartner/issues/99) for details
52 |
--------------------------------------------------------------------------------
/docs/mcts-main-help.md:
--------------------------------------------------------------------------------
1 | For your convenience, a copy of `./mcts_main --help` output is provided to
2 | you below
3 |
4 | Date is January 2019, so it may be outdated if newer versions are released
5 |
6 | `./mcts_main --help`
7 |
8 | Outputs the below result :
9 |
10 | ```
11 | mcts_main: Warning: SetUsageMessage() never called
12 |
13 | Flags from external/com_github_gflags_gflags/src/gflags.cc:
14 | -flagfile (load flags from file) type: string default: ""
15 | -fromenv (set flags from the environment [use 'export FLAGS_flag1=value'])
16 | type: string default: ""
17 | -tryfromenv (set flags from the environment if present) type: string
18 | default: ""
19 | -undefok (comma-separated list of flag names that it is okay to specify on
20 | the command line even if the program does not define a flag with that
21 | name. IMPORTANT: flags in this list that have arguments MUST use the
22 | flag=value format) type: string default: ""
23 |
24 | Flags from external/com_github_gflags_gflags/src/gflags_completions.cc:
25 | -tab_completion_columns (Number of columns to use in output for tab
26 | completion) type: int32 default: 80
27 | -tab_completion_word (If non-empty, HandleCommandLineCompletions() will
28 | hijack the process and attempt to do bash-style command line flag
29 | completion on this value.) type: string default: ""
30 |
31 | Flags from external/com_github_gflags_gflags/src/gflags_reporting.cc:
32 | -help (show help on all flags [tip: all flags can have two dashes])
33 | type: bool default: false currently: true
34 | -helpfull (show help on all flags -- same as -help) type: bool
35 | default: false
36 | -helpmatch (show help on modules whose name contains the specified substr)
37 | type: string default: ""
38 | -helpon (show help on the modules named by this flag value) type: string
39 | default: ""
40 | -helppackage (show help on all modules in the main package) type: bool
41 | default: false
42 | -helpshort (show help on only the main module for this program) type: bool
43 | default: false
44 | -helpxml (produce an xml version of help) type: bool default: false
45 | -version (show version and build info and exit) type: bool default: false
46 |
47 |
48 |
49 | Flags from external/com_github_google_glog/src/logging.cc:
50 | -alsologtoemail (log messages go to these email addresses in addition to
51 | logfiles) type: string default: ""
52 | -alsologtostderr (log messages go to stderr in addition to logfiles)
53 | type: bool default: false
54 | -colorlogtostderr (color messages logged to stderr (if supported by
55 | terminal)) type: bool default: false
56 | -drop_log_memory (Drop in-memory buffers of log contents. Logs can grow
57 | very quickly and they are rarely read before they need to be evicted from
58 | memory. Instead, drop them from memory as soon as they are flushed to
59 | disk.) type: bool default: true
60 | -log_backtrace_at (Emit a backtrace when logging at file:linenum.)
61 | type: string default: ""
62 | -log_dir (If specified, logfiles are written into this directory instead of
63 | the default logging directory.) type: string default: ""
64 | -log_link (Put additional links to the log files in this directory)
65 | type: string default: ""
66 | -log_prefix (Prepend the log prefix to the start of each log line)
67 | type: bool default: true
68 | -logbuflevel (Buffer log messages logged at this level or lower (-1 means
69 | don't buffer; 0 means buffer INFO only; ...)) type: int32 default: 0
70 | -logbufsecs (Buffer log messages for at most this many seconds) type: int32
71 | default: 30
72 | -logemaillevel (Email log messages logged at this level or higher (0 means
73 | email all; 3 means email FATAL only; ...)) type: int32 default: 999
74 | -logfile_mode (Log file mode/permissions.) type: int32 default: 436
75 | -logmailer (Mailer used to send logging email) type: string
76 | default: "/bin/mail"
77 | -logtostderr (log messages go to stderr instead of logfiles) type: bool
78 | default: false
79 | -max_log_size (approx. maximum log file size (in MB). A value of 0 will be
80 | silently overridden to 1.) type: int32 default: 1800
81 | -minloglevel (Messages logged at a lower level than this don't actually get
82 | logged anywhere) type: int32 default: 0
83 | -stderrthreshold (log messages at or above this level are copied to stderr
84 | in addition to logfiles. This flag obsoletes --alsologtostderr.)
85 | type: int32 default: 2
86 | -stop_logging_if_full_disk (Stop attempting to log to disk if the disk is
87 | full.) type: bool default: false
88 |
89 | Flags from external/com_github_google_glog/src/vlog_is_on.cc:
90 | -v (Show all VLOG(m) messages for m <= this. Overridable by --vmodule.)
91 | type: int32 default: 0
92 | -vmodule (per-module verbose level. Argument is a comma-separated list of
93 | =. is a glob pattern, matched
94 | against the filename base (that is, name ignoring .cc/.h./-inl.h). overrides any value given by --v.) type: string default: ""
96 |
97 |
98 |
99 | Flags from mcts/mcts_main.cc:
100 | -allow_ip (List of client ip allowed to connect, seperated by comma.)
101 | type: string default: ""
102 | -config_path (Path of mcts config file.) type: string default: ""
103 | -fork_per_request (Fork for each request or not.) type: bool default: true
104 | -gpu_list (List of gpus used by neural network.) type: string default: ""
105 | -gtp (Run as gtp server.) type: bool default: false
106 | -init_moves (Initialize Go board with init_moves.) type: string default: ""
107 | -inter_op_parallelism_threads (Number of tf's inter op threads) type: int32
108 | default: 0
109 | -intra_op_parallelism_threads (Number of tf's intra op threads) type: int32
110 | default: 0
111 | -listen_port (Listen which port.) type: int32 default: 0
112 |
113 | ```
114 |
--------------------------------------------------------------------------------
/docs/minimalist-bazel-configure.md:
--------------------------------------------------------------------------------
1 | To reduce phoenixgo size and building time, it is wise to remove all
2 | the uneeded options during the bazel `./configure` , as was discussed
3 | here : [#76](https://github.com/Tencent/PhoenixGo/issues/76)
4 |
5 | This is an example of minimalist options that you can use :
6 |
7 | note : if you have trouble with path configurations, see
8 | [path-errors](/docs/path-errors.md)
9 |
10 | ```
11 | Please specify the location of python. [Default is /usr/bin/python]:
12 |
13 |
14 | Found possible Python library paths:
15 | /usr/local/lib/python2.7/dist-packages
16 | /usr/lib/python2.7/dist-packages
17 | Please input the desired Python library path to use. Default is [/usr/local/lib/python2.7/dist-packages]
18 |
19 | Do you wish to build TensorFlow with jemalloc as malloc support? [Y/n]:
20 | jemalloc as malloc support will be enabled for TensorFlow.
21 |
22 | Do you wish to build TensorFlow with Google Cloud Platform support? [Y/n]: n
23 | No Google Cloud Platform support will be enabled for TensorFlow.
24 |
25 | Do you wish to build TensorFlow with Hadoop File System support? [Y/n]: n
26 | No Hadoop File System support will be enabled for TensorFlow.
27 |
28 | Do you wish to build TensorFlow with Amazon S3 File System support? [Y/n]: n
29 | No Amazon S3 File System support will be enabled for TensorFlow.
30 |
31 | Do you wish to build TensorFlow with Apache Kafka Platform support? [Y/n]: n
32 | No Apache Kafka Platform support will be enabled for TensorFlow.
33 |
34 | Do you wish to build TensorFlow with XLA JIT support? [y/N]: y
35 | XLA JIT support will be enabled for TensorFlow.
36 |
37 | Do you wish to build TensorFlow with GDR support? [y/N]: y
38 | GDR support will be enabled for TensorFlow.
39 |
40 | Do you wish to build TensorFlow with VERBS support? [y/N]: y
41 | VERBS support will be enabled for TensorFlow.
42 |
43 | Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]:
44 | No OpenCL SYCL support will be enabled for TensorFlow.
45 |
46 | Do you wish to build TensorFlow with CUDA support? [y/N]: y
47 | CUDA support will be enabled for TensorFlow.
48 |
49 | Please specify the CUDA SDK version you want to use, e.g. 7.0. [Leave empty to default to CUDA 9.0]:
50 |
51 |
52 | Please specify the location where CUDA 9.0 toolkit is installed. Refer to README.md for more details. [Default is /usr/local/cuda]:
53 |
54 |
55 | Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 7.0]: 7.0.5
56 |
57 |
58 | Please specify the location where cuDNN 7 library is installed. Refer to README.md for more details. [Default is /usr/local/cuda]:
59 |
60 |
61 | Do you wish to build TensorFlow with TensorRT support? [y/N]: y
62 | TensorRT support will be enabled for TensorFlow.
63 |
64 | Please specify the location where TensorRT is installed. [Default is /usr/lib/x86_64-linux-gnu]:
65 |
66 |
67 | Please specify the NCCL version you want to use. [Leave empty to default to NCCL 1.3]:
68 |
69 |
70 | Please specify a list of comma-separated Cuda compute capabilities you want to build with.
71 | You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus.
72 | Please note that each additional compute capability significantly increases your build time and binary size. [Default is: 6.1]
73 |
74 |
75 | Do you want to use clang as CUDA compiler? [y/N]:
76 | nvcc will be used as CUDA compiler.
77 |
78 | Please specify which gcc should be used by nvcc as the host compiler. [Default is /usr/bin/gcc]:
79 |
80 |
81 | Do you wish to build TensorFlow with MPI support? [y/N]:
82 | No MPI support will be enabled for TensorFlow.
83 |
84 | Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]:
85 |
86 |
87 | Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
88 |
89 | ```
90 |
--------------------------------------------------------------------------------
/docs/path-errors.md:
--------------------------------------------------------------------------------
1 | In the example below, ubuntu 16.04 LTS is used with cuda 9.0 (deb install),
2 | cudnn 7.0.5 (deb install), tensorrt 3.0.4 (deb install), as well as bazel 0.18.1
3 | (.sh file install) (0.11.1 also works)
4 |
5 | Other linux distributions with nvidia tar install are possible too
6 | (some other versions have been [tested here](/docs/tested-versions.md)),
7 | but it is easier to do it like that, and remember that this is just an example.
8 | The settings below have been tested to be working and to fix most common path
9 | issues, and are shown as an interactive help :
10 |
11 | ## 1) post-install of cuda : do the path exports
12 |
13 | After cuda 9.0 deb install and cudnn 7.0.5 deb are installed successfully,
14 | one post install step is needed : add the path to cuda-9.0. Here we do an
15 | all in once step by also including the path of cudnn even if cudnn is not
16 | installed yet.
17 |
18 | The `export` command alone adds the path during current boot, but the
19 | changes will be lost after reboot. This is why we will add paths permanently
20 | using `bashrc`
21 |
22 | ### bashrc
23 |
24 | edit bashrc file :
25 |
26 | `sudo nano ~/.bashrc`
27 |
28 | you need to add the lines below at the end of the bashrc file (using nano,
29 | or the text editor of your choice) :
30 |
31 | ```
32 | # add paths for cuda-9.0
33 | export PATH=${PATH}:/usr/local/cuda-9.0/bin
34 |
35 | # in case you use tensorrt tar install
36 | # add paths for cuda and cudnn install paths
37 | export CUDA_INSTALL_DIR=/usr/local/cuda
38 | export CUDNN_INSTALL_DIR=/usr/local/cuda
39 | ```
40 |
41 | With nano, after editing is finished, save and exit with `Ctrl+X` +
42 | then press `y` + then press ENTER key.
43 |
44 | Now update bashrc file :
45 |
46 | `source ~/.bashrc`
47 |
48 | as you can now see, now the command nvcc --version will work successfully :
49 |
50 | `nvcc --version`
51 |
52 | should display something like this :
53 |
54 | ```
55 | nvcc: NVIDIA (R) Cuda compiler driver
56 | Copyright (c) 2005-2017 NVIDIA Corporation
57 | Built on Fri_Sep__1_21:08:03_CDT_2017
58 | Cuda compilation tools, release 9.0, V9.0.176
59 | ```
60 |
61 | you can also check if cuda installation is a success with
62 | `cat /proc/driver/nvidia/version`
63 |
64 | should display something like this :
65 |
66 | ```
67 | NVRM version: NVIDIA UNIX x86_64 Kernel Module 384.130 Wed Mar 21 03:37:26 PDT 2018
68 | GCC version: gcc version 5.4.0 20160609 (Ubuntu 5.4.0-6ubuntu1~16.04.10)
69 | ```
70 |
71 | If you need help of how to install the deb files of cuda 9.0, cudnn 7.0.5,
72 | and tensorrt 3.0.4 for ubuntu 16.04, you can go here
73 | [@wonderingabout](https://github.com/wonderingabout/nvidia-archives)
74 |
75 | ## 1b) after cuda 9.0 and cudnn 7.0.5 deb installs, test your cudnn
76 |
77 | run a test to see if your install works by compiling and runing a
78 | cudnn code sample :
79 |
80 | ```
81 | cp -r /usr/src/cudnn_samples_v7/ ~ && cd ~/cudnn_samples_v7/mnistCUDNN && make clean && make && ./mnistCUDNN
82 | ```
83 |
84 | should display this : `Test passed!`
85 |
86 | ## 2) locate cuda and cudnn paths, and update database if not here
87 |
88 | Run this command : `locate libcudart.so && locate libcudnn.so.7`
89 |
90 | you need to see something like this :
91 |
92 | ```
93 | /usr/local/cuda-9.0/doc/man/man7/libcudart.so.7
94 | /usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so
95 | /usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0
96 | /usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0.176
97 | /usr/lib/x86_64-linux-gnu/libcudnn.so.7
98 | /usr/lib/x86_64-linux-gnu/libcudnn.so.7.0.5
99 | ```
100 | If you don't see this, run this command :
101 |
102 | `sudo updatedb && locate libcudart.so && locate libcudnn.so.7`
103 |
104 | It should now display all the cuda and cudnn paths same as above.
105 |
106 | ## 3) during bazel compile, this is the paths you need to put
107 |
108 | Press ENTER for every prompt to choose default settings
109 | (or `n` if you dont want a setting), except for these :
110 |
111 | - CUDA : choose `y` , version `9.0` if it is not default
112 | (then leave path to default by pressing ENTER)
113 | - cudnn : choose version `7.0.5`
114 | (then leave default path by pressing ENTER)
115 | - if you want to use tensorrt do `y` and press enter to keep default
116 | path if you did tensorrt deb install, but for tar install you need to setup
117 | manually your path, as you can see here
118 | [@wonderingabout](https://github.com/wonderingabout/nvidia-archives)
119 |
120 | same as below :
121 |
122 | ```
123 | Do you wish to build TensorFlow with CUDA support? [y/N]: y
124 | CUDA support will be enabled for TensorFlow.
125 |
126 | Please specify the CUDA SDK version you want to use, e.g. 7.0. [Leave empty to default to CUDA 9.0]:
127 |
128 | Please specify the location where CUDA 9.0 toolkit is installed. Refer to README.md for more details. [Default is /usr/local/cuda]:
129 |
130 | Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 7.0]: 7.0.5
131 |
132 | Please specify the location where cuDNN 7 library is installed. Refer to README.md for more details. [Default is /usr/local/cuda-9.0/]:
133 |
134 | Do you wish to build TensorFlow with TensorRT support? [y/N]: y
135 | TensorRT support will be enabled for TensorFlow.
136 |
137 | Please specify the location where TensorRT is installed. [Default is /usr/lib/x86_64-linux-gnu]:
138 | ```
139 |
140 | note : (tar install is needed to install `.whl` to inrcease max batch size
141 | with tensorrt), see : [#75](https://github.com/Tencent/PhoenixGo/issues/75)
142 |
143 | Final words :
144 |
145 | Remember that these settings are just an example, other settings or
146 | package versions or linux distributions are possible too, but this
147 | example has been tested to successfully work on ubuntu 16.04 LTS with
148 | deb install of cuda 9.0, deb install of cudnn 7.0.5, deb install of
149 | tensorrt 3.0.4, as well as .sh file install of bazel 0.11.1
150 |
151 | They are provided as a general help for linux compile and run, they
152 | are not an obligatory method to use, but will hopefully make using
153 | PhoenixGo on linux systems easier
154 |
155 | credits :
156 | - [mishra.thedeepak](https://medium.com/@mishra.thedeepak/cuda-and-cudnn-installation-for-tensorflow-gpu-79beebb356d2)
157 | - [kezulin.me](https://kezunlin.me/post/dacc4196/)
158 | - [wonderingabout](https://github.com/wonderingabout/nvidia-archives)
159 |
160 | minor sources :
161 | - [nvidia pdf install guide for cuda 9.0](http://developer.download.nvidia.com/compute/cuda/9.0/Prod/docs/sidebar/CUDA_Installation_Guide_Linux.pdf)
162 | - [medium.com/@zhanwenchen/](https://medium.com/@zhanwenchen/install-cuda-and-cudnn-for-tensorflow-gpu-on-ubuntu-79306e4ac04e)
163 | - [nvidia cudnn](https://developer.nvidia.com/rdp/cudnn-archive)
164 | - [nvidia tensorrt](https://developer.nvidia.com/nvidia-tensorrt3-download)
165 |
--------------------------------------------------------------------------------
/docs/tested-versions.md:
--------------------------------------------------------------------------------
1 | Below is a list of other versions of cuda/cudnn/tensorrt/bazel tar/deb
2 | tested by independent contributors, that work with PhoenixGo AI
3 |
4 | compatibility for tensorrt :
5 | - 3.0.4 requires cudnn 7.0.x and +
6 | - 4.x requires cudnn 7.1.x and +
7 | - 5.x requires cudnn 7.3.x and +
8 |
9 | ### tests by [wonderingabout](https://github.com/wonderingabout/) :
10 |
11 | #### works
12 | - bazel : 0.11.1 , 0.17.2 , 0.18.1
13 | - ubuntu : 16.04 , 18.04
14 | - cuda : 9.0 deb , 10.0 deb (10.0 deb ubuntu 18.04 is with cudnn
15 | 7.4.x deb and no tensorrt), 9.0 for windows 10 and windows server
16 | 2016
17 | - cudnn : 7.0.5 deb, 7.1.4 deb , 7.4.2 deb, 7.1.4 on windows 10
18 | with cuda 9.0
19 | - tensorrt : 3.0.4 deb , 3.0.4 tar
20 | - pycuda : for tensorrt tar, pycuda with pyhton 2.7
21 | (needs [a modification](http://0561blue.tistory.com/m/13?category=627413),
22 | pycuda specific issue, see [#75](https://github.com/Tencent/PhoenixGo/issues/75)
23 |
24 | #### does NOT work :
25 | - cuda on windows 10 : cuda 10.0
26 | - cudnn for windows 10 : cudnn 7.4.x for cuda 10.0 , cudnn 7.4.x for
27 | cuda 9.0
28 | - tensorrt : 4.x deb , 4.x tar, 5.x deb
29 | - bazel : 0.19.x and superior
30 | (needs [a modification](https://github.com/tensorflow/tensorflow/issues/23401))
31 |
--------------------------------------------------------------------------------
/etc/mcts_1gpu.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 1
2 | num_search_threads: 8
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | enable_tensorrt: 1
12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN"
13 | }
14 | gpu_list: "0"
15 | c_puct: 2.5
16 | virtual_loss: 1.0
17 | enable_resign: 1
18 | v_resign: -0.9
19 | enable_dirichlet_noise: 0
20 | dirichlet_noise_alpha: 0.03
21 | dirichlet_noise_ratio: 0.25
22 | monitor_log_every_ms: 0
23 | get_best_move_mode: 0
24 | enable_background_search: 1
25 | enable_policy_temperature: 0
26 | policy_temperature: 0.67
27 | inherit_default_act: 1
28 | early_stop {
29 | enable: 1
30 | check_every_ms: 100
31 | sims_factor: 1.0
32 | sims_threshold: 2000
33 | }
34 | unstable_overtime {
35 | enable: 1
36 | time_factor: 0.3
37 | }
38 | behind_overtime {
39 | enable: 1
40 | act_threshold: 0.0
41 | time_factor: 0.3
42 | }
43 | time_control {
44 | enable: 1
45 | c_denom: 20
46 | c_maxply: 40
47 | reserved_time: 1.0
48 | }
49 |
--------------------------------------------------------------------------------
/etc/mcts_1gpu_grp.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 1
2 | num_search_threads: 8
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 0
6 | max_simulations_per_step: 3200
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | enable_tensorrt: 1
12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN"
13 | }
14 | gpu_list: "0"
15 | c_puct: 2.5
16 | virtual_loss: 1.0
17 | enable_resign: 1
18 | v_resign: -0.9
19 | enable_dirichlet_noise: 0
20 | dirichlet_noise_alpha: 0.03
21 | dirichlet_noise_ratio: 0.25
22 | monitor_log_every_ms: 0
23 | get_best_move_mode: 0
24 | enable_background_search: 0
25 | enable_policy_temperature: 0
26 | policy_temperature: 0.67
27 | inherit_default_act: 1
28 | debugger {
29 | print_tree_depth: 20
30 | print_tree_width: 3
31 | }
32 | early_stop {
33 | enable: 0
34 | check_every_ms: 100
35 | sims_factor: 1.0
36 | sims_threshold: 2000
37 | }
38 | unstable_overtime {
39 | enable: 0
40 | time_factor: 0.3
41 | }
42 | behind_overtime {
43 | enable: 0
44 | act_threshold: 0.0
45 | time_factor: 0.3
46 | }
47 | time_control {
48 | enable: 0
49 | c_denom: 20
50 | c_maxply: 40
51 | reserved_time: 1.0
52 | }
53 |
--------------------------------------------------------------------------------
/etc/mcts_1gpu_notensorrt.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 1
2 | num_search_threads: 8
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | }
12 | gpu_list: "0"
13 | c_puct: 2.5
14 | virtual_loss: 1.0
15 | enable_resign: 1
16 | v_resign: -0.9
17 | enable_dirichlet_noise: 0
18 | dirichlet_noise_alpha: 0.03
19 | dirichlet_noise_ratio: 0.25
20 | monitor_log_every_ms: 0
21 | get_best_move_mode: 0
22 | enable_background_search: 1
23 | enable_policy_temperature: 0
24 | policy_temperature: 0.67
25 | inherit_default_act: 1
26 | early_stop {
27 | enable: 1
28 | check_every_ms: 100
29 | sims_factor: 1.0
30 | sims_threshold: 2000
31 | }
32 | unstable_overtime {
33 | enable: 1
34 | time_factor: 0.3
35 | }
36 | behind_overtime {
37 | enable: 1
38 | act_threshold: 0.0
39 | time_factor: 0.3
40 | }
41 | time_control {
42 | enable: 1
43 | c_denom: 20
44 | c_maxply: 40
45 | reserved_time: 1.0
46 | }
47 |
--------------------------------------------------------------------------------
/etc/mcts_1gpu_notensorrt_grp.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 1
2 | num_search_threads: 8
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 0
6 | max_simulations_per_step: 3200
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | }
12 | gpu_list: "0"
13 | c_puct: 2.5
14 | virtual_loss: 1.0
15 | enable_resign: 1
16 | v_resign: -0.9
17 | enable_dirichlet_noise: 0
18 | dirichlet_noise_alpha: 0.03
19 | dirichlet_noise_ratio: 0.25
20 | monitor_log_every_ms: 0
21 | get_best_move_mode: 0
22 | enable_background_search: 0
23 | enable_policy_temperature: 0
24 | policy_temperature: 0.67
25 | inherit_default_act: 1
26 | debugger {
27 | print_tree_depth: 20
28 | print_tree_width: 3
29 | }
30 | early_stop {
31 | enable: 0
32 | check_every_ms: 100
33 | sims_factor: 1.0
34 | sims_threshold: 2000
35 | }
36 | unstable_overtime {
37 | enable: 0
38 | time_factor: 0.3
39 | }
40 | behind_overtime {
41 | enable: 0
42 | act_threshold: 0.0
43 | time_factor: 0.3
44 | }
45 | time_control {
46 | enable: 0
47 | c_denom: 20
48 | c_maxply: 40
49 | reserved_time: 1.0
50 | }
51 |
--------------------------------------------------------------------------------
/etc/mcts_2gpu.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 2
2 | num_search_threads: 12
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | enable_tensorrt: 1
12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN"
13 | }
14 | gpu_list: "0,1"
15 | c_puct: 2.5
16 | virtual_loss: 1.0
17 | enable_resign: 1
18 | v_resign: -0.9
19 | enable_dirichlet_noise: 0
20 | dirichlet_noise_alpha: 0.03
21 | dirichlet_noise_ratio: 0.25
22 | monitor_log_every_ms: 0
23 | get_best_move_mode: 0
24 | enable_background_search: 1
25 | enable_policy_temperature: 0
26 | policy_temperature: 0.67
27 | inherit_default_act: 1
28 | early_stop {
29 | enable: 1
30 | check_every_ms: 100
31 | sims_factor: 1.0
32 | sims_threshold: 2000
33 | }
34 | unstable_overtime {
35 | enable: 1
36 | time_factor: 0.3
37 | }
38 | behind_overtime {
39 | enable: 1
40 | act_threshold: 0.0
41 | time_factor: 0.3
42 | }
43 | time_control {
44 | enable: 1
45 | c_denom: 20
46 | c_maxply: 40
47 | reserved_time: 1.0
48 | }
49 |
--------------------------------------------------------------------------------
/etc/mcts_2gpu_grp.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 2
2 | num_search_threads: 12
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 0
6 | max_simulations_per_step: 3200
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | enable_tensorrt: 1
12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN"
13 | }
14 | gpu_list: "0,1"
15 | c_puct: 2.5
16 | virtual_loss: 1.0
17 | enable_resign: 1
18 | v_resign: -0.9
19 | enable_dirichlet_noise: 0
20 | dirichlet_noise_alpha: 0.03
21 | dirichlet_noise_ratio: 0.25
22 | monitor_log_every_ms: 0
23 | get_best_move_mode: 0
24 | enable_background_search: 0
25 | enable_policy_temperature: 0
26 | policy_temperature: 0.67
27 | inherit_default_act: 1
28 | debugger {
29 | print_tree_depth: 20
30 | print_tree_width: 3
31 | }
32 | early_stop {
33 | enable: 0
34 | check_every_ms: 100
35 | sims_factor: 1.0
36 | sims_threshold: 2000
37 | }
38 | unstable_overtime {
39 | enable: 0
40 | time_factor: 0.3
41 | }
42 | behind_overtime {
43 | enable: 0
44 | act_threshold: 0.0
45 | time_factor: 0.3
46 | }
47 | time_control {
48 | enable: 0
49 | c_denom: 20
50 | c_maxply: 40
51 | reserved_time: 1.0
52 | }
53 |
--------------------------------------------------------------------------------
/etc/mcts_2gpu_notensorrt.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 2
2 | num_search_threads: 12
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | }
12 | gpu_list: "0,1"
13 | c_puct: 2.5
14 | virtual_loss: 1.0
15 | enable_resign: 1
16 | v_resign: -0.9
17 | enable_dirichlet_noise: 0
18 | dirichlet_noise_alpha: 0.03
19 | dirichlet_noise_ratio: 0.25
20 | monitor_log_every_ms: 0
21 | get_best_move_mode: 0
22 | enable_background_search: 1
23 | enable_policy_temperature: 0
24 | policy_temperature: 0.67
25 | inherit_default_act: 1
26 | early_stop {
27 | enable: 1
28 | check_every_ms: 100
29 | sims_factor: 1.0
30 | sims_threshold: 2000
31 | }
32 | unstable_overtime {
33 | enable: 1
34 | time_factor: 0.3
35 | }
36 | behind_overtime {
37 | enable: 1
38 | act_threshold: 0.0
39 | time_factor: 0.3
40 | }
41 | time_control {
42 | enable: 1
43 | c_denom: 20
44 | c_maxply: 40
45 | reserved_time: 1.0
46 | }
47 |
--------------------------------------------------------------------------------
/etc/mcts_2gpu_notensorrt_grp.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 2
2 | num_search_threads: 12
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 0
6 | max_simulations_per_step: 3200
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | }
12 | gpu_list: "0,1"
13 | c_puct: 2.5
14 | virtual_loss: 1.0
15 | enable_resign: 1
16 | v_resign: -0.9
17 | enable_dirichlet_noise: 0
18 | dirichlet_noise_alpha: 0.03
19 | dirichlet_noise_ratio: 0.25
20 | monitor_log_every_ms: 0
21 | get_best_move_mode: 0
22 | enable_background_search: 0
23 | enable_policy_temperature: 0
24 | policy_temperature: 0.67
25 | inherit_default_act: 1
26 | debugger {
27 | print_tree_depth: 20
28 | print_tree_width: 3
29 | }
30 | early_stop {
31 | enable: 0
32 | check_every_ms: 100
33 | sims_factor: 1.0
34 | sims_threshold: 2000
35 | }
36 | unstable_overtime {
37 | enable: 0
38 | time_factor: 0.3
39 | }
40 | behind_overtime {
41 | enable: 0
42 | act_threshold: 0.0
43 | time_factor: 0.3
44 | }
45 | time_control {
46 | enable: 0
47 | c_denom: 20
48 | c_maxply: 40
49 | reserved_time: 1.0
50 | }
51 |
--------------------------------------------------------------------------------
/etc/mcts_3gpu.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 3
2 | num_search_threads: 16
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | enable_tensorrt: 1
12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN"
13 | }
14 | gpu_list: "0,1,2"
15 | c_puct: 2.5
16 | virtual_loss: 1.0
17 | enable_resign: 1
18 | v_resign: -0.9
19 | enable_dirichlet_noise: 0
20 | dirichlet_noise_alpha: 0.03
21 | dirichlet_noise_ratio: 0.25
22 | monitor_log_every_ms: 0
23 | get_best_move_mode: 0
24 | enable_background_search: 1
25 | enable_policy_temperature: 0
26 | policy_temperature: 0.67
27 | inherit_default_act: 1
28 | early_stop {
29 | enable: 1
30 | check_every_ms: 100
31 | sims_factor: 1.0
32 | sims_threshold: 2000
33 | }
34 | unstable_overtime {
35 | enable: 1
36 | time_factor: 0.3
37 | }
38 | behind_overtime {
39 | enable: 1
40 | act_threshold: 0.0
41 | time_factor: 0.3
42 | }
43 | time_control {
44 | enable: 1
45 | c_denom: 20
46 | c_maxply: 40
47 | reserved_time: 1.0
48 | }
49 |
--------------------------------------------------------------------------------
/etc/mcts_3gpu_notensorrt.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 3
2 | num_search_threads: 16
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | }
12 | gpu_list: "0,1,2"
13 | c_puct: 2.5
14 | virtual_loss: 1.0
15 | enable_resign: 1
16 | v_resign: -0.9
17 | enable_dirichlet_noise: 0
18 | dirichlet_noise_alpha: 0.03
19 | dirichlet_noise_ratio: 0.25
20 | monitor_log_every_ms: 0
21 | get_best_move_mode: 0
22 | enable_background_search: 1
23 | enable_policy_temperature: 0
24 | policy_temperature: 0.67
25 | inherit_default_act: 1
26 | early_stop {
27 | enable: 1
28 | check_every_ms: 100
29 | sims_factor: 1.0
30 | sims_threshold: 2000
31 | }
32 | unstable_overtime {
33 | enable: 1
34 | time_factor: 0.3
35 | }
36 | behind_overtime {
37 | enable: 1
38 | act_threshold: 0.0
39 | time_factor: 0.3
40 | }
41 | time_control {
42 | enable: 1
43 | c_denom: 20
44 | c_maxply: 40
45 | reserved_time: 1.0
46 | }
47 |
--------------------------------------------------------------------------------
/etc/mcts_4gpu.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 4
2 | num_search_threads: 20
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | enable_tensorrt: 1
12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN"
13 | }
14 | gpu_list: "0,1,2,3"
15 | c_puct: 2.5
16 | virtual_loss: 1.0
17 | enable_resign: 1
18 | v_resign: -0.9
19 | enable_dirichlet_noise: 0
20 | dirichlet_noise_alpha: 0.03
21 | dirichlet_noise_ratio: 0.25
22 | monitor_log_every_ms: 0
23 | get_best_move_mode: 0
24 | enable_background_search: 1
25 | enable_policy_temperature: 0
26 | policy_temperature: 0.67
27 | inherit_default_act: 1
28 | early_stop {
29 | enable: 1
30 | check_every_ms: 100
31 | sims_factor: 1.0
32 | sims_threshold: 2000
33 | }
34 | unstable_overtime {
35 | enable: 1
36 | time_factor: 0.3
37 | }
38 | behind_overtime {
39 | enable: 1
40 | act_threshold: 0.0
41 | time_factor: 0.3
42 | }
43 | time_control {
44 | enable: 1
45 | c_denom: 20
46 | c_maxply: 40
47 | reserved_time: 1.0
48 | }
49 |
--------------------------------------------------------------------------------
/etc/mcts_4gpu_notensorrt.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 4
2 | num_search_threads: 20
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | }
12 | gpu_list: "0,1,2,3"
13 | c_puct: 2.5
14 | virtual_loss: 1.0
15 | enable_resign: 1
16 | v_resign: -0.9
17 | enable_dirichlet_noise: 0
18 | dirichlet_noise_alpha: 0.03
19 | dirichlet_noise_ratio: 0.25
20 | monitor_log_every_ms: 0
21 | get_best_move_mode: 0
22 | enable_background_search: 1
23 | enable_policy_temperature: 0
24 | policy_temperature: 0.67
25 | inherit_default_act: 1
26 | early_stop {
27 | enable: 1
28 | check_every_ms: 100
29 | sims_factor: 1.0
30 | sims_threshold: 2000
31 | }
32 | unstable_overtime {
33 | enable: 1
34 | time_factor: 0.3
35 | }
36 | behind_overtime {
37 | enable: 1
38 | act_threshold: 0.0
39 | time_factor: 0.3
40 | }
41 | time_control {
42 | enable: 1
43 | c_denom: 20
44 | c_maxply: 40
45 | reserved_time: 1.0
46 | }
47 |
--------------------------------------------------------------------------------
/etc/mcts_5gpu.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 5
2 | num_search_threads: 24
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | enable_tensorrt: 1
12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN"
13 | }
14 | gpu_list: "0,1,2,3,4"
15 | c_puct: 2.5
16 | virtual_loss: 1.0
17 | enable_resign: 1
18 | v_resign: -0.9
19 | enable_dirichlet_noise: 0
20 | dirichlet_noise_alpha: 0.03
21 | dirichlet_noise_ratio: 0.25
22 | monitor_log_every_ms: 0
23 | get_best_move_mode: 0
24 | enable_background_search: 1
25 | enable_policy_temperature: 0
26 | policy_temperature: 0.67
27 | inherit_default_act: 1
28 | early_stop {
29 | enable: 1
30 | check_every_ms: 100
31 | sims_factor: 1.0
32 | sims_threshold: 2000
33 | }
34 | unstable_overtime {
35 | enable: 1
36 | time_factor: 0.3
37 | }
38 | behind_overtime {
39 | enable: 1
40 | act_threshold: 0.0
41 | time_factor: 0.3
42 | }
43 | time_control {
44 | enable: 1
45 | c_denom: 20
46 | c_maxply: 40
47 | reserved_time: 1.0
48 | }
49 |
--------------------------------------------------------------------------------
/etc/mcts_5gpu_notensorrt.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 5
2 | num_search_threads: 24
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | }
12 | gpu_list: "0,1,2,3,4"
13 | c_puct: 2.5
14 | virtual_loss: 1.0
15 | enable_resign: 1
16 | v_resign: -0.9
17 | enable_dirichlet_noise: 0
18 | dirichlet_noise_alpha: 0.03
19 | dirichlet_noise_ratio: 0.25
20 | monitor_log_every_ms: 0
21 | get_best_move_mode: 0
22 | enable_background_search: 1
23 | enable_policy_temperature: 0
24 | policy_temperature: 0.67
25 | inherit_default_act: 1
26 | early_stop {
27 | enable: 1
28 | check_every_ms: 100
29 | sims_factor: 1.0
30 | sims_threshold: 2000
31 | }
32 | unstable_overtime {
33 | enable: 1
34 | time_factor: 0.3
35 | }
36 | behind_overtime {
37 | enable: 1
38 | act_threshold: 0.0
39 | time_factor: 0.3
40 | }
41 | time_control {
42 | enable: 1
43 | c_denom: 20
44 | c_maxply: 40
45 | reserved_time: 1.0
46 | }
47 |
--------------------------------------------------------------------------------
/etc/mcts_6gpu.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 6
2 | num_search_threads: 28
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | enable_tensorrt: 1
12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN"
13 | }
14 | gpu_list: "0,1,2,3,4,5"
15 | c_puct: 2.5
16 | virtual_loss: 1.0
17 | enable_resign: 1
18 | v_resign: -0.9
19 | enable_dirichlet_noise: 0
20 | dirichlet_noise_alpha: 0.03
21 | dirichlet_noise_ratio: 0.25
22 | monitor_log_every_ms: 0
23 | get_best_move_mode: 0
24 | enable_background_search: 1
25 | enable_policy_temperature: 0
26 | policy_temperature: 0.67
27 | inherit_default_act: 1
28 | early_stop {
29 | enable: 1
30 | check_every_ms: 100
31 | sims_factor: 1.0
32 | sims_threshold: 2000
33 | }
34 | unstable_overtime {
35 | enable: 1
36 | time_factor: 0.3
37 | }
38 | behind_overtime {
39 | enable: 1
40 | act_threshold: 0.0
41 | time_factor: 0.3
42 | }
43 | time_control {
44 | enable: 1
45 | c_denom: 20
46 | c_maxply: 40
47 | reserved_time: 1.0
48 | }
49 |
--------------------------------------------------------------------------------
/etc/mcts_6gpu_notensorrt.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 6
2 | num_search_threads: 28
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | }
12 | gpu_list: "0,1,2,3,4,5"
13 | c_puct: 2.5
14 | virtual_loss: 1.0
15 | enable_resign: 1
16 | v_resign: -0.9
17 | enable_dirichlet_noise: 0
18 | dirichlet_noise_alpha: 0.03
19 | dirichlet_noise_ratio: 0.25
20 | monitor_log_every_ms: 0
21 | get_best_move_mode: 0
22 | enable_background_search: 1
23 | enable_policy_temperature: 0
24 | policy_temperature: 0.67
25 | inherit_default_act: 1
26 | early_stop {
27 | enable: 1
28 | check_every_ms: 100
29 | sims_factor: 1.0
30 | sims_threshold: 2000
31 | }
32 | unstable_overtime {
33 | enable: 1
34 | time_factor: 0.3
35 | }
36 | behind_overtime {
37 | enable: 1
38 | act_threshold: 0.0
39 | time_factor: 0.3
40 | }
41 | time_control {
42 | enable: 1
43 | c_denom: 20
44 | c_maxply: 40
45 | reserved_time: 1.0
46 | }
47 |
--------------------------------------------------------------------------------
/etc/mcts_7gpu.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 7
2 | num_search_threads: 32
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | enable_tensorrt: 1
12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN"
13 | }
14 | gpu_list: "0,1,2,3,4,5,6"
15 | c_puct: 2.5
16 | virtual_loss: 1.0
17 | enable_resign: 1
18 | v_resign: -0.9
19 | enable_dirichlet_noise: 0
20 | dirichlet_noise_alpha: 0.03
21 | dirichlet_noise_ratio: 0.25
22 | monitor_log_every_ms: 0
23 | get_best_move_mode: 0
24 | enable_background_search: 1
25 | enable_policy_temperature: 0
26 | policy_temperature: 0.67
27 | inherit_default_act: 1
28 | early_stop {
29 | enable: 1
30 | check_every_ms: 100
31 | sims_factor: 1.0
32 | sims_threshold: 2000
33 | }
34 | unstable_overtime {
35 | enable: 1
36 | time_factor: 0.3
37 | }
38 | behind_overtime {
39 | enable: 1
40 | act_threshold: 0.0
41 | time_factor: 0.3
42 | }
43 | time_control {
44 | enable: 1
45 | c_denom: 20
46 | c_maxply: 40
47 | reserved_time: 1.0
48 | }
49 |
--------------------------------------------------------------------------------
/etc/mcts_7gpu_notensorrt.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 7
2 | num_search_threads: 32
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | }
12 | gpu_list: "0,1,2,3,4,5,6"
13 | c_puct: 2.5
14 | virtual_loss: 1.0
15 | enable_resign: 1
16 | v_resign: -0.9
17 | enable_dirichlet_noise: 0
18 | dirichlet_noise_alpha: 0.03
19 | dirichlet_noise_ratio: 0.25
20 | monitor_log_every_ms: 0
21 | get_best_move_mode: 0
22 | enable_background_search: 1
23 | enable_policy_temperature: 0
24 | policy_temperature: 0.67
25 | inherit_default_act: 1
26 | early_stop {
27 | enable: 1
28 | check_every_ms: 100
29 | sims_factor: 1.0
30 | sims_threshold: 2000
31 | }
32 | unstable_overtime {
33 | enable: 1
34 | time_factor: 0.3
35 | }
36 | behind_overtime {
37 | enable: 1
38 | act_threshold: 0.0
39 | time_factor: 0.3
40 | }
41 | time_control {
42 | enable: 1
43 | c_denom: 20
44 | c_maxply: 40
45 | reserved_time: 1.0
46 | }
47 |
--------------------------------------------------------------------------------
/etc/mcts_8gpu.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 8
2 | num_search_threads: 36
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | enable_tensorrt: 1
12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN"
13 | }
14 | gpu_list: "0,1,2,3,4,5,6,7"
15 | c_puct: 2.5
16 | virtual_loss: 1.0
17 | enable_resign: 1
18 | v_resign: -0.9
19 | enable_dirichlet_noise: 0
20 | dirichlet_noise_alpha: 0.03
21 | dirichlet_noise_ratio: 0.25
22 | monitor_log_every_ms: 0
23 | get_best_move_mode: 0
24 | enable_background_search: 1
25 | enable_policy_temperature: 0
26 | policy_temperature: 0.67
27 | inherit_default_act: 1
28 | early_stop {
29 | enable: 1
30 | check_every_ms: 100
31 | sims_factor: 1.0
32 | sims_threshold: 2000
33 | }
34 | unstable_overtime {
35 | enable: 1
36 | time_factor: 0.3
37 | }
38 | behind_overtime {
39 | enable: 1
40 | act_threshold: 0.0
41 | time_factor: 0.3
42 | }
43 | time_control {
44 | enable: 1
45 | c_denom: 20
46 | c_maxply: 40
47 | reserved_time: 1.0
48 | }
49 |
--------------------------------------------------------------------------------
/etc/mcts_8gpu_notensorrt.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 8
2 | num_search_threads: 36
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | }
12 | gpu_list: "0,1,2,3,4,5,6,7"
13 | c_puct: 2.5
14 | virtual_loss: 1.0
15 | enable_resign: 1
16 | v_resign: -0.9
17 | enable_dirichlet_noise: 0
18 | dirichlet_noise_alpha: 0.03
19 | dirichlet_noise_ratio: 0.25
20 | monitor_log_every_ms: 0
21 | get_best_move_mode: 0
22 | enable_background_search: 1
23 | enable_policy_temperature: 0
24 | policy_temperature: 0.67
25 | inherit_default_act: 1
26 | early_stop {
27 | enable: 1
28 | check_every_ms: 100
29 | sims_factor: 1.0
30 | sims_threshold: 2000
31 | }
32 | unstable_overtime {
33 | enable: 1
34 | time_factor: 0.3
35 | }
36 | behind_overtime {
37 | enable: 1
38 | act_threshold: 0.0
39 | time_factor: 0.3
40 | }
41 | time_control {
42 | enable: 1
43 | c_denom: 20
44 | c_maxply: 40
45 | reserved_time: 1.0
46 | }
47 |
--------------------------------------------------------------------------------
/etc/mcts_async_dist.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 64
2 | num_search_threads: 32
3 | enable_async: 1
4 | eval_task_queue_size: 128
5 | max_children_per_node: 64
6 | max_search_tree_size: 400000000
7 | timeout_ms_per_step: 30000
8 | max_simulations_per_step: 0
9 | eval_batch_size: 4
10 | eval_wait_batch_timeout_us: 100
11 | model_config {
12 | train_dir: "ckpt"
13 | enable_tensorrt: 1
14 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN"
15 | }
16 | enable_dist: 1
17 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
18 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
19 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
20 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
21 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
22 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
23 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
24 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
25 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
26 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
27 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
28 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
29 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
30 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
31 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
32 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
33 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
34 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
35 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
36 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
37 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
38 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
39 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
40 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
41 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
42 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
43 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
44 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
45 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
46 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
47 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
48 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
49 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
50 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
51 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
52 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
53 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
54 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
55 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
56 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
57 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
58 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
59 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
60 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
61 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
62 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
63 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
64 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
65 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
66 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
67 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
68 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
69 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
70 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
71 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
72 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
73 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
74 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
75 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
76 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
77 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
78 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
79 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
80 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port"
81 | dist_config {
82 | timeout_ms: 100
83 | enable_leaky_bucket: 1
84 | leaky_bucket_size: 3
85 | leaky_bucket_refill_period_ms: 5000
86 | }
87 | c_puct: 2.5
88 | virtual_loss: 0.5
89 | enable_resign: 1
90 | v_resign: -0.9
91 | enable_dirichlet_noise: 0
92 | dirichlet_noise_alpha: 0.03
93 | dirichlet_noise_ratio: 0.25
94 | monitor_log_every_ms: 0
95 | get_best_move_mode: 0
96 | enable_background_search: 1
97 | enable_policy_temperature: 0
98 | policy_temperature: 0.67
99 | inherit_default_act: 1
100 | early_stop {
101 | enable: 1
102 | check_every_ms: 100
103 | sims_factor: 1.0
104 | sims_threshold: 100000
105 | }
106 | unstable_overtime {
107 | enable: 1
108 | time_factor: 0.3
109 | }
110 | behind_overtime {
111 | enable: 1
112 | act_threshold: 0.0
113 | time_factor: 0.3
114 | }
115 | time_control {
116 | enable: 1
117 | c_denom: 20
118 | c_maxply: 40
119 | reserved_time: 1.0
120 | }
121 |
--------------------------------------------------------------------------------
/etc/mcts_cpu.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 1
2 | num_search_threads: 8
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | }
12 | c_puct: 2.5
13 | virtual_loss: 1.0
14 | enable_resign: 1
15 | v_resign: -0.9
16 | enable_dirichlet_noise: 0
17 | dirichlet_noise_alpha: 0.03
18 | dirichlet_noise_ratio: 0.25
19 | monitor_log_every_ms: 0
20 | get_best_move_mode: 0
21 | enable_background_search: 1
22 | enable_policy_temperature: 0
23 | policy_temperature: 0.67
24 | inherit_default_act: 1
25 | early_stop {
26 | enable: 1
27 | check_every_ms: 100
28 | sims_factor: 1.0
29 | sims_threshold: 2000
30 | }
31 | unstable_overtime {
32 | enable: 1
33 | time_factor: 0.3
34 | }
35 | behind_overtime {
36 | enable: 1
37 | act_threshold: 0.0
38 | time_factor: 0.3
39 | }
40 | time_control {
41 | enable: 1
42 | c_denom: 20
43 | c_maxply: 40
44 | reserved_time: 1.0
45 | }
46 |
--------------------------------------------------------------------------------
/etc/mcts_cpu_grp.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 1
2 | num_search_threads: 8
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 0
6 | max_simulations_per_step: 1600
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | }
12 | c_puct: 2.5
13 | virtual_loss: 1.0
14 | enable_resign: 1
15 | v_resign: -0.9
16 | enable_dirichlet_noise: 0
17 | dirichlet_noise_alpha: 0.03
18 | dirichlet_noise_ratio: 0.25
19 | monitor_log_every_ms: 0
20 | get_best_move_mode: 0
21 | enable_background_search: 0
22 | enable_policy_temperature: 0
23 | policy_temperature: 0.67
24 | inherit_default_act: 1
25 | debugger {
26 | print_tree_depth: 20
27 | print_tree_width: 3
28 | }
29 | early_stop {
30 | enable: 0
31 | check_every_ms: 100
32 | sims_factor: 1.0
33 | sims_threshold: 2000
34 | }
35 | unstable_overtime {
36 | enable: 0
37 | time_factor: 0.3
38 | }
39 | behind_overtime {
40 | enable: 0
41 | act_threshold: 0.0
42 | time_factor: 0.3
43 | }
44 | time_control {
45 | enable: 0
46 | c_denom: 20
47 | c_maxply: 40
48 | reserved_time: 1.0
49 | }
50 |
--------------------------------------------------------------------------------
/etc/mcts_dist.conf:
--------------------------------------------------------------------------------
1 | num_eval_threads: 32
2 | num_search_threads: 132
3 | max_children_per_node: 64
4 | max_search_tree_size: 400000000
5 | timeout_ms_per_step: 30000
6 | max_simulations_per_step: 0
7 | eval_batch_size: 4
8 | eval_wait_batch_timeout_us: 100
9 | model_config {
10 | train_dir: "ckpt"
11 | enable_tensorrt: 1
12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN"
13 | }
14 | enable_dist: 1
15 | dist_svr_addrs: "ip:port"
16 | dist_svr_addrs: "ip:port"
17 | dist_svr_addrs: "ip:port"
18 | dist_svr_addrs: "ip:port"
19 | dist_svr_addrs: "ip:port"
20 | dist_svr_addrs: "ip:port"
21 | dist_svr_addrs: "ip:port"
22 | dist_svr_addrs: "ip:port"
23 | dist_svr_addrs: "ip:port"
24 | dist_svr_addrs: "ip:port"
25 | dist_svr_addrs: "ip:port"
26 | dist_svr_addrs: "ip:port"
27 | dist_svr_addrs: "ip:port"
28 | dist_svr_addrs: "ip:port"
29 | dist_svr_addrs: "ip:port"
30 | dist_svr_addrs: "ip:port"
31 | dist_svr_addrs: "ip:port"
32 | dist_svr_addrs: "ip:port"
33 | dist_svr_addrs: "ip:port"
34 | dist_svr_addrs: "ip:port"
35 | dist_svr_addrs: "ip:port"
36 | dist_svr_addrs: "ip:port"
37 | dist_svr_addrs: "ip:port"
38 | dist_svr_addrs: "ip:port"
39 | dist_svr_addrs: "ip:port"
40 | dist_svr_addrs: "ip:port"
41 | dist_svr_addrs: "ip:port"
42 | dist_svr_addrs: "ip:port"
43 | dist_svr_addrs: "ip:port"
44 | dist_svr_addrs: "ip:port"
45 | dist_svr_addrs: "ip:port"
46 | dist_svr_addrs: "ip:port"
47 | dist_config {
48 | timeout_ms: 100
49 | enable_leaky_bucket: 1
50 | leaky_bucket_size: 3
51 | leaky_bucket_refill_period_ms: 5000
52 | }
53 | c_puct: 2.5
54 | virtual_loss: 1.0
55 | enable_resign: 1
56 | v_resign: -0.9
57 | enable_dirichlet_noise: 0
58 | dirichlet_noise_alpha: 0.03
59 | dirichlet_noise_ratio: 0.25
60 | monitor_log_every_ms: 0
61 | get_best_move_mode: 0
62 | enable_background_search: 1
63 | enable_policy_temperature: 0
64 | policy_temperature: 0.67
65 | inherit_default_act: 1
66 | early_stop {
67 | enable: 1
68 | check_every_ms: 100
69 | sims_factor: 1.0
70 | sims_threshold: 10000
71 | }
72 | unstable_overtime {
73 | enable: 1
74 | time_factor: 0.3
75 | }
76 | behind_overtime {
77 | enable: 1
78 | act_threshold: 0.0
79 | time_factor: 0.3
80 | }
81 | time_control {
82 | enable: 1
83 | c_denom: 20
84 | c_maxply: 40
85 | reserved_time: 1.0
86 | }
87 |
--------------------------------------------------------------------------------
/images/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent/PhoenixGo/fbf67f9aec42531bff9569c44b85eb4c3f37b7be/images/logo.jpg
--------------------------------------------------------------------------------
/mcts/BUILD:
--------------------------------------------------------------------------------
1 | load("//:rules.bzl", "cc_proto_library", "tf_cc_binary")
2 |
3 | tf_cc_binary(
4 | name = "mcts_main",
5 | srcs = ["mcts_main.cc"],
6 | deps = [
7 | ":mcts_engine",
8 | "@boost//:asio",
9 | ],
10 | )
11 |
12 | tf_cc_binary(
13 | name = "debug_tool",
14 | srcs = [
15 | "debug_tool.cc",
16 | "mcts_config.h",
17 | "mcts_config.cc",
18 | ],
19 | deps = [
20 | ":mcts_config_cc_proto",
21 | "//common:go_comm",
22 | "//common:go_state",
23 | "//common:timer",
24 | "//model:zero_model",
25 | "//model:trt_zero_model",
26 | ],
27 | )
28 |
29 | cc_library(
30 | name = "mcts_engine",
31 | srcs = [
32 | "mcts_engine.cc",
33 | "mcts_monitor.cc",
34 | "mcts_debugger.cc",
35 | "byo_yomi_timer.cc",
36 | ],
37 | hdrs = [
38 | "mcts_engine.h",
39 | "mcts_monitor.h",
40 | "mcts_debugger.h",
41 | "byo_yomi_timer.h",
42 | ],
43 | deps = [
44 | ":mcts_config",
45 | "//common:go_comm",
46 | "//common:go_state",
47 | "//common:task_queue",
48 | "//common:wait_group",
49 | "//common:thread_conductor",
50 | "//common:str_utils",
51 | "//common:timer",
52 | "//model:zero_model",
53 | "//model:trt_zero_model",
54 | "//dist:dist_zero_model_client",
55 | "//dist:async_dist_zero_model_client",
56 | "@com_github_google_glog//:glog",
57 | ],
58 | visibility = ["//visibility:public"],
59 | )
60 |
61 | cc_library(
62 | name = "mcts_config",
63 | srcs = ["mcts_config.cc"],
64 | hdrs = ["mcts_config.h"],
65 | deps = [
66 | ":mcts_config_cc_proto",
67 | "@com_github_google_glog//:glog",
68 | ],
69 | visibility = ["//visibility:public"],
70 | )
71 |
72 | cc_proto_library(
73 | name = "mcts_config_cc_proto",
74 | srcs = ["mcts_config.proto"],
75 | deps = [
76 | "//model:model_config_cc_proto",
77 | "//dist:dist_config_cc_proto",
78 | ],
79 | visibility = ["//visibility:public"],
80 | )
81 |
--------------------------------------------------------------------------------
/mcts/byo_yomi_timer.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include "byo_yomi_timer.h"
19 |
20 | #include
21 |
22 | ByoYomiTimer::ByoYomiTimer()
23 | : m_enable(false),
24 | m_remain_time{0.0f, 0.0f},
25 | m_byo_yomi_time(0.0f),
26 | m_curr_player(GoComm::BLACK)
27 | {
28 | }
29 |
30 | void ByoYomiTimer::Set(float main_time, float byo_yomi_time)
31 | {
32 | m_enable = true;
33 | m_remain_time[0] = m_remain_time[1] = main_time;
34 | m_byo_yomi_time = byo_yomi_time;
35 | m_timer.Reset();
36 | }
37 |
38 | void ByoYomiTimer::Reset()
39 | {
40 | m_enable = false;
41 | m_remain_time[0] = m_remain_time[1] = 0.0f;
42 | m_byo_yomi_time = 0.0f;
43 | m_curr_player = GoComm::BLACK;
44 | }
45 |
46 | bool ByoYomiTimer::IsEnable()
47 | {
48 | return m_enable;
49 | }
50 |
51 | void ByoYomiTimer::HandOff()
52 | {
53 | m_remain_time[m_curr_player == GoComm::BLACK ? 0 : 1] -= m_timer.fsec();
54 | m_curr_player = m_curr_player == GoComm::BLACK ? GoComm::WHITE : GoComm::BLACK;
55 | m_timer.Reset();
56 | }
57 |
58 | void ByoYomiTimer::SetRemainTime(GoStoneColor color, float time)
59 | {
60 | m_remain_time[color == GoComm::BLACK ? 0 : 1] = time;
61 | if (color == m_curr_player) {
62 | m_timer.Reset();
63 | }
64 | }
65 |
66 | float ByoYomiTimer::GetRemainTime(GoStoneColor color)
67 | {
68 | float remain_time = m_remain_time[color == GoComm::BLACK ? 0 : 1];
69 | if (color == m_curr_player) {
70 | remain_time -= m_timer.fsec();
71 | }
72 | return std::max(remain_time, 0.0f);
73 | }
74 |
75 | float ByoYomiTimer::GetByoYomiTime()
76 | {
77 | return m_byo_yomi_time;
78 | }
79 |
--------------------------------------------------------------------------------
/mcts/byo_yomi_timer.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include "common/go_comm.h"
21 | #include "common/timer.h"
22 |
23 | class ByoYomiTimer
24 | {
25 | public:
26 | ByoYomiTimer();
27 | void Set(float main_time, float byo_yomi_time);
28 | void Reset();
29 | bool IsEnable();
30 | void HandOff();
31 | void SetRemainTime(GoStoneColor color, float time);
32 | float GetRemainTime(GoStoneColor color);
33 | float GetByoYomiTime();
34 |
35 | private:
36 | bool m_enable;
37 | float m_remain_time[2];
38 | float m_byo_yomi_time;
39 | GoStoneColor m_curr_player;
40 | Timer m_timer;
41 | };
42 |
--------------------------------------------------------------------------------
/mcts/debug_tool.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include
19 | #include
20 |
21 | #include "model/zero_model.h"
22 | #include "model/trt_zero_model.h"
23 | #include "common/go_state.h"
24 | #include "common/timer.h"
25 |
26 | #include "mcts_config.h"
27 |
28 | DEFINE_string(config_path, "", "Path of mcts config file.");
29 | DEFINE_string(init_moves, "", "Initialize Go board with init_moves.");
30 | DEFINE_int32(gpu, 0, "gpu used by neural network.");
31 | DEFINE_int32(intra_op_parallelism_threads, 0, "Number of tf's intra op threads");
32 | DEFINE_int32(inter_op_parallelism_threads, 0, "Number of tf's inter op threads");
33 | DEFINE_int32(transform, 0, "Transform features.");
34 | DEFINE_int32(num_iterations, 1, "How many iterations should run.");
35 | DEFINE_int32(batch_size, 1, "Batch size of each iterations.");
36 |
37 | void InitMove(GoState& board, std::string& moves)
38 | {
39 | for (size_t i = 0; i < moves.size(); i += 3) {
40 | int x = -1, y = -1;
41 | if (moves[i] != 'z') {
42 | x = moves[i] - 'a';
43 | y = moves[i + 1] - 'a';
44 | }
45 | board.Move(x, y);
46 | }
47 | }
48 |
49 | void TransformCoord(GoCoordId &x, GoCoordId &y, int mode, bool reverse)
50 | {
51 | if (reverse) {
52 | if (mode & 4) std::swap(x, y);
53 | if (mode & 2) y = GoComm::BORDER_SIZE - y - 1;
54 | if (mode & 1) x = GoComm::BORDER_SIZE - x - 1;
55 | } else {
56 | if (mode & 1) x = GoComm::BORDER_SIZE - x - 1;
57 | if (mode & 2) y = GoComm::BORDER_SIZE - y - 1;
58 | if (mode & 4) std::swap(x, y);
59 | }
60 | }
61 |
62 | template
63 | void TransformFeatures(T &features, int mode, bool reverse)
64 | {
65 | T ret(features.size());
66 | int depth = features.size() / GoComm::GOBOARD_SIZE;
67 | for (int i = 0; i < GoComm::GOBOARD_SIZE; ++i) {
68 | GoCoordId x, y;
69 | GoFunction::IdToCoord(i, x, y);
70 | TransformCoord(x, y, mode, reverse);
71 | int j = GoFunction::CoordToId(x, y);
72 | for (int k = 0; k < depth; ++k) {
73 | ret[i * depth + k] = features[j * depth + k];
74 | }
75 | }
76 | features = std::move(ret);
77 | }
78 |
79 | int main(int argc, char* argv[])
80 | {
81 | google::ParseCommandLineFlags(&argc, &argv, true);
82 | google::InitGoogleLogging(argv[0]);
83 | google::InstallFailureSignalHandler();
84 |
85 | auto config = LoadConfig(FLAGS_config_path);
86 | CHECK(config != nullptr) << "Load mcts config file '" << FLAGS_config_path << "' failed";
87 |
88 | if (FLAGS_intra_op_parallelism_threads > 0) {
89 | config->mutable_model_config()->set_intra_op_parallelism_threads(FLAGS_intra_op_parallelism_threads);
90 | }
91 |
92 | if (FLAGS_inter_op_parallelism_threads > 0) {
93 | config->mutable_model_config()->set_inter_op_parallelism_threads(FLAGS_inter_op_parallelism_threads);
94 | }
95 |
96 | if (config->model_config().enable_mkl()) {
97 | ZeroModel::SetMKLEnv(config->model_config());
98 | }
99 |
100 | std::unique_ptr model(new ZeroModel(FLAGS_gpu));
101 | #if HAVE_TENSORRT
102 | if (config->model_config().enable_tensorrt()) {
103 | model.reset(new TrtZeroModel(FLAGS_gpu));
104 | }
105 | #endif
106 | CHECK_EQ(model->Init(config->model_config()), 0) << "Model Init Fail, config path " << FLAGS_config_path<< ", gpu " << FLAGS_gpu;
107 |
108 | GoState board;
109 | InitMove(board, FLAGS_init_moves);
110 |
111 | auto features = board.GetFeature();
112 | TransformFeatures(features, FLAGS_transform, false);
113 | std::vector> inputs(FLAGS_batch_size, features);
114 |
115 | std::vector> policies;
116 | std::vector values;
117 |
118 | Timer timer;
119 | for (int i = 1; i <= FLAGS_num_iterations; ++i) {
120 | CHECK_EQ(model->Forward(inputs, policies, values), 0) << "Forward fail";
121 | LOG_IF(INFO, i % 100 == 0) << i << "/" << FLAGS_num_iterations << " iterations";
122 | }
123 | float avg_cost_ms = timer.fms() / FLAGS_num_iterations;
124 | LOG(INFO) << "Cost " << avg_cost_ms << "ms per iteration";
125 |
126 | std::vector& policy = policies[0];
127 | TransformFeatures(policy, FLAGS_transform, true);
128 | float value = values[0];
129 | board.ShowBoard();
130 | for (int i = 0; i < GoComm::BORDER_SIZE; ++i) {
131 | for (int j = 0; j < GoComm::BORDER_SIZE; ++j) {
132 | printf("%.4f ", policy[i * GoComm::BORDER_SIZE + j]);
133 | }
134 | puts("");
135 | }
136 | printf("Value %.4f Pass %.4f\n", value, policy[361]);
137 | }
138 |
--------------------------------------------------------------------------------
/mcts/mcts_config.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include "mcts_config.h"
19 |
20 | #include
21 |
22 | #include
23 | #include
24 |
25 | std::unique_ptr LoadConfig(const char *config_path)
26 | {
27 | auto config = std::unique_ptr(new MCTSConfig);
28 | std::ostringstream conf_ss;
29 | if (!(conf_ss << std::ifstream(config_path).rdbuf())) {
30 | PLOG(ERROR) << "read config file " << config_path << " error";
31 | return nullptr;
32 | }
33 | if (!google::protobuf::TextFormat::ParseFromString(conf_ss.str(), config.get())) {
34 | LOG(ERROR) << "parse config file " << config_path << " error! buf=" << conf_ss.str();
35 | return nullptr;
36 | }
37 | return config;
38 | }
39 |
40 | std::unique_ptr LoadConfig(const std::string &config_path)
41 | {
42 | return LoadConfig(config_path.c_str());
43 | }
44 |
--------------------------------------------------------------------------------
/mcts/mcts_config.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include
21 | #include
22 |
23 | #include "mcts/mcts_config.pb.h"
24 |
25 | std::unique_ptr LoadConfig(const char *config_path);
26 | std::unique_ptr LoadConfig(const std::string &config_path);
27 |
--------------------------------------------------------------------------------
/mcts/mcts_config.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | import "model/model_config.proto";
4 | import "dist/dist_config.proto";
5 |
6 | message MCTSConfig {
7 | int32 num_eval_threads = 1;
8 | int32 num_search_threads = 2;
9 | int32 max_search_tree_size = 3;
10 | int32 max_children_per_node = 4;
11 | int32 timeout_ms_per_step = 5;
12 | int32 max_simulations_per_step = 6;
13 | int32 eval_batch_size = 7;
14 | int32 eval_wait_batch_timeout_us = 8;
15 | ModelConfig model_config = 9;
16 | string gpu_list = 10;
17 | bool enable_dist = 11;
18 | repeated string dist_svr_addrs = 12;
19 | DistConfig dist_config = 13;
20 | float genmove_temperature = 14;
21 | float c_puct = 15;
22 | float virtual_loss = 16;
23 | int32 virtual_loss_mode = 17; // 0 - act+ucb, 1 - act, 2 - ucb, 3 - none
24 | bool enable_resign = 18;
25 | float v_resign = 19;
26 | int32 resign_mode = 20;
27 | bool enable_dirichlet_noise = 21;
28 | float dirichlet_noise_alpha = 22;
29 | float dirichlet_noise_ratio = 23;
30 | int32 monitor_log_every_ms = 24;
31 | int32 get_best_move_mode = 25; // 0 by mcts, 1 by policy, 2 by value
32 | bool enable_background_search = 26;
33 | bool enable_policy_temperature = 27;
34 | float policy_temperature = 28;
35 | bool disable_transform = 29;
36 | bool disable_pass = 30;
37 | float default_act = 31;
38 | bool inherit_default_act = 32;
39 | float inherit_default_act_factor = 33;
40 | bool clear_search_tree_per_move = 34;
41 |
42 | // async
43 | bool enable_async = 51;
44 | int32 eval_task_queue_size = 52;
45 |
46 | message DebuggerConfig {
47 | int32 print_tree_depth = 1;
48 | int32 print_tree_width = 2;
49 | };
50 | DebuggerConfig debugger = 60;
51 |
52 | // rules for foxwq
53 | bool disable_double_pass_scoring = 81;
54 | bool disable_positional_superko = 82;
55 | int32 max_gen_passes = 83;
56 | bool enable_pass_pass = 84;
57 |
58 | message EarlyStopConfig {
59 | bool enable = 1;
60 | int32 check_every_ms = 2;
61 | float sims_factor = 3;
62 | int32 sims_threshold = 4;
63 | };
64 | EarlyStopConfig early_stop = 91;
65 |
66 | message UnstableOvertimeConfig {
67 | bool enable = 1;
68 | float time_factor = 2;
69 | };
70 | UnstableOvertimeConfig unstable_overtime = 92;
71 |
72 | message BehindOvertimeConfig {
73 | bool enable = 1;
74 | float act_threshold = 2;
75 | float time_factor = 3;
76 | };
77 | BehindOvertimeConfig behind_overtime = 93;
78 |
79 | message TimeControlConfig {
80 | bool enable = 1;
81 | float main_time = 2;
82 | float byo_yomi_time = 3;
83 | int32 c_denom = 4;
84 | int32 c_maxply = 5;
85 | float reserved_time = 6;
86 | float min_time = 7;
87 | int32 byo_yomi_after = 8;
88 | };
89 | TimeControlConfig time_control = 94;
90 | }
91 |
--------------------------------------------------------------------------------
/mcts/mcts_debugger.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include "mcts_debugger.h"
19 |
20 | #include
21 | #include
22 | #include
23 | #include
24 | #include
25 | #include
26 |
27 | #include
28 |
29 | #include "common/go_comm.h"
30 |
31 | #include "mcts_engine.h"
32 |
33 | MCTSDebugger::MCTSDebugger(MCTSEngine *engine)
34 | : m_engine(engine)
35 | {
36 | }
37 |
38 | void MCTSDebugger::Debug() // call before move
39 | {
40 | if (VLOG_IS_ON(1)) {
41 | int ith = m_engine->m_num_moves + 1;
42 | std::string ith_str = std::to_string(ith) + "th move(" + "wb"[ith&1] + ")";
43 | VLOG(1) << "========== debug info for " << ith_str << " begin ==========";
44 | VLOG(1) << "main move path: " << GetMainMovePath(0);
45 | VLOG(1) << "second move path: " << GetMainMovePath(1);
46 | VLOG(1) << "third move path: " << GetMainMovePath(2);
47 | int depth = m_engine->GetConfig().debugger().print_tree_depth();
48 | int width = m_engine->GetConfig().debugger().print_tree_width();
49 | PrintTree(depth ? depth : 1, width ? width : 10);
50 | VLOG(1) << "model global step: " << m_engine->m_model_global_step;
51 | VLOG(1) << "========== debug info for " << ith_str << " end ==========";
52 | }
53 | }
54 |
55 | std::string MCTSDebugger::GetDebugStr()
56 | {
57 | TreeNode *root = m_engine->m_root;
58 | int ith = m_engine->m_num_moves;
59 | std::string ith_str = std::to_string(ith) + "th move(" + "wb"[ith&1] + ")";
60 | float root_action = (float)root->total_action / k_action_value_base / root->visit_count;
61 | std::string debug_str =
62 | ith_str + ": " + GoFunction::IdToStr(root->move) +
63 | ", winrate=" + std::to_string((root_action + 1) * 50) + "%" +
64 | ", N=" + std::to_string(root->visit_count) +
65 | ", Q=" + std::to_string(root_action) +
66 | ", p=" + std::to_string(root->prior_prob) +
67 | ", v=" + std::to_string(root->value);
68 | if (m_engine->m_simulation_counter > 0) {
69 | debug_str +=
70 | ", cost " + std::to_string(m_engine->m_search_timer.fms()) + "ms" +
71 | ", sims=" + std::to_string(m_engine->m_simulation_counter) +
72 | ", height=" + std::to_string(m_engine->m_monitor.MaxSearchTreeHeight()) +
73 | ", avg_height=" + std::to_string(m_engine->m_monitor.AvgSearchTreeHeight());
74 | }
75 | debug_str += ", global_step=" + std::to_string(m_engine->m_model_global_step);
76 | return debug_str;
77 | }
78 |
79 | std::string MCTSDebugger::GetLastMoveDebugStr()
80 | {
81 | return m_last_move_debug_str;
82 | }
83 |
84 | void MCTSDebugger::UpdateLastMoveDebugStr()
85 | {
86 | m_last_move_debug_str = GetDebugStr();
87 | }
88 |
89 | std::string MCTSDebugger::GetMainMovePath(int rank)
90 | {
91 | std::string moves;
92 | TreeNode *node = m_engine->m_root;
93 | while (node->expand_state == k_expanded && node->ch_len > rank) {
94 | TreeNode *ch = node->ch;
95 | std::vector idx(node->ch_len);
96 | std::iota(idx.begin(), idx.end(), 0);
97 | std::nth_element(idx.begin(), idx.begin() + rank, idx.end(),
98 | [ch](int i, int j) { return ch[i].visit_count > ch[j].visit_count; });
99 | TreeNode *best_ch = &ch[idx[rank]];
100 | if (moves.size()) moves += ",";
101 | moves += GoFunction::IdToStr(best_ch->move);
102 | char buf[100];
103 | snprintf(buf, sizeof(buf), "(%d,%.2f,%.2f,%.2f)",
104 | best_ch->visit_count.load(),
105 | (float)best_ch->total_action / k_action_value_base / best_ch->visit_count,
106 | best_ch->prior_prob.load(),
107 | best_ch->value.load());
108 | moves += buf;
109 | node = best_ch;
110 | rank = 0;
111 | }
112 | return moves;
113 | }
114 |
115 | void MCTSDebugger::PrintTree(int depth, int topk, const std::string &prefix)
116 | {
117 | TreeNode *root = m_engine->m_root;
118 | std::queue> que;
119 | que.emplace(root, 1);
120 | while (!que.empty()) {
121 | TreeNode *node; int dep;
122 | std::tie(node, dep) = que.front();
123 | que.pop();
124 |
125 | TreeNode *ch = node->ch;
126 | std::vector idx(node->ch_len);
127 | std::iota(idx.begin(), idx.end(), 0);
128 | std::sort(idx.begin(), idx.end(), [ch](int i, int j) { return ch[i].visit_count > ch[j].visit_count; });
129 | if (topk < (int)idx.size()) idx.erase(idx.begin() + topk, idx.end());
130 |
131 | for (int i: idx) {
132 | if (ch[i].visit_count == 0) {
133 | break;
134 | }
135 | std::string moves;
136 | for (TreeNode *t = &ch[i]; t != root; t = t->fa) {
137 | if (moves.size()) moves = "," + moves;
138 | moves = GoFunction::IdToStr(t->move) + moves;
139 | }
140 | VLOG(1) << prefix << moves
141 | << ": N=" << ch[i].visit_count
142 | << ", W=" << (float)ch[i].total_action / k_action_value_base
143 | << ", Q=" << (float)ch[i].total_action / k_action_value_base / ch[i].visit_count
144 | << ", p=" << ch[i].prior_prob
145 | << ", v=" << ch[i].value;
146 | if (dep < depth) que.emplace(&ch[i], dep + 1);
147 | }
148 | }
149 | }
150 |
--------------------------------------------------------------------------------
/mcts/mcts_debugger.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include
21 |
22 | class MCTSEngine;
23 | class MCTSDebugger
24 | {
25 | public:
26 | MCTSDebugger(MCTSEngine *engine);
27 |
28 | void Debug(); // call before move
29 |
30 | std::string GetDebugStr();
31 | std::string GetLastMoveDebugStr(); // call after move
32 | void UpdateLastMoveDebugStr();
33 |
34 | std::string GetMainMovePath(int rank);
35 | void PrintTree(int depth, int topk, const std::string &prefix = "");
36 |
37 | private:
38 | MCTSEngine *m_engine;
39 | std::string m_last_move_debug_str;
40 | };
41 |
--------------------------------------------------------------------------------
/mcts/mcts_engine.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include
21 | #include
22 | #include
23 | #include
24 |
25 | #include "common/go_comm.h"
26 | #include "common/go_state.h"
27 | #include "common/task_queue.h"
28 | #include "common/wait_group.h"
29 | #include "common/thread_conductor.h"
30 | #include "common/timer.h"
31 | #include "model/zero_model_base.h"
32 |
33 | #include "mcts_config.h"
34 | #include "mcts_monitor.h"
35 | #include "mcts_debugger.h"
36 | #include "byo_yomi_timer.h"
37 |
38 | struct TreeNode
39 | {
40 | std::atomic fa;
41 | std::atomic ch; // child nodes must allocate contiguously
42 | std::atomic ch_len;
43 | std::atomic size;
44 | std::atomic expand_state;
45 |
46 | std::atomic move;
47 | std::atomic visit_count;
48 | std::atomic virtual_loss_count;
49 | std::atomic total_action;
50 | std::atomic prior_prob;
51 | std::atomic value;
52 | };
53 |
54 | const int64_t k_action_value_base = 1 << 16;
55 | const int k_unexpanded = 0;
56 | const int k_expanding = 1;
57 | const int k_expanded = 2;
58 |
59 | typedef std::function, float)> EvalCallback;
60 |
61 | struct EvalTask
62 | {
63 | std::vector features;
64 | EvalCallback callback;
65 | };
66 |
67 | class MCTSEngine
68 | {
69 | public:
70 | MCTSEngine(const MCTSConfig &config);
71 | ~MCTSEngine();
72 |
73 | void Reset(const std::string &init_moves="");
74 | void Move(GoCoordId x, GoCoordId y);
75 | void GenMove(GoCoordId &x, GoCoordId &y);
76 | void GenMove(GoCoordId &x, GoCoordId &y, std::vector &visit_count, float &v_resign);
77 | bool Undo();
78 | const GoState &GetBoard();
79 | MCTSConfig &GetConfig();
80 | void SetPendingConfig(std::unique_ptr config);
81 | MCTSDebugger &GetDebugger();
82 | int GetModelGlobalStep();
83 | ByoYomiTimer &GetByoYomiTimer();
84 |
85 | private:
86 | TreeNode *InitNode(TreeNode *node, TreeNode *fa, int move, float prior_prob);
87 | TreeNode *FindChild(TreeNode *node, int move);
88 |
89 | void Eval(const GoState &board, EvalCallback callback);
90 | void EvalRoutine(std::unique_ptr model);
91 |
92 | TreeNode *Select(GoState &board);
93 | TreeNode *SelectChild(TreeNode *node);
94 | int Expand(TreeNode *node, GoState &board, const std::vector &policy);
95 | void Backup(TreeNode *node, float value, int ch_len);
96 | void UndoVirtualLoss(TreeNode *node);
97 |
98 | bool CheckEarlyStop(int64_t timeout_us);
99 | bool CheckUnstable();
100 | bool CheckBehind();
101 |
102 | int64_t GetSearchTimeoutUs();
103 | int64_t GetSearchOvertimeUs(int64_t timeout_us);
104 |
105 | void Search();
106 | void SearchWait(int64_t timeout_us, bool is_overtime);
107 | void SearchResume();
108 | void SearchPause();
109 | void SearchRoutine();
110 |
111 | void ChangeRoot(TreeNode *node);
112 | void InitRoot();
113 |
114 | void DeleteRoutine();
115 | int DeleteTree(TreeNode *node);
116 |
117 | int GetBestMove(float &v_resign);
118 | int GetSamplingMove(float temperature);
119 | std::vector GetVisitCount(TreeNode *node);
120 |
121 | template
122 | void TransformFeatures(T &features, int mode, bool reverse = false);
123 | void TransformCoord(GoCoordId &x, GoCoordId &y, int mode, bool reverse = false);
124 |
125 | void ApplyTemperature(std::vector &probs, float temperature);
126 |
127 | void TTableUpdate(uint64_t hash, int64_t value);
128 | void TTableSync(TreeNode *node);
129 | void TTableClear();
130 |
131 | void EvalCacheInsert(uint64_t hash, const std::vector policy, float value);
132 | bool EvalCacheFind(uint64_t hash, std::vector &policy, float &value);
133 |
134 | bool IsPassDisable();
135 |
136 | private:
137 | MCTSConfig m_config;
138 | std::unique_ptr m_pending_config;
139 |
140 | TreeNode *m_root;
141 | GoState m_board;
142 |
143 | std::vector m_eval_threads;
144 | TaskQueue m_eval_task_queue;
145 | WaitGroup m_eval_threads_init_wg;
146 | WaitGroup m_eval_tasks_wg;
147 | std::atomic m_model_global_step;
148 |
149 | std::vector m_search_threads;
150 | ThreadConductor m_search_threads_conductor;
151 | bool m_is_searching;
152 |
153 | std::thread m_delete_thread;
154 | TaskQueue m_delete_queue;
155 |
156 | std::atomic m_simulation_counter;
157 | Timer m_search_timer;
158 |
159 | int m_num_moves;
160 | std::string m_moves_str;
161 |
162 | int m_gen_passes;
163 |
164 | ByoYomiTimer m_byo_yomi_timer;
165 |
166 | MCTSMonitor m_monitor;
167 | MCTSDebugger m_debugger;
168 |
169 | friend class MCTSMonitor;
170 | friend class MCTSDebugger;
171 | };
172 |
--------------------------------------------------------------------------------
/mcts/mcts_monitor.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include "mcts_monitor.h"
19 |
20 | #include
21 |
22 | #include "mcts_engine.h"
23 |
24 | int MCTSMonitor::g_next_monitor_id = 0;
25 | MCTSMonitor *MCTSMonitor::g_global_monitors[k_max_monitor_instances];
26 | std::mutex MCTSMonitor::g_global_monitors_mutex;
27 | thread_local std::shared_ptr MCTSMonitor::g_local_monitors[k_max_monitor_instances];
28 |
29 | MCTSMonitor::MCTSMonitor(MCTSEngine *engine)
30 | : m_engine(engine), m_slot(0)
31 | {
32 | {
33 | std::lock_guard lock(g_global_monitors_mutex);
34 | m_id = g_next_monitor_id++;
35 | while (m_slot < k_max_monitor_instances && g_global_monitors[m_slot]) ++m_slot;
36 | CHECK(m_slot < k_max_monitor_instances) << "Too many MCTSMonitor instances";
37 | g_global_monitors[m_slot] = this;
38 | }
39 | if (m_engine->GetConfig().monitor_log_every_ms() > 0) {
40 | m_monitor_thread = std::thread(&MCTSMonitor::MonitorRoutine, this);
41 | }
42 | }
43 |
44 | MCTSMonitor::~MCTSMonitor()
45 | {
46 | if (m_monitor_thread.joinable()) {
47 | m_monitor_thread_conductor.Terminate();
48 | m_monitor_thread.join();
49 | }
50 | g_local_monitors[m_slot] = nullptr;
51 | std::lock_guard lock(g_global_monitors_mutex);
52 | g_global_monitors[m_slot] = nullptr;
53 | }
54 |
55 | void MCTSMonitor::Pause()
56 | {
57 | m_monitor_thread_conductor.Pause();
58 | m_monitor_thread_conductor.Join();
59 | }
60 |
61 | void MCTSMonitor::Resume()
62 | {
63 | if (m_engine->GetConfig().monitor_log_every_ms() > 0) {
64 | if (!m_monitor_thread.joinable()) {
65 | m_monitor_thread = std::thread(&MCTSMonitor::MonitorRoutine, this);
66 | }
67 | m_monitor_thread_conductor.Resume(1);
68 | }
69 | }
70 |
71 | void MCTSMonitor::Reset()
72 | {
73 | std::lock_guard lock(m_local_monitors_mutex);
74 | for (auto &local_monitor: m_local_monitors) {
75 | local_monitor->Reset();
76 | }
77 | m_local_monitors.erase(
78 | std::remove_if(m_local_monitors.begin(), m_local_monitors.end(),
79 | [](const std::shared_ptr &ptr) { return ptr.use_count() == 1; }),
80 | m_local_monitors.end());
81 | }
82 |
83 | void MCTSMonitor::Log()
84 | {
85 | VLOG(0) << "MCTSMonitor: avg eval cost " << AvgEvalCostMs() << "ms";
86 | VLOG(0) << "MCTSMonitor: max eval cost " << MaxEvalCostMs() << "ms";
87 | VLOG(0) << "MCTSMonitor: avg eval cost " << AvgEvalCostMsPerBatch() << "ms per batch";
88 | VLOG(0) << "MCTSMonitor: max eval cost " << MaxEvalCostMsPerBatch() << "ms per batch";
89 | VLOG(0) << "MCTSMonitor: avg eval batch size " << AvgEvalBatchSize();
90 | VLOG(0) << "MCTSMonitor: eval timeout " << EvalTimeout() << " times";
91 | VLOG(0) << "MCTSMonitor: avg simulation cost " << AvgSimulationCostMs() << "ms";
92 | VLOG(0) << "MCTSMonitor: max simulation cost " << MaxSimulationCostMs() << "ms";
93 | VLOG(0) << "MCTSMonitor: avg select cost " << AvgSelectCostMs() << "ms";
94 | VLOG(0) << "MCTSMonitor: max select cost " << MaxSelectCostMs() << "ms";
95 | VLOG(0) << "MCTSMonitor: avg expand cost " << AvgExpandCostMs() << "ms";
96 | VLOG(0) << "MCTSMonitor: max expand cost " << MaxExpandCostMs() << "ms";
97 | VLOG(0) << "MCTSMonitor: avg backup cost " << AvgBackupCostMs() << "ms";
98 | VLOG(0) << "MCTSMonitor: max backup cost " << MaxBackupCostMs() << "ms";
99 | VLOG(0) << "MCTSMonitor: select same node " << SelectSameNode() << " times";
100 | VLOG(0) << "MCTSMonitor: search tree height is " << MaxSearchTreeHeight();
101 | VLOG(0) << "MCTSMonitor: avg height of nodes is " << AvgSearchTreeHeight();
102 | VLOG(0) << "MCTSMonitor: avg eval task queue size is " << AvgTaskQueueSize();
103 |
104 |
105 | if (m_engine->GetConfig().enable_async()) {
106 | VLOG(0) << "MCTSMonitor: avg rpc queue size is " << AvgRpcQueueSize();
107 | }
108 | }
109 |
110 | void MCTSMonitor::MonitorRoutine()
111 | {
112 | m_monitor_thread_conductor.Wait();
113 | for (;;) {
114 | if (!m_monitor_thread_conductor.IsRunning()) {
115 | m_monitor_thread_conductor.AckPause();
116 | m_monitor_thread_conductor.Wait();
117 | if (m_monitor_thread_conductor.IsTerminate()) {
118 | LOG(WARNING) << "MonitorRoutine: terminate";
119 | return;
120 | }
121 | }
122 | m_monitor_thread_conductor.Sleep(m_engine->GetConfig().monitor_log_every_ms() * 1000LL);
123 | if (m_monitor_thread_conductor.IsRunning()) {
124 | Log();
125 | }
126 | google::FlushLogFiles(google::GLOG_INFO);
127 | }
128 | }
129 |
--------------------------------------------------------------------------------
/mcts_main.filters:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF}
6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx
7 |
8 |
9 | {93995380-89BD-4b04-88EB-625FBE52EBFB}
10 | h;hh;hpp;hxx;hm;inl;inc;xsd
11 |
12 |
13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01}
14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav
15 |
16 |
17 |
18 |
19 | Source Files
20 |
21 |
22 | Source Files
23 |
24 |
25 | Source Files
26 |
27 |
28 | Source Files
29 |
30 |
31 | Source Files
32 |
33 |
34 | Source Files
35 |
36 |
37 | Source Files
38 |
39 |
40 | Source Files
41 |
42 |
43 | Source Files
44 |
45 |
46 | Source Files
47 |
48 |
49 | Source Files
50 |
51 |
52 | Source Files
53 |
54 |
55 | Source Files
56 |
57 |
58 | Source Files
59 |
60 |
61 | Source Files
62 |
63 |
64 | Source Files
65 |
66 |
67 | Source Files
68 |
69 |
70 | Source Files
71 |
72 |
73 | Source Files
74 |
75 |
76 | Source Files
77 |
78 |
79 | Source Files
80 |
81 |
82 | Source Files
83 |
84 |
85 | Source Files
86 |
87 |
88 | Source Files
89 |
90 |
91 |
92 |
93 | Header Files
94 |
95 |
96 | Header Files
97 |
98 |
99 | Header Files
100 |
101 |
102 | Header Files
103 |
104 |
105 | Header Files
106 |
107 |
108 | Header Files
109 |
110 |
111 | Header Files
112 |
113 |
114 | Header Files
115 |
116 |
117 | Header Files
118 |
119 |
120 | Header Files
121 |
122 |
123 | Header Files
124 |
125 |
126 | Header Files
127 |
128 |
129 | Header Files
130 |
131 |
132 | Header Files
133 |
134 |
135 | Header Files
136 |
137 |
138 | Header Files
139 |
140 |
141 | Header Files
142 |
143 |
144 | Header Files
145 |
146 |
147 | Header Files
148 |
149 |
150 | Header Files
151 |
152 |
153 | Header Files
154 |
155 |
156 | Header Files
157 |
158 |
159 | Header Files
160 |
161 |
162 | Header Files
163 |
164 |
165 | Header Files
166 |
167 |
168 | Header Files
169 |
170 |
171 | Header Files
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
--------------------------------------------------------------------------------
/model/BUILD:
--------------------------------------------------------------------------------
1 | load("//:rules.bzl", "cc_proto_library")
2 | load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_tensorrt")
3 |
4 | cc_binary(
5 | name = "build_tensorrt_model",
6 | srcs = ["build_tensorrt_model.cc"],
7 | deps = [
8 | "@com_github_google_glog//:glog",
9 | ] + if_tensorrt([
10 | "@local_config_tensorrt//:nv_infer",
11 | "@local_config_tensorrt//:nv_parsers",
12 | ]),
13 | copts = if_tensorrt(["-DGOOGLE_TENSORRT=1"]),
14 | )
15 |
16 | cc_library(
17 | name = "zero_model",
18 | srcs = [
19 | "zero_model.cc",
20 | ],
21 | hdrs = ["zero_model.h"],
22 | deps = [
23 | ":zero_model_base",
24 | ":checkpoint_utils",
25 | "@com_github_google_glog//:glog",
26 | "@org_tensorflow//tensorflow/core:tensorflow",
27 | ],
28 | visibility = ["//visibility:public"],
29 | )
30 |
31 | cc_library(
32 | name = "trt_zero_model",
33 | srcs = ["trt_zero_model.cc"],
34 | hdrs = ["trt_zero_model.h"],
35 | deps = [
36 | ":zero_model_base",
37 | "@boost//:filesystem",
38 | "@com_github_google_glog//:glog",
39 | ] + if_tensorrt([
40 | "@local_config_tensorrt//:nv_infer",
41 | ]),
42 | copts = if_tensorrt(["-DGOOGLE_TENSORRT=1"]),
43 | visibility = ["//visibility:public"],
44 | )
45 |
46 | cc_library(
47 | name = "zero_model_base",
48 | hdrs = ["zero_model_base.h"],
49 | deps = [
50 | ":model_config_cc_proto",
51 | "//common:errordef",
52 | ],
53 | visibility = ["//visibility:public"],
54 | )
55 |
56 | cc_proto_library(
57 | name = "model_config_cc_proto",
58 | srcs = ["model_config.proto"],
59 | visibility = ["//visibility:public"],
60 | )
61 |
62 | cc_library(
63 | name = "checkpoint_utils",
64 | srcs = ["checkpoint_utils.cc"],
65 | hdrs = ["checkpoint_utils.h"],
66 | deps = [
67 | ":checkpoint_state_cc_proto",
68 | "@boost//:filesystem",
69 | "@com_github_google_glog//:glog",
70 | ],
71 | visibility = ["//visibility:public"],
72 | )
73 |
74 | cc_proto_library(
75 | name = "checkpoint_state_cc_proto",
76 | srcs = ["checkpoint_state.proto"],
77 | )
78 |
--------------------------------------------------------------------------------
/model/build_tensorrt_model.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making Phoenix Go available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #if GOOGLE_TENSORRT
19 |
20 | #include
21 | #include
22 |
23 | #include
24 | #include
25 |
26 | #include "cuda/include/cuda_runtime_api.h"
27 | #include "tensorrt/include/NvInfer.h"
28 | #include "tensorrt/include/NvUffParser.h"
29 |
30 | DEFINE_string(model_path, "", "Path of model.");
31 | DEFINE_string(data_type, "FP32", "Data type to build, FP32 or FP16.");
32 | DEFINE_int32(max_batch_size, 4, "Max size for input batch.");
33 | DEFINE_int32(max_workspace_size, 1<<30, "Parameter to control memory allocation (in bytes).");
34 | DEFINE_int32(calib_iterations, 10000, "Num of iterations to run while calibration.");
35 | DEFINE_string(storage_address, "", "Address of wegostorage");
36 | DEFINE_int32(gpu, 0, "Gpu used while building.");
37 |
38 | using namespace nvinfer1;
39 | using namespace nvuffparser;
40 |
41 | class Logger : public ILogger
42 | {
43 | void log(Severity severity, const char *msg) override
44 | {
45 | switch (severity) {
46 | case Severity::kINTERNAL_ERROR: LOG(ERROR) << msg; break;
47 | case Severity::kERROR: LOG(ERROR) << msg; break;
48 | case Severity::kWARNING: LOG(WARNING) << msg; break;
49 | case Severity::kINFO: LOG(INFO) << msg; break;
50 | }
51 | }
52 | } g_logger;
53 |
54 | int main(int argc, char *argv[])
55 | {
56 | gflags::ParseCommandLineFlags(&argc, &argv, true);
57 | google::InitGoogleLogging(argv[0]);
58 | google::InstallFailureSignalHandler();
59 |
60 | std::string uff_path = FLAGS_model_path + ".uff";
61 | std::string calib_path = FLAGS_model_path + ".calib";
62 | std::string output_path = FLAGS_model_path + "." + FLAGS_data_type + ".PLAN";
63 |
64 | IUffParser *parser = createUffParser();
65 | parser = createUffParser();
66 | parser->registerInput("inputs", DimsCHW(19, 19, 17));
67 | parser->registerOutput("policy");
68 | parser->registerOutput("value");
69 |
70 | IBuilder *builder = createInferBuilder(g_logger);
71 | INetworkDefinition *network = builder->createNetwork();
72 |
73 | CHECK(parser->parse(uff_path.c_str(), *network, DataType::kFLOAT));
74 |
75 | builder->setMaxBatchSize(FLAGS_max_batch_size);
76 | builder->setMaxWorkspaceSize(FLAGS_max_workspace_size);
77 | builder->setHalf2Mode(FLAGS_data_type == "FP16");
78 |
79 | ICudaEngine *engine = builder->buildCudaEngine(*network);
80 | CHECK_NOTNULL(engine);
81 |
82 | IHostMemory *serialized_engine = engine->serialize();
83 |
84 | std::ofstream output(output_path, std::ios::binary);
85 | output.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size());
86 | }
87 |
88 | #else // GOOGLE_TENSORRT
89 |
90 | #include
91 |
92 | int main()
93 | {
94 | fprintf(stderr, "TensorRT is not enable!\n");
95 | return -1;
96 | }
97 |
98 | #endif // GOOGLE_TENSORRT
99 |
--------------------------------------------------------------------------------
/model/checkpoint_state.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package tensorflow;
4 | option cc_enable_arenas = true;
5 |
6 | // Protocol buffer representing the checkpoint state.
7 | //
8 | // TODO(touts): Add other attributes as needed.
9 | message CheckpointState {
10 | // Path to the most-recent model checkpoint.
11 | string model_checkpoint_path = 1;
12 |
13 | // Paths to all not-yet-deleted model checkpoints, sorted from oldest to
14 | // newest.
15 | // Note that the value of model_checkpoint_path should be the last item in
16 | // this list.
17 | repeated string all_model_checkpoint_paths = 2;
18 | }
19 |
--------------------------------------------------------------------------------
/model/checkpoint_utils.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include "checkpoint_utils.h"
19 |
20 | #include
21 | #include
22 |
23 | #include "model/checkpoint_state.pb.h"
24 |
25 | namespace fs = boost::filesystem;
26 |
27 | fs::path GetCheckpointPath(const fs::path &train_dir)
28 | {
29 | tensorflow::CheckpointState checkpoint_state;
30 | fs::path ckpt_state_path = train_dir / "checkpoint";
31 | std::ostringstream ckpt_ss;
32 | if (!(ckpt_ss << std::ifstream(ckpt_state_path.string()).rdbuf())) {
33 | PLOG(ERROR) << "Error reading " << ckpt_state_path;
34 | return "";
35 | }
36 | if (!google::protobuf::TextFormat::ParseFromString(ckpt_ss.str(), &checkpoint_state)) {
37 | LOG(ERROR) << "Error parsing " << ckpt_state_path << ", buf=" << ckpt_ss.str();
38 | return "";
39 | }
40 | fs::path checkpoint_path = checkpoint_state.model_checkpoint_path();
41 | if (checkpoint_path.is_relative()) {
42 | checkpoint_path = train_dir / checkpoint_path;
43 | }
44 | return checkpoint_path;
45 | }
46 |
47 | bool CopyCheckpoint(const fs::path &from, const fs::path &to)
48 | {
49 | for (int i = 0; i < 3; ++i) {
50 | try {
51 | fs::path from_ckpt_path = GetCheckpointPath(from);
52 | if (from_ckpt_path.empty()) {
53 | LOG(ERROR) << "Error reading model path from " << from;
54 | continue;
55 | }
56 |
57 | fs::path ckpt_name = from_ckpt_path.filename();
58 | fs::path to_ckpt_path = to / ckpt_name;
59 |
60 | fs::create_directories(to);
61 | for (std::string suffix: {".data-00000-of-00001", ".index"}) {
62 | fs::path from_file_path = from_ckpt_path.string() + suffix;
63 | fs::path to_file_path = to_ckpt_path.string() + suffix;
64 | LOG(INFO) << "Copying from " << from_file_path << " to " << to_file_path << suffix;
65 | fs::copy_file(from_file_path, to_file_path, fs::copy_option::overwrite_if_exists);
66 | fs::path symlink_path = to / ("zero.ckpt" + suffix);
67 | fs::remove(symlink_path);
68 | fs::create_symlink(to_file_path.filename(), symlink_path);
69 | }
70 |
71 | tensorflow::CheckpointState checkpoint_state;
72 | checkpoint_state.set_model_checkpoint_path(to_ckpt_path.string());
73 | fs::path ckpt_state_path = to / "checkpoint";
74 | LOG(INFO) << "Writing checkpoint file " << ckpt_state_path;
75 | if (!(std::ofstream(ckpt_state_path.string()) << checkpoint_state.DebugString())) {
76 | PLOG(ERROR) << "Error writing " << to << "/checkpoint";
77 | continue;
78 | }
79 |
80 | fs::copy_file(from / "meta_graph", to / "meta_graph", fs::copy_option::overwrite_if_exists);
81 | LOG(INFO) << "Copy checkpoint from " << from << " to " << to << " succ.";
82 | return true;
83 | } catch (const std::exception &e) {
84 | LOG(ERROR) << e.what();
85 | }
86 | }
87 | return false;
88 | }
89 |
--------------------------------------------------------------------------------
/model/checkpoint_utils.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include
21 |
22 | boost::filesystem::path GetCheckpointPath(const boost::filesystem::path &train_dir);
23 |
24 | bool CopyCheckpoint(const boost::filesystem::path &from, const boost::filesystem::path &to);
25 |
--------------------------------------------------------------------------------
/model/model_config.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | message ModelConfig {
4 | string train_dir = 1;
5 | string checkpoint_path = 2;
6 | string meta_graph_path = 3;
7 | int32 intra_op_parallelism_threads = 6;
8 | int32 inter_op_parallelism_threads = 7;
9 | bool enable_mkl = 8;
10 | int32 kmp_blocktime = 9;
11 | bool kmp_settings = 10;
12 | string kmp_affinity = 11;
13 | bool enable_xla = 12;
14 | bool enable_tensorrt = 15;
15 | string tensorrt_model_path = 16;
16 | }
17 |
--------------------------------------------------------------------------------
/model/trt_zero_model.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include "trt_zero_model.h"
19 |
20 | #if GOOGLE_TENSORRT
21 |
22 | #include
23 | #include
24 |
25 | #include
26 | #include
27 |
28 | #include "cuda/include/cuda_runtime_api.h"
29 | #include "tensorrt/include/NvInfer.h"
30 |
31 | namespace fs = boost::filesystem;
32 |
33 | class Logger : public nvinfer1::ILogger
34 | {
35 | void log(Severity severity, const char *msg) override
36 | {
37 | switch (severity) {
38 | case Severity::kINTERNAL_ERROR: LOG(ERROR) << msg; break;
39 | case Severity::kERROR: LOG(ERROR) << msg; break;
40 | case Severity::kWARNING: LOG(WARNING) << msg; break;
41 | case Severity::kINFO: LOG(INFO) << msg; break;
42 | }
43 | }
44 | } g_logger;
45 |
46 | TrtZeroModel::TrtZeroModel(int gpu)
47 | : m_engine(nullptr), m_runtime(nullptr), m_context(nullptr), m_gpu(gpu), m_global_step(0)
48 | {
49 | }
50 |
51 | TrtZeroModel::~TrtZeroModel()
52 | {
53 | if (m_context) {
54 | m_context->destroy();
55 | }
56 | if (m_engine) {
57 | m_engine->destroy();
58 | }
59 | if (m_runtime) {
60 | m_runtime->destroy();
61 | }
62 | for (auto buf: m_cuda_buf) {
63 | int ret = cudaFree(buf);
64 | if (ret != 0) {
65 | LOG(ERROR) << "cuda free err " << ret;
66 | }
67 | }
68 | }
69 |
70 | int TrtZeroModel::Init(const ModelConfig &model_config)
71 | {
72 | cudaSetDevice(m_gpu);
73 |
74 | fs::path train_dir = model_config.train_dir();
75 |
76 | fs::path tensorrt_model_path = model_config.tensorrt_model_path();
77 | if (tensorrt_model_path.is_relative()) {
78 | tensorrt_model_path = train_dir / tensorrt_model_path;
79 | }
80 |
81 | std::ostringstream model_ss(std::ios::binary);
82 | if (!(model_ss << std::ifstream(tensorrt_model_path.string(), std::ios::binary).rdbuf())) {
83 | PLOG(ERROR) << "read tensorrt model '" << tensorrt_model_path << "' error";
84 | return ERR_READ_TRT_MODEL;
85 | }
86 | std::string model_str = model_ss.str();
87 |
88 | m_runtime = nvinfer1::createInferRuntime(g_logger);
89 | m_engine = m_runtime->deserializeCudaEngine(model_str.c_str(), model_str.size(), nullptr);
90 | if (m_engine == nullptr) {
91 | PLOG(ERROR) << "load cuda engine error";
92 | return ERR_LOAD_TRT_ENGINE;
93 | }
94 | m_context = m_engine->createExecutionContext();
95 |
96 | int batch_size = m_engine->getMaxBatchSize();
97 | LOG(INFO) << "tensorrt max batch size: " << batch_size;
98 | for (int i = 0; i < m_engine->getNbBindings(); ++i) {
99 | auto dim = m_engine->getBindingDimensions(i);
100 | std::string dim_str = "(";
101 | int size = 1;
102 | for (int i = 0; i < dim.nbDims; ++i) {
103 | if (i) dim_str += ", ";
104 | dim_str += std::to_string(dim.d[i]);
105 | size *= dim.d[i];
106 | }
107 | dim_str += ")";
108 | LOG(INFO) << "tensorrt binding: " << m_engine->getBindingName(i) << " " << dim_str;
109 |
110 | void *buf;
111 | int ret = cudaMalloc(&buf, batch_size * size * sizeof(float));
112 | if (ret != 0) {
113 | LOG(ERROR) << "cuda malloc err " << ret;
114 | return ERR_CUDA_MALLOC;
115 | }
116 | m_cuda_buf.push_back(buf);
117 | }
118 |
119 | if (!(std::ifstream(tensorrt_model_path.string() + ".step") >> m_global_step)) {
120 | LOG(WARNING) << "read global step from " << tensorrt_model_path << ".step failed";
121 | }
122 |
123 | return 0;
124 | }
125 |
126 | int TrtZeroModel::Forward(const std::vector> &inputs,
127 | std::vector> &policy, std::vector &value)
128 | {
129 | int batch_size = inputs.size();
130 | if (batch_size == 0) {
131 | LOG(ERROR) << "Error batch size can not be 0.";
132 | return ERR_INVALID_INPUT;
133 | }
134 |
135 | std::vector inputs_flat(batch_size * INPUT_DIM);
136 | for (int i = 0; i < batch_size; ++i) {
137 | if (inputs[i].size() != INPUT_DIM) {
138 | LOG(ERROR) << "Error input dim not match, need " << INPUT_DIM << ", got " << inputs[i].size();
139 | return ERR_INVALID_INPUT;
140 | }
141 | for (int j = 0; j < INPUT_DIM; ++j) {
142 | inputs_flat[i * INPUT_DIM + j] = inputs[i][j];
143 | }
144 | }
145 |
146 | int ret = cudaMemcpy(m_cuda_buf[0], inputs_flat.data(), inputs_flat.size() * sizeof(float), cudaMemcpyHostToDevice);
147 | if (ret != 0) {
148 | LOG(ERROR) << "cuda memcpy err " << ret;
149 | return ERR_CUDA_MEMCPY;
150 | }
151 |
152 | m_context->execute(batch_size, m_cuda_buf.data());
153 |
154 | std::vector policy_flat(batch_size * OUTPUT_DIM);
155 | ret = cudaMemcpy(policy_flat.data(), m_cuda_buf[1], policy_flat.size() * sizeof(float), cudaMemcpyDeviceToHost);
156 | if (ret != 0) {
157 | LOG(ERROR) << "cuda memcpy err " << ret;
158 | return ERR_CUDA_MEMCPY;
159 | }
160 | policy.resize(batch_size);
161 | for (int i = 0; i < batch_size; ++i) {
162 | policy[i].resize(OUTPUT_DIM);
163 | for (int j = 0; j < OUTPUT_DIM; ++j) {
164 | policy[i][j] = policy_flat[i * OUTPUT_DIM + j];
165 | }
166 | }
167 |
168 | value.resize(batch_size);
169 | ret = cudaMemcpy(value.data(), m_cuda_buf[2], value.size() * sizeof(float), cudaMemcpyDeviceToHost);
170 | if (ret != 0) {
171 | LOG(ERROR) << "cuda memcpy err " << ret;
172 | return ERR_CUDA_MEMCPY;
173 | }
174 | for (int i = 0; i < batch_size; ++i) {
175 | value[i] = -value[i];
176 | }
177 |
178 | return 0;
179 | }
180 |
181 | int TrtZeroModel::GetGlobalStep(int &global_step)
182 | {
183 | global_step = m_global_step;
184 | return 0;
185 | }
186 |
187 | #else // GOOGLE_TENSORRT
188 |
189 | #include
190 |
191 | TrtZeroModel::TrtZeroModel(int gpu)
192 | {
193 | LOG(FATAL) << "TensorRT is not enable!";
194 | }
195 |
196 | TrtZeroModel::~TrtZeroModel()
197 | {
198 | LOG(FATAL) << "TensorRT is not enable!";
199 | }
200 |
201 | int TrtZeroModel::Init(const ModelConfig &model_config)
202 | {
203 | LOG(FATAL) << "TensorRT is not enable!";
204 | return 0;
205 | }
206 |
207 | int TrtZeroModel::Forward(const std::vector> &inputs,
208 | std::vector> &policy, std::vector &value)
209 | {
210 | LOG(FATAL) << "TensorRT is not enable!";
211 | return 0;
212 | }
213 |
214 | int TrtZeroModel::GetGlobalStep(int &global_step)
215 | {
216 | LOG(FATAL) << "TensorRT is not enable!";
217 | return 0;
218 | }
219 |
220 | #endif // GOOGLE_TENSORRT
221 |
--------------------------------------------------------------------------------
/model/trt_zero_model.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include
21 |
22 | #include "model/zero_model_base.h"
23 | #include "model/model_config.pb.h"
24 |
25 | namespace nvinfer1 {
26 | class ICudaEngine;
27 | class IRuntime;
28 | class IExecutionContext;
29 | }
30 |
31 | class TrtZeroModel final : public ZeroModelBase
32 | {
33 | public:
34 | TrtZeroModel(int gpu);
35 | ~TrtZeroModel();
36 |
37 | int Init(const ModelConfig &model_config) override;
38 |
39 | // input [batch, 19 * 19 * 17]
40 | // policy [batch, 19 * 19 + 1]
41 | int Forward(const std::vector> &inputs,
42 | std::vector> &policy, std::vector &value) override;
43 |
44 | int GetGlobalStep(int &global_step) override;
45 |
46 | private:
47 | nvinfer1::ICudaEngine *m_engine;
48 | nvinfer1::IRuntime *m_runtime;
49 | nvinfer1::IExecutionContext *m_context;
50 | std::vector m_cuda_buf;
51 | int m_gpu;
52 | int m_global_step;
53 | };
54 |
--------------------------------------------------------------------------------
/model/zero_model.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #include "zero_model.h"
19 |
20 | #include
21 |
22 | #include
23 |
24 | #include "tensorflow/core/public/session.h"
25 | #include "tensorflow/core/protobuf/meta_graph.pb.h"
26 |
27 | #include "model/checkpoint_utils.h"
28 |
29 | namespace fs = boost::filesystem;
30 | namespace tf = tensorflow;
31 |
32 | const std::string input_tensor_name = "inputs";
33 | const std::string policy_tensor_name = "policy";
34 | const std::string value_tensor_name = "value";
35 |
36 | ZeroModel::ZeroModel(int gpu)
37 | : m_session(nullptr), m_gpu(gpu)
38 | {
39 | }
40 |
41 | ZeroModel::~ZeroModel()
42 | {
43 | if (m_session != nullptr) {
44 | tf::Status status = m_session->Close();
45 | if (!status.ok()) {
46 | LOG(ERROR) << "Error closing tf session: " << status.ToString();
47 | }
48 | }
49 | }
50 |
51 | int ZeroModel::Init(const ModelConfig &model_config)
52 | {
53 | fs::path train_dir = model_config.train_dir();
54 |
55 | fs::path meta_graph_path = model_config.meta_graph_path();
56 | if (meta_graph_path.empty()) {
57 | meta_graph_path = train_dir / "meta_graph";
58 | } else if (meta_graph_path.is_relative()) {
59 | meta_graph_path = train_dir / meta_graph_path;
60 | }
61 |
62 | fs::path checkpoint_path = model_config.checkpoint_path();
63 | if (checkpoint_path.empty()) {
64 | checkpoint_path = GetCheckpointPath(train_dir);
65 | } else if (checkpoint_path.is_relative()) {
66 | checkpoint_path = train_dir / checkpoint_path;
67 | }
68 |
69 | if (checkpoint_path.empty()) {
70 | return ERR_READ_CHECKPOINT;
71 | }
72 | LOG(INFO) << "Read checkpoint state succ";
73 |
74 | tf::MetaGraphDef meta_graph_def;
75 | tf::Status status = ReadBinaryProto(tf::Env::Default(), meta_graph_path.string(), &meta_graph_def);
76 | if (!status.ok()) {
77 | LOG(ERROR) << "Error reading graph definition from " << meta_graph_path << ": " << status.ToString();
78 | return ERR_READ_CHECKPOINT;
79 | }
80 | LOG(INFO) << "Read meta graph succ";
81 |
82 | for (auto &node: *meta_graph_def.mutable_graph_def()->mutable_node()) {
83 | node.set_device("/gpu:" + std::to_string(m_gpu));
84 | }
85 |
86 | tf::SessionOptions options;
87 | options.config.set_allow_soft_placement(true);
88 | options.config.mutable_gpu_options()->set_per_process_gpu_memory_fraction(0.5);
89 | options.config.mutable_gpu_options()->set_allow_growth(true);
90 | options.config.set_intra_op_parallelism_threads(model_config.intra_op_parallelism_threads());
91 | options.config.set_inter_op_parallelism_threads(model_config.inter_op_parallelism_threads());
92 | if (model_config.enable_xla()) {
93 | options.config.mutable_graph_options()->mutable_optimizer_options()->set_global_jit_level(tf::OptimizerOptions::ON_1);
94 | }
95 | m_session = std::unique_ptr(tf::NewSession(options));
96 | if (m_session == nullptr) {
97 | LOG(ERROR) << "Could not create Tensorflow session.";
98 | return ERR_CREATE_SESSION;
99 | }
100 | LOG(INFO) << "Create session succ";
101 |
102 | status = m_session->Create(meta_graph_def.graph_def());
103 | if (!status.ok()) {
104 | LOG(ERROR) << "Error creating graph: " << status.ToString();
105 | return ERR_CREATE_GRAPH;
106 | }
107 | LOG(INFO) << "Create graph succ";
108 |
109 | tf::Tensor checkpoint_path_tensor(tf::DT_STRING, tf::TensorShape());
110 | checkpoint_path_tensor.scalar()() = checkpoint_path.string();
111 | status = m_session->Run({{meta_graph_def.saver_def().filename_tensor_name(), checkpoint_path_tensor}},
112 | {}, /* fetches_outputs is empty */
113 | {meta_graph_def.saver_def().restore_op_name()},
114 | nullptr);
115 | if (!status.ok()) {
116 | LOG(ERROR) << "Error loading checkpoint from " << checkpoint_path << ": " << status.ToString();
117 | return ERR_RESTORE_VAR;
118 | }
119 | LOG(INFO) << "Load checkpoint succ";
120 |
121 | std::vector> inputs(1, std::vector(INPUT_DIM, false));
122 | std::vector> policy;
123 | std::vector value;
124 | Forward(inputs, policy, value);
125 |
126 | return 0;
127 | }
128 |
129 | int ZeroModel::Forward(const std::vector> &inputs,
130 | std::vector> &policy, std::vector &value)
131 | {
132 | int batch_size = inputs.size();
133 | if (batch_size == 0) {
134 | LOG(ERROR) << "Error batch size can not be 0.";
135 | return ERR_INVALID_INPUT;
136 | }
137 |
138 | tf::Tensor feature_tensor(tf::DT_BOOL, tf::TensorShape({batch_size, INPUT_DIM}));
139 | auto matrix = feature_tensor.matrix();
140 | for (int i = 0; i < batch_size; ++i) {
141 | if (inputs[i].size() != INPUT_DIM) {
142 | LOG(ERROR) << "Error input dim not match, need " << INPUT_DIM << ", got " << inputs[i].size();
143 | return ERR_INVALID_INPUT;
144 | }
145 | for (int j = 0; j < INPUT_DIM; ++j) {
146 | matrix(i, j) = inputs[i][j];
147 | }
148 | }
149 |
150 |
151 | std::vector> network_inputs = {{input_tensor_name, feature_tensor}};
152 | std::vector fetch_outputs = {policy_tensor_name, value_tensor_name};
153 | std::vector network_outputs;
154 | tf::Status status = m_session->Run(network_inputs, fetch_outputs, {}, &network_outputs);
155 | if (!status.ok()) {
156 | LOG(ERROR) << "Error session run: " << status.ToString();
157 | return ERR_SESSION_RUN;
158 | }
159 |
160 | auto policy_tensor = network_outputs[0].matrix();
161 | auto value_tensor = network_outputs[1].flat();
162 | policy.resize(batch_size);
163 | value.resize(batch_size);
164 | for (int i = 0; i < batch_size; ++i) {
165 | policy[i].resize(OUTPUT_DIM);
166 | for (int j = 0; j < OUTPUT_DIM; ++j) {
167 | policy[i][j] = policy_tensor(i, j);
168 | }
169 | value[i] = -value_tensor(i);
170 | }
171 |
172 | return 0;
173 | }
174 |
175 | int ZeroModel::GetGlobalStep(int &global_step)
176 | {
177 | std::vector network_outputs;
178 | tf::Status status = m_session->Run({}, {"global_step"}, {}, &network_outputs);
179 | if (!status.ok()) {
180 | LOG(ERROR) << "Error session run: " << status.ToString();
181 | return ERR_SESSION_RUN;
182 | }
183 |
184 | global_step = network_outputs[0].scalar()();
185 | return 0;
186 | }
187 |
188 | void ZeroModel::SetMKLEnv(const ModelConfig &model_config)
189 | {
190 | #if defined(_WIN32) || defined(_WIN64)
191 | _putenv_s("KMP_BLOCKTIME", std::to_string(model_config.kmp_blocktime()).c_str());
192 | _putenv_s("KMP_SETTINGS", std::to_string(model_config.kmp_settings()).c_str());
193 | _putenv_s("KMP_AFFINITY", model_config.kmp_affinity().c_str());
194 | if (model_config.intra_op_parallelism_threads() > 0) {
195 | _putenv_s("OMP_NUM_THREADS", std::to_string(model_config.intra_op_parallelism_threads()).c_str());
196 | }
197 | #else
198 | setenv("KMP_BLOCKTIME", std::to_string(model_config.kmp_blocktime()).c_str(), 0);
199 | setenv("KMP_SETTINGS", std::to_string(model_config.kmp_settings()).c_str(), 0);
200 | setenv("KMP_AFFINITY", model_config.kmp_affinity().c_str(), 0);
201 | if (model_config.intra_op_parallelism_threads() > 0) {
202 | setenv("OMP_NUM_THREADS", std::to_string(model_config.intra_op_parallelism_threads()).c_str(), 0);
203 | }
204 | #endif
205 | }
206 |
--------------------------------------------------------------------------------
/model/zero_model.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tencent is pleased to support the open source community by making PhoenixGo available.
3 | *
4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 | *
6 | * Licensed under the BSD 3-Clause License (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * https://opensource.org/licenses/BSD-3-Clause
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include