├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── PATENTS ├── README.md ├── board ├── .gitignore ├── README.md ├── board.c ├── board.h ├── board.lua ├── comm.lua ├── dcnn_utils.lua ├── default_policy.c ├── default_policy.h ├── default_policy.lua ├── default_policy_common.c ├── default_policy_common.h ├── default_policy_dcnn.lua ├── dumpWinrate.lua ├── ownermap.c ├── ownermap.h ├── ownermap.lua ├── pattern.c ├── pattern.h ├── pattern_v2.c ├── pattern_v2.h ├── pattern_v2.lua ├── sample_one_pattern_v2.lua └── sample_pattern_v2.c ├── cnnPlayerV2 ├── cnnPlayerMCTSV2.lua ├── cnnPlayerV2.lua ├── cnnPlayerV2Framework.lua ├── cnnPlayerV3.lua ├── cnnPlayerV3SelfPlay.lua ├── win_rate_game1_LeeSedal_AlphaGo.txt ├── win_rate_game2_LeeSedal_AlphaGo.txt └── win_rate_game3_LeeSedal_AlphaGo.txt ├── common ├── comm.c ├── comm.h ├── comm_constant.h ├── comm_pipe.c ├── comm_pipe.h ├── common.c ├── common.h ├── common.lua ├── package.h └── util_package.lua ├── compile.sh ├── figure.png ├── libs └── README.md ├── local_evaluator ├── cnn_evaluator.lua ├── cnn_evaluator.sh ├── cnn_evaluator_run1.lua ├── cnn_exchanger.h ├── cnn_local_exchanger.c ├── cnn_local_exchanger.h └── kill_evaluator.sh ├── mctsv2 ├── FollyEventCount.h ├── event_count.cpp ├── event_count.h ├── playout_callbacks.c ├── playout_callbacks.h ├── playout_common.h ├── playout_multithread.c ├── playout_multithread.h ├── playout_multithread.lua ├── playout_params.h ├── test_playout_multithread.c ├── test_tree_multithread.c ├── test_tsumego.lua ├── tree.c ├── tree.h ├── tree_search.c ├── tree_search.h └── tree_search_internal.h ├── pachi_tactics ├── PACHI_LICENSE ├── README.md ├── board_interface.c ├── board_interface.h ├── fixp.h ├── moggy.c ├── moggy.h ├── moggy.lua ├── moggy_test.c ├── mq.h ├── tactics │ ├── 1lib.c │ ├── 1lib.h │ ├── 2lib.c │ ├── 2lib.h │ ├── TARGETS │ ├── ladder.c │ ├── ladder.h │ ├── nakade.c │ ├── nakade.h │ ├── nlib.c │ ├── nlib.h │ ├── selfatari.c │ └── selfatari.h └── util.h ├── sgfs └── alphago_leesedol_1.sgf ├── train.lua ├── train.sh ├── train ├── README.md └── rl_framework │ ├── examples │ └── go │ │ ├── ParallelCriterion2.lua │ │ ├── fm_go.lua │ │ └── models │ │ └── model-12-parallel-384-n-output-bn.lua │ └── infra │ ├── agent.lua │ ├── bundle.lua │ ├── dataset.lua │ ├── engine.lua │ ├── env.lua │ ├── forwardmodel.lua │ └── framework.lua ├── tsumego ├── rank_move.c ├── rank_move.h ├── solver.c ├── solver.h ├── solver.lua ├── test_solver.c └── test_solver.lua └── utils ├── goutils.lua ├── nnutils.lua ├── sgf.lua ├── test.sgf └── utils.lua /.gitignore: -------------------------------------------------------------------------------- 1 | *.o 2 | *.so 3 | *.swp 4 | *.bin 5 | models 6 | test_playout_multithread 7 | cnnPlayerV2/game-*.sgf 8 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to DarkForest Go engine 2 | We want to make contributing to this project as easy and transparent as possible. 3 | We also welcome any fork attempt from this project. 4 | 5 | ## Our Development Process 6 | DarkForest Go engine has been developed mainly by Yuandong Tian and Yan Zhu from Facebook AI Research since May 2015. 7 | Since it opens to the public, its future development will be conducted in this open source branch. 8 | 9 | ## Bug reports 10 | Please follow the steps when you find a bug: 11 | 12 | 1. use the Github issue search to check if the issue has been addressed before. 13 | 2. Make sure if the issue has been fixed by synchronizing the most recent branch in the repository. 14 | 3. Write a simple and easy-to-understand test case that reveals the issue. We use github issues to track public bugs. please ensure your description is clear and has sufficient instructions to be able to reproduce the issue. 15 | 16 | ## Pull Requests 17 | We actively welcome your pull requests. 18 | 19 | 1. Fork the repo and create your branch from `master`. 20 | 2. If you've added code that should be tested, add tests. 21 | 3. If you've changed APIs, update the documentation. 22 | 4. Ensure the test suite passes. 23 | 5. Make sure your code lints. 24 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 25 | 26 | ## Contributor License Agreement ("CLA") 27 | In order to accept your pull request, we need you to submit a CLA. You only need 28 | to do this once to work on any of Facebook's open source projects. 29 | 30 | Complete your CLA here: 31 | 32 | ## Coding Style 33 | * 2 spaces for indentation rather than tabs 34 | * 80 character line length 35 | 36 | ## License 37 | By contributing to DarkForest Go engine, you agree that your contributions will be licensed under its BSD license. 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For DarkForest Go software 4 | 5 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the DarkForest Go software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. 34 | -------------------------------------------------------------------------------- /board/.gitignore: -------------------------------------------------------------------------------- 1 | *.txt 2 | ctags 3 | tags 4 | -------------------------------------------------------------------------------- /board/README.md: -------------------------------------------------------------------------------- 1 | Board library 2 | ================= 3 | 4 | Default policy 5 | -------------- 6 | 7 | In DarkForest, there are 3 default policies that could be chosen, specified by `--playout_policy`: 8 | 9 | 1. `v2` (default) 10 | DarkForest uses its own trained default policy. The training set is Tygem (Thanks [Ling Wang](mailto:1160071998@qq.com) for providing this). `./models/playout-model.bin` is the trained model. Check `patternv2.c` for the source code. 11 | 12 | 2. `simple` 13 | DarkForest uses simple default policy (trying to save when atari, trying to put others into atari, 3x3 pattern matching from Pachi, etc). 14 | 15 | 3. `pachi` 16 | DarkForest uses Pachi's default policy. The majority of codes are in `./pachi_tactics`. 17 | 18 | Utilities 19 | --------- 20 | 21 | Dump default policy: 22 | 23 | ```bash 24 | th sample_one_pattern_v2.lua -p ../models/playout-policy.bin --temperature 0.5 --sgf_file [your_sgf_file] --move_from 230 --num_games 10 --num_moves 200 --save_prefix moves 25 | ``` 26 | 27 | Then it will dump the games played by playout policy and visualize them. If `save_prefix` is set, then the move sequence of each trial will also be saved. 28 | -------------------------------------------------------------------------------- /board/comm.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | local ffi = require 'ffi' 11 | local pl = require 'pl.import_into'() 12 | local utils = require('utils.utils') 13 | local common = require("common.common") 14 | 15 | local script_path = common.script_path() 16 | local symbols, s = utils.ffi_include(paths.concat(script_path, "comm.h")) 17 | local comm = {} 18 | 19 | local C = ffi.load(paths.concat(script_path, "../libs/libcomm.so")) 20 | 21 | function comm.init(id, is_create_new) 22 | is_create_new = is_create_new == true and 1 or 0 23 | return C.CommInit(id, is_create_new) 24 | end 25 | 26 | function comm.send(channel_id, m) 27 | C.CommSend(channel_id, m, ffi.sizeof(m)) 28 | end 29 | 30 | function comm.send_no_block(channel_id, m) 31 | return C.CommSendNoBlock(channel_id, m, ffi.sizeof(m)) == 0 32 | end 33 | 34 | function comm.receive(channel_id, m) 35 | C.CommReceive(channel_id, m, ffi.sizeof(m)) 36 | end 37 | 38 | function comm.receive_no_block(channel_id, m) 39 | return C.CommReceiveNoBlock(channel_id, m, ffi.sizeof(m)) == 0 40 | end 41 | 42 | function comm.destroy(channel_id) 43 | C.CommDestroy(channel_id) 44 | end 45 | 46 | return comm 47 | -------------------------------------------------------------------------------- /board/dcnn_utils.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | local utils 11 | local common = require("common.common") 12 | local board = require("board.board") 13 | local goutils = require 'utils.goutils' 14 | local utils = require 'utils.utils' 15 | local pl = require 'pl.import_into'() 16 | 17 | local dcnn_utils = { } 18 | 19 | -- Parameters to use 20 | -- medatory: 21 | -- shuffle_top_n Use topn for sampling, if topn == 1, then we only output the best legal move. 22 | -- codename codename for the model 23 | -- input, feature_type: for customized model, omitted if codename is specified. 24 | -- presample_codename: codename for presample model. 25 | -- temperature: Temperature for presample model (default 1) 26 | -- sample_step: Since which ply we start to use normal model, default is -1 27 | -- optional: 28 | -- usecpu whether to use cpu for evaluation. 29 | -- use_local_model whether we load a local .bin file. 30 | function dcnn_utils.init(options) 31 | -- opt.feature_type and opt.userank are necessary for the game to be played. 32 | local opt = pl.tablex.deepcopy(options) 33 | opt.sample_step = opt.sample_step or -1 34 | opt.temperature = opt.temperature or 1 35 | opt.shuffle_top_n = opt.shuffle_top_n or 1 36 | opt.rank = opt.rank or '9d' 37 | 38 | if opt.usecpu == nil or opt.usecpu == false then 39 | utils = require 'utils.utils' 40 | utils.require_cutorch() 41 | else 42 | g_nnutils_only_cpu = true 43 | utils = require 'utils.utils' 44 | end 45 | 46 | opt.userank = true 47 | assert(opt.shuffle_top_n >= 1) 48 | 49 | -- print("Loading model = " .. opt.input) 50 | opt.input = (opt.codename == "" and opt.input or common.codenames[opt.codename].model_name) 51 | opt.feature_type = (opt.codename == "" and opt.feature_type or common.codenames[opt.codename].feature_type) 52 | opt.attention = { 1, 1, common.board_size, common.board_size } 53 | 54 | local model_name = opt.use_local_model and pl.path.basename(opt.input) or opt.input 55 | if opt.verbose then print("Load model " .. model_name) end 56 | local model = torch.load(model_name) 57 | if opt.verbose then print("Load model complete") end 58 | 59 | local preSampleModel 60 | local preSampleOpt = pl.tablex.deepcopy(opt) 61 | 62 | if opt.temperature > 1 then 63 | if opt.verbose then print("temperature: " , opt.temperature) end 64 | preSampleModel = goutils.getDistillModel(model, opt.temperature) 65 | elseif opt.presample_codename ~= nil and opt.presample_codename ~= false then 66 | local code = common.codenames[opt.presample_codename] 67 | if opt.verbose then print("Load preSampleModel " .. code.model_name) end 68 | preSampleModel = torch.load(code.model_name) 69 | preSampleOpt.feature_type = code.feature_type 70 | else 71 | preSampleModel = model 72 | end 73 | 74 | opt.preSampleModel = preSampleModel 75 | opt.preSampleOpt = preSampleOpt 76 | opt.model = model 77 | if opt.valueModel and opt.valueModel ~= "" then opt.valueModel = torch.load(opt.valueModel) end 78 | if opt.verbose then print("dcnn ready!") end 79 | 80 | return opt 81 | end 82 | 83 | function dcnn_utils.dbg_set() 84 | utils.dbg_set() 85 | end 86 | 87 | function dcnn_utils.play(opt, b, player) 88 | -- It will return sortProb, sortInd, value, output 89 | return goutils.play_with_cnn(b, player, opt, opt.rank, opt.model) 90 | end 91 | 92 | function dcnn_utils.batch_play(opt, bs) 93 | return goutils.batch_play_with_cnn(bs, opt, opt.rank, opt.model) 94 | end 95 | 96 | function dcnn_utils.sample(opt, b, player) 97 | local sortProb, sortInd, value 98 | if b._ply > opt.sample_step then -- after sample, sample from normal model 99 | if opt.debug then print("normal model") end 100 | sortProb, sortInd = goutils.play_with_cnn(b, player, opt, opt.rank, opt.model) 101 | elseif b._ply < opt.sample_step then -- before sample the move, encouraging more diverse moves 102 | if opt.debug then print("presample model") end 103 | sortProb, sortInd = goutils.play_with_cnn(b, player, opt.preSampleOpt, opt.preSampleOpt.rank, opt.preSampleModel) 104 | else 105 | if opt.debug then print("uniform sample") end 106 | sortProb, sortInd = goutils.randomPlay(b, player, opt, opt.rank, opt.model) 107 | end 108 | if opt.debug then 109 | print("ply: ", b._ply) 110 | local j = 1 111 | for k = 1, 20 do 112 | local x, y = goutils.moveIdx2xy(sortInd[k][j]) 113 | local check_res, comments = goutils.check_move(b, x, y, player) 114 | if check_res then 115 | -- The move is all right. 116 | utils.dprint(" Move (%d, %d), ind = %d, move = %s, conf = (%f)", 117 | x, y, sortInd[k][j], goutils.compose_move_gtp(x, y, tonumber(player)), sortProb[k][j]) 118 | else 119 | utils.dprint(" Skipped Move (%d, %d), ind = %d, move = %s, conf = (%f), Reason = %s", 120 | x, y, sortInd[k][j], goutils.compose_move_gtp(x, y, tonumber(player)), sortProb[k][j], comments) 121 | end 122 | end 123 | end 124 | 125 | if opt.valueModel and value then 126 | print("current value: " .. string.format("%.3f", value)) 127 | end 128 | 129 | -- Apply the moves until we have seen a valid one. 130 | if opt.shuffle_top_n == 1 then 131 | local xf, yf, idx = goutils.tryplay_candidates(b, player, sortProb, sortInd) 132 | return xf, yf 133 | else 134 | local xf, yf, idx = goutils.tryplay_candidates_sample(b, player, sortProb, sortInd, opt.shuffle_top_n) 135 | return xf, yf 136 | end 137 | end 138 | 139 | function dcnn_utils.get_value(opt, b, player) 140 | return goutils.get_value(b, player, opt) 141 | end 142 | return dcnn_utils 143 | -------------------------------------------------------------------------------- /board/default_policy.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _DEFEAULT_POLICY_H_ 11 | #define _DEFEAULT_POLICY_H_ 12 | 13 | #include "default_policy_common.h" 14 | 15 | #ifdef __cplusplus 16 | extern "C" { 17 | #endif 18 | 19 | // The parameters for default policy. 20 | typedef struct { 21 | BOOL switches[NUM_MOVE_TYPE]; 22 | // Try to save our group in atari if its size is >= thres_save_atari. 23 | int thres_save_atari; 24 | // Allow self-atari move if the group size is smaller than thres_allow_atari_stone (before the new move is put). 25 | int thres_allow_atari_stone; 26 | // Reduce opponent liberties if its liberties <= thres_opponent_libs and #stones >= thres_opponent_stones. 27 | int thres_opponent_libs; 28 | int thres_opponent_stones; 29 | } DefPolicyParams; 30 | 31 | void *InitDefPolicy(); 32 | void DestroyDefPolicy(void *); 33 | void DefPolicyParamsPrint(void *hh); 34 | 35 | // Set the inital value of default policy params. 36 | void InitDefPolicyParams(DefPolicyParams *params); 37 | 38 | // Set policy parameters. If not called, then the default policy will use the default parameters. 39 | BOOL SetDefPolicyParams(void *h, const DefPolicyParams *params); 40 | 41 | // Utilities for playing default policy. Referenced from Pachi's code. 42 | void ComputeDefPolicy(void *h, DefPolicyMoves *m, const Region *r); 43 | 44 | // Sample the default policy, if ids != NULL, then only sample valid moves and save the ids information for the next play. 45 | BOOL SampleDefPolicy(void *h, DefPolicyMoves *ms, void *context, RandFunc rand_func, BOOL verbose, GroupId4 *ids, DefPolicyMove *m); 46 | BOOL SimpleSampleDefPolicy(void *h, const DefPolicyMoves *ms, void *context, RandFunc rand_func, GroupId4 *ids, DefPolicyMove *m); 47 | 48 | // Run the default policy 49 | DefPolicyMove RunOldDefPolicy(void *def_policy, void *context, RandFunc rand_func, Board* board, const Region *r, int max_depth, BOOL verbose); 50 | DefPolicyMove RunDefPolicy(void *def_policy, void *context, RandFunc rand_func, Board* board, const Region *r, int max_depth, BOOL verbose); 51 | 52 | #ifdef __cplusplus 53 | } 54 | #endif 55 | 56 | #endif 57 | -------------------------------------------------------------------------------- /board/default_policy.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | local ffi = require 'ffi' 11 | local pl = require 'pl.import_into'() 12 | local utils = require('utils.utils') 13 | local common = require("common.common") 14 | local goutils = require('utils.goutils') 15 | local board = require('board.board') 16 | 17 | -- local symbols, s = utils.ffi_include(paths.concat(common.lib_path, "board/default_policy.h")) 18 | local script_path = common.script_path() 19 | local symbols, s = utils.ffi_include(paths.concat(script_path, "default_policy.h")) 20 | -- local C = ffi.load("libexperimental_deeplearning_yuandong_go_board_board_c.so") 21 | local C = ffi.load(paths.concat(script_path, "../libs/libdefault_policy.so")) 22 | local dp = {} 23 | 24 | local function script_path() 25 | local str = debug.getinfo(2, "S").source:sub(2) 26 | return str:match("(.*/)") or "./" 27 | end 28 | 29 | dp.default_typename = { 30 | [0] = "normal", 31 | [1] = "ko_fight", 32 | [2] = "opponent_in_danger", 33 | [3] = "our_atari", 34 | [4] = "nakade", 35 | [5] = "pattern", 36 | [6] = 'no_move' 37 | } 38 | 39 | dp.default_typename_hash = { } 40 | for i = 1, #dp.default_typename do 41 | dp.default_typename_hash[dp.default_typename[i]] = i 42 | end 43 | 44 | function dp.new(rule) 45 | dp.rule = rule or board.chinese_rule 46 | return C.InitDefPolicy() 47 | end 48 | 49 | function dp.free(def_policy) 50 | C.DestroyDefPolicy(def_policy); 51 | end 52 | 53 | function dp.new_with_params(params_table, rule) 54 | local policy = dp.new(rule) 55 | local params = ffi.new("DefPolicyParams") 56 | C.InitDefPolicyParams(params) 57 | 58 | for k, v in pairs(params_table) do 59 | local idx = dp.default_typename_hash[k] 60 | if idx then 61 | params.switches[idx] = v and common.TRUE or common.FALSE 62 | end 63 | end 64 | 65 | if dp.set_params(policy, params) then 66 | C.DefPolicyParamsPrint(policy) 67 | return policy 68 | end 69 | end 70 | 71 | function dp.new_params(use_for_server) 72 | local res = ffi.new("DefPolicyParams") 73 | C.InitDefPolicyParams(res) 74 | -- A hack in the server side 75 | if use_for_server then 76 | res.switches[0] = common.FALSE 77 | res.switches[1] = common.FALSE 78 | -- Only enable extending on our atari, kill opponent if they are big and in atari, and nakade points (important). 79 | res.switches[2] = common.TRUE 80 | res.switches[3] = common.TRUE 81 | res.switches[4] = common.TRUE 82 | res.switches[5] = common.FALSE 83 | res.switches[6] = common.FALSE 84 | -- Save our groups if they are big and in atari. 85 | res.thres_save_atari = 5; 86 | -- Kill opponent groups if they are big and in atari. 87 | res.thres_opponent_libs = 1; 88 | res.thres_opponent_stones = 5; 89 | end 90 | return res 91 | end 92 | 93 | function dp.set_params(def_policy, def_params) 94 | return C.SetDefPolicyParams(def_policy, def_params) == common.TRUE 95 | end 96 | 97 | function dp.typename(move_type) 98 | return C.GetDefMoveType(move_type) 99 | end 100 | 101 | local moves = ffi.new("DefPolicyMoves") 102 | 103 | function dp.get_candidate_moves(def_policy, b) 104 | moves.board = b 105 | C.ComputeDefPolicy(def_policy, moves, nil) 106 | -- Then dump all moves from def_policy. 107 | local moves_table = { } 108 | if moves.num_moves > 0 then 109 | for i = 0, moves.num_moves - 1 do 110 | local x, y = common.coord2xy(moves.moves[i].m) 111 | table.insert(moves_table, { x, y }) 112 | end 113 | end 114 | return moves_table 115 | end 116 | 117 | function dp.run(def_policy, b, max_depth, verbose) 118 | local def_move = C.RunDefPolicy(def_policy, nil, nil, b, nil, max_depth, verbose and common.TRUE or common.FALSE) 119 | return board.get_fast_score(b, dp["rule"]), def_move 120 | end 121 | 122 | function dp.run_old(def_policy, b, max_depth, verbose) 123 | local def_move = C.RunOldDefPolicy(def_policy, nil, nil, b, nil, max_depth, verbose and common.TRUE or common.FALSE) 124 | return board.get_fast_score(b, dp["rule"]), def_move 125 | end 126 | 127 | return dp 128 | -------------------------------------------------------------------------------- /board/default_policy_common.c: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #include "default_policy_common.h" 11 | 12 | static const char *g_move_prompts[NUM_MOVE_TYPE] = { 13 | "NORMAL", 14 | "KO_FIGHT", 15 | "OPPONENT_IN_DANGER", 16 | "OUR_ATARI", 17 | "NAKADE", 18 | "PATTERN", 19 | "NO_MOVE" 20 | }; 21 | 22 | const char *GetDefMoveType(MoveType type) { 23 | if (type < 0 || type >= NUM_MOVE_TYPE) return NULL; 24 | else return g_move_prompts[type]; 25 | } 26 | 27 | DefPolicyMove c_mg(Coord m, MoveType t, int gamma) { 28 | DefPolicyMove move; 29 | move.m = m; 30 | move.type = t; 31 | move.gamma = gamma; 32 | move.game_ended = FALSE; 33 | return move; 34 | } 35 | 36 | DefPolicyMove c_m(Coord m, MoveType t) { 37 | DefPolicyMove move; 38 | move.m = m; 39 | move.type = t; 40 | move.gamma = 100; 41 | move.game_ended = FALSE; 42 | return move; 43 | } 44 | 45 | void add_move(DefPolicyMoves *m, DefPolicyMove move) { 46 | if (m->num_moves < MACRO_BOARD_SIZE*MACRO_BOARD_SIZE) { 47 | // char buf[30]; 48 | // printf("#move = %d. Add move %s. type = %d\n", m->num_moves, get_move_str(move.m, S_EMPTY, buf), move.type); 49 | m->moves[m->num_moves ++] = move; 50 | #ifdef SHOW_PROMPT 51 | printf(g_move_prompts[move.type]); 52 | printf("\n"); 53 | #endif 54 | } else { 55 | printf("#moves is out of bound!! num_moves = %d\n", m->num_moves); 56 | error(""); 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /board/default_policy_common.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _DEFAULT_POLICY_COMMON_ 11 | #define _DEFAULT_POLICY_COMMON_ 12 | 13 | #include "board.h" 14 | 15 | // #define SHOW_PROMPT 16 | typedef enum { NORMAL = 0, KO_FIGHT, OPPONENT_IN_DANGER, OUR_ATARI, NAKADE, PATTERN, NO_MOVE, NUM_MOVE_TYPE } MoveType; 17 | 18 | typedef struct { 19 | Coord m; 20 | int gamma; 21 | MoveType type; 22 | BOOL game_ended; 23 | } DefPolicyMove; 24 | 25 | // A queue for adding candidate moves. 26 | typedef struct { 27 | const Board *board; 28 | // Move sequence. 29 | DefPolicyMove moves[MACRO_BOARD_SIZE*MACRO_BOARD_SIZE]; 30 | int num_moves; 31 | } DefPolicyMoves; 32 | 33 | // Get a constant string that describes the type of default policy. 34 | const char *GetDefMoveType(MoveType type); 35 | 36 | // Simple constructor of default policy move. 37 | DefPolicyMove c_mg(Coord m, MoveType t, int gamma); 38 | DefPolicyMove c_m(Coord m, MoveType t); 39 | 40 | // Add moves to def policy queue. 41 | void add_move(DefPolicyMoves *m, DefPolicyMove move); 42 | 43 | #endif 44 | -------------------------------------------------------------------------------- /board/default_policy_dcnn.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | -- Default policy with pure DCNN 11 | local common = require("common.common") 12 | local board = require("board.board") 13 | local dcnn_utils = require 'board.dcnn_utils' 14 | local utils = require 'utils.utils' 15 | 16 | local ffi = require 'ffi' 17 | local symbols, s = utils.ffi_include(paths.concat(common.script_path(), "default_policy_common.h")) 18 | 19 | local dp = { } 20 | 21 | -- Initialize with a given codename and rule 22 | function dp.init(options) 23 | local opt = dcnn_utils.init(options) 24 | opt.rule = options.rule or board.chinese_rule 25 | return opt 26 | end 27 | 28 | function dp.run(def_policy, b, max_depth, verbose) 29 | -- keep sampling until we cannot make move anymore. 30 | local x, y 31 | local counter = 0 32 | local def_move = ffi.new("DefPolicyMove") 33 | while true do 34 | if max_depth > 0 and counter > max_depth then break end 35 | 36 | x, y = dcnn_utils.sample(def_policy, b, b._next_player) 37 | if x == nil or y == nil then break end 38 | def_move.m = common.xy2coord(x, y) 39 | 40 | board.play(b, x, y, b._next_player) 41 | counter = counter + 1 42 | end 43 | local score = board.get_fast_score(b, def_policy.rule) 44 | -- Also output the last move. 45 | return board.get_fast_score(b, def_policy.rule), def_move 46 | end 47 | 48 | return dp 49 | -------------------------------------------------------------------------------- /board/dumpWinrate.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | local pl = require 'pl.import_into'() 11 | local goutils = require 'utils.goutils' 12 | local sgf = require 'utils.sgf' 13 | local playoutv2 = require('mctsv2.playout_multithread') 14 | 15 | local opt = pl.lapp[[ 16 | -s,--sgf (default "") Sgf file to load 17 | -s,--start_n (default -1) Start from 18 | -e,--end_n (default -1) End to 19 | ]] 20 | 21 | assert(opt.sgf) 22 | local content = io.open(opt.sgf):read("*a") 23 | local game = assert(sgf.parse(content)) 24 | print(game.sgf[1].PW) 25 | print(game.sgf[1].PB) 26 | print(game.sgf[1].RE) 27 | 28 | local b = board.new() 29 | board.clear(b) 30 | 31 | goutils.apply_handicaps(b, game, true) 32 | 33 | local n = game:get_total_moves() 34 | game:play(function (move, counter) 35 | local x, y, player = sgf.parse_move(move, false, true) 36 | if x and y and player then 37 | board.play(b, x, y, player) 38 | -- board.show(b, 'last_move') 39 | return true 40 | end 41 | end, opt.start_n) 42 | 43 | local tr = playoutv2.new(opt.rollout) 44 | 45 | 46 | -- Then we start the dumping. 47 | game:play(function (move, counter) 48 | local x, y, player = sgf.parse_move(move, false, true) 49 | if x and y and player then 50 | board.play(b, x, y, player) 51 | -- board.show(b, 'last_move') 52 | return true 53 | end 54 | end, opt.end_n, true) 55 | 56 | 57 | -------------------------------------------------------------------------------- /board/ownermap.c: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #include "ownermap.h" 11 | // Ownermap 12 | typedef struct { 13 | // Ownermap 14 | int total_ownermap_count; 15 | // Histogram. S_EMPTY (S_UNKNOWN), S_BLACK, S_WHITE, S_OFF_BOARD (S_DAME) 16 | int ownermap[MACRO_BOARD_SIZE][MACRO_BOARD_SIZE][4]; 17 | } Handle; 18 | 19 | void *InitOwnermap() { 20 | Handle *hh = (Handle *)malloc(sizeof(Handle)); 21 | if (hh == NULL) error("Ownermap handle cannot be initialized."); 22 | return hh; 23 | } 24 | 25 | void FreeOwnermap(void *hh) { 26 | free(hh); 27 | } 28 | 29 | void ClearOwnermap(void *hh) { 30 | Handle *h = (Handle *)hh; 31 | // Ownermap 32 | h->total_ownermap_count = 0; 33 | memset(h->ownermap, 0, sizeof(h->ownermap)); 34 | } 35 | 36 | void AccuOwnermap(void *hh, const Board *board) { 37 | // Accumulate the ownermap with the board situation. 38 | // Usually the current board situation is after the default policy is applied. 39 | Handle *h = (Handle *)hh; 40 | // Accumulate the ownermap 41 | for (int i = 0; i < BOARD_SIZE; ++i) { 42 | for (int j = 0; j < BOARD_SIZE; ++j) { 43 | Coord c = OFFSETXY(i, j); 44 | Stone s = board->_infos[c].color; 45 | if (s == S_EMPTY) { 46 | s = GetEyeColor(board, c); 47 | } 48 | h->ownermap[i][j][s] ++; 49 | } 50 | } 51 | h->total_ownermap_count ++; 52 | } 53 | 54 | float OwnermapFloatOne(Handle *h, int i, int j, Stone player) { 55 | return ((float) h->ownermap[i][j][player]) / h->total_ownermap_count; 56 | } 57 | 58 | Stone OwnermapJudgeOne(Handle *h, int i, int j, float ratio) { 59 | int empty = h->ownermap[i][j][S_EMPTY]; 60 | int black = h->ownermap[i][j][S_BLACK]; 61 | int white = h->ownermap[i][j][S_WHITE]; 62 | int n = h->total_ownermap_count; 63 | 64 | int thres = (int)(n * ratio); 65 | 66 | if (empty >= thres) return S_DAME; 67 | if (empty + black >= thres) return S_BLACK; 68 | if (empty + white >= thres) return S_WHITE; 69 | return S_UNKNOWN; 70 | } 71 | 72 | void GetDeadStones(void *hh, const Board *board, float ratio, Stone *livedead, Stone *group_stats) { 73 | // Threshold the ownermap and determine. 74 | Handle *h = (Handle *)hh; 75 | 76 | Stone *internal_group_stats = NULL; 77 | 78 | if (group_stats == NULL) { 79 | internal_group_stats = (Stone *)malloc(board->_num_groups * sizeof(Stone)); 80 | group_stats = internal_group_stats; 81 | } 82 | 83 | memset(group_stats, S_EMPTY, board->_num_groups * sizeof(Stone)); 84 | 85 | for (int i = 0; i < BOARD_SIZE; ++i) { 86 | for (int j = 0; j < BOARD_SIZE; ++j) { 87 | Coord c = OFFSETXY(i, j); 88 | Stone s = board->_infos[c].color; 89 | Stone owner = OwnermapJudgeOne(h, i, j, ratio); 90 | 91 | // printf("owner at (%d, %d) = %d\n", i, j, owner); 92 | short id = board->_infos[c].id; 93 | if (owner == S_UNKNOWN) { 94 | group_stats[id] = s | S_UNKNOWN; 95 | } else if (! (group_stats[id] & S_UNKNOWN)) { 96 | // The group has deterministic state or empty. 97 | Stone stat = s; 98 | if (owner == s) stat |= S_ALIVE; 99 | else if (owner == OPPONENT(s)) stat |= S_DEAD; 100 | else stat |= S_UNKNOWN; 101 | 102 | if (group_stats[id] == S_EMPTY) group_stats[id] = stat; 103 | else if (group_stats[id] != stat) group_stats[id] = s | S_UNKNOWN; 104 | } 105 | } 106 | } 107 | // Once we get the group stats, we thus can fill the ownermap. 108 | if (livedead != NULL) { 109 | // Zero out everything else. 110 | memset(livedead, S_EMPTY, BOARD_SIZE * BOARD_SIZE * sizeof(Stone)); 111 | for (int i = 1; i < board->_num_groups; ++i) { 112 | TRAVERSE(board, i, c) { 113 | livedead[EXPORT_OFFSET(c)] = group_stats[i]; 114 | } ENDTRAVERSE 115 | } 116 | } 117 | 118 | if (internal_group_stats != NULL) free(internal_group_stats); 119 | } 120 | 121 | void GetOwnermap(void *hh, float ratio, Stone *ownermap) { 122 | // Threshold the ownermap and determine. 123 | Handle *h = (Handle *)hh; 124 | 125 | for (int i = 0; i < BOARD_SIZE; ++i) { 126 | for (int j = 0; j < BOARD_SIZE; ++j) { 127 | Coord c = OFFSETXY(i, j); 128 | ownermap[EXPORT_OFFSET(c)] = OwnermapJudgeOne(h, i, j, ratio); 129 | } 130 | } 131 | } 132 | 133 | void GetOwnermapFloat(void *hh, Stone player, float *ownermap) { 134 | Handle *h = (Handle *)hh; 135 | for (int i = 0; i < BOARD_SIZE; ++i) { 136 | for (int j = 0; j < BOARD_SIZE; ++j) { 137 | Coord c = OFFSETXY(i, j); 138 | ownermap[EXPORT_OFFSET(c)] = OwnermapFloatOne(h, i, j, player); 139 | } 140 | } 141 | } 142 | 143 | float GetTTScoreOwnermap(void *hh, const Board *board, Stone *livedead, Stone *territory) { 144 | Stone group_stats[MAX_GROUP]; 145 | GetDeadStones(hh, board, 0.5, livedead, group_stats); 146 | return GetTrompTaylorScore(board, group_stats, territory); 147 | } 148 | 149 | void ShowDeadStones(const Board *board, const Stone *stones) { 150 | // Show the board with ownership 151 | char buf[2000]; 152 | int len = 0; 153 | len += sprintf(buf + len, " A B C D E F G H J K L M N O P Q R S T\n"); 154 | char stone[3]; 155 | stone[2] = 0; 156 | for (int j = BOARD_SIZE - 1; j >= 0; --j) { 157 | len += sprintf(buf + len, "%2d ", j + 1); 158 | for (int i = 0; i < BOARD_SIZE; ++i) { 159 | Coord c = OFFSETXY(i, j); 160 | Stone s = board->_infos[c].color; 161 | if (HAS_STONE(s)) { 162 | char ss = (s == S_BLACK ? 'X' : 'O'); 163 | // Make it lower case if the stone are dead. 164 | Stone stat = stones[EXPORT_OFFSET(c)]; 165 | if (stat & S_DEAD) ss |= 0x20; 166 | stone[0] = ss; 167 | stone[1] = ( (stat & S_UNKNOWN) ? '?' : (c == board->_last_move ? ')' : ' ')); 168 | 169 | } else if (s == S_EMPTY) { 170 | if (STAR_ON19(i, j)) 171 | strcpy(stone, "+ "); 172 | else 173 | strcpy(stone, ". "); 174 | } else strcpy(stone, "# "); 175 | len += sprintf(buf + len, stone); 176 | } 177 | len += sprintf(buf + len, "%d\n", j + 1); 178 | } 179 | len += sprintf(buf + len, " A B C D E F G H J K L M N O P Q R S T"); 180 | // Finally print 181 | printf(buf); 182 | } 183 | 184 | void ShowStonesProb(void *hh, Stone player) { 185 | // Show the board with ownership 186 | char buf[20000]; 187 | int len = 0; 188 | Handle *h = (Handle *)hh; 189 | 190 | const char *prompt = " A B C D E F G H J K L M N O P Q R S T\n"; 191 | len += sprintf(buf + len, prompt); 192 | for (int j = BOARD_SIZE - 1; j >= 0; --j) { 193 | len += sprintf(buf + len, "%2d ", j + 1); 194 | for (int i = 0; i < BOARD_SIZE; ++i) { 195 | Coord c = OFFSETXY(i, j); 196 | float val = OwnermapFloatOne(h, i, j, player); 197 | len += sprintf(buf + len, "%.3f ", val); 198 | } 199 | len += sprintf(buf + len, "%d\n", j + 1); 200 | } 201 | len += sprintf(buf + len, prompt); 202 | // Finally print 203 | printf(buf); 204 | } 205 | 206 | 207 | -------------------------------------------------------------------------------- /board/ownermap.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _OWNERMAP_H_ 11 | #define _OWNERMAP_H_ 12 | 13 | #include "board.h" 14 | 15 | #ifdef __cplusplus 16 | extern "C" { 17 | #endif 18 | 19 | void *InitOwnermap(); 20 | void FreeOwnermap(void *); 21 | 22 | // Accumulating Ownermap 23 | void ClearOwnermap(void *hh); 24 | void AccuOwnermap(void *hh, const Board *board); 25 | void GetDeadStones(void *hh, const Board *board, float ratio, Stone *livedead, Stone *group_stats); 26 | void GetOwnermap(void *hh, float ratio, Stone *ownermap); 27 | 28 | // Get ownermap probability. 29 | void GetOwnermapFloat(void *hh, Stone player, float *ownermap); 30 | 31 | // Get Trompy-Taylor score directly. 32 | // If livedead != NULL, then livedead is a BOARD_SIZE * BOARD_SIZE array. Otherwise this output is ignored. 33 | // If territory != NULL, then it is also a BOARD_SIZE * BOARD_SIZE array. Otherwise this output is ignored. 34 | float GetTTScoreOwnermap(void *hh, const Board *board, Stone *livedead, Stone *territory); 35 | 36 | // Visulize DeadStones 37 | void ShowDeadStones(const Board *board, const Stone *stones); 38 | void ShowStonesProb(void *hh, Stone player); 39 | 40 | #ifdef __cplusplus 41 | } 42 | #endif 43 | 44 | #endif 45 | -------------------------------------------------------------------------------- /board/ownermap.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | local ffi = require 'ffi' 11 | local pl = require 'pl.import_into'() 12 | local utils = require('utils.utils') 13 | local common = require("common.common") 14 | local goutils = require('utils.goutils') 15 | local board = require('board.board') 16 | 17 | local script_path = common.script_path() 18 | local symbols, s = utils.ffi_include(paths.concat(script_path, "ownermap.h")) 19 | -- local symbols, s = utils.ffi_include(paths.concat(common.lib_path, "board/ownermap.h")) 20 | local C = ffi.load(paths.concat(script_path, "../libs/libownermap.so")) 21 | -- local C = ffi.load("libexperimental_deeplearning_yuandong_go_board_default_policy_c.so") 22 | local om = {} 23 | 24 | -- S_DEAD = 8 25 | -- om.dead_white = 8 + common.white 26 | -- om.dead_black = 8 + common.black 27 | -- om.dame_empty = 3 28 | om.dead_white = tonumber(symbols.S_DEAD) + common.white 29 | om.dead_black = tonumber(symbols.S_DEAD) + common.black 30 | om.dame_empty = tonumber(symbols.S_DAME) 31 | 32 | -- print("Dead white = " .. om.dead_white) 33 | -- print("Dead white = " .. om.dead_black) 34 | -- print("Dame empty = " .. om.dame_empty) 35 | 36 | -- Utilities for Ownermap 37 | function om.new() 38 | return C.InitOwnermap() 39 | end 40 | 41 | function om.free(ownermap) 42 | C.FreeOwnermap(ownermap) 43 | end 44 | 45 | function om.clear_ownermap(ownermap) 46 | C.ClearOwnermap(ownermap) 47 | end 48 | 49 | function om.accu_ownermap(ownermap, board_after_dp) 50 | C.AccuOwnermap(ownermap, board_after_dp); 51 | end 52 | 53 | function om.get_ownermap(ownermap, ratio) 54 | local ownermap = torch.CharTensor(common.board_size, common.board_size) 55 | C.GetOwnermap(ownermap, ratio, ownermap:data()) 56 | return ownermap 57 | end 58 | 59 | function om.get_deadstones(ownermap, b, ratio) 60 | local livedead = torch.CharTensor(common.board_size, common.board_size) 61 | C.GetDeadStones(ownermap, b, ratio, livedead:data(), nil) 62 | return livedead 63 | end 64 | 65 | function om.get_deadlist(livedead) 66 | -- From livedead we can get back the location of deadstones 67 | local livedead2 = livedead:view(-1) 68 | local dead_whites = { } 69 | local dead_whites_str = { } 70 | local dead_blacks = { } 71 | local dead_blacks_str = { } 72 | for i = 1, common.board_size * common.board_size do 73 | local x, y = goutils.moveIdx2xy(i) 74 | local s = goutils.compose_move_gtp(x, y) 75 | 76 | if livedead2[i] == om.dead_white then 77 | table.insert(dead_whites, {x, y}) 78 | table.insert(dead_whites_str, s) 79 | elseif livedead2[i] == om.dead_black then 80 | table.insert(dead_blacks, {x, y}) 81 | table.insert(dead_blacks_str, s) 82 | end 83 | end 84 | return { 85 | b = dead_blacks, 86 | w = dead_whites, 87 | b_str = dead_blacks_str, 88 | w_str = dead_whites_str, 89 | dames = dames, 90 | dames_str = dames_str 91 | } 92 | end 93 | 94 | function om.get_territorylist(territory) 95 | local territory2 = territory:view(-1) 96 | local whites = { } 97 | local whites_str = { } 98 | local blacks = { } 99 | local blacks_str = { } 100 | local dames = { } 101 | local dames_str = { } 102 | for i = 1, common.board_size * common.board_size do 103 | local x, y = goutils.moveIdx2xy(i) 104 | local s = goutils.compose_move_gtp(x, y) 105 | 106 | if territory2[i] == common.white then 107 | table.insert(whites, {x, y}) 108 | table.insert(whites_str, s) 109 | elseif territory2[i] == common.black then 110 | table.insert(blacks, {x, y}) 111 | table.insert(blacks_str, s) 112 | elseif territory2[i] == om.dame_empty then 113 | -- Dame 114 | table.insert(dames, {x, y}) 115 | table.insert(dames_str, s) 116 | end 117 | end 118 | return { 119 | b = blacks, 120 | w = whites, 121 | b_str = blacks_str, 122 | w_str = whites_str, 123 | dames = dames, 124 | dames_str = dames_str 125 | } 126 | end 127 | 128 | function om.get_ownermap_float(ownermap_ptr, player) 129 | local ownermap = torch.FloatTensor(common.board_size, common.board_size) 130 | C.GetOwnermapFloat(ownermap_ptr, player, ownermap:data()) 131 | return ownermap 132 | end 133 | 134 | function om.show_deadstones(b, stones) 135 | C.ShowDeadStones(b, stones:data()) 136 | end 137 | 138 | function om.show_stones_prob(ownermap, player) 139 | C.ShowStonesProb(ownermap, player) 140 | end 141 | 142 | function om.get_ttscore_ownermap(ownermap, b) 143 | local livedead = torch.CharTensor(common.board_size, common.board_size) 144 | local territory = torch.CharTensor(common.board_size, common.board_size) 145 | local score = C.GetTTScoreOwnermap(ownermap, b, livedead:data(), territory:data()) 146 | return score, livedead, territory 147 | end 148 | 149 | function om.print_list(t) 150 | local s = "" 151 | for _, c in ipairs(t) do 152 | s = s .. c .. " " 153 | end 154 | print(s) 155 | end 156 | 157 | function om.util_compute_final_score(ownermap, b, komi, trial, def_policy_func) 158 | local new_ownermap 159 | if not ownermap then 160 | ownermap = om.new() 161 | new_ownermap = true 162 | end 163 | assert(ownermap) 164 | assert(def_policy_func) 165 | 166 | trial = trial or 1000 167 | komi = komi or 6.5 168 | 169 | om.clear_ownermap(ownermap) 170 | local scores = torch.Tensor(trial) 171 | for i = 1, trial do 172 | local b2 = board.copyfrom(b) 173 | scores[i] = def_policy_func(b2, -1) - komi 174 | om.accu_ownermap(ownermap, b2) 175 | end 176 | local score, livedead, territory = om.get_ttscore_ownermap(ownermap, b) 177 | 178 | --[[ 179 | local territorylist = om.get_territorylist(territory) 180 | print(string.format("#W = %d, #B = %d, #dame = %d", 181 | #territorylist.w, #territorylist.b, #territorylist.dames)) 182 | 183 | print("White = ") 184 | print_list(territorylist.w_str) 185 | print("Black = ") 186 | print_list(territorylist.b_str) 187 | print("Dame = ") 188 | print_list(territorylist.dames_str) 189 | 190 | print("Score before komi = " .. score) 191 | ]] 192 | 193 | if new_ownermap then 194 | om.free(ownermap) 195 | end 196 | return score - komi, livedead, territory, scores 197 | end 198 | 199 | -- Some utilities for dp 200 | function om.compute_stats(scores) 201 | if scores == nil then 202 | return { 203 | std = 0, 204 | max = 0, 205 | min = 0, 206 | mean = 0, 207 | advantage = '-' 208 | } 209 | else 210 | if type(scores) == 'table' then 211 | scores = torch.FloatTensor(scores) 212 | end 213 | return { 214 | mean = scores:mean(), 215 | std = scores:std(), 216 | max = scores:max(), 217 | min = scores:min(), 218 | advantage = scores:mean() > 0 and 'B' or 'W' 219 | } 220 | end 221 | end 222 | 223 | return om 224 | -------------------------------------------------------------------------------- /board/pattern.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | // This file is inspired by Pachi's engine (https://github.com/pasky/pachi). 10 | // The main DarkForest engine (when specified with `--playout_policy v2`) does not depend on it. 11 | // However, the simple policy opened with `--playout_policy simple` will use this library. 12 | 13 | #ifndef _PATTERN_H_ 14 | #define _PATTERN_H_ 15 | 16 | #include 17 | 18 | #include "board.h" 19 | #include "default_policy_common.h" 20 | 21 | /* hash3_t pattern: ignore middle point, 2 bits per intersection (color) 22 | * plus 1 bit per each direct neighbor => 8*2 + 4 bits. Bitmap point order: 23 | * 7 6 5 b 24 | * 4 3 a 9 25 | * 2 1 0 8 */ 26 | /* Value bit 0: black pattern; bit 1: white pattern */ 27 | 28 | typedef uint64_t hash_t; 29 | typedef uint32_t hash3_t; // 3x3 pattern hash 30 | 31 | // Conceal all the interfaces. 32 | void *InitPatternDB(); 33 | // The hash pattern is extracted from Board. 34 | hash3_t GetHash(const Board *b, Coord m); 35 | BOOL QueryPatternDB(void *pp, hash3_t pat, Stone color, int* gamma); 36 | void DestroyPatternDB(void *); 37 | 38 | // Get pattern moves from the board and put them to default policy move queue. 39 | void CheckPatternFromLastMove(void *, DefPolicyMoves *m); 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /board/sample_one_pattern_v2.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | -- Try compare with three playout tactics. 11 | local pat = require 'board.pattern_v2' 12 | local board = require 'board.board' 13 | local common = require 'common.common' 14 | local goutils = require 'utils.goutils' 15 | local utils = require 'utils.utils' 16 | local sgf = require 'utils.sgf' 17 | local pl = require 'pl.import_into'() 18 | 19 | -- Load a sgf file and try havest its pattern. 20 | local opt = pl.lapp[[ 21 | -p,--pattern_file (default "") The pattern file to load. 22 | -v,--verbose (default 1) Verbose level 23 | --save_prefix (default "game") The prefix of the game. 24 | --sample_topn (default -1) Sample from topn move. 25 | --temperature (default 1.0) Temperature 26 | --sgf_file (default "") If not empty, then we sample after move_from moves. 27 | --move_from (default 0) Sample since move_from 28 | --num_moves (default 200) The number of moves to simulate. 29 | --num_games (default 10) The number of games to simulate. 30 | --stats Whether we compute the statistics. 31 | ]] 32 | 33 | -- pat.params.verbose = opt.verbose 34 | local pat_h = pat.init(opt.pattern_file) 35 | pat.set_verbose(pat_h, opt.verbose) 36 | -- pat.update_params(pat_h) 37 | 38 | pat.set_sample_params(pat_h, opt.sample_topn, opt.temperature) 39 | pat.print(pat_h) 40 | 41 | local b = board.new() 42 | board.clear(b) 43 | 44 | print("Load sgf: " .. opt.sgf_file) 45 | local f = assert(io.open(opt.sgf_file, "r")) 46 | local game = sgf.parse(f:read("*a")) 47 | if game == nil then 48 | error("Game " .. opt.sgf_file .. " cannot be loaded") 49 | end 50 | 51 | goutils.apply_handicaps(b, game, true) 52 | 53 | if opt.move_from > 0 then 54 | game:play(function (move, counter) 55 | local x, y, player = sgf.parse_move(move, false, true) 56 | if x and y and player then 57 | board.play(b, x, y, player) 58 | -- board.show(b, 'last_move') 59 | return true 60 | end 61 | end, opt.move_from) 62 | end 63 | 64 | --[[ 65 | print("Starting board situation") 66 | board.show(b, 'last_move') 67 | ]] 68 | 69 | local blacks, whites = board.get_black_white_stones(b) 70 | local header = { 71 | blacks = blacks, 72 | whites = whites 73 | } 74 | 75 | local duration = 0 76 | local total_num_moves = 0 77 | local stats = { } 78 | 79 | print("Current situation!") 80 | board.show_fancy(b, 'all') 81 | 82 | local summary = pat.init_sample_summary() 83 | 84 | for i = 1, opt.num_games do 85 | local be = pat.new(pat_h, b) 86 | local moves 87 | if opt.num_moves > 0 then 88 | moves, this_summary = pat.sample_many(be, opt.num_moves, nil, true) 89 | else 90 | this_summary = pat.sample_until(be) 91 | end 92 | local score = board.get_fast_score(pat.get_board(be)) 93 | 94 | if opt.stats then 95 | -- Compute the final score. 96 | table.insert(stats, score) 97 | else 98 | print(string.format("[%d/%d] Score = %f", i, opt.num_games, score)) 99 | pat.print_sample_summary(this_summary) 100 | board.show_fancy(pat.get_board(be)) 101 | 102 | -- Save to file 103 | if opt.save_prefix ~= "" and moves ~= nil then 104 | local filename = opt.save_prefix .. '-' .. i .. ".sgf" 105 | local f = assert(io.open(filename, "w")) 106 | f:write(sgf.sgf_string(header, moves)) 107 | f:close() 108 | end 109 | end 110 | 111 | pat.free(be) 112 | pat.combine_sample_summary(summary, this_summary) 113 | end 114 | 115 | pat.print_sample_summary(summary) 116 | 117 | if opt.stats then 118 | local mean = 0 119 | local max = -1000 120 | local min = 1000 121 | for i = 1, #stats do 122 | mean = mean + stats[i] 123 | max = math.max(max, stats[i]) 124 | min = math.min(min, stats[i]) 125 | end 126 | mean = mean / #stats 127 | print(string.format("#sample = %d, mean = %f, min = %f, max = %f", #stats, mean, min, max)) 128 | end 129 | 130 | pat.destroy(pat_h) 131 | -------------------------------------------------------------------------------- /board/sample_pattern_v2.c: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #include "pattern_v2.h" 11 | 12 | void str2play(const char *str, Coord *m, Stone *player) { 13 | if (str[0] == 'W') *player = S_WHITE; 14 | else *player = S_BLACK; 15 | 16 | if (! strcmp(str + 2, "PASS")) { 17 | *m = M_PASS; 18 | return; 19 | } 20 | if (! strcmp(str + 2, "RESIGN")) { 21 | *m = M_RESIGN; 22 | return; 23 | } 24 | 25 | int x = str[2] - 'A'; 26 | if (x >= 8) x --; 27 | int y; 28 | sscanf(str + 3, "%d", &y); 29 | y --; 30 | *m = OFFSETXY(x, y); 31 | } 32 | 33 | void simple_play(Board *b, const char *move_str) { 34 | GroupId4 ids; 35 | Coord m; 36 | Stone player; 37 | str2play(move_str, &m, &player); 38 | 39 | if (TryPlay(b, X(m), Y(m), player, &ids)) { 40 | Play(b, &ids); 41 | } 42 | } 43 | 44 | int main(int argc, char *argv[]) { 45 | // Load the pattern library and sample from it. 46 | if (argc < 5) { 47 | printf("Usage: sample_pattern_v2 pattern_file num_moves num_games verbose"); 48 | } 49 | const char *pattern_file = argv[1]; 50 | int num_moves, num_games, verbose; 51 | sscanf(argv[2], "%d", &num_moves); 52 | sscanf(argv[3], "%d", &num_games); 53 | sscanf(argv[4], "%d", &verbose); 54 | 55 | void *pat = InitPatternV2(pattern_file, NULL, FALSE); 56 | PatternV2Params params = *PatternV2GetParams(pat); 57 | params.verbose = verbose; 58 | PatternV2UpdateParams(pat, ¶ms); 59 | 60 | Board b; 61 | ClearBoard(&b); 62 | simple_play(&b, "B Q4"); 63 | simple_play(&b, "B Q16"); 64 | simple_play(&b, "B D4"); 65 | simple_play(&b, "B D16"); 66 | simple_play(&b, "W F3"); 67 | 68 | SampleSummary summary; 69 | AllMovesExt *move_ext = InitAllMovesExt(num_moves); 70 | 71 | double total_duration = 0.0; 72 | int total_moves = 0; 73 | 74 | for (int i = 0; i < num_games; ++i) { 75 | // Sample it. 76 | void *be = PatternV2InitBoardExtra(pat, &b); 77 | double start = wallclock(); 78 | PatternV2SampleMany(be, move_ext, NULL, &summary); 79 | total_duration += wallclock() - start; 80 | total_moves += summary.n; 81 | 82 | // After sampling, send the summary. 83 | printf("Game %d: moves [%d], random/top-k: %d/%d/%d/%d/%d, counter: %d/%d/%d/%d/%d\n", i, summary.n, 84 | summary.num_topn[0], summary.num_topn[1], summary.num_topn[2], summary.num_topn[3], summary.num_topn[4], 85 | summary.num_counters[1], summary.num_counters[2], summary.num_counters[3], summary.num_counters[4], summary.num_counters[5]); 86 | 87 | PatternV2DestroyBoardExtra(be); 88 | } 89 | 90 | printf("Time: %lf usec (%lf/%d)\n", total_duration / total_moves * 1e6, total_duration, total_moves); 91 | 92 | DestroyAllMovesExt(move_ext); 93 | DestroyPatternV2(pat); 94 | } 95 | -------------------------------------------------------------------------------- /cnnPlayerV2/cnnPlayerV2.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | local utils = require 'utils.utils' 11 | 12 | utils.require_torch() 13 | utils.require_cutorch() 14 | 15 | local goutils = require 'utils.goutils' 16 | local common = require("common.common") 17 | local sgfloader = require 'utils.sgf' 18 | local pl = require 'pl.import_into'() 19 | local board = require 'board.board' 20 | 21 | -- Let's follow the gtp protocol. 22 | -- Load a model and wait for the input. 23 | 24 | local opt = pl.lapp[[ 25 | -i,--input (default "./models/df2.bin") Input CNN models. 26 | -f,--feature_type (default "old") By default we only test old features: 27 | -r,--rank (default "9d") We play in the level of rank. 28 | -c,--usecpu Whether we use cpu to run the program. 29 | ]] 30 | 31 | -- opt.feature_type and opt.userank are necessary for the game to be played. 32 | opt.userank = true 33 | 34 | local b = board.new() 35 | local board_initialized = false 36 | -- Load the trained CNN models. 37 | -- Send the signature 38 | -- print("Loading model = " .. opt.input) 39 | local model = torch.load(opt.input) 40 | 41 | io.stderr:write("CNNPlayerV2") 42 | 43 | -- Adhoc strategy, if they pass, we pass. 44 | local enemy_pass = false 45 | -- 46 | -- not supporting final_score in this version 47 | -- Return format: 48 | -- command correct: true/false 49 | -- output string: 50 | -- whether we need to quit the program. 51 | local commands = { 52 | boardsize = function (board_size) 53 | local s = tonumber(board_size) 54 | if s ~= board.get_board_size(b) then 55 | error(string.format("Board size %d is not supported!", s)) 56 | end 57 | return true 58 | end, 59 | clear_board = function () 60 | board.clear(b) 61 | enemy_pass = false 62 | board_initialized = true 63 | return true 64 | end, 65 | komi = function(komi) 66 | io.stderr:write("The current algorithm has no awareness of komi. Nevertheless, it can still play the game.") 67 | -- return board.set_komi(b) 68 | return true 69 | end, 70 | play = function(p, coord) 71 | -- Receive what the opponent plays and update the board. 72 | -- Alpha + number 73 | if not board_initialized then error("Board should be initialized!!") end 74 | local x, y, player = goutils.parse_move_gtp(coord, p) 75 | if not board.play(b, x, y, player) then 76 | error("Illegal move from the opponent!") 77 | end 78 | board.show(b) 79 | print(" ") 80 | 81 | if goutils.is_pass(x, y) then enemy_pass = true end 82 | return true 83 | end, 84 | genmove = function(player) 85 | if not board_initialized then error("Board should be initialized!!") end 86 | -- If enemy pass then we pass. 87 | if enemy_pass then 88 | return true, "pass" 89 | end 90 | 91 | -- Call CNN to get the move. 92 | -- First extract features 93 | player = (player:lower() == 'w') and common.white or common.black 94 | local sortProb, sortInd = goutils.play_with_cnn(b, player, opt, opt.rank, model) 95 | 96 | -- Apply the moves until we have seen a valid one. 97 | local xf, yf, idx = goutils.tryplay_candidates(b, player, sortProb, sortInd) 98 | 99 | local move 100 | if xf == nil then 101 | io.stderr:write("Error! No move is valid!") 102 | -- We just pass here. 103 | move = "pass" 104 | -- Play pass here. 105 | board.play(b, 1, 1, player) 106 | else 107 | move = goutils.compose_move_gtp(xf, yf) 108 | -- Don't use any = signs. 109 | io.stderr:write(string.format("idx: %d, x: %d, y: %d, movestr: %s", idx, xf, yf, move)) 110 | -- Actual play this move 111 | if not board.play(b, xf, yf, player) then 112 | io.stderr:write("Illegal move from move_predictor! move = " .. move) 113 | end 114 | end 115 | 116 | -- Show the current board 117 | board.show(b) 118 | print(" ") 119 | 120 | -- Tell the GTP server we have chosen this move 121 | return true, move 122 | end, 123 | name = function () return true, "go_player_v2" end, 124 | version = function () return true, "version 1.0" end, 125 | tsdebug = function () return true, "not supported yet" end, 126 | protocol_version = function () return true, "0.1" end, 127 | quit = function () return true, "Byebye!", true end, 128 | -- final_score = function () return 0 end, 129 | } 130 | 131 | -- Add list_commands and known_command 132 | local all_commands = {} 133 | for k, _ in pairs(commands) do table.insert(all_commands, k) end 134 | local all_commands_str = table.concat(all_commands, "\n") 135 | 136 | commands.list_commands = function () 137 | return true, all_commands_str 138 | end 139 | 140 | commands.known_command = function (c) return true, type(commands[c]) == 'function' and "true" or "false" end 141 | 142 | -- Begin the main loop 143 | while true do 144 | local line = io.read() 145 | local content = pl.utils.split(line) 146 | 147 | local cmdid = '' 148 | if string.match(content[1], "%d+") then 149 | cmdid = table.remove(content, 1) 150 | end 151 | 152 | local command = table.remove(content, 1) 153 | local successful, outputstr, quit 154 | 155 | if commands[command] == nil then 156 | print("Warning: Ignoring unknown command - " .. line) 157 | else 158 | successful, outputstr, quit = commands[command](unpack(content)) 159 | end 160 | 161 | if successful then 162 | if outputstr == nil then outputstr = '' end 163 | print(string.format("=%s %s\n\n\n", cmdid, outputstr)) 164 | else 165 | print(string.format("?%s ???\n\n\n", cmdid)) 166 | end 167 | io.flush() 168 | if quit then break end 169 | end 170 | 171 | -- Remove models and perform a few garbage collections 172 | model = nil 173 | collectgarbage() 174 | collectgarbage() 175 | collectgarbage() 176 | collectgarbage() 177 | collectgarbage() 178 | collectgarbage() 179 | 180 | -------------------------------------------------------------------------------- /cnnPlayerV2/cnnPlayerV3.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | package.path = package.path .. ';../?.lua' 11 | 12 | local pl = require 'pl.import_into'() 13 | local dcnn_utils = require 'board.dcnn_utils' 14 | local opt = pl.lapp[[ 15 | --codename (default "darkfores2") Code name for models. 16 | -f,--feature_type (default "old") By default we only test old features. If codename is specified, this is omitted. 17 | -r,--rank (default "9d") We play in the level of rank. 18 | --use_local_model Whether we just load local model from the current path 19 | -c,--usecpu Whether we use cpu to run the program. 20 | --shuffle_top_n (default 1) We random choose one of the first n move and play it. 21 | --debug Wehther we use debug mode 22 | --exec (default "") Whether we run an initial script 23 | --setup_board (default "") Setup board. The argument is "sgfname moveto" 24 | --win_rate_thres (default 0.0) If the win rate is lower than that, resign. 25 | --sample_step (default -1) Sample at a particular step. 26 | --temperature (default 1) 27 | --presample_codename (default "darkforest") 28 | --presample_ft (default "old") 29 | --valueModel (default "../models/value_model.bin") 30 | --verbose Whether we print more information 31 | ]] 32 | --for k,v in pairs(opt) do 33 | -- print(k, v) 34 | --end 35 | 36 | local common = require("common.common") 37 | local CNNPlayerV2 = require 'cnnPlayerV2.cnnPlayerV2Framework' 38 | 39 | if opt.debug then 40 | dcnn_utils.dbg_set() 41 | end 42 | 43 | local dcnn_opt = dcnn_utils.init(opt) 44 | local callbacks = { } 45 | function callbacks.move_predictor(b, player) 46 | return dcnn_utils.sample(dcnn_opt, b, player) 47 | end 48 | 49 | function callbacks.get_value(b, player) 50 | local value = -1 51 | if dcnn_opt.valueModel then 52 | value= dcnn_utils.get_value(dcnn_opt, b, player) 53 | end 54 | print("value: ".. string.format("%.3f", value)) 55 | end 56 | 57 | function callbacks.new_game() 58 | collectgarbage() 59 | collectgarbage() 60 | end 61 | 62 | function callbacks.set_attention(x_left, y_top, x_right, y_bottom) 63 | -- Set the attention region if the feature_type is extended_with_attention 64 | if opt.feature_type == "extended_with_attention" then 65 | opt.attention = { x_left, y_top, x_right, y_bottom } 66 | end 67 | end 68 | 69 | local opt2 = { 70 | rule = opt.rule, 71 | exec = opt.exec, 72 | setup_board = opt.setup_board, 73 | win_rate_thres = opt.win_rate_thres, 74 | } 75 | 76 | local cnnplayer = CNNPlayerV2("CNNPlayerV2", "go_player_v2", "1.0", callbacks, opt2) 77 | cnnplayer:mainloop() 78 | 79 | model = nil 80 | collectgarbage() 81 | collectgarbage() 82 | collectgarbage() 83 | collectgarbage() 84 | collectgarbage() 85 | collectgarbage() 86 | -------------------------------------------------------------------------------- /common/comm.c: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #include "comm.h" 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | /*#include */ 19 | 20 | #define MESSAGE_HEAD_SIZE sizeof(long) 21 | 22 | /* Simple communication.*/ 23 | int CommInit(int id, int create_new) { 24 | key_t key; 25 | if ((key = ftok("/home/yuandong/.bashrc", id)) == -1) { 26 | printf("Error! generate key failed!\n"); 27 | exit(1); 28 | } 29 | int channel_id = -1; 30 | int attr = 0644; 31 | if (create_new == 1) attr |= IPC_CREAT; 32 | if ((channel_id = msgget(key, attr)) == -1) { 33 | printf("Error! create message queue failed! \n"); 34 | exit(1); 35 | } 36 | return channel_id; 37 | } 38 | 39 | /* Send a message through the message queue. size is the total size of the buffer in bytes.*/ 40 | void CommSend(int channel_id, void *message, int size) { 41 | if (msgsnd(channel_id, message, size - MESSAGE_HEAD_SIZE, 0) == -1) { 42 | printf("Error! Send message wrong!\n"); 43 | exit(1); 44 | } 45 | } 46 | 47 | /* Send a message through the message queue. If the queue is full, return -1, else return 0 */ 48 | int CommSendNoBlock(int channel_id, void *message, int size) { 49 | if (msgsnd(channel_id, message, size - MESSAGE_HEAD_SIZE, IPC_NOWAIT) == -1) { 50 | return -1; 51 | } 52 | return 0; 53 | } 54 | 55 | /* Receive a message through the queue. If no message, wait until there is one. */ 56 | void CommReceive(int channel_id, void *message, int size) { 57 | long type = * (long *)message; 58 | if (msgrcv(channel_id, message, size - MESSAGE_HEAD_SIZE, type, 0) == -1) { 59 | printf("Error! Receive message wrong!\n"); 60 | exit(1); 61 | } 62 | } 63 | 64 | /* Receive a message through the queue. If no message, wait until there is one. */ 65 | int CommReceiveNoBlock(int channel_id, void *message, int size) { 66 | long type = * (long *)message; 67 | if (msgrcv(channel_id, message, size - MESSAGE_HEAD_SIZE, type, IPC_NOWAIT) == -1) { 68 | return -1; 69 | } 70 | return 0; 71 | } 72 | 73 | /* Destory the message queue. */ 74 | void CommDestroy(int channel_id) { 75 | if (msgctl(channel_id, IPC_RMID, NULL) == -1) { 76 | printf("Error! Failed to destroy message queue! \n"); 77 | exit(1); 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /common/comm.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _COMM_H_ 11 | #define _COMM_H_ 12 | 13 | /* Simple communication. Used for communication between two processes. */ 14 | 15 | /* if create_new == 1, create the channel; otherwise just open it */ 16 | int CommInit(int id, int create_new); 17 | 18 | /* Send a message through the message queue. If the queue is full, wait until it is not full.*/ 19 | void CommSend(int channel_id, void *message, int size); 20 | 21 | /* Send a message through the message queue. If the queue is full, return -1, else return 0 */ 22 | int CommSendNoBlock(int channel_id, void *message, int size); 23 | 24 | /* Receive a message through the queue. If no message, wait until there is one. */ 25 | void CommReceive(int channel_id, void *message, int size); 26 | 27 | /* Receive a message through the queue. If no message, return -1, else return 0 */ 28 | int CommReceiveNoBlock(int channel_id, void *message, int size); 29 | 30 | /* Destory the message queue. */ 31 | void CommDestroy(int channel_id); 32 | 33 | #endif 34 | -------------------------------------------------------------------------------- /common/comm_constant.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _COMM_CONSTANT_H_ 11 | #define _COMM_CONSTANT_H_ 12 | 13 | #define MOVE_NORMAL 0 14 | #define MOVE_SIMPLE_KO 1 15 | // Tactics moves provided by default_playout (e.g., nakade point) 16 | // They are arranged before the actual CNN moves. 17 | // Sometime NN missed it. (So we need to train a better model) 18 | #define MOVE_TACTICAL 2 19 | 20 | // Move used for life and death situations. They are more silly moves but on a local region. 21 | #define MOVE_LD 3 22 | 23 | #endif 24 | -------------------------------------------------------------------------------- /common/comm_pipe.c: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _GNU_SOURCE 11 | #define _GNU_SOURCE 1 12 | #endif 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include "comm_pipe.h" 21 | 22 | #define PIPE_SIZE 1048576 23 | 24 | // Hack here 25 | #define F_SETPIPE_SZ 1031 26 | #define F_GETPIPE_SZ 1032 27 | 28 | int PipeInit(const char *name, int create_pipe, Pipe *p) { 29 | if (strlen(name) >= sizeof(p->filename)) { 30 | printf("Input filename %s is too long!\n", name); 31 | return -1; 32 | } 33 | strcpy(p->filename, name); 34 | 35 | if (create_pipe) { 36 | mkfifo(name, 0666); 37 | p->is_server = 1; 38 | 39 | p->fd = open(name, O_RDWR); 40 | if (p->fd == -1) { 41 | printf("Cannot open pipe %s (server) !", name); 42 | return -1; 43 | } 44 | } else { 45 | // Load it from the global file. 46 | p->is_server = 0; 47 | p->fd = open(name, O_RDWR); 48 | if (p->fd == -1) { 49 | printf("Cannot open pipe %s (client) !\n", name); 50 | return -1; 51 | } 52 | } 53 | if (fcntl(p->fd, F_SETPIPE_SZ, PIPE_SIZE) != PIPE_SIZE) { 54 | printf("Cannot resize pipe %s to %d", name, PIPE_SIZE); 55 | return -1; 56 | } 57 | if (fcntl(p->fd, F_SETFL, O_NONBLOCK) == -1) { 58 | printf("Cannot set to nonblocking model\n"); 59 | return -1; 60 | } 61 | 62 | return 0; 63 | } 64 | 65 | int PipeRead(Pipe *p, void *buffer, int size) { 66 | if (read(p->fd, buffer, size) == -1) { 67 | return -1; 68 | } 69 | return 0; 70 | } 71 | 72 | int PipeWrite(Pipe *p, void *buffer, int size) { 73 | if (write(p->fd, buffer, size) == -1) { 74 | return -1; 75 | } 76 | return 0; 77 | } 78 | 79 | void PipeClose(Pipe *p) { 80 | close(p->fd); 81 | if (p->is_server) unlink(p->filename); 82 | } 83 | -------------------------------------------------------------------------------- /common/comm_pipe.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _COMM_PIPE_H_ 11 | #define _COMM_PIPE_H_ 12 | 13 | #define PIPE_READ 0 14 | #define PIPE_WRITE 1 15 | 16 | typedef struct { 17 | int fd; 18 | char filename[1000]; 19 | int is_server; 20 | } Pipe; 21 | 22 | // Create pipe, if create_pipe == 0, then load the fid from an existing file. 23 | int PipeInit(const char *filename, int create_pipe, Pipe *p); 24 | 25 | // Nonblocking read/write. return -1 if failed, else return 0 26 | int PipeRead(Pipe *p, void *buffer, int size); 27 | int PipeWrite(Pipe *p, void *buffer, int size); 28 | 29 | void PipeClose(Pipe *p); 30 | 31 | #endif 32 | -------------------------------------------------------------------------------- /common/common.c: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #include "common.h" 11 | #include 12 | #include 13 | 14 | double __attribute__ ((noinline)) wallclock(void) { 15 | struct timeval t; 16 | gettimeofday(&t, NULL); 17 | return (1.0e-6*t.tv_usec + t.tv_sec); 18 | } 19 | 20 | uint64_t __attribute__ ((noinline)) wallclock64() { 21 | return (uint64_t)(wallclock() * 1e6); 22 | } 23 | 24 | void dbg_printf(const char *format, ...) { 25 | #ifdef DEBUG 26 | va_list argptr; 27 | va_start(argptr, format); 28 | printf("INFO: "); 29 | vprintf(format, argptr); 30 | va_end(argptr); 31 | printf("\n"); 32 | fflush(stdout); 33 | #endif 34 | } 35 | 36 | void error(const char *format, ...) { 37 | va_list argptr; 38 | va_start(argptr, format); 39 | printf("ERROR: "); 40 | vprintf(format, argptr); 41 | va_end(argptr); 42 | printf("\n"); 43 | fflush(stdout); 44 | // Make an easy sev. 45 | char *a = NULL; 46 | *a = 1; 47 | exit(1); 48 | } 49 | 50 | 51 | -------------------------------------------------------------------------------- /common/common.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _COMMON_H_ 11 | #define _COMMON_H_ 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #ifdef __cplusplus 19 | extern "C" { 20 | #endif 21 | 22 | double __attribute__ ((noinline)) wallclock(void); 23 | uint64_t __attribute__ ((noinline)) wallclock64(); 24 | 25 | #ifdef __cplusplus 26 | } 27 | #endif 28 | 29 | #define S_EMPTY 0 30 | #define S_BLACK 1 31 | #define S_WHITE 2 32 | #define S_OFF_BOARD 3 33 | 34 | typedef unsigned short Coord; 35 | typedef unsigned char Stone; 36 | typedef unsigned char BOOL; 37 | //static const BOOL TRUE = 1; 38 | //static const BOOL FALSE = 0; 39 | #define TRUE 1 40 | #define FALSE 0 41 | 42 | #define STR_BOOL(s) ((s) ? "true" : "false") 43 | #define STR_STONE(s) ((s) == S_BLACK ? "B" : ((s) == S_WHITE ? "W" : "U")) 44 | 45 | #define timeit { \ 46 | double __start = wallclock(); \ 47 | 48 | #define endtime \ 49 | double __duration = wallclock() - __start; \ 50 | printf("Time spent = %lf\n", __duration); \ 51 | } 52 | 53 | #define endtime2(t) \ 54 | t = wallclock() - __start; \ 55 | } 56 | 57 | /* 58 | #define timeit { \ 59 | struct timespec __start, __finish; \ 60 | clock_gettime(CLOCK_MONOTONIC, &__start); 61 | 62 | #define endtime \ 63 | clock_gettime(CLOCK_MONOTONIC, &__finish); \ 64 | double __elapsed = (__finish.tv_sec - __start.tv_sec); \ 65 | __elapsed += (__finish.tv_nsec - __start.tv_nsec) / 1000000000.0; \ 66 | printf("Time spent = %f\n", __elapsed); \ 67 | } 68 | 69 | #define endtime2(t) \ 70 | clock_gettime(CLOCK_MONOTONIC, &__finish); \ 71 | double __elapsed = (__finish.tv_sec - __start.tv_sec); \ 72 | __elapsed += (__finish.tv_nsec - __start.tv_nsec) / 1000000000.0; \ 73 | t = __elapsed; \ 74 | } 75 | */ 76 | 77 | typedef unsigned int (* RandFunc)(void *context, unsigned int max_value); 78 | typedef float (* RandFuncF)(void *context, float max_value); 79 | 80 | void dbg_printf(const char *format, ...); 81 | void error(const char *format, ...); 82 | 83 | extern inline float load_atomic_float(const float *loc) { 84 | // sizeof(float) == sizeof(int) 85 | const int *p = (const int *)loc; 86 | int val = __atomic_load_n(p, __ATOMIC_ACQUIRE); 87 | void *pp1 = (void *)&val; 88 | const float *pp2 = (const float *)pp1; 89 | return *pp2; 90 | } 91 | 92 | extern inline void save_atomic_float(float v, float *loc) { 93 | // sizeof(float) == sizeof(int) 94 | int val; 95 | void *pp1 = (void *)&val; 96 | float *pp2 = (float *)pp1; 97 | *pp2 = v; 98 | __atomic_store_n((int *)loc, val, __ATOMIC_RELAXED); 99 | } 100 | 101 | extern inline void inc_atomic_float(float *loc, float inc) { 102 | int *p = (int *)loc; 103 | int val = __atomic_load_n(p, __ATOMIC_ACQUIRE); 104 | void *pp1 = (void *)&val; 105 | float *pp2 = (float *)pp1; 106 | *pp2 += inc; 107 | __atomic_store_n(p, val, __ATOMIC_RELAXED); 108 | } 109 | 110 | // ============================== Utility ===================================== 111 | // You need have own random generator, the official one (rand()) has built-in thread lock. 112 | extern inline uint16_t fast_random(unsigned long *pmseed, unsigned int max) { 113 | unsigned long hi, lo; 114 | lo = 16807 * (*pmseed & 0xffff); 115 | hi = 16807 * (*pmseed >> 16); 116 | lo += (hi & 0x7fff) << 16; 117 | lo += hi >> 15; 118 | *pmseed = (lo & 0x7fffffff) + (lo >> 31); 119 | return ((*pmseed & 0xffff) * max) >> 16; 120 | } 121 | 122 | // Generate a number for uint64_t. 123 | extern inline uint64_t fast_random64(uint64_t *pmseed) { 124 | uint64_t hi, lo; 125 | uint64_t v = 0; 126 | for (int i = 0; i < 4; ++i) { 127 | lo = 16807 * (*pmseed & 0xffff); 128 | hi = 16807 * (*pmseed >> 16); 129 | lo += (hi & 0x7fff) << 16; 130 | lo += hi >> 15; 131 | *pmseed = (lo & 0x7fffffff) + (lo >> 31); 132 | v <<= 16; 133 | v |= *pmseed & 0xffff; 134 | } 135 | return v; 136 | } 137 | 138 | #endif 139 | -------------------------------------------------------------------------------- /common/common.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | local common = {} 11 | local ffi = require 'ffi' 12 | local utils = require 'utils.utils' 13 | require 'paths' 14 | 15 | common.res_unknown = 0 16 | common.empty = 0 17 | common.black = 1 18 | common.white = 2 19 | common.board_size = 19 20 | 21 | common.TRUE = 1 22 | common.FALSE = 0 23 | 24 | common.player_name = { [0] = 'U', [1] = 'B', [2] = 'W' } 25 | 26 | -- local symbols, s = utils.ffi_include(paths.concat(common.lib_path, "common/common.h")) 27 | function common.script_path() 28 | local str = debug.getinfo(2, "S").source:sub(2) 29 | return str:match("(.*/)") or "./" 30 | end 31 | 32 | local script_path = common.script_path() 33 | 34 | -- Codename for models and their path 35 | common.codenames = { 36 | darkforest = { 37 | model_name = paths.concat(script_path, "../models/df.bin"), 38 | feature_type = 'old' 39 | }, 40 | darkfores1 = { 41 | model_name = paths.concat(script_path, "../models/df1.bin"), 42 | feature_type = 'extended' 43 | }, 44 | darkfores2 = { 45 | model_name = paths.concat(script_path, "../models/df2.bin"), 46 | feature_type = 'extended' 47 | }, 48 | df2_cpu = { 49 | model_name = paths.concat(script_path, "../models/df2_cpu.bin"), 50 | feature_type = 'extended' 51 | }, 52 | } 53 | 54 | -- 55 | local symbols, s = utils.ffi_include(paths.concat(common.script_path(), "common.h")) 56 | local C = ffi.load(paths.concat(script_path, "../libs/libcommon.so")) 57 | 58 | function common.opponent(p) return 3 - p end 59 | function common.wallclock() return C.wallclock() end 60 | 61 | -- From move x, y (starting from 1) to Coord 62 | -- #define OFFSETXY(x, y) ( ((y) + BOARD_MARGIN) * MACRO_BOARD_EXPAND_SIZE + (x) + BOARD_MARGIN ) 63 | function common.xy2coord(x, y) 64 | -- BOARD_MARGIN = 1 65 | -- BOARD_EXPAND_SIZE = 21 66 | return x + y * 21 67 | end 68 | 69 | -- From Coord to x, y (starting from 1) 70 | function common.coord2xy(m) 71 | -- BOARD_MARGIN = 1 72 | -- BOARD_EXPAND_SIZE = 21 73 | return m % 21, math.floor(m / 21) 74 | end 75 | 76 | return common 77 | -------------------------------------------------------------------------------- /common/package.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _PACKAGE_H_ 11 | #define _PACKAGE_H_ 12 | 13 | #include 14 | #include 15 | #include 16 | #include "../common/common.h" 17 | #include "../common/comm_constant.h" 18 | #include "../board/board.h" 19 | 20 | #define NUM_FIRST_MOVES 20 21 | #define MAX_CUSTOM_DATA 500 22 | #define SIG_OK 0 23 | #define SIG_RESTART 1 24 | #define SIG_FINISHSOON 2 25 | #define SIG_NOPKG 3 26 | #define SIG_ACK 100 27 | 28 | // Several kind of messages. The first element has to be long 29 | // Message 1: board 30 | typedef struct { 31 | // sequence number. 32 | long seq; 33 | uint64_t b; 34 | // Send time (in microsecond). 35 | double t_sent; 36 | // Board configuration 37 | Board board; 38 | } MBoard; 39 | 40 | // Messsage 2: move information 41 | typedef struct { 42 | long seq; 43 | uint64_t b; 44 | 45 | // Send time, received time and reply time. 46 | double t_sent, t_received, t_replied; 47 | char hostname[30]; 48 | 49 | Stone player; 50 | BOOL error; 51 | char xs[NUM_FIRST_MOVES]; 52 | char ys[NUM_FIRST_MOVES]; 53 | float probs[NUM_FIRST_MOVES]; 54 | // Use for types of moves, can be MOVE_SIMPLE_KO or MOVE_NORMAL 55 | char types[NUM_FIRST_MOVES]; 56 | 57 | // Custom data. E.g., feature for the current board. 58 | char extra[MAX_CUSTOM_DATA]; 59 | 60 | // The board hash for the board used. 61 | // uint64_t board_hash; 62 | 63 | // Score if there is any prediction. 64 | BOOL has_score; 65 | float score; 66 | } MMove; 67 | 68 | #endif 69 | -------------------------------------------------------------------------------- /compile.sh: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | # 9 | 10 | #!/bin/bash 11 | 12 | CPP_FLAGS="-O4 -fPIC -std=c++11 -I/usr/include/malloc/ -Wl,-export-dynamic" 13 | CXX=g++ 14 | 15 | echo Compiling 16 | $CXX $CPP_FLAGS -I./common -c common/common.c common/comm.c common/comm_pipe.c 17 | $CXX $CPP_FLAGS -I./common -I./board -c board/board.c board/default_policy.c board/default_policy_common.c board/pattern.c board/pattern_v2.c board/ownermap.c board/sample_pattern_v2.c 18 | $CXX $CPP_FLAGS -I./common -I./board -c tsumego/rank_move.c 19 | 20 | $CXX $CPP_FLAGS -fpermissive -I./common -I./board -c pachi_tactics/moggy.c pachi_tactics/board_interface.c 21 | $CXX $CPP_FLAGS -fpermissive -I./common -I./board -c pachi_tactics/tactics/1lib.c pachi_tactics/tactics/2lib.c pachi_tactics/tactics/ladder.c pachi_tactics/tactics/nakade.c pachi_tactics/tactics/nlib.c pachi_tactics/tactics/selfatari.c 22 | 23 | echo Create moggy 24 | $CXX -shared -o libmoggy.so moggy.o board.o board_interface.o 1lib.o 2lib.o ladder.o nakade.o nlib.o selfatari.o pattern.o 25 | 26 | $CXX $CPP_FLAGS -I./common -I./board -I./mctsv2 -c mctsv2/tree.c mctsv2/playout_multithread.c mctsv2/playout_callbacks.c mctsv2/event_count.cpp mctsv2/tree_search.c 27 | $CXX $CPP_FLAGS -I./common -c ./local_evaluator/cnn_local_exchanger.c 28 | 29 | echo Create libboard and libcomm 30 | $CXX -shared -Wl,-export-dynamic -o libcommon.so common.o 31 | $CXX -shared -Wl,-export-dynamic -o libboard.so board.o common.o 32 | $CXX -shared -Wl,-export-dynamic -o libdefault_policy.so default_policy_common.o default_policy.o board.o common.o pattern.o pattern_v2.o 33 | $CXX -shared -Wl,-export-dynamic -o libownermap.so board.o common.o ownermap.o 34 | $CXX -shared -Wl,-export-dynamic -o libpattern_v2.so pattern_v2.o board.o common.o ownermap.o 35 | $CXX -shared -Wl,-export-dynamic -o libcomm.so comm.o 36 | 37 | echo Create libplayout_multithread.so 38 | $CXX -shared -o libplayout_multithread.so tree.o playout_multithread.o board.o tree_search.o playout_callbacks.o common.o cnn_local_exchanger.o comm_pipe.o default_policy.o pattern.o pattern_v2.o default_policy_common.o rank_move.o event_count.o moggy.o board_interface.o 1lib.o 2lib.o ladder.o nakade.o nlib.o selfatari.o -lm 39 | 40 | echo Create liblocalexchanger.so 41 | $CXX -shared -o liblocalexchanger.so comm_pipe.o cnn_local_exchanger.o board.o common.o -lm 42 | 43 | echo Compile all test codes 44 | $CXX $CPP_FLAGS -lm -pthread mctsv2/test_playout_multithread.c tree.o playout_multithread.o board.o common.o playout_callbacks.o comm_pipe.o event_count.o tree_search.o cnn_local_exchanger.o default_policy.o default_policy_common.o pattern.o pattern_v2.o rank_move.o moggy.o board_interface.o 1lib.o 2lib.o ladder.o nakade.o nlib.o selfatari.o -I./common -I./board -o test_playout_multithread 45 | 46 | echo Put all .so file into directory so that lua could load 47 | DEST_DIR=./libs 48 | 49 | cp libboard.so $DEST_DIR 50 | cp libownermap.so $DEST_DIR 51 | cp libpattern_v2.so $DEST_DIR 52 | cp libdefault_policy.so $DEST_DIR 53 | cp libcomm.so $DEST_DIR 54 | cp libcommon.so $DEST_DIR 55 | cp libplayout_multithread.so $DEST_DIR 56 | cp libmoggy.so $DEST_DIR 57 | cp liblocalexchanger.so $DEST_DIR 58 | -------------------------------------------------------------------------------- /figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/darkforestGo/ef1885ed5004dac8cbea2cbd3644706565af0876/figure.png -------------------------------------------------------------------------------- /libs/README.md: -------------------------------------------------------------------------------- 1 | All the .so files will be saved here after compilation. 2 | -------------------------------------------------------------------------------- /local_evaluator/cnn_evaluator.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | package.path = package.path .. ';../?.lua' 11 | 12 | local pl = require 'pl.import_into'() 13 | local utils = require('utils.utils') 14 | 15 | utils.require_torch() 16 | utils.require_cutorch() 17 | 18 | local ffi = require 'ffi' 19 | local utils = require('utils.utils') 20 | local common = require("common.common") 21 | local threads = require 'threads' 22 | threads.serialization('threads.sharedserialize') 23 | 24 | --local ctrl_restart = tonumber(symbols.CTRL_RESTART) 25 | --local ctrl_remove = tonumber(symbols.CTRL_REMOVE) 26 | --local ex = ffi.new("Exchanger") 27 | --C.Init(ex, true, true) 28 | local opt = pl.lapp[[ 29 | -g,--gpu (default 1) GPU id to use. 30 | --async Make it asynchronized. 31 | --pipe_path (default "./") Path for pipe file. Default is in the current directory, i.e., go/mcts 32 | --codename (default "darkfores2") Code name for the model to load. 33 | --use_local_model If true, load the local model. 34 | ]] 35 | 36 | print("GPU used: " .. opt.gpu) 37 | opt_internal = opt 38 | 39 | -- Start 4 GPUs. 40 | -- pool = threads.Threads(#gpus, function () end, function (idx) gpu = tonumber(gpus[idx]) end) 41 | -- for i = 1, #gpus do 42 | -- pool:addjob(function () paths.dofile("cnn_evaluator_run1.lua") end) 43 | -- end 44 | -- pool:synchronize() 45 | paths.dofile("cnn_evaluator_run1.lua") 46 | 47 | 48 | -------------------------------------------------------------------------------- /local_evaluator/cnn_evaluator.sh: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | # 9 | 10 | #!/bin/bash 11 | 12 | NUM_GPU=$1 13 | OUTPUT_PATH=$2 14 | OTHER_OPTS=$3 15 | 16 | echo num of gpu used = $NUM_GPU 17 | echo other parameters = $OTHER_OPTS 18 | echo output path = $OUTPUT_PATH 19 | 20 | OUTPUT=$OUTPUT_PATH/cnn_eval 21 | for i in `seq 1 $NUM_GPU`; do 22 | echo "" > $OUTPUT-${i}.log 23 | th cnn_evaluator.lua -g $i $OTHER_OPTS --pipe_path $OUTPUT_PATH >> $OUTPUT-${i}.log 2>&1 & 24 | echo $! 25 | done 26 | 27 | # Wait until they are ready. 28 | for i in `seq 1 $NUM_GPU`; do 29 | while true; do 30 | if grep -q "ready" $OUTPUT-${i}.log; then 31 | break 32 | fi 33 | sleep 1 34 | done 35 | done 36 | 37 | -------------------------------------------------------------------------------- /local_evaluator/cnn_evaluator_run1.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | local pl = require 'pl.import_into'() 11 | local utils = require('utils.utils') 12 | 13 | utils.require_torch() 14 | utils.require_cutorch() 15 | 16 | local ffi = require 'ffi' 17 | local utils = require('utils.utils') 18 | local board = require('board.board') 19 | local common = require("common.common") 20 | local util_pkg = require 'common.util_package' 21 | 22 | -- local symbols, s = utils.ffi_include(paths.concat(common.lib_path, "local_evaluator/cnn_local_exchanger.h")) 23 | local script_path = common.script_path() 24 | local symbols, s = utils.ffi_include(paths.concat(script_path, "cnn_local_exchanger.h")) 25 | local C = ffi.load(paths.concat(script_path, "../libs/liblocalexchanger.so")) 26 | 27 | local sig_ok = tonumber(symbols.SIG_OK) 28 | local max_batch = opt_internal.async and 128 or 32 29 | 30 | -- number of attempt before wait_board gave up and return nil. 31 | -- previously this number is indefinite, i.e., wait until there is a package (which might cause deadlock). 32 | local num_attempt = 10 33 | 34 | cutorch.setDevice(opt_internal.gpu) 35 | local model_filename = common.codenames[opt_internal.codename].model_name 36 | local feature_type = common.codenames[opt_internal.codename].feature_type 37 | assert(model_filename, "opt.codename [" .. opt_internal.codename .. "] not found!") 38 | 39 | if opt_internal.use_local_model then 40 | model_filename = pl.path.basename(model_filename) 41 | end 42 | print("Loading model = " .. model_filename) 43 | local model = torch.load(model_filename) 44 | print("Loading complete") 45 | 46 | -- Server side. 47 | local ex = C.ExLocalInit(opt_internal.pipe_path, opt_internal.gpu - 1, common.TRUE) 48 | print("CNN Exchanger initialized.") 49 | print("Size of MBoard: " .. ffi.sizeof('MBoard')) 50 | print("Size of MMove: " .. ffi.sizeof('MMove')) 51 | board.print_info() 52 | 53 | -- [board_idx, received time] 54 | local block_ids = torch.DoubleTensor(max_batch) 55 | local sortProb = torch.FloatTensor(max_batch, common.board_size * common.board_size) 56 | local sortInd = torch.FloatTensor(max_batch, common.board_size * common.board_size) 57 | 58 | util_pkg.init(max_batch, feature_type) 59 | -- util_pkg.dbg_set() 60 | 61 | print("ready") 62 | io.flush() 63 | 64 | -- Preallocate the cuda tensors. 65 | local probs_cuda, sortProb_cuda, sortInd_cuda 66 | 67 | -- Feature for the batch. 68 | local all_features 69 | while true do 70 | -- Get data 71 | block_ids:zero() 72 | if all_features then 73 | all_features:zero() 74 | end 75 | 76 | local num_valid = 0 77 | 78 | -- Start the cycle. 79 | -- local start = common.wallclock() 80 | for i = 1, max_batch do 81 | local mboard = util_pkg.boards[i - 1] 82 | -- require 'fb.debugger'.enter() 83 | local ret = C.ExLocalServerGetBoard(ex, mboard, num_attempt) 84 | -- require 'fb.debugger'.enter() 85 | if ret == sig_ok and mboard.seq ~= 0 and mboard.b ~= 0 then 86 | local feature = util_pkg.extract_board_feature(i) 87 | if feature ~= nil then 88 | local nplane, h, w = unpack(feature:size():totable()) 89 | if all_features == nil then 90 | all_features = torch.CudaTensor(max_batch, nplane, h, w):zero() 91 | probs_cuda = torch.CudaTensor(max_batch, h*w) 92 | sortProb_cuda = torch.CudaTensor(max_batch, h*w) 93 | sortInd_cuda = torch.CudaLongTensor(max_batch, h*w) 94 | end 95 | num_valid = num_valid + 1 96 | all_features[num_valid]:copy(feature) 97 | block_ids[num_valid] = i 98 | end 99 | end 100 | end 101 | -- print(string.format("Collect data = %f", common.wallclock() - start)) 102 | -- Now all data are ready, run the model. 103 | if C.ExLocalServerIsRestarting(ex) == common.FALSE and all_features ~= nil and num_valid > 0 then 104 | print(string.format("Valid sample = %d / %d", num_valid, max_batch)) 105 | util_pkg.dprint("Start evaluation...") 106 | local start = common.wallclock() 107 | local output = model:forward(all_features:sub(1, num_valid)) 108 | local territory 109 | util_pkg.dprint("End evaluation...") 110 | -- If the output is multitask, only take the first one. 111 | -- require 'fb.debugger'.enter() 112 | if type(output) == 'table' then 113 | -- Territory 114 | -- require 'fb.debugger'.enter() 115 | if #output == 4 then 116 | territory = output[4] 117 | end 118 | output = output[1] 119 | end 120 | 121 | local probs_cuda_sel = probs_cuda:sub(1, num_valid) 122 | local sortProb_cuda_sel = sortProb_cuda:sub(1, num_valid) 123 | local sortInd_cuda_sel = sortInd_cuda:sub(1, num_valid) 124 | 125 | torch.exp(probs_cuda_sel, output:view(num_valid, -1)) 126 | torch.sort(sortProb_cuda_sel, sortInd_cuda_sel, probs_cuda_sel, 2, true) 127 | -- local sortProb_cuda = torch.CudaTensor(num_valid, 19*19):fill(0.5) 128 | -- local sortInd_cuda = torch.CudaTensor(num_valid, 19*19):fill(23) 129 | 130 | sortProb:sub(1, num_valid):copy(sortProb_cuda_sel) 131 | sortInd:sub(1, num_valid):copy(sortInd_cuda_sel) 132 | 133 | local score 134 | if territory then 135 | -- Compute score, only if > 0.6 we regard it as black/white territory. 136 | -- score = territory:ge(0.6):sum(3):float() 137 | local diff = territory[{{}, {1}, {}}] - territory[{{}, {2}, {}}] 138 | score = diff:ge(0):sum(3) 139 | -- score = territory:sum(3):float() 140 | end 141 | -- sortProb:copy(sortProb_cuda[{{}, {1, num_first_move}}]) 142 | -- sortInd:copy(sortInd_cuda[{{}, {1, num_first_move}}]) 143 | print(string.format("Computation = %f", common.wallclock() - start)) 144 | 145 | local start = common.wallclock() 146 | -- Send them back. 147 | for k = 1, num_valid do 148 | local mmove = util_pkg.prepare_move(block_ids[k], sortProb[k], sortInd[k], score and score[k]) 149 | util_pkg.dprint("Actually send move") 150 | C.ExLocalServerSendMove(ex, mmove) 151 | util_pkg.dprint("After send move") 152 | end 153 | print(string.format("Send back = %f", common.wallclock() - start)) 154 | 155 | end 156 | 157 | util_pkg.sparse_gc() 158 | 159 | -- Send control message if necessary. 160 | if C.ExLocalServerSendAckIfNecessary(ex) == common.TRUE then 161 | print("Ack signal sent!") 162 | end 163 | end 164 | 165 | C.ExLocalDestroy(ex) 166 | -------------------------------------------------------------------------------- /local_evaluator/cnn_exchanger.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _CNN_EXCHANGER_ 11 | #define _CNN_EXCHANGER_ 12 | 13 | #include 14 | #include "../common/comm_pipe.h" 15 | #include "../common/common.h" 16 | 17 | // Dummy functions for distributed version. 18 | 19 | void *ExClientInit(const char tier_name[100]) { return NULL; } 20 | void ExClientDestroy(void *) { } 21 | 22 | // 23 | int ExClientSetMaxWaitCount(void *ctx, int n) { return 0; } 24 | 25 | // Send board (not blocked) 26 | BOOL ExClientSendBoard(void *ctx, MBoard *board) { return TRUE; } 27 | 28 | // Receive move (not blocked) 29 | BOOL ExClientGetMove(void *ctx, MMove *move) { return TRUE; } 30 | 31 | // Send restart signal (in block mode) once the search is over 32 | BOOL ExClientSendRestart(void *ctx) { return TRUE; } 33 | 34 | BOOL ExClientIncWaitCount(void *ctx, BOOL send_if_needed) { return TRUE; } 35 | 36 | BOOL ExClientDecWaitCount(void *ctx) { return TRUE; } 37 | 38 | BOOL ExClientSendFinishSoon(void *ctx) { return TRUE; } 39 | 40 | // Blocked wait until ack is received. 41 | BOOL ExClientWaitAck(void *ctx) { return TRUE; } 42 | 43 | void ExClientStopReceivers(void *) { } 44 | 45 | #endif 46 | -------------------------------------------------------------------------------- /local_evaluator/cnn_local_exchanger.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _CNN_LOCAL_EXCHANGER_H_ 11 | #define _CNN_LOCAL_EXCHANGER_H_ 12 | 13 | #include "../common/package.h" 14 | 15 | #ifdef __cplusplus 16 | extern "C" { 17 | #endif 18 | 19 | // Init exchanger. 20 | // pipe_path: the path of the pipe 21 | // id: the id of the pipe. 22 | // is_server: whether this opened pipe is a server. 23 | void *ExLocalInit(const char *pipe_path, int id, BOOL is_server); 24 | void ExLocalDestroy(void *ctx); 25 | 26 | // Server side, three cases 27 | // 1. Block on message with exit value = SIG_OK on newboard. 28 | // 2. Return immediately with exit value = SIG_RESTART 29 | // 3. Return immediately with exit value = SIG_HIGH_PR 30 | // If num_attempt == 0, then try indefinitely, otherwise try num_attempt. 31 | int ExLocalServerGetBoard(void *ctx, MBoard *board, int num_attempt); 32 | // Block send moves, once CNN finish evaluation. 33 | // If done is set, don't send anything. 34 | BOOL ExLocalServerSendMove(void *ctx, MMove *move); 35 | // Send ack for any unusual signal received. 36 | BOOL ExLocalServerSendAckIfNecessary(void *ctx); 37 | // Check whether the server is restarting. 38 | BOOL ExLocalServerIsRestarting(void *ctx); 39 | 40 | // Client side 41 | // Set Maximum wait count. Return the previous maximum. 42 | int ExLocalClientSetMaxWaitCount(void *ctx, int n); 43 | // Send board (not blocked) 44 | BOOL ExLocalClientSendBoard(void *ctx, MBoard *board); 45 | // Receive move (not blocked) 46 | BOOL ExLocalClientGetMove(void *ctx, MMove *move); 47 | // Add the wait count. If the count is >= wait_count_max (set by ExLocalClientSetWaitCount) and send_if_needed is true, 48 | // then send SIG_FINISHSOON. 49 | // This means that already n threads are waiting on the results, please response soon. 50 | // Return TRUE if we have sent SIG_FINISHSOON, return FALSE if not. 51 | BOOL ExLocalClientIncWaitCount(void *ctx, BOOL send_if_needed); 52 | // Decrease the wait count. 53 | // Return TRUE if we have done the operation, FALSE if the count is < 0 (this should error). 54 | BOOL ExLocalClientDecWaitCount(void *ctx); 55 | 56 | // Send restart signal (in block mode) once the search is over 57 | BOOL ExLocalClientSendRestart(void *ctx); 58 | // Send finish soon signal. Server will evaluate all current existing 59 | // board situations and then return. It won't wait for a missing board indefinitely. 60 | BOOL ExLocalClientSendFinishSoon(void *ctx); 61 | // Blocked wait until ack is received. 62 | BOOL ExLocalClientWaitAck(void *ctx); 63 | 64 | #ifdef __cplusplus 65 | } 66 | #endif 67 | 68 | #endif 69 | -------------------------------------------------------------------------------- /local_evaluator/kill_evaluator.sh: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | # 9 | 10 | #!/bin/bash 11 | 12 | killall -9 cnn_local_evaluator_lua_main 13 | 14 | -------------------------------------------------------------------------------- /mctsv2/FollyEventCount.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | // @author Tudor Bosman (tudorb@fb.com) 10 | 11 | #pragma once 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | 24 | namespace _folly { 25 | 26 | namespace detail { 27 | 28 | inline int futex(int* uaddr, int op, int val, const timespec* timeout, 29 | int* uaddr2, int val3) noexcept { 30 | return syscall(SYS_futex, uaddr, op, val, timeout, uaddr2, val3); 31 | } 32 | 33 | } // namespace detail 34 | 35 | /** 36 | * Event count: a condition variable for lock free algorithms. 37 | * 38 | * See http://www.1024cores.net/home/lock-free-algorithms/eventcounts for 39 | * details. 40 | * 41 | * Event counts allow you to convert a non-blocking lock-free / wait-free 42 | * algorithm into a blocking one, by isolating the blocking logic. You call 43 | * prepareWait() before checking your condition and then either cancelWait() 44 | * or wait() depending on whether the condition was true. When another 45 | * thread makes the condition true, it must call notify() / notifyAll() just 46 | * like a regular condition variable. 47 | * 48 | * If "<" denotes the happens-before relationship, consider 2 threads (T1 and 49 | * T2) and 3 events: 50 | * - E1: T1 returns from prepareWait 51 | * - E2: T1 calls wait 52 | * (obviously E1 < E2, intra-thread) 53 | * - E3: T2 calls notifyAll 54 | * 55 | * If E1 < E3, then E2's wait will complete (and T1 will either wake up, 56 | * or not block at all) 57 | * 58 | * This means that you can use an EventCount in the following manner: 59 | * 60 | * Waiter: 61 | * if (!condition()) { // handle fast path first 62 | * for (;;) { 63 | * auto key = eventCount.prepareWait(); 64 | * if (condition()) { 65 | * eventCount.cancelWait(); 66 | * break; 67 | * } else { 68 | * eventCount.wait(key); 69 | * } 70 | * } 71 | * } 72 | * 73 | * (This pattern is encapsulated in await()) 74 | * 75 | * Poster: 76 | * make_condition_true(); 77 | * eventCount.notifyAll(); 78 | * 79 | * Note that, just like with regular condition variables, the waiter needs to 80 | * be tolerant of spurious wakeups and needs to recheck the condition after 81 | * being woken up. Also, as there is no mutual exclusion implied, "checking" 82 | * the condition likely means attempting an operation on an underlying 83 | * data structure (push into a lock-free queue, etc) and returning true on 84 | * success and false on failure. 85 | */ 86 | class EventCount { 87 | public: 88 | EventCount() noexcept : val_(0) { } 89 | 90 | class Key { 91 | friend class EventCount; 92 | explicit Key(uint32_t e) noexcept : epoch_(e) { } 93 | uint32_t epoch_; 94 | }; 95 | 96 | void notify() noexcept; 97 | void notifyAll() noexcept; 98 | Key prepareWait() noexcept; 99 | void cancelWait() noexcept; 100 | void wait(Key key) noexcept; 101 | 102 | /** 103 | * Wait for condition() to become true. Will clean up appropriately if 104 | * condition() throws, and then rethrow. 105 | */ 106 | template 107 | void await(Condition condition); 108 | 109 | private: 110 | void doNotify(int n) noexcept; 111 | EventCount(const EventCount&) = delete; 112 | EventCount(EventCount&&) = delete; 113 | EventCount& operator=(const EventCount&) = delete; 114 | EventCount& operator=(EventCount&&) = delete; 115 | 116 | // This requires 64-bit 117 | static_assert(sizeof(int) == 4, "bad platform"); 118 | static_assert(sizeof(uint32_t) == 4, "bad platform"); 119 | static_assert(sizeof(uint64_t) == 8, "bad platform"); 120 | 121 | #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ 122 | static constexpr size_t kEpochOffset = 1; 123 | #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ 124 | static constexpr size_t kEpochOffset = 0; // in units of sizeof(int) 125 | #else 126 | # error Your machine uses a weird endianness! 127 | #endif 128 | 129 | // val_ stores the epoch in the most significant 32 bits and the 130 | // waiter count in the least significant 32 bits. 131 | std::atomic val_; 132 | 133 | static constexpr uint64_t kAddWaiter = uint64_t(1); 134 | static constexpr uint64_t kSubWaiter = uint64_t(-1); 135 | static constexpr size_t kEpochShift = 32; 136 | static constexpr uint64_t kAddEpoch = uint64_t(1) << kEpochShift; 137 | static constexpr uint64_t kWaiterMask = kAddEpoch - 1; 138 | }; 139 | 140 | inline void EventCount::notify() noexcept { 141 | doNotify(1); 142 | } 143 | 144 | inline void EventCount::notifyAll() noexcept { 145 | doNotify(INT_MAX); 146 | } 147 | 148 | inline void EventCount::doNotify(int n) noexcept { 149 | uint64_t prev = val_.fetch_add(kAddEpoch, std::memory_order_acq_rel); 150 | if (prev & kWaiterMask) { 151 | detail::futex(reinterpret_cast(&val_) + kEpochOffset, 152 | FUTEX_WAKE, n, nullptr, nullptr, 0); 153 | } 154 | } 155 | 156 | inline EventCount::Key EventCount::prepareWait() noexcept { 157 | uint64_t prev = val_.fetch_add(kAddWaiter, std::memory_order_acq_rel); 158 | return Key(prev >> kEpochShift); 159 | } 160 | 161 | inline void EventCount::cancelWait() noexcept { 162 | // memory_order_relaxed would suffice for correctness, but the faster 163 | // #waiters gets to 0, the less likely it is that we'll do spurious wakeups 164 | // (and thus system calls). 165 | uint64_t prev = val_.fetch_add(kSubWaiter, std::memory_order_seq_cst); 166 | assert((prev & kWaiterMask) != 0); 167 | } 168 | 169 | inline void EventCount::wait(Key key) noexcept { 170 | while ((val_.load(std::memory_order_acquire) >> kEpochShift) == key.epoch_) { 171 | detail::futex(reinterpret_cast(&val_) + kEpochOffset, 172 | FUTEX_WAIT, key.epoch_, nullptr, nullptr, 0); 173 | } 174 | // memory_order_relaxed would suffice for correctness, but the faster 175 | // #waiters gets to 0, the less likely it is that we'll do spurious wakeups 176 | // (and thus system calls) 177 | uint64_t prev = val_.fetch_add(kSubWaiter, std::memory_order_seq_cst); 178 | assert((prev & kWaiterMask) != 0); 179 | } 180 | 181 | template 182 | void EventCount::await(Condition condition) { 183 | if (condition()) return; // fast path 184 | 185 | // condition() is the only thing that may throw, everything else is 186 | // noexcept, so we can hoist the try/catch block outside of the loop 187 | try { 188 | for (;;) { 189 | auto key = prepareWait(); 190 | if (condition()) { 191 | cancelWait(); 192 | break; 193 | } else { 194 | wait(key); 195 | } 196 | } 197 | } catch (...) { 198 | cancelWait(); 199 | throw; 200 | } 201 | } 202 | 203 | } // namespace _folly 204 | -------------------------------------------------------------------------------- /mctsv2/event_count.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | // @author Tudor Bosman (tudorb@fb.com) 10 | 11 | #include "event_count.h" 12 | 13 | #include 14 | #include "FollyEventCount.h" 15 | 16 | static_assert(sizeof(EventCount) == sizeof(_folly::EventCount), 17 | "EventCount size mismatch"); 18 | 19 | static_assert(alignof(EventCount) == alignof(_folly::EventCount), 20 | "EventCount alignment mismatch"); 21 | 22 | static_assert(std::is_standard_layout<_folly::EventCount>::value, 23 | "_folly::EventCount must be standard layout"); 24 | 25 | 26 | static_assert(sizeof(EventCountKey) == sizeof(_folly::EventCount::Key), 27 | "EventCountKey size mismatch"); 28 | 29 | static_assert(alignof(EventCountKey) == alignof(_folly::EventCount::Key), 30 | "EventCountKey alignment mismatch"); 31 | 32 | static_assert(std::is_standard_layout<_folly::EventCount::Key>::value, 33 | "_folly::EventCount::Key must be standard layout"); 34 | 35 | namespace { 36 | 37 | inline _folly::EventCount* EV(EventCount* ev) { 38 | return reinterpret_cast<_folly::EventCount*>(reinterpret_cast(ev)); 39 | } 40 | 41 | inline _folly::EventCount::Key* EVK(EventCountKey* key) { 42 | return reinterpret_cast<_folly::EventCount::Key*>( 43 | reinterpret_cast(key)); 44 | } 45 | 46 | } // namespace 47 | 48 | void event_count_init(EventCount* ev) { 49 | new (ev) _folly::EventCount(); 50 | } 51 | 52 | void event_count_destroy(EventCount* ev) { 53 | EV(ev)->~EventCount(); 54 | } 55 | 56 | void event_count_notify(EventCount* ev) { 57 | EV(ev)->notify(); 58 | } 59 | 60 | void event_count_broadcast(EventCount* ev) { 61 | EV(ev)->notifyAll(); 62 | } 63 | 64 | EventCountKey event_count_prepare(EventCount* ev) { 65 | auto key = EV(ev)->prepareWait(); 66 | EventCountKey ek; 67 | memcpy(&ek, &key, sizeof(ek)); 68 | return ek; 69 | } 70 | 71 | void event_count_cancel(EventCount* ev) { 72 | EV(ev)->cancelWait(); 73 | } 74 | 75 | void event_count_wait(EventCount* ev, EventCountKey key) { 76 | EV(ev)->wait(*EVK(&key)); 77 | } 78 | 79 | void event_count_await(EventCount* ev, EventCountWaitCallback cb, void* ctx) { 80 | EV(ev)->await([cb, ctx] () { return cb(ctx); }); 81 | } 82 | -------------------------------------------------------------------------------- /mctsv2/event_count.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | // @author Tudor Bosman (tudorb@fb.com) 10 | 11 | #pragma once 12 | 13 | #include 14 | 15 | #ifdef __cplusplus 16 | extern "C" { 17 | #endif 18 | 19 | typedef struct _EventCount { 20 | uint64_t _x; 21 | } EventCount; 22 | 23 | typedef struct _EventCountKey { 24 | uint32_t _x; 25 | } EventCountKey; 26 | 27 | void event_count_init(EventCount* ev); 28 | void event_count_destroy(EventCount* ev); 29 | 30 | void event_count_notify(EventCount* ev); 31 | void event_count_broadcast(EventCount* ev); 32 | 33 | EventCountKey event_count_prepare(EventCount* ev); 34 | void event_count_cancel(EventCount* ev); 35 | void event_count_wait(EventCount* ev, EventCountKey key); 36 | 37 | // Wait until cb(ctx) returns non-zero 38 | typedef int (*EventCountWaitCallback)(void*); 39 | void event_count_await(EventCount* ev, EventCountWaitCallback cb, 40 | void* ctx); 41 | 42 | #ifdef __cplusplus 43 | } // extern "C" 44 | #endif 45 | -------------------------------------------------------------------------------- /mctsv2/playout_callbacks.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _PLAYOUT_CALLBACKS_H_ 11 | #define _PLAYOUT_CALLBACKS_H_ 12 | 13 | #include "tree_search_internal.h" 14 | 15 | // Normal mode. 16 | BOOL cnn_policy(ThreadInfo *info, TreeBlock *bl, const Board *board, BlockOffset *offset, TreeBlock **child_chosen); 17 | void threaded_run_bp(ThreadInfo *info, float black_moku, Stone next_player, int end_ply, BOOL board_on_child, BlockOffset child_offset, TreeBlock *b); 18 | float threaded_compute_score(ThreadInfo *info, const Board *board); 19 | BOOL dcnn_leaf_expansion(ThreadInfo *info, const Board *board, TreeBlock *b); 20 | 21 | BOOL async_policy(ThreadInfo *info, TreeBlock *bl, const Board *board, BlockOffset *offset, TreeBlock **child_chosen); 22 | 23 | // Def policy using fast rollout. 24 | DefPolicyMove fast_rollout_def_policy(void *def_policy, void *context, RandFunc rand_func, Board* board, const Region *r, int max_depth, BOOL verbose); 25 | 26 | // Tsumego mode. 27 | BOOL ld_policy(ThreadInfo *info, TreeBlock *bl, const Board *board, BlockOffset *offset, TreeBlock **child_chosen); 28 | void threaded_run_tsumego_bp(ThreadInfo *info, float black_moku, Stone next_player, int end_ply, BOOL board_on_child, BlockOffset child_offset, TreeBlock *b); 29 | BOOL tsumego_dcnn_leaf_expansion(ThreadInfo *info, const Board *board, TreeBlock *bl); 30 | BOOL tsumego_rule_leaf_expansion(ThreadInfo *info, const Board *board, TreeBlock *b); 31 | 32 | #endif 33 | -------------------------------------------------------------------------------- /mctsv2/playout_common.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _PLAYOUT_COMMON_H_ 11 | #define _PLAYOUT_COMMON_H_ 12 | 13 | #include "../board/board.h" 14 | 15 | #define PRINT_INFO(...) do { if (s->params.verbose >= V_INFO) { printf(__VA_ARGS__); fflush(stdout); } } while(0) 16 | #define PRINT_CRITICAL(...) do { if (s->params.verbose >= V_CRITICAL) { printf(__VA_ARGS__); fflush(stdout); } } while(0) 17 | #define PRINT_DEBUG(...) do { if (s->params.verbose >= V_DEBUG) { printf(__VA_ARGS__); fflush(stdout); } } while(0) 18 | 19 | // Define Move so that we could communicate between LUA and C. 20 | typedef struct { 21 | int x; 22 | int y; 23 | Coord m; 24 | Stone player; 25 | float win_rate; 26 | float win_games; 27 | int total_games; 28 | } Move; 29 | 30 | typedef struct { 31 | Move moves[MACRO_BOARD_SIZE*MACRO_BOARD_SIZE]; 32 | int num_moves; 33 | } Moves; 34 | 35 | #endif 36 | -------------------------------------------------------------------------------- /mctsv2/playout_multithread.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _PLAYOUT_MULTITHREAD_H_ 11 | #define _PLAYOUT_MULTITHREAD_H_ 12 | 13 | #include "playout_params.h" 14 | #include "playout_common.h" 15 | #include "tree_search.h" 16 | 17 | #ifdef __cplusplus 18 | extern "C" { 19 | #endif 20 | 21 | // Compose the move from x, y and player. 22 | Move compose_move(int x, int y, Stone player); 23 | Move compose_move2(Coord m, Stone player); 24 | 25 | void ts_v2_print_params(void *ctx); 26 | void ts_v2_init_params(SearchParamsV2 *params); 27 | 28 | // Initialize the tree search, return a search handle (void *). 29 | void* ts_v2_init(const SearchParamsV2 *params, const TreeParams *tree_params, const Board *board); 30 | 31 | // Set the board, also can be used to restart the game. 32 | void ts_v2_setboard(void *ctx, const Board *init_board); 33 | void ts_v2_add_move_history(void *ctx, Coord m, Stone player, BOOL actual_play); 34 | 35 | // Change the parameters on the fly. Be extra careful about this function. 36 | // Depending on different setting, we might need to clear up different internal status. 37 | BOOL ts_v2_set_params(void *ctx, const SearchParamsV2 *new_params, const TreeParams *new_tree_params); 38 | 39 | // Set time left (in second). 40 | // Note that this will not block all threads. 41 | BOOL ts_v2_set_time_left(void *ctx, unsigned int time_left, unsigned int num_moves); 42 | 43 | // Perform tree search given the current board and player. 44 | void ts_v2_search_start(void *h); 45 | void ts_v2_search_stop(void *h); 46 | 47 | // Turn on all threads. Return true if succeed. 48 | // It will not clear all statistics. 49 | void ts_v2_thread_on(void *h); 50 | void ts_v2_thread_off(void *h); 51 | 52 | // Return the best move as a result of the current search. 53 | // Move_seq must not be NULL and will store the move sequence (if l&d mode is on). 54 | Move ts_v2_pick_best(void *h, AllMoves *move_seq, const Board *verify_board); 55 | 56 | // Peek the topk move and save it to moves. 57 | void ts_v2_peek(void *h, int topk, Moves *moves, const Board *verify_board); 58 | 59 | // Output the current tree to a json file. 60 | void ts_v2_tree_to_json(void *h, const char *json_prefix); 61 | 62 | // Output the feature to a text file, one feature a line. L&D mode only. 63 | void ts_v2_tree_to_feature(void *ctx, const char *feature_prefix); 64 | 65 | // Once a move is picked, prune the tree accordingly. 66 | void ts_v2_prune_opponent(void *ctx, Coord m); 67 | void ts_v2_prune_ours(void *ctx, Coord m); 68 | 69 | // Undo the most recent pass. If no pass is the recent pass, do nothing. 70 | int ts_v2_undo_pass(void *h, const Board *before_board); 71 | 72 | // Free tree search handle. 73 | void ts_v2_free(void *h); 74 | 75 | #ifdef __cplusplus 76 | } 77 | #endif 78 | 79 | #endif 80 | -------------------------------------------------------------------------------- /mctsv2/playout_params.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _PLAYOUT_PARAMS_H_ 11 | #define _PLAYOUT_PARAMS_H_ 12 | 13 | #include "../board/board.h" 14 | #include "../common/common.h" 15 | #include "../common/package.h" 16 | 17 | // Verbose level 18 | #define V_SLIENT 0 19 | #define V_CRITICAL 1 20 | #define V_INFO 2 21 | #define V_DEBUG 3 22 | 23 | #define SERVER_LOCAL 0 24 | #define SERVER_CLUSTER 1 25 | 26 | #define THREAD_NEW_BLOCKED 0 27 | #define THREAD_ALREADY_BLOCKED 1 28 | #define THREAD_NEW_RESUMED 2 29 | #define THREAD_ALREADY_RESUMED 3 30 | #define THREAD_STILL_BLOCKED 4 31 | 32 | // Choice of default policies. 33 | #define DP_SIMPLE 0 34 | #define DP_PACHI 1 35 | #define DP_V2 2 36 | 37 | // Used for maximal time spent. 38 | #define THRES_PLY1 60 39 | #define THRES_PLY2 200 40 | #define THRES_PLY3 260 41 | #define THRES_TIME_CLOSE 180 42 | #define MIN_TIME_SPENT 1 43 | 44 | typedef struct { 45 | char pipe_path[200]; 46 | char tier_name[200]; 47 | 48 | // Whether we use local server or global server, could be SERVER_LOCAL or SERVER_CLUSTER 49 | int server_type; 50 | 51 | // Go rule, rule = RULE_CHINESE (default) or RULE_JAPANESE 52 | int rule; 53 | 54 | // Komi (This also includes handicap). 55 | float komi; 56 | 57 | // It seems that MCTS really slacks off when the estimated win rate is too high. So we need to use dynkomi. 58 | // Set dynkomi_factor = 0.0 would disable dynkomi. If used, usually we set it to 1.0. 59 | float dynkomi_factor; 60 | 61 | // Verbose level. 62 | int verbose; 63 | 64 | // #gpu we used. 65 | int num_gpu; 66 | 67 | // Only use cpu-based rollout. 68 | BOOL cpu_only; 69 | 70 | // Print search tree. 71 | BOOL print_search_tree; 72 | 73 | // If total_time == 0, then do not use heuristic time management. 74 | // Otherwise we assume total_time is all the time (in second) we have and we need to plan for it. 75 | int heuristic_tm_total_time; 76 | // Maximum time spent and minimal time spent, will be computed when heuristic_tm_total_time is set. 77 | float max_time_spent, min_time_spent; 78 | 79 | // Set the time_left, unit is second. 80 | // This is a bit special since we don't need to lock all threads and resume afterwards. 81 | // We just need to change the number, and the threads will take care of it. If time left is 0, then there is no 82 | // constraints on the time left. 83 | unsigned int time_left; 84 | } SearchParamsV2; 85 | 86 | typedef struct { 87 | float dynkomi; 88 | } SearchVariants; 89 | 90 | // Parameters for each monte carlo tree search. 91 | typedef struct { 92 | // The number of rollouts the root node should achieve per move. 93 | int num_rollout; 94 | // The number of rollouts each move should at least search. 95 | int num_rollout_per_move; 96 | 97 | // The number of dcnn evaluated required per move. 98 | int num_dcnn_per_move; 99 | 100 | // Expand leaf only if total >= expand_n_thres. 101 | int expand_n_thres; 102 | 103 | int verbose; 104 | 105 | // #move receivers. 106 | int num_receiver; 107 | 108 | int max_depth_default_policy; 109 | int max_send_attempts; 110 | 111 | // Number of CPU threads for MCTS trees. 112 | int num_tree_thread; 113 | 114 | // Whether we put noise during UCT. 115 | // The noise is to speed up the performance of MCTS+DCNN. If scores are deterministic, then MCTS will block on one node. 116 | // Which might be, actually a good thing. 117 | // When num_virtual_games > 0, then both sigma and sigma_over_n is not used. 118 | float sigma; 119 | 120 | // Whether we use sigma * sqrt(n_parent) / n. This will reduce sigma gradually when we are confidence on one node's win rate. 121 | BOOL use_sigma_over_n; 122 | 123 | // Decision mixture ratio between cnn_prediction_confidence and mcts count / parent count. 124 | // Final score = mcts_count_ratio + decision_mixture_ratio * cnn_confidence. 125 | float decision_mixture_ratio; 126 | 127 | // Receiver parameters. 128 | // Accumulated probability threshold (in percent). 129 | int rcv_acc_percent_thres; 130 | 131 | // Maximum number of move to pick. 132 | // Minimum number of move to pick. 133 | int rcv_max_num_move; 134 | int rcv_min_num_move; 135 | 136 | // Use pondering 137 | BOOL use_pondering; 138 | 139 | // Time limit for each move (in sec). 140 | long time_limit; 141 | 142 | // Immediate return if CNN only gives one best move. 143 | BOOL single_move_return; 144 | 145 | // Which default policy we are using. 146 | int default_policy_choice; 147 | // The name of pattern file. 148 | char pattern_filename[1000]; 149 | int default_policy_sample_topn; 150 | double default_policy_temperature; 151 | 152 | // Define minimal rollout so that the search procedure can be peekable. 153 | int min_rollout_peekable; 154 | 155 | // Use RAVE heuristics. 156 | BOOL use_rave; 157 | 158 | // Use sync/async model. 159 | // In async model, we will use fast rollout to fill in the moves first, when DCNN move returns, we append the moves. 160 | // In sync model, we just wait until DCNN moves return. 161 | BOOL use_async; 162 | int fast_rollout_max_move; 163 | 164 | // Tsumego mode. In this mode, we focus on a small region and generate a lot of moves to determine the life and death situation. 165 | BOOL life_and_death_mode; 166 | 167 | // The rectangle used for move generation (I don't think that is a good idea, but try that for now.) 168 | Region ld_region; 169 | 170 | // Whether we use hand-crafted features or from Tsumego DCNN model. 171 | BOOL use_tsumego_dcnn; 172 | 173 | // Specify which side is the defender. 174 | Stone defender; 175 | 176 | // Build an online model for search. 177 | BOOL use_online_model; 178 | // Learning rate. 179 | float online_model_alpha; 180 | 181 | // Online prior mixture ratio when the online model is open. 182 | float online_prior_mixture_ratio; 183 | 184 | // Use the win rate prediction to replace playout. 185 | BOOL use_cnn_final_score; 186 | 187 | // We use win rate prediction for ply >= min_ply_to_use_cnn_final_score. 188 | int min_ply_to_use_cnn_final_score; 189 | 190 | // Once we use win rate prediction, the mixture ratio between win rate prediction and actual playout score. 191 | // Final score = final_mixture_ratio * win_rate_prediction + (1.0 - final_mixture_ratio) * playout_result. 192 | float final_mixture_ratio; 193 | 194 | // Whether we use virtual game. If num_virtual_games == 0, then we will use sigma. 195 | int num_virtual_games; 196 | 197 | // Whether we run playout if a node is waiting for expansion. 198 | // If it is 0, then all threads will be blocked when waiting for CNN evaluation. 199 | // If it is 100, then all threads will continue (run playout and restart). 200 | // Ideally we only want a few node to wait so that they can be the next batch to expand child nodes. 201 | int percent_playout_in_expansion; 202 | 203 | // Run a few playout and takes their mean. Default is 1, but could be higher. 204 | int num_playout_per_rollout; 205 | 206 | // Whether we use PUCT and previous UCT 207 | BOOL use_old_uct; 208 | } TreeParams; 209 | 210 | #endif 211 | 212 | -------------------------------------------------------------------------------- /mctsv2/test_playout_multithread.c: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #include "playout_multithread.h" 11 | #include "tree_search.h" 12 | static BOOL check_correct = FALSE; 13 | 14 | int seed; 15 | void onexit() { 16 | printf("Random seed = %d\n", seed); 17 | } 18 | 19 | int main(int argc, char *argv[]) { 20 | int K, R; 21 | int nthread = 16; 22 | int num_gpu = 4; 23 | 24 | char server_type[100]; 25 | strcpy(server_type, "local"); 26 | 27 | if (check_correct) { 28 | K = 1000; 29 | R = 10; 30 | } else { 31 | K = 100000; 32 | R = 50; 33 | } 34 | 35 | SearchParamsV2 search_params; 36 | TreeParams tree_params; 37 | 38 | ts_v2_init_params(&search_params); 39 | tree_search_init_params(&tree_params); 40 | 41 | search_params.verbose = V_INFO; 42 | tree_params.verbose = V_INFO; 43 | // tree_params.verbose = V_DEBUG; 44 | // 45 | K = 1000; 46 | if (argc >= 2) sscanf(argv[1], "%s", server_type); 47 | if (argc >= 3) sscanf(argv[2], "%d", &K); 48 | if (argc >= 4) sscanf(argv[3], "%d", &nthread); 49 | if (argc >= 5) sscanf(argv[4], "%d", &num_gpu); 50 | if (argc >= 6) sscanf(argv[5], "%d", &R); 51 | 52 | if (! strcmp(server_type, "local")) { 53 | printf("Use local server\n"); 54 | search_params.server_type = SERVER_LOCAL; 55 | search_params.num_gpu = num_gpu; 56 | strcpy(search_params.pipe_path, "/data/local/go/"); 57 | } else { 58 | printf("Use cluster server = %s\n", server_type); 59 | search_params.server_type = SERVER_CLUSTER; 60 | search_params.num_gpu = num_gpu; 61 | strcpy(search_params.tier_name, server_type); 62 | } 63 | 64 | // search_params.print_search_tree = TRUE; 65 | 66 | tree_params.use_async = FALSE; 67 | search_params.cpu_only = FALSE; 68 | tree_params.expand_n_thres = 0; 69 | 70 | /* 71 | search_params.cpu_only = TRUE; 72 | tree_params.use_async = TRUE; 73 | tree_params.expand_n_thres = 40; 74 | */ 75 | 76 | tree_params.num_rollout = K; 77 | tree_params.num_rollout_per_move = K; 78 | tree_params.num_dcnn_per_move = K; 79 | tree_params.num_receiver = num_gpu; 80 | tree_params.num_tree_thread = nthread; 81 | tree_params.sigma = 0.05; 82 | // tree_params.num_virtual_games = 10; 83 | tree_params.decision_mixture_ratio = 5.0; 84 | tree_params.rcv_max_num_move = 20; 85 | tree_params.use_rave = FALSE; 86 | tree_params.use_online_model = FALSE; 87 | tree_params.online_model_alpha = 0.001; 88 | tree_params.online_prior_mixture_ratio = 5.0; 89 | tree_params.rcv_acc_percent_thres = 80; 90 | // tree_params.sigma = 0.00; 91 | tree_params.use_pondering = TRUE; 92 | strcpy(tree_params.pattern_filename, "../models/playout-model.bin"); 93 | tree_params.default_policy_choice = DP_V2; 94 | tree_params.default_policy_temperature = 0.125; 95 | 96 | // seed = time(NULL); 97 | seed = 1441648459; 98 | srand(seed); 99 | atexit(onexit); 100 | printf("K = %d, R = %d, nthread = %d, num_gpu = %d\n", K, R, nthread, num_gpu); 101 | 102 | // Random expand a few nodes. 103 | // 104 | double t; 105 | Board board; 106 | ClearBoard(&board); 107 | GroupId4 ids; 108 | AllMoves move_seq; 109 | 110 | void *tree_handle = ts_v2_init(&search_params, &tree_params, &board); 111 | ts_v2_print_params(tree_handle); 112 | ts_v2_search_start(tree_handle); 113 | 114 | char dbuf[200]; 115 | strcpy(dbuf, "/tmp/test_playout_multithread.XXXXXX"); 116 | if (!mkdtemp(dbuf)) { 117 | error("Could not create temporary directory"); 118 | } 119 | 120 | printf("Saving JSON tree dumps in %s\n", dbuf); 121 | char* buf = dbuf + strlen(dbuf); 122 | *buf++ = '/'; 123 | 124 | timeit 125 | for (int i = 0; i < R; ++i) { 126 | timeit 127 | // dprintf("======================= Start round %d out of %d ========================\n", i, R); 128 | Move m = ts_v2_pick_best(tree_handle, &move_seq, NULL); 129 | sprintf(buf, "mcts_tree_%d", i); 130 | ts_v2_tree_to_json(tree_handle, buf); 131 | ts_v2_prune_ours(tree_handle, m.m); 132 | 133 | // Save the current tree. 134 | // printf("Saving current tree to %s...\n", buf); 135 | // tree_print_out_cnn(buf, ts_get_tree_pool(tree_handle)); 136 | 137 | if (! TryPlay(&board, m.x, m.y, board._next_player, &ids) ) error("The move given by expansion should never fail!!"); 138 | Play(&board, &ids); 139 | 140 | ShowBoard(&board, SHOW_LAST_MOVE); 141 | printf("\n"); 142 | 143 | if (check_correct) { 144 | // dprintf("======================= Finish round %d out of %d ========================\n", i, R); 145 | // tree_pool_check(ts_get_tree_pool(tree_handle)); 146 | } 147 | endtime 148 | } 149 | endtime2(t) 150 | 151 | printf("Freeing\n"); 152 | timeit 153 | ts_v2_search_stop(tree_handle); 154 | ts_v2_free(tree_handle); 155 | endtime 156 | 157 | printf("Time used for mcts = %lf\n", t); 158 | printf("rollout rate = %f\n", K * R / t); 159 | return 0; 160 | } 161 | 162 | -------------------------------------------------------------------------------- /mctsv2/test_tree_multithread.c: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #include "tree.h" 11 | // #include "block_queue.h" 12 | #include 13 | #include "../common/common.h" 14 | #include 15 | 16 | #define CHECK_CORRECT 17 | 18 | #define NUM_THREAD 16 19 | // #define NUM_THREAD 1 20 | 21 | typedef struct { 22 | pthread_mutex_t mutex_alloc[NUM_THREAD]; 23 | // Seed shared by all threads (not a good design, but this is not important here). 24 | unsigned long seed; 25 | TreePool *p; 26 | int num_rollout; 27 | int num_rollout_per_thread; 28 | // Queue q; 29 | } SearchInfo; 30 | 31 | int thread_rand(void *context, int max_value) { 32 | SearchInfo *info = (SearchInfo *)context; 33 | return fast_random(&info->seed, max_value); 34 | } 35 | 36 | void init_callback(TreePool *p, TreeBlock *bl, void *context, void *context2) { 37 | // Here we just set the number of free to be active. 38 | bl->n = thread_rand(context, BLOCK_SIZE - 1) + 1; 39 | } 40 | 41 | void* thread_random_expansion(void *ctx) { 42 | // Copy a new board. No pointer in board. 43 | SearchInfo *info = (SearchInfo* )ctx; 44 | 45 | TreePool *p = info->p; 46 | 47 | for (int i = 0; i < info->num_rollout_per_thread; ++i) { 48 | // Random traverse down the tree and expand a node 49 | TreeBlock *b = p->root->children[0].child; 50 | // printf("---Start playout %d/%d [Round %d]---\n", i, K, round); 51 | BOOL expanded = FALSE; 52 | while (!expanded) { // Once expanded, leave the loop. 53 | // Pick a random children. 54 | if (b == TP_NULL) error("We should never visit TP_NULL."); 55 | // tree_show_block(p, b); 56 | // printf("Total children = %d\n", total_children); 57 | BlockOffset child_idx = thread_rand(ctx, b->n); 58 | TreeBlock *c = b->children[child_idx].child; 59 | 60 | // Expand it or go downward 61 | // printf("[Explore %d]: pick idx = %d/%d\n", ID(b), child_idx, b->n); 62 | if (c == TP_NULL) { 63 | int res = tree_simple_begin_expand(b, child_idx, &c); 64 | if (c == TP_NULL) { 65 | // expand the current tree. 66 | c = tree_simple_g_alloc(p, ctx, NULL, init_callback, b, child_idx); 67 | if (c == TP_NULL) { 68 | error("allocation error, b = 0"); 69 | } 70 | } 71 | expanded = TRUE; 72 | } 73 | b = c; 74 | } 75 | // Back prop. 76 | int black_count = thread_rand(ctx, 2); 77 | while (b != p->root) { 78 | TreeBlock * parent = b->parent; 79 | BlockOffset parent_offset = b->parent_offset; 80 | Stat* stat = &b->parent->data.stats[parent_offset]; 81 | 82 | // Add total first, otherwise the winning rate might go over 1. 83 | // stat->total += 1; 84 | // stat->black_win += black_count; 85 | __sync_fetch_and_add(&stat->total, 1); 86 | inc_atomic_float(&stat->black_win, (float)black_count); 87 | // __sync_fetch_and_add(&stat->black_win, black_count); 88 | 89 | b = parent; 90 | } 91 | } 92 | // printf("thread done. #rollout = %d\n", info->num_rollout_per_thread); 93 | return NULL; 94 | } 95 | 96 | void *thread_tree_search_daemon(void *ctx) { 97 | SearchInfo *info = (SearchInfo *)ctx; 98 | info->num_rollout_per_thread = (info->num_rollout + NUM_THREAD - 1) / NUM_THREAD; 99 | info->seed = 324; 100 | 101 | for (int i = 0; i < NUM_THREAD; ++i) pthread_mutex_init(&info->mutex_alloc[i], NULL); 102 | // Initialize queue. 103 | // queue_init(&info->q); 104 | 105 | // Split the task into nthread, and make a sparate thread for tree-leaf allocation. 106 | pthread_t explorers[NUM_THREAD]; 107 | TreePool *p = info->p; 108 | 109 | for (int i = 0; i < NUM_THREAD; ++i) { 110 | pthread_attr_t attr; 111 | pthread_attr_init(&attr); 112 | pthread_attr_setstacksize(&attr, 1048576); 113 | if (i == 0 && p->root->children[i].child == TP_NULL) { 114 | // We need to initialize root. 115 | p->root->children[i].child = 116 | tree_simple_g_alloc(p, info, NULL, init_callback, p->root, 0); 117 | } 118 | pthread_create(&explorers[i], &attr, thread_random_expansion, info); 119 | } 120 | 121 | // Wait until all finished. 122 | for (int i = 0; i < NUM_THREAD; ++i) { 123 | pthread_join(explorers[i], NULL); 124 | } 125 | 126 | for (int i = 0; i < NUM_THREAD; ++i) pthread_mutex_destroy(&info->mutex_alloc[i]); 127 | return NULL; 128 | // queue_release(&info->q); 129 | } 130 | 131 | int seed; 132 | void onexit() { 133 | printf("Random seed = %d\n", seed); 134 | } 135 | 136 | int main() { 137 | #ifdef CHECK_CORRECT 138 | const int K = 1000; 139 | const int R = 100; 140 | #else 141 | const int K = 100000; 142 | const int R = 100; 143 | #endif 144 | 145 | atexit(onexit); 146 | printf("K = %d, R = %d\n", K, R); 147 | 148 | TreePool p; 149 | timeit 150 | tree_simple_pool_init(&p); 151 | endtime 152 | 153 | // Random expand a few nodes. 154 | // 155 | double t; 156 | timeit 157 | for (int i = 0; i < R; ++i) { 158 | // printf("======================= Start round %d out of %d ========================\n", i, R); 159 | // random_expansion(&p, i, K); 160 | pthread_t daemon; 161 | SearchInfo info = { .p = &p, .num_rollout = K }; 162 | pthread_create(&daemon, NULL, thread_tree_search_daemon, &info); 163 | pthread_join(daemon, NULL); 164 | 165 | // First child. 166 | BlockOffset offset = FIRST_NONLEAF(p.root->children[0].child); 167 | tree_simple_free_except(&p, p.root->children[offset].child); 168 | #ifdef CHECK_CORRECT 169 | printf("======================= Finish round %d out of %d ========================\n", i, R); 170 | tree_simple_pool_check(&p); 171 | #endif 172 | } 173 | endtime2(t) 174 | 175 | printf("Freeing tree pool\n"); 176 | timeit 177 | tree_simple_pool_free(&p); 178 | endtime 179 | 180 | printf("rollout rate = %f\n", K * R / t); 181 | return 0; 182 | } 183 | -------------------------------------------------------------------------------- /mctsv2/test_tsumego.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | local sgf = require("utils.sgf") 11 | local goutils = require 'utils.goutils' 12 | local utils = require('utils.utils') 13 | local board = require("board.board") 14 | local common = require("common.common") 15 | local playout = require 'mctsv2.playout_multithread' 16 | 17 | local pl = require 'pl.import_into'() 18 | 19 | -- Tsumego example 20 | -- local example1 = '/home/yuandong/test/tsumego-1.sgf' 21 | local solver = { } 22 | 23 | -- opt: verbose, json_filename 24 | function solver.solve(tr, game, opt) 25 | local b = board.new() 26 | board.clear(b) 27 | 28 | -- Assume the board is cleared. And we will setup everything given this game. 29 | -- First put all the existing stones there. (apply_handicap is not a good name). 30 | goutils.apply_handicaps(b, game, true) 31 | 32 | if opt.verbose then board.show(b, 'last_move') end 33 | 34 | playout.set_tsumego_mode(tr, b, 1) 35 | if opt.verbose then 36 | playout.print_params(tr) 37 | print("Start playing! Defender: " .. (playout.tree_params.defender == common.white and 'W' or 'B')) 38 | end 39 | 40 | -- local filename = "mcts_tsumego.json" 41 | local m, move_seq = playout.play_rollout(tr, opt.json_filename) 42 | 43 | local s = "" 44 | local res = { } 45 | for i = 1, #move_seq do 46 | local x, y = common.coord2xy(move_seq[i]) 47 | table.insert(res, {x, y, b._next_player}) 48 | 49 | local coord, player = goutils.compose_move_gtp(x, y, b._next_player) 50 | s = s .. player .. " " .. coord .. " " 51 | 52 | board.play2(b, move_seq[i]) 53 | end 54 | 55 | return { 56 | solved = (m.win_rate == 1.0), 57 | move_seq = res, 58 | final_board = b, 59 | move_seq_str = s 60 | } 61 | end 62 | 63 | function solver.print_gt(game, opt) 64 | local board_stack = { } 65 | local first_variation = { } 66 | local first_variation_collected = false 67 | 68 | local b = board.new() 69 | board.clear(b) 70 | 71 | -- Assume the board is cleared. And we will setup everything given this game. 72 | -- First put all the existing stones there. (apply_handicap is not a good name). 73 | goutils.apply_handicaps(b, game, true) 74 | 75 | game:play_start( 76 | function () 77 | -- table.insert(board_stack, board.copyfrom(b)) 78 | end, 79 | function () 80 | -- b = table.remove(board_stack) 81 | if #first_variation > 0 then 82 | first_variation_collected = true 83 | end 84 | end) 85 | 86 | if opt.verbose then 87 | print("Initial board position = ") 88 | board.show(b, "last_move") 89 | end 90 | 91 | local s = "" 92 | while not first_variation_collected do 93 | local move = game:play_current() 94 | -- require 'fb.debugger'.enter() 95 | -- print("Move:") 96 | -- print(move) 97 | local x, y, player = sgf.parse_move(move, false, true) 98 | if x ~= nil then 99 | table.insert(first_variation, {x, y, player}) 100 | 101 | local c, player_str = goutils.compose_move_gtp(x, y, player) 102 | s = s .. player_str .. " " .. c .. " " 103 | if not board.play(b, x, y, player) then 104 | error("The move " .. player_str .. " " .. c .. " cannot be played!") 105 | end 106 | end 107 | 108 | -- Move to next. 109 | if not game:play_next() then break end 110 | end 111 | 112 | return { 113 | solved = true, 114 | move_seq = first_variation, 115 | final_board = b, 116 | move_seq_str = s 117 | } 118 | end 119 | 120 | local num_rollout = 300000 121 | -- local num_rollout = 100000 122 | -- playout.params.print_search_tree = common.TRUE 123 | -- playout.params.decision_mixture_ratio = 1.0 124 | playout.tree_params.sigma = 0.05 125 | playout.tree_params.use_online_model = common.TRUE 126 | playout.tree_params.online_model_alpha = 0.0001 127 | -- playout.params.verbose = 3 128 | -- playout.tree_params.verbose = 3 129 | playout.tree_params.num_tree_thread = 64 130 | 131 | local tr = playout.new(num_rollout) 132 | -- local sgf_list = '/home/yuandong/test/tsumego/tsumego.lst' 133 | local sgf_list = '/home/yuandong/test/tsumego/tsumego_sample.lst' 134 | local dirname = pl.path.dirname(sgf_list) 135 | local lines = pl.utils.readlines(sgf_list) 136 | 137 | local num_solved = 0 138 | local num_total = 0 139 | local topn = { } 140 | 141 | local print_gt = true 142 | local opt = { verbose = true, save_feature = true, output_json = true } 143 | 144 | for idx, f in ipairs(lines) do 145 | print(f) 146 | local game = sgf.parse(io.open(paths.concat(dirname, f)):read("*a")) 147 | 148 | if opt.output_json then 149 | opt.json_filename = string.format("mcts-%d.json", idx) 150 | end 151 | 152 | local sol = solver.solve(tr, game, opt) 153 | local sol_gt = solver.print_gt(game, opt) 154 | 155 | if opt.save_feature then 156 | local filename = string.format("feature-%d.feature", idx) 157 | playout.save_tree_feature(tr, filename) 158 | end 159 | 160 | if opt.verbose then 161 | print(string.format("%s, Move: %s", (sol.solved and "Solved" or "Unsolved"), sol.move_seq_str)) 162 | print("GT Move: " .. sol_gt.move_seq_str) 163 | -- board.show(sol.final_board, 'last_move') 164 | -- board.show(sol_gt.final_board, 'last_move') 165 | end 166 | 167 | -- Compare with gt solved, usually sol.move_seq is longer than sol_gt.move_seq 168 | for i = 1, math.min(#sol.move_seq, #sol_gt.move_seq) do 169 | local xmatch = (sol.move_seq[i][1] == sol_gt.move_seq[i][1]) 170 | local ymatch = (sol.move_seq[i][2] == sol_gt.move_seq[i][2]) 171 | if not topn[i] then 172 | topn[i] = { 0, 0 } 173 | end 174 | if xmatch and ymatch then 175 | topn[i][1] = topn[i][1] + 1 176 | end 177 | topn[i][2] = topn[i][2] + 1 178 | end 179 | 180 | if sol.solved then num_solved = num_solved + 1 end 181 | num_total = num_total + 1 182 | end 183 | 184 | print(string.format("Summary: solved/total: %d/%d", num_solved, num_total)) 185 | for i, c in ipairs(topn) do 186 | print(string.format("Top %d: %.2f (%d/%d)", i, 100 * c[1] / c[2], c[1], c[2])) 187 | end 188 | 189 | playout.free(tr) 190 | -------------------------------------------------------------------------------- /mctsv2/tree_search.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _TREE_SEARCH_H_ 11 | #define _TREE_SEARCH_H_ 12 | 13 | #include "playout_params.h" 14 | #include "playout_common.h" 15 | 16 | #ifdef __cplusplus 17 | extern "C" { 18 | #endif 19 | 20 | // Send/Receive callback. 21 | typedef BOOL (* func_send_board)(void *context, int, MBoard *b); 22 | // Receive the move from exchanger. 23 | typedef BOOL (* func_receive_move)(void *context, int, MMove *mmove); 24 | typedef int (* func_receiver_discard_move)(void *context, int); 25 | typedef void (* func_receiver_restart)(void *context); 26 | 27 | typedef struct { 28 | void *context; 29 | // Callbacks. 30 | func_send_board callback_send_board; 31 | func_receive_move callback_receive_move; 32 | func_receiver_discard_move callback_receiver_discard_move; 33 | func_receiver_restart callback_receiver_restart; 34 | } ExCallbacks; 35 | 36 | // ================================================ 37 | // APIs for tree search. 38 | void tree_search_init_params(TreeParams *params); 39 | 40 | void *tree_search_init(const SearchParamsV2 *common_params, const SearchVariants *variants, const ExCallbacks *callbacks, const TreeParams *params, const Board *board_init); 41 | void tree_search_free(void *ctx); 42 | 43 | void tree_search_print_params(void *ctx); 44 | BOOL tree_search_set_params(void *ctx, const TreeParams *new_params); 45 | 46 | // Start and stop the search. 47 | void tree_search_start(void *ctx); 48 | void tree_search_stop(void *ctx); 49 | 50 | // Stop and resume the threads. 51 | void tree_search_thread_off(void *ctx); 52 | void tree_search_thread_on(void *ctx); 53 | 54 | // Reset the entire tree. This happens when we setboard/setkomi etc. 55 | BOOL tree_search_reset_tree(void *ctx); 56 | BOOL tree_search_undo_pass(void *ctx, const Board *before_board); 57 | BOOL tree_search_set_board(void *ctx, const Board *new_board); 58 | 59 | void tree_search_print_tree(void *ctx); 60 | 61 | // Save the tree to json format. 62 | void tree_search_to_json(void *ctx, const Move *prev_moves, int num_prev_moves, const char *output_filename); 63 | 64 | // Save the tree to feature file (ARFF format). 65 | void tree_search_to_feature(void *ctx, const char *output_filename); 66 | 67 | // Return the best move. Remember to call tree_search_prune_ours after the decision is made. 68 | Move tree_search_pick_best(void *ctx, AllMoves *all_moves, const Board *verify_board); 69 | 70 | // Peek the top few moves, topk = moves->num_moves. 71 | BOOL tree_search_peek(void *ctx, Moves *moves, const Board *verify_board); 72 | 73 | // Prune the tree given the move. 74 | void tree_search_prune_opponent(void *ctx, Coord m); 75 | void tree_search_prune_ours(void *ctx, Coord m); 76 | 77 | #ifdef __cplusplus 78 | } 79 | #endif 80 | 81 | #endif 82 | -------------------------------------------------------------------------------- /mctsv2/tree_search_internal.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _TREE_SEARCH_INTERNAL_H_ 11 | #define _TREE_SEARCH_INTERNAL_H_ 12 | 13 | #include 14 | #include 15 | 16 | #include "tree_search.h" 17 | #include "tree.h" 18 | #include "../board/default_policy_common.h" 19 | 20 | // ========================== Data Structure ============================= 21 | struct __TreeHandle; 22 | 23 | typedef struct { 24 | struct __TreeHandle *s; 25 | int receiver_id; 26 | 27 | // Notification for receivers. 28 | pthread_mutex_t lock; 29 | 30 | int cnn_move_valid; 31 | int cnn_move_received; 32 | int cnn_move_discarded; 33 | int cnn_move_seq_mismatched; 34 | int cnn_move_board_hash_mismatched; 35 | } ReceiverParams; 36 | 37 | // This is one for each thread. 38 | typedef struct { 39 | // A pointer to search info common. 40 | struct __TreeHandle *s; 41 | // The exchanger id we should use for this thread. 42 | int ex_id; 43 | 44 | // Random seed used for this threaad. 45 | unsigned long seed; 46 | 47 | // For each thread, count #loops. 48 | int counter; 49 | 50 | // Counters. 51 | int num_policy_failed; 52 | int num_expand_failed; 53 | int leaf_expanded; 54 | int cnn_send_infunc; 55 | int cnn_send_attempt; 56 | int cnn_send_success; 57 | int use_ucb, use_cnn, use_async; 58 | int max_depth; 59 | // Count for preempt-expanding 60 | int preempt_playout_count; 61 | } ThreadInfo; 62 | 63 | // Some callback functions. 64 | typedef DefPolicyMove (* func_def_policy)(void *def_policy, void *context, RandFunc rand_func, Board* board, const Region *r, int max_depth, BOOL verbose); 65 | typedef float (* func_compute_score)(ThreadInfo *info, const Board *board); 66 | // When board_on_child is false, then the bl is at the same situation as board. 67 | // When board_on_child is true, then board points to the child that has been played on the board, and child_offset is that child's offset. 68 | typedef void (* func_back_prop)(ThreadInfo *info, float black_moku, Stone next_player, int end_ply, BOOL board_on_child, BlockOffset child_offset, TreeBlock *bl); 69 | typedef BOOL (* func_policy)(ThreadInfo *info, TreeBlock *bl, const Board *board, BlockOffset *offset, TreeBlock **child_chosen); 70 | typedef BOOL (* func_expand)(ThreadInfo *info, const Board *board, TreeBlock *b); 71 | 72 | #define SC_NOT_YET 0 73 | #define SC_TIME_OUT 1 74 | #define SC_DCNN_ROLLOUT_REACHED 2 75 | #define SC_TOTAL_ROLLOUT_REACHED 3 76 | #define SC_NO_NEW_DCNN_EVAL 4 77 | #define SC_SINGLE_MOVE_RETURN 5 78 | #define SC_NO_VALID_MOVE 6 79 | #define SC_TIME_LEFT_CLOSE 7 80 | #define SC_TIME_HEURISTIC_STAGE1 8 81 | #define SC_TIME_HEURISTIC_STAGE2 9 82 | #define SC_TIME_HEURISTIC_STAGE3 10 83 | #define SC_TIME_HEURISTIC_STAGE4 11 84 | 85 | typedef struct __TreeHandle { 86 | TreeParams params; 87 | ExCallbacks callbacks; 88 | const SearchParamsV2 *common_params; 89 | const SearchVariants *common_variants; 90 | 91 | // This sequence number for this search, used for cnn communication. 92 | long seq; 93 | 94 | // The internal board. 95 | Board board; 96 | 97 | volatile BOOL search_done; 98 | volatile BOOL receiver_done; 99 | 100 | // Tree Pool used among all threads. 101 | TreePool p; 102 | 103 | // Semaphore for all threads. 104 | // If all_threads_blocking_count == 0, then all threads are running, 105 | // Each call of block_all_threads will increase it by 1, if all_threads_blocking_count > 0, then all threads are blocked. 106 | // Each call of resume_all_threads will decrease it by 1, if all_threads_blocking_count == 0, then all threads are resumed. 107 | // Call resume_all_threads when all_threads_blocking_count == 0 will have no effect. 108 | int all_threads_blocking_count; 109 | sem_t sem_all_threads_unblocked, sem_all_threads_blocked; 110 | int threads_count; 111 | 112 | // Total rollout count. 113 | int rollout_count; 114 | // Total dcnn evaluation received. 115 | int dcnn_count; 116 | int prev_dcnn_count; 117 | BOOL all_stats_cleared; 118 | 119 | // The timestamp when the search start. It will be update when resume_all_threads are called. 120 | long ts_search_start; 121 | 122 | // The timestamp when command "genmove" is called. Use for time control. 123 | long ts_search_genmove_called; 124 | 125 | // Notification with search complete signal. 126 | pthread_mutex_t mutex_search_complete; 127 | sem_t sem_search_complete; 128 | int flag_search_complete; 129 | 130 | // Callbacks. 131 | func_def_policy callback_def_policy; 132 | func_compute_score callback_compute_score; 133 | func_back_prop callback_backprop; 134 | func_policy callback_policy; 135 | func_expand callback_expand; 136 | 137 | // Threads for searching. # = number of tree threads. 138 | pthread_t *explorers; 139 | ThreadInfo *infos; 140 | 141 | // For default policy. 142 | void *def_policy; 143 | 144 | // Fast rollout interface. 145 | void *fast_rollout_policy; 146 | 147 | // Move receivers. # = number of gpus 148 | pthread_t *move_receivers; 149 | ReceiverParams *move_params; 150 | 151 | // Whether the bot is pondering. (Think when the opponent is thinking) 152 | BOOL is_pondering; 153 | 154 | // Online linear model. Weights and bias here. For now we use one float for each location of the board. 155 | // w . x + b will predict a score of the current board. sigmoid(w . x + b) should give the winrate between [0, 1]. 156 | // Atomic operation is used to read/write the data. 157 | // We train the model online so that the best move is better than other moves. 158 | pthread_mutex_t mutex_online_model; 159 | float model_weights[MACRO_BOARD_SIZE * MACRO_BOARD_SIZE]; 160 | float model_bias; 161 | float model_acc_err; 162 | int model_count_err; 163 | 164 | // Model for which move to search first. 165 | // If a move causes win, put positive, if a move causes loss, put negative. 166 | int move_scores_black[BOUND_COORD]; 167 | int move_scores_white[BOUND_COORD]; 168 | } TreeHandle; 169 | 170 | // Some utilities. 171 | // ============================== Utility ===================================== 172 | extern inline unsigned int thread_rand(void *context, unsigned int max_value) { 173 | ThreadInfo *info = (ThreadInfo *)context; 174 | return fast_random(&info->seed, max_value); 175 | } 176 | 177 | extern inline unsigned int normal_rand(void *context, unsigned int max_value) { 178 | return rand() % max_value; 179 | } 180 | 181 | extern inline float thread_randf(ThreadInfo *info) { 182 | const int max_for_float = 32768; 183 | return ((float)fast_random(&info->seed, max_for_float)) / max_for_float; 184 | } 185 | 186 | #endif 187 | -------------------------------------------------------------------------------- /pachi_tactics/PACHI_LICENSE: -------------------------------------------------------------------------------- 1 | License 2 | ------- 3 | 4 | Pachi is distributed under the GPLv2 licence (see the COPYING file for 5 | details and full text of the licence); you are welcome to tweak it as 6 | you wish (contributing back upstream is welcome) and distribute 7 | it freely, but only together with the source code. You are welcome 8 | to make private modifications to the code (e.g. try new algorithms and 9 | approaches), use them internally or even to have your bot play on the 10 | internet and enter competitions, but as soon as you want to release it 11 | to the public, you need to release the source code as well. 12 | 13 | One exception is the Autotest framework, which is licenced under the 14 | terms of the MIT licence (close to public domain) - you are free to 15 | use it any way you wish. 16 | -------------------------------------------------------------------------------- /pachi_tactics/README.md: -------------------------------------------------------------------------------- 1 | Readme 2 | ======== 3 | 4 | Under this folder there is [Pachi](https://github.com/pasky/pachi)'s playout policy. DarkForest does not depend on it but use it as benchmark. 5 | When running cnnPlayerV2/cnnPlayerMCTSV2.lua, use switch `--default_policy pachi` to use pachi playout policy. 6 | -------------------------------------------------------------------------------- /pachi_tactics/board_interface.c: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #include "board_interface.h" 11 | 12 | Coord get_nlibs_of_group(const Board *b, group_t id, int k, Coord *libs) { 13 | if (k > 1 && libs == NULL) error("To get >= 2 libs, libs must not be NULL!"); 14 | 15 | int count = 0; 16 | Coord last = M_PASS; 17 | int lib_count = b->_groups[id].liberties; 18 | if (lib_count < k) error("The liberty count is %d and cannot get %d liberty points!\n", lib_count, k); 19 | TRAVERSE(b, id, c) { 20 | FOR4(c, _, cc) { 21 | if (b->_infos[cc].color == S_EMPTY) { 22 | last = cc; 23 | if (libs != NULL) { 24 | BOOL dup = FALSE; 25 | for (int j = 0; j < count; ++j) { 26 | if (libs[j] == cc) { 27 | dup = TRUE; 28 | break; 29 | } 30 | } 31 | if (!dup) { 32 | libs[count++] = cc; 33 | } 34 | } else { 35 | // Do not check duplicate. 36 | count ++; 37 | } 38 | if (count == k) return last; 39 | } 40 | } ENDFOR4 41 | } ENDTRAVERSE 42 | error("This should never be reached!"); 43 | return last; 44 | } 45 | 46 | Coord board_group_other_lib(const Board *b, group_t id, Coord to) { 47 | int count = 0; 48 | int lib_count = b->_groups[id].liberties; 49 | if (lib_count < 2) error("The liberty count is %d and cannot get the other liberty point!\n", lib_count); 50 | TRAVERSE(b, id, c) { 51 | FOR4(c, _, cc) { 52 | if (cc != to && b->_infos[cc].color == S_EMPTY) return cc; 53 | } ENDFOR4 54 | } ENDTRAVERSE 55 | error("This should never be reached!"); 56 | return M_PASS; 57 | } 58 | 59 | BOOL check_loc_adjacent_group(const Board *b, Coord loc, group_t group) { 60 | // Check if loc is adjacent to the group. 61 | FOR4(loc, _, c) { 62 | if (b->_infos[c].id == group) return TRUE; 63 | } ENDFOR4 64 | return FALSE; 65 | } 66 | 67 | int group_stone_count(Board *b, group_t group, int max_val) { 68 | int stone_count = b->_groups[group].stones; 69 | return stone_count > max_val ? max_val : stone_count; 70 | } 71 | 72 | int neighbor_count_at(const Board *b, Coord c, Stone player) { 73 | // Count the number of stones of that color at a particular location. 74 | int count = 0; 75 | FOR4(c, _, cc) { 76 | if (b->_infos[cc].color == player) count ++; 77 | } ENDFOR4 78 | return count; 79 | } 80 | 81 | int immediate_liberty_count(const Board *b, Coord c) { 82 | return neighbor_count_at(b, c, S_EMPTY); 83 | } 84 | 85 | group_t board_get_atari_neighbor(Board *b, Coord coord, Stone group_color) 86 | { 87 | FOR4(coord, _, c) { 88 | group_t g_id = b->_infos[c].id; 89 | const Group *g = &b->_groups[g_id]; 90 | if (g_id && board_at(b, c) == group_color && g->liberties == 1) 91 | return g_id; 92 | /* We return first match. */ 93 | } ENDFOR4 94 | return 0; 95 | } 96 | 97 | bool board_is_valid_move(const Board *b, struct move* m) { 98 | GroupId4 ids; 99 | return TryPlay(b, X(m->coord), Y(m->coord), m->color, &ids); 100 | } 101 | 102 | bool board_is_valid_play(const Board *b, Stone player, Coord m) { 103 | GroupId4 ids; 104 | return TryPlay(b, X(m), Y(m), player, &ids); 105 | } 106 | 107 | int board_play(Board *b, struct move *m) { 108 | GroupId4 ids; 109 | if (TryPlay(b, X(m->coord), Y(m->coord), m->color, &ids)) { 110 | Play(b, &ids); 111 | return 0; 112 | } else { 113 | return -1; 114 | } 115 | } 116 | 117 | 118 | -------------------------------------------------------------------------------- /pachi_tactics/board_interface.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _BOARD_INTERFACE_H_ 11 | #define _BOARD_INTERFACE_H_ 12 | 13 | #include "../board/board.h" 14 | 15 | // Some functions and macros for pachi code. 16 | #define group_at(b, c) (&(b)->_groups[b->_infos[(c)].id]) 17 | #define board_at(b, c) (b)->_infos[(c)].color 18 | 19 | #define board_atxy(b, x, y) (b)->_infos[OFFSETXY((x), (y))].color 20 | #define group_atxy(b, x, y) (b)->_infos[OFFSETXY((x), (y))].id 21 | // board_group_info(b, g). is replaced as g-> 22 | // 23 | #define S_OFFBOARD S_OFF_BOARD 24 | #define S_NONE S_EMPTY 25 | #define S_MAX 4 26 | 27 | // Handling true/false 28 | #define true TRUE 29 | #define false FALSE 30 | #define bool BOOL 31 | 32 | typedef short group_t; 33 | typedef Coord coord_t; 34 | 35 | struct move { 36 | Coord coord; 37 | Stone color; 38 | }; 39 | 40 | // Assume we have at least k slots in libs. 41 | // Return the last libs we found. When libs == NULL, don't put libs. 42 | // If we just want to get one, use Coord lib = get_nlib_of_group(b, group, 1, NULL); 43 | Coord get_nlibs_of_group(const Board *b, group_t group, int k, Coord *libs); 44 | Coord board_group_other_lib(const Board *b, group_t group, Coord to); 45 | bool check_loc_adjacent_group(const Board *b, Coord loc, group_t group); 46 | int group_stone_count(Board *b, group_t group, int max); 47 | 48 | int neighbor_count_at(const Board *b, Coord c, Stone player); 49 | int immediate_liberty_count(const Board *b, Coord c); 50 | 51 | group_t board_get_atari_neighbor(Board *b, Coord coord, Stone group_color); 52 | bool board_is_valid_move(const Board *b, struct move* m); 53 | bool board_is_valid_play(const Board *b, Stone player, Coord m); 54 | int board_play(Board *board, struct move *m); 55 | 56 | extern inline int board_size(const Board *b) { return 19; } 57 | extern inline bool is_pass(Coord m) { return m == M_PASS; } 58 | extern inline bool is_resign(Coord m) { return m == M_RESIGN; } 59 | // inline bool board_large(const Board *b) { return true; } 60 | extern inline bool group_is_onestone(const Board *b, group_t group) { return b->_groups[group].stones == 1; } 61 | 62 | #define PLDEBUGL(n) false 63 | #define DEBUGL(n) false 64 | 65 | #define likely(x) __builtin_expect(!!(x), 1) 66 | #define unlikely(x) __builtin_expect((x), 0) 67 | 68 | inline char *coord2sstr(Coord c, const Board *board) { return (char *)""; } 69 | 70 | #define board_is_one_point_eye IsTrueEye 71 | #define board_is_eyelike IsEye 72 | #define stone_other OPPONENT 73 | #define board_large(b) true 74 | 75 | #endif 76 | 77 | -------------------------------------------------------------------------------- /pachi_tactics/fixp.h: -------------------------------------------------------------------------------- 1 | #ifndef PACHI_FIXP_H 2 | #define PACHI_FIXP_H 3 | 4 | /* Tools for counting fixed-point numbers. */ 5 | 6 | /* We implement a simple fixed-point number type, with fixed number of 7 | * fractional binary digits after the radix point. */ 8 | 9 | #include 10 | 11 | typedef uint_fast32_t fixp_t; 12 | 13 | /* We should accomodate at least 0..131072 (17bits) in the whole number 14 | * portion; assuming at least 32bit integer, that leaves us with 15-bit 15 | * fractional part. Thankfully, we need only unsigned values. */ 16 | #define FIXP_BITS 15 17 | 18 | #define FIXP_SCALE (1< 12 | 13 | int main() { 14 | char buf[100]; 15 | // After that we use moggy to analyze it. (With default parameters). 16 | void* policy = playout_moggy_init(NULL); 17 | 18 | // Put a few random moves and ask mogo to play it out. 19 | AllMoves all_moves; 20 | GroupId4 ids; 21 | const int num_round = 10000; 22 | 23 | Board b; 24 | timeit 25 | for (int j = 0; j < num_round; j++) { 26 | printf("Round = %d/%d\n", j, num_round); 27 | ClearBoard(&b); 28 | //for (int i = 0; i < rand() % 100 + 1; ++i) { 29 | for (int i = 0; i < 0; ++i) { 30 | FindAllCandidateMoves(&b, b._next_player, 0, &all_moves); 31 | // Randomly find one move. 32 | int idx = rand() % all_moves.num_moves; 33 | if (TryPlay2(&b, all_moves.moves[idx], &ids)) { 34 | Play(&b, &ids); 35 | } 36 | } 37 | 38 | ShowBoard(&b, SHOW_LAST_MOVE); 39 | play_random_game(policy, NULL, NULL, &b, -1, FALSE); 40 | ShowBoard(&b, SHOW_LAST_MOVE); 41 | } 42 | endtime 43 | 44 | playout_moggy_destroy(policy); 45 | return 0; 46 | } 47 | -------------------------------------------------------------------------------- /pachi_tactics/mq.h: -------------------------------------------------------------------------------- 1 | #ifndef PACHI_MQ_H 2 | #define PACHI_MQ_H 3 | 4 | /* Move queues; in fact, they are more like move lists, usually used 5 | * to accumulate equally good move candidates, then choosing from them 6 | * randomly. But they are also used to juggle group lists (using the 7 | * fact that Coord == group_t). */ 8 | 9 | #include "../board/board.h" 10 | #include "board_interface.h" 11 | #include "fixp.h" 12 | 13 | #define MQL 512 /* XXX: On larger board this might not be enough. */ 14 | struct move_queue { 15 | unsigned int moves; 16 | Coord move[MQL]; 17 | /* Each move can have an optional tag or set of tags. 18 | * The usage of these is user-dependent. */ 19 | unsigned char tag[MQL]; 20 | }; 21 | 22 | /* Pick a random move from the queue. */ 23 | static Coord mq_pick(void *context, RandFunc randfunc, struct move_queue *q); 24 | 25 | /* Add a move to the queue. */ 26 | static void mq_add(struct move_queue *q, Coord c, unsigned char tag); 27 | 28 | /* Cat two queues together. */ 29 | static void mq_append(struct move_queue *qd, struct move_queue *qs); 30 | 31 | /* Check if the last move in queue is not a dupe, and remove it 32 | * in that case. */ 33 | static void mq_nodup(struct move_queue *q); 34 | 35 | /* Print queue contents on stderr. */ 36 | static void mq_print(struct move_queue *q, Board *b, char *label); 37 | 38 | 39 | /* Variations of the above that allow move weighting. */ 40 | /* XXX: The "kinds of move queue" issue (it's even worse in some other 41 | * branches) is one of the few good arguments for C++ in Pachi... 42 | * At least rewrite it to be less hacky and maybe make a move_gamma_queue 43 | * that encapsulates move_queue. */ 44 | 45 | static Coord mq_gamma_pick(void *context, RandFunc randfunc, struct move_queue *q, fixp_t *gammas); 46 | static void mq_gamma_add(struct move_queue *q, fixp_t *gammas, Coord c, double gamma, unsigned char tag); 47 | static void mq_gamma_print(struct move_queue *q, fixp_t *gammas, Board *b, char *label); 48 | 49 | /* Use this one if you want larger numbers. */ 50 | static inline uint32_t 51 | fast_irandom(void *context, RandFunc randfunc, unsigned int max) 52 | { 53 | if (max <= 65536) 54 | return randfunc(context, max); 55 | int himax = (max - 1) / 65536; 56 | uint16_t hi = randfunc(context, himax + 1); 57 | return ((uint32_t)hi << 16) | randfunc(context, hi < himax ? 65536 : max % 65536); 58 | } 59 | 60 | static inline Coord 61 | mq_pick(void *context, RandFunc func, struct move_queue *q) 62 | { 63 | return q->moves ? q->move[func(context, q->moves)] : M_PASS; 64 | } 65 | 66 | static inline void 67 | mq_add(struct move_queue *q, Coord c, unsigned char tag) 68 | { 69 | assert(q->moves < MQL); 70 | q->tag[q->moves] = tag; 71 | q->move[q->moves++] = c; 72 | } 73 | 74 | static inline void 75 | mq_append(struct move_queue *qd, struct move_queue *qs) 76 | { 77 | assert(qd->moves + qs->moves < MQL); 78 | memcpy(&qd->tag[qd->moves], qs->tag, qs->moves * sizeof(*qs->tag)); 79 | memcpy(&qd->move[qd->moves], qs->move, qs->moves * sizeof(*qs->move)); 80 | qd->moves += qs->moves; 81 | } 82 | 83 | static inline void 84 | mq_nodup(struct move_queue *q) 85 | { 86 | for (unsigned int i = 1; i < 4; i++) { 87 | if (q->moves <= i) 88 | return; 89 | if (q->move[q->moves - 1 - i] == q->move[q->moves - 1]) { 90 | q->tag[q->moves - 1 - i] |= q->tag[q->moves - 1]; 91 | q->moves--; 92 | return; 93 | } 94 | } 95 | } 96 | 97 | static inline void 98 | mq_print(struct move_queue *q, Board *b, char *label) 99 | { 100 | fprintf(stderr, "%s candidate moves: ", label); 101 | for (unsigned int i = 0; i < q->moves; i++) { 102 | fprintf(stderr, "%s ", coord2sstr(q->move[i], b)); 103 | } 104 | fprintf(stderr, "\n"); 105 | } 106 | 107 | static inline Coord 108 | mq_gamma_pick(void *context, RandFunc func, struct move_queue *q, fixp_t *gammas) 109 | { 110 | if (!q->moves) 111 | return M_PASS; 112 | fixp_t total = 0; 113 | for (unsigned int i = 0; i < q->moves; i++) { 114 | total += gammas[i]; 115 | } 116 | if (!total) 117 | return M_PASS; 118 | fixp_t stab = fast_irandom(context, func, total); 119 | for (unsigned int i = 0; i < q->moves; i++) { 120 | if (stab < gammas[i]) 121 | return q->move[i]; 122 | stab -= gammas[i]; 123 | } 124 | assert(0); 125 | return M_PASS; 126 | } 127 | 128 | static inline void 129 | mq_gamma_add(struct move_queue *q, fixp_t *gammas, Coord c, double gamma, unsigned char tag) 130 | { 131 | mq_add(q, c, tag); 132 | gammas[q->moves - 1] = double_to_fixp(gamma); 133 | } 134 | 135 | static inline void 136 | mq_gamma_print(struct move_queue *q, fixp_t *gammas, Board *b, char *label) 137 | { 138 | fprintf(stderr, "%s candidate moves: ", label); 139 | for (unsigned int i = 0; i < q->moves; i++) { 140 | fprintf(stderr, "%s(%.3f) ", coord2sstr(q->move[i], b), fixp_to_double(gammas[i])); 141 | } 142 | fprintf(stderr, "\n"); 143 | } 144 | 145 | #endif 146 | -------------------------------------------------------------------------------- /pachi_tactics/tactics/1lib.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #define DEBUG 6 | #include "../../board/board.h" 7 | #include "../mq.h" 8 | #include "1lib.h" 9 | #include "ladder.h" 10 | #include "selfatari.h" 11 | 12 | /* Whether to avoid capturing/atariing doomed groups (this is big 13 | * performance hit and may reduce playouts balance; it does increase 14 | * the strength, but not quite proportionally to the performance). */ 15 | //#define NO_DOOMED_GROUPS 16 | // 17 | static bool 18 | can_play_on_lib(Board *b, group_t group, Stone to_play) 19 | { 20 | Coord capture = get_nlibs_of_group(b, group, 1, NULL); 21 | /* 22 | if (DEBUGL(6)) 23 | fprintf(stderr, "can capture group %d (%s)?\n", 24 | g, coord2sstr(capture, b)); 25 | */ 26 | /* Does playing on the liberty usefully capture the group? */ 27 | GroupId4 ids; 28 | if (TryPlay2(b, capture, &ids) && !is_bad_selfatari(b, to_play, capture)) 29 | return true; 30 | 31 | return false; 32 | } 33 | 34 | /* For given position @c, decide if this is a group that is in danger from 35 | * @capturer and @to_play can do anything about it (play at the last 36 | * liberty to either capture or escape). */ 37 | /* Note that @to_play is important; e.g. consider snapback, it's good 38 | * to play at the last liberty by attacker, but not defender. */ 39 | static inline __attribute__((always_inline)) bool 40 | capturable_group(Board *b, Stone capturer, Coord c, Stone to_play) 41 | { 42 | group_t g_id = b->_infos[c].id; 43 | int libs = b->_groups[g_id].liberties; 44 | if (board_at(b, c) != OPPONENT(capturer) || libs > 1) 45 | return false; 46 | 47 | return can_play_on_lib(b, g_id, to_play); 48 | } 49 | 50 | bool can_countercapture(Board *b, Stone owner, group_t id, Stone to_play, struct move_queue *q, int tag) { 51 | // [TODO]: Need to fix this. 52 | //if (b->clen < 2) 53 | // return false; 54 | 55 | unsigned int qmoves_prev = q ? q->moves : 0; 56 | Group *g = &b->_groups[id]; 57 | 58 | TRAVERSE(b, id, c) { 59 | FOR4(c, _, cc) { 60 | if (!capturable_group(b, owner, c, to_play)) 61 | continue; 62 | 63 | if (!q) { 64 | return true; 65 | } 66 | // mq_add(q, board_group_info(b, group_at(b, c)).lib[0], tag); 67 | mq_add(q, get_nlibs_of_group(b, id, 1, NULL), tag); 68 | mq_nodup(q); 69 | } ENDFOR4 70 | } ENDTRAVERSE 71 | 72 | bool can = q ? q->moves > qmoves_prev : false; 73 | return can; 74 | } 75 | 76 | #ifdef NO_DOOMED_GROUPS 77 | static bool can_be_rescued(Board *b, group_t group, Stone color, int tag) 78 | { 79 | /* Does playing on the liberty rescue the group? */ 80 | if (can_play_on_lib(b, group, color)) 81 | return true; 82 | 83 | /* Then, maybe we can capture one of our neighbors? */ 84 | return can_countercapture(b, color, group, color, NULL, tag); 85 | } 86 | #endif 87 | 88 | void 89 | group_atari_check(void *context, RandFunc randfunc, unsigned int alwaysccaprate, Board *b, group_t group, Stone to_play, 90 | struct move_queue *q, Coord *ladder, bool middle_ladder, int tag) 91 | { 92 | Group *g = &b->_groups[group]; 93 | Stone color = g->color; 94 | // Coord lib = board_group_info(b, group).lib[0]; 95 | Coord lib = get_nlibs_of_group(b, group, 1, NULL); 96 | 97 | assert(color != S_OFFBOARD && color != S_NONE); 98 | /* 99 | if (DEBUGL(5)) 100 | fprintf(stderr, "[%s] atariiiiiiiii %s of color %d\n", 101 | coord2sstr(group, b), coord2sstr(lib, b), color); 102 | */ 103 | assert(board_at(b, lib) == S_NONE); 104 | 105 | if (to_play != color) { 106 | /* We are the attacker! In that case, do not try defending 107 | * our group, since we can capture the culprit. */ 108 | #ifdef NO_DOOMED_GROUPS 109 | /* Do not remove group that cannot be saved by the opponent. */ 110 | if (!can_be_rescued(b, group, color, tag)) 111 | return; 112 | #endif 113 | if (can_play_on_lib(b, group, to_play)) { 114 | mq_add(q, lib, tag); 115 | mq_nodup(q); 116 | } 117 | return; 118 | } 119 | 120 | /* Can we capture some neighbor? */ 121 | bool ccap = can_countercapture(b, color, group, to_play, q, tag); 122 | if (ccap && !ladder && alwaysccaprate > randfunc(context, 100)) 123 | return; 124 | 125 | /* Otherwise, do not save kos. */ 126 | if (g->stones == 1 && neighbor_count_at(b, lib, color) + neighbor_count_at(b, lib, S_OFFBOARD) == 4) { 127 | /* Except when the ko is for an eye! */ 128 | bool eyeconnect = false; 129 | FORDIAG4(lib, _, c) { 130 | if (board_at(b, c) == S_NONE && neighbor_count_at(b, c, color) + neighbor_count_at(b, c, S_OFFBOARD) == 4) { 131 | eyeconnect = true; 132 | break; 133 | } 134 | } ENDFORDIAG4 135 | if (!eyeconnect) 136 | return; 137 | } 138 | 139 | /* Do not suicide... */ 140 | if (!can_play_on_lib(b, group, to_play)) 141 | return; 142 | /* 143 | if (DEBUGL(6)) 144 | fprintf(stderr, "...escape route valid\n"); 145 | */ 146 | 147 | /* ...or play out ladders (unless we can counter-capture anytime). */ 148 | if (!ccap) { 149 | if (is_ladder(b, lib, group, middle_ladder)) { 150 | /* Sometimes we want to keep the ladder move in the 151 | * queue in order to discourage it. */ 152 | if (!ladder) 153 | return; 154 | else 155 | *ladder = lib; 156 | } 157 | } 158 | 159 | mq_add(q, lib, tag); 160 | mq_nodup(q); 161 | } 162 | -------------------------------------------------------------------------------- /pachi_tactics/tactics/1lib.h: -------------------------------------------------------------------------------- 1 | #ifndef PACHI_TACTICS_1LIB_H 2 | #define PACHI_TACTICS_1LIB_H 3 | 4 | /* One-liberty tactical checks (i.e. dealing with atari situations). */ 5 | 6 | #include "../../board/board.h" 7 | #include "../board_interface.h" 8 | #include "../mq.h" 9 | 10 | /* For given atari group @group owned by @owner, decide if @to_play 11 | * can save it / keep it in danger by dealing with one of the 12 | * neighboring groups. */ 13 | bool can_countercapture(Board *b, Stone owner, group_t g, 14 | Stone to_play, struct move_queue *q, int tag); 15 | 16 | /* Examine given group in atari, suggesting suitable moves for player 17 | * @to_play to deal with it (rescuing or capturing it). */ 18 | /* ladder != NULL implies to always enqueue all relevant moves. */ 19 | void group_atari_check(void *context, RandFunc randfunc, unsigned int alwaysccaprate, Board *b, group_t group, Stone to_play, 20 | struct move_queue *q, Coord *ladder, bool middle_ladder, int tag); 21 | 22 | #endif 23 | -------------------------------------------------------------------------------- /pachi_tactics/tactics/2lib.h: -------------------------------------------------------------------------------- 1 | #ifndef PACHI_TACTICS_2LIB_H 2 | #define PACHI_TACTICS_2LIB_H 3 | 4 | /* Two-liberty tactical checks (i.e. dealing with two-step capturing races, 5 | * preventing atari). */ 6 | 7 | #include "../../board/board.h" 8 | #include "../board_interface.h" 9 | 10 | struct move_queue; 11 | 12 | void can_atari_group(void *context, RandFunc randfunc, Board *b, group_t group, Stone owner, Stone to_play, struct move_queue *q, int tag, bool use_def_no_hopeless); 13 | void group_2lib_check(void *context, RandFunc randfunc, Board *b, group_t group, Stone to_play, struct move_queue *q, int tag, bool use_miaisafe, bool use_def_no_hopeless); 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /pachi_tactics/tactics/TARGETS: -------------------------------------------------------------------------------- 1 | cpp_library( 2 | name = 'pachi_tactics_c', 3 | srcs = [ 4 | '1lib.c', 5 | '2lib.c', 6 | 'nlib.c', 7 | 'selfatari.c', 8 | 'ladder.c', 9 | 'nakade.c' 10 | ], 11 | deps = [ 12 | '@/experimental/deeplearning/yuandong/go/board:board_c', 13 | ], 14 | ) 15 | 16 | 17 | -------------------------------------------------------------------------------- /pachi_tactics/tactics/ladder.h: -------------------------------------------------------------------------------- 1 | #ifndef PACHI_TACTICS_LADDER_H 2 | #define PACHI_TACTICS_LADDER_H 3 | 4 | /* Reading ladders. */ 5 | #include "../../board/board.h" 6 | #include "../board_interface.h" 7 | 8 | /* Check if escaping on this liberty by given group in atari would play out 9 | * a simple ladder. */ 10 | /* Two ways of ladder reading can be enabled separately; simple first-line 11 | * ladders and trivial middle-board ladders. */ 12 | static bool is_ladder(Board *b, Coord coord, group_t laddered, bool test_middle); 13 | 14 | /* Check if a 2-lib group of color @lcolor escaping at @escapelib would be 15 | * caught in a ladder given opponent stone at @chaselib. */ 16 | bool wouldbe_ladder(Board *b, group_t group, Coord escapelib, Coord chaselib, Stone lcolor); 17 | 18 | bool is_border_ladder(Board *b, Coord coord, Stone lcolor); 19 | bool is_middle_ladder(Board *b, Coord coord, group_t group, Stone lcolor); 20 | static inline bool 21 | is_ladder(Board *b, Coord coord, group_t laddered, bool test_middle) 22 | { 23 | Stone lcolor = b->_groups[laddered].color; 24 | 25 | /* 26 | if (DEBUGL(6)) 27 | fprintf(stderr, "ladder check - does %s play out %s's laddered group %s?\n", 28 | coord2sstr(coord, b), stone2str(lcolor), coord2sstr(laddered, b)); 29 | */ 30 | 31 | /* First, special-case first-line "ladders". This is a huge chunk 32 | * of ladders we actually meet and want to play. */ 33 | if (neighbor_count_at(b, coord, S_OFFBOARD) == 1 34 | && neighbor_count_at(b, coord, lcolor) == 1) { 35 | bool l = is_border_ladder(b, coord, lcolor); 36 | // if (DEBUGL(6)) fprintf(stderr, "border ladder solution: %d\n", l); 37 | return l; 38 | } 39 | 40 | bool l = test_middle && is_middle_ladder(b, coord, laddered, lcolor); 41 | // if (DEBUGL(6)) fprintf(stderr, "middle ladder solution: %d\n", l); 42 | return l; 43 | } 44 | 45 | #endif 46 | -------------------------------------------------------------------------------- /pachi_tactics/tactics/nakade.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #define DEBUG 6 | #include "nakade.h" 7 | 8 | Coord 9 | nakade_point(Board *b, Coord around, Stone color) 10 | { 11 | /* First, examine the nakade area. For sure, it must be at most 12 | * six points. And it must be within color group(s). */ 13 | #define NAKADE_MAX 6 14 | Coord area[NAKADE_MAX]; int area_n = 0; 15 | 16 | area[area_n++] = around; 17 | 18 | for (int i = 0; i < area_n; i++) { 19 | FOR4(area[i], _, c) { 20 | if (board_at(b, c) == OPPONENT(color)) 21 | return M_PASS; 22 | if (board_at(b, c) == S_NONE) { 23 | bool dup = false; 24 | for (int j = 0; j < area_n; j++) 25 | if (c == area[j]) { 26 | dup = true; 27 | break; 28 | } 29 | if (dup) continue; 30 | 31 | if (area_n >= NAKADE_MAX) { 32 | /* Too large nakade area. */ 33 | return M_PASS; 34 | } 35 | area[area_n++] = c; 36 | } 37 | } ENDFOR4 38 | } 39 | 40 | /* We also collect adjecency information - how many neighbors 41 | * we have for each area point, and histogram of this. This helps 42 | * us verify the appropriate bulkiness of the shape. */ 43 | int neighbors[area_n]; int ptbynei[9] = {area_n, 0}; 44 | memset(neighbors, 0, sizeof(neighbors)); 45 | for (int i = 0; i < area_n; i++) { 46 | for (int j = i + 1; j < area_n; j++) 47 | if (NEIGHBOR4(area[i], area[j])) { 48 | ptbynei[neighbors[i]]--; 49 | neighbors[i]++; 50 | ptbynei[neighbors[i]]++; 51 | ptbynei[neighbors[j]]--; 52 | neighbors[j]++; 53 | ptbynei[neighbors[j]]++; 54 | } 55 | } 56 | 57 | /* For each given neighbor count, arbitrary one coordinate 58 | * featuring that. */ 59 | Coord coordbynei[9]; 60 | for (int i = 0; i < area_n; i++) 61 | coordbynei[neighbors[i]] = area[i]; 62 | 63 | switch (area_n) { 64 | case 1: return M_PASS; 65 | case 2: return M_PASS; 66 | case 3: assert(ptbynei[2] == 1); 67 | return coordbynei[2]; // middle point 68 | case 4: if (ptbynei[3] != 1) return M_PASS; // long line 69 | return coordbynei[3]; // tetris four 70 | case 5: if (ptbynei[3] == 1 && ptbynei[1] == 1) return coordbynei[3]; // bulky five 71 | if (ptbynei[4] == 1) return coordbynei[4]; // cross five 72 | return M_PASS; // long line 73 | case 6: if (ptbynei[4] == 1 && ptbynei[2] == 3) 74 | return coordbynei[4]; // rabbity six 75 | return M_PASS; // anything else 76 | default: assert(0); 77 | } 78 | // This should never be reached. 79 | return M_PASS; 80 | } 81 | -------------------------------------------------------------------------------- /pachi_tactics/tactics/nakade.h: -------------------------------------------------------------------------------- 1 | #ifndef PACHI_TACTICS_NAKADE_H 2 | #define PACHI_TACTICS_NAKADE_H 3 | 4 | /* Piercing eyes. */ 5 | 6 | #include "../../board/board.h" 7 | #include "../board_interface.h" 8 | 9 | /* Find an eye-piercing point within the @around area of empty board 10 | * internal to group of color @color. 11 | * Returns pass if the area is not a nakade shape or not internal. */ 12 | Coord nakade_point(Board *b, Coord around, Stone color); 13 | 14 | #endif 15 | -------------------------------------------------------------------------------- /pachi_tactics/tactics/nlib.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #define DEBUG 6 | #include "../../board/board.h" 7 | #include "../mq.h" 8 | #include "2lib.h" 9 | #include "nlib.h" 10 | #include "selfatari.h" 11 | 12 | 13 | void 14 | group_nlib_defense_check(void *context, RandFunc randfunc, Board *b, group_t group, Stone to_play, struct move_queue *q, int tag) 15 | { 16 | Stone color = to_play; 17 | assert(color != S_OFFBOARD && color != S_NONE && color == b->_groups[group].color); 18 | 19 | /* 20 | if (DEBUGL(5)) 21 | fprintf(stderr, "[%s] nlib defense check of color %d\n", 22 | coord2sstr(group, b), color); 23 | */ 24 | 25 | #if 0 26 | /* XXX: The code below is specific for 3-liberty groups. Its impact 27 | * needs to be tested first, and possibly moved to a more appropriate 28 | * place. */ 29 | 30 | /* First, look at our liberties. */ 31 | int continuous = 0, enemy = 0, spacy = 0, eyes = 0; 32 | for (int i = 0; i < 3; i++) { 33 | coord_t c = board_group_info(b, group).lib[i]; 34 | eyes += board_is_one_point_eye(b, c, to_play); 35 | continuous += coord_is_adjecent(c, board_group_info(b, group).lib[(i + 1) % 3], b); 36 | enemy += neighbor_count_at(b, c, stone_other(color)); 37 | spacy += immediate_liberty_count(b, c) > 1; 38 | } 39 | 40 | /* Safe groups are boring. */ 41 | if (eyes > 1) 42 | return; 43 | 44 | /* If all our liberties are in single line and they are internal, 45 | * this is likely a tiny three-point eyespace that we rather want 46 | * to live at! */ 47 | assert(continuous < 3); 48 | if (continuous == 2 && !enemy && spacy == 1) { 49 | assert(!eyes); 50 | int i; 51 | for (i = 0; i < 3; i++) 52 | if (immediate_liberty_count(b, board_group_info(b, group).lib[i]) == 2) 53 | break; 54 | /* Play at middle point. */ 55 | mq_add(q, board_group_info(b, group).lib[i], tag); 56 | mq_nodup(q); 57 | return; 58 | } 59 | #endif 60 | 61 | /* "Escaping" (gaining more liberties) with many-liberty group 62 | * is difficult. Do not even try. */ 63 | 64 | /* There is another way to gain safety - through winning semeai 65 | * with another group. */ 66 | /* We will not look at taking liberties of enemy n-groups, since 67 | * we do not try to gain liberties for own n-groups. That would 68 | * be really unbalanced (and most of our liberty-taking moves 69 | * would be really stupid, most likely). */ 70 | 71 | /* However, it is possible that we must start capturing a 2-lib 72 | * neighbor right now, because of approach liberties. Therefore, 73 | * we will check for this case. If we take a liberty of a group 74 | * even if we could have waited another move, no big harm done 75 | * either. */ 76 | 77 | TRAVERSE(b, group, gg) { 78 | FOR4(gg, _, c) { 79 | if (board_at(b, c) != OPPONENT(color)) 80 | continue; 81 | 82 | group_t g2_id = b->_infos[c].id; 83 | Group *g2 = &b->_groups[g2_id]; 84 | if (g2->liberties != 2) continue; 85 | can_atari_group(context, randfunc, b, g2_id, OPPONENT(color), to_play, q, tag, true /* XXX */); 86 | } ENDFOR4 87 | } ENDTRAVERSE 88 | } 89 | -------------------------------------------------------------------------------- /pachi_tactics/tactics/nlib.h: -------------------------------------------------------------------------------- 1 | #ifndef PACHI_TACTICS_NLIB_H 2 | #define PACHI_TACTICS_NLIB_H 3 | 4 | /* N-liberty semeai defense tactical checks. */ 5 | 6 | #include "../../board/board.h" 7 | 8 | struct move_queue; 9 | 10 | void group_nlib_defense_check(void *context, RandFunc func, Board *b, group_t group, Stone to_play, struct move_queue *q, int tag); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /pachi_tactics/tactics/selfatari.h: -------------------------------------------------------------------------------- 1 | #ifndef PACHI_TACTICS_SELFATARI_H 2 | #define PACHI_TACTICS_SELFATARI_H 3 | 4 | /* A fairly reliable elf-atari detector. */ 5 | 6 | #include "../../board/board.h" 7 | 8 | /* Check if this move is undesirable self-atari (resulting group would have 9 | * only single liberty and not capture anything; ko is allowed); we mostly 10 | * want to avoid these moves. The function actually does a rather elaborate 11 | * tactical check, allowing self-atari moves that are nakade, eye falsification 12 | * or throw-ins. */ 13 | static bool is_bad_selfatari(Board *b, Stone color, Coord to); 14 | 15 | /* Move (color, coord) is a selfatari; this means that it puts a group of 16 | * ours in atari; i.e., the group has two liberties now. Return the other 17 | * liberty of such a troublesome group (optionally stored at *bygroup) 18 | * if that one is not a self-atari. 19 | * (In case (color, coord) is a multi-selfatari, consider a randomly chosen 20 | * candidate.) */ 21 | Coord selfatari_cousin(void *context, RandFunc randfunc, Board *b, Stone color, Coord coord, group_t *bygroup); 22 | 23 | 24 | bool is_bad_selfatari_slow(Board *b, Stone color, Coord to); 25 | static inline bool 26 | is_bad_selfatari(Board *b, Stone color, Coord to) 27 | { 28 | /* More than one immediate liberty, thumbs up! */ 29 | if (immediate_liberty_count(b, to) > 1) 30 | return false; 31 | 32 | return is_bad_selfatari_slow(b, color, to); 33 | } 34 | 35 | #endif 36 | -------------------------------------------------------------------------------- /pachi_tactics/util.h: -------------------------------------------------------------------------------- 1 | #ifndef PACHI_TACTICS_UTIL_H 2 | #define PACHI_TACTICS_UTIL_H 3 | 4 | /* Advanced tactical checks non-essential to the board implementation. */ 5 | 6 | #include "../board/board.h" 7 | #include "board_interface.h" 8 | 9 | #define coord_dx(c1, c2) (X(c1) - X(c2)) 10 | #define coord_dy(c1, c2) (Y(c1) - Y(c2)) 11 | 12 | struct move_queue; 13 | typedef float floating_t; 14 | 15 | /* Measure various distances on the board: */ 16 | /* Distance from the edge; on edge returns 0. */ 17 | static int coord_edge_distance(Coord c, Board *b); 18 | /* Distance of two points in gridcular metric - this metric defines 19 | * circle-like structures on the square grid. */ 20 | static int coord_gridcular_distance(Coord c1, Coord c2, Board *b); 21 | 22 | /* Construct a "common fate graph" from given coordinate; that is, a weighted 23 | * graph of intersections where edges between all neighbors have weight 1, 24 | * but edges between neighbors of same color have weight 0. Thus, this is 25 | * "stone chain" metric in a sense. */ 26 | /* The output are distanes from start stored in given [board_size2()] array; 27 | * intersections further away than maxdist have all distance maxdist+1 set. */ 28 | void cfg_distances(Board *b, Coord start, int *distances, int maxdist); 29 | 30 | /* Compute an extra komi describing the "effective handicap" black receives 31 | * (returns 0 for even game with 7.5 komi). @stone_value is value of single 32 | * handicap stone, 7 is a good default. */ 33 | /* This is just an approximation since in reality, handicap seems to be usually 34 | * non-linear. */ 35 | floating_t board_effective_handicap(Board *b, int first_move_value); 36 | 37 | /* To avoid running out of time, assume we always have at least 30 more moves 38 | * to play if we don't have more precise information from gtp time_left: */ 39 | #define MIN_MOVES_LEFT 30 40 | 41 | /* Tactical evaluation of move @coord by color @color, given 42 | * simulation end position @b. I.e., a move is tactically good 43 | * if the resulting group stays on board until the game end. 44 | * The value is normalized to [0,1]. */ 45 | /* We can also take into account surrounding stones, e.g. to 46 | * encourage taking off external liberties during a semeai. */ 47 | static double board_local_value(bool scan_neis, Board *b, Coord coord, Stone color); 48 | 49 | 50 | static inline int 51 | coord_edge_distance(Coord c, Board *b) 52 | { 53 | int x = X(c), y = Y(c); 54 | int dx = x > board_size(b) / 2 ? board_size(b) - 1 - x : x; 55 | int dy = y > board_size(b) / 2 ? board_size(b) - 1 - y : y; 56 | return (dx < dy ? dx : dy) - 1 /* S_OFFBOARD */; 57 | } 58 | 59 | static inline int 60 | coord_gridcular_distance(Coord c1, Coord c2, Board *b) 61 | { 62 | int dx = abs(coord_dx(c1, c2)), dy = abs(coord_dy(c1, c2)); 63 | return dx + dy + (dx > dy ? dx : dy); 64 | } 65 | 66 | static inline double 67 | board_local_value(bool scan_neis, Board *b, Coord coord, Stone color) 68 | { 69 | if (scan_neis) { 70 | /* Count surrounding friendly stones and our eyes. */ 71 | int friends = 0; 72 | FOR4(coord, _, c) { 73 | friends += board_at(b, c) == color || board_at(b, c) == S_OFFBOARD || IsTrueEye(b, c, color); 74 | } ENDFOR4 75 | return (double) (2 * (board_at(b, coord) == color) + friends) / 6.f; 76 | } else { 77 | return (board_at(b, coord) == color) ? 1.f : 0.f; 78 | } 79 | } 80 | 81 | #endif 82 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | require 'torch' 11 | require 'cutorch' 12 | require 'nn' 13 | require 'cunn' 14 | require 'cudnn' 15 | require 'nngraph' 16 | 17 | require 'xlua' 18 | 19 | local framework = require 'train.rl_framework.infra.framework' 20 | local rl = require 'train.rl_framework.infra.env' 21 | local pl = require 'pl.import_into'() 22 | 23 | require 'train.rl_framework.infra.bundle' 24 | require 'train.rl_framework.infra.agent' 25 | 26 | local tnt = require 'torchnet' 27 | 28 | -- cutorch.setDevice(3) 29 | 30 | -- Build simple models. 31 | function build_policy_model(opt) 32 | local network_maker = require('train.rl_framework.examples.go.models.' .. opt.model_name) 33 | local network, crit, outputdim, monitor_list = network_maker({1, 25, 19, 19}, opt) 34 | return network:cuda(), crit:cuda() 35 | end 36 | 37 | local opt = pl.lapp[[ 38 | --actor (default "policy") 39 | --sampling (default "replay") 40 | --optim (default "supervised") 41 | --loss (default 'policy') 42 | --alpha (default 0.1) 43 | --nthread (default 8) 44 | --batchsize (default 256) 45 | --num_forward_models (default 4096) Number of forward models. 46 | --progress Whether to print the progress 47 | --epoch_size (default 12800) Epoch size 48 | --epoch_size_test (default 128000) Epoch size for test. 49 | --data_augmentation Whether to use data_augmentation 50 | 51 | --nGPU (default 1) Number of GPUs to use. 52 | --nstep (default 3) Number of steps. 53 | --model_name (default 'model-12-parallel-384-n-output-bn') 54 | --datasource (default 'kgs') 55 | --feature_type (default 'extended') 56 | ]] 57 | 58 | opt.userank = true 59 | opt.intermediate_step = opt.epoch_size / opt.batchsize / 10 60 | print(pl.pretty.write(opt)) 61 | 62 | local model, crits = build_policy_model(opt) 63 | 64 | local bundle = rl.Bundle{ 65 | models = { 66 | policy = model, 67 | }, 68 | crits = crits 69 | } 70 | 71 | local agent = rl.Agent{ 72 | bundle = bundle, 73 | opt = opt 74 | } 75 | 76 | local stats = { 77 | sgf_idx = { }, 78 | board_freq = torch.FloatTensor(19, 19):zero(), 79 | ply = { }, 80 | count = 0 81 | } 82 | 83 | local callbacks = { 84 | thread_init = function() 85 | require 'train.rl_framework.examples.go.ParallelCriterion2' 86 | end, 87 | forward_model_init = function(partition) 88 | local tnt = require 'torchnet' 89 | return tnt.IndexedDataset{ 90 | fields = { opt.datasource .. "_" .. partition }, 91 | path = './dataset' 92 | } 93 | end, 94 | forward_model_generator = function(dataset, partition) 95 | local fm_go = require 'train.rl_framework.examples.go.fm_go' 96 | return fm_go.FMGo(dataset, partition, opt) 97 | end, 98 | onSample = function(state) 99 | -- Compute the stats. 100 | --[[ 101 | if state.signature == 'train' then return end 102 | for i = 1, state.sample.sgf_idx:size(1) do 103 | local idx = state.sample.sgf_idx[i] 104 | if stats.sgf_idx[idx] == nil then stats.sgf_idx[idx] = 0 end 105 | stats.sgf_idx[idx] = stats.sgf_idx[idx] + 1 106 | 107 | local xy = state.sample.xy[i] 108 | local x = xy[1] 109 | local y = xy[2] 110 | 111 | stats.board_freq[x][y] = stats.board_freq[x][y] + 1 112 | stats.count = stats.count + 1 113 | 114 | local ply = state.sample.ply[i] 115 | if stats.ply[ply] == nil then stats.ply[ply] = 0 end 116 | stats.ply[ply] = stats.ply[ply] + 1 117 | end 118 | 119 | if stats.count % (2000 * opt.batchsize) == 0 then 120 | print(stats.board_freq:clone():mul(1.0 / stats.count)) 121 | require 'fb.debugger'.enter() 122 | end 123 | ]] 124 | end, 125 | --[[ 126 | onStartEpoch = function() 127 | print("In onStartEpoch") 128 | end, 129 | onStart = function() 130 | print("In onStart") 131 | end, 132 | onSample = function() 133 | print("In onSample") 134 | end, 135 | onUpdate = function() 136 | print("In onUpdate") 137 | end, 138 | onEndEpoch = function() 139 | print("In onEndEpoch") 140 | end 141 | ]] 142 | } 143 | 144 | -- callbacks: 145 | -- forward_model_generator 146 | -- checkpoint_filename(state, err): Get checkpoint filename 147 | -- tune_lr(state): tune the learning rate 148 | -- print(log, state): print the current state 149 | -- (All the remaining functions take state as input) 150 | -- onStartEpoch 151 | -- onStart 152 | -- onSample 153 | -- onUpdate 154 | -- For now just shortcut the trainloss/testloss. 155 | framework.run_rl(agent, callbacks, opt) 156 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | # 9 | 10 | #!/bin/bash -e 11 | th train.lua --nGPU 1 --datasource kgs --num_forward_models 2048 --nthread 4 --alpha 0.05 --epoch_size 128000 --data_augmentation 12 | -------------------------------------------------------------------------------- /train/README.md: -------------------------------------------------------------------------------- 1 | Training Policy network 2 | ================= 3 | 4 | In this directory, we implement a simple reinforcement learning framework using Lua/torch, and use 5 | it for training the policy network. 6 | 7 | 1. `./rl_framework/infra` A simple framework for reinforcement learning. 8 | 9 | 2. `./rl_framework/examples/go` The training code for policy network. 10 | 11 | For simple usage, the main program for training is under the root directory. Simply copy the data 12 | from [here](https://www.dropbox.com/sh/ihzvzajywmfvbhm/AACIgYxew4daP1LXY_HCKwNla?dl=0) to `./dataset` and run `./train.sh` 13 | will start the training procedure, which is an implementation of our [paper](http://arxiv.org/abs/1511.06410) plus a few modifications, including adding batch-normalization layers. 14 | 15 | With 4 GPUs, the training procedure gives 56.1% top-1 accuracy in KGS dataset in 3.5 days, and 57.1% top-1 in 6.5 days (see the simple log below). The parameters used are the following: `--epoch_size 256000 --GPU 4 --data_augmentation --alpha 0.1 --nthread 4` 16 | 17 |
18 | | Sun Aug 21 21:54:15 2016 | epoch 0001 | ms/batch 721 | train [1pi@1]: 11.230860 [1pi@5]: 30.617970 [3pi@1]: 3.099219 [3pi@5]: 14.042188 [2pi@5]: 18.935938 [2pi@1]: 4.482813 [policy]: 8.361849
19 | | test [1pi@1]: 27.767189 [1pi@5]: 59.403130 [3pi@1]: 5.380469 [3pi@5]: 24.729689 [2pi@5]: 34.030472 [2pi@1]: 8.382812 [policy]: 6.558414 | saved *
20 | 
21 | | Thu Aug 25 10:35:11 2016 | epoch 0381 | ms/batch 719 | train [1pi@1]: 56.226566 [1pi@5]: 87.523834 [3pi@1]: 21.542580 [3pi@5]: 51.992970 [2pi@5]: 68.728127 [2pi@1]: 34.199612 [policy]: 3.736506
22 | | test [1pi@1]: 56.124222 [1pi@5]: 87.432816 [3pi@1]: 21.600000 [3pi@5]: 52.107815 [2pi@5]: 68.922661 [2pi@1]: 34.421875 [policy]: 3.737540  | saved *
23 | 
24 | | Sun Aug 28 00:49:32 2016 | epoch 0661 | ms/batch 721 | train [1pi@1]: 57.075783 [1pi@5]: 88.215240 [3pi@1]: 22.512892 [3pi@5]: 53.472267 [2pi@5]: 70.093361 [2pi@1]: 35.576565 [policy]: 3.638625
25 | | test [1pi@1]: 57.101566 [1pi@5]: 88.271095 [3pi@1]: 22.295313 [3pi@5]: 53.226566 [2pi@5]: 70.085938 [2pi@1]: 35.185940 [policy]: 3.646803  | saved
26 | 
27 | 28 | Note that this is a general framework and could be used to train other tasks (e.g, value networks) in the future. If you have used our engine or training procedure, please cite the following paper: 29 | 30 | ``` 31 | Better Computer Go Player with Neural Network and Long-term Prediction, ICLR 2016 32 | Yuandong Tian, Yan Zhu 33 | 34 | @article{tian2015better, 35 | title={Better Computer Go Player with Neural Network and Long-term Prediction}, 36 | author={Tian, Yuandong and Zhu, Yan}, 37 | journal={arXiv preprint arXiv:1511.06410}, 38 | year={2015} 39 | } 40 | ``` 41 | 42 | 43 | -------------------------------------------------------------------------------- /train/rl_framework/examples/go/ParallelCriterion2.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | local ParallelCriterion, parent = torch.class('nn.ParallelCriterion2', 'nn.Criterion') 11 | 12 | function ParallelCriterion:__init(rep_count) 13 | parent.__init(self) 14 | self.criterions = {} 15 | self.weights = {} 16 | self.gradInput = {} 17 | -- If rep_count == 2, and #criterions = 10, we expect 5 targets. 18 | -- The 10 criterions will receive the following targets in order: t1, t2, t3, t4, t5, t1, t2, t3, t3, t5 19 | self.rep_count = rep_count or 1 20 | end 21 | 22 | function ParallelCriterion:add(criterion, weight) 23 | weight = weight or 1 24 | table.insert(self.criterions, criterion) 25 | table.insert(self.weights, weight) 26 | return self 27 | end 28 | 29 | function ParallelCriterion:updateOutput(input, target) 30 | self.output = 0 31 | assert(#self.criterions == target:size(2) * self.rep_count, 32 | string.format("ParallelCriterion2: #criterions [%d] != #target [%d] * #repcount [%d]", #self.criterions, target:size(2), self.rep_count)) 33 | for i,criterion in ipairs(self.criterions) do 34 | -- Target size is nbatch x #targets, which is more suitable for torchnet setting. 35 | local target_idx = (i - 1) % target:size(2) + 1 36 | local target = target:select(2, target_idx) 37 | self.output = self.output + self.weights[i]*criterion:updateOutput(input[i],target) 38 | end 39 | return self.output 40 | end 41 | 42 | function ParallelCriterion:updateGradInput(input, target) 43 | self.gradInput = nn.utils.recursiveResizeAs(self.gradInput, input) 44 | nn.utils.recursiveFill(self.gradInput, 0) 45 | for i,criterion in ipairs(self.criterions) do 46 | local target_idx = (i - 1) % target:size(2) + 1 47 | local target = target:select(2, target_idx) 48 | nn.utils.recursiveAdd(self.gradInput[i], self.weights[i], criterion:updateGradInput(input[i], target)) 49 | end 50 | return self.gradInput 51 | end 52 | 53 | function ParallelCriterion:type(type) 54 | self.gradInput = {} 55 | return parent.type(self, type) 56 | end 57 | 58 | -------------------------------------------------------------------------------- /train/rl_framework/examples/go/models/model-12-parallel-384-n-output-bn.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cudnn' 3 | local pl = require('pl.import_into')() 4 | local nnutils = require 'utils.nnutils' 5 | require 'train.rl_framework.examples.go.ParallelCriterion2' 6 | 7 | -- Specification. 8 | local function get_network_spec(n) 9 | return { 10 | -- input is 19x19x? 11 | {type='conv', kw=5, dw=1, pw=2, nop=92}, 12 | {type='relu'}, 13 | {type='spatialbn'}, 14 | -- No max pooling, does not make sense. 15 | -- {type='maxp', kw=3, dw=2}, 16 | 17 | {type='conv', kw=3, dw=1, pw=1, nop=384}, 18 | {type='relu'}, 19 | {type='spatialbn'}, 20 | 21 | -- {type='maxp', kw=3, dw=2}, 22 | {type='conv', kw=3, dw=1, pw=1, nop=384}, 23 | {type='relu'}, 24 | {type='spatialbn'}, 25 | 26 | {type='conv', kw=3, dw=1, pw=1, nop=384}, 27 | {type='relu'}, 28 | {type='spatialbn'}, 29 | 30 | -- {type='maxp', kw=3, dw=2}, 31 | {type='conv', kw=3, dw=1, pw=1, nop=384}, 32 | {type='relu'}, 33 | {type='spatialbn'}, 34 | 35 | {type='conv', kw=3, dw=1, pw=1, nop=384}, 36 | {type='relu'}, 37 | {type='spatialbn'}, 38 | 39 | {type='conv', kw=3, dw=1, pw=1, nop=384}, 40 | {type='relu'}, 41 | {type='spatialbn'}, 42 | 43 | {type='conv', kw=3, dw=1, pw=1, nop=384}, 44 | {type='relu'}, 45 | {type='spatialbn'}, 46 | 47 | -- {type='maxp', kw=3, dw=2}, 48 | {type='conv', kw=3, dw=1, pw=1, nop=384}, 49 | {type='relu'}, 50 | {type='spatialbn'}, 51 | 52 | {type='conv', kw=3, dw=1, pw=1, nop=384}, 53 | {type='relu'}, 54 | {type='spatialbn'}, 55 | 56 | {type='conv', kw=3, dw=1, pw=1, nop=384}, 57 | {type='relu'}, 58 | {type='spatialbn'}, 59 | 60 | {type='conv', kw=3, dw=1, pw=1, nop=n}, 61 | } 62 | end 63 | 64 | return function(inputdim, config) 65 | assert(inputdim[3] == 19) 66 | assert(inputdim[4] == 19) 67 | 68 | local net, outputdim = nnutils.make_network(get_network_spec(config.nstep), inputdim) 69 | if config.nGPU>1 then 70 | require 'cutorch' 71 | assert(config.nGPU <= cutorch.getDeviceCount(), 'number of GPUs less than config.nGPU specified') 72 | local net_single = net 73 | net = nn.DataParallel(1) 74 | for i=1, config.nGPU do 75 | cutorch.withDevice(i, function() 76 | net:add(net_single:clone()) 77 | end) 78 | end 79 | end 80 | 81 | local model = nn.Sequential() 82 | model:add(net):add(nn.View(config.nstep, 19*19):setNumInputDims(3)):add(nn.SplitTable(1, 2)) 83 | local softmax = nn.Sequential() 84 | -- softmax:add(nn.Reshape(19*19, true)) 85 | softmax:add(nn.LogSoftMax()) 86 | -- )View(-1):setNumInputDims(2)) 87 | 88 | local softmaxs = nn.ParallelTable() 89 | -- Use self-defined parallel criterion 2, which can handle targets of the format nbatch * #target 90 | local criterions = nn.ParallelCriterion2() 91 | for k = 1, config.nstep do 92 | softmaxs:add(softmax:clone()) 93 | local w = 1.0 / k 94 | criterions:add(nn.ClassNLLCriterion(), w) 95 | end 96 | model:add(softmaxs) 97 | return model, criterions, outputdim 98 | end 99 | -------------------------------------------------------------------------------- /train/rl_framework/infra/engine.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local doc = require 'argcheck.doc' 13 | 14 | doc[[ 15 | 16 | ### rl.Engine 17 | 18 | The Engine module implements the training procedure for reinforcement learning. 19 | 20 | procedure in `train`, including data sampling, forward prop, back prop, and 21 | parameter updates. It also operates as a coroutine allowing a user control 22 | (i.e. increment some sort of `tnt.Meter`) at events such as 'start', 23 | 'start-epoch', 'forward', 'forward-criterion', 'backward', etc. 24 | 25 | Accordingly, `train` requires a network (nn.Module), a criterion expressing the 26 | loss function (nn.Criterion), a dataset iterator (tnt.DatasetIterator), and a 27 | learning rate, at the minimum. The `test` function allows for simple evaluation 28 | of a model on a dataset. 29 | 30 | A `state` is maintained for external access to outputs and parameters of modules 31 | as well as sampled data. 32 | ]] 33 | 34 | require 'nn' 35 | 36 | local rl = require 'train.rl_framework.infra.env' 37 | local RLEngine, Engine = torch.class('rl.Engine', 'tnt.Engine', rl) 38 | 39 | RLEngine.__init = argcheck{ 40 | {name="self", type="rl.Engine"}, 41 | call = 42 | function(self) 43 | Engine.__init(self, { 44 | "onStart", "onStartEpoch", "onSample", "onForward", 45 | "onEndEpoch", "onUpdate", "onEnd" 46 | }) 47 | end 48 | } 49 | 50 | local function clear_errs(state) 51 | state.errs = nil 52 | state.errs_count = nil 53 | end 54 | 55 | local function accumulate_errs(state, errs) 56 | if not state.errs then 57 | state.errs = { } 58 | state.errs_count = 0 59 | end 60 | -- require 'fb.debugger'.enter() 61 | for k, e in pairs(errs) do 62 | if type(e) == 'number' then 63 | e = torch.FloatTensor({e}) 64 | end 65 | 66 | if state.errs[k] == nil then 67 | state.errs[k] = e 68 | else 69 | state.errs[k]:add(e) 70 | end 71 | end 72 | state.errs_count = state.errs_count + 1 73 | end 74 | 75 | RLEngine.train = argcheck{ 76 | {name="self", type="rl.Engine"}, 77 | {name="agent", type="rl.Agent"}, 78 | {name="iterator", type="tnt.DatasetIterator"}, 79 | {name="opt", type="table"}, 80 | call = 81 | function(self, agent, iterator, opt) 82 | -- assert(opt.lr, "Learning rate has to be set") 83 | local state = { 84 | agent = agent, 85 | iterator = iterator, 86 | maxepoch = opt.maxepoch or 1000, 87 | sample = {}, 88 | epoch = 0, -- epoch done so far 89 | lr = opt.lr, 90 | training = true, 91 | } 92 | 93 | local function update_sampling_before() 94 | -- If the sampler is multiple threads, we need to call synchronize() to make sure 95 | -- when the sampling model is being updated, all the samplers are not using it. 96 | if iterator.__threads then 97 | iterator.__threads:synchronize() 98 | end 99 | end 100 | 101 | self.hooks("onStart", state) 102 | while state.epoch < state.maxepoch do 103 | state.agent:training() 104 | clear_errs(state) 105 | state.t = 0 106 | self.hooks("onStartEpoch", state) 107 | 108 | for sample in state.iterator() do 109 | state.sample = sample 110 | self.hooks("onSample", state) 111 | 112 | -- This includes forward/backward and parameter update. 113 | -- Different RL will use different approaches. 114 | local errs = state.agent:optimize(sample) 115 | accumulate_errs(state, errs) 116 | 117 | state.t = state.t + 1 118 | self.hooks("onUpdate", state) 119 | end 120 | 121 | -- Update the sampling model. 122 | -- state.agent:update_sampling_model(update_sampling_before) 123 | state.agent:update_sampling_model() 124 | 125 | state.epoch = state.epoch + 1 126 | self.hooks("onEndEpoch", state) 127 | end 128 | self.hooks("onEnd", state) 129 | return state 130 | end 131 | } 132 | 133 | RLEngine.test = argcheck{ 134 | {name="self", type="rl.Engine"}, 135 | {name="agent", type="rl.Agent"}, 136 | {name="iterator", type="tnt.DatasetIterator"}, 137 | call = function(self, agent, iterator) 138 | local state = { 139 | agent = agent, 140 | iterator = iterator, 141 | sample = {}, 142 | t = 0, -- samples seen so far 143 | training = false 144 | } 145 | 146 | self.hooks("onStart", state) 147 | state.agent:evaluate() 148 | for sample in state.iterator() do 149 | state.sample = sample 150 | self.hooks("onSample", state) 151 | local errs = state.agent:optimize(sample) 152 | accumulate_errs(state, errs) 153 | 154 | state.t = state.t + 1 155 | self.hooks("onForward", state) 156 | end 157 | self.hooks("onEnd", state) 158 | return state 159 | end 160 | } 161 | -------------------------------------------------------------------------------- /train/rl_framework/infra/env.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | local rl = { } 11 | 12 | require 'torch' 13 | 14 | function rl.func_lookup(v, func_table) 15 | return type(v) == 'function' and v or func_table[v] 16 | end 17 | 18 | return rl 19 | -------------------------------------------------------------------------------- /train/rl_framework/infra/forwardmodel.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | local argcheck = require 'argcheck' 11 | local doc = require 'argcheck.doc' 12 | 13 | local pl = require 'pl.import_into'() 14 | 15 | doc[[ 16 | ### rl.ForwardModel 17 | The ForwardModel module implements the forward model. The application should just override this class to define behavior. 18 | The model could be: 19 | * A simulated environment (any action is possible, real experience) 20 | * A fitted forward model (any action possible) 21 | * Previous experience (arbitrary action is not possible) 22 | 23 | User could inherent from this class. 24 | ]] 25 | 26 | require 'nn' 27 | 28 | local rl = require 'train.rl_framework.infra.env' 29 | local ForwardModel = torch.class('rl.ForwardModel', rl) 30 | 31 | ForwardModel.__init = argcheck{ 32 | {name="self", type="rl.ForwardModel"}, 33 | call = function(self) 34 | end 35 | } 36 | 37 | -- Reset to a (new, maybe random) state. 38 | ForwardModel.reset = argcheck{ 39 | {name="self", type="rl.ForwardModel"}, 40 | call = function(self) 41 | error("ForwardModel.reset is not implemented") 42 | end 43 | } 44 | 45 | -- forward(a) -> s', r 46 | ForwardModel.forward = argcheck{ 47 | {name="self", type="rl.ForwardModel"}, 48 | {name="action", type="number"}, 49 | call = function(self, action) 50 | error("ForwardModel.forward is not implemented") 51 | end 52 | } 53 | 54 | ForwardModel.add_sample = argcheck{ 55 | {name="self", type="rl.ForwardModel"}, 56 | {name="entry", type="table"}, 57 | call = function(self, entry) 58 | error("ForwardModel.add_sample is not implemented") 59 | end 60 | } 61 | 62 | -- get_actions() -> get available actions for current state (by returning a list) 63 | ForwardModel.get_actions = argcheck{ 64 | {name="self", type="rl.ForwardModel"}, 65 | call = function(self) 66 | error("ForwardModel.get_actions is not implemented") 67 | end 68 | } 69 | 70 | -- Get the representation of the current state. 71 | ForwardModel.get_curr_state_rep = argcheck{ 72 | {name="self", type="rl.ForwardModel"}, 73 | call = function(self) 74 | error("ForwardModel.get_curr_state_rep is not implemented") 75 | end 76 | } 77 | 78 | --[[ 79 | ForwardModel.clone = argcheck{ 80 | {name="self", type="rl.ForwardModel"}, 81 | call = function(self) 82 | error("ForwardModel.clone is not implemented") 83 | end 84 | } 85 | ]] 86 | -------------------------------------------------------------------------------- /tsumego/rank_move.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _RANK_MOVE_H_ 11 | #define _RANK_MOVE_H_ 12 | 13 | #include "../board/board.h" 14 | #include 15 | 16 | void GetRankedMoves(const Board* board, Stone defender, const Region *r, int max_num_moves, AllMoves *all_moves); 17 | 18 | // Save the candidate move to file. 19 | // The feature for each move are printed in one line, separated by comma. 20 | // Usually good move with score 1, bad move with score 0. 21 | void SaveMoveFeatureName(FILE *fp); 22 | BOOL SaveMoveWithFeature(const Board *board, Stone defender, Coord m, int score, FILE *fp); 23 | 24 | #endif 25 | -------------------------------------------------------------------------------- /tsumego/solver.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #ifndef _SOLVER_H_ 11 | #define _SOLVER_H_ 12 | 13 | #include "../board/board.h" 14 | 15 | // Currently we only have one type. 16 | #define TG_LIVE_DIES 0 17 | 18 | // Several criteriafor tsumego. 19 | // 1. w/b lives. 20 | // There is one w/b group which has 2 liberties and no candidate move for b/w can touch the group. 21 | // 2. w/b dead. 22 | // w/b loses too many stones (higher than the threshold). 23 | 24 | typedef struct { 25 | // Goal of this search. 26 | // E.g., player = WHITE, then if w lives, -10, w dead, 10, otherwise keep searching. 27 | // player = BLACK, then if b dead, -10, b lives, +10. otherwise keep searching. 28 | Stone target_player; 29 | 30 | // The threshold for dead. 31 | int dead_thres; 32 | 33 | Region region; 34 | 35 | // Maximum count of the search. -1 means infinite. 36 | int max_count; 37 | } TGCriterion; 38 | 39 | // Solve a given tsumego, located at Region, by doing an exhaustive search. 40 | int TsumegoSearch(const Board *board, const TGCriterion *criterion, AllMoves *move_seq); 41 | 42 | #endif 43 | -------------------------------------------------------------------------------- /tsumego/solver.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | local ffi = require 'ffi' 11 | local pl = require 'pl.import_into'() 12 | local utils = require('utils.utils') 13 | local common = require("common.common") 14 | local board = require("board.board") 15 | 16 | local image = require 'image' 17 | 18 | -- local symbols, s = utils.ffi_include(paths.concat(common.lib_path, "tsumego/solver.h")) 19 | local symbols, s = utils.ffi_include(paths.concat(common.script_path(), "solver.h")) 20 | local C = ffi.load("libexperimental_deeplearning_yuandong_go_tsumego_solver_c.so") 21 | 22 | -- print(s) 23 | local tg = {} 24 | 25 | local black_lives = tonumber(symbols.BLACK_LIVES) 26 | local black_dies = tonumber(symbols.BLACK_DIES) 27 | local white_lives = tonumber(symbols.WHITE_LIVES) 28 | local white_dies = tonumber(symbols.WHITE_DIES) 29 | local die_at_loc = tonumber(symbols.DIE_AT_LOC) 30 | 31 | local all_moves = ffi.new("AllMoves") 32 | 33 | --[[ 34 | function tg.set_die_at_crit(x, y) 35 | -- Find a way so that stone die at x, y. 36 | local crit = ffi.new("TGCriterion") 37 | crit.goal = die_at_loc 38 | crit.critical_loc = common.xy2coord(x, y) 39 | crit.max_depth = 10 40 | return crit 41 | end 42 | ]] 43 | 44 | function tg.set_target(aspect) 45 | -- Find a way so that stone die at x, y. 46 | local crit = ffi.new("TGCriterion") 47 | crit.max_count = -1 -- 1000000 48 | crit.dead_thres = 3 49 | crit.target_player = aspect == 'b' and common.black or common.white 50 | return crit 51 | end 52 | 53 | function tg.solve(b, crit) 54 | crit.region = board.get_stones_bbox(b) 55 | 56 | local res = C.TsumegoSearch(b, crit, all_moves) 57 | -- Convert all moves to (x, y, player) 58 | local player = b._next_player 59 | local moves = { } 60 | for i = 0, all_moves.num_moves - 1 do 61 | local x, y = common.coord2xy(all_moves.moves[i]) 62 | table.insert(moves, {x, y, player}) 63 | player = common.opponent(player) 64 | end 65 | return moves 66 | end 67 | 68 | return tg 69 | -------------------------------------------------------------------------------- /tsumego/test_solver.c: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2016-present, Facebook, Inc. 3 | // All rights reserved. 4 | // 5 | // This source code is licensed under the BSD-style license found in the 6 | // LICENSE file in the root directory of this source tree. An additional grant 7 | // of patent rights can be found in the PATENTS file in the same directory. 8 | // 9 | 10 | #include "solver.h" 11 | 12 | Coord gtp2coord(const char *s) { 13 | } 14 | 15 | void PlayStones(const Board *b, const char *blacks, const char *whites) { 16 | } 17 | 18 | int main() { 19 | Board b; 20 | ClearBoard(&b); 21 | // Put stones here. 22 | const char *whites[] = { 23 | "L19", "L18", "L17", "L16", "M16", "N16", "O16", 24 | "P16", "R19", "R16", "R18", "S17", "S16", "T18" 25 | }; 26 | const char *blacks[] = { 27 | "M18", "M17", "N17", "O17", "P18", "Q18", "Q17", 28 | "R17", "S18" 29 | }; 30 | 31 | TGCriterion crit; 32 | TGRegion *r = &crit.region; 33 | GetBoardBBox(&b, &r->left, &r->top, &r->right, &r->bottom); 34 | 35 | // Black should live. 36 | crit.w_cap_upper_bound = 4; 37 | crit.w_crit_loc = M_PASS; 38 | 39 | // Get moves. 40 | AllMoves moves; 41 | TsumegoSearch(&b, &crit, &moves); 42 | } 43 | -------------------------------------------------------------------------------- /tsumego/test_solver.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016-present, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | local sgf = require("utils.sgf") 11 | local goutils = require 'utils.goutils' 12 | local utils = require('utils.utils') 13 | local board = require("board.board") 14 | local common = require("common.common") 15 | local tg = require 'tsumego.solver' 16 | 17 | -- Tsumego example 18 | local example1 = '/home/yuandong/test/tsumego-1.sgf' 19 | local game = sgf.parse(io.open(example1):read("*a")) 20 | 21 | local sgf_play = { } 22 | 23 | local b = board.new() 24 | board.clear(b) 25 | 26 | -- Assume the board is cleared. And we will setup everything given this game. 27 | -- First put all the existing stones there. (apply_handicap is not a good name). 28 | goutils.apply_handicaps(b, game) 29 | board.play(b, goutils.parse_move_gtp('T16', 'B')) 30 | board.play(b, goutils.parse_move_gtp('O19', 'W')) 31 | board.play(b, goutils.parse_move_gtp('N19', 'B')) 32 | board.play(b, goutils.parse_move_gtp('M19', 'W')) 33 | board.play(b, goutils.parse_move_gtp('O18', 'B')) 34 | board.play(b, goutils.parse_move_gtp('N18', 'W')) 35 | 36 | board.play(b, goutils.parse_move_gtp('P19', 'B')) 37 | board.play(b, goutils.parse_move_gtp('N19', 'W')) 38 | 39 | 40 | board.show(b, 'last_move') 41 | 42 | local crit = tg.set_target('b') 43 | local moves = tg.solve(b, crit) 44 | 45 | print("Best sequence:") 46 | for i = 1, #moves do 47 | local x, y, player = unpack(moves[i]) 48 | local c, player_str = goutils.compose_move_gtp(x, y, player) 49 | print("Move: " .. c .. " " .. player_str) 50 | if not board.play(b, x, y, player) then 51 | error("The move " .. c .. " " .. player_str .. " cannot be played!") 52 | end 53 | end 54 | 55 | board.show(b, 'last_move') 56 | 57 | 58 | -------------------------------------------------------------------------------- /utils/test.sgf: -------------------------------------------------------------------------------- 1 | (;GM[1]FF[4]CA[UTF-8]AP[CGoban:3]ST[2] 2 | RU[Chinese]SZ[9]KM[7.00]TM[240]OT[10/30 Canadian] 3 | PW[Zen19S]PB[AyaMC]BR[2d]DT[2014-06-01]PC[The KGS Go Server at http://www.gokgs.com/]C[AyaMC [2d\]: GTP Engine for AyaMC (black): Aya version 7.73e : If life-death is wrong, push 'undo' and remove dead stones. Private chat 'wr' gets winrate. 4 | Zen19S [?\]: GTP Engine for Zen19S (white): Zen version 10.3d13 5 | ]RE[W+Resign] 6 | ;B[cd]BL[239.548] 7 | ;W[gf]WL[234.64] 8 | ;B[fe]BL[230.617] 9 | ;W[ff]WL[228.344] 10 | ;B[ge]BL[213.782] 11 | ;W[ee]WL[221.953] 12 | ;B[ef]BL[205.33] 13 | ;W[ed]WL[214.83] 14 | ;B[hf]BL[191.51] 15 | ;W[df]WL[207.814] 16 | ;B[eg]BL[185.509] 17 | ;W[dg]WL[200.478] 18 | ;B[eh]BL[185.026] 19 | ;W[dh]WL[192.372] 20 | ;B[gg]BL[172.396] 21 | ;W[dc]WL[184.731] 22 | ;B[cc]BL[170.638] 23 | ;W[be]WL[177.666] 24 | ;B[db]BL[170.132] 25 | ;W[eb]WL[170.459] 26 | ;B[ce]BL[157.002] 27 | ;W[bf]WL[163.868] 28 | ;B[cb]BL[154.423] 29 | ;W[gd]WL[155.161] 30 | ;B[fb]BL[138.276] 31 | ;W[ea]WL[148.9] 32 | ;B[cf]BL[128.132] 33 | ;W[cg]WL[143.157] 34 | ;B[de]BL[119.658] 35 | ;W[fd]WL[137.606] 36 | ;B[he]BL[118.777] 37 | ;W[bd]WL[132.253] 38 | ;B[ag]BL[103.379] 39 | ;W[gb]WL[126.901] 40 | ;B[bc]BL[92.299] 41 | ;W[bg]WL[122.126] 42 | ;B[bh]BL[89.489] 43 | ;W[da]WL[117.182] 44 | ;B[fc]BL[79.358] 45 | ;W[gc]WL[112.607] 46 | ;B[ab]BL[71.489] 47 | ;W[ba]WL[108.102]C[gghideki [?\]: thx 48 | ]) 49 | --------------------------------------------------------------------------------