├── .bazelrc ├── .gitignore ├── .travis.yml ├── BUILD ├── LICENSE.TXT ├── PhoenixGo.sln ├── README.md ├── ThirdParty.props ├── WORKSPACE ├── common ├── BUILD ├── errordef.h ├── go_comm.cc ├── go_comm.h ├── go_state.cc ├── go_state.h ├── str_utils.cc ├── str_utils.h ├── task_queue.h ├── thread_conductor.cc ├── thread_conductor.h ├── timer.cc ├── timer.h ├── wait_group.cc └── wait_group.h ├── configure ├── configure.py ├── dist ├── BUILD ├── async_dist_zero_model_client.cc ├── async_dist_zero_model_client.h ├── async_rpc_queue.h ├── dist_config.proto ├── dist_zero_model.proto ├── dist_zero_model_client.cc ├── dist_zero_model_client.h ├── dist_zero_model_server.cc ├── leaky_bucket.cc └── leaky_bucket.h ├── docs ├── FAQ.md ├── benchmark-gtx1060-75w.md ├── benchmark-teslaV100.md ├── go-review-partner.md ├── mcts-main-help.md ├── minimalist-bazel-configure.md ├── path-errors.md └── tested-versions.md ├── etc ├── mcts_1gpu.conf ├── mcts_1gpu_grp.conf ├── mcts_1gpu_notensorrt.conf ├── mcts_1gpu_notensorrt_grp.conf ├── mcts_2gpu.conf ├── mcts_2gpu_grp.conf ├── mcts_2gpu_notensorrt.conf ├── mcts_2gpu_notensorrt_grp.conf ├── mcts_3gpu.conf ├── mcts_3gpu_notensorrt.conf ├── mcts_4gpu.conf ├── mcts_4gpu_notensorrt.conf ├── mcts_5gpu.conf ├── mcts_5gpu_notensorrt.conf ├── mcts_6gpu.conf ├── mcts_6gpu_notensorrt.conf ├── mcts_7gpu.conf ├── mcts_7gpu_notensorrt.conf ├── mcts_8gpu.conf ├── mcts_8gpu_notensorrt.conf ├── mcts_async_dist.conf ├── mcts_cpu.conf ├── mcts_cpu_grp.conf └── mcts_dist.conf ├── images └── logo.jpg ├── mcts ├── BUILD ├── byo_yomi_timer.cc ├── byo_yomi_timer.h ├── debug_tool.cc ├── mcts_config.cc ├── mcts_config.h ├── mcts_config.proto ├── mcts_debugger.cc ├── mcts_debugger.h ├── mcts_engine.cc ├── mcts_engine.h ├── mcts_main.cc ├── mcts_monitor.cc └── mcts_monitor.h ├── mcts_main.filters ├── mcts_main.vcxproj ├── model ├── BUILD ├── build_tensorrt_model.cc ├── checkpoint_state.proto ├── checkpoint_utils.cc ├── checkpoint_utils.h ├── model_config.proto ├── trt_zero_model.cc ├── trt_zero_model.h ├── zero_model.cc ├── zero_model.h └── zero_model_base.h ├── rules.bzl ├── scripts ├── build_tensorrt_model.sh ├── get_global_step.py ├── graph_transform.py ├── start.sh ├── start_cpu.bat └── start_gpu.bat ├── third_party ├── glog │ ├── BUILD │ └── glog.patch └── tensorflow │ ├── .bazelrc │ ├── BUILD │ └── tensorflow.patch └── tools └── .keep /.bazelrc: -------------------------------------------------------------------------------- 1 | import %workspace%/third_party/tensorflow/.bazelrc 2 | 3 | build --config=opt 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | log 2 | conf 3 | ckpt 4 | bazel-* 5 | compile_commands.json 6 | tools/actions 7 | third_party/bazel 8 | *.pyc 9 | .tf_configure.bazelrc 10 | tools/python_bin_path.sh 11 | *.pb.h 12 | *.pb.cc 13 | *.user 14 | *.VC.db 15 | *.VC.VC.opendb 16 | .vs 17 | x64 18 | x86 19 | Release 20 | Debug 21 | .clang 22 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: trusty 2 | sudo: required 3 | 4 | before_install: 5 | - wget https://github.com/bazelbuild/bazel/releases/download/0.11.1/bazel-0.11.1-installer-linux-x86_64.sh 6 | - chmod +x bazel-0.11.1-installer-linux-x86_64.sh 7 | - ./bazel-0.11.1-installer-linux-x86_64.sh --user 8 | - export PATH="$PATH:$HOME/bin" 9 | 10 | script: 11 | - bazel build //mcts:mcts_main 12 | - bazel build //dist:dist_zero_model_server 13 | -------------------------------------------------------------------------------- /BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/PhoenixGo/fbf67f9aec42531bff9569c44b85eb4c3f37b7be/BUILD -------------------------------------------------------------------------------- /PhoenixGo.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 14 4 | VisualStudioVersion = 14.0.25420.1 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "mcts_main", "mcts_main.vcxproj", "{ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|x64 = Debug|x64 11 | Debug|x86 = Debug|x86 12 | Release|x64 = Release|x64 13 | Release|x86 = Release|x86 14 | EndGlobalSection 15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 16 | {ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}.Debug|x64.ActiveCfg = Debug|x64 17 | {ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}.Debug|x64.Build.0 = Debug|x64 18 | {ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}.Debug|x86.ActiveCfg = Debug|Win32 19 | {ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}.Debug|x86.Build.0 = Debug|Win32 20 | {ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}.Release|x64.ActiveCfg = Release|x64 21 | {ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}.Release|x64.Build.0 = Release|x64 22 | {ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}.Release|x86.ActiveCfg = Release|Win32 23 | {ABF9E6C0-D295-4B39-B367-3B67DF2CA3E9}.Release|x86.Build.0 = Release|Win32 24 | EndGlobalSection 25 | GlobalSection(SolutionProperties) = preSolution 26 | HideSolutionNode = FALSE 27 | EndGlobalSection 28 | EndGlobal 29 | -------------------------------------------------------------------------------- /ThirdParty.props: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | C:\Users\jchzhang\tensorflow 6 | C:\Users\jchzhang\tensorflow\tensorflow\contrib\cmake\build 7 | C:\Users\jchzhang\glog-0.3.5 8 | C:\Users\jchzhang\glog-0.3.5 9 | C:\Users\jchzhang\gflags-2.2.1 10 | C:\Users\jchzhang\gflags-2.2.1 11 | C:\Users\jchzhang\boost_1_66_0 12 | C:\Users\jchzhang\boost_1_66_0\stage\lib 13 | 14 | 15 | 16 | 17 | 18 | $(tensorflow_SourcePath) 19 | 20 | 21 | $(tensorflow_BuildPath) 22 | 23 | 24 | $(glog_SourcePath) 25 | 26 | 27 | $(glog_BuildPath) 28 | 29 | 30 | $(gflags_SourcePath) 31 | 32 | 33 | $(gflags_BuildPath) 34 | 35 | 36 | $(boost_IncludePath) 37 | 38 | 39 | $(boost_LibPath) 40 | 41 | 42 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 2 | 3 | http_archive( 4 | name = "com_github_google_glog", 5 | urls = ["https://github.com/google/glog/archive/55cc27b6eca3d7906fc1a920ca95df7717deb4e7.tar.gz"], 6 | sha256 = "4966f4233e4dcac53c7c7dd26054d17c045447e183bf9df7081575d2b888b95b", 7 | strip_prefix = "glog-55cc27b6eca3d7906fc1a920ca95df7717deb4e7", 8 | patches = ["//third_party/glog:glog.patch"], 9 | ) 10 | 11 | http_archive( 12 | name = "org_tensorflow", 13 | urls = ["https://github.com/tensorflow/tensorflow/archive/v1.13.1.tar.gz"], 14 | sha256 = "7cd19978e6bc7edc2c847bce19f95515a742b34ea5e28e4389dade35348f58ed", 15 | strip_prefix = "tensorflow-1.13.1", 16 | patches = ["//third_party/tensorflow:tensorflow.patch"], 17 | ) 18 | 19 | http_archive( 20 | name = "io_bazel_rules_closure", 21 | sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae", 22 | strip_prefix = "rules_closure-dbb96841cc0a5fb2664c37822803b06dab20c7d1", 23 | urls = [ 24 | "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", 25 | "https://github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", # 2018-04-13 26 | ], 27 | ) 28 | 29 | load('@org_tensorflow//tensorflow:workspace.bzl', 'tf_workspace') 30 | tf_workspace(path_prefix = "", tf_repo_name = "org_tensorflow") 31 | 32 | http_archive( 33 | name = "com_github_nelhage_rules_boost", 34 | urls = ["https://github.com/nelhage/rules_boost/archive/6d6fd834281cb8f8e758dd9ad76df86304bf1869.tar.gz"], 35 | sha256 = "9adb4899e40fc10871bab1ff2e8feee950c194eec9940490f65a2761bbe6941d", 36 | strip_prefix = "rules_boost-6d6fd834281cb8f8e758dd9ad76df86304bf1869", 37 | ) 38 | 39 | load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps") 40 | boost_deps() 41 | -------------------------------------------------------------------------------- /common/BUILD: -------------------------------------------------------------------------------- 1 | cc_library( 2 | name = "go_comm", 3 | srcs = ["go_comm.cc"], 4 | hdrs = ["go_comm.h"], 5 | visibility = ["//visibility:public"], 6 | ) 7 | 8 | cc_library( 9 | name = "go_state", 10 | srcs = ["go_state.cc"], 11 | hdrs = ["go_state.h"], 12 | deps = [":go_comm"], 13 | visibility = ["//visibility:public"], 14 | ) 15 | 16 | cc_library( 17 | name = "errordef", 18 | hdrs = ["errordef.h"], 19 | visibility = ["//visibility:public"], 20 | ) 21 | 22 | cc_library( 23 | name = "task_queue", 24 | hdrs = ["task_queue.h"], 25 | visibility = ["//visibility:public"], 26 | ) 27 | 28 | cc_library( 29 | name = "wait_group", 30 | srcs = ["wait_group.cc"], 31 | hdrs = ["wait_group.h"], 32 | visibility = ["//visibility:public"], 33 | ) 34 | 35 | cc_library( 36 | name = "thread_conductor", 37 | srcs = ["thread_conductor.cc"], 38 | hdrs = ["thread_conductor.h"], 39 | deps = [":wait_group"], 40 | visibility = ["//visibility:public"], 41 | ) 42 | 43 | cc_library( 44 | name = "timer", 45 | srcs = ["timer.cc"], 46 | hdrs = ["timer.h"], 47 | visibility = ["//visibility:public"], 48 | ) 49 | 50 | cc_library( 51 | name = "str_utils", 52 | srcs = ["str_utils.cc"], 53 | hdrs = ["str_utils.h"], 54 | visibility = ["//visibility:public"], 55 | ) 56 | -------------------------------------------------------------------------------- /common/errordef.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | enum { 21 | ERR_INVALID_INPUT = -1, 22 | ERR_FORWARD_TIMEOUT = -2, 23 | 24 | ERR_READ_CHECKPOINT = -1000, 25 | ERR_CREATE_SESSION = -1001, 26 | ERR_CREATE_GRAPH = -1002, 27 | ERR_RESTORE_VAR = -1003, 28 | ERR_SESSION_RUN = -1005, 29 | 30 | // tensorrt error 31 | ERR_READ_TRT_MODEL = -2000, 32 | ERR_LOAD_TRT_ENGINE = -2001, 33 | ERR_CUDA_MALLOC = -2002, 34 | ERR_CUDA_FREE = -2003, 35 | ERR_CUDA_MEMCPY = -2004, 36 | 37 | // rpc error 38 | ERR_GLOBAL_STEP_CONFLICT = -3000, 39 | ERR_EMPTY_RESP = -3001, 40 | }; 41 | -------------------------------------------------------------------------------- /common/go_comm.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include "go_comm.h" 19 | 20 | #include 21 | #include 22 | // #include 23 | 24 | #define x first 25 | #define y second 26 | 27 | using namespace std; 28 | using namespace GoComm; 29 | 30 | GoHashValuePair g_hash_weight[BORDER_SIZE][BORDER_SIZE]; 31 | 32 | vector g_neighbour_cache_by_coord[BORDER_SIZE][BORDER_SIZE]; 33 | GoSize g_neighbour_size[GOBOARD_SIZE]; 34 | GoCoordId g_neighbour_cache_by_id[GOBOARD_SIZE][5]; 35 | GoCoordId g_log2_table[67]; 36 | uint64_t g_zobrist_board_hash_weight[4][GOBOARD_SIZE]; 37 | uint64_t g_zobrist_ko_hash_weight[GOBOARD_SIZE]; 38 | uint64_t g_zobrist_player_hash_weight[4]; 39 | 40 | namespace GoFunction { 41 | 42 | bool InBoard(const GoCoordId id) { 43 | return 0 <= id && id < GOBOARD_SIZE; 44 | } 45 | 46 | bool InBoard(const GoCoordId x, const GoCoordId y) { 47 | return 0 <= x && x < BORDER_SIZE 48 | && 0 <= y && y < BORDER_SIZE; 49 | } 50 | 51 | bool IsPass(const GoCoordId id) { 52 | return COORD_PASS == id; 53 | } 54 | 55 | bool IsPass(const GoCoordId x, const GoCoordId y) { 56 | return COORD_PASS == CoordToId(x, y); 57 | } 58 | 59 | bool IsUnset(const GoCoordId id) { 60 | return COORD_UNSET == id; 61 | } 62 | 63 | bool IsUnset(const GoCoordId x, const GoCoordId y) { 64 | return COORD_UNSET == CoordToId(x, y); 65 | } 66 | 67 | bool IsResign(const GoCoordId id) { 68 | return COORD_RESIGN == id; 69 | } 70 | 71 | bool IsResign(const GoCoordId x, const GoCoordId y) { 72 | return COORD_RESIGN == CoordToId(x, y); 73 | } 74 | 75 | 76 | void IdToCoord(const GoCoordId id, GoCoordId &x, GoCoordId &y) { 77 | if (COORD_PASS == id) { 78 | x = y = COORD_PASS; 79 | } else if (COORD_RESIGN == id) { 80 | x = y = COORD_RESIGN; 81 | } else if (!InBoard(id)) { 82 | x = y = COORD_UNSET; 83 | } else { 84 | x = id / BORDER_SIZE; 85 | y = id % BORDER_SIZE; 86 | } 87 | } 88 | 89 | GoCoordId CoordToId(const GoCoordId x, const GoCoordId y) { 90 | if (COORD_PASS == x && COORD_PASS == y) { 91 | return COORD_PASS; 92 | } 93 | if (COORD_RESIGN == x && COORD_RESIGN == y) { 94 | return COORD_RESIGN; 95 | } 96 | if (!InBoard(x, y)) { 97 | return COORD_UNSET; 98 | } 99 | return x * BORDER_SIZE + y; 100 | } 101 | 102 | void StrToCoord(const string &str, GoCoordId &x, GoCoordId &y) { 103 | // CHECK_EQ(str.length(), 2) << "string[" << str << "] length not equal to 2"; 104 | x = str[0] - 'a'; 105 | y = str[1] - 'a'; 106 | if (str == "zz") { 107 | x = y = COORD_PASS; 108 | } else if (!InBoard(x, y)) { 109 | x = y = COORD_UNSET; 110 | } 111 | } 112 | 113 | string CoordToStr(const GoCoordId x, const GoCoordId y) { 114 | char buffer[3]; 115 | if (!InBoard(x, y)) { 116 | buffer[0] = buffer[1] = 'z'; 117 | } else { 118 | buffer[0] = x + 'a'; 119 | buffer[1] = y + 'a'; 120 | } 121 | return string(buffer, 2); 122 | } 123 | 124 | std::string IdToStr(const GoCoordId id) { 125 | GoCoordId x, y; 126 | IdToCoord(id, x, y); 127 | return CoordToStr(x, y); 128 | } 129 | 130 | GoCoordId StrToId(const std::string &str) { 131 | GoCoordId x, y; 132 | StrToCoord(str, x, y); 133 | return CoordToId(x, y); 134 | } 135 | 136 | once_flag CreateGlobalVariables_once; 137 | void CreateGlobalVariables() { 138 | call_once( 139 | CreateGlobalVariables_once, 140 | []() { 141 | CreateNeighbourCache(); 142 | CreateHashWeights(); 143 | CreateQuickLog2Table(); 144 | CreateZobristHash(); 145 | } 146 | ); 147 | } 148 | 149 | 150 | void CreateHashWeights() { 151 | g_hash_weight[0][0] = GoHashValuePair(1, 1); 152 | for (GoCoordId i = 1; i < GOBOARD_SIZE; ++i) { 153 | g_hash_weight[i / BORDER_SIZE][i % BORDER_SIZE] = 154 | GoHashValuePair(g_hash_weight[(i - 1) / BORDER_SIZE][(i - 1) % BORDER_SIZE].x * g_hash_unit.x, 155 | g_hash_weight[(i - 1) / BORDER_SIZE][(i - 1) % BORDER_SIZE].y * g_hash_unit.y); 156 | } 157 | } 158 | 159 | void CreateNeighbourCache() { 160 | for (GoCoordId x = 0; x < BORDER_SIZE; ++x) { 161 | for (GoCoordId y = 0; y < BORDER_SIZE; ++y) { 162 | GoCoordId id = CoordToId(x, y); 163 | 164 | g_neighbour_cache_by_coord[x][y].clear(); 165 | for (int i = 0; i <= DELTA_SIZE; ++i) { 166 | g_neighbour_cache_by_id[id][i] = COORD_UNSET; 167 | } 168 | for (int i = 0; i < DELTA_SIZE; ++i) { 169 | GoCoordId nx = x + DELTA_X[i]; 170 | GoCoordId ny = y + DELTA_Y[i]; 171 | 172 | if (InBoard(nx, ny)) { 173 | g_neighbour_cache_by_coord[x][y].push_back(GoPosition(nx, ny)); 174 | } 175 | } 176 | g_neighbour_size[id] = g_neighbour_cache_by_coord[x][y].size(); 177 | for (GoSize i = 0; i < g_neighbour_cache_by_coord[x][y].size(); ++i) { 178 | g_neighbour_cache_by_id[id][i] = CoordToId(g_neighbour_cache_by_coord[x][y][i].x, 179 | g_neighbour_cache_by_coord[x][y][i].y); 180 | } 181 | } 182 | } 183 | // cerr << hex << int(g_neighbour_cache_by_coord) << endl; 184 | } 185 | 186 | void CreateQuickLog2Table() { 187 | memset(g_log2_table, -1, sizeof(g_log2_table)); 188 | int tmp = 1; 189 | 190 | for (GoCoordId i = 0; i < 64; ++i) { 191 | g_log2_table[tmp] = i; 192 | tmp *= 2; 193 | tmp %= 67; 194 | } 195 | } 196 | 197 | #if defined(_WIN32) || defined(_WIN64) 198 | static int rand_r(unsigned int *seed) 199 | { 200 | unsigned int next = *seed; 201 | int result; 202 | 203 | next *= 1103515245; 204 | next += 12345; 205 | result = (unsigned int)(next / 65536) % 2048; 206 | 207 | next *= 1103515245; 208 | next += 12345; 209 | result <<= 10; 210 | result ^= (unsigned int)(next / 65536) % 1024; 211 | 212 | next *= 1103515245; 213 | next += 12345; 214 | result <<= 10; 215 | result ^= (unsigned int)(next / 65536) % 1024; 216 | 217 | *seed = next; 218 | 219 | return result; 220 | } 221 | #endif 222 | 223 | void CreateZobristHash() { 224 | uint32_t seed = 0xdeadbeaf; 225 | 226 | for (int i = 0; i < 4; ++i) { 227 | g_zobrist_player_hash_weight[i] = (uint64_t) rand_r(&seed) << 32 | rand_r(&seed); 228 | for (int j = 0; j < GOBOARD_SIZE; ++j) { 229 | g_zobrist_board_hash_weight[i][j] = (uint64_t) rand_r(&seed) << 32 | rand_r(&seed); 230 | } 231 | } 232 | 233 | for (int i = 0; i < GOBOARD_SIZE; ++i) { 234 | g_zobrist_ko_hash_weight[i] = (uint64_t) rand_r(&seed) << 32 | rand_r(&seed); 235 | } 236 | } 237 | 238 | } // namespace GoFunction 239 | 240 | #undef y 241 | #undef x 242 | -------------------------------------------------------------------------------- /common/go_comm.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | #include 22 | #include 23 | 24 | // Return code of functions should be "int" 25 | typedef uint8_t GoStoneColor; // Stone color 26 | typedef int16_t GoCoordId; // Stone IDs or coordinates 27 | typedef int16_t GoBlockId; // Block IDs 28 | typedef int16_t GoSize; // Counts of visit times, used blocks, .. 29 | 30 | namespace GoComm { 31 | 32 | const GoCoordId BORDER_SIZE = 19; 33 | const GoCoordId GOBOARD_SIZE = BORDER_SIZE * BORDER_SIZE; 34 | const GoCoordId COORD_UNSET = -2; 35 | const GoCoordId COORD_PASS = -1; 36 | const GoCoordId COORD_RESIGN = -3; 37 | 38 | const GoSize SIZE_NONE = 0; 39 | 40 | const GoBlockId MAX_BLOCK_SIZE = 1 << 8; 41 | const GoBlockId BLOCK_UNSET = -1; 42 | 43 | const GoStoneColor EMPTY = 0; 44 | const GoStoneColor BLACK = 1; 45 | const GoStoneColor WHITE = 2; 46 | const GoStoneColor WALL = 3; 47 | const GoStoneColor COLOR_UNKNOWN = -1; 48 | const char *const COLOR_STRING[] = { "Empty", "Black", "White", "Wall" }; 49 | 50 | const GoCoordId DELTA_X[] = { 0, 1, 0, -1 }; 51 | const GoCoordId DELTA_Y[] = { -1, 0, 1, 0 }; 52 | const GoSize DELTA_SIZE = sizeof(DELTA_X) / sizeof(*DELTA_X); 53 | 54 | const GoSize UINT64_BITS = sizeof(uint64_t) * 8; 55 | const GoSize LIBERTY_STATE_SIZE = (GOBOARD_SIZE + UINT64_BITS - 1) / UINT64_BITS; 56 | const GoSize BOARD_STATE_SIZE = (GOBOARD_SIZE + UINT64_BITS - 1) / UINT64_BITS; 57 | 58 | } // namespace GoComm 59 | 60 | namespace GoFeature { 61 | 62 | const int SIZE_HISTORYEACHSIDE = 16; 63 | const int SIZE_PLAYERCOLOR = 1; 64 | 65 | const int STARTPOS_HISTORYEACHSIDE = 0; 66 | const int STARTPOS_PLAYERCOLOR = STARTPOS_HISTORYEACHSIDE + SIZE_HISTORYEACHSIDE; 67 | 68 | const int FEATURE_COUNT = STARTPOS_PLAYERCOLOR + SIZE_PLAYERCOLOR; 69 | 70 | } // namespace GoFeature 71 | 72 | 73 | namespace GoFunction { 74 | 75 | extern bool InBoard(const GoCoordId id); 76 | 77 | extern bool InBoard(const GoCoordId x, const GoCoordId y); 78 | 79 | extern bool IsPass(const GoCoordId id); 80 | 81 | extern bool IsPass(const GoCoordId x, const GoCoordId y); 82 | 83 | extern bool IsUnset(const GoCoordId id); 84 | 85 | extern bool IsUnset(const GoCoordId x, const GoCoordId y); 86 | 87 | extern bool IsResign(const GoCoordId id); 88 | 89 | extern bool IsResign(const GoCoordId x, const GoCoordId y); 90 | 91 | 92 | extern void IdToCoord(const GoCoordId id, GoCoordId &x, GoCoordId &y); 93 | 94 | extern GoCoordId CoordToId(const GoCoordId x, const GoCoordId y); 95 | 96 | extern void StrToCoord(const std::string &str, GoCoordId &x, GoCoordId &y); 97 | 98 | extern std::string CoordToStr(const GoCoordId x, const GoCoordId y); 99 | 100 | extern std::string IdToStr(const GoCoordId id); 101 | 102 | extern GoCoordId StrToId(const std::string &str); 103 | 104 | 105 | extern void CreateGlobalVariables(); 106 | 107 | extern void CreateHashWeights(); 108 | 109 | extern void CreateNeighbourCache(); 110 | 111 | extern void CreateQuickLog2Table(); 112 | 113 | extern void CreateZobristHash(); 114 | 115 | } // namespace GoFunction 116 | 117 | 118 | typedef std::pair GoPosition; 119 | typedef std::pair GoHashValuePair; 120 | 121 | extern GoHashValuePair g_hash_weight[GoComm::BORDER_SIZE][GoComm::BORDER_SIZE]; 122 | const GoHashValuePair g_hash_unit(3, 7); 123 | extern uint64_t g_zobrist_board_hash_weight[4][GoComm::GOBOARD_SIZE]; 124 | extern uint64_t g_zobrist_ko_hash_weight[GoComm::GOBOARD_SIZE]; 125 | extern uint64_t g_zobrist_player_hash_weight[4]; 126 | 127 | extern std::vector g_neighbour_cache_by_coord[GoComm::BORDER_SIZE][GoComm::BORDER_SIZE]; 128 | extern GoSize g_neighbour_size[GoComm::GOBOARD_SIZE]; 129 | extern GoCoordId g_neighbour_cache_by_id[GoComm::GOBOARD_SIZE][GoComm::DELTA_SIZE + 1]; 130 | extern GoCoordId g_log2_table[67]; 131 | 132 | #define FOR_NEI(id, nb) for (GoCoordId *nb = g_neighbour_cache_by_id[(id)]; \ 133 | GoComm::COORD_UNSET != *nb; ++nb) 134 | #define FOR_EACHCOORD(id) for (GoCoordId id = 0; id < GoComm::GOBOARD_SIZE; ++id) 135 | #define FOR_EACHBLOCK(id) for (GoBlockId id = 0; id < GoComm::MAX_BLOCK_SIZE; ++id) 136 | -------------------------------------------------------------------------------- /common/str_utils.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include "str_utils.h" 19 | 20 | std::vector SplitStr(const std::string &str, char delim) 21 | { 22 | std::vector ret = {""}; 23 | for (char c: str) { 24 | if (c == delim) { 25 | ret.emplace_back(); 26 | } else { 27 | ret.back() += c; 28 | } 29 | } 30 | return ret; 31 | } 32 | -------------------------------------------------------------------------------- /common/str_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | #include 22 | 23 | std::vector SplitStr(const std::string &str, char delim); 24 | -------------------------------------------------------------------------------- /common/task_queue.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | template 27 | class TaskQueue 28 | { 29 | public: 30 | TaskQueue(int capacity = 0) 31 | : m_capacity(capacity), m_size(0), m_is_close(false) 32 | { 33 | } 34 | 35 | template 36 | void Push(U &&elem) 37 | { 38 | std::unique_lock lock(m_mutex); 39 | if (m_capacity > 0) { 40 | m_push_cond.wait(lock, [this]{ return m_queue.size() < m_capacity; }); 41 | } 42 | m_queue.push_back(std::forward(elem)); 43 | m_size = m_queue.size(); 44 | lock.unlock(); 45 | m_pop_cond.notify_one(); 46 | } 47 | 48 | template 49 | void PushFront(U &&elem) 50 | { 51 | std::unique_lock lock(m_mutex); 52 | m_queue.push_front(std::forward(elem)); 53 | m_size = m_queue.size(); 54 | lock.unlock(); 55 | m_pop_cond.notify_one(); 56 | } 57 | 58 | bool Pop(T &elem, int64_t timeout_us = -1) 59 | { 60 | std::unique_lock lock(m_mutex); 61 | if (timeout_us < 0) { 62 | m_pop_cond.wait(lock, [this]{ return !m_queue.empty() || m_is_close; }); 63 | } else { 64 | if (!m_pop_cond.wait_for(lock, std::chrono::microseconds(timeout_us), 65 | [this]{ return !m_queue.empty() || m_is_close; })) { 66 | return false; 67 | } 68 | } 69 | if (m_queue.empty()) { 70 | return false; 71 | } 72 | elem = std::move(m_queue.front()); 73 | m_queue.pop_front(); 74 | m_size = m_queue.size(); 75 | lock.unlock(); 76 | if (m_capacity > 0) { 77 | m_push_cond.notify_one(); 78 | } 79 | return true; 80 | } 81 | 82 | void Close() 83 | { 84 | { 85 | std::lock_guard lock(m_mutex); 86 | m_is_close = true; 87 | } 88 | m_push_cond.notify_all(); 89 | m_pop_cond.notify_all(); 90 | } 91 | 92 | bool IsClose() const 93 | { 94 | return m_is_close; 95 | } 96 | 97 | int Size() const 98 | { 99 | return m_size; 100 | } 101 | 102 | private: 103 | std::deque m_queue; 104 | int m_capacity; 105 | std::atomic m_size; 106 | std::mutex m_mutex; 107 | std::condition_variable m_push_cond; 108 | std::condition_variable m_pop_cond; 109 | std::atomic m_is_close; 110 | }; 111 | -------------------------------------------------------------------------------- /common/thread_conductor.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include "thread_conductor.h" 19 | 20 | #include 21 | 22 | ThreadConductor::ThreadConductor() : m_state(k_pause) 23 | { 24 | } 25 | 26 | ThreadConductor::~ThreadConductor() 27 | { 28 | } 29 | 30 | void ThreadConductor::Pause() 31 | { 32 | m_resume_wg.Wait(); 33 | std::lock_guard lock(m_mutex); 34 | m_state = k_pause; 35 | m_cond.notify_all(); 36 | } 37 | 38 | void ThreadConductor::Resume(int num_threads) 39 | { 40 | { 41 | std::lock_guard lock(m_mutex); 42 | if (m_state == k_running) return; 43 | m_resume_wg.Add(num_threads); 44 | m_pause_wg.Add(num_threads); 45 | m_state = k_running; 46 | } 47 | m_cond.notify_all(); 48 | m_resume_wg.Wait(); 49 | } 50 | 51 | void ThreadConductor::Wait() 52 | { 53 | { 54 | std::unique_lock lock(m_mutex); 55 | m_cond.wait(lock, [this]{ return m_state != k_pause; }); 56 | if (m_state == k_terminate) return; 57 | } 58 | m_resume_wg.Done(); 59 | } 60 | 61 | void ThreadConductor::AckPause() 62 | { 63 | m_pause_wg.Done(); 64 | } 65 | 66 | bool ThreadConductor::Join(int64_t timeout_us) 67 | { 68 | return m_pause_wg.Wait(timeout_us); 69 | } 70 | 71 | void ThreadConductor::Sleep(int64_t duration_us) 72 | { 73 | std::unique_lock lock(m_mutex); 74 | m_cond.wait_for(lock, std::chrono::microseconds(duration_us), [this]{ return m_state == k_pause; }); 75 | } 76 | 77 | bool ThreadConductor::IsRunning() 78 | { 79 | return m_state == k_running; 80 | } 81 | 82 | void ThreadConductor::Terminate() 83 | { 84 | Pause(); 85 | Join(); 86 | { 87 | std::lock_guard lock(m_mutex); 88 | m_state = k_terminate; 89 | } 90 | m_cond.notify_all(); 91 | } 92 | 93 | bool ThreadConductor::IsTerminate() 94 | { 95 | return m_state == k_terminate; 96 | } 97 | -------------------------------------------------------------------------------- /common/thread_conductor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | #include 22 | #include 23 | 24 | #include "wait_group.h" 25 | 26 | 27 | class ThreadConductor 28 | { 29 | public: 30 | ThreadConductor(); 31 | ~ThreadConductor(); 32 | 33 | void Pause(); 34 | void Resume(int num_threads); 35 | void Wait(); 36 | void AckPause(); 37 | bool Join(int64_t timeout_us = -1); 38 | void Sleep(int64_t duration_us); 39 | bool IsRunning(); 40 | void Terminate(); 41 | bool IsTerminate(); 42 | 43 | private: 44 | std::atomic m_state; 45 | std::mutex m_mutex; 46 | std::condition_variable m_cond; 47 | 48 | WaitGroup m_resume_wg; 49 | WaitGroup m_pause_wg; 50 | 51 | private: 52 | static const int k_pause = 0; 53 | static const int k_running = 1; 54 | static const int k_terminate = 2; 55 | }; 56 | -------------------------------------------------------------------------------- /common/timer.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include "timer.h" 19 | 20 | using namespace std::chrono; 21 | 22 | Timer::Timer() 23 | : m_start(clock::now()) 24 | { 25 | } 26 | 27 | void Timer::Reset() 28 | { 29 | m_start = clock::now(); 30 | } 31 | 32 | int64_t Timer::sec() const 33 | { 34 | return duration_cast(clock::now() - m_start).count(); 35 | } 36 | 37 | int64_t Timer::ms() const 38 | { 39 | return duration_cast(clock::now() - m_start).count(); 40 | } 41 | 42 | int64_t Timer::us() const 43 | { 44 | return duration_cast(clock::now() - m_start).count(); 45 | } 46 | 47 | float Timer::fsec() const 48 | { 49 | return std::chrono::duration(clock::now() - m_start).count(); 50 | } 51 | 52 | float Timer::fms() const 53 | { 54 | return std::chrono::duration(clock::now() - m_start).count(); 55 | } 56 | 57 | float Timer::fus() const 58 | { 59 | return std::chrono::duration(clock::now() - m_start).count(); 60 | } 61 | -------------------------------------------------------------------------------- /common/timer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | 22 | class Timer 23 | { 24 | public: 25 | Timer(); 26 | void Reset(); 27 | int64_t sec() const; 28 | int64_t ms() const; 29 | int64_t us() const; 30 | float fsec() const; 31 | float fms() const; 32 | float fus() const; 33 | 34 | private: 35 | typedef std::chrono::system_clock clock; 36 | clock::time_point m_start; 37 | }; 38 | -------------------------------------------------------------------------------- /common/wait_group.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include "wait_group.h" 19 | 20 | void WaitGroup::Add(int v) 21 | { 22 | bool notify; 23 | { 24 | std::lock_guard lock(m_mutex); 25 | m_counter += v; 26 | if (m_counter < 0) { 27 | throw std::runtime_error("WaitGroup::Add(): m_counter < 0"); 28 | } 29 | notify = (m_counter == 0); 30 | } 31 | if (notify) { 32 | m_cond.notify_all(); 33 | } 34 | } 35 | 36 | void WaitGroup::Done() 37 | { 38 | Add(-1); 39 | } 40 | 41 | bool WaitGroup::Wait(int64_t timeout_us) 42 | { 43 | std::unique_lock lock(m_mutex); 44 | if (timeout_us < 0) { 45 | m_cond.wait(lock, [this]{ return m_counter == 0; }); 46 | return true; 47 | } else { 48 | return m_cond.wait_for(lock, std::chrono::microseconds(timeout_us), [this]{ return m_counter == 0; }); 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /common/wait_group.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | #include 22 | 23 | class WaitGroup 24 | { 25 | public: 26 | void Add(int v = 1); 27 | void Done(); 28 | bool Wait(int64_t timeout_us = -1); 29 | 30 | private: 31 | int m_counter = 0; 32 | std::mutex m_mutex; 33 | std::condition_variable m_cond; 34 | }; 35 | -------------------------------------------------------------------------------- /configure: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -o pipefail 5 | 6 | if [ -z "$PYTHON_BIN_PATH" ]; then 7 | PYTHON_BIN_PATH=$(which python || which python3 || true) 8 | fi 9 | 10 | # Set all env variables 11 | CONFIGURE_DIR=$(dirname "$0") 12 | "$PYTHON_BIN_PATH" "${CONFIGURE_DIR}/configure.py" "$@" 13 | 14 | echo "Configuration finished" 15 | 16 | -------------------------------------------------------------------------------- /dist/BUILD: -------------------------------------------------------------------------------- 1 | load("//:rules.bzl", "cc_proto_library", "tf_cc_binary") 2 | 3 | tf_cc_binary( 4 | name = "dist_zero_model_server", 5 | srcs = [ 6 | "dist_zero_model_server.cc", 7 | ], 8 | deps = [ 9 | ":dist_zero_model_cc_proto", 10 | "//common:timer", 11 | "//model:zero_model", 12 | "//model:trt_zero_model", 13 | "@com_github_google_glog//:glog", 14 | ], 15 | ) 16 | 17 | cc_library( 18 | name = "dist_zero_model_client", 19 | srcs = ["dist_zero_model_client.cc"], 20 | hdrs = ["dist_zero_model_client.h"], 21 | deps = [ 22 | ":dist_config_cc_proto", 23 | ":dist_zero_model_cc_proto", 24 | ":leaky_bucket", 25 | "//model:zero_model_base", 26 | "@com_github_google_glog//:glog", 27 | ], 28 | visibility = ["//visibility:public"], 29 | ) 30 | 31 | cc_library( 32 | name = "async_dist_zero_model_client", 33 | srcs = ["async_dist_zero_model_client.cc"], 34 | hdrs = ["async_dist_zero_model_client.h"], 35 | deps = [ 36 | ":dist_config_cc_proto", 37 | ":dist_zero_model_cc_proto", 38 | ":async_rpc_queue", 39 | ":leaky_bucket", 40 | "//model:zero_model_base", 41 | "@com_github_google_glog//:glog", 42 | ], 43 | visibility = ["//visibility:public"], 44 | ) 45 | 46 | cc_proto_library( 47 | name = "dist_config_cc_proto", 48 | srcs = ["dist_config.proto"], 49 | visibility = ["//visibility:public"], 50 | ) 51 | 52 | cc_proto_library( 53 | name = "dist_zero_model_cc_proto", 54 | srcs = ["dist_zero_model.proto"], 55 | deps = ["//model:model_config_cc_proto"], 56 | use_grpc_plugin = True, 57 | ) 58 | 59 | cc_library( 60 | name = "async_rpc_queue", 61 | hdrs = ["async_rpc_queue.h"], 62 | deps = ["@grpc//:grpc++_unsecure"], 63 | ) 64 | 65 | cc_library( 66 | name = "leaky_bucket", 67 | srcs = ["leaky_bucket.cc"], 68 | hdrs = ["leaky_bucket.h"], 69 | ) 70 | -------------------------------------------------------------------------------- /dist/async_dist_zero_model_client.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include "async_dist_zero_model_client.h" 19 | 20 | #include 21 | 22 | #include 23 | 24 | AsyncDistZeroModelClient::AsyncDistZeroModelClient(const std::vector &svr_addrs, 25 | const DistConfig &dist_config) 26 | : m_config(dist_config), 27 | m_svr_addrs(svr_addrs), 28 | m_bucket_mutexes(new std::mutex[svr_addrs.size()]) 29 | { 30 | CHECK(!m_svr_addrs.empty()); 31 | for (size_t i = 0; i < m_svr_addrs.size(); ++i) { 32 | m_stubs.emplace_back(DistZeroModel::NewStub( 33 | grpc::CreateChannel(m_svr_addrs[i], grpc::InsecureChannelCredentials()))); 34 | m_avail_stubs.push(i); 35 | m_leaky_buckets.emplace_back(m_config.leaky_bucket_size(), m_config.leaky_bucket_refill_period_ms()); 36 | } 37 | m_forward_rpc_complete_thread = std::thread(&AsyncRpcQueue::Complete, &m_forward_rpc_queue, -1); 38 | } 39 | 40 | AsyncDistZeroModelClient::~AsyncDistZeroModelClient() 41 | { 42 | m_forward_rpc_queue.Shutdown(); 43 | LOG(INFO) << "~AsyncDistZeroModelClient waiting async rpc complete thread stop"; 44 | m_forward_rpc_complete_thread.join(); 45 | LOG(INFO) << "~AsyncDistZeroModelClient waiting all stubs released"; 46 | std::unique_lock lock(m_mutex); 47 | m_cond.wait(lock, [this]{ return m_avail_stubs.size() == m_stubs.size(); }); 48 | LOG(INFO) << "~AsyncDistZeroModelClient succ"; 49 | } 50 | 51 | int AsyncDistZeroModelClient::Init(const ModelConfig &model_config) 52 | { 53 | AsyncRpcQueue queue; 54 | InitReq req; 55 | req.mutable_model_config()->CopyFrom(model_config); 56 | int ret = 0; 57 | for (size_t i = 0; i < m_stubs.size(); ++i) { 58 | queue.Call( 59 | BindAsyncRpcFunc(*m_stubs[i], AsyncInit), req, 60 | [this, i, &ret](grpc::Status &status, InitResp &resp) { 61 | if (!status.ok()) { 62 | LOG(ERROR) << "DistZeroModel::Init error, " << m_svr_addrs[i] << ", ret " 63 | << status.error_code() << ": " << status.error_message(); 64 | ret = status.error_code(); 65 | } 66 | } 67 | ); 68 | } 69 | queue.Complete(m_stubs.size()); 70 | return ret; 71 | } 72 | 73 | int AsyncDistZeroModelClient::GetGlobalStep(int &global_step) 74 | { 75 | AsyncRpcQueue queue; 76 | GetGlobalStepReq req; 77 | int ret = 0; 78 | std::vector global_steps(m_stubs.size()); 79 | for (size_t i = 0; i < m_stubs.size(); ++i) { 80 | queue.Call( 81 | BindAsyncRpcFunc(*m_stubs[i], AsyncGetGlobalStep), req, 82 | [this, i, &ret, &global_steps](grpc::Status &status, GetGlobalStepResp &resp) { 83 | if (status.ok()) { 84 | global_steps[i] = resp.global_step(); 85 | } else { 86 | LOG(ERROR) << "DistZeroModel::GetGlobalStep error, " << m_svr_addrs[i] << ", ret " 87 | << status.error_code() << ": " << status.error_message(); 88 | ret = status.error_code(); 89 | } 90 | } 91 | ); 92 | } 93 | queue.Complete(m_stubs.size()); 94 | if (ret == 0) { 95 | for (size_t i = 1; i < m_stubs.size(); ++i) { 96 | if (global_steps[i] != global_steps[0]) { 97 | LOG(ERROR) << "Recived different global_step, " 98 | << global_steps[i] << "(" << m_svr_addrs[i] << ")" << " vs " 99 | << global_steps[0] << "(" << m_svr_addrs[0] << ")"; 100 | ret = ERR_GLOBAL_STEP_CONFLICT; 101 | } 102 | } 103 | } 104 | if (ret == 0) { 105 | global_step = global_steps[0]; 106 | } 107 | return ret; 108 | } 109 | 110 | void AsyncDistZeroModelClient::Forward(const std::vector> &inputs, callback_t callback) 111 | { 112 | int stub_id = GetStub(); 113 | ForwardReq req; 114 | for (const auto &features: inputs) { 115 | if (features.size() != INPUT_DIM) { 116 | LOG(ERROR) << "Error input dim not match, need " << INPUT_DIM << ", got " << features.size(); 117 | callback(ERR_INVALID_INPUT, {}, {}); 118 | return; 119 | } 120 | std::string encode_features((INPUT_DIM + 7) / 8, 0); 121 | for (int i = 0; i < INPUT_DIM; ++i) { 122 | encode_features[i / 8] |= (unsigned char)features[i] << (i % 8); 123 | } 124 | req.add_inputs(encode_features); 125 | } 126 | m_forward_rpc_queue.Call( 127 | BindAsyncRpcFunc(*m_stubs[stub_id], AsyncForward), req, 128 | [this, stub_id, callback](grpc::Status &status, ForwardResp &resp) { 129 | if (status.ok() && resp.outputs_size() == 0) { 130 | status = grpc::Status(grpc::StatusCode(ERR_EMPTY_RESP), "receive empty response"); 131 | } 132 | 133 | bool release_now = true; 134 | if (!status.ok()) { 135 | m_stubs[stub_id] = DistZeroModel::NewStub( 136 | grpc::CreateChannel(m_svr_addrs[stub_id], grpc::InsecureChannelCredentials())); 137 | if (m_config.enable_leaky_bucket()) { 138 | std::lock_guard lock(m_bucket_mutexes[stub_id]); 139 | m_leaky_buckets[stub_id].ConsumeToken(); 140 | release_now = !m_leaky_buckets[stub_id].Empty(); 141 | } 142 | } 143 | if (release_now) { 144 | ReleaseStub(stub_id); 145 | } else { 146 | DisableStub(stub_id); 147 | } 148 | 149 | if (status.ok()) { 150 | std::vector> policy; 151 | std::vector value; 152 | for (auto &output: resp.outputs()) { 153 | policy.emplace_back(output.policy().begin(), output.policy().end()); 154 | value.push_back(output.value()); 155 | } 156 | callback(0, std::move(policy), std::move(value)); 157 | } else if (status.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) { 158 | LOG(ERROR) << "DistZeroModel::Forward timeout, " << m_svr_addrs[stub_id]; 159 | callback(ERR_FORWARD_TIMEOUT, {}, {}); 160 | } else { 161 | LOG(ERROR) << "DistZeroModel::Forward error, " << m_svr_addrs[stub_id] << " " 162 | << status.error_code() << ": " << status.error_message(); 163 | callback(status.error_code(), {}, {}); 164 | } 165 | }, 166 | m_config.timeout_ms() 167 | ); 168 | } 169 | 170 | int AsyncDistZeroModelClient::Forward(const std::vector> &inputs, 171 | std::vector> &policy, std::vector &value) 172 | { 173 | std::promise>, std::vector>> promise; 174 | Forward(inputs, [&promise](int ret, std::vector> policy, std::vector value) { 175 | promise.set_value(std::make_tuple(ret, std::move(policy), std::move(value))); 176 | }); 177 | int ret; 178 | std::tie(ret, policy, value) = promise.get_future().get(); 179 | return ret; 180 | } 181 | 182 | int AsyncDistZeroModelClient::RpcQueueSize() 183 | { 184 | return m_forward_rpc_queue.Size(); 185 | } 186 | 187 | void AsyncDistZeroModelClient::Wait() 188 | { 189 | std::unique_lock lock(m_mutex); 190 | m_cond.wait(lock, [this]{ return !m_avail_stubs.empty(); }); 191 | } 192 | 193 | int AsyncDistZeroModelClient::GetStub() 194 | { 195 | std::unique_lock lock(m_mutex); 196 | m_cond.wait(lock, [this]{ return !m_avail_stubs.empty(); }); 197 | int stub_id = m_avail_stubs.top(); 198 | m_avail_stubs.pop(); 199 | return stub_id; 200 | } 201 | 202 | void AsyncDistZeroModelClient::ReleaseStub(int stub_id) 203 | { 204 | { 205 | std::lock_guard lock(m_mutex); 206 | m_avail_stubs.push(stub_id); 207 | } 208 | m_cond.notify_one(); 209 | } 210 | 211 | void AsyncDistZeroModelClient::DisableStub(int stub_id) 212 | { 213 | LOG(ERROR) << "disable DistZeroModel " << m_svr_addrs[stub_id]; 214 | std::thread([this, stub_id]() { 215 | { 216 | std::lock_guard lock(m_bucket_mutexes[stub_id]); 217 | m_leaky_buckets[stub_id].WaitRefill(); 218 | } 219 | ReleaseStub(stub_id); 220 | LOG(INFO) << "reenable DistZeroModel " << m_svr_addrs[stub_id]; 221 | }).detach(); 222 | } 223 | -------------------------------------------------------------------------------- /dist/async_dist_zero_model_client.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | #include "model/zero_model_base.h" 27 | 28 | #include "dist/async_rpc_queue.h" 29 | #include "dist/leaky_bucket.h" 30 | #include "dist/dist_config.pb.h" 31 | #include "dist/dist_zero_model.grpc.pb.h" 32 | 33 | class AsyncDistZeroModelClient final : public ZeroModelBase 34 | { 35 | public: 36 | AsyncDistZeroModelClient(const std::vector &svr_addrs, const DistConfig &dist_config); 37 | 38 | ~AsyncDistZeroModelClient() override; 39 | 40 | int Init(const ModelConfig &model_config) override; 41 | 42 | int Forward(const std::vector>& inputs, 43 | std::vector> &policy, std::vector &value) override; 44 | 45 | void Forward(const std::vector> &inputs, callback_t callback) override; 46 | 47 | int GetGlobalStep(int &global_step) override; 48 | 49 | int RpcQueueSize() override; 50 | 51 | void Wait() override; 52 | 53 | private: 54 | int GetStub(); 55 | 56 | void ReleaseStub(int stub_id); 57 | 58 | void DisableStub(int stub_id); 59 | 60 | private: 61 | DistConfig m_config; 62 | std::vector m_svr_addrs; 63 | std::vector> m_stubs; 64 | AsyncRpcQueue m_forward_rpc_queue; 65 | std::thread m_forward_rpc_complete_thread; 66 | 67 | std::stack m_avail_stubs; 68 | std::mutex m_mutex; 69 | std::condition_variable m_cond; 70 | 71 | std::vector m_leaky_buckets; 72 | std::unique_ptr m_bucket_mutexes; 73 | }; 74 | -------------------------------------------------------------------------------- /dist/async_rpc_queue.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include 19 | #include 20 | 21 | #include 22 | 23 | template 24 | using AsyncRpcFunc = std::function> 25 | (grpc::ClientContext*, const Req&, grpc::CompletionQueue*)>; 26 | 27 | #define BindAsyncRpcFunc(stub, func) \ 28 | std::bind(&std::remove_reference::type::func, stub, \ 29 | std::placeholders::_1, std::placeholders::_2, std::placeholders::_3) 30 | 31 | template 32 | using AsyncRpcCallback = std::function; 33 | 34 | struct AsyncClientCallBase 35 | { 36 | virtual ~AsyncClientCallBase() {} 37 | virtual void Complete() = 0; 38 | }; 39 | 40 | template 41 | struct AsyncClientCall: public AsyncClientCallBase 42 | { 43 | Resp resp; 44 | grpc::ClientContext context; 45 | grpc::Status status; 46 | std::unique_ptr> response_reader; 47 | AsyncRpcCallback callback; 48 | 49 | void Complete() override { callback(status, resp); } 50 | }; 51 | 52 | class AsyncRpcQueue 53 | { 54 | 55 | public: 56 | AsyncRpcQueue() 57 | : m_size(0), m_is_shutdown(false) 58 | { 59 | } 60 | 61 | template 62 | void Call(AsyncRpcFunc fn, const Req &req, AsyncRpcCallback callback, int timeout_ms = -1) 63 | { 64 | auto call = new AsyncClientCall; 65 | call->callback = callback; 66 | if (timeout_ms > 0) { 67 | call->context.set_deadline(std::chrono::system_clock::now() + std::chrono::milliseconds(timeout_ms)); 68 | } 69 | call->response_reader = fn(&call->context, req, &m_cq); 70 | call->response_reader->Finish(&call->resp, &call->status, (void*)call); 71 | 72 | ++m_size; 73 | } 74 | 75 | void Complete(int n = -1) 76 | { 77 | void *got_tag; 78 | bool ok = false; 79 | for (int i = 0; (n < 0 || i < n) && !m_is_shutdown && m_cq.Next(&got_tag, &ok); ++i) { 80 | std::unique_ptr call(static_cast(got_tag)); 81 | --m_size; 82 | call->Complete(); 83 | } 84 | } 85 | 86 | void Shutdown() 87 | { 88 | m_is_shutdown = true; 89 | m_cq.Shutdown(); 90 | } 91 | 92 | int Size() 93 | { 94 | return m_size; 95 | } 96 | 97 | private: 98 | grpc::CompletionQueue m_cq; 99 | std::atomic m_size; 100 | std::atomic m_is_shutdown; 101 | }; 102 | -------------------------------------------------------------------------------- /dist/dist_config.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | message DistConfig { 4 | int32 timeout_ms = 1; 5 | bool enable_leaky_bucket = 2; 6 | int32 leaky_bucket_size = 3; 7 | int32 leaky_bucket_refill_period_ms = 4; 8 | } 9 | -------------------------------------------------------------------------------- /dist/dist_zero_model.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "model/model_config.proto"; 4 | 5 | message ModelOutput { 6 | repeated float policy = 1; 7 | float value = 2; 8 | } 9 | 10 | message InitReq { ModelConfig model_config = 1; } 11 | message InitResp {} 12 | 13 | message GetGlobalStepReq {} 14 | message GetGlobalStepResp { int32 global_step = 1; } 15 | 16 | message ForwardReq { repeated bytes inputs = 1; } 17 | message ForwardResp { repeated ModelOutput outputs = 1; } 18 | 19 | service DistZeroModel { 20 | rpc Init (InitReq) returns (InitResp) {} 21 | rpc GetGlobalStep (GetGlobalStepReq) returns (GetGlobalStepResp) {} 22 | rpc Forward (ForwardReq) returns (ForwardResp) {} 23 | } 24 | -------------------------------------------------------------------------------- /dist/dist_zero_model_client.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include "dist_zero_model_client.h" 19 | 20 | #include 21 | 22 | #include 23 | #include 24 | 25 | DistZeroModelClient::DistZeroModelClient(const std::string &server_address, const DistConfig &dist_config) 26 | : m_config(dist_config), 27 | m_server_address(server_address), 28 | m_stub(DistZeroModel::NewStub(grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()))), 29 | m_leaky_bucket(dist_config.leaky_bucket_size(), dist_config.leaky_bucket_refill_period_ms()) 30 | { 31 | } 32 | 33 | int DistZeroModelClient::Init(const ModelConfig &model_config) 34 | { 35 | InitReq req; 36 | InitResp resp; 37 | 38 | req.mutable_model_config()->CopyFrom(model_config); 39 | 40 | grpc::ClientContext context; 41 | grpc::Status status = m_stub->Init(&context, req, &resp); 42 | 43 | if (status.ok()) { 44 | return 0; 45 | } else { 46 | LOG(ERROR) << "DistZeroModel::Init error, " << m_server_address << ", ret " 47 | << status.error_code() << ": " << status.error_message(); 48 | return status.error_code(); 49 | } 50 | } 51 | 52 | int DistZeroModelClient::GetGlobalStep(int &global_step) 53 | { 54 | GetGlobalStepReq req; 55 | GetGlobalStepResp resp; 56 | 57 | grpc::ClientContext context; 58 | grpc::Status status = m_stub->GetGlobalStep(&context, req, &resp); 59 | 60 | if (status.ok()) { 61 | global_step = resp.global_step(); 62 | return 0; 63 | } else { 64 | LOG(ERROR) << "DistZeroModel::GetGlobalStep error, " << m_server_address << ", ret " 65 | << status.error_code() << ": " << status.error_message(); 66 | return status.error_code(); 67 | } 68 | } 69 | 70 | int DistZeroModelClient::Forward(const std::vector>& inputs, 71 | std::vector> &policy, std::vector &value) 72 | { 73 | ForwardReq req; 74 | ForwardResp resp; 75 | 76 | for (const auto &features: inputs) { 77 | if (features.size() != INPUT_DIM) { 78 | LOG(ERROR) << "Error input dim not match, need " << INPUT_DIM << ", got " << features.size(); 79 | return ERR_INVALID_INPUT; 80 | } 81 | std::string encode_features((INPUT_DIM + 7) / 8, 0); 82 | for (int i = 0; i < INPUT_DIM; ++i) { 83 | encode_features[i / 8] |= (unsigned char)features[i] << (i % 8); 84 | } 85 | req.add_inputs(encode_features); 86 | } 87 | 88 | grpc::ClientContext context; 89 | if (m_config.timeout_ms() > 0) { 90 | context.set_deadline(std::chrono::system_clock::now() + std::chrono::milliseconds(m_config.timeout_ms())); 91 | } 92 | grpc::Status status = m_stub->Forward(&context, req, &resp); 93 | 94 | if (!status.ok()) { 95 | m_stub = DistZeroModel::NewStub(grpc::CreateChannel(m_server_address, grpc::InsecureChannelCredentials())); 96 | if (m_config.enable_leaky_bucket()) { 97 | m_leaky_bucket.ConsumeToken(); 98 | } 99 | } 100 | 101 | if (status.ok()) { 102 | policy.clear(); 103 | value.clear(); 104 | for (auto &output: resp.outputs()) { 105 | policy.emplace_back(output.policy().begin(), output.policy().end()); 106 | value.push_back(output.value()); 107 | } 108 | return 0; 109 | } else if (status.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) { 110 | LOG(ERROR) << "DistZeroModel::Forward timeout, " << m_server_address; 111 | return ERR_FORWARD_TIMEOUT; 112 | } else { 113 | LOG(ERROR) << "DistZeroModel::Forward error, " << m_server_address << " " 114 | << status.error_code() << ": " << status.error_message(); 115 | return status.error_code(); 116 | } 117 | } 118 | 119 | void DistZeroModelClient::Wait() 120 | { 121 | if (m_config.enable_leaky_bucket() && m_leaky_bucket.Empty()) { 122 | LOG(ERROR) << "disable DistZeroModel " << m_server_address; 123 | m_leaky_bucket.WaitRefill(); 124 | LOG(INFO) << "reenable DistZeroModel " << m_server_address; 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /dist/dist_zero_model_client.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | 22 | #include "model/zero_model_base.h" 23 | 24 | #include "dist/leaky_bucket.h" 25 | #include "dist/dist_config.pb.h" 26 | #include "dist/dist_zero_model.grpc.pb.h" 27 | 28 | class DistZeroModelClient final : public ZeroModelBase 29 | { 30 | public: 31 | DistZeroModelClient(const std::string &server_adress, const DistConfig &dist_config); 32 | 33 | int Init(const ModelConfig &model_config) override; 34 | 35 | int Forward(const std::vector>& inputs, 36 | std::vector> &policy, std::vector &value) override; 37 | 38 | int GetGlobalStep(int &global_step) override; 39 | 40 | void Wait() override; 41 | 42 | private: 43 | DistConfig m_config; 44 | std::string m_server_address; 45 | std::unique_ptr m_stub; 46 | LeakyBucket m_leaky_bucket; 47 | }; 48 | -------------------------------------------------------------------------------- /dist/dist_zero_model_server.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include 19 | #include 20 | #include 21 | 22 | #include "common/timer.h" 23 | #include "model/zero_model.h" 24 | #include "model/trt_zero_model.h" 25 | 26 | #include "dist/dist_zero_model.grpc.pb.h" 27 | 28 | DEFINE_string(server_address, "", "Server address."); 29 | DEFINE_int32(gpu, 0, "Use which gpu."); 30 | 31 | class DistZeroModelServiceImpl final : public DistZeroModel::Service 32 | { 33 | public: 34 | grpc::Status Init(grpc::ServerContext *context, const InitReq *req, InitResp *resp) override 35 | { 36 | std::lock_guard lock(m_mutex); 37 | 38 | LOG(INFO) << "Init with config: " << req->model_config().DebugString(); 39 | 40 | if (req->model_config().enable_mkl()) { 41 | ZeroModel::SetMKLEnv(req->model_config()); 42 | } 43 | 44 | m_model.reset(new ZeroModel(FLAGS_gpu)); 45 | if (req->model_config().enable_tensorrt()) { 46 | m_model.reset(new TrtZeroModel(FLAGS_gpu)); 47 | } 48 | 49 | int ret = m_model->Init(req->model_config()); 50 | if (ret == 0) { 51 | LOG(INFO) << "Init model succ"; 52 | return grpc::Status::OK; 53 | } else { 54 | LOG(ERROR) << "Init model error: " << ret; 55 | return grpc::Status(grpc::StatusCode(ret), "Init model error"); 56 | } 57 | } 58 | 59 | grpc::Status GetGlobalStep(grpc::ServerContext *context, 60 | const GetGlobalStepReq *req, GetGlobalStepResp *resp) override 61 | { 62 | std::lock_guard lock(m_mutex); 63 | 64 | if (m_model == nullptr) { 65 | return grpc::Status(grpc::StatusCode(-1), "DistZeroModel hasn't init"); 66 | } 67 | 68 | int global_step; 69 | int ret = m_model->GetGlobalStep(global_step); 70 | 71 | if (ret == 0) { 72 | LOG(INFO) << "Get global_step=" << global_step; 73 | resp->set_global_step(global_step); 74 | return grpc::Status::OK; 75 | } else { 76 | LOG(INFO) << "Get global_step error: " << ret; 77 | return grpc::Status(grpc::StatusCode(ret), "Get global_step error"); 78 | } 79 | } 80 | 81 | grpc::Status Forward(grpc::ServerContext *context, const ForwardReq *req, ForwardResp *resp) override 82 | { 83 | Timer timer; 84 | std::lock_guard lock(m_mutex); 85 | 86 | if (m_model == nullptr) { 87 | return grpc::Status(grpc::StatusCode(-1), "DistZeroModel hasn't init"); 88 | } 89 | 90 | std::vector> inputs; 91 | for (const auto &encode_features: req->inputs()) { 92 | if ((int)encode_features.size() * 8 < m_model->INPUT_DIM) { 93 | LOG(ERROR) << "Error input features need " << m_model->INPUT_DIM << " bits, recv only " 94 | << encode_features.size() * 8; 95 | return grpc::Status(grpc::StatusCode(ERR_INVALID_INPUT), "Forward error"); 96 | } 97 | std::vector features(m_model->INPUT_DIM); 98 | for (int i = 0; i < m_model->INPUT_DIM; ++i) { 99 | features[i] = (unsigned char)encode_features[i / 8] >> (i % 8) & 1; 100 | } 101 | inputs.push_back(std::move(features)); 102 | } 103 | 104 | std::vector> policy; 105 | std::vector value; 106 | int ret = m_model->Forward(inputs, policy, value); 107 | 108 | if (ret == 0) { 109 | for (size_t i = 0; i < policy.size(); ++i) { 110 | auto *output = resp->add_outputs(); 111 | for (const auto &p: policy[i]) { 112 | output->add_policy(p); 113 | } 114 | output->set_value(value[i]); 115 | } 116 | if (resp->outputs_size() == 0) LOG(ERROR) << "input batch size " << inputs.size() << ", output 0!!!"; 117 | LOG_EVERY_N(INFO, 1000) << "Forward succ."; 118 | return grpc::Status::OK; 119 | } else { 120 | LOG(ERROR) << "Forward error: " << ret; 121 | return grpc::Status(grpc::StatusCode(ret), "Forward error"); 122 | } 123 | } 124 | 125 | private: 126 | std::unique_ptr m_model; 127 | std::mutex m_mutex; 128 | }; 129 | 130 | int main(int argc, char *argv[]) 131 | { 132 | google::ParseCommandLineFlags(&argc, &argv, true); 133 | google::InitGoogleLogging(argv[0]); 134 | google::InstallFailureSignalHandler(); 135 | 136 | DistZeroModelServiceImpl service; 137 | 138 | grpc::ServerBuilder builder; 139 | builder.AddListeningPort(FLAGS_server_address, grpc::InsecureServerCredentials()); 140 | builder.RegisterService(&service); 141 | std::unique_ptr server(builder.BuildAndStart()); 142 | LOG(INFO) << "Server listening on " << FLAGS_server_address; 143 | server->Wait(); 144 | } 145 | -------------------------------------------------------------------------------- /dist/leaky_bucket.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include 19 | 20 | #include "leaky_bucket.h" 21 | 22 | LeakyBucket::LeakyBucket(int bucket_size, int refill_period_ms) 23 | : m_bucket_size(bucket_size), m_tokens(bucket_size), 24 | m_refill_period(std::chrono::milliseconds(refill_period_ms)), 25 | m_last_refill(clock::now()) 26 | { 27 | } 28 | 29 | void LeakyBucket::ConsumeToken() 30 | { 31 | auto now = clock::now(); 32 | if (now - m_last_refill > m_refill_period) { 33 | m_last_refill = now; 34 | m_tokens = m_bucket_size; 35 | } 36 | --m_tokens; 37 | } 38 | 39 | bool LeakyBucket::Empty() 40 | { 41 | return m_tokens <= 0; 42 | } 43 | 44 | void LeakyBucket::WaitRefill() 45 | { 46 | if (m_tokens <= 0) { 47 | std::this_thread::sleep_until(m_last_refill + m_refill_period); 48 | m_last_refill = clock::now(); 49 | m_tokens = m_bucket_size; 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /dist/leaky_bucket.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | 22 | class LeakyBucket { 23 | public: 24 | LeakyBucket(int bucket_size, int refill_period_ms); 25 | void ConsumeToken(); 26 | bool Empty(); 27 | void WaitRefill(); 28 | 29 | private: 30 | int m_bucket_size; 31 | int m_tokens; 32 | typedef std::chrono::system_clock clock; 33 | clock::duration m_refill_period; 34 | clock::time_point m_last_refill; 35 | }; 36 | -------------------------------------------------------------------------------- /docs/benchmark-teslaV100.md: -------------------------------------------------------------------------------- 1 | # Benchmark setup : 2 | 3 | ## setup 4 | - hardware : google cloud machine with Tesla V100-SXM2-16GB, 4 and 5 | 12vcpu (skylake server or later, support avx512, google cloud platform 6 | currently does not allow more than 12 vcpu per die, so maximum for 1 7 | GPU is 12 vcpu = 6 physical cores / 12 cpu threads), 8 | 16gb system ram, 40 gb hdd 9 | - software : ubuntu 18.04 LTS, nvidia 410, cuda 10.0, cudnn 7.4.2, 10 | no tensorrt, bazel 0.11.1 11 | - engine settings: limited time 60 seconds per move, all other time 12 | management settings disabled in config file, all the rest is default 13 | settings 14 | 15 | ## methodology : 16 | - most moves come from the same game played using 17 | [gtp2ogs](https://github.com/online-go/gtp2ogs), for few moves moves, 18 | copy paste stderr output 19 | - tensorRT is not used with the V100 here, because it would need to 20 | build our own tensor model, which was not done here, see 21 | [FAQ question](#a13-i-have-a-nvidia-rtx-card-turing-or-tesla-v100titan-v-volta-is-it-compatible-) 22 | for details 23 | 24 | ## credits : 25 | - credit for doing this tests go to 26 | [wonderingabout](https://github.com/wonderingabout) 27 | - credit for providing the hardware goes to google cloud 28 | platform 29 | 30 | # BATCH SIZE 4 31 | 32 | - batch size 4 33 | - tensorrt : OFF 34 | - 8 threads 35 | - children : 64 36 | - 400M tree size 37 | - unlimited sims 38 | - 60s per move 39 | 40 | ### 4 vcpu (2 physical cores/ 4 cpu threads) : 41 | 42 | ` 43 | stderr: 4th move(w): pp, winrate=56.125683%, N=20226, Q=0.122514, p=0.728064, v=0.114009, cost 60014.109375ms, sims=22976, height=46, avg_height=12.719517, global_step=639200 44 | ` 45 | 46 | ### 12 vcpu (6 physical cores/ 12 cpu threads) : 47 | 48 | ` 49 | stderr: 2th move(w): pd, winrate=56.061207%, N=6544, Q=0.121224, p=0.212260, v=0.106201, cost 60021.523438ms, sims=23096, height=30, avg_height=9.461098, global_step=639200 50 | ` 51 | 52 | # BATCH SIZE 8 53 | 54 | - batch size 8 55 | - tensorrt : OFF 56 | - 16 threads 57 | - children : 96 58 | - 2000M tree size 59 | - unlimited sims 60 | - 60s per move 61 | 62 | ### 4 vcpu (2 physical cores/ 4 cpu threads) : 63 | 64 | ` 65 | stderr: 8th move(w): nq, winrate=56.569016%, N=27010, Q=0.131380, p=0.591054, v=0.117058, cost 60036.335938ms, sims=32152, height=59, avg_height=14.057215, global_step=639200 66 | ` 67 | 68 | ### 12 vcpu (6 physical cores/ 12 cpu threads) : 69 | 70 | ` 71 | stderr: 4th move(w): pp, winrate=56.120705%, N=28938, Q=0.122414, p=0.715722, v=0.114932, cost 60016.148438ms, sims=32728, height=54, avg_height=13.301727, global_step=639200 72 | ` 73 | 74 | # BATCH SIZE 16 75 | 76 | - batch size 16 77 | - tensorrt : OFF 78 | - 32 threads 79 | - children : 128 80 | - 2000M tree size 81 | - unlimited sims 82 | - 60s per move 83 | 84 | ### 4 vcpu (2 physical cores/ 4 cpu threads) : 85 | 86 | ` 87 | stderr: 2th move(w): dp, winrate=56.057503%, N=15696, Q=0.121150, p=0.207913, v=0.105841, cost 60048.324219ms, sims=53568, height=34, avg_height=9.826570, global_step=639200 88 | ` 89 | 90 | ### 12 vcpu (6 physical cores/ 12 cpu threads) : 91 | 92 | ` 93 | stderr: 6th move(w): qn, winrate=56.170525%, N=29628, Q=0.123410, p=0.306212, v=0.111480, cost 60020.058594ms, sims=53968, height=71, avg_height=14.943110, global_step=639200 94 | ` 95 | 96 | # BATCH SIZE 32 97 | 98 | - batch size 32 99 | - tensorrt : OFF 100 | - 64 threads 101 | - children : 128 102 | - 2000M tree size 103 | - unlimited sims 104 | - 60s per move 105 | 106 | ### 4 vcpu (2 physical cores/ 4 cpu threads) : 107 | 108 | ` 109 | stderr: 10th move(w): cp, winrate=65.475777%, N=58821, Q=0.309516, p=0.886150, v=0.196629, cost 60111.078125ms, sims=59444, height=53, avg_height=10.118464, global_step=639200 110 | ` 111 | 112 | ### 12 vcpu (6 physical cores/ 12 cpu threads) : 113 | 114 | ` 115 | stderr: 8th move(w): qf, winrate=56.714546%, N=55717, Q=0.134291, p=0.613282, v=0.114160, cost 60048.957031ms, sims=62368, height=64, avg_height=13.618808, global_step=639200 116 | ` 117 | 118 | # BATCH SIZE 64 119 | 120 | - batch size 64 121 | - tensorrt : OFF 122 | - 32 threads 123 | - children : 128 124 | - 2000M tree size 125 | - unlimited sims 126 | - 60s per move 127 | 128 | ### 4 vcpu (2 physical cores/ 4 cpu threads) : 129 | 130 | ` 131 | stderr: 12th move(w): bo, winrate=65.815079%, N=64431, Q=0.316302, p=0.884180, v=0.226788, cost 60263.683594ms, sims=64960, height=54, avg_height=12.711725, global_step=639200 132 | ` 133 | 134 | ### 12 vcpu (6 physical cores/ 12 cpu threads) : 135 | 136 | ` 137 | stderr: 10th move(w): pc, winrate=65.360603%, N=67943, Q=0.307212, p=0.887373, v=0.175454, cost 60165.914062ms, sims=69031, height=63, avg_height=7.840148, global_step=639200 138 | ` 139 | 140 | # BATCH SIZE 128 141 | 142 | - batch size 128 143 | - tensorrt : OFF 144 | - 256 threads 145 | - children : 128 146 | - 2000M tree size 147 | - unlimited sims 148 | - 60s per move 149 | 150 | ### 4 vcpu (2 physical cores/ 4 cpu threads) : 151 | 152 | ` 153 | stderr: 16th move(w): rf, winrate=65.895859%, N=66225, Q=0.317917, p=0.937327, v=0.202470, cost 60232.250000ms, sims=67328, height=44, avg_height=10.560142, global_step=639200 154 | ` 155 | 156 | ### 12 vcpu (6 physical cores/ 12 cpu threads) : 157 | 158 | ` 159 | stderr: 12th move(w): ob, winrate=65.983253%, N=70697, Q=0.319665, p=0.920173, v=0.223816, cost 60312.035156ms, sims=71664, height=49, avg_height=6.786881, global_step=639200 160 | ` 161 | 162 | # CONCLUSIONS for Tesla V100 : 163 | 164 | - all the conclusions below are without tensorRT optimization, which 165 | is known to bring 15-30% extra computation performance depending on 166 | hardware and settings : 167 | 168 | ``` 169 | -> for batch size 4 to 8 , gain = +43% 170 | -> for batch size 8 to 16 , gain = +60% 171 | -> for batch size 16 to 32 , gain = +14% 172 | -> for batch size 32 to 64 , gain = +10% 173 | -> for batch size 64 to 128 , gain = +3.7% 174 | ``` 175 | 176 | - batch sizes 8 and 16 significant great increases speed on Tesla 177 | V100 with 6 cores / 12 cpu threads or less 178 | - batch sizes higher 16 to 64 bring significant small increase speed 179 | on Tesla V100 with 6 cores / 12 cpu threads or less, but considering 180 | the loss of computing accuracy, this is not an efficient choice 181 | - therefore, the most efficient batch size seems to be 16, providing 182 | **an average 900 simulations per second on Tesla V100**, and a 135% 183 | speed increase as compared to batch size 4 184 | - batch size higher than 64 do not bring significant speed increases on 185 | Tesla V100 with 6 cores / 12 cpu threads or less 186 | 187 | - on the CPU side, as of February 2019, PhoenixGo engine does not 188 | significantly benefit from a number of cpu threads higher than 2 189 | cpu cores/ 4 cpu threads, even on Tesla V100 190 | 191 | For comparison, you can refer to 192 | [gtx-1060-75w-benchmark](benchmark-gtx1060-75w.md) 193 | -------------------------------------------------------------------------------- /docs/go-review-partner.md: -------------------------------------------------------------------------------- 1 | GoReviewPartner github page is [here](https://github.com/pnprog/goreviewpartner) 2 | 3 | Support is still in testing, but so far these are the recommended 4 | settings to do use PhoenixGo with GoReviewPartner 5 | 6 | 3 config files are provided to you using optimized settings for grp 7 | (GoReviewPartner) : 8 | - GPU with tensorRT (linux only) : 9 | [mcts_1gpu_grp.conf](/etc/mcts_1gpu_grp.conf) 10 | - GPU without tensorRT (linux, windows) : 11 | [mcts_1gpu_notensorrt_grp.conf](/etc/mcts_1gpu_notensorrt_grp.conf) 12 | - CPU-only (linux, mac, windows), much slower : 13 | [mcts_cpu_grp.conf](/etc/mcts_cpu_grp.conf) 14 | 15 | note : it is also possible with multiple GPU, only the most common 16 | examples were shown here 17 | 18 | if you want to do the changes manually, you need to : 19 | 20 | in phoenixgo .conf config file : 21 | 22 | - disable all time settings in config file (set to `0`) 23 | - set time to unlimited (set timeout to 0 ms) 24 | - use simulations per move to have fixed computation per move (playouts), 25 | because it is not needed to play fast since this is an analysis, 26 | unlimited time 27 | - for the same reason, set `enable background search` to 0 (disable 28 | pondering), because this is not a live game, time settings are not 29 | needed and can cause conflicts 30 | - add debug width and height with `debugger` in config file, see 31 | [FAQ question](/docs/FAQ.md/#a2-where-is-the-pv-analysis-) 32 | 33 | in grp (GoReviewPartner) settings : 34 | - it is easier to use one of the pre-made grp profile in config.ini 35 | (slightly modify them if needed) 36 | 37 | - Don't forget to add paths in the config file as explained in 38 | [FAQ question](/docs/FAQ.md/#a5-ckptzerockpt-20b-v1fp32plan-error-no-such-file-or-directory) 39 | 40 | - If you're on windows, you need to also pay attention to the syntax 41 | too, for example on windows it is `--logtostderr --v 1` not 42 | `--logtostderr --v=1` , see 43 | [FAQ question](/docs/FAQ.md/#a4-syntax-error-windows) for details and 44 | other differences 45 | 46 | - and run the mcts_main (not start.sh) with the needed parameters, 47 | see an example 48 | [here](https://github.com/wonderingabout/goreviewpartner/blob/config-profiles-phoenixgo/config.ini#L100-L116) 49 | 50 | also, see [#86](https://github.com/Tencent/PhoenixGo/issues/86) and 51 | [#99](https://github.com/pnprog/goreviewpartner/issues/99) for details 52 | -------------------------------------------------------------------------------- /docs/mcts-main-help.md: -------------------------------------------------------------------------------- 1 | For your convenience, a copy of `./mcts_main --help` output is provided to 2 | you below 3 | 4 | Date is January 2019, so it may be outdated if newer versions are released 5 | 6 | `./mcts_main --help` 7 | 8 | Outputs the below result : 9 | 10 | ``` 11 | mcts_main: Warning: SetUsageMessage() never called 12 | 13 | Flags from external/com_github_gflags_gflags/src/gflags.cc: 14 | -flagfile (load flags from file) type: string default: "" 15 | -fromenv (set flags from the environment [use 'export FLAGS_flag1=value']) 16 | type: string default: "" 17 | -tryfromenv (set flags from the environment if present) type: string 18 | default: "" 19 | -undefok (comma-separated list of flag names that it is okay to specify on 20 | the command line even if the program does not define a flag with that 21 | name. IMPORTANT: flags in this list that have arguments MUST use the 22 | flag=value format) type: string default: "" 23 | 24 | Flags from external/com_github_gflags_gflags/src/gflags_completions.cc: 25 | -tab_completion_columns (Number of columns to use in output for tab 26 | completion) type: int32 default: 80 27 | -tab_completion_word (If non-empty, HandleCommandLineCompletions() will 28 | hijack the process and attempt to do bash-style command line flag 29 | completion on this value.) type: string default: "" 30 | 31 | Flags from external/com_github_gflags_gflags/src/gflags_reporting.cc: 32 | -help (show help on all flags [tip: all flags can have two dashes]) 33 | type: bool default: false currently: true 34 | -helpfull (show help on all flags -- same as -help) type: bool 35 | default: false 36 | -helpmatch (show help on modules whose name contains the specified substr) 37 | type: string default: "" 38 | -helpon (show help on the modules named by this flag value) type: string 39 | default: "" 40 | -helppackage (show help on all modules in the main package) type: bool 41 | default: false 42 | -helpshort (show help on only the main module for this program) type: bool 43 | default: false 44 | -helpxml (produce an xml version of help) type: bool default: false 45 | -version (show version and build info and exit) type: bool default: false 46 | 47 | 48 | 49 | Flags from external/com_github_google_glog/src/logging.cc: 50 | -alsologtoemail (log messages go to these email addresses in addition to 51 | logfiles) type: string default: "" 52 | -alsologtostderr (log messages go to stderr in addition to logfiles) 53 | type: bool default: false 54 | -colorlogtostderr (color messages logged to stderr (if supported by 55 | terminal)) type: bool default: false 56 | -drop_log_memory (Drop in-memory buffers of log contents. Logs can grow 57 | very quickly and they are rarely read before they need to be evicted from 58 | memory. Instead, drop them from memory as soon as they are flushed to 59 | disk.) type: bool default: true 60 | -log_backtrace_at (Emit a backtrace when logging at file:linenum.) 61 | type: string default: "" 62 | -log_dir (If specified, logfiles are written into this directory instead of 63 | the default logging directory.) type: string default: "" 64 | -log_link (Put additional links to the log files in this directory) 65 | type: string default: "" 66 | -log_prefix (Prepend the log prefix to the start of each log line) 67 | type: bool default: true 68 | -logbuflevel (Buffer log messages logged at this level or lower (-1 means 69 | don't buffer; 0 means buffer INFO only; ...)) type: int32 default: 0 70 | -logbufsecs (Buffer log messages for at most this many seconds) type: int32 71 | default: 30 72 | -logemaillevel (Email log messages logged at this level or higher (0 means 73 | email all; 3 means email FATAL only; ...)) type: int32 default: 999 74 | -logfile_mode (Log file mode/permissions.) type: int32 default: 436 75 | -logmailer (Mailer used to send logging email) type: string 76 | default: "/bin/mail" 77 | -logtostderr (log messages go to stderr instead of logfiles) type: bool 78 | default: false 79 | -max_log_size (approx. maximum log file size (in MB). A value of 0 will be 80 | silently overridden to 1.) type: int32 default: 1800 81 | -minloglevel (Messages logged at a lower level than this don't actually get 82 | logged anywhere) type: int32 default: 0 83 | -stderrthreshold (log messages at or above this level are copied to stderr 84 | in addition to logfiles. This flag obsoletes --alsologtostderr.) 85 | type: int32 default: 2 86 | -stop_logging_if_full_disk (Stop attempting to log to disk if the disk is 87 | full.) type: bool default: false 88 | 89 | Flags from external/com_github_google_glog/src/vlog_is_on.cc: 90 | -v (Show all VLOG(m) messages for m <= this. Overridable by --vmodule.) 91 | type: int32 default: 0 92 | -vmodule (per-module verbose level. Argument is a comma-separated list of 93 | =. is a glob pattern, matched 94 | against the filename base (that is, name ignoring .cc/.h./-inl.h). overrides any value given by --v.) type: string default: "" 96 | 97 | 98 | 99 | Flags from mcts/mcts_main.cc: 100 | -allow_ip (List of client ip allowed to connect, seperated by comma.) 101 | type: string default: "" 102 | -config_path (Path of mcts config file.) type: string default: "" 103 | -fork_per_request (Fork for each request or not.) type: bool default: true 104 | -gpu_list (List of gpus used by neural network.) type: string default: "" 105 | -gtp (Run as gtp server.) type: bool default: false 106 | -init_moves (Initialize Go board with init_moves.) type: string default: "" 107 | -inter_op_parallelism_threads (Number of tf's inter op threads) type: int32 108 | default: 0 109 | -intra_op_parallelism_threads (Number of tf's intra op threads) type: int32 110 | default: 0 111 | -listen_port (Listen which port.) type: int32 default: 0 112 | 113 | ``` 114 | -------------------------------------------------------------------------------- /docs/minimalist-bazel-configure.md: -------------------------------------------------------------------------------- 1 | To reduce phoenixgo size and building time, it is wise to remove all 2 | the uneeded options during the bazel `./configure` , as was discussed 3 | here : [#76](https://github.com/Tencent/PhoenixGo/issues/76) 4 | 5 | This is an example of minimalist options that you can use : 6 | 7 | note : if you have trouble with path configurations, see 8 | [path-errors](/docs/path-errors.md) 9 | 10 | ``` 11 | Please specify the location of python. [Default is /usr/bin/python]: 12 | 13 | 14 | Found possible Python library paths: 15 | /usr/local/lib/python2.7/dist-packages 16 | /usr/lib/python2.7/dist-packages 17 | Please input the desired Python library path to use. Default is [/usr/local/lib/python2.7/dist-packages] 18 | 19 | Do you wish to build TensorFlow with jemalloc as malloc support? [Y/n]: 20 | jemalloc as malloc support will be enabled for TensorFlow. 21 | 22 | Do you wish to build TensorFlow with Google Cloud Platform support? [Y/n]: n 23 | No Google Cloud Platform support will be enabled for TensorFlow. 24 | 25 | Do you wish to build TensorFlow with Hadoop File System support? [Y/n]: n 26 | No Hadoop File System support will be enabled for TensorFlow. 27 | 28 | Do you wish to build TensorFlow with Amazon S3 File System support? [Y/n]: n 29 | No Amazon S3 File System support will be enabled for TensorFlow. 30 | 31 | Do you wish to build TensorFlow with Apache Kafka Platform support? [Y/n]: n 32 | No Apache Kafka Platform support will be enabled for TensorFlow. 33 | 34 | Do you wish to build TensorFlow with XLA JIT support? [y/N]: y 35 | XLA JIT support will be enabled for TensorFlow. 36 | 37 | Do you wish to build TensorFlow with GDR support? [y/N]: y 38 | GDR support will be enabled for TensorFlow. 39 | 40 | Do you wish to build TensorFlow with VERBS support? [y/N]: y 41 | VERBS support will be enabled for TensorFlow. 42 | 43 | Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: 44 | No OpenCL SYCL support will be enabled for TensorFlow. 45 | 46 | Do you wish to build TensorFlow with CUDA support? [y/N]: y 47 | CUDA support will be enabled for TensorFlow. 48 | 49 | Please specify the CUDA SDK version you want to use, e.g. 7.0. [Leave empty to default to CUDA 9.0]: 50 | 51 | 52 | Please specify the location where CUDA 9.0 toolkit is installed. Refer to README.md for more details. [Default is /usr/local/cuda]: 53 | 54 | 55 | Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 7.0]: 7.0.5 56 | 57 | 58 | Please specify the location where cuDNN 7 library is installed. Refer to README.md for more details. [Default is /usr/local/cuda]: 59 | 60 | 61 | Do you wish to build TensorFlow with TensorRT support? [y/N]: y 62 | TensorRT support will be enabled for TensorFlow. 63 | 64 | Please specify the location where TensorRT is installed. [Default is /usr/lib/x86_64-linux-gnu]: 65 | 66 | 67 | Please specify the NCCL version you want to use. [Leave empty to default to NCCL 1.3]: 68 | 69 | 70 | Please specify a list of comma-separated Cuda compute capabilities you want to build with. 71 | You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus. 72 | Please note that each additional compute capability significantly increases your build time and binary size. [Default is: 6.1] 73 | 74 | 75 | Do you want to use clang as CUDA compiler? [y/N]: 76 | nvcc will be used as CUDA compiler. 77 | 78 | Please specify which gcc should be used by nvcc as the host compiler. [Default is /usr/bin/gcc]: 79 | 80 | 81 | Do you wish to build TensorFlow with MPI support? [y/N]: 82 | No MPI support will be enabled for TensorFlow. 83 | 84 | Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]: 85 | 86 | 87 | Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n 88 | 89 | ``` 90 | -------------------------------------------------------------------------------- /docs/path-errors.md: -------------------------------------------------------------------------------- 1 | In the example below, ubuntu 16.04 LTS is used with cuda 9.0 (deb install), 2 | cudnn 7.0.5 (deb install), tensorrt 3.0.4 (deb install), as well as bazel 0.18.1 3 | (.sh file install) (0.11.1 also works) 4 | 5 | Other linux distributions with nvidia tar install are possible too 6 | (some other versions have been [tested here](/docs/tested-versions.md)), 7 | but it is easier to do it like that, and remember that this is just an example. 8 | The settings below have been tested to be working and to fix most common path 9 | issues, and are shown as an interactive help : 10 | 11 | ## 1) post-install of cuda : do the path exports 12 | 13 | After cuda 9.0 deb install and cudnn 7.0.5 deb are installed successfully, 14 | one post install step is needed : add the path to cuda-9.0. Here we do an 15 | all in once step by also including the path of cudnn even if cudnn is not 16 | installed yet. 17 | 18 | The `export` command alone adds the path during current boot, but the 19 | changes will be lost after reboot. This is why we will add paths permanently 20 | using `bashrc` 21 | 22 | ### bashrc 23 | 24 | edit bashrc file : 25 | 26 | `sudo nano ~/.bashrc` 27 | 28 | you need to add the lines below at the end of the bashrc file (using nano, 29 | or the text editor of your choice) : 30 | 31 | ``` 32 | # add paths for cuda-9.0 33 | export PATH=${PATH}:/usr/local/cuda-9.0/bin 34 | 35 | # in case you use tensorrt tar install 36 | # add paths for cuda and cudnn install paths 37 | export CUDA_INSTALL_DIR=/usr/local/cuda 38 | export CUDNN_INSTALL_DIR=/usr/local/cuda 39 | ``` 40 | 41 | With nano, after editing is finished, save and exit with `Ctrl+X` + 42 | then press `y` + then press ENTER key. 43 | 44 | Now update bashrc file : 45 | 46 | `source ~/.bashrc` 47 | 48 | as you can now see, now the command nvcc --version will work successfully : 49 | 50 | `nvcc --version` 51 | 52 | should display something like this : 53 | 54 | ``` 55 | nvcc: NVIDIA (R) Cuda compiler driver 56 | Copyright (c) 2005-2017 NVIDIA Corporation 57 | Built on Fri_Sep__1_21:08:03_CDT_2017 58 | Cuda compilation tools, release 9.0, V9.0.176 59 | ``` 60 | 61 | you can also check if cuda installation is a success with 62 | `cat /proc/driver/nvidia/version` 63 | 64 | should display something like this : 65 | 66 | ``` 67 | NVRM version: NVIDIA UNIX x86_64 Kernel Module 384.130 Wed Mar 21 03:37:26 PDT 2018 68 | GCC version: gcc version 5.4.0 20160609 (Ubuntu 5.4.0-6ubuntu1~16.04.10) 69 | ``` 70 | 71 | If you need help of how to install the deb files of cuda 9.0, cudnn 7.0.5, 72 | and tensorrt 3.0.4 for ubuntu 16.04, you can go here 73 | [@wonderingabout](https://github.com/wonderingabout/nvidia-archives) 74 | 75 | ## 1b) after cuda 9.0 and cudnn 7.0.5 deb installs, test your cudnn 76 | 77 | run a test to see if your install works by compiling and runing a 78 | cudnn code sample : 79 | 80 | ``` 81 | cp -r /usr/src/cudnn_samples_v7/ ~ && cd ~/cudnn_samples_v7/mnistCUDNN && make clean && make && ./mnistCUDNN 82 | ``` 83 | 84 | should display this : `Test passed!` 85 | 86 | ## 2) locate cuda and cudnn paths, and update database if not here 87 | 88 | Run this command : `locate libcudart.so && locate libcudnn.so.7` 89 | 90 | you need to see something like this : 91 | 92 | ``` 93 | /usr/local/cuda-9.0/doc/man/man7/libcudart.so.7 94 | /usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so 95 | /usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0 96 | /usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0.176 97 | /usr/lib/x86_64-linux-gnu/libcudnn.so.7 98 | /usr/lib/x86_64-linux-gnu/libcudnn.so.7.0.5 99 | ``` 100 | If you don't see this, run this command : 101 | 102 | `sudo updatedb && locate libcudart.so && locate libcudnn.so.7` 103 | 104 | It should now display all the cuda and cudnn paths same as above. 105 | 106 | ## 3) during bazel compile, this is the paths you need to put 107 | 108 | Press ENTER for every prompt to choose default settings 109 | (or `n` if you dont want a setting), except for these : 110 | 111 | - CUDA : choose `y` , version `9.0` if it is not default 112 | (then leave path to default by pressing ENTER) 113 | - cudnn : choose version `7.0.5` 114 | (then leave default path by pressing ENTER) 115 | - if you want to use tensorrt do `y` and press enter to keep default 116 | path if you did tensorrt deb install, but for tar install you need to setup 117 | manually your path, as you can see here 118 | [@wonderingabout](https://github.com/wonderingabout/nvidia-archives) 119 | 120 | same as below : 121 | 122 | ``` 123 | Do you wish to build TensorFlow with CUDA support? [y/N]: y 124 | CUDA support will be enabled for TensorFlow. 125 | 126 | Please specify the CUDA SDK version you want to use, e.g. 7.0. [Leave empty to default to CUDA 9.0]: 127 | 128 | Please specify the location where CUDA 9.0 toolkit is installed. Refer to README.md for more details. [Default is /usr/local/cuda]: 129 | 130 | Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 7.0]: 7.0.5 131 | 132 | Please specify the location where cuDNN 7 library is installed. Refer to README.md for more details. [Default is /usr/local/cuda-9.0/]: 133 | 134 | Do you wish to build TensorFlow with TensorRT support? [y/N]: y 135 | TensorRT support will be enabled for TensorFlow. 136 | 137 | Please specify the location where TensorRT is installed. [Default is /usr/lib/x86_64-linux-gnu]: 138 | ``` 139 | 140 | note : (tar install is needed to install `.whl` to inrcease max batch size 141 | with tensorrt), see : [#75](https://github.com/Tencent/PhoenixGo/issues/75) 142 | 143 | Final words : 144 | 145 | Remember that these settings are just an example, other settings or 146 | package versions or linux distributions are possible too, but this 147 | example has been tested to successfully work on ubuntu 16.04 LTS with 148 | deb install of cuda 9.0, deb install of cudnn 7.0.5, deb install of 149 | tensorrt 3.0.4, as well as .sh file install of bazel 0.11.1 150 | 151 | They are provided as a general help for linux compile and run, they 152 | are not an obligatory method to use, but will hopefully make using 153 | PhoenixGo on linux systems easier 154 | 155 | credits : 156 | - [mishra.thedeepak](https://medium.com/@mishra.thedeepak/cuda-and-cudnn-installation-for-tensorflow-gpu-79beebb356d2) 157 | - [kezulin.me](https://kezunlin.me/post/dacc4196/) 158 | - [wonderingabout](https://github.com/wonderingabout/nvidia-archives) 159 | 160 | minor sources : 161 | - [nvidia pdf install guide for cuda 9.0](http://developer.download.nvidia.com/compute/cuda/9.0/Prod/docs/sidebar/CUDA_Installation_Guide_Linux.pdf) 162 | - [medium.com/@zhanwenchen/](https://medium.com/@zhanwenchen/install-cuda-and-cudnn-for-tensorflow-gpu-on-ubuntu-79306e4ac04e) 163 | - [nvidia cudnn](https://developer.nvidia.com/rdp/cudnn-archive) 164 | - [nvidia tensorrt](https://developer.nvidia.com/nvidia-tensorrt3-download) 165 | -------------------------------------------------------------------------------- /docs/tested-versions.md: -------------------------------------------------------------------------------- 1 | Below is a list of other versions of cuda/cudnn/tensorrt/bazel tar/deb 2 | tested by independent contributors, that work with PhoenixGo AI 3 | 4 | compatibility for tensorrt : 5 | - 3.0.4 requires cudnn 7.0.x and + 6 | - 4.x requires cudnn 7.1.x and + 7 | - 5.x requires cudnn 7.3.x and + 8 | 9 | ### tests by [wonderingabout](https://github.com/wonderingabout/) : 10 | 11 | #### works 12 | - bazel : 0.11.1 , 0.17.2 , 0.18.1 13 | - ubuntu : 16.04 , 18.04 14 | - cuda : 9.0 deb , 10.0 deb (10.0 deb ubuntu 18.04 is with cudnn 15 | 7.4.x deb and no tensorrt), 9.0 for windows 10 and windows server 16 | 2016 17 | - cudnn : 7.0.5 deb, 7.1.4 deb , 7.4.2 deb, 7.1.4 on windows 10 18 | with cuda 9.0 19 | - tensorrt : 3.0.4 deb , 3.0.4 tar 20 | - pycuda : for tensorrt tar, pycuda with pyhton 2.7 21 | (needs [a modification](http://0561blue.tistory.com/m/13?category=627413), 22 | pycuda specific issue, see [#75](https://github.com/Tencent/PhoenixGo/issues/75) 23 | 24 | #### does NOT work : 25 | - cuda on windows 10 : cuda 10.0 26 | - cudnn for windows 10 : cudnn 7.4.x for cuda 10.0 , cudnn 7.4.x for 27 | cuda 9.0 28 | - tensorrt : 4.x deb , 4.x tar, 5.x deb 29 | - bazel : 0.19.x and superior 30 | (needs [a modification](https://github.com/tensorflow/tensorflow/issues/23401)) 31 | -------------------------------------------------------------------------------- /etc/mcts_1gpu.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 1 2 | num_search_threads: 8 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | enable_tensorrt: 1 12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN" 13 | } 14 | gpu_list: "0" 15 | c_puct: 2.5 16 | virtual_loss: 1.0 17 | enable_resign: 1 18 | v_resign: -0.9 19 | enable_dirichlet_noise: 0 20 | dirichlet_noise_alpha: 0.03 21 | dirichlet_noise_ratio: 0.25 22 | monitor_log_every_ms: 0 23 | get_best_move_mode: 0 24 | enable_background_search: 1 25 | enable_policy_temperature: 0 26 | policy_temperature: 0.67 27 | inherit_default_act: 1 28 | early_stop { 29 | enable: 1 30 | check_every_ms: 100 31 | sims_factor: 1.0 32 | sims_threshold: 2000 33 | } 34 | unstable_overtime { 35 | enable: 1 36 | time_factor: 0.3 37 | } 38 | behind_overtime { 39 | enable: 1 40 | act_threshold: 0.0 41 | time_factor: 0.3 42 | } 43 | time_control { 44 | enable: 1 45 | c_denom: 20 46 | c_maxply: 40 47 | reserved_time: 1.0 48 | } 49 | -------------------------------------------------------------------------------- /etc/mcts_1gpu_grp.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 1 2 | num_search_threads: 8 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 0 6 | max_simulations_per_step: 3200 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | enable_tensorrt: 1 12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN" 13 | } 14 | gpu_list: "0" 15 | c_puct: 2.5 16 | virtual_loss: 1.0 17 | enable_resign: 1 18 | v_resign: -0.9 19 | enable_dirichlet_noise: 0 20 | dirichlet_noise_alpha: 0.03 21 | dirichlet_noise_ratio: 0.25 22 | monitor_log_every_ms: 0 23 | get_best_move_mode: 0 24 | enable_background_search: 0 25 | enable_policy_temperature: 0 26 | policy_temperature: 0.67 27 | inherit_default_act: 1 28 | debugger { 29 | print_tree_depth: 20 30 | print_tree_width: 3 31 | } 32 | early_stop { 33 | enable: 0 34 | check_every_ms: 100 35 | sims_factor: 1.0 36 | sims_threshold: 2000 37 | } 38 | unstable_overtime { 39 | enable: 0 40 | time_factor: 0.3 41 | } 42 | behind_overtime { 43 | enable: 0 44 | act_threshold: 0.0 45 | time_factor: 0.3 46 | } 47 | time_control { 48 | enable: 0 49 | c_denom: 20 50 | c_maxply: 40 51 | reserved_time: 1.0 52 | } 53 | -------------------------------------------------------------------------------- /etc/mcts_1gpu_notensorrt.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 1 2 | num_search_threads: 8 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | } 12 | gpu_list: "0" 13 | c_puct: 2.5 14 | virtual_loss: 1.0 15 | enable_resign: 1 16 | v_resign: -0.9 17 | enable_dirichlet_noise: 0 18 | dirichlet_noise_alpha: 0.03 19 | dirichlet_noise_ratio: 0.25 20 | monitor_log_every_ms: 0 21 | get_best_move_mode: 0 22 | enable_background_search: 1 23 | enable_policy_temperature: 0 24 | policy_temperature: 0.67 25 | inherit_default_act: 1 26 | early_stop { 27 | enable: 1 28 | check_every_ms: 100 29 | sims_factor: 1.0 30 | sims_threshold: 2000 31 | } 32 | unstable_overtime { 33 | enable: 1 34 | time_factor: 0.3 35 | } 36 | behind_overtime { 37 | enable: 1 38 | act_threshold: 0.0 39 | time_factor: 0.3 40 | } 41 | time_control { 42 | enable: 1 43 | c_denom: 20 44 | c_maxply: 40 45 | reserved_time: 1.0 46 | } 47 | -------------------------------------------------------------------------------- /etc/mcts_1gpu_notensorrt_grp.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 1 2 | num_search_threads: 8 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 0 6 | max_simulations_per_step: 3200 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | } 12 | gpu_list: "0" 13 | c_puct: 2.5 14 | virtual_loss: 1.0 15 | enable_resign: 1 16 | v_resign: -0.9 17 | enable_dirichlet_noise: 0 18 | dirichlet_noise_alpha: 0.03 19 | dirichlet_noise_ratio: 0.25 20 | monitor_log_every_ms: 0 21 | get_best_move_mode: 0 22 | enable_background_search: 0 23 | enable_policy_temperature: 0 24 | policy_temperature: 0.67 25 | inherit_default_act: 1 26 | debugger { 27 | print_tree_depth: 20 28 | print_tree_width: 3 29 | } 30 | early_stop { 31 | enable: 0 32 | check_every_ms: 100 33 | sims_factor: 1.0 34 | sims_threshold: 2000 35 | } 36 | unstable_overtime { 37 | enable: 0 38 | time_factor: 0.3 39 | } 40 | behind_overtime { 41 | enable: 0 42 | act_threshold: 0.0 43 | time_factor: 0.3 44 | } 45 | time_control { 46 | enable: 0 47 | c_denom: 20 48 | c_maxply: 40 49 | reserved_time: 1.0 50 | } 51 | -------------------------------------------------------------------------------- /etc/mcts_2gpu.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 2 2 | num_search_threads: 12 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | enable_tensorrt: 1 12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN" 13 | } 14 | gpu_list: "0,1" 15 | c_puct: 2.5 16 | virtual_loss: 1.0 17 | enable_resign: 1 18 | v_resign: -0.9 19 | enable_dirichlet_noise: 0 20 | dirichlet_noise_alpha: 0.03 21 | dirichlet_noise_ratio: 0.25 22 | monitor_log_every_ms: 0 23 | get_best_move_mode: 0 24 | enable_background_search: 1 25 | enable_policy_temperature: 0 26 | policy_temperature: 0.67 27 | inherit_default_act: 1 28 | early_stop { 29 | enable: 1 30 | check_every_ms: 100 31 | sims_factor: 1.0 32 | sims_threshold: 2000 33 | } 34 | unstable_overtime { 35 | enable: 1 36 | time_factor: 0.3 37 | } 38 | behind_overtime { 39 | enable: 1 40 | act_threshold: 0.0 41 | time_factor: 0.3 42 | } 43 | time_control { 44 | enable: 1 45 | c_denom: 20 46 | c_maxply: 40 47 | reserved_time: 1.0 48 | } 49 | -------------------------------------------------------------------------------- /etc/mcts_2gpu_grp.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 2 2 | num_search_threads: 12 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 0 6 | max_simulations_per_step: 3200 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | enable_tensorrt: 1 12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN" 13 | } 14 | gpu_list: "0,1" 15 | c_puct: 2.5 16 | virtual_loss: 1.0 17 | enable_resign: 1 18 | v_resign: -0.9 19 | enable_dirichlet_noise: 0 20 | dirichlet_noise_alpha: 0.03 21 | dirichlet_noise_ratio: 0.25 22 | monitor_log_every_ms: 0 23 | get_best_move_mode: 0 24 | enable_background_search: 0 25 | enable_policy_temperature: 0 26 | policy_temperature: 0.67 27 | inherit_default_act: 1 28 | debugger { 29 | print_tree_depth: 20 30 | print_tree_width: 3 31 | } 32 | early_stop { 33 | enable: 0 34 | check_every_ms: 100 35 | sims_factor: 1.0 36 | sims_threshold: 2000 37 | } 38 | unstable_overtime { 39 | enable: 0 40 | time_factor: 0.3 41 | } 42 | behind_overtime { 43 | enable: 0 44 | act_threshold: 0.0 45 | time_factor: 0.3 46 | } 47 | time_control { 48 | enable: 0 49 | c_denom: 20 50 | c_maxply: 40 51 | reserved_time: 1.0 52 | } 53 | -------------------------------------------------------------------------------- /etc/mcts_2gpu_notensorrt.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 2 2 | num_search_threads: 12 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | } 12 | gpu_list: "0,1" 13 | c_puct: 2.5 14 | virtual_loss: 1.0 15 | enable_resign: 1 16 | v_resign: -0.9 17 | enable_dirichlet_noise: 0 18 | dirichlet_noise_alpha: 0.03 19 | dirichlet_noise_ratio: 0.25 20 | monitor_log_every_ms: 0 21 | get_best_move_mode: 0 22 | enable_background_search: 1 23 | enable_policy_temperature: 0 24 | policy_temperature: 0.67 25 | inherit_default_act: 1 26 | early_stop { 27 | enable: 1 28 | check_every_ms: 100 29 | sims_factor: 1.0 30 | sims_threshold: 2000 31 | } 32 | unstable_overtime { 33 | enable: 1 34 | time_factor: 0.3 35 | } 36 | behind_overtime { 37 | enable: 1 38 | act_threshold: 0.0 39 | time_factor: 0.3 40 | } 41 | time_control { 42 | enable: 1 43 | c_denom: 20 44 | c_maxply: 40 45 | reserved_time: 1.0 46 | } 47 | -------------------------------------------------------------------------------- /etc/mcts_2gpu_notensorrt_grp.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 2 2 | num_search_threads: 12 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 0 6 | max_simulations_per_step: 3200 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | } 12 | gpu_list: "0,1" 13 | c_puct: 2.5 14 | virtual_loss: 1.0 15 | enable_resign: 1 16 | v_resign: -0.9 17 | enable_dirichlet_noise: 0 18 | dirichlet_noise_alpha: 0.03 19 | dirichlet_noise_ratio: 0.25 20 | monitor_log_every_ms: 0 21 | get_best_move_mode: 0 22 | enable_background_search: 0 23 | enable_policy_temperature: 0 24 | policy_temperature: 0.67 25 | inherit_default_act: 1 26 | debugger { 27 | print_tree_depth: 20 28 | print_tree_width: 3 29 | } 30 | early_stop { 31 | enable: 0 32 | check_every_ms: 100 33 | sims_factor: 1.0 34 | sims_threshold: 2000 35 | } 36 | unstable_overtime { 37 | enable: 0 38 | time_factor: 0.3 39 | } 40 | behind_overtime { 41 | enable: 0 42 | act_threshold: 0.0 43 | time_factor: 0.3 44 | } 45 | time_control { 46 | enable: 0 47 | c_denom: 20 48 | c_maxply: 40 49 | reserved_time: 1.0 50 | } 51 | -------------------------------------------------------------------------------- /etc/mcts_3gpu.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 3 2 | num_search_threads: 16 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | enable_tensorrt: 1 12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN" 13 | } 14 | gpu_list: "0,1,2" 15 | c_puct: 2.5 16 | virtual_loss: 1.0 17 | enable_resign: 1 18 | v_resign: -0.9 19 | enable_dirichlet_noise: 0 20 | dirichlet_noise_alpha: 0.03 21 | dirichlet_noise_ratio: 0.25 22 | monitor_log_every_ms: 0 23 | get_best_move_mode: 0 24 | enable_background_search: 1 25 | enable_policy_temperature: 0 26 | policy_temperature: 0.67 27 | inherit_default_act: 1 28 | early_stop { 29 | enable: 1 30 | check_every_ms: 100 31 | sims_factor: 1.0 32 | sims_threshold: 2000 33 | } 34 | unstable_overtime { 35 | enable: 1 36 | time_factor: 0.3 37 | } 38 | behind_overtime { 39 | enable: 1 40 | act_threshold: 0.0 41 | time_factor: 0.3 42 | } 43 | time_control { 44 | enable: 1 45 | c_denom: 20 46 | c_maxply: 40 47 | reserved_time: 1.0 48 | } 49 | -------------------------------------------------------------------------------- /etc/mcts_3gpu_notensorrt.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 3 2 | num_search_threads: 16 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | } 12 | gpu_list: "0,1,2" 13 | c_puct: 2.5 14 | virtual_loss: 1.0 15 | enable_resign: 1 16 | v_resign: -0.9 17 | enable_dirichlet_noise: 0 18 | dirichlet_noise_alpha: 0.03 19 | dirichlet_noise_ratio: 0.25 20 | monitor_log_every_ms: 0 21 | get_best_move_mode: 0 22 | enable_background_search: 1 23 | enable_policy_temperature: 0 24 | policy_temperature: 0.67 25 | inherit_default_act: 1 26 | early_stop { 27 | enable: 1 28 | check_every_ms: 100 29 | sims_factor: 1.0 30 | sims_threshold: 2000 31 | } 32 | unstable_overtime { 33 | enable: 1 34 | time_factor: 0.3 35 | } 36 | behind_overtime { 37 | enable: 1 38 | act_threshold: 0.0 39 | time_factor: 0.3 40 | } 41 | time_control { 42 | enable: 1 43 | c_denom: 20 44 | c_maxply: 40 45 | reserved_time: 1.0 46 | } 47 | -------------------------------------------------------------------------------- /etc/mcts_4gpu.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 4 2 | num_search_threads: 20 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | enable_tensorrt: 1 12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN" 13 | } 14 | gpu_list: "0,1,2,3" 15 | c_puct: 2.5 16 | virtual_loss: 1.0 17 | enable_resign: 1 18 | v_resign: -0.9 19 | enable_dirichlet_noise: 0 20 | dirichlet_noise_alpha: 0.03 21 | dirichlet_noise_ratio: 0.25 22 | monitor_log_every_ms: 0 23 | get_best_move_mode: 0 24 | enable_background_search: 1 25 | enable_policy_temperature: 0 26 | policy_temperature: 0.67 27 | inherit_default_act: 1 28 | early_stop { 29 | enable: 1 30 | check_every_ms: 100 31 | sims_factor: 1.0 32 | sims_threshold: 2000 33 | } 34 | unstable_overtime { 35 | enable: 1 36 | time_factor: 0.3 37 | } 38 | behind_overtime { 39 | enable: 1 40 | act_threshold: 0.0 41 | time_factor: 0.3 42 | } 43 | time_control { 44 | enable: 1 45 | c_denom: 20 46 | c_maxply: 40 47 | reserved_time: 1.0 48 | } 49 | -------------------------------------------------------------------------------- /etc/mcts_4gpu_notensorrt.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 4 2 | num_search_threads: 20 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | } 12 | gpu_list: "0,1,2,3" 13 | c_puct: 2.5 14 | virtual_loss: 1.0 15 | enable_resign: 1 16 | v_resign: -0.9 17 | enable_dirichlet_noise: 0 18 | dirichlet_noise_alpha: 0.03 19 | dirichlet_noise_ratio: 0.25 20 | monitor_log_every_ms: 0 21 | get_best_move_mode: 0 22 | enable_background_search: 1 23 | enable_policy_temperature: 0 24 | policy_temperature: 0.67 25 | inherit_default_act: 1 26 | early_stop { 27 | enable: 1 28 | check_every_ms: 100 29 | sims_factor: 1.0 30 | sims_threshold: 2000 31 | } 32 | unstable_overtime { 33 | enable: 1 34 | time_factor: 0.3 35 | } 36 | behind_overtime { 37 | enable: 1 38 | act_threshold: 0.0 39 | time_factor: 0.3 40 | } 41 | time_control { 42 | enable: 1 43 | c_denom: 20 44 | c_maxply: 40 45 | reserved_time: 1.0 46 | } 47 | -------------------------------------------------------------------------------- /etc/mcts_5gpu.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 5 2 | num_search_threads: 24 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | enable_tensorrt: 1 12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN" 13 | } 14 | gpu_list: "0,1,2,3,4" 15 | c_puct: 2.5 16 | virtual_loss: 1.0 17 | enable_resign: 1 18 | v_resign: -0.9 19 | enable_dirichlet_noise: 0 20 | dirichlet_noise_alpha: 0.03 21 | dirichlet_noise_ratio: 0.25 22 | monitor_log_every_ms: 0 23 | get_best_move_mode: 0 24 | enable_background_search: 1 25 | enable_policy_temperature: 0 26 | policy_temperature: 0.67 27 | inherit_default_act: 1 28 | early_stop { 29 | enable: 1 30 | check_every_ms: 100 31 | sims_factor: 1.0 32 | sims_threshold: 2000 33 | } 34 | unstable_overtime { 35 | enable: 1 36 | time_factor: 0.3 37 | } 38 | behind_overtime { 39 | enable: 1 40 | act_threshold: 0.0 41 | time_factor: 0.3 42 | } 43 | time_control { 44 | enable: 1 45 | c_denom: 20 46 | c_maxply: 40 47 | reserved_time: 1.0 48 | } 49 | -------------------------------------------------------------------------------- /etc/mcts_5gpu_notensorrt.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 5 2 | num_search_threads: 24 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | } 12 | gpu_list: "0,1,2,3,4" 13 | c_puct: 2.5 14 | virtual_loss: 1.0 15 | enable_resign: 1 16 | v_resign: -0.9 17 | enable_dirichlet_noise: 0 18 | dirichlet_noise_alpha: 0.03 19 | dirichlet_noise_ratio: 0.25 20 | monitor_log_every_ms: 0 21 | get_best_move_mode: 0 22 | enable_background_search: 1 23 | enable_policy_temperature: 0 24 | policy_temperature: 0.67 25 | inherit_default_act: 1 26 | early_stop { 27 | enable: 1 28 | check_every_ms: 100 29 | sims_factor: 1.0 30 | sims_threshold: 2000 31 | } 32 | unstable_overtime { 33 | enable: 1 34 | time_factor: 0.3 35 | } 36 | behind_overtime { 37 | enable: 1 38 | act_threshold: 0.0 39 | time_factor: 0.3 40 | } 41 | time_control { 42 | enable: 1 43 | c_denom: 20 44 | c_maxply: 40 45 | reserved_time: 1.0 46 | } 47 | -------------------------------------------------------------------------------- /etc/mcts_6gpu.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 6 2 | num_search_threads: 28 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | enable_tensorrt: 1 12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN" 13 | } 14 | gpu_list: "0,1,2,3,4,5" 15 | c_puct: 2.5 16 | virtual_loss: 1.0 17 | enable_resign: 1 18 | v_resign: -0.9 19 | enable_dirichlet_noise: 0 20 | dirichlet_noise_alpha: 0.03 21 | dirichlet_noise_ratio: 0.25 22 | monitor_log_every_ms: 0 23 | get_best_move_mode: 0 24 | enable_background_search: 1 25 | enable_policy_temperature: 0 26 | policy_temperature: 0.67 27 | inherit_default_act: 1 28 | early_stop { 29 | enable: 1 30 | check_every_ms: 100 31 | sims_factor: 1.0 32 | sims_threshold: 2000 33 | } 34 | unstable_overtime { 35 | enable: 1 36 | time_factor: 0.3 37 | } 38 | behind_overtime { 39 | enable: 1 40 | act_threshold: 0.0 41 | time_factor: 0.3 42 | } 43 | time_control { 44 | enable: 1 45 | c_denom: 20 46 | c_maxply: 40 47 | reserved_time: 1.0 48 | } 49 | -------------------------------------------------------------------------------- /etc/mcts_6gpu_notensorrt.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 6 2 | num_search_threads: 28 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | } 12 | gpu_list: "0,1,2,3,4,5" 13 | c_puct: 2.5 14 | virtual_loss: 1.0 15 | enable_resign: 1 16 | v_resign: -0.9 17 | enable_dirichlet_noise: 0 18 | dirichlet_noise_alpha: 0.03 19 | dirichlet_noise_ratio: 0.25 20 | monitor_log_every_ms: 0 21 | get_best_move_mode: 0 22 | enable_background_search: 1 23 | enable_policy_temperature: 0 24 | policy_temperature: 0.67 25 | inherit_default_act: 1 26 | early_stop { 27 | enable: 1 28 | check_every_ms: 100 29 | sims_factor: 1.0 30 | sims_threshold: 2000 31 | } 32 | unstable_overtime { 33 | enable: 1 34 | time_factor: 0.3 35 | } 36 | behind_overtime { 37 | enable: 1 38 | act_threshold: 0.0 39 | time_factor: 0.3 40 | } 41 | time_control { 42 | enable: 1 43 | c_denom: 20 44 | c_maxply: 40 45 | reserved_time: 1.0 46 | } 47 | -------------------------------------------------------------------------------- /etc/mcts_7gpu.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 7 2 | num_search_threads: 32 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | enable_tensorrt: 1 12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN" 13 | } 14 | gpu_list: "0,1,2,3,4,5,6" 15 | c_puct: 2.5 16 | virtual_loss: 1.0 17 | enable_resign: 1 18 | v_resign: -0.9 19 | enable_dirichlet_noise: 0 20 | dirichlet_noise_alpha: 0.03 21 | dirichlet_noise_ratio: 0.25 22 | monitor_log_every_ms: 0 23 | get_best_move_mode: 0 24 | enable_background_search: 1 25 | enable_policy_temperature: 0 26 | policy_temperature: 0.67 27 | inherit_default_act: 1 28 | early_stop { 29 | enable: 1 30 | check_every_ms: 100 31 | sims_factor: 1.0 32 | sims_threshold: 2000 33 | } 34 | unstable_overtime { 35 | enable: 1 36 | time_factor: 0.3 37 | } 38 | behind_overtime { 39 | enable: 1 40 | act_threshold: 0.0 41 | time_factor: 0.3 42 | } 43 | time_control { 44 | enable: 1 45 | c_denom: 20 46 | c_maxply: 40 47 | reserved_time: 1.0 48 | } 49 | -------------------------------------------------------------------------------- /etc/mcts_7gpu_notensorrt.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 7 2 | num_search_threads: 32 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | } 12 | gpu_list: "0,1,2,3,4,5,6" 13 | c_puct: 2.5 14 | virtual_loss: 1.0 15 | enable_resign: 1 16 | v_resign: -0.9 17 | enable_dirichlet_noise: 0 18 | dirichlet_noise_alpha: 0.03 19 | dirichlet_noise_ratio: 0.25 20 | monitor_log_every_ms: 0 21 | get_best_move_mode: 0 22 | enable_background_search: 1 23 | enable_policy_temperature: 0 24 | policy_temperature: 0.67 25 | inherit_default_act: 1 26 | early_stop { 27 | enable: 1 28 | check_every_ms: 100 29 | sims_factor: 1.0 30 | sims_threshold: 2000 31 | } 32 | unstable_overtime { 33 | enable: 1 34 | time_factor: 0.3 35 | } 36 | behind_overtime { 37 | enable: 1 38 | act_threshold: 0.0 39 | time_factor: 0.3 40 | } 41 | time_control { 42 | enable: 1 43 | c_denom: 20 44 | c_maxply: 40 45 | reserved_time: 1.0 46 | } 47 | -------------------------------------------------------------------------------- /etc/mcts_8gpu.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 8 2 | num_search_threads: 36 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | enable_tensorrt: 1 12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN" 13 | } 14 | gpu_list: "0,1,2,3,4,5,6,7" 15 | c_puct: 2.5 16 | virtual_loss: 1.0 17 | enable_resign: 1 18 | v_resign: -0.9 19 | enable_dirichlet_noise: 0 20 | dirichlet_noise_alpha: 0.03 21 | dirichlet_noise_ratio: 0.25 22 | monitor_log_every_ms: 0 23 | get_best_move_mode: 0 24 | enable_background_search: 1 25 | enable_policy_temperature: 0 26 | policy_temperature: 0.67 27 | inherit_default_act: 1 28 | early_stop { 29 | enable: 1 30 | check_every_ms: 100 31 | sims_factor: 1.0 32 | sims_threshold: 2000 33 | } 34 | unstable_overtime { 35 | enable: 1 36 | time_factor: 0.3 37 | } 38 | behind_overtime { 39 | enable: 1 40 | act_threshold: 0.0 41 | time_factor: 0.3 42 | } 43 | time_control { 44 | enable: 1 45 | c_denom: 20 46 | c_maxply: 40 47 | reserved_time: 1.0 48 | } 49 | -------------------------------------------------------------------------------- /etc/mcts_8gpu_notensorrt.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 8 2 | num_search_threads: 36 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | } 12 | gpu_list: "0,1,2,3,4,5,6,7" 13 | c_puct: 2.5 14 | virtual_loss: 1.0 15 | enable_resign: 1 16 | v_resign: -0.9 17 | enable_dirichlet_noise: 0 18 | dirichlet_noise_alpha: 0.03 19 | dirichlet_noise_ratio: 0.25 20 | monitor_log_every_ms: 0 21 | get_best_move_mode: 0 22 | enable_background_search: 1 23 | enable_policy_temperature: 0 24 | policy_temperature: 0.67 25 | inherit_default_act: 1 26 | early_stop { 27 | enable: 1 28 | check_every_ms: 100 29 | sims_factor: 1.0 30 | sims_threshold: 2000 31 | } 32 | unstable_overtime { 33 | enable: 1 34 | time_factor: 0.3 35 | } 36 | behind_overtime { 37 | enable: 1 38 | act_threshold: 0.0 39 | time_factor: 0.3 40 | } 41 | time_control { 42 | enable: 1 43 | c_denom: 20 44 | c_maxply: 40 45 | reserved_time: 1.0 46 | } 47 | -------------------------------------------------------------------------------- /etc/mcts_async_dist.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 64 2 | num_search_threads: 32 3 | enable_async: 1 4 | eval_task_queue_size: 128 5 | max_children_per_node: 64 6 | max_search_tree_size: 400000000 7 | timeout_ms_per_step: 30000 8 | max_simulations_per_step: 0 9 | eval_batch_size: 4 10 | eval_wait_batch_timeout_us: 100 11 | model_config { 12 | train_dir: "ckpt" 13 | enable_tensorrt: 1 14 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN" 15 | } 16 | enable_dist: 1 17 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 18 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 19 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 20 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 21 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 22 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 23 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 24 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 25 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 26 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 27 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 28 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 29 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 30 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 31 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 32 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 33 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 34 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 35 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 36 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 37 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 38 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 39 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 40 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 41 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 42 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 43 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 44 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 45 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 46 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 47 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 48 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 49 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 50 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 51 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 52 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 53 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 54 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 55 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 56 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 57 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 58 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 59 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 60 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 61 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 62 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 63 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 64 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 65 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 66 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 67 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 68 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 69 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 70 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 71 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 72 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 73 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 74 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 75 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 76 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 77 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 78 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 79 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 80 | dist_svr_addrs: "ip:port,ip:port,ip:port,ip:port" 81 | dist_config { 82 | timeout_ms: 100 83 | enable_leaky_bucket: 1 84 | leaky_bucket_size: 3 85 | leaky_bucket_refill_period_ms: 5000 86 | } 87 | c_puct: 2.5 88 | virtual_loss: 0.5 89 | enable_resign: 1 90 | v_resign: -0.9 91 | enable_dirichlet_noise: 0 92 | dirichlet_noise_alpha: 0.03 93 | dirichlet_noise_ratio: 0.25 94 | monitor_log_every_ms: 0 95 | get_best_move_mode: 0 96 | enable_background_search: 1 97 | enable_policy_temperature: 0 98 | policy_temperature: 0.67 99 | inherit_default_act: 1 100 | early_stop { 101 | enable: 1 102 | check_every_ms: 100 103 | sims_factor: 1.0 104 | sims_threshold: 100000 105 | } 106 | unstable_overtime { 107 | enable: 1 108 | time_factor: 0.3 109 | } 110 | behind_overtime { 111 | enable: 1 112 | act_threshold: 0.0 113 | time_factor: 0.3 114 | } 115 | time_control { 116 | enable: 1 117 | c_denom: 20 118 | c_maxply: 40 119 | reserved_time: 1.0 120 | } 121 | -------------------------------------------------------------------------------- /etc/mcts_cpu.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 1 2 | num_search_threads: 8 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | } 12 | c_puct: 2.5 13 | virtual_loss: 1.0 14 | enable_resign: 1 15 | v_resign: -0.9 16 | enable_dirichlet_noise: 0 17 | dirichlet_noise_alpha: 0.03 18 | dirichlet_noise_ratio: 0.25 19 | monitor_log_every_ms: 0 20 | get_best_move_mode: 0 21 | enable_background_search: 1 22 | enable_policy_temperature: 0 23 | policy_temperature: 0.67 24 | inherit_default_act: 1 25 | early_stop { 26 | enable: 1 27 | check_every_ms: 100 28 | sims_factor: 1.0 29 | sims_threshold: 2000 30 | } 31 | unstable_overtime { 32 | enable: 1 33 | time_factor: 0.3 34 | } 35 | behind_overtime { 36 | enable: 1 37 | act_threshold: 0.0 38 | time_factor: 0.3 39 | } 40 | time_control { 41 | enable: 1 42 | c_denom: 20 43 | c_maxply: 40 44 | reserved_time: 1.0 45 | } 46 | -------------------------------------------------------------------------------- /etc/mcts_cpu_grp.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 1 2 | num_search_threads: 8 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 0 6 | max_simulations_per_step: 1600 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | } 12 | c_puct: 2.5 13 | virtual_loss: 1.0 14 | enable_resign: 1 15 | v_resign: -0.9 16 | enable_dirichlet_noise: 0 17 | dirichlet_noise_alpha: 0.03 18 | dirichlet_noise_ratio: 0.25 19 | monitor_log_every_ms: 0 20 | get_best_move_mode: 0 21 | enable_background_search: 0 22 | enable_policy_temperature: 0 23 | policy_temperature: 0.67 24 | inherit_default_act: 1 25 | debugger { 26 | print_tree_depth: 20 27 | print_tree_width: 3 28 | } 29 | early_stop { 30 | enable: 0 31 | check_every_ms: 100 32 | sims_factor: 1.0 33 | sims_threshold: 2000 34 | } 35 | unstable_overtime { 36 | enable: 0 37 | time_factor: 0.3 38 | } 39 | behind_overtime { 40 | enable: 0 41 | act_threshold: 0.0 42 | time_factor: 0.3 43 | } 44 | time_control { 45 | enable: 0 46 | c_denom: 20 47 | c_maxply: 40 48 | reserved_time: 1.0 49 | } 50 | -------------------------------------------------------------------------------- /etc/mcts_dist.conf: -------------------------------------------------------------------------------- 1 | num_eval_threads: 32 2 | num_search_threads: 132 3 | max_children_per_node: 64 4 | max_search_tree_size: 400000000 5 | timeout_ms_per_step: 30000 6 | max_simulations_per_step: 0 7 | eval_batch_size: 4 8 | eval_wait_batch_timeout_us: 100 9 | model_config { 10 | train_dir: "ckpt" 11 | enable_tensorrt: 1 12 | tensorrt_model_path: "zero.ckpt-20b-v1.FP32.PLAN" 13 | } 14 | enable_dist: 1 15 | dist_svr_addrs: "ip:port" 16 | dist_svr_addrs: "ip:port" 17 | dist_svr_addrs: "ip:port" 18 | dist_svr_addrs: "ip:port" 19 | dist_svr_addrs: "ip:port" 20 | dist_svr_addrs: "ip:port" 21 | dist_svr_addrs: "ip:port" 22 | dist_svr_addrs: "ip:port" 23 | dist_svr_addrs: "ip:port" 24 | dist_svr_addrs: "ip:port" 25 | dist_svr_addrs: "ip:port" 26 | dist_svr_addrs: "ip:port" 27 | dist_svr_addrs: "ip:port" 28 | dist_svr_addrs: "ip:port" 29 | dist_svr_addrs: "ip:port" 30 | dist_svr_addrs: "ip:port" 31 | dist_svr_addrs: "ip:port" 32 | dist_svr_addrs: "ip:port" 33 | dist_svr_addrs: "ip:port" 34 | dist_svr_addrs: "ip:port" 35 | dist_svr_addrs: "ip:port" 36 | dist_svr_addrs: "ip:port" 37 | dist_svr_addrs: "ip:port" 38 | dist_svr_addrs: "ip:port" 39 | dist_svr_addrs: "ip:port" 40 | dist_svr_addrs: "ip:port" 41 | dist_svr_addrs: "ip:port" 42 | dist_svr_addrs: "ip:port" 43 | dist_svr_addrs: "ip:port" 44 | dist_svr_addrs: "ip:port" 45 | dist_svr_addrs: "ip:port" 46 | dist_svr_addrs: "ip:port" 47 | dist_config { 48 | timeout_ms: 100 49 | enable_leaky_bucket: 1 50 | leaky_bucket_size: 3 51 | leaky_bucket_refill_period_ms: 5000 52 | } 53 | c_puct: 2.5 54 | virtual_loss: 1.0 55 | enable_resign: 1 56 | v_resign: -0.9 57 | enable_dirichlet_noise: 0 58 | dirichlet_noise_alpha: 0.03 59 | dirichlet_noise_ratio: 0.25 60 | monitor_log_every_ms: 0 61 | get_best_move_mode: 0 62 | enable_background_search: 1 63 | enable_policy_temperature: 0 64 | policy_temperature: 0.67 65 | inherit_default_act: 1 66 | early_stop { 67 | enable: 1 68 | check_every_ms: 100 69 | sims_factor: 1.0 70 | sims_threshold: 10000 71 | } 72 | unstable_overtime { 73 | enable: 1 74 | time_factor: 0.3 75 | } 76 | behind_overtime { 77 | enable: 1 78 | act_threshold: 0.0 79 | time_factor: 0.3 80 | } 81 | time_control { 82 | enable: 1 83 | c_denom: 20 84 | c_maxply: 40 85 | reserved_time: 1.0 86 | } 87 | -------------------------------------------------------------------------------- /images/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/PhoenixGo/fbf67f9aec42531bff9569c44b85eb4c3f37b7be/images/logo.jpg -------------------------------------------------------------------------------- /mcts/BUILD: -------------------------------------------------------------------------------- 1 | load("//:rules.bzl", "cc_proto_library", "tf_cc_binary") 2 | 3 | tf_cc_binary( 4 | name = "mcts_main", 5 | srcs = ["mcts_main.cc"], 6 | deps = [ 7 | ":mcts_engine", 8 | "@boost//:asio", 9 | ], 10 | ) 11 | 12 | tf_cc_binary( 13 | name = "debug_tool", 14 | srcs = [ 15 | "debug_tool.cc", 16 | "mcts_config.h", 17 | "mcts_config.cc", 18 | ], 19 | deps = [ 20 | ":mcts_config_cc_proto", 21 | "//common:go_comm", 22 | "//common:go_state", 23 | "//common:timer", 24 | "//model:zero_model", 25 | "//model:trt_zero_model", 26 | ], 27 | ) 28 | 29 | cc_library( 30 | name = "mcts_engine", 31 | srcs = [ 32 | "mcts_engine.cc", 33 | "mcts_monitor.cc", 34 | "mcts_debugger.cc", 35 | "byo_yomi_timer.cc", 36 | ], 37 | hdrs = [ 38 | "mcts_engine.h", 39 | "mcts_monitor.h", 40 | "mcts_debugger.h", 41 | "byo_yomi_timer.h", 42 | ], 43 | deps = [ 44 | ":mcts_config", 45 | "//common:go_comm", 46 | "//common:go_state", 47 | "//common:task_queue", 48 | "//common:wait_group", 49 | "//common:thread_conductor", 50 | "//common:str_utils", 51 | "//common:timer", 52 | "//model:zero_model", 53 | "//model:trt_zero_model", 54 | "//dist:dist_zero_model_client", 55 | "//dist:async_dist_zero_model_client", 56 | "@com_github_google_glog//:glog", 57 | ], 58 | visibility = ["//visibility:public"], 59 | ) 60 | 61 | cc_library( 62 | name = "mcts_config", 63 | srcs = ["mcts_config.cc"], 64 | hdrs = ["mcts_config.h"], 65 | deps = [ 66 | ":mcts_config_cc_proto", 67 | "@com_github_google_glog//:glog", 68 | ], 69 | visibility = ["//visibility:public"], 70 | ) 71 | 72 | cc_proto_library( 73 | name = "mcts_config_cc_proto", 74 | srcs = ["mcts_config.proto"], 75 | deps = [ 76 | "//model:model_config_cc_proto", 77 | "//dist:dist_config_cc_proto", 78 | ], 79 | visibility = ["//visibility:public"], 80 | ) 81 | -------------------------------------------------------------------------------- /mcts/byo_yomi_timer.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include "byo_yomi_timer.h" 19 | 20 | #include 21 | 22 | ByoYomiTimer::ByoYomiTimer() 23 | : m_enable(false), 24 | m_remain_time{0.0f, 0.0f}, 25 | m_byo_yomi_time(0.0f), 26 | m_curr_player(GoComm::BLACK) 27 | { 28 | } 29 | 30 | void ByoYomiTimer::Set(float main_time, float byo_yomi_time) 31 | { 32 | m_enable = true; 33 | m_remain_time[0] = m_remain_time[1] = main_time; 34 | m_byo_yomi_time = byo_yomi_time; 35 | m_timer.Reset(); 36 | } 37 | 38 | void ByoYomiTimer::Reset() 39 | { 40 | m_enable = false; 41 | m_remain_time[0] = m_remain_time[1] = 0.0f; 42 | m_byo_yomi_time = 0.0f; 43 | m_curr_player = GoComm::BLACK; 44 | } 45 | 46 | bool ByoYomiTimer::IsEnable() 47 | { 48 | return m_enable; 49 | } 50 | 51 | void ByoYomiTimer::HandOff() 52 | { 53 | m_remain_time[m_curr_player == GoComm::BLACK ? 0 : 1] -= m_timer.fsec(); 54 | m_curr_player = m_curr_player == GoComm::BLACK ? GoComm::WHITE : GoComm::BLACK; 55 | m_timer.Reset(); 56 | } 57 | 58 | void ByoYomiTimer::SetRemainTime(GoStoneColor color, float time) 59 | { 60 | m_remain_time[color == GoComm::BLACK ? 0 : 1] = time; 61 | if (color == m_curr_player) { 62 | m_timer.Reset(); 63 | } 64 | } 65 | 66 | float ByoYomiTimer::GetRemainTime(GoStoneColor color) 67 | { 68 | float remain_time = m_remain_time[color == GoComm::BLACK ? 0 : 1]; 69 | if (color == m_curr_player) { 70 | remain_time -= m_timer.fsec(); 71 | } 72 | return std::max(remain_time, 0.0f); 73 | } 74 | 75 | float ByoYomiTimer::GetByoYomiTime() 76 | { 77 | return m_byo_yomi_time; 78 | } 79 | -------------------------------------------------------------------------------- /mcts/byo_yomi_timer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include "common/go_comm.h" 21 | #include "common/timer.h" 22 | 23 | class ByoYomiTimer 24 | { 25 | public: 26 | ByoYomiTimer(); 27 | void Set(float main_time, float byo_yomi_time); 28 | void Reset(); 29 | bool IsEnable(); 30 | void HandOff(); 31 | void SetRemainTime(GoStoneColor color, float time); 32 | float GetRemainTime(GoStoneColor color); 33 | float GetByoYomiTime(); 34 | 35 | private: 36 | bool m_enable; 37 | float m_remain_time[2]; 38 | float m_byo_yomi_time; 39 | GoStoneColor m_curr_player; 40 | Timer m_timer; 41 | }; 42 | -------------------------------------------------------------------------------- /mcts/debug_tool.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include 19 | #include 20 | 21 | #include "model/zero_model.h" 22 | #include "model/trt_zero_model.h" 23 | #include "common/go_state.h" 24 | #include "common/timer.h" 25 | 26 | #include "mcts_config.h" 27 | 28 | DEFINE_string(config_path, "", "Path of mcts config file."); 29 | DEFINE_string(init_moves, "", "Initialize Go board with init_moves."); 30 | DEFINE_int32(gpu, 0, "gpu used by neural network."); 31 | DEFINE_int32(intra_op_parallelism_threads, 0, "Number of tf's intra op threads"); 32 | DEFINE_int32(inter_op_parallelism_threads, 0, "Number of tf's inter op threads"); 33 | DEFINE_int32(transform, 0, "Transform features."); 34 | DEFINE_int32(num_iterations, 1, "How many iterations should run."); 35 | DEFINE_int32(batch_size, 1, "Batch size of each iterations."); 36 | 37 | void InitMove(GoState& board, std::string& moves) 38 | { 39 | for (size_t i = 0; i < moves.size(); i += 3) { 40 | int x = -1, y = -1; 41 | if (moves[i] != 'z') { 42 | x = moves[i] - 'a'; 43 | y = moves[i + 1] - 'a'; 44 | } 45 | board.Move(x, y); 46 | } 47 | } 48 | 49 | void TransformCoord(GoCoordId &x, GoCoordId &y, int mode, bool reverse) 50 | { 51 | if (reverse) { 52 | if (mode & 4) std::swap(x, y); 53 | if (mode & 2) y = GoComm::BORDER_SIZE - y - 1; 54 | if (mode & 1) x = GoComm::BORDER_SIZE - x - 1; 55 | } else { 56 | if (mode & 1) x = GoComm::BORDER_SIZE - x - 1; 57 | if (mode & 2) y = GoComm::BORDER_SIZE - y - 1; 58 | if (mode & 4) std::swap(x, y); 59 | } 60 | } 61 | 62 | template 63 | void TransformFeatures(T &features, int mode, bool reverse) 64 | { 65 | T ret(features.size()); 66 | int depth = features.size() / GoComm::GOBOARD_SIZE; 67 | for (int i = 0; i < GoComm::GOBOARD_SIZE; ++i) { 68 | GoCoordId x, y; 69 | GoFunction::IdToCoord(i, x, y); 70 | TransformCoord(x, y, mode, reverse); 71 | int j = GoFunction::CoordToId(x, y); 72 | for (int k = 0; k < depth; ++k) { 73 | ret[i * depth + k] = features[j * depth + k]; 74 | } 75 | } 76 | features = std::move(ret); 77 | } 78 | 79 | int main(int argc, char* argv[]) 80 | { 81 | google::ParseCommandLineFlags(&argc, &argv, true); 82 | google::InitGoogleLogging(argv[0]); 83 | google::InstallFailureSignalHandler(); 84 | 85 | auto config = LoadConfig(FLAGS_config_path); 86 | CHECK(config != nullptr) << "Load mcts config file '" << FLAGS_config_path << "' failed"; 87 | 88 | if (FLAGS_intra_op_parallelism_threads > 0) { 89 | config->mutable_model_config()->set_intra_op_parallelism_threads(FLAGS_intra_op_parallelism_threads); 90 | } 91 | 92 | if (FLAGS_inter_op_parallelism_threads > 0) { 93 | config->mutable_model_config()->set_inter_op_parallelism_threads(FLAGS_inter_op_parallelism_threads); 94 | } 95 | 96 | if (config->model_config().enable_mkl()) { 97 | ZeroModel::SetMKLEnv(config->model_config()); 98 | } 99 | 100 | std::unique_ptr model(new ZeroModel(FLAGS_gpu)); 101 | #if HAVE_TENSORRT 102 | if (config->model_config().enable_tensorrt()) { 103 | model.reset(new TrtZeroModel(FLAGS_gpu)); 104 | } 105 | #endif 106 | CHECK_EQ(model->Init(config->model_config()), 0) << "Model Init Fail, config path " << FLAGS_config_path<< ", gpu " << FLAGS_gpu; 107 | 108 | GoState board; 109 | InitMove(board, FLAGS_init_moves); 110 | 111 | auto features = board.GetFeature(); 112 | TransformFeatures(features, FLAGS_transform, false); 113 | std::vector> inputs(FLAGS_batch_size, features); 114 | 115 | std::vector> policies; 116 | std::vector values; 117 | 118 | Timer timer; 119 | for (int i = 1; i <= FLAGS_num_iterations; ++i) { 120 | CHECK_EQ(model->Forward(inputs, policies, values), 0) << "Forward fail"; 121 | LOG_IF(INFO, i % 100 == 0) << i << "/" << FLAGS_num_iterations << " iterations"; 122 | } 123 | float avg_cost_ms = timer.fms() / FLAGS_num_iterations; 124 | LOG(INFO) << "Cost " << avg_cost_ms << "ms per iteration"; 125 | 126 | std::vector& policy = policies[0]; 127 | TransformFeatures(policy, FLAGS_transform, true); 128 | float value = values[0]; 129 | board.ShowBoard(); 130 | for (int i = 0; i < GoComm::BORDER_SIZE; ++i) { 131 | for (int j = 0; j < GoComm::BORDER_SIZE; ++j) { 132 | printf("%.4f ", policy[i * GoComm::BORDER_SIZE + j]); 133 | } 134 | puts(""); 135 | } 136 | printf("Value %.4f Pass %.4f\n", value, policy[361]); 137 | } 138 | -------------------------------------------------------------------------------- /mcts/mcts_config.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include "mcts_config.h" 19 | 20 | #include 21 | 22 | #include 23 | #include 24 | 25 | std::unique_ptr LoadConfig(const char *config_path) 26 | { 27 | auto config = std::unique_ptr(new MCTSConfig); 28 | std::ostringstream conf_ss; 29 | if (!(conf_ss << std::ifstream(config_path).rdbuf())) { 30 | PLOG(ERROR) << "read config file " << config_path << " error"; 31 | return nullptr; 32 | } 33 | if (!google::protobuf::TextFormat::ParseFromString(conf_ss.str(), config.get())) { 34 | LOG(ERROR) << "parse config file " << config_path << " error! buf=" << conf_ss.str(); 35 | return nullptr; 36 | } 37 | return config; 38 | } 39 | 40 | std::unique_ptr LoadConfig(const std::string &config_path) 41 | { 42 | return LoadConfig(config_path.c_str()); 43 | } 44 | -------------------------------------------------------------------------------- /mcts/mcts_config.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | #include 22 | 23 | #include "mcts/mcts_config.pb.h" 24 | 25 | std::unique_ptr LoadConfig(const char *config_path); 26 | std::unique_ptr LoadConfig(const std::string &config_path); 27 | -------------------------------------------------------------------------------- /mcts/mcts_config.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "model/model_config.proto"; 4 | import "dist/dist_config.proto"; 5 | 6 | message MCTSConfig { 7 | int32 num_eval_threads = 1; 8 | int32 num_search_threads = 2; 9 | int32 max_search_tree_size = 3; 10 | int32 max_children_per_node = 4; 11 | int32 timeout_ms_per_step = 5; 12 | int32 max_simulations_per_step = 6; 13 | int32 eval_batch_size = 7; 14 | int32 eval_wait_batch_timeout_us = 8; 15 | ModelConfig model_config = 9; 16 | string gpu_list = 10; 17 | bool enable_dist = 11; 18 | repeated string dist_svr_addrs = 12; 19 | DistConfig dist_config = 13; 20 | float genmove_temperature = 14; 21 | float c_puct = 15; 22 | float virtual_loss = 16; 23 | int32 virtual_loss_mode = 17; // 0 - act+ucb, 1 - act, 2 - ucb, 3 - none 24 | bool enable_resign = 18; 25 | float v_resign = 19; 26 | int32 resign_mode = 20; 27 | bool enable_dirichlet_noise = 21; 28 | float dirichlet_noise_alpha = 22; 29 | float dirichlet_noise_ratio = 23; 30 | int32 monitor_log_every_ms = 24; 31 | int32 get_best_move_mode = 25; // 0 by mcts, 1 by policy, 2 by value 32 | bool enable_background_search = 26; 33 | bool enable_policy_temperature = 27; 34 | float policy_temperature = 28; 35 | bool disable_transform = 29; 36 | bool disable_pass = 30; 37 | float default_act = 31; 38 | bool inherit_default_act = 32; 39 | float inherit_default_act_factor = 33; 40 | bool clear_search_tree_per_move = 34; 41 | 42 | // async 43 | bool enable_async = 51; 44 | int32 eval_task_queue_size = 52; 45 | 46 | message DebuggerConfig { 47 | int32 print_tree_depth = 1; 48 | int32 print_tree_width = 2; 49 | }; 50 | DebuggerConfig debugger = 60; 51 | 52 | // rules for foxwq 53 | bool disable_double_pass_scoring = 81; 54 | bool disable_positional_superko = 82; 55 | int32 max_gen_passes = 83; 56 | bool enable_pass_pass = 84; 57 | 58 | message EarlyStopConfig { 59 | bool enable = 1; 60 | int32 check_every_ms = 2; 61 | float sims_factor = 3; 62 | int32 sims_threshold = 4; 63 | }; 64 | EarlyStopConfig early_stop = 91; 65 | 66 | message UnstableOvertimeConfig { 67 | bool enable = 1; 68 | float time_factor = 2; 69 | }; 70 | UnstableOvertimeConfig unstable_overtime = 92; 71 | 72 | message BehindOvertimeConfig { 73 | bool enable = 1; 74 | float act_threshold = 2; 75 | float time_factor = 3; 76 | }; 77 | BehindOvertimeConfig behind_overtime = 93; 78 | 79 | message TimeControlConfig { 80 | bool enable = 1; 81 | float main_time = 2; 82 | float byo_yomi_time = 3; 83 | int32 c_denom = 4; 84 | int32 c_maxply = 5; 85 | float reserved_time = 6; 86 | float min_time = 7; 87 | int32 byo_yomi_after = 8; 88 | }; 89 | TimeControlConfig time_control = 94; 90 | } 91 | -------------------------------------------------------------------------------- /mcts/mcts_debugger.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include "mcts_debugger.h" 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | #include 28 | 29 | #include "common/go_comm.h" 30 | 31 | #include "mcts_engine.h" 32 | 33 | MCTSDebugger::MCTSDebugger(MCTSEngine *engine) 34 | : m_engine(engine) 35 | { 36 | } 37 | 38 | void MCTSDebugger::Debug() // call before move 39 | { 40 | if (VLOG_IS_ON(1)) { 41 | int ith = m_engine->m_num_moves + 1; 42 | std::string ith_str = std::to_string(ith) + "th move(" + "wb"[ith&1] + ")"; 43 | VLOG(1) << "========== debug info for " << ith_str << " begin =========="; 44 | VLOG(1) << "main move path: " << GetMainMovePath(0); 45 | VLOG(1) << "second move path: " << GetMainMovePath(1); 46 | VLOG(1) << "third move path: " << GetMainMovePath(2); 47 | int depth = m_engine->GetConfig().debugger().print_tree_depth(); 48 | int width = m_engine->GetConfig().debugger().print_tree_width(); 49 | PrintTree(depth ? depth : 1, width ? width : 10); 50 | VLOG(1) << "model global step: " << m_engine->m_model_global_step; 51 | VLOG(1) << "========== debug info for " << ith_str << " end =========="; 52 | } 53 | } 54 | 55 | std::string MCTSDebugger::GetDebugStr() 56 | { 57 | TreeNode *root = m_engine->m_root; 58 | int ith = m_engine->m_num_moves; 59 | std::string ith_str = std::to_string(ith) + "th move(" + "wb"[ith&1] + ")"; 60 | float root_action = (float)root->total_action / k_action_value_base / root->visit_count; 61 | std::string debug_str = 62 | ith_str + ": " + GoFunction::IdToStr(root->move) + 63 | ", winrate=" + std::to_string((root_action + 1) * 50) + "%" + 64 | ", N=" + std::to_string(root->visit_count) + 65 | ", Q=" + std::to_string(root_action) + 66 | ", p=" + std::to_string(root->prior_prob) + 67 | ", v=" + std::to_string(root->value); 68 | if (m_engine->m_simulation_counter > 0) { 69 | debug_str += 70 | ", cost " + std::to_string(m_engine->m_search_timer.fms()) + "ms" + 71 | ", sims=" + std::to_string(m_engine->m_simulation_counter) + 72 | ", height=" + std::to_string(m_engine->m_monitor.MaxSearchTreeHeight()) + 73 | ", avg_height=" + std::to_string(m_engine->m_monitor.AvgSearchTreeHeight()); 74 | } 75 | debug_str += ", global_step=" + std::to_string(m_engine->m_model_global_step); 76 | return debug_str; 77 | } 78 | 79 | std::string MCTSDebugger::GetLastMoveDebugStr() 80 | { 81 | return m_last_move_debug_str; 82 | } 83 | 84 | void MCTSDebugger::UpdateLastMoveDebugStr() 85 | { 86 | m_last_move_debug_str = GetDebugStr(); 87 | } 88 | 89 | std::string MCTSDebugger::GetMainMovePath(int rank) 90 | { 91 | std::string moves; 92 | TreeNode *node = m_engine->m_root; 93 | while (node->expand_state == k_expanded && node->ch_len > rank) { 94 | TreeNode *ch = node->ch; 95 | std::vector idx(node->ch_len); 96 | std::iota(idx.begin(), idx.end(), 0); 97 | std::nth_element(idx.begin(), idx.begin() + rank, idx.end(), 98 | [ch](int i, int j) { return ch[i].visit_count > ch[j].visit_count; }); 99 | TreeNode *best_ch = &ch[idx[rank]]; 100 | if (moves.size()) moves += ","; 101 | moves += GoFunction::IdToStr(best_ch->move); 102 | char buf[100]; 103 | snprintf(buf, sizeof(buf), "(%d,%.2f,%.2f,%.2f)", 104 | best_ch->visit_count.load(), 105 | (float)best_ch->total_action / k_action_value_base / best_ch->visit_count, 106 | best_ch->prior_prob.load(), 107 | best_ch->value.load()); 108 | moves += buf; 109 | node = best_ch; 110 | rank = 0; 111 | } 112 | return moves; 113 | } 114 | 115 | void MCTSDebugger::PrintTree(int depth, int topk, const std::string &prefix) 116 | { 117 | TreeNode *root = m_engine->m_root; 118 | std::queue> que; 119 | que.emplace(root, 1); 120 | while (!que.empty()) { 121 | TreeNode *node; int dep; 122 | std::tie(node, dep) = que.front(); 123 | que.pop(); 124 | 125 | TreeNode *ch = node->ch; 126 | std::vector idx(node->ch_len); 127 | std::iota(idx.begin(), idx.end(), 0); 128 | std::sort(idx.begin(), idx.end(), [ch](int i, int j) { return ch[i].visit_count > ch[j].visit_count; }); 129 | if (topk < (int)idx.size()) idx.erase(idx.begin() + topk, idx.end()); 130 | 131 | for (int i: idx) { 132 | if (ch[i].visit_count == 0) { 133 | break; 134 | } 135 | std::string moves; 136 | for (TreeNode *t = &ch[i]; t != root; t = t->fa) { 137 | if (moves.size()) moves = "," + moves; 138 | moves = GoFunction::IdToStr(t->move) + moves; 139 | } 140 | VLOG(1) << prefix << moves 141 | << ": N=" << ch[i].visit_count 142 | << ", W=" << (float)ch[i].total_action / k_action_value_base 143 | << ", Q=" << (float)ch[i].total_action / k_action_value_base / ch[i].visit_count 144 | << ", p=" << ch[i].prior_prob 145 | << ", v=" << ch[i].value; 146 | if (dep < depth) que.emplace(&ch[i], dep + 1); 147 | } 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /mcts/mcts_debugger.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | 22 | class MCTSEngine; 23 | class MCTSDebugger 24 | { 25 | public: 26 | MCTSDebugger(MCTSEngine *engine); 27 | 28 | void Debug(); // call before move 29 | 30 | std::string GetDebugStr(); 31 | std::string GetLastMoveDebugStr(); // call after move 32 | void UpdateLastMoveDebugStr(); 33 | 34 | std::string GetMainMovePath(int rank); 35 | void PrintTree(int depth, int topk, const std::string &prefix = ""); 36 | 37 | private: 38 | MCTSEngine *m_engine; 39 | std::string m_last_move_debug_str; 40 | }; 41 | -------------------------------------------------------------------------------- /mcts/mcts_engine.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include "common/go_comm.h" 26 | #include "common/go_state.h" 27 | #include "common/task_queue.h" 28 | #include "common/wait_group.h" 29 | #include "common/thread_conductor.h" 30 | #include "common/timer.h" 31 | #include "model/zero_model_base.h" 32 | 33 | #include "mcts_config.h" 34 | #include "mcts_monitor.h" 35 | #include "mcts_debugger.h" 36 | #include "byo_yomi_timer.h" 37 | 38 | struct TreeNode 39 | { 40 | std::atomic fa; 41 | std::atomic ch; // child nodes must allocate contiguously 42 | std::atomic ch_len; 43 | std::atomic size; 44 | std::atomic expand_state; 45 | 46 | std::atomic move; 47 | std::atomic visit_count; 48 | std::atomic virtual_loss_count; 49 | std::atomic total_action; 50 | std::atomic prior_prob; 51 | std::atomic value; 52 | }; 53 | 54 | const int64_t k_action_value_base = 1 << 16; 55 | const int k_unexpanded = 0; 56 | const int k_expanding = 1; 57 | const int k_expanded = 2; 58 | 59 | typedef std::function, float)> EvalCallback; 60 | 61 | struct EvalTask 62 | { 63 | std::vector features; 64 | EvalCallback callback; 65 | }; 66 | 67 | class MCTSEngine 68 | { 69 | public: 70 | MCTSEngine(const MCTSConfig &config); 71 | ~MCTSEngine(); 72 | 73 | void Reset(const std::string &init_moves=""); 74 | void Move(GoCoordId x, GoCoordId y); 75 | void GenMove(GoCoordId &x, GoCoordId &y); 76 | void GenMove(GoCoordId &x, GoCoordId &y, std::vector &visit_count, float &v_resign); 77 | bool Undo(); 78 | const GoState &GetBoard(); 79 | MCTSConfig &GetConfig(); 80 | void SetPendingConfig(std::unique_ptr config); 81 | MCTSDebugger &GetDebugger(); 82 | int GetModelGlobalStep(); 83 | ByoYomiTimer &GetByoYomiTimer(); 84 | 85 | private: 86 | TreeNode *InitNode(TreeNode *node, TreeNode *fa, int move, float prior_prob); 87 | TreeNode *FindChild(TreeNode *node, int move); 88 | 89 | void Eval(const GoState &board, EvalCallback callback); 90 | void EvalRoutine(std::unique_ptr model); 91 | 92 | TreeNode *Select(GoState &board); 93 | TreeNode *SelectChild(TreeNode *node); 94 | int Expand(TreeNode *node, GoState &board, const std::vector &policy); 95 | void Backup(TreeNode *node, float value, int ch_len); 96 | void UndoVirtualLoss(TreeNode *node); 97 | 98 | bool CheckEarlyStop(int64_t timeout_us); 99 | bool CheckUnstable(); 100 | bool CheckBehind(); 101 | 102 | int64_t GetSearchTimeoutUs(); 103 | int64_t GetSearchOvertimeUs(int64_t timeout_us); 104 | 105 | void Search(); 106 | void SearchWait(int64_t timeout_us, bool is_overtime); 107 | void SearchResume(); 108 | void SearchPause(); 109 | void SearchRoutine(); 110 | 111 | void ChangeRoot(TreeNode *node); 112 | void InitRoot(); 113 | 114 | void DeleteRoutine(); 115 | int DeleteTree(TreeNode *node); 116 | 117 | int GetBestMove(float &v_resign); 118 | int GetSamplingMove(float temperature); 119 | std::vector GetVisitCount(TreeNode *node); 120 | 121 | template 122 | void TransformFeatures(T &features, int mode, bool reverse = false); 123 | void TransformCoord(GoCoordId &x, GoCoordId &y, int mode, bool reverse = false); 124 | 125 | void ApplyTemperature(std::vector &probs, float temperature); 126 | 127 | void TTableUpdate(uint64_t hash, int64_t value); 128 | void TTableSync(TreeNode *node); 129 | void TTableClear(); 130 | 131 | void EvalCacheInsert(uint64_t hash, const std::vector policy, float value); 132 | bool EvalCacheFind(uint64_t hash, std::vector &policy, float &value); 133 | 134 | bool IsPassDisable(); 135 | 136 | private: 137 | MCTSConfig m_config; 138 | std::unique_ptr m_pending_config; 139 | 140 | TreeNode *m_root; 141 | GoState m_board; 142 | 143 | std::vector m_eval_threads; 144 | TaskQueue m_eval_task_queue; 145 | WaitGroup m_eval_threads_init_wg; 146 | WaitGroup m_eval_tasks_wg; 147 | std::atomic m_model_global_step; 148 | 149 | std::vector m_search_threads; 150 | ThreadConductor m_search_threads_conductor; 151 | bool m_is_searching; 152 | 153 | std::thread m_delete_thread; 154 | TaskQueue m_delete_queue; 155 | 156 | std::atomic m_simulation_counter; 157 | Timer m_search_timer; 158 | 159 | int m_num_moves; 160 | std::string m_moves_str; 161 | 162 | int m_gen_passes; 163 | 164 | ByoYomiTimer m_byo_yomi_timer; 165 | 166 | MCTSMonitor m_monitor; 167 | MCTSDebugger m_debugger; 168 | 169 | friend class MCTSMonitor; 170 | friend class MCTSDebugger; 171 | }; 172 | -------------------------------------------------------------------------------- /mcts/mcts_monitor.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include "mcts_monitor.h" 19 | 20 | #include 21 | 22 | #include "mcts_engine.h" 23 | 24 | int MCTSMonitor::g_next_monitor_id = 0; 25 | MCTSMonitor *MCTSMonitor::g_global_monitors[k_max_monitor_instances]; 26 | std::mutex MCTSMonitor::g_global_monitors_mutex; 27 | thread_local std::shared_ptr MCTSMonitor::g_local_monitors[k_max_monitor_instances]; 28 | 29 | MCTSMonitor::MCTSMonitor(MCTSEngine *engine) 30 | : m_engine(engine), m_slot(0) 31 | { 32 | { 33 | std::lock_guard lock(g_global_monitors_mutex); 34 | m_id = g_next_monitor_id++; 35 | while (m_slot < k_max_monitor_instances && g_global_monitors[m_slot]) ++m_slot; 36 | CHECK(m_slot < k_max_monitor_instances) << "Too many MCTSMonitor instances"; 37 | g_global_monitors[m_slot] = this; 38 | } 39 | if (m_engine->GetConfig().monitor_log_every_ms() > 0) { 40 | m_monitor_thread = std::thread(&MCTSMonitor::MonitorRoutine, this); 41 | } 42 | } 43 | 44 | MCTSMonitor::~MCTSMonitor() 45 | { 46 | if (m_monitor_thread.joinable()) { 47 | m_monitor_thread_conductor.Terminate(); 48 | m_monitor_thread.join(); 49 | } 50 | g_local_monitors[m_slot] = nullptr; 51 | std::lock_guard lock(g_global_monitors_mutex); 52 | g_global_monitors[m_slot] = nullptr; 53 | } 54 | 55 | void MCTSMonitor::Pause() 56 | { 57 | m_monitor_thread_conductor.Pause(); 58 | m_monitor_thread_conductor.Join(); 59 | } 60 | 61 | void MCTSMonitor::Resume() 62 | { 63 | if (m_engine->GetConfig().monitor_log_every_ms() > 0) { 64 | if (!m_monitor_thread.joinable()) { 65 | m_monitor_thread = std::thread(&MCTSMonitor::MonitorRoutine, this); 66 | } 67 | m_monitor_thread_conductor.Resume(1); 68 | } 69 | } 70 | 71 | void MCTSMonitor::Reset() 72 | { 73 | std::lock_guard lock(m_local_monitors_mutex); 74 | for (auto &local_monitor: m_local_monitors) { 75 | local_monitor->Reset(); 76 | } 77 | m_local_monitors.erase( 78 | std::remove_if(m_local_monitors.begin(), m_local_monitors.end(), 79 | [](const std::shared_ptr &ptr) { return ptr.use_count() == 1; }), 80 | m_local_monitors.end()); 81 | } 82 | 83 | void MCTSMonitor::Log() 84 | { 85 | VLOG(0) << "MCTSMonitor: avg eval cost " << AvgEvalCostMs() << "ms"; 86 | VLOG(0) << "MCTSMonitor: max eval cost " << MaxEvalCostMs() << "ms"; 87 | VLOG(0) << "MCTSMonitor: avg eval cost " << AvgEvalCostMsPerBatch() << "ms per batch"; 88 | VLOG(0) << "MCTSMonitor: max eval cost " << MaxEvalCostMsPerBatch() << "ms per batch"; 89 | VLOG(0) << "MCTSMonitor: avg eval batch size " << AvgEvalBatchSize(); 90 | VLOG(0) << "MCTSMonitor: eval timeout " << EvalTimeout() << " times"; 91 | VLOG(0) << "MCTSMonitor: avg simulation cost " << AvgSimulationCostMs() << "ms"; 92 | VLOG(0) << "MCTSMonitor: max simulation cost " << MaxSimulationCostMs() << "ms"; 93 | VLOG(0) << "MCTSMonitor: avg select cost " << AvgSelectCostMs() << "ms"; 94 | VLOG(0) << "MCTSMonitor: max select cost " << MaxSelectCostMs() << "ms"; 95 | VLOG(0) << "MCTSMonitor: avg expand cost " << AvgExpandCostMs() << "ms"; 96 | VLOG(0) << "MCTSMonitor: max expand cost " << MaxExpandCostMs() << "ms"; 97 | VLOG(0) << "MCTSMonitor: avg backup cost " << AvgBackupCostMs() << "ms"; 98 | VLOG(0) << "MCTSMonitor: max backup cost " << MaxBackupCostMs() << "ms"; 99 | VLOG(0) << "MCTSMonitor: select same node " << SelectSameNode() << " times"; 100 | VLOG(0) << "MCTSMonitor: search tree height is " << MaxSearchTreeHeight(); 101 | VLOG(0) << "MCTSMonitor: avg height of nodes is " << AvgSearchTreeHeight(); 102 | VLOG(0) << "MCTSMonitor: avg eval task queue size is " << AvgTaskQueueSize(); 103 | 104 | 105 | if (m_engine->GetConfig().enable_async()) { 106 | VLOG(0) << "MCTSMonitor: avg rpc queue size is " << AvgRpcQueueSize(); 107 | } 108 | } 109 | 110 | void MCTSMonitor::MonitorRoutine() 111 | { 112 | m_monitor_thread_conductor.Wait(); 113 | for (;;) { 114 | if (!m_monitor_thread_conductor.IsRunning()) { 115 | m_monitor_thread_conductor.AckPause(); 116 | m_monitor_thread_conductor.Wait(); 117 | if (m_monitor_thread_conductor.IsTerminate()) { 118 | LOG(WARNING) << "MonitorRoutine: terminate"; 119 | return; 120 | } 121 | } 122 | m_monitor_thread_conductor.Sleep(m_engine->GetConfig().monitor_log_every_ms() * 1000LL); 123 | if (m_monitor_thread_conductor.IsRunning()) { 124 | Log(); 125 | } 126 | google::FlushLogFiles(google::GLOG_INFO); 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /mcts_main.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hh;hpp;hxx;hm;inl;inc;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav 15 | 16 | 17 | 18 | 19 | Source Files 20 | 21 | 22 | Source Files 23 | 24 | 25 | Source Files 26 | 27 | 28 | Source Files 29 | 30 | 31 | Source Files 32 | 33 | 34 | Source Files 35 | 36 | 37 | Source Files 38 | 39 | 40 | Source Files 41 | 42 | 43 | Source Files 44 | 45 | 46 | Source Files 47 | 48 | 49 | Source Files 50 | 51 | 52 | Source Files 53 | 54 | 55 | Source Files 56 | 57 | 58 | Source Files 59 | 60 | 61 | Source Files 62 | 63 | 64 | Source Files 65 | 66 | 67 | Source Files 68 | 69 | 70 | Source Files 71 | 72 | 73 | Source Files 74 | 75 | 76 | Source Files 77 | 78 | 79 | Source Files 80 | 81 | 82 | Source Files 83 | 84 | 85 | Source Files 86 | 87 | 88 | Source Files 89 | 90 | 91 | 92 | 93 | Header Files 94 | 95 | 96 | Header Files 97 | 98 | 99 | Header Files 100 | 101 | 102 | Header Files 103 | 104 | 105 | Header Files 106 | 107 | 108 | Header Files 109 | 110 | 111 | Header Files 112 | 113 | 114 | Header Files 115 | 116 | 117 | Header Files 118 | 119 | 120 | Header Files 121 | 122 | 123 | Header Files 124 | 125 | 126 | Header Files 127 | 128 | 129 | Header Files 130 | 131 | 132 | Header Files 133 | 134 | 135 | Header Files 136 | 137 | 138 | Header Files 139 | 140 | 141 | Header Files 142 | 143 | 144 | Header Files 145 | 146 | 147 | Header Files 148 | 149 | 150 | Header Files 151 | 152 | 153 | Header Files 154 | 155 | 156 | Header Files 157 | 158 | 159 | Header Files 160 | 161 | 162 | Header Files 163 | 164 | 165 | Header Files 166 | 167 | 168 | Header Files 169 | 170 | 171 | Header Files 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | -------------------------------------------------------------------------------- /model/BUILD: -------------------------------------------------------------------------------- 1 | load("//:rules.bzl", "cc_proto_library") 2 | load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_tensorrt") 3 | 4 | cc_binary( 5 | name = "build_tensorrt_model", 6 | srcs = ["build_tensorrt_model.cc"], 7 | deps = [ 8 | "@com_github_google_glog//:glog", 9 | ] + if_tensorrt([ 10 | "@local_config_tensorrt//:nv_infer", 11 | "@local_config_tensorrt//:nv_parsers", 12 | ]), 13 | copts = if_tensorrt(["-DGOOGLE_TENSORRT=1"]), 14 | ) 15 | 16 | cc_library( 17 | name = "zero_model", 18 | srcs = [ 19 | "zero_model.cc", 20 | ], 21 | hdrs = ["zero_model.h"], 22 | deps = [ 23 | ":zero_model_base", 24 | ":checkpoint_utils", 25 | "@com_github_google_glog//:glog", 26 | "@org_tensorflow//tensorflow/core:tensorflow", 27 | ], 28 | visibility = ["//visibility:public"], 29 | ) 30 | 31 | cc_library( 32 | name = "trt_zero_model", 33 | srcs = ["trt_zero_model.cc"], 34 | hdrs = ["trt_zero_model.h"], 35 | deps = [ 36 | ":zero_model_base", 37 | "@boost//:filesystem", 38 | "@com_github_google_glog//:glog", 39 | ] + if_tensorrt([ 40 | "@local_config_tensorrt//:nv_infer", 41 | ]), 42 | copts = if_tensorrt(["-DGOOGLE_TENSORRT=1"]), 43 | visibility = ["//visibility:public"], 44 | ) 45 | 46 | cc_library( 47 | name = "zero_model_base", 48 | hdrs = ["zero_model_base.h"], 49 | deps = [ 50 | ":model_config_cc_proto", 51 | "//common:errordef", 52 | ], 53 | visibility = ["//visibility:public"], 54 | ) 55 | 56 | cc_proto_library( 57 | name = "model_config_cc_proto", 58 | srcs = ["model_config.proto"], 59 | visibility = ["//visibility:public"], 60 | ) 61 | 62 | cc_library( 63 | name = "checkpoint_utils", 64 | srcs = ["checkpoint_utils.cc"], 65 | hdrs = ["checkpoint_utils.h"], 66 | deps = [ 67 | ":checkpoint_state_cc_proto", 68 | "@boost//:filesystem", 69 | "@com_github_google_glog//:glog", 70 | ], 71 | visibility = ["//visibility:public"], 72 | ) 73 | 74 | cc_proto_library( 75 | name = "checkpoint_state_cc_proto", 76 | srcs = ["checkpoint_state.proto"], 77 | ) 78 | -------------------------------------------------------------------------------- /model/build_tensorrt_model.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making Phoenix Go available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #if GOOGLE_TENSORRT 19 | 20 | #include 21 | #include 22 | 23 | #include 24 | #include 25 | 26 | #include "cuda/include/cuda_runtime_api.h" 27 | #include "tensorrt/include/NvInfer.h" 28 | #include "tensorrt/include/NvUffParser.h" 29 | 30 | DEFINE_string(model_path, "", "Path of model."); 31 | DEFINE_string(data_type, "FP32", "Data type to build, FP32 or FP16."); 32 | DEFINE_int32(max_batch_size, 4, "Max size for input batch."); 33 | DEFINE_int32(max_workspace_size, 1<<30, "Parameter to control memory allocation (in bytes)."); 34 | DEFINE_int32(calib_iterations, 10000, "Num of iterations to run while calibration."); 35 | DEFINE_string(storage_address, "", "Address of wegostorage"); 36 | DEFINE_int32(gpu, 0, "Gpu used while building."); 37 | 38 | using namespace nvinfer1; 39 | using namespace nvuffparser; 40 | 41 | class Logger : public ILogger 42 | { 43 | void log(Severity severity, const char *msg) override 44 | { 45 | switch (severity) { 46 | case Severity::kINTERNAL_ERROR: LOG(ERROR) << msg; break; 47 | case Severity::kERROR: LOG(ERROR) << msg; break; 48 | case Severity::kWARNING: LOG(WARNING) << msg; break; 49 | case Severity::kINFO: LOG(INFO) << msg; break; 50 | } 51 | } 52 | } g_logger; 53 | 54 | int main(int argc, char *argv[]) 55 | { 56 | gflags::ParseCommandLineFlags(&argc, &argv, true); 57 | google::InitGoogleLogging(argv[0]); 58 | google::InstallFailureSignalHandler(); 59 | 60 | std::string uff_path = FLAGS_model_path + ".uff"; 61 | std::string calib_path = FLAGS_model_path + ".calib"; 62 | std::string output_path = FLAGS_model_path + "." + FLAGS_data_type + ".PLAN"; 63 | 64 | IUffParser *parser = createUffParser(); 65 | parser = createUffParser(); 66 | parser->registerInput("inputs", DimsCHW(19, 19, 17)); 67 | parser->registerOutput("policy"); 68 | parser->registerOutput("value"); 69 | 70 | IBuilder *builder = createInferBuilder(g_logger); 71 | INetworkDefinition *network = builder->createNetwork(); 72 | 73 | CHECK(parser->parse(uff_path.c_str(), *network, DataType::kFLOAT)); 74 | 75 | builder->setMaxBatchSize(FLAGS_max_batch_size); 76 | builder->setMaxWorkspaceSize(FLAGS_max_workspace_size); 77 | builder->setHalf2Mode(FLAGS_data_type == "FP16"); 78 | 79 | ICudaEngine *engine = builder->buildCudaEngine(*network); 80 | CHECK_NOTNULL(engine); 81 | 82 | IHostMemory *serialized_engine = engine->serialize(); 83 | 84 | std::ofstream output(output_path, std::ios::binary); 85 | output.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); 86 | } 87 | 88 | #else // GOOGLE_TENSORRT 89 | 90 | #include 91 | 92 | int main() 93 | { 94 | fprintf(stderr, "TensorRT is not enable!\n"); 95 | return -1; 96 | } 97 | 98 | #endif // GOOGLE_TENSORRT 99 | -------------------------------------------------------------------------------- /model/checkpoint_state.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | option cc_enable_arenas = true; 5 | 6 | // Protocol buffer representing the checkpoint state. 7 | // 8 | // TODO(touts): Add other attributes as needed. 9 | message CheckpointState { 10 | // Path to the most-recent model checkpoint. 11 | string model_checkpoint_path = 1; 12 | 13 | // Paths to all not-yet-deleted model checkpoints, sorted from oldest to 14 | // newest. 15 | // Note that the value of model_checkpoint_path should be the last item in 16 | // this list. 17 | repeated string all_model_checkpoint_paths = 2; 18 | } 19 | -------------------------------------------------------------------------------- /model/checkpoint_utils.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include "checkpoint_utils.h" 19 | 20 | #include 21 | #include 22 | 23 | #include "model/checkpoint_state.pb.h" 24 | 25 | namespace fs = boost::filesystem; 26 | 27 | fs::path GetCheckpointPath(const fs::path &train_dir) 28 | { 29 | tensorflow::CheckpointState checkpoint_state; 30 | fs::path ckpt_state_path = train_dir / "checkpoint"; 31 | std::ostringstream ckpt_ss; 32 | if (!(ckpt_ss << std::ifstream(ckpt_state_path.string()).rdbuf())) { 33 | PLOG(ERROR) << "Error reading " << ckpt_state_path; 34 | return ""; 35 | } 36 | if (!google::protobuf::TextFormat::ParseFromString(ckpt_ss.str(), &checkpoint_state)) { 37 | LOG(ERROR) << "Error parsing " << ckpt_state_path << ", buf=" << ckpt_ss.str(); 38 | return ""; 39 | } 40 | fs::path checkpoint_path = checkpoint_state.model_checkpoint_path(); 41 | if (checkpoint_path.is_relative()) { 42 | checkpoint_path = train_dir / checkpoint_path; 43 | } 44 | return checkpoint_path; 45 | } 46 | 47 | bool CopyCheckpoint(const fs::path &from, const fs::path &to) 48 | { 49 | for (int i = 0; i < 3; ++i) { 50 | try { 51 | fs::path from_ckpt_path = GetCheckpointPath(from); 52 | if (from_ckpt_path.empty()) { 53 | LOG(ERROR) << "Error reading model path from " << from; 54 | continue; 55 | } 56 | 57 | fs::path ckpt_name = from_ckpt_path.filename(); 58 | fs::path to_ckpt_path = to / ckpt_name; 59 | 60 | fs::create_directories(to); 61 | for (std::string suffix: {".data-00000-of-00001", ".index"}) { 62 | fs::path from_file_path = from_ckpt_path.string() + suffix; 63 | fs::path to_file_path = to_ckpt_path.string() + suffix; 64 | LOG(INFO) << "Copying from " << from_file_path << " to " << to_file_path << suffix; 65 | fs::copy_file(from_file_path, to_file_path, fs::copy_option::overwrite_if_exists); 66 | fs::path symlink_path = to / ("zero.ckpt" + suffix); 67 | fs::remove(symlink_path); 68 | fs::create_symlink(to_file_path.filename(), symlink_path); 69 | } 70 | 71 | tensorflow::CheckpointState checkpoint_state; 72 | checkpoint_state.set_model_checkpoint_path(to_ckpt_path.string()); 73 | fs::path ckpt_state_path = to / "checkpoint"; 74 | LOG(INFO) << "Writing checkpoint file " << ckpt_state_path; 75 | if (!(std::ofstream(ckpt_state_path.string()) << checkpoint_state.DebugString())) { 76 | PLOG(ERROR) << "Error writing " << to << "/checkpoint"; 77 | continue; 78 | } 79 | 80 | fs::copy_file(from / "meta_graph", to / "meta_graph", fs::copy_option::overwrite_if_exists); 81 | LOG(INFO) << "Copy checkpoint from " << from << " to " << to << " succ."; 82 | return true; 83 | } catch (const std::exception &e) { 84 | LOG(ERROR) << e.what(); 85 | } 86 | } 87 | return false; 88 | } 89 | -------------------------------------------------------------------------------- /model/checkpoint_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | 22 | boost::filesystem::path GetCheckpointPath(const boost::filesystem::path &train_dir); 23 | 24 | bool CopyCheckpoint(const boost::filesystem::path &from, const boost::filesystem::path &to); 25 | -------------------------------------------------------------------------------- /model/model_config.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | message ModelConfig { 4 | string train_dir = 1; 5 | string checkpoint_path = 2; 6 | string meta_graph_path = 3; 7 | int32 intra_op_parallelism_threads = 6; 8 | int32 inter_op_parallelism_threads = 7; 9 | bool enable_mkl = 8; 10 | int32 kmp_blocktime = 9; 11 | bool kmp_settings = 10; 12 | string kmp_affinity = 11; 13 | bool enable_xla = 12; 14 | bool enable_tensorrt = 15; 15 | string tensorrt_model_path = 16; 16 | } 17 | -------------------------------------------------------------------------------- /model/trt_zero_model.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include "trt_zero_model.h" 19 | 20 | #if GOOGLE_TENSORRT 21 | 22 | #include 23 | #include 24 | 25 | #include 26 | #include 27 | 28 | #include "cuda/include/cuda_runtime_api.h" 29 | #include "tensorrt/include/NvInfer.h" 30 | 31 | namespace fs = boost::filesystem; 32 | 33 | class Logger : public nvinfer1::ILogger 34 | { 35 | void log(Severity severity, const char *msg) override 36 | { 37 | switch (severity) { 38 | case Severity::kINTERNAL_ERROR: LOG(ERROR) << msg; break; 39 | case Severity::kERROR: LOG(ERROR) << msg; break; 40 | case Severity::kWARNING: LOG(WARNING) << msg; break; 41 | case Severity::kINFO: LOG(INFO) << msg; break; 42 | } 43 | } 44 | } g_logger; 45 | 46 | TrtZeroModel::TrtZeroModel(int gpu) 47 | : m_engine(nullptr), m_runtime(nullptr), m_context(nullptr), m_gpu(gpu), m_global_step(0) 48 | { 49 | } 50 | 51 | TrtZeroModel::~TrtZeroModel() 52 | { 53 | if (m_context) { 54 | m_context->destroy(); 55 | } 56 | if (m_engine) { 57 | m_engine->destroy(); 58 | } 59 | if (m_runtime) { 60 | m_runtime->destroy(); 61 | } 62 | for (auto buf: m_cuda_buf) { 63 | int ret = cudaFree(buf); 64 | if (ret != 0) { 65 | LOG(ERROR) << "cuda free err " << ret; 66 | } 67 | } 68 | } 69 | 70 | int TrtZeroModel::Init(const ModelConfig &model_config) 71 | { 72 | cudaSetDevice(m_gpu); 73 | 74 | fs::path train_dir = model_config.train_dir(); 75 | 76 | fs::path tensorrt_model_path = model_config.tensorrt_model_path(); 77 | if (tensorrt_model_path.is_relative()) { 78 | tensorrt_model_path = train_dir / tensorrt_model_path; 79 | } 80 | 81 | std::ostringstream model_ss(std::ios::binary); 82 | if (!(model_ss << std::ifstream(tensorrt_model_path.string(), std::ios::binary).rdbuf())) { 83 | PLOG(ERROR) << "read tensorrt model '" << tensorrt_model_path << "' error"; 84 | return ERR_READ_TRT_MODEL; 85 | } 86 | std::string model_str = model_ss.str(); 87 | 88 | m_runtime = nvinfer1::createInferRuntime(g_logger); 89 | m_engine = m_runtime->deserializeCudaEngine(model_str.c_str(), model_str.size(), nullptr); 90 | if (m_engine == nullptr) { 91 | PLOG(ERROR) << "load cuda engine error"; 92 | return ERR_LOAD_TRT_ENGINE; 93 | } 94 | m_context = m_engine->createExecutionContext(); 95 | 96 | int batch_size = m_engine->getMaxBatchSize(); 97 | LOG(INFO) << "tensorrt max batch size: " << batch_size; 98 | for (int i = 0; i < m_engine->getNbBindings(); ++i) { 99 | auto dim = m_engine->getBindingDimensions(i); 100 | std::string dim_str = "("; 101 | int size = 1; 102 | for (int i = 0; i < dim.nbDims; ++i) { 103 | if (i) dim_str += ", "; 104 | dim_str += std::to_string(dim.d[i]); 105 | size *= dim.d[i]; 106 | } 107 | dim_str += ")"; 108 | LOG(INFO) << "tensorrt binding: " << m_engine->getBindingName(i) << " " << dim_str; 109 | 110 | void *buf; 111 | int ret = cudaMalloc(&buf, batch_size * size * sizeof(float)); 112 | if (ret != 0) { 113 | LOG(ERROR) << "cuda malloc err " << ret; 114 | return ERR_CUDA_MALLOC; 115 | } 116 | m_cuda_buf.push_back(buf); 117 | } 118 | 119 | if (!(std::ifstream(tensorrt_model_path.string() + ".step") >> m_global_step)) { 120 | LOG(WARNING) << "read global step from " << tensorrt_model_path << ".step failed"; 121 | } 122 | 123 | return 0; 124 | } 125 | 126 | int TrtZeroModel::Forward(const std::vector> &inputs, 127 | std::vector> &policy, std::vector &value) 128 | { 129 | int batch_size = inputs.size(); 130 | if (batch_size == 0) { 131 | LOG(ERROR) << "Error batch size can not be 0."; 132 | return ERR_INVALID_INPUT; 133 | } 134 | 135 | std::vector inputs_flat(batch_size * INPUT_DIM); 136 | for (int i = 0; i < batch_size; ++i) { 137 | if (inputs[i].size() != INPUT_DIM) { 138 | LOG(ERROR) << "Error input dim not match, need " << INPUT_DIM << ", got " << inputs[i].size(); 139 | return ERR_INVALID_INPUT; 140 | } 141 | for (int j = 0; j < INPUT_DIM; ++j) { 142 | inputs_flat[i * INPUT_DIM + j] = inputs[i][j]; 143 | } 144 | } 145 | 146 | int ret = cudaMemcpy(m_cuda_buf[0], inputs_flat.data(), inputs_flat.size() * sizeof(float), cudaMemcpyHostToDevice); 147 | if (ret != 0) { 148 | LOG(ERROR) << "cuda memcpy err " << ret; 149 | return ERR_CUDA_MEMCPY; 150 | } 151 | 152 | m_context->execute(batch_size, m_cuda_buf.data()); 153 | 154 | std::vector policy_flat(batch_size * OUTPUT_DIM); 155 | ret = cudaMemcpy(policy_flat.data(), m_cuda_buf[1], policy_flat.size() * sizeof(float), cudaMemcpyDeviceToHost); 156 | if (ret != 0) { 157 | LOG(ERROR) << "cuda memcpy err " << ret; 158 | return ERR_CUDA_MEMCPY; 159 | } 160 | policy.resize(batch_size); 161 | for (int i = 0; i < batch_size; ++i) { 162 | policy[i].resize(OUTPUT_DIM); 163 | for (int j = 0; j < OUTPUT_DIM; ++j) { 164 | policy[i][j] = policy_flat[i * OUTPUT_DIM + j]; 165 | } 166 | } 167 | 168 | value.resize(batch_size); 169 | ret = cudaMemcpy(value.data(), m_cuda_buf[2], value.size() * sizeof(float), cudaMemcpyDeviceToHost); 170 | if (ret != 0) { 171 | LOG(ERROR) << "cuda memcpy err " << ret; 172 | return ERR_CUDA_MEMCPY; 173 | } 174 | for (int i = 0; i < batch_size; ++i) { 175 | value[i] = -value[i]; 176 | } 177 | 178 | return 0; 179 | } 180 | 181 | int TrtZeroModel::GetGlobalStep(int &global_step) 182 | { 183 | global_step = m_global_step; 184 | return 0; 185 | } 186 | 187 | #else // GOOGLE_TENSORRT 188 | 189 | #include 190 | 191 | TrtZeroModel::TrtZeroModel(int gpu) 192 | { 193 | LOG(FATAL) << "TensorRT is not enable!"; 194 | } 195 | 196 | TrtZeroModel::~TrtZeroModel() 197 | { 198 | LOG(FATAL) << "TensorRT is not enable!"; 199 | } 200 | 201 | int TrtZeroModel::Init(const ModelConfig &model_config) 202 | { 203 | LOG(FATAL) << "TensorRT is not enable!"; 204 | return 0; 205 | } 206 | 207 | int TrtZeroModel::Forward(const std::vector> &inputs, 208 | std::vector> &policy, std::vector &value) 209 | { 210 | LOG(FATAL) << "TensorRT is not enable!"; 211 | return 0; 212 | } 213 | 214 | int TrtZeroModel::GetGlobalStep(int &global_step) 215 | { 216 | LOG(FATAL) << "TensorRT is not enable!"; 217 | return 0; 218 | } 219 | 220 | #endif // GOOGLE_TENSORRT 221 | -------------------------------------------------------------------------------- /model/trt_zero_model.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | 22 | #include "model/zero_model_base.h" 23 | #include "model/model_config.pb.h" 24 | 25 | namespace nvinfer1 { 26 | class ICudaEngine; 27 | class IRuntime; 28 | class IExecutionContext; 29 | } 30 | 31 | class TrtZeroModel final : public ZeroModelBase 32 | { 33 | public: 34 | TrtZeroModel(int gpu); 35 | ~TrtZeroModel(); 36 | 37 | int Init(const ModelConfig &model_config) override; 38 | 39 | // input [batch, 19 * 19 * 17] 40 | // policy [batch, 19 * 19 + 1] 41 | int Forward(const std::vector> &inputs, 42 | std::vector> &policy, std::vector &value) override; 43 | 44 | int GetGlobalStep(int &global_step) override; 45 | 46 | private: 47 | nvinfer1::ICudaEngine *m_engine; 48 | nvinfer1::IRuntime *m_runtime; 49 | nvinfer1::IExecutionContext *m_context; 50 | std::vector m_cuda_buf; 51 | int m_gpu; 52 | int m_global_step; 53 | }; 54 | -------------------------------------------------------------------------------- /model/zero_model.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include "zero_model.h" 19 | 20 | #include 21 | 22 | #include 23 | 24 | #include "tensorflow/core/public/session.h" 25 | #include "tensorflow/core/protobuf/meta_graph.pb.h" 26 | 27 | #include "model/checkpoint_utils.h" 28 | 29 | namespace fs = boost::filesystem; 30 | namespace tf = tensorflow; 31 | 32 | const std::string input_tensor_name = "inputs"; 33 | const std::string policy_tensor_name = "policy"; 34 | const std::string value_tensor_name = "value"; 35 | 36 | ZeroModel::ZeroModel(int gpu) 37 | : m_session(nullptr), m_gpu(gpu) 38 | { 39 | } 40 | 41 | ZeroModel::~ZeroModel() 42 | { 43 | if (m_session != nullptr) { 44 | tf::Status status = m_session->Close(); 45 | if (!status.ok()) { 46 | LOG(ERROR) << "Error closing tf session: " << status.ToString(); 47 | } 48 | } 49 | } 50 | 51 | int ZeroModel::Init(const ModelConfig &model_config) 52 | { 53 | fs::path train_dir = model_config.train_dir(); 54 | 55 | fs::path meta_graph_path = model_config.meta_graph_path(); 56 | if (meta_graph_path.empty()) { 57 | meta_graph_path = train_dir / "meta_graph"; 58 | } else if (meta_graph_path.is_relative()) { 59 | meta_graph_path = train_dir / meta_graph_path; 60 | } 61 | 62 | fs::path checkpoint_path = model_config.checkpoint_path(); 63 | if (checkpoint_path.empty()) { 64 | checkpoint_path = GetCheckpointPath(train_dir); 65 | } else if (checkpoint_path.is_relative()) { 66 | checkpoint_path = train_dir / checkpoint_path; 67 | } 68 | 69 | if (checkpoint_path.empty()) { 70 | return ERR_READ_CHECKPOINT; 71 | } 72 | LOG(INFO) << "Read checkpoint state succ"; 73 | 74 | tf::MetaGraphDef meta_graph_def; 75 | tf::Status status = ReadBinaryProto(tf::Env::Default(), meta_graph_path.string(), &meta_graph_def); 76 | if (!status.ok()) { 77 | LOG(ERROR) << "Error reading graph definition from " << meta_graph_path << ": " << status.ToString(); 78 | return ERR_READ_CHECKPOINT; 79 | } 80 | LOG(INFO) << "Read meta graph succ"; 81 | 82 | for (auto &node: *meta_graph_def.mutable_graph_def()->mutable_node()) { 83 | node.set_device("/gpu:" + std::to_string(m_gpu)); 84 | } 85 | 86 | tf::SessionOptions options; 87 | options.config.set_allow_soft_placement(true); 88 | options.config.mutable_gpu_options()->set_per_process_gpu_memory_fraction(0.5); 89 | options.config.mutable_gpu_options()->set_allow_growth(true); 90 | options.config.set_intra_op_parallelism_threads(model_config.intra_op_parallelism_threads()); 91 | options.config.set_inter_op_parallelism_threads(model_config.inter_op_parallelism_threads()); 92 | if (model_config.enable_xla()) { 93 | options.config.mutable_graph_options()->mutable_optimizer_options()->set_global_jit_level(tf::OptimizerOptions::ON_1); 94 | } 95 | m_session = std::unique_ptr(tf::NewSession(options)); 96 | if (m_session == nullptr) { 97 | LOG(ERROR) << "Could not create Tensorflow session."; 98 | return ERR_CREATE_SESSION; 99 | } 100 | LOG(INFO) << "Create session succ"; 101 | 102 | status = m_session->Create(meta_graph_def.graph_def()); 103 | if (!status.ok()) { 104 | LOG(ERROR) << "Error creating graph: " << status.ToString(); 105 | return ERR_CREATE_GRAPH; 106 | } 107 | LOG(INFO) << "Create graph succ"; 108 | 109 | tf::Tensor checkpoint_path_tensor(tf::DT_STRING, tf::TensorShape()); 110 | checkpoint_path_tensor.scalar()() = checkpoint_path.string(); 111 | status = m_session->Run({{meta_graph_def.saver_def().filename_tensor_name(), checkpoint_path_tensor}}, 112 | {}, /* fetches_outputs is empty */ 113 | {meta_graph_def.saver_def().restore_op_name()}, 114 | nullptr); 115 | if (!status.ok()) { 116 | LOG(ERROR) << "Error loading checkpoint from " << checkpoint_path << ": " << status.ToString(); 117 | return ERR_RESTORE_VAR; 118 | } 119 | LOG(INFO) << "Load checkpoint succ"; 120 | 121 | std::vector> inputs(1, std::vector(INPUT_DIM, false)); 122 | std::vector> policy; 123 | std::vector value; 124 | Forward(inputs, policy, value); 125 | 126 | return 0; 127 | } 128 | 129 | int ZeroModel::Forward(const std::vector> &inputs, 130 | std::vector> &policy, std::vector &value) 131 | { 132 | int batch_size = inputs.size(); 133 | if (batch_size == 0) { 134 | LOG(ERROR) << "Error batch size can not be 0."; 135 | return ERR_INVALID_INPUT; 136 | } 137 | 138 | tf::Tensor feature_tensor(tf::DT_BOOL, tf::TensorShape({batch_size, INPUT_DIM})); 139 | auto matrix = feature_tensor.matrix(); 140 | for (int i = 0; i < batch_size; ++i) { 141 | if (inputs[i].size() != INPUT_DIM) { 142 | LOG(ERROR) << "Error input dim not match, need " << INPUT_DIM << ", got " << inputs[i].size(); 143 | return ERR_INVALID_INPUT; 144 | } 145 | for (int j = 0; j < INPUT_DIM; ++j) { 146 | matrix(i, j) = inputs[i][j]; 147 | } 148 | } 149 | 150 | 151 | std::vector> network_inputs = {{input_tensor_name, feature_tensor}}; 152 | std::vector fetch_outputs = {policy_tensor_name, value_tensor_name}; 153 | std::vector network_outputs; 154 | tf::Status status = m_session->Run(network_inputs, fetch_outputs, {}, &network_outputs); 155 | if (!status.ok()) { 156 | LOG(ERROR) << "Error session run: " << status.ToString(); 157 | return ERR_SESSION_RUN; 158 | } 159 | 160 | auto policy_tensor = network_outputs[0].matrix(); 161 | auto value_tensor = network_outputs[1].flat(); 162 | policy.resize(batch_size); 163 | value.resize(batch_size); 164 | for (int i = 0; i < batch_size; ++i) { 165 | policy[i].resize(OUTPUT_DIM); 166 | for (int j = 0; j < OUTPUT_DIM; ++j) { 167 | policy[i][j] = policy_tensor(i, j); 168 | } 169 | value[i] = -value_tensor(i); 170 | } 171 | 172 | return 0; 173 | } 174 | 175 | int ZeroModel::GetGlobalStep(int &global_step) 176 | { 177 | std::vector network_outputs; 178 | tf::Status status = m_session->Run({}, {"global_step"}, {}, &network_outputs); 179 | if (!status.ok()) { 180 | LOG(ERROR) << "Error session run: " << status.ToString(); 181 | return ERR_SESSION_RUN; 182 | } 183 | 184 | global_step = network_outputs[0].scalar()(); 185 | return 0; 186 | } 187 | 188 | void ZeroModel::SetMKLEnv(const ModelConfig &model_config) 189 | { 190 | #if defined(_WIN32) || defined(_WIN64) 191 | _putenv_s("KMP_BLOCKTIME", std::to_string(model_config.kmp_blocktime()).c_str()); 192 | _putenv_s("KMP_SETTINGS", std::to_string(model_config.kmp_settings()).c_str()); 193 | _putenv_s("KMP_AFFINITY", model_config.kmp_affinity().c_str()); 194 | if (model_config.intra_op_parallelism_threads() > 0) { 195 | _putenv_s("OMP_NUM_THREADS", std::to_string(model_config.intra_op_parallelism_threads()).c_str()); 196 | } 197 | #else 198 | setenv("KMP_BLOCKTIME", std::to_string(model_config.kmp_blocktime()).c_str(), 0); 199 | setenv("KMP_SETTINGS", std::to_string(model_config.kmp_settings()).c_str(), 0); 200 | setenv("KMP_AFFINITY", model_config.kmp_affinity().c_str(), 0); 201 | if (model_config.intra_op_parallelism_threads() > 0) { 202 | setenv("OMP_NUM_THREADS", std::to_string(model_config.intra_op_parallelism_threads()).c_str(), 0); 203 | } 204 | #endif 205 | } 206 | -------------------------------------------------------------------------------- /model/zero_model.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | 22 | #include "model/zero_model_base.h" 23 | #include "model/model_config.pb.h" 24 | 25 | namespace tensorflow { class Session; } 26 | 27 | class ZeroModel final : public ZeroModelBase 28 | { 29 | public: 30 | ZeroModel(int gpu); 31 | ~ZeroModel(); 32 | 33 | int Init(const ModelConfig &model_config) override; 34 | 35 | // input [batch, 19 * 19 * 17] 36 | // policy [batch, 19 * 19 + 1] 37 | int Forward(const std::vector> &inputs, 38 | std::vector> &policy, std::vector &value) override; 39 | 40 | int GetGlobalStep(int &global_step) override; 41 | 42 | static void SetMKLEnv(const ModelConfig &model_config); 43 | 44 | private: 45 | std::unique_ptr m_session; 46 | int m_gpu; 47 | }; 48 | -------------------------------------------------------------------------------- /model/zero_model_base.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tencent is pleased to support the open source community by making PhoenixGo available. 3 | * 4 | * Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 5 | * 6 | * Licensed under the BSD 3-Clause License (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * https://opensource.org/licenses/BSD-3-Clause 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | #include 22 | 23 | #include "common/errordef.h" 24 | 25 | #include "model/model_config.pb.h" 26 | 27 | class ZeroModelBase 28 | { 29 | public: 30 | typedef std::function>, std::vector)> callback_t; 31 | 32 | virtual ~ZeroModelBase() {} 33 | 34 | virtual int Init(const ModelConfig &model_config) = 0; 35 | 36 | virtual int Forward(const std::vector> &inputs, 37 | std::vector> &policy, std::vector &value) = 0; 38 | 39 | virtual void Forward(const std::vector> &inputs, callback_t callback) 40 | { 41 | std::vector> policy; 42 | std::vector value; 43 | int ret = Forward(inputs, policy, value); 44 | callback(ret, std::move(policy), std::move(value)); 45 | } 46 | 47 | virtual int GetGlobalStep(int &global_step) = 0; 48 | 49 | virtual int RpcQueueSize() { return 0; } 50 | 51 | virtual void Wait() {} 52 | 53 | enum { 54 | INPUT_DIM = 19 * 19 * 17, 55 | OUTPUT_DIM = 19 * 19 + 1, 56 | }; 57 | }; 58 | -------------------------------------------------------------------------------- /rules.bzl: -------------------------------------------------------------------------------- 1 | load("@protobuf_archive//:protobuf.bzl", protobuf_cc_proto_library="cc_proto_library") 2 | 3 | 4 | def cc_proto_library(name, srcs=[], deps=[], use_grpc_plugin=False, **kwargs): 5 | protobuf_cc_proto_library( 6 | name=name, 7 | srcs=srcs, 8 | deps=deps, 9 | use_grpc_plugin=use_grpc_plugin, 10 | cc_libs = ["@protobuf_archive//:protobuf"], 11 | protoc="@protobuf_archive//:protoc", 12 | default_runtime="@protobuf_archive//:protobuf", 13 | **kwargs 14 | ) 15 | 16 | 17 | def tf_cc_binary(name, srcs=[], deps=[], linkopts=[], **kwargs): 18 | native.cc_binary( 19 | name=name, 20 | srcs=srcs + ["@org_tensorflow//tensorflow:libtensorflow_framework.so"], 21 | deps=deps, 22 | **kwargs 23 | ) 24 | -------------------------------------------------------------------------------- /scripts/build_tensorrt_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | set -e 5 | 6 | base_dir=`dirname $0` 7 | 8 | train_dir="$1" 9 | checkpoint="$2" 10 | 11 | max_batch_size=4 12 | if [ $# -ge 3 ]; then 13 | max_batch_size="$3" 14 | fi 15 | 16 | python -m tensorflow.python.tools.freeze_graph --input_meta_graph="$train_dir/meta_graph" --input_checkpoint="$train_dir/$checkpoint" --input_binary --output_graph="$train_dir/$checkpoint.frozen.pb" --output_node_names=policy,value,global_step 17 | 18 | python $base_dir/graph_transform.py "$train_dir/$checkpoint.frozen.pb" "$train_dir/$checkpoint.transformed.pb" 19 | 20 | python -m tensorflow.python.tools.optimize_for_inference --input="$train_dir/$checkpoint.transformed.pb" --output="$train_dir/$checkpoint.optimized.pb" --frozen_graph=True --input_name=inputs --output_name=policy,value,global_step 21 | 22 | convert-to-uff tensorflow -o "$train_dir/$checkpoint.uff" --input-file "$train_dir/$checkpoint.optimized.pb" -O "policy" -O "value" 23 | 24 | $base_dir/../bazel-bin/model/build_tensorrt_model --logtostderr --model_path="$train_dir/$checkpoint" --data_type=FP32 --max_batch_size=$max_batch_size 25 | 26 | python $base_dir/get_global_step.py "$train_dir/$checkpoint.frozen.pb" > "$train_dir/$checkpoint.FP32.PLAN.step" 27 | -------------------------------------------------------------------------------- /scripts/get_global_step.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | 5 | import tensorflow as tf 6 | 7 | input_path = sys.argv[1] 8 | 9 | with tf.Session() as session: 10 | with open(input_path, "rb") as f: 11 | graph_def = tf.GraphDef() 12 | graph_def.ParseFromString(f.read()) 13 | tf.import_graph_def(graph_def, name="") 14 | global_step = session.run("global_step:0") 15 | print global_step 16 | -------------------------------------------------------------------------------- /scripts/graph_transform.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import tensorflow as tf 4 | 5 | input_path = sys.argv[1] 6 | output_path = sys.argv[2] 7 | 8 | with open(input_path, "rb") as f: 9 | graph_def = tf.GraphDef() 10 | graph_def.ParseFromString(f.read()) 11 | 12 | tf.import_graph_def(graph_def, name="") 13 | graph = tf.get_default_graph() 14 | 15 | output_graph_def = tf.GraphDef() 16 | for node in graph_def.node: 17 | replace_node = tf.NodeDef() 18 | replace_node.CopyFrom(node) 19 | if node.name == "value": 20 | continue 21 | if node.name == "zero/value_head/Reshape_1": 22 | replace_node.name = "value" 23 | if node.name == "inputs": 24 | replace_node.attr["dtype"].CopyFrom(tf.AttrValue(type=tf.float32.as_datatype_enum)) 25 | for i, inp in enumerate(node.input): 26 | if inp == "Cast": 27 | replace_node.input[i] = "inputs" 28 | output_graph_def.node.extend([replace_node]) 29 | 30 | with open(output_path, "wb") as f: 31 | f.write(output_graph_def.SerializeToString()) 32 | -------------------------------------------------------------------------------- /scripts/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd "`dirname $0`/.." 4 | echo "current directory: '$PWD'" >&2 5 | 6 | config="$1" 7 | 8 | if [[ -z $config ]]; then 9 | if [[ `uname` == Darwin ]]; then 10 | echo "running on macOS" >&2 11 | config="etc/mcts_cpu.conf" 12 | else 13 | ldd bazel-bin/mcts/mcts_main | grep libcuda > /dev/null 14 | has_cuda=$? 15 | if [[ $has_cuda == "0" ]]; then 16 | echo "mcts_main was built with CUDA support" >&2 17 | else 18 | echo "mcts_main wasn't built with CUDA support" >&2 19 | fi 20 | 21 | ldd bazel-bin/mcts/mcts_main | grep libnvinfer > /dev/null 22 | has_tensorrt=$? 23 | if [[ $has_tensorrt == "0" ]]; then 24 | echo "mcts_main was built with TensorRT support" >&2 25 | else 26 | echo "mcts_main wasn't built with TensorRT support" >&2 27 | fi 28 | 29 | num_gpu=0 30 | if [[ $has_cuda == "0" ]]; then 31 | num_gpu=`nvidia-smi -L | wc -l` 32 | echo "found $num_gpu GPU(s)" >&2 33 | fi 34 | 35 | if [[ $has_cuda == "0" && $num_gpu -gt 0 ]]; then 36 | if [[ $has_tensorrt == "0" ]]; then 37 | config="etc/mcts_${num_gpu}gpu.conf" 38 | else 39 | config="etc/mcts_${num_gpu}gpu_notensorrt.conf" 40 | fi 41 | else 42 | config="etc/mcts_cpu.conf" 43 | fi 44 | fi 45 | fi 46 | 47 | echo "use config file '$config'" >&2 48 | 49 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$PWD/bazel-bin/external/org_tensorflow/tensorflow" 50 | 51 | echo "log to '$PWD/log'" >&2 52 | mkdir -p log 53 | 54 | echo "start mcts_main" >&2 55 | exec bazel-bin/mcts/mcts_main --config_path="$config" --gtp --log_dir=log --v=1 56 | -------------------------------------------------------------------------------- /scripts/start_cpu.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | setlocal 4 | 5 | set config=etc\mcts_cpu.conf 6 | 7 | echo use config file '%config%' >&2 8 | 9 | pushd %~dp0.. 10 | 11 | echo log to %CD%\log >&2 12 | md log 2>NUL 13 | 14 | echo start mcts_main >&2 15 | x64\Release\mcts_main --config_path=%config% --gtp --log_dir=log --v=1 16 | 17 | popd -------------------------------------------------------------------------------- /scripts/start_gpu.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | setlocal 4 | 5 | set "PATH=%PATH%;C:\Program Files\NVIDIA Corporation\NVSMI" 6 | for /f %%i in ('nvidia-smi -L ^| find /v /c ""') do set num_gpu=%%i 7 | echo found %num_gpu% GPU(s) >&2 8 | set config=etc\mcts_%num_gpu%gpu_notensorrt.conf 9 | 10 | echo use config file '%config%' >&2 11 | 12 | pushd %~dp0.. 13 | 14 | echo log to %CD%\log >&2 15 | md log 2>NUL 16 | 17 | echo start mcts_main >&2 18 | x64\Release\mcts_main --config_path=%config% --gtp --log_dir=log --v=1 19 | 20 | popd -------------------------------------------------------------------------------- /third_party/glog/BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/PhoenixGo/fbf67f9aec42531bff9569c44b85eb4c3f37b7be/third_party/glog/BUILD -------------------------------------------------------------------------------- /third_party/glog/glog.patch: -------------------------------------------------------------------------------- 1 | diff --git bazel/glog.bzl bazel/glog.bzl 2 | index b33f1aa..502a61c 100644 3 | --- bazel/glog.bzl 4 | +++ bazel/glog.bzl 5 | @@ -69,6 +69,7 @@ def glog_library(namespace='google', with_gflags=1): 6 | '-DHAVE_SIGACTION', 7 | # For logging.cc. 8 | '-DHAVE_PREAD', 9 | + '-DHAVE_UNISTD_H', 10 | 11 | # Include generated header files. 12 | '-I%s/glog_internal' % gendir, 13 | -------------------------------------------------------------------------------- /third_party/tensorflow/.bazelrc: -------------------------------------------------------------------------------- 1 | # Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the 2 | # target CPU to build transient dependencies correctly. See 3 | # https://docs.bazel.build/versions/master/user-manual.html#flag--fat_apk_cpu 4 | build:android --crosstool_top=//external:android/crosstool 5 | build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain 6 | build:android_arm --config=android 7 | build:android_arm --cpu=armeabi-v7a 8 | build:android_arm --fat_apk_cpu=armeabi-v7a 9 | build:android_arm64 --config=android 10 | build:android_arm64 --cpu=arm64-v8a 11 | build:android_arm64 --fat_apk_cpu=arm64-v8a 12 | 13 | # Config to use a mostly-static build and disable modular op registration 14 | # support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python). 15 | # By default, TensorFlow will build with a dependence on 16 | # //tensorflow:libtensorflow_framework.so. 17 | build:monolithic --define framework_shared_object=false 18 | 19 | # For projects which use TensorFlow as part of a Bazel build process, putting 20 | # nothing in a bazelrc will default to a monolithic build. The following line 21 | # opts in to modular op registration support by default. 22 | build --define framework_shared_object=true 23 | 24 | # Please note that MKL on MacOS or windows is still not supported. 25 | # If you would like to use a local MKL instead of downloading, please set the 26 | # environment variable "TF_MKL_ROOT" every time before build. 27 | build:mkl --define=build_with_mkl=true --define=enable_mkl=true 28 | build:mkl -c opt 29 | 30 | # This config option is used to enable MKL-DNN open source library only, 31 | # without depending on MKL binary version. 32 | build:mkl_open_source_only --define=build_with_mkl_dnn_only=true 33 | build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true 34 | 35 | build:download_clang --crosstool_top=@local_config_download_clang//:toolchain 36 | build:download_clang --define=using_clang=true 37 | # Instruct clang to use LLD for linking. 38 | # This only works with GPU builds currently, since Bazel sets -B/usr/bin in 39 | # auto-generated CPU crosstool, forcing /usr/bin/ld.lld to be preferred over 40 | # the downloaded one. 41 | build:download_clang_use_lld --linkopt='-fuse-ld=lld' 42 | 43 | build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain 44 | build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true 45 | 46 | build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain 47 | build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true 48 | 49 | build:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain 50 | build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true --define=using_clang=true 51 | 52 | build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain 53 | build:sycl --define=using_sycl=true --define=using_trisycl=false 54 | 55 | build:sycl_nodouble --crosstool_top=@local_config_sycl//crosstool:toolchain 56 | build:sycl_nodouble --define=using_sycl=true --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE 57 | 58 | build:sycl_asan --crosstool_top=@local_config_sycl//crosstool:toolchain 59 | build:sycl_asan --define=using_sycl=true --define=using_trisycl=false --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address 60 | 61 | build:sycl_trisycl --crosstool_top=@local_config_sycl//crosstool:toolchain 62 | build:sycl_trisycl --define=using_sycl=true --define=using_trisycl=true 63 | 64 | # Options extracted from configure script 65 | build:gdr --define=with_gdr_support=true 66 | build:ngraph --define=with_ngraph_support=true 67 | build:verbs --define=with_verbs_support=true 68 | 69 | # Options to disable default on features 70 | build:noaws --define=no_aws_support=true 71 | build:nogcp --define=no_gcp_support=true 72 | build:nohdfs --define=no_hdfs_support=true 73 | build:nokafka --define=no_kafka_support=true 74 | build:noignite --define=no_ignite_support=true 75 | build:nonccl --define=no_nccl_support=true 76 | 77 | build --define=use_fast_cpp_protos=true 78 | build --define=allow_oversize_protos=true 79 | 80 | build --spawn_strategy=standalone 81 | build --genrule_strategy=standalone 82 | build -c opt 83 | 84 | # Other build flags. 85 | build --define=grpc_no_ares=true 86 | 87 | # Modular TF build options 88 | build:dynamic_kernels --define=dynamic_loaded_kernels=true 89 | build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS 90 | 91 | # Default paths for TF_SYSTEM_LIBS 92 | build --define=PREFIX=/usr 93 | build --define=LIBDIR=$(PREFIX)/lib 94 | build --define=INCLUDEDIR=$(PREFIX)/include 95 | 96 | # Default options should come above this line 97 | 98 | # Options from ./configure 99 | try-import %workspace%/.tf_configure.bazelrc 100 | 101 | # Put user-specific options in .bazelrc.user 102 | try-import %workspace%/.bazelrc.user 103 | -------------------------------------------------------------------------------- /third_party/tensorflow/BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/PhoenixGo/fbf67f9aec42531bff9569c44b85eb4c3f37b7be/third_party/tensorflow/BUILD -------------------------------------------------------------------------------- /third_party/tensorflow/tensorflow.patch: -------------------------------------------------------------------------------- 1 | diff --git tensorflow/workspace.bzl tensorflow/workspace.bzl 2 | index dff151246a..f365f557b8 100755 3 | --- tensorflow/workspace.bzl 4 | +++ tensorflow/workspace.bzl 5 | @@ -414,11 +414,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): 6 | 7 | tf_http_archive( 8 | name = "com_github_gflags_gflags", 9 | - sha256 = "ae27cdbcd6a2f935baa78e4f21f675649271634c092b1be01469440495609d0e", 10 | - strip_prefix = "gflags-2.2.1", 11 | + sha256 = "6e16c8bc91b1310a44f3965e616383dbda48f83e8c1eaa2370a215057b00cabe", 12 | + strip_prefix = "gflags-77592648e3f3be87d6c7123eb81cbad75f9aef5a", 13 | urls = [ 14 | - "https://mirror.bazel.build/github.com/gflags/gflags/archive/v2.2.1.tar.gz", 15 | - "https://github.com/gflags/gflags/archive/v2.2.1.tar.gz", 16 | + "https://mirror.bazel.build/github.com/gflags/gflags/archive/77592648e3f3be87d6c7123eb81cbad75f9aef5a.tar.gz", 17 | + "https://github.com/gflags/gflags/archive/77592648e3f3be87d6c7123eb81cbad75f9aef5a.tar.gz", 18 | ], 19 | ) 20 | 21 | diff --git third_party/tensorrt/BUILD.tpl third_party/tensorrt/BUILD.tpl 22 | index 57682e8735..b1c64f8477 100644 23 | --- third_party/tensorrt/BUILD.tpl 24 | +++ third_party/tensorrt/BUILD.tpl 25 | @@ -34,6 +34,38 @@ cc_library( 26 | visibility = ["//visibility:public"], 27 | ) 28 | 29 | +cc_library( 30 | + name = "nv_infer_plugin", 31 | + srcs = [%{nv_infer_plugin}], 32 | + data = [%{nv_infer_plugin}], 33 | + includes = [ 34 | + "include", 35 | + ], 36 | + copts= cuda_default_copts(), 37 | + deps = [ 38 | + "@local_config_cuda//cuda:cuda", 39 | + ":tensorrt_headers", 40 | + ], 41 | + linkstatic = 1, 42 | + visibility = ["//visibility:public"], 43 | +) 44 | + 45 | +cc_library( 46 | + name = "nv_parsers", 47 | + srcs = [%{nv_parsers}], 48 | + data = [%{nv_parsers}], 49 | + includes = [ 50 | + "include", 51 | + ], 52 | + copts= cuda_default_copts(), 53 | + deps = [ 54 | + "@local_config_cuda//cuda:cuda", 55 | + ":tensorrt_headers", 56 | + ], 57 | + linkstatic = 1, 58 | + visibility = ["//visibility:public"], 59 | +) 60 | + 61 | 62 | %{tensorrt_genrules} 63 | 64 | diff --git third_party/tensorrt/tensorrt_configure.bzl third_party/tensorrt/tensorrt_configure.bzl 65 | index 9b946505a6..61b8f00bff 100644 66 | --- third_party/tensorrt/tensorrt_configure.bzl 67 | +++ third_party/tensorrt/tensorrt_configure.bzl 68 | @@ -19,8 +19,8 @@ load( 69 | _TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH" 70 | _TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION" 71 | 72 | -_TF_TENSORRT_LIBS = ["nvinfer"] 73 | -_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h"] 74 | +_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin", "nvparsers"] 75 | +_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h", "NvUffParser.h", "NvCaffeParser.h"] 76 | 77 | _DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR" 78 | _DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR" 79 | -------------------------------------------------------------------------------- /tools/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/PhoenixGo/fbf67f9aec42531bff9569c44b85eb4c3f37b7be/tools/.keep --------------------------------------------------------------------------------