├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── LICENSE.txt ├── Makefile ├── README.md ├── README_CN.md ├── README_JP.md ├── config.txt ├── engine └── .gitkeep ├── log └── .gitkeep ├── obj └── .gitkeep ├── prob ├── prob_ptn3x3.txt └── prob_ptn_rsp.txt └── src ├── bitboard.h ├── board.cc ├── board.h ├── config.h ├── eval_cache.h ├── eval_worker.h ├── feature.cc ├── feature.h ├── gtp.cc ├── gtp.h ├── main.cc ├── network.cc ├── network.h ├── node.h ├── option.cc ├── option.h ├── pattern.cc ├── pattern.h ├── route_queue.h ├── search.cc ├── search.h ├── sgf.cc ├── sgf.h ├── test.cc ├── test.h ├── timer.h ├── types.cc └── types.h /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Title** 11 | Make it possible to grasp the content of the issue by title only. 12 | 13 | Good: 'Resigns immediately on handicap games' 14 | Bad : 'some questions', 'a bug' 15 | 16 | **Describe the bug** 17 | A clear and concise description of what the bug is. 18 | Do not include multiple topics in one issue. 19 | 20 | **Your environment** 21 | Describe the version of AQ, OS, GPU, etc. 22 | 23 | **To reproduce** 24 | Write as much detail as you can about how you're going to reproduce it. 25 | If you have an idea for a cause or solution, please fill in the form. 26 | 27 | **Additional context** 28 | Add any other context about the problem here. 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # 2 | # 1. General Compiler Settings 3 | # 4 | COMPILER = gcc 5 | CFLAGS = -std=c++11 -Wextra -fpermissive -fmessage-length=0 -mbmi2 -mavx2 -MMD -MP -Wno-deprecated-declarations 6 | LDFLAGS = -lstdc++ -lm 7 | INCLUDES = 8 | 9 | # 10 | # 2. Traget Specific Settings 11 | # 12 | 13 | # 2.1 Linux / Windows 14 | ifeq ($(shell uname),Linux) 15 | # TensorRT 16 | LDFLAGS += -L/usr/local/cuda/targets/x86_64-linux/lib/ -lpthread -lcudart -lnvinfer -lnvonnxparser -lnvparsers 17 | INCLUDES += -I/usr/local/cuda/include -I/usr/local/cuda/targets/x86_64-linux/include 18 | OUTFILE = AQ 19 | else 20 | echo 'TensorRT7 on Windows deos not support MinGW. Use MSVC instead.' 21 | endif 22 | 23 | # 2.2 Set FLAGS 24 | ifeq ($(TARGET),) 25 | CFLAGS += -Ofast -fno-fast-math 26 | endif 27 | ifeq ($(TARGET),debug) 28 | CFLAGS += -g -Og 29 | endif 30 | 31 | # 32 | # 3. Default Settings 33 | # 34 | OUTFILE = AQ 35 | OBJDIR = ./obj 36 | SRCDIR = ./src 37 | SOURCES = $(wildcard $(SRCDIR)/*.cc) 38 | 39 | OBJECTS = $(addprefix $(OBJDIR)/, $(SOURCES:.cc=.o)) 40 | DEPENDS = $(OBJECTS:.o=.d) 41 | 42 | # 43 | # 4. Public Targets 44 | # 45 | .PHONY: all debug clean 46 | all: 47 | $(MAKE) executable 48 | 49 | debug: 50 | $(MAKE) TARGET=$@ executable 51 | 52 | clean: 53 | rm -f $(OBJECTS) $(DEPENDS) $(OUTFILE) ${OBJECTS:.o=.gcda} 54 | 55 | # 56 | # 5. Private Targets 57 | # 58 | .PHONY: executable 59 | executable: $(OBJECTS) 60 | $(COMPILER) -o $(OUTFILE) $^ $(LDFLAGS) $(CFLAGS) 61 | 62 | $(OBJDIR)/%.o: %.cc Makefile 63 | @[ -d $(dir $@) ] || mkdir -p $(dir $@) 64 | $(COMPILER) $(CFLAGS) $(INCLUDES) -o $@ -c $< 65 | 66 | -include $(DEPENDS) 67 | 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GLOBIS-AQZ 2 | 3 | GLOBIS-AQZ is a Go game engine that uses Deep Learning technology. 4 | It features support for both the Japanese rule with Komi 6.5 and the Chinese rule with Komi 7.5. 5 | 6 | This program utilizes the results of the GLOBIS-AQZ project. 7 | 8 | > GLOBIS-AQZ is a joint project developed by GLOBIS Corporation, Mr. Yu Yamaguchi, and Tripleize Co., Ltd., provided by the National Institute of Advanced Industrial Science and Technology (AIST), and cooperated by the Nihon Ki-in. This program uses the result of GLOBIS-AQZ. 9 | 10 | Since it is open source software, anyone can use it for free. 11 | This program is for playing and analyzing games, so please set it to GUI software such as [Lizzie](https://github.com/featurecat/lizzie), [Sabaki](https://github.com/SabakiHQ/Sabaki) and [GoGui](https://sourceforge.net/projects/gogui/). 12 | 13 | 日本語の説明は[こちら](https://github.com/ymgaq/AQ/blob/master/README_JP.md)をご覧ください。 14 | 请看[这里的](https://github.com/ymgaq/AQ/blob/master/README_CN.md)中文解释. 15 | 16 | ## 1. Downloads 17 | Download executable files from [Releases](https://github.com/ymgaq/AQ/releases). 18 | The executable files built on Windows 10 or Linux (Ubuntu 18.04) are available. 19 | 20 | If it does not work as it is in other environments, please consider building it for each environment. (for developers) 21 | 22 | ## 2. Requirements 23 | + OS : Windows 10, Linux (64-bit) 24 | + GPU : Nvidia's GPU ([Compute Capability](https://developer.nvidia.com/cuda-gpus) >3.0) 25 | + [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit) 10.0 or 10.2 + [cuDNN](https://developer.nvidia.com/cudnn) 7.6.5 26 | + [TensorRT 7.0.0](https://docs.nvidia.com/deeplearning/sdk/tensorrt-archived/tensorrt-700/tensorrt-install-guide/index.html) 27 | + [Visual C++ redistributable packages for Visual Studio 2015-2019](https://support.microsoft.com/en-us/help/2977003/the-latest-supported-visual-c-downloads) (for Windows only) 28 | 29 | It has been tested in the following environment. 30 | + Ubuntu 18.04 / RTX2080Ti / CUDA10.0 / TensorRT7.0.0 31 | + Windows 10 Pro (64bit) / RTX2080Ti / CUDA10.2 / TensorRT7.0.0 32 | 33 | ## 3. How to use 34 | For example, if you want to start GTP mode in the case of Japanese rule and with time settings of 20 minutes and 30-seconds byoyomi: 35 | ``` 36 | $ AQ.exe --rule=1 --komi=6.5 --main_time=1200 --byoyomi=30 37 | ``` 38 | With Chinese rule and Komi 7.5 (default), the number of searches (playouts) is fixed at 800 without ponder: 39 | ``` 40 | $ AQ.exe --search_limit=800 --use_ponder=off 41 | ``` 42 | With Tromp-Taylor rule and Komi 7.5, 15 minutes sudden death such as games on [CGOS](http://www.yss-aya.com/cgos/): 43 | ``` 44 | $ AQ.exe --rule=2 --repetition_rule=2 --main_time=900 --byoyomi=0 45 | ``` 46 | 47 | ### 3-1. Setting environment variables 48 | In the case of Windows, the following path must be registered in the PATH environment variable. 49 | ``` 50 | {your_cuda_path}\NVIDIA GPU Computing Toolkit\CUDA\v10.{x}\bin 51 | {your_tensorrt_path}\TensorRT-7.0.0.{xx}\lib 52 | ``` 53 | 54 | ### 3-2. Generating engine files 55 | The first time it starts up, it generates a network engine optimized for your environment from a file in UFF (Universal File Format) format. 56 | It may take a few minutes to generate this engine. 57 | The serialized engine files are saved in the `engine` folder, so it will start immediately the second time around. 58 | 59 | ### 3-3. Register with Lizzie 60 | For Windows, add `{your_aq_folder}/AQ.exe --lizzie` to the engine command. 61 | For example, if you want to analyze by Japanese rules, please modify the config.txt file in the AQ folder to use various settings. 62 | 63 | ## 4. Options 64 | Here's a description of the main options. 65 | It can be specified as a command line argument, or it can be changed by editing config.txt. 66 | For example, `--komi=6.5`. 67 | 68 | ### 4-1. Game options 69 | | Option | default | description | 70 | | :--- | :--- | :--- | 71 | | --num_gpus | 1 | The number of GPUs to use. | 72 | | --num_threads | 16 | The number of threads to be used for searching. | 73 | | --main_time | 0.0 | Main time of search (in seconds). | 74 | | --byoyomi | 3.0 | Byoyomi (in seconds). | 75 | | --rule | 0 | The rule of the game. 0: Chinese rule 1: Japanese rule 2: Tromp-Taylor rule | 76 | | --komi | 7.5 | Number of Komi. In the case of Japanese rule, please specify 6.5. | 77 | | --batch_size | 8 | The number of batches for a single evaluation. | 78 | | --search_limit | -1 | The number of searches (playouts). -1 means this option is disable. | 79 | | --node_size | 65536 | Maximum number of nodes of the search. When this number of nodes is reached, the search is terminated. | 80 | | --use_ponder | on | Whether or not to read ahead in the opponent's turn. You must turn it on when using it in Lizzie. | 81 | | --resign_value | 0.05 | the winning rate to be given up. | 82 | | --save_log | off | Whether or not to save the game's thought logs and sgf files. | 83 | 84 | ### 4-2. Launch modes 85 | Mainly for debugging. Please do not use any other games other than `--lizzie` for normal games and analysis. 86 | They are only recognized as a command line argument. 87 | 88 | | Option | Launch mode | 89 | | :--- | :--- | 90 | | (not specified) | GTP communication mode | 91 | | --lizzie | In addition to GTP communication, it outputs information for Lizzie. | 92 | | --self | AQ starts a self game. | 93 | | --policy_self | AQ starts a self game with the best move in policy networks. | 94 | | --test | Tests the consistency of the board data structure, etc. | 95 | | --benchmark | Measures the computational speed of rollouts and neural networks. | 96 | 97 | ## 5. Compilation method 98 | The following is an explanation for developers. 99 | The source code is implemented only for games and analysis, and does not include any learning functions. 100 | 101 | AQ is written so that it can be compiled with C++11/C++14, and the coding conventions are generally referred to the following page. 102 | + [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html) 103 | 104 | ### 5-1. Linux 105 | Requirements 106 | + gcc 107 | + make 108 | + CUDA Toolkit 10.x 109 | + TensorRT 7.0.0 110 | 111 | Check the include path and library path of CUDA and TensorRT in the Makefile and make it. 112 | 113 | ``` 114 | $ make 115 | ``` 116 | 117 | ### 5-2. Windows 118 | Requirements 119 | + Visual Studio 2019 (MSVC v142) 120 | + CUDA Toolkit 10.x 121 | + TensorRT 7.0.0 122 | 123 | Additional include directories: 124 | ``` 125 | {your_cuda_path}\NVIDIA GPU Computing Toolkit\CUDA\v10.x\include 126 | {your_tensorrt_path}\TensorRT-7.0.0.xx\include 127 | ``` 128 | 129 | Additional library directories: 130 | ``` 131 | {your_cuda_path}\NVIDIA GPU Computing Toolkit\CUDA\v10.x\lib\x64 132 | {your_tensorrt_path}\TensorRT-7.0.0.xx\lib 133 | ``` 134 | 135 | Additional library files: 136 | ``` 137 | cudart.lib 138 | nvparsers.lib 139 | nvonnxparser.lib 140 | nvinfer.lib 141 | ``` 142 | 143 | Add each of the above and build it. 144 | 145 | ## 6. License 146 | [GPLv3](https://github.com/ymgaq/AQ/blob/master/LICENSE.txt) 147 | Author: [Yu Yamaguchi](https://twitter.com/ymg_aq) 148 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | # GLOBIS-AQZ (AQ) 2 | 3 | GLOBIS-AQZ是一个使用深度学习技术的围棋引擎。 4 | 它的特点是既支持日本规则,也支持中国规则。 5 | 6 | 该项目利用GLOBIS-AQZ项目的结果。 7 | 8 | > GLOBIS-AQZ is a joint project developed by GLOBIS Corporation, Mr. Yu Yamaguchi, and Tripleize Co., Ltd., provided by the National Institute of Advanced Industrial Science and Technology (AIST), and cooperated by the Nihon Ki-in. This program uses the result of GLOBIS-AQZ. 9 | 10 | 由于它是开源软件,任何人都可以免费使用。 11 | 本程序是用来玩游戏和分析游戏的,请将其设置为[Lizzie](https://github.com/featurecat/lizzie)、[Sabaki](https://github.com/SabakiHQ/Sabaki)、[GoGui](https://sourceforge.net/projects/gogui/)等GUI软件。 12 | 13 | 请注意,此描述为机器翻译,因此可能存在不准确的地方。 14 | 15 | Please see [here](https://github.com/ymgaq/AQ/blob/master/README.md) for an explanation in English. 16 | 日本語の説明は[こちら](https://github.com/ymgaq/AQ/blob/master/README_JP.md)をご覧ください。 17 | 18 | ## 1. 下载 19 | 从[Releases](https://github.com/ymgaq/AQ/releases)中下载. 20 | Windows 10和Linux(Ubuntu 18.04)上构建的可执行文件。 21 | 22 | 如果它在其他环境中无法正常运行,请考虑为每个环境构建它(针对开发者)。 23 | 24 | ## 2. 动作环境要求 25 | + OS : Windows 10, Linux 26 | + GPU : Nvidia's GPU ([Compute Capability](https://developer.nvidia.com/cuda-gpus) >3.0) 27 | + [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit) 10.0 or 10.2 + [cuDNN](https://developer.nvidia.com/cudnn) 7.6.5 28 | + [TensorRT 7.0.0](https://docs.nvidia.com/deeplearning/sdk/tensorrt-archived/tensorrt-700/tensorrt-install-guide/index.html) 29 | + [适用于 Visual Studio 2015、2017 和 2019 的 Microsoft Visual C++ 可再发行软件包](https://support.microsoft.com/zh-cn/help/2977003/the-latest-supported-visual-c-downloads) (仅Windows) 30 | 31 | 它在以下环境中进行了测试: 32 | + Ubuntu 18.04 / RTX2080Ti / CUDA10.0 / TensorRT7.0.0 33 | + Windows 10 Pro (64bit) / RTX2080Ti / CUDA10.2 / TensorRT7.0.0 34 | 35 | ## 3. 如何使用 36 | 例如,根据日本的规则,如果你在GTP模式下开始的时间是20分钟+30秒: 37 | ``` 38 | $ AQ.exe --rule=1 --komi=6.5 --main_time=1200 --byoyomi=30 39 | ``` 40 | 用中国规则(默认),要将出局数(playouts)固定为800,并在没有"ponder"的情况下开始: 41 | ``` 42 | $ AQ.exe --search_limit=800 --use_ponder=off 43 | ``` 44 | 用Tromp-Taylor规则,时间定在15分钟(这是一个[CGOS](http://www.yss-aya.com/cgos/)的设置): 45 | ``` 46 | $ AQ.exe --rule=2 --repetition_rule=2 --main_time=900 --byoyomi=0 47 | ``` 48 | 49 | ### 3-1. 设置环境变量 50 | 在Windows的情况下,必须在PATH环境变量中注册以下路径。 51 | ``` 52 | {your_cuda_path}\NVIDIA GPU Computing Toolkit\CUDA\v10.{x}\bin 53 | {your_tensorrt_path}\TensorRT-7.0.0.{xx}\lib 54 | ``` 55 | 56 | ### 3-2. 生成引擎文件 57 | 第一次启动时,它会从UFF(Universal File Format)格式的文件中生成一个为您的环境优化的网络引擎。 58 | 可能需要几分钟的时间来生成这个引擎。 59 | 序列化的引擎文件被保存在`engine`文件夹中,所以它将会立即启动第二次。 60 | 61 | ### 3-3. 向Lizzie注册 62 | 对于Windows,在引擎命令中添加`{your_aq_folder}/AQ.exe --lizzie`。 63 | 例如,如果你想用日本规则分析,请修改AQ文件夹中的config.txt文件,使用各种设置。 64 | 65 | ## 4. 选项 66 | 以下是对主要选项的描述。 67 | 它可以作为命令行参数指定,也可以通过编辑config.txt来改变。 68 | 例如,`--komi=6.5`。 69 | 70 | ### 4-1. 游戏选项 71 | | 选项 | 缺省 | 描述 | 72 | | :--- | :--- | :--- | 73 | | --num_gpus | 1 | 要使用的GPU数量。 | 74 | | --num_threads | 16 | 用于搜索的线程数量。 | 75 | | --main_time | 0.0 | 搜索的主要时间(单位:秒)。 | 76 | | --byoyomi | 3.0 | 倒计时时间(单位:秒)。 | 77 | | --rule | 0 | 的游戏规则。 0:中国规则 1:日本规则 2:Tromp-Taylor规则 | 78 | | --komi | 7.5 | Komi的数量。在日本规定的情况下,请注明6.5。 | 79 | | --batch_size | 8 | 一次评价的批次数。 | 80 | | --search_limit | -1 | 搜索次数(playouts)。-1表示该选项被禁用。 | 81 | | --node_size | 65536 | 搜索的最大节点数。当达到这个节点数时,搜索结束。 | 82 | | --use_ponder | on | 是否要在对手的回合中提前阅读。 在Lizzie中使用时必须打开它。 | 83 | | --resign_value | 0.05 | 放弃的胜率。 | 84 | | --save_log | off | 是否保存游戏中的思想记录和sgf文件。 | 85 | 86 | ### 4-2. 启动模式 87 | 主要是用来调试的。请不要使用`--lizzie`以外的任何其他游戏进行正常的游戏和分析。 88 | 它们只被认可为命令行参数。 89 | 90 | | 选项 | 启动模式 | 91 | | :--- | :--- | 92 | | (不详) | GTP通信模式 | 93 | | --lizzie | 除了GTP通讯外,它还能为Lizzie输出信息。 | 94 | | --self | 开始自我匹配。 | 95 | | --policy_self | 它以policy network的最大手笔开始自我匹配。 | 96 | | --test | 测试板式数据结构的一致性等。 | 97 | | --benchmark | 它可以衡量推出和神经网络的计算速度。 | 98 | 99 | ## 5. 汇编方法 100 | 以下是对开发者的解释。 101 | 源代码只实现了游戏和分析,不包含任何学习功能。 102 | 103 | AQ的编写是为了能用C++11/C++14进行编译,编码约定一般参考下面的页面。 104 | + [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html) 105 | 106 | ### 5-1. Linux 107 | Requirements 108 | + gcc 109 | + make 110 | + CUDA Toolkit 10.x 111 | + TensorRT 7.0.0 112 | 113 | 在Makefile中检查CUDA和TensorRT的include路径和库路径。 114 | 115 | ``` 116 | $ make 117 | ``` 118 | 119 | ### 5-2. Windows 120 | Requirements 121 | + Visual Studio 2019 (MSVC v142) 122 | + CUDA Toolkit 10.x 123 | + TensorRT 7.0.0 124 | 125 | Additional include directories: 126 | ``` 127 | {your_cuda_path}\NVIDIA GPU Computing Toolkit\CUDA\v10.x\include 128 | {your_tensorrt_path}\TensorRT-7.0.0.xx\include 129 | ``` 130 | 131 | Additional library directories: 132 | ``` 133 | {your_cuda_path}\NVIDIA GPU Computing Toolkit\CUDA\v10.x\lib\x64 134 | {your_tensorrt_path}\TensorRT-7.0.0.xx\lib 135 | ``` 136 | 137 | Additional library files: 138 | ``` 139 | cudart.lib 140 | nvparsers.lib 141 | nvonnxparser.lib 142 | nvinfer.lib 143 | ``` 144 | 145 | 执行上述设置并编译。 146 | 147 | ## 6. License 148 | [GPLv3](https://github.com/ymgaq/AQ/blob/master/LICENSE.txt) 149 | 作者: [山口 祐](https://twitter.com/ymg_aq) 150 | -------------------------------------------------------------------------------- /README_JP.md: -------------------------------------------------------------------------------- 1 | # 囲碁AI 「GLOBIS-AQZ」 2 | 3 | 「GLOBIS-AQZ」はDeep Learning技術を利用した囲碁の思考エンジンです。 4 | 日本ルール6目半と中国ルール7目半の両方に対応していることが特徴です。 5 | 6 | このプログラムはGLOBIS-AQZプロジェクトの成果を利用しています。 7 | 8 | > GLOBIS-AQZは、開発:株式会社グロービス、山口祐氏、株式会社トリプルアイズ、開発環境の提供:国立研究開発法人 産業技術総合研究所、協力:公益財団法人日本棋院のメンバーによって取り組んでいる共同プロジェクトです。このプログラムは、GLOBIS-AQZでの試算を活用しています。 9 | 10 | オープンソース・ソフトウェアですので、どなたでも無料で使用することができます。 11 | 対局・解析のためのプログラムですので、「[Lizzie](https://github.com/featurecat/lizzie)」「[Sabaki](https://github.com/SabakiHQ/Sabaki)」「[GoGui](https://sourceforge.net/projects/gogui/)」といったGUIソフトに設定して利用してください。 12 | 13 | Please see [here](https://github.com/ymgaq/AQ/blob/master/README.md) for an explanation in English. 14 | 请看[这里的](https://github.com/ymgaq/AQ/blob/master/README_CN.md)中文解释. 15 | 16 | ## 1. ダウンロード 17 | [Releases](https://github.com/ymgaq/AQ/releases)からダウンロードしてください。 18 | Windows 10、 Linuxでビルドした実行ファイルが利用できます。 19 | 20 | それ以外の環境でそのままでは動作しない場合、5.ビルド方法を参考に各環境ごとにビルドを検討してください。(開発者向け) 21 | 22 | ## 2. 動作環境 23 | + OS : Windows 10, Linux 24 | + GPU : Nvidia製GPU ([Compute Capability](https://developer.nvidia.com/cuda-gpus) >3.0) 25 | + [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit) 10.0 or 10.2 + [cuDNN](https://developer.nvidia.com/cudnn) 7.6.5 26 | + [TensorRT 7.0.0](https://docs.nvidia.com/deeplearning/sdk/tensorrt-archived/tensorrt-700/tensorrt-install-guide/index.html) 27 | + [Visual Studio 2015、2017、および 2019 用 Microsoft Visual C++ 再頒布可能パッケージ](https://support.microsoft.com/ja-jp/help/2977003/the-latest-supported-visual-c-downloads) (Windowsのみ) 28 | 29 | 下記の環境で動作確認をしています。 30 | + Ubuntu 18.04 / RTX2080Ti / CUDA10.0 / TensorRT7.0.0 31 | + Windows 10 Pro (64bit) / RTX2080Ti / CUDA10.2 / TensorRT7.0.0 32 | 33 | ## 3. 使い方 34 | 例えば、日本ルール・コミ6目半で持ち時間20分、切れたら30秒でGTPモードを起動する場合: 35 | ``` 36 | $ AQ.exe --rule=1 --komi=6.5 --main_time=1200 --byoyomi=30 37 | ``` 38 | 中国ルール・コミ7目半(デフォルト)で探索数(playouts)800固定、ポンダーなしの場合: 39 | ``` 40 | $ AQ.exe --search_limit=800 --use_ponder=off 41 | ``` 42 | Tromp-Taylorルール・コミ7目半で15分切れ負けの場合 ([CGOS](http://www.yss-aya.com/cgos/)の設定です): 43 | ``` 44 | $ AQ.exe --rule=2 --repetition_rule=2 --main_time=900 --byoyomi=0 45 | ``` 46 | 47 | ### 3-1. 環境変数の設定 48 | Windowsの場合、環境変数のPATHに以下のようなパスが登録されている必要があります。 49 | ``` 50 | {your_cuda_path}\NVIDIA GPU Computing Toolkit\CUDA\v10.{x}\bin 51 | {your_tensorrt_path}\TensorRT-7.0.0.{xx}\lib 52 | ``` 53 | 54 | ### 3-2. エンジンファイルの生成 55 | 初回起動時に、UFF(Universal File Format)形式のファイルからお手持ちの環境に最適化されたネットワークエンジンを生成します。 56 | このエンジン生成には数分程度かかることがあります。 57 | シリアライズ化されたエンジンファイルが`engine`フォルダに保存されるので、2回目以降はすぐに起動します。 58 | 59 | ### 3-3. Lizzieへの登録 60 | engineコマンドにWindowsの場合は `{your_aq_folder}/AQ.exe --lizzie` を登録してください。 61 | 日本ルールで解析させたい場合など、各種設定はAQフォルダ内のconfig.txtを修正してご利用ください。 62 | 63 | ## 4. オプション 64 | 主なオプションについての説明です。 65 | コマンドライン引数として指定できる他、config.txtを編集することでも変更可能です。 66 | `--komi=6.5`のように指定してください。 67 | 68 | ### 4-1. 対局オプション 69 | | オプション | デフォルト値 | 説明 | 70 | | :--- | :--- | :--- | 71 | | --num_gpus | 1 | 使用するGPU数です。 | 72 | | --num_threads | 16 | 探索に使用するスレッド数です。 | 73 | | --main_time | 0.0 | 探索の持ち時間(秒)です。 | 74 | | --byoyomi | 3.0 | 秒読みの時間(秒)です。 | 75 | | --rule | 0 | 対局のルールです。 0:中国ルール 1:日本ルール 2:Tromp-Taylorルール | 76 | | --komi | 7.5 | コミ数です。日本ルールの場合は6.5を指定してください。 | 77 | | --batch_size | 8 | 局面の評価を行うバッチ数です。 | 78 | | --search_limit | -1 | 探索回数(playout)。-1で無制限になります。 | 79 | | --node_size | 65536 | 探索の最大ノード数。このノード数に達すると探索を終了します。 | 80 | | --use_ponder | on | 相手の手番で先読みをします。Lizzieで使用するときはonにしてください。 | 81 | | --resign_value | 0.05 | 投了する勝率です。 | 82 | | --save_log | off | 対局の思考ログ・棋譜を保存するかの設定です。 | 83 | 84 | ### 4-2. 起動モード 85 | 主にデバッグ用の機能です。`--lizzie`以外は通常の対局・解析用途には使用しないでください。 86 | コマンドライン引数としてのみ認識されます。 87 | 88 | | オプション | 起動モード | 89 | | :--- | :--- | 90 | | (指定なし) | GTP通信モード | 91 | | --lizzie | GTP通信に加え、Lizzie用の情報を出力します | 92 | | --self | 自己対局を行います。 | 93 | | --policy_self | ポリシーネットワークの最大の手で自己対局を行います。 | 94 | | --test | 盤面データ構造の整合性などをテストします。 | 95 | | --benchmark | ロールアウトやニューラルネットワークの計算速度を測定します。 | 96 | 97 | ## 5. ビルド方法 98 | 以下は開発者向けの説明になります。 99 | なお、公開しているソースコードは対局・解析のみの実装で、学習に関する機能は含まれていません。 100 | 101 | AQは、C++11/C++14でコンパイルできるように書かれており、コーディング規約等は概ね以下のページを参考にしています。 102 | + コーディング規約: [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html) 103 | 104 | ### 5-1. Linux 105 | Requirements 106 | + gcc 107 | + make 108 | + CUDA Toolkit 10.x 109 | + TensorRT 7.0.0 110 | 111 | Makefile内のCUDA・TensorRTのインクルードパス・ライブラリパスを確認し、makeしてください。 112 | 113 | ``` 114 | make 115 | ``` 116 | 117 | ### 5-2. Windows 118 | Requirements 119 | + Visual Studio 2019 (MSVC v142) 120 | + CUDA Toolkit 10.x 121 | + TensorRT 7.0.0 122 | 123 | インクルードディレクトリに 124 | ``` 125 | {your_cuda_path}\NVIDIA GPU Computing Toolkit\CUDA\v10.x\include 126 | {your_tensorrt_path}\TensorRT-7.0.0.xx\include 127 | ``` 128 | 129 | 追加のライブラリディレクトリに 130 | ``` 131 | {your_cuda_path}\NVIDIA GPU Computing Toolkit\CUDA\v10.x\lib\x64 132 | {your_tensorrt_path}\TensorRT-7.0.0.xx\lib 133 | ``` 134 | 135 | 追加のライブラリに 136 | ``` 137 | cudart.lib 138 | nvparsers.lib 139 | nvonnxparser.lib 140 | nvinfer.lib 141 | ``` 142 | 143 | をそれぞれ追加してビルドしてくだい。 144 | 145 | ## 6. ライセンス 146 | [GPLv3](https://github.com/ymgaq/AQ/blob/master/LICENSE.txt) 147 | 開発者: [山口 祐](https://twitter.com/ymg_aq) 148 | -------------------------------------------------------------------------------- /config.txt: -------------------------------------------------------------------------------- 1 | #### --- Hardware --- #### 2 | 3 | # Number of GPUs. [1-8] 4 | --num_gpus=1 5 | 6 | # Number of threads for search. [1-512] 7 | # Typically, when using batch_size=n, 2*n*num_gpus threads 8 | # are used for NN searches, and the rest (at least 1) are 9 | # used for rollouts. 10 | # It is recommended to set double of batch_size * num_gpus 11 | # unless there is a reason to do so. 12 | --num_threads=16 13 | 14 | #### --- Rule --- #### 15 | 16 | # Rule of game. 17 | # 0: Chinese, 1: Japanese, 2: Tromp-Tralor 18 | --rule=0 19 | 20 | # Repetition rule. 21 | # This is a judgment method when the same board is repeated. 22 | # 0: Draw, 1: Super Ko, 2: Tromp-Tralor 23 | --repetition_rule=0 24 | 25 | # Komi. Use 6.5 for the Japanese rule. 26 | --komi=7.5 27 | 28 | #### --- Time control --- #### 29 | 30 | # Main time. (in seconds) 31 | --main_time=0.0 32 | 33 | # Japanese byoyomi time. (in seconds) 34 | --byoyomi=3.0 35 | 36 | # Number of thinking-time extensions of byoyomi. 37 | --num_extensions=0 38 | 39 | # Margin time when thinking in byoyomi. (in seconds) 40 | # Use if you want to account for network delay. 41 | --byoyomi_margin=0.0 42 | 43 | # Threshold of remaining time that AQ returns 44 | # a move without search. (in seconds) 45 | # Used in 'sudden death' time setting. 46 | --emergency_time=15.0 47 | 48 | #### --- Search --- #### 49 | 50 | # Batch size for evaluation in search. [1, 8] 51 | --batch_size=8 52 | 53 | # Searching limit of evaluation. [-1, 100000] 54 | # AQ stops thinking when the number of evaluated boards 55 | # reaches search_limit. 56 | # '--search_limit=-1' means that this option is disabled. 57 | --search_limit=-1 58 | 59 | # Maximum number of nodes. 60 | # When this number of nodes is reached, the search is terminated. 61 | # AQ uses about 1.3GB of memory per 100000 nodes. 62 | --node_size=65536 63 | 64 | # Whether using pondering. 65 | --use_ponder=on 66 | 67 | # Threshold of winning rate that AQ resigns. 68 | --resign_value=0.05 69 | 70 | # Save the thought log file in the log directory. 71 | --save_log=off 72 | -------------------------------------------------------------------------------- /engine/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ymgaq/AQ/d0e07e9822e1aec0ab8289f11193c741d9223fb0/engine/.gitkeep -------------------------------------------------------------------------------- /log/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ymgaq/AQ/d0e07e9822e1aec0ab8289f11193c741d9223fb0/log/.gitkeep -------------------------------------------------------------------------------- /obj/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ymgaq/AQ/d0e07e9822e1aec0ab8289f11193c741d9223fb0/obj/.gitkeep -------------------------------------------------------------------------------- /src/bitboard.h: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #ifndef BITBOARD_H_ 21 | #define BITBOARD_H_ 22 | 23 | #include 24 | #include 25 | 26 | #include "./types.h" 27 | 28 | #ifdef __GNUC__ 29 | /** 30 | * Function for finding an unsigned 64-integer NTZ. 31 | * Returns closest position of 1 from lower bit. 32 | * 33 | * 0b01011000 -> 3 34 | * 0b0 -> 64 35 | */ 36 | static constexpr int ntz(uint64_t x) noexcept { return __builtin_ctzll(x); } 37 | 38 | static constexpr int PopCount(uint64_t x) noexcept { 39 | return __builtin_popcountll(x); 40 | } 41 | #else // Linux 42 | constexpr uint64_t kNtzMagic64 = 0x03F0A933ADCBD8D1ULL; 43 | constexpr int kNtzTable64[127] = { 44 | 64, 0, -1, 1, -1, 12, -1, 2, 60, -1, 13, -1, -1, 53, -1, 3, 61, -1, -1, 45 | 21, -1, 14, -1, 42, -1, 24, 54, -1, -1, 28, -1, 4, 62, -1, 58, -1, 19, -1, 46 | 22, -1, -1, 17, 15, -1, -1, 33, -1, 43, -1, 50, -1, 25, 55, -1, -1, 35, -1, 47 | 38, 29, -1, -1, 45, -1, 5, 63, -1, 11, -1, 59, -1, 52, -1, -1, 20, -1, 41, 48 | 23, -1, 27, -1, -1, 57, 18, -1, 16, -1, 32, -1, 49, -1, -1, 34, 37, -1, 44, 49 | -1, -1, 10, -1, 51, -1, 40, -1, 26, 56, -1, -1, 31, 48, -1, 36, -1, 9, -1, 50 | 39, -1, -1, 30, 47, -1, 8, -1, -1, 46, 7, -1, 6, 51 | }; 52 | /** 53 | * Function for finding an unsigned 64-integer NTZ. 54 | * Returns closest position of 1 from lower bit. 55 | * 56 | * 0b01011000 -> 3 57 | * 0b0 -> 64 58 | */ 59 | static constexpr int ntz(uint64_t x) noexcept { 60 | return kNtzTable64[static_cast(kNtzMagic64 * 61 | static_cast(x & -x)) >> 62 | 57]; 63 | } 64 | 65 | static int PopCount(uint64_t x) noexcept { return __popcnt64(x); } 66 | #endif 67 | 68 | /** 69 | * Returns the index of the bitboard containing v. 70 | */ 71 | inline int v2bb_idx(Vertex v) { return kCoordTable.v2bb_idx_table[v]; } 72 | 73 | /** 74 | * Returns the bit where v is located on the bitboard. 75 | */ 76 | inline uint64_t v2bb_bit(Vertex v) { return kCoordTable.v2bb_bit_table[v]; } 77 | 78 | /** 79 | * Return a vertex from the bitboard. 80 | */ 81 | inline Vertex bb2v(int bb_idx, int bit) { 82 | return kCoordTable.bb2v_table[bb_idx][bit]; 83 | } 84 | 85 | // -------------------- 86 | // Bitboard 87 | // -------------------- 88 | 89 | /** 90 | * @class Bitboard 91 | * Bitboard class is composed of kNumBBs of 64-bit integers. 92 | * (kNumBBs=6 in 19x19 board, 2 in 9x9 board) 93 | * 94 | * Mainly, the board coordinates are handled in a one-dimensional 95 | * system, but sometimes Bitboard may be faster, such as when 96 | * checking liberty verteces of stones. 97 | */ 98 | class Bitboard { 99 | public: 100 | // Constructor 101 | Bitboard() : p_{0}, num_bits_(0) {} 102 | 103 | Bitboard(const Bitboard& rhs) : num_bits_(rhs.num_bits_) { 104 | for (int i = 0; i < kNumBBs; ++i) p_[i] = rhs.p_[i]; 105 | } 106 | 107 | Bitboard& operator=(const Bitboard& rhs) { 108 | for (int i = 0; i < kNumBBs; ++i) p_[i] = rhs.p_[i]; 109 | num_bits_ = rhs.num_bits_; 110 | 111 | return *this; 112 | } 113 | 114 | bool operator==(const Bitboard& rhs) const { 115 | bool is_equal = (num_bits_ == rhs.num_bits_); 116 | for (int i = 0; i < kNumBBs; ++i) is_equal &= (p_[i] == rhs.p_[i]); 117 | 118 | return is_equal; 119 | } 120 | 121 | int num_bits() const { return num_bits_; } 122 | 123 | void set_num_bits(int val) { num_bits_ = val; } 124 | 125 | uint64_t p(int idx) const { return p_[idx]; } 126 | 127 | void set_p(int idx, uint64_t val) { p_[idx] = val; } 128 | 129 | void Init() { 130 | for (int i = 0; i < kNumBBs; ++i) p_[i] = 0; 131 | num_bits_ = 0; 132 | } 133 | 134 | /** 135 | * Adds a vertex to bitboard. 136 | */ 137 | void Add(Vertex v) { 138 | ASSERT_LV2(kVtZero <= v && v < kNumVts); 139 | ASSERT_LV2(0 <= v2bb_idx(v) && v2bb_idx(v) < kNumBBs); 140 | ASSERT_LV2(((v2bb_bit(v) - 1) & v2bb_bit(v)) == 0); // single bit 141 | 142 | auto& p_op = p_[v2bb_idx(v)]; 143 | uint64_t bit_v = v2bb_bit(v); 144 | 145 | num_bits_ += static_cast((p_op & bit_v) == 0); 146 | p_op |= bit_v; 147 | } 148 | 149 | /** 150 | * Deletes a vertex from bitboard. 151 | */ 152 | void Remove(Vertex v) { 153 | ASSERT_LV2(kVtZero <= v && v < kNumVts); 154 | ASSERT_LV2(0 <= v2bb_idx(v) && v2bb_idx(v) < kNumBBs); 155 | ASSERT_LV2(((v2bb_bit(v) - 1) & v2bb_bit(v)) == 0); // single bit 156 | 157 | auto& p_op = p_[v2bb_idx(v)]; 158 | uint64_t bit_v = v2bb_bit(v); 159 | 160 | num_bits_ -= static_cast((p_op & bit_v) != 0); 161 | p_op &= ~bit_v; 162 | } 163 | 164 | /** 165 | * Merges with another bitboard. 166 | */ 167 | void Merge(const Bitboard& rhs) { 168 | num_bits_ = 0; 169 | for (int i = 0; i < kNumBBs; ++i) { 170 | p_[i] |= rhs.p_[i]; 171 | 172 | if (p_[i]) num_bits_ += PopCount(p_[i]); 173 | } 174 | } 175 | 176 | /** 177 | * Gets list of vertices. 178 | */ 179 | std::vector Vertices() const { 180 | std::vector vs; 181 | int num_op = num_bits_; 182 | int bb_idx = 0; 183 | uint64_t p_op = p_[0]; 184 | 185 | while (num_op > 0) { 186 | if (p_op == 0) { 187 | p_op = p_[++bb_idx]; 188 | continue; 189 | } 190 | 191 | int bit_idx = ntz(p_op); 192 | 193 | ASSERT_LV2(0 <= bb_idx && bb_idx < kNumBBs); 194 | ASSERT_LV2(0 <= bit_idx && bit_idx < 64); 195 | ASSERT_LV2((p_op & (0x1ULL << bit_idx)) != 0); 196 | 197 | vs.push_back(bb2v(bb_idx, bit_idx)); 198 | p_op ^= (0x1ULL << bit_idx); 199 | --num_op; 200 | } 201 | 202 | ASSERT_LV2(static_cast(vs.size()) == num_bits_); 203 | 204 | return std::move(vs); 205 | } 206 | 207 | /** 208 | * Returns the first vertex. 209 | */ 210 | Vertex FirstVertex() const { 211 | for (int i = 0; i < kNumBBs; ++i) 212 | if (p_[i] != 0) return bb2v(i, ntz(p_[i])); 213 | 214 | return kVtNull; 215 | } 216 | 217 | /** 218 | * Outputs bitboard information. (for debug) 219 | */ 220 | friend std::ostream& operator<<(std::ostream& os, const Bitboard& bb) { 221 | os << "num_bits_=" << bb.num_bits_; 222 | os << " p_: "; 223 | for (auto& rs : bb.Vertices()) os << rs << " "; 224 | os << std::endl; 225 | return os; 226 | } 227 | 228 | private: 229 | uint64_t p_[kNumBBs]; // 64-bit integers that make up the bitboard. 230 | int num_bits_; // Number of vertex bits in the bitboard. 231 | }; 232 | 233 | // -------------------- 234 | // StoneGroup 235 | // -------------------- 236 | 237 | /** 238 | * @class StoneGroup 239 | * StoneGroup class contains information in adjacent stones, 240 | * such as liberty points or stone count. 241 | * 242 | * Stones are taken when adjacent liberties becomes 0. 243 | */ 244 | class StoneGroup { 245 | public: 246 | // Constructor 247 | StoneGroup() : liberty_atari_(kVtNull), num_stones_(1) {} 248 | 249 | StoneGroup(const StoneGroup& rhs) 250 | : liberty_atari_(rhs.liberty_atari_), 251 | num_stones_(rhs.num_stones_), 252 | bb_liberties_(rhs.bb_liberties_) {} 253 | 254 | StoneGroup& operator=(const StoneGroup& rhs) { 255 | num_stones_ = rhs.num_stones_; 256 | liberty_atari_ = rhs.liberty_atari_; 257 | bb_liberties_ = rhs.bb_liberties_; 258 | 259 | return *this; 260 | } 261 | 262 | bool operator==(const StoneGroup& rhs) const { 263 | return num_stones_ == rhs.num_stones_ && 264 | liberty_atari_ == rhs.liberty_atari_ && 265 | bb_liberties_ == rhs.bb_liberties_; 266 | } 267 | 268 | /** 269 | * Returns number of stones. 270 | */ 271 | int size() const { return num_stones_; } 272 | 273 | Vertex liberty_atari() const { return liberty_atari_; } 274 | 275 | Bitboard bb_liberties() const { return bb_liberties_; } 276 | 277 | int num_liberties() const { return bb_liberties_.num_bits(); } 278 | 279 | std::vector lib_vertices() const { 280 | return std::move(bb_liberties_.Vertices()); 281 | } 282 | 283 | bool captured() const { return bb_liberties_.num_bits() == 0; } 284 | 285 | bool atari() const { return bb_liberties_.num_bits() == 1; } 286 | 287 | bool pre_atari() const { return bb_liberties_.num_bits() == 2; } 288 | 289 | void Init() { 290 | liberty_atari_ = kVtNull; 291 | num_stones_ = 1; 292 | bb_liberties_.Init(); 293 | } 294 | 295 | void SetNull() { 296 | liberty_atari_ = kVtNull; 297 | bb_liberties_.Init(); 298 | bb_liberties_.set_num_bits(int{kVtNull}); 299 | num_stones_ = int{kVtNull}; 300 | } 301 | 302 | /** 303 | * Adds a stone at v to this group. 304 | */ 305 | void Add(Vertex v) { 306 | if (bb_liberties_.num_bits() == kVtNull) return; // wall 307 | 308 | bb_liberties_.Add(v); 309 | // liberty_atari_ is called only when num_liberties == 1 310 | liberty_atari_ = v; 311 | } 312 | 313 | /** 314 | * Removes a stone at v from this group. 315 | */ 316 | void Remove(Vertex v) { 317 | if (bb_liberties_.num_bits() == kVtNull) return; // wall 318 | 319 | bb_liberties_.Remove(v); 320 | if (bb_liberties_.num_bits() == 1) 321 | liberty_atari_ = bb_liberties_.FirstVertex(); 322 | } 323 | 324 | /** 325 | * Merges with another stone group. 326 | */ 327 | void Merge(const StoneGroup& rhs) { 328 | bb_liberties_.Merge(rhs.bb_liberties_); 329 | if (bb_liberties_.num_bits() == 1) 330 | liberty_atari_ = bb_liberties_.FirstVertex(); 331 | num_stones_ += rhs.num_stones_; 332 | } 333 | 334 | /** 335 | * Outputs StoneGroup information. (for debug) 336 | */ 337 | friend std::ostream& operator<<(std::ostream& os, const StoneGroup& sg) { 338 | os << "liberty_atari_=" << sg.liberty_atari_ 339 | << " num_stones_=" << sg.num_stones_ << std::endl; 340 | os << "bb_liberties_: " << sg.bb_liberties_ << std::endl; 341 | return os; 342 | } 343 | 344 | private: 345 | Bitboard bb_liberties_; // Bitboard of liberties. 346 | Vertex liberty_atari_; // Vertex of liberty when in Atari. 347 | int num_stones_; // Number of stones in this stone group. 348 | }; 349 | 350 | #endif // BITBOARD_H_ 351 | -------------------------------------------------------------------------------- /src/config.h: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #ifndef CONFIG_H_ 21 | #define CONFIG_H_ 22 | 23 | // -------------------- 24 | // include 25 | // -------------------- 26 | 27 | #ifdef _WIN32 28 | #define COMPILER_MSVC 29 | #ifndef NOMINMAX 30 | #define NOMINMAX 31 | #endif 32 | #include 33 | #include 34 | #ifndef UINT64_MAX 35 | #define UINT64_MAX 0xffffffffffffffffULL 36 | #endif 37 | #else 38 | #include 39 | #endif 40 | 41 | // -------------------- 42 | // board size 43 | // -------------------- 44 | 45 | #ifndef BOARD_SIZE 46 | #define BOARD_SIZE 19 47 | #endif 48 | 49 | // -------------------- 50 | // assertion 51 | // -------------------- 52 | 53 | // #define USE_DEBUG_ASSERT 54 | 55 | /** 56 | * ASSERT, which will not be disabled even if it is not a DEBUG build (since 57 | * normal asserts will be disabled). 58 | * 59 | * Deliberately causing a memory access violation. 60 | * When USE_DEBUG_ASSERT is enabled, wait 3 seconds after outputting the 61 | * contents of ASSERT before executing the code that causes an access violation. 62 | * 63 | * This code has been based on the following link code as a reference; 64 | * https://github.com/yaneurao/YaneuraOu/blob/master/source/config.h 65 | * Revision date: 5/1/2020 66 | */ 67 | #if !defined(USE_DEBUG_ASSERT) 68 | #define ASSERT(X) \ 69 | { \ 70 | if (!(X)) *reinterpret_cast(1) = 0; \ 71 | } 72 | #else 73 | #define ASSERT(X) \ 74 | { \ 75 | if (!(X)) { \ 76 | std::cout << "\nError : ASSERT(" << #X << ")" << std::endl; \ 77 | std::this_thread::sleep_for(std::chrono::microseconds(3000)); \ 78 | *reinterpret_cast(1) = 0; \ 79 | } \ 80 | } 81 | #endif 82 | 83 | #ifndef ASSERT_LV 84 | #define ASSERT_LV 0 85 | #endif 86 | 87 | #define ASSERT_LV_EX(L, X) \ 88 | { \ 89 | if (L <= ASSERT_LV) ASSERT(X); \ 90 | } 91 | #define ASSERT_LV1(X) ASSERT_LV_EX(1, X) 92 | #define ASSERT_LV2(X) ASSERT_LV_EX(2, X) 93 | #define ASSERT_LV3(X) ASSERT_LV_EX(3, X) 94 | 95 | #endif // CONFIG_H_ 96 | -------------------------------------------------------------------------------- /src/eval_cache.h: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #ifndef EVAL_CACHE_H_ 21 | #define EVAL_CACHE_H_ 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | 31 | #include "./node.h" 32 | 33 | /** 34 | * @struct ValueAndProb 35 | * Structure that holds value and policy. 36 | */ 37 | struct ValueAndProb { 38 | double value; 39 | std::array prob; // Doesn't include kPass. 40 | 41 | // Constructor 42 | ValueAndProb() : value(0.0), prob{0.0} {} 43 | 44 | ValueAndProb(const ValueAndProb& rhs) : value(rhs.value), prob(rhs.prob) {} 45 | 46 | ValueAndProb& operator=(const ValueAndProb& rhs) { 47 | value = rhs.value; 48 | prob = rhs.prob; 49 | return *this; 50 | } 51 | }; 52 | 53 | /** 54 | * @struct SyncedEntry 55 | * Feature and ValueAndProb structures for synchronization in evaluation. 56 | */ 57 | struct SyncedEntry { 58 | std::mutex mx; 59 | std::condition_variable cv; 60 | Feature ft; 61 | ValueAndProb vp; 62 | 63 | // Constructor 64 | explicit SyncedEntry(const Feature& ft_) : ft(ft_) {} 65 | }; 66 | 67 | /** 68 | * @class EvalCache 69 | * EvalCache class exclusively manages caches of ValueAndProb. 70 | * 71 | * @code 72 | * ValueAndProb vp; 73 | * bool found = eval_cache.Probe(b, &vp); 74 | * if(!found) { 75 | * engine.Infer(b.get_feature(), &vp); 76 | * eval_cache.Insert(b.key(), vp); 77 | * } 78 | * @endcode 79 | */ 80 | class EvalCache { 81 | public: 82 | // Constructor 83 | explicit EvalCache(int size = 10000) { max_size_ = size; } 84 | 85 | void Resize(int size) { 86 | max_size_ = size; 87 | while (order_.size() > max_size_) { 88 | vp_map_.erase(order_.front()); 89 | order_.pop_front(); 90 | } 91 | } 92 | 93 | void Init() { 94 | vp_map_.clear(); 95 | order_.clear(); 96 | } 97 | 98 | bool Probe(const Board& b, ValueAndProb* vp, bool check_sym = true) { 99 | Key key = b.key(); 100 | bool found = false; 101 | 102 | std::lock_guard lk(mx_); 103 | 104 | if (check_sym && b.game_ply() < kNumRvts / 12) { 105 | for (int i = 0; i < 8; ++i) { 106 | Key sym_hash = b.key(i); 107 | auto itr = vp_map_.find(sym_hash); 108 | found = (itr != vp_map_.end()); 109 | if (found) { 110 | if (i == 0) { 111 | *vp = *itr->second; 112 | } else { 113 | vp->value = itr->second->value; 114 | for (int j = 0; j < kNumRvts; ++j) 115 | vp->prob[rv2sym(j, i)] = itr->second->prob[j]; 116 | 117 | AddCache(key, *vp); 118 | } 119 | return true; 120 | } 121 | } 122 | 123 | return false; 124 | } else { 125 | auto itr = vp_map_.find(key); 126 | if (itr == vp_map_.end()) return false; 127 | *vp = *itr->second; 128 | } 129 | 130 | return true; 131 | } 132 | 133 | bool Probe(Key key, ValueAndProb* vp) { 134 | std::lock_guard lk(mx_); 135 | auto itr = vp_map_.find(key); 136 | if (itr == vp_map_.end()) return false; 137 | *vp = *itr->second; 138 | return true; 139 | } 140 | 141 | void Insert(Key key, const ValueAndProb& vp) { 142 | std::lock_guard lk(mx_); 143 | if (vp_map_.find(key) == vp_map_.end()) AddCache(key, vp); 144 | } 145 | 146 | private: 147 | std::mutex mx_; 148 | size_t max_size_; 149 | std::unordered_map> vp_map_; 150 | std::deque order_; 151 | 152 | void AddCache(Key key, const ValueAndProb& vp) { 153 | vp_map_.emplace(key, 154 | std::unique_ptr(new ValueAndProb(vp))); 155 | order_.push_back(key); 156 | 157 | if (order_.size() > max_size_) { 158 | vp_map_.erase(order_.front()); 159 | order_.pop_front(); 160 | } 161 | } 162 | }; 163 | 164 | #endif // EVAL_CACHE_H_ 165 | -------------------------------------------------------------------------------- /src/eval_worker.h: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #ifndef EVAL_WORKER_H_ 21 | #define EVAL_WORKER_H_ 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | 30 | #include "./network.h" 31 | #include "./option.h" 32 | 33 | /** 34 | * @class EvalWorker 35 | * The EvalWorker class waits for features to be evaluated, and when the queue 36 | * piles up, it performs inference asynchronously on the GPU for each batch 37 | * size. 38 | * 39 | * This class is implemented with reference to OpenCLScheduler of LeelaZero. 40 | * https://github.com/leela-zero/leela-zero/blob/next/src/OpenCLScheduler.cpp 41 | * Revision date: 5/1/2020 42 | */ 43 | class EvalWorker { 44 | public: 45 | ~EvalWorker() { 46 | { 47 | std::unique_lock lk(mx_); 48 | running_ = false; 49 | } 50 | cv_.notify_all(); 51 | if (workers_.size() > 0) 52 | for (auto& th : workers_) th.join(); 53 | } 54 | 55 | EvalWorker() { 56 | running_ = true; 57 | wait_time_millisec_ = 10; 58 | in_single_eval_.store(false); 59 | batch_size_ = Options["batch_size"].get_int(); 60 | use_full_features_ = Options["use_full_features"].get_bool(); 61 | value_from_black_ = Options["value_from_black"].get_bool(); 62 | } 63 | 64 | void Init(std::vector gpu_ids, std::string model_path = "") { 65 | int num_threads = 2; 66 | if (gpu_ids.empty()) { 67 | int num_gpus = Options["num_gpus"].get_int(); 68 | for (int i = 0; i < num_gpus; ++i) gpu_ids.push_back(i); 69 | } 70 | 71 | for (auto gpu_id : gpu_ids) { 72 | for (int i = 0; i < num_threads; ++i) { 73 | auto th = 74 | std::thread(&EvalWorker::BatchWorker, this, gpu_id, model_path); 75 | workers_.push_back(std::move(th)); 76 | std::this_thread::sleep_for(std::chrono::milliseconds(50)); // 50 msec 77 | } 78 | } 79 | } 80 | 81 | void ReplaceModel(std::vector gpu_ids, std::string model_path = "") { 82 | running_ = false; 83 | cv_.notify_all(); 84 | for (auto& th : workers_) th.join(); 85 | workers_.clear(); 86 | 87 | running_ = true; 88 | wait_time_millisec_ = 10; 89 | in_single_eval_.store(false); 90 | 91 | int num_threads = 2; 92 | for (auto gpu_id : gpu_ids) { 93 | for (int i = 0; i < num_threads; ++i) { 94 | auto th = 95 | std::thread(&EvalWorker::BatchWorker, this, gpu_id, model_path); 96 | workers_.push_back(std::move(th)); 97 | } 98 | } 99 | } 100 | 101 | void Evaluate(const Feature& ft, ValueAndProb* vp) { 102 | auto entry = std::make_shared(ft); 103 | 104 | std::unique_lock lk(entry->mx); 105 | 106 | { 107 | std::lock_guard lk(mx_); 108 | synced_queue_.push_back(entry); 109 | 110 | if (in_single_eval_.load() && wait_time_millisec_ < 15) 111 | wait_time_millisec_ += 2; 112 | } 113 | 114 | cv_.notify_one(); 115 | entry->cv.wait(lk); 116 | 117 | *vp = entry->vp; 118 | } 119 | 120 | std::mutex* get_mutex() { return &mx_; } 121 | 122 | private: 123 | std::atomic running_; 124 | std::mutex mx_; 125 | std::condition_variable cv_; 126 | bool use_full_features_; 127 | bool value_from_black_; 128 | int wait_time_millisec_; 129 | std::atomic in_single_eval_; 130 | int batch_size_; 131 | std::deque> synced_queue_; 132 | std::vector workers_; 133 | 134 | std::vector> PickupEntry() { 135 | std::vector> entry_queue; 136 | int num_entries = 0; 137 | 138 | std::unique_lock lk(mx_); 139 | while (true) { 140 | if (!running_) return std::move(entry_queue); 141 | 142 | num_entries = synced_queue_.size(); 143 | if (num_entries >= batch_size_) { 144 | num_entries = batch_size_; 145 | break; 146 | } 147 | 148 | bool timeout = !cv_.wait_for( 149 | lk, std::chrono::milliseconds(wait_time_millisec_), [this]() { 150 | return !running_ || 151 | static_cast(synced_queue_.size()) >= batch_size_; 152 | }); 153 | 154 | if (!synced_queue_.empty()) { 155 | if (timeout && in_single_eval_.exchange(true) == false) { 156 | if (wait_time_millisec_ > 1) { 157 | wait_time_millisec_--; 158 | } 159 | num_entries = 1; 160 | break; 161 | } 162 | } 163 | } 164 | 165 | auto end = synced_queue_.begin(); 166 | std::advance(end, num_entries); 167 | std::move(synced_queue_.begin(), end, std::back_inserter(entry_queue)); 168 | synced_queue_.erase(synced_queue_.begin(), end); 169 | 170 | return std::move(entry_queue); 171 | } 172 | 173 | void BatchWorker(const int gpu_id, std::string model_path) { 174 | TensorEngine engine(gpu_id, batch_size_); 175 | 176 | { 177 | std::lock_guard lock(mx_); 178 | engine.Init(model_path, use_full_features_, value_from_black_); 179 | } 180 | 181 | while (true) { 182 | auto entry_queue = PickupEntry(); 183 | int num_entries = entry_queue.size(); 184 | 185 | if (!running_) return; 186 | engine.Infer(&entry_queue, kNumSymmetry); 187 | 188 | for (auto& entry : entry_queue) { 189 | std::lock_guard lk(entry->mx); 190 | entry->cv.notify_all(); 191 | } 192 | 193 | if (num_entries == 1) in_single_eval_ = false; 194 | } 195 | } 196 | }; 197 | 198 | #endif // EVAL_WORKER_H_ 199 | -------------------------------------------------------------------------------- /src/feature.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #include "./feature.h" 21 | #include "./board.h" 22 | 23 | void Feature::Update(const Board& b) { 24 | // 1. Initializes features. 25 | std::fill(ladder_esc_.begin(), ladder_esc_.end(), float{0.0}); 26 | std::fill(sensibleness_.begin(), sensibleness_.end(), float{0.0}); 27 | for (int i = 0; i < kFeatureSize; ++i) { 28 | std::fill(liberty_[i].begin(), liberty_[i].end(), float{0.0}); 29 | std::fill(cap_size_[i].begin(), cap_size_[i].end(), float{0.0}); 30 | std::fill(self_atari_[i].begin(), self_atari_[i].end(), float{0.0}); 31 | std::fill(liberty_after_[i].begin(), liberty_after_[i].end(), float{0.0}); 32 | } 33 | Color c_us = b.side_to_move(); 34 | 35 | for (RawVertex rv = kRvtZero; rv < kNumRvts; ++rv) { 36 | Vertex v = rv2v(rv); 37 | 38 | if (b.color_at(v) != kEmpty) { 39 | liberty_[std::min(kFeatureSize - 1, b.sg_num_liberties_at(v) - 1)][rv] = 40 | 1.0; 41 | } else if (b.IsLegal(v)) { 42 | // 2. Updates sensibleness. 43 | if (!b.IsEyeShape(v) && !b.IsSeki(v)) sensibleness_[rv] = 1.0; 44 | 45 | // 3. Checks sg_id of surrounding stone groups. 46 | std::vector our_sg_ids; 47 | 48 | Bitboard libs; 49 | for (Direction d = kDirZero; d < kNumDir4; ++d) { 50 | Vertex v_nbr = v + dir2v(d); 51 | if (b.color_at(v_nbr) == kEmpty) { 52 | libs.Add(v_nbr); 53 | } else if (b.color_at(v_nbr) == c_us) { 54 | our_sg_ids.push_back(b.sg_id(v_nbr)); 55 | } 56 | } 57 | sort(our_sg_ids.begin(), our_sg_ids.end()); 58 | our_sg_ids.erase(unique(our_sg_ids.begin(), our_sg_ids.end()), 59 | our_sg_ids.end()); 60 | 61 | // 4. Counts size and liberty of the neighboring groups. 62 | int num_captured = 0; 63 | int num_our_stones = 1; 64 | std::vector checked_ids; 65 | 66 | for (Direction d = kDirZero; d < kNumDir4; ++d) { 67 | Vertex v_nbr = v + dir2v(d); 68 | 69 | if (b.color_at(v_nbr) == ~c_us) { // 4-1. Opponent's stone. 70 | // Adds to num_captured if it is in Atari and not yet checked. 71 | if (b.sg_atari_at(v_nbr) && 72 | find(checked_ids.begin(), checked_ids.end(), b.sg_id(v_nbr)) == 73 | checked_ids.end()) { 74 | checked_ids.push_back(b.sg_id(v_nbr)); 75 | libs.Add(v_nbr); 76 | num_captured += b.sg_size_at(v_nbr); 77 | 78 | Vertex v_tmp = v_nbr; 79 | do { 80 | for (Direction d2 = kDirZero; d2 < kNumDir4; ++d2) { 81 | Vertex v_tmp_nbr = v_tmp + dir2v(d2); 82 | if (b.color_at(v_tmp_nbr) == c_us && 83 | find(our_sg_ids.begin(), our_sg_ids.end(), 84 | b.sg_id(v_tmp_nbr)) != our_sg_ids.end()) 85 | libs.Add(v_tmp); 86 | } 87 | v_tmp = b.next_v(v_tmp); 88 | } while (v_tmp != v_nbr); 89 | } 90 | } else if (b.color_at(v_nbr) == c_us) { // 4-2. Player's stone. 91 | // Adds to num_our_stones if it is not yet checked. 92 | if (find(checked_ids.begin(), checked_ids.end(), b.sg_id(v_nbr)) == 93 | checked_ids.end()) { 94 | checked_ids.push_back(b.sg_id(v_nbr)); 95 | num_our_stones += b.sg_size_at(v_nbr); 96 | libs.Merge(b.sg_liberties_at(v_nbr)); 97 | } 98 | } 99 | } 100 | 101 | // 5. Updates capture size. 102 | if (num_captured != 0) 103 | cap_size_[std::min(kFeatureSize - 1, num_captured - 1)][rv] = 1.0; 104 | 105 | libs.Remove(v); 106 | int num_liberties = libs.num_bits(); 107 | 108 | // 6. Updates self-atari size. 109 | if (num_liberties == 1) 110 | self_atari_[std::min(kFeatureSize - 1, num_our_stones - 1)][rv] = 1.0; 111 | // 7. Updates liberties after the move. 112 | liberty_after_[std::min(kFeatureSize - 1, num_liberties - 1)][rv] = 1.0; 113 | } 114 | } 115 | 116 | // 8. Updates vertices escaping from ladder. 117 | constexpr int num_escapes = kBSize == 9 ? 3 : 4; 118 | Board b_cpy = b; 119 | auto escape_vertices = b_cpy.LadderEscapes(num_escapes); 120 | for (auto& v_esc : escape_vertices) ladder_esc_[v2rv(v_esc)] = 1.0; 121 | } 122 | 123 | float* Feature::Copy(float* oi, bool use_full, int symmetry_idx) const { 124 | auto copy_n_symmetry = 125 | [symmetry_idx](std::vector::const_iterator input, float* output) { 126 | if (symmetry_idx == 0) { 127 | output = std::copy_n(input, kNumRvts, output); 128 | } else { 129 | for (int j = 0; j < kNumRvts; ++j) { 130 | *output = *(input + rv2sym(j, symmetry_idx)); 131 | ++output; 132 | } 133 | } 134 | return output; 135 | }; 136 | 137 | for (int i = 0; i < kNumHistory; ++i) 138 | oi = copy_n_symmetry(stones_[next_side_][i].begin(), oi); 139 | for (int i = 0; i < kNumHistory; ++i) 140 | oi = copy_n_symmetry(stones_[~next_side_][i].begin(), oi); 141 | 142 | if (next_side_ == kWhite) { 143 | oi = std::fill_n(oi, kNumRvts, float{0.0}); 144 | oi = std::fill_n(oi, kNumRvts, float{1.0}); 145 | } else { 146 | oi = std::fill_n(oi, kNumRvts, float{1.0}); 147 | oi = std::fill_n(oi, kNumRvts, float{0.0}); 148 | } 149 | 150 | if (use_full) { 151 | for (int i = 0; i < kFeatureSize; ++i) 152 | oi = copy_n_symmetry(liberty_[i].begin(), oi); 153 | for (int i = 0; i < kFeatureSize; ++i) 154 | oi = copy_n_symmetry(cap_size_[i].begin(), oi); 155 | for (int i = 0; i < kFeatureSize; ++i) 156 | oi = copy_n_symmetry(self_atari_[i].begin(), oi); 157 | for (int i = 0; i < kFeatureSize; ++i) 158 | oi = copy_n_symmetry(liberty_after_[i].begin(), oi); 159 | oi = copy_n_symmetry(ladder_esc_.begin(), oi); 160 | oi = copy_n_symmetry(sensibleness_.begin(), oi); 161 | } 162 | 163 | return oi; 164 | } 165 | -------------------------------------------------------------------------------- /src/feature.h: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #ifndef FEATURE_H_ 21 | #define FEATURE_H_ 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | #include "./bitboard.h" 30 | #include "./config.h" 31 | #include "./pattern.h" 32 | #include "./types.h" 33 | 34 | // -------------------- 35 | // Diff 36 | // -------------------- 37 | 38 | /** 39 | * @struct LightMap 40 | * A lightweight map that stores variables of T with an integer key of 41 | * upper limit Size. 42 | */ 43 | template 44 | struct LightMap { 45 | std::vector> entry; 46 | bool flag[Size]; 47 | 48 | LightMap() : flag{false} {} 49 | 50 | void Insert(int idx, T val) { 51 | if (!flag[idx]) { 52 | flag[idx] = true; 53 | entry.push_back({idx, val}); 54 | } 55 | } 56 | }; 57 | 58 | /** 59 | * @struct Diff 60 | * Diff class keeps history of board properties. 61 | * See Board class members for detail. 62 | */ 63 | struct Diff { 64 | LightMap empty; 65 | LightMap empty_id; 66 | LightMap sg; 67 | LightMap sg_id; 68 | LightMap next_v; 69 | LightMap ptn; 70 | LightMap prob[kNumPlayers]; 71 | LightMap sum_prob_rank[kNumPlayers]; 72 | 73 | Key key; 74 | Vertex prev_ko; 75 | Bitboard removed_stones; 76 | Pattern prev_ptn; 77 | double prev_rsp_prob; 78 | Vertex response_move[4]; 79 | 80 | Vertex features_add; 81 | Bitboard features_sub; 82 | 83 | Diff() 84 | : key(UINT64_MAX), 85 | prev_ko(kVtNull), 86 | removed_stones{}, 87 | prev_ptn(0xffffffff), 88 | prev_rsp_prob(0.0), 89 | response_move{kVtNull}, 90 | features_add(kVtNull), 91 | features_sub{} {} 92 | }; 93 | 94 | // -------------------- 95 | // Feature 96 | // -------------------- 97 | 98 | class Board; 99 | constexpr int kNumHistory = 8; 100 | constexpr int kFeatureSize = 8; 101 | 102 | /** 103 | * @class Feature 104 | * Feature class contains input features for neural network. 105 | * 106 | * [0]-[15] : stones 0->my(t) 1->her(t) 2->my(t-1) ... 107 | * [16]-[17]: color 108 | * [18]-[25]: liberty 109 | * [26]-[33]: capture size 110 | * [34]-[41]: self Atari size 111 | * [42]-[49]: liberty after 112 | * [50] : ladder escape 113 | * [51] : sensibleness 114 | */ 115 | class Feature { 116 | public: 117 | Feature() 118 | : next_side_(kBlack), 119 | stones_(kNumPlayers, std::vector>( 120 | kNumHistory, std::vector(kNumRvts, 0))), 121 | add_history_(kNumHistory, kVtNull), 122 | sub_history_(kNumHistory), 123 | liberty_(kFeatureSize, std::vector(kNumRvts, 0)), 124 | cap_size_(kFeatureSize, std::vector(kNumRvts, 0)), 125 | self_atari_(kFeatureSize, std::vector(kNumRvts, 0)), 126 | liberty_after_(kFeatureSize, std::vector(kNumRvts, 0)), 127 | ladder_esc_(kNumRvts, 0), 128 | sensibleness_(kNumRvts, 0) {} 129 | 130 | Feature(const Feature& rhs) 131 | : next_side_(rhs.next_side_), 132 | stones_(rhs.stones_), 133 | add_history_(rhs.add_history_), 134 | sub_history_(rhs.sub_history_), 135 | liberty_(rhs.liberty_), 136 | cap_size_(rhs.cap_size_), 137 | self_atari_(rhs.self_atari_), 138 | liberty_after_(rhs.liberty_after_), 139 | ladder_esc_(rhs.ladder_esc_), 140 | sensibleness_(rhs.sensibleness_) {} 141 | 142 | void Init() { 143 | next_side_ = kBlack; 144 | std::fill(add_history_.begin(), add_history_.end(), kVtNull); 145 | 146 | for (int i = 0; i < kNumHistory; ++i) { 147 | sub_history_[i].Init(); 148 | std::fill(stones_[kWhite][i].begin(), stones_[kWhite][i].end(), 149 | float{0.0}); 150 | std::fill(stones_[kBlack][i].begin(), stones_[kBlack][i].end(), 151 | float{0.0}); 152 | } 153 | 154 | for (int i = 0; i < kFeatureSize; ++i) { 155 | std::fill(liberty_[i].begin(), liberty_[i].end(), float{0.0}); 156 | std::fill(cap_size_[i].begin(), cap_size_[i].end(), float{0.0}); 157 | std::fill(self_atari_[i].begin(), self_atari_[i].end(), float{0.0}); 158 | std::fill(liberty_after_[i].begin(), liberty_after_[i].end(), float{0.0}); 159 | } 160 | 161 | std::fill(ladder_esc_.begin(), ladder_esc_.end(), float{0.0}); 162 | std::fill(sensibleness_.begin(), sensibleness_.end(), float{0.0}); 163 | } 164 | 165 | Feature& operator=(const Feature& rhs) { 166 | stones_ = rhs.stones_; 167 | next_side_ = rhs.next_side_; 168 | add_history_ = rhs.add_history_; 169 | sub_history_ = rhs.sub_history_; 170 | liberty_ = rhs.liberty_; 171 | cap_size_ = rhs.cap_size_; 172 | self_atari_ = rhs.self_atari_; 173 | liberty_after_ = rhs.liberty_after_; 174 | ladder_esc_ = rhs.ladder_esc_; 175 | sensibleness_ = rhs.sensibleness_; 176 | 177 | return *this; 178 | } 179 | 180 | bool operator==(const Feature& rhs) const { 181 | return next_side_ == rhs.next_side_ && stones_ == rhs.stones_; 182 | } 183 | 184 | float stones(Color c, int t, int rv) const { return stones_[c][t][rv]; } 185 | 186 | Color next_side() const { return next_side_; } 187 | 188 | Vertex last_add() const { return add_history_[kNumHistory - 1]; } 189 | 190 | Bitboard last_sub() const { return sub_history_[kNumHistory - 1]; } 191 | 192 | float liberty(int t, int rv) const { return liberty_[t][rv]; } 193 | 194 | float cap_size(int t, int rv) const { return cap_size_[t][rv]; } 195 | 196 | float self_atari(int t, int rv) const { return self_atari_[t][rv]; } 197 | 198 | float liberty_after(int t, int rv) const { return liberty_after_[t][rv]; } 199 | 200 | float ladder_esc(int rv) const { return ladder_esc_[rv]; } 201 | 202 | float sensibleness(int rv) const { return sensibleness_[rv]; } 203 | 204 | void DoNullMove() { 205 | Color c = ~next_side_; 206 | 207 | for (int i = 1; i < kNumHistory; ++i) { 208 | if (add_history_[i - 1] < kPass) 209 | stones_[c][i][v2rv(add_history_[i - 1])] = 1.0; 210 | for (auto& rs : sub_history_[i - 1].Vertices()) 211 | stones_[~c][i][v2rv(rs)] = 0.0; 212 | c = ~c; 213 | } 214 | 215 | for (int i = kNumHistory - 1; i > 0; --i) { 216 | add_history_[i] = add_history_[i - 1]; 217 | sub_history_[i] = sub_history_[i - 1]; 218 | } 219 | 220 | add_history_[0] = kVtNull; 221 | sub_history_[0].Init(); 222 | 223 | next_side_ = ~next_side_; 224 | } 225 | 226 | void Undo(Vertex v, Bitboard bb) { 227 | Color c = ~next_side_; 228 | 229 | for (int i = 0; i < kNumHistory; ++i) { 230 | if (add_history_[i] < kPass) stones_[c][i][v2rv(add_history_[i])] = 0.0; 231 | for (auto& rs : sub_history_[i].Vertices()) 232 | stones_[~c][i][v2rv(rs)] = 1.0; 233 | 234 | c = ~c; 235 | } 236 | 237 | for (int i = 0; i < kNumHistory - 1; ++i) { 238 | add_history_[i] = add_history_[i + 1]; 239 | sub_history_[i] = sub_history_[i + 1]; 240 | } 241 | 242 | add_history_[kNumHistory - 1] = v; 243 | sub_history_[kNumHistory - 1] = bb; 244 | 245 | next_side_ = ~next_side_; 246 | } 247 | 248 | void Add(Color c, Vertex v) { 249 | ASSERT_LV2(kColorZero <= c && c < kNumPlayers); 250 | ASSERT_LV2(is_ok(v) && !in_wall(v)); 251 | ASSERT_LV2(c == ~next_side_); 252 | 253 | stones_[c][0][v2rv(v)] = 1.0; 254 | add_history_[0] = v; 255 | } 256 | 257 | void Remove(Color c, Vertex v) { 258 | ASSERT_LV2(kColorZero <= c && c < kNumPlayers); 259 | ASSERT_LV2(is_ok(v) && !in_wall(v)); 260 | ASSERT_LV2(c == next_side_); 261 | 262 | stones_[c][0][v2rv(v)] = 0.0; 263 | sub_history_[0].Add(v); 264 | } 265 | 266 | void Update(const Board& b); 267 | 268 | float* Copy(float* oi, bool use_full = true, int symmetry_idx = 0) const; 269 | 270 | /** 271 | * Outputs Feature information. (for debug) 272 | */ 273 | friend std::ostream& operator<<(std::ostream& os, const Feature& ft) { 274 | os << "next_side_=" << ft.next_side_ << std::endl; 275 | for (int i = 1; i < kNumHistory; ++i) { 276 | os << "diff(" << i << "): kWhite "; 277 | for (int j = 0; j < kNumRvts; ++j) { 278 | if (ft.stones_[kWhite][i - 1][j] != ft.stones_[kWhite][i][j]) 279 | os << rv2v(RawVertex(j)) << " "; 280 | } 281 | os << "kBlack "; 282 | for (int j = 0; j < kNumRvts; ++j) { 283 | if (ft.stones_[kBlack][i - 1][j] != ft.stones_[kBlack][i][j]) 284 | os << rv2v(RawVertex(j)) << " "; 285 | } 286 | os << std::endl; 287 | } 288 | return os; 289 | } 290 | 291 | private: 292 | std::vector>> stones_; 293 | Color next_side_; 294 | std::vector add_history_; 295 | std::vector sub_history_; 296 | std::vector> liberty_; 297 | std::vector> cap_size_; 298 | std::vector> self_atari_; 299 | std::vector> liberty_after_; 300 | std::vector ladder_esc_; 301 | std::vector sensibleness_; 302 | }; 303 | 304 | #endif // FEATURE_H_ 305 | -------------------------------------------------------------------------------- /src/gtp.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #include "./gtp.h" 21 | 22 | const char kVersion[] = "4.0.0"; 23 | 24 | const std::vector kListCommands = {"board_size", 25 | "list_commands", 26 | "clear_board", 27 | "genmove", 28 | "play", 29 | "quit", 30 | "time_left", 31 | "time_settings", 32 | "name", 33 | "protocol_version", 34 | "version", 35 | "komi", 36 | "final_score", 37 | "kgs-time_settings", 38 | "kgs-game_over", 39 | "place_free_handicap", 40 | "set_free_handicap", 41 | "gogui-play_sequence", 42 | "lz-analyze"}; 43 | 44 | std::string GTPConnector::OnClearBoardCommand() { 45 | StopLizzieAnalysis(); 46 | b_.Init(); // Initializes the board. 47 | AllocateGPU(); // Allocates memory. 48 | tree_.InitRoot(); 49 | tree_.UpdateRoot(b_); 50 | sgf_.Init(); 51 | c_engine_ = kEmpty; 52 | go_ponder_ = false; 53 | 54 | // Resumes from SGF file. 55 | if ((std::string)Options["resume_file_name"] != "") { 56 | sgf_.Read((std::string)Options["working_dir"] + 57 | (std::string)Options["resume_file_name"]); 58 | sgf_.ReconstructBoard(&b_, sgf_.game_ply()); 59 | tree_.UpdateRoot(b_); 60 | 61 | Options["resume_file_name"] = ""; 62 | } 63 | 64 | if (save_log_) { 65 | std::ifstream ifs(sgf_path_); 66 | if (ifs.is_open()) { 67 | time_t t = time(NULL); 68 | char date[64]; 69 | strftime(date, sizeof(date), "%Y%m%d_%H%M%S", localtime(&t)); 70 | std::string date_str = date; 71 | 72 | log_path_ = JoinPath(Options["working_dir"], "log", date_str + ".txt"); 73 | sgf_path_ = JoinPath(Options["working_dir"], "log", date_str + ".sgf"); 74 | 75 | if (tree_.log_file()) tree_.log_file()->close(); 76 | tree_.SetLogFile(log_path_); 77 | ifs.close(); 78 | } 79 | } 80 | 81 | fprintf(stderr, "cleared board.\n"); 82 | return ""; 83 | } 84 | 85 | /** 86 | * Searches and send the next move. 87 | * e.g. "=genmove b", "=genmove white", ... 88 | */ 89 | std::string GTPConnector::OnGenmoveCommand() { 90 | auto t0 = std::chrono::system_clock::now(); 91 | StopLizzieAnalysis(); 92 | std::string response(""); 93 | 94 | // a. Allocates memory. 95 | AllocateGPU(); 96 | 97 | // b. Returns error if side to move is not consistent. 98 | Color c_arg = FindString(args_[0], "B", "b") ? kBlack : kWhite; 99 | if (c_arg != b_.side_to_move()) { 100 | success_handle_ = false; 101 | response = "genmove command passed wrong color."; 102 | fprintf(stderr, "? %s\n", response.c_str()); 103 | if (tree_.log_file()) *tree_.log_file() << "? " << response << std::endl; 104 | 105 | return response; 106 | } 107 | 108 | c_engine_ = b_.side_to_move(); 109 | go_ponder_ = true; 110 | tree_.PrepareToThink(); 111 | 112 | // c. Searches for the best move. 113 | double winning_rate = 0.5; 114 | Vertex next_move = tree_.Search(b_, 0.0, &winning_rate, true, false); 115 | 116 | bool resign = false; 117 | if (next_move != kPass && 118 | winning_rate < Options["resign_value"].get_double()) { 119 | resign = true; 120 | next_move = kPass; 121 | } 122 | 123 | // d. Plays the move. 124 | b_.MakeMove(next_move); 125 | tree_.UpdateRoot(b_); 126 | 127 | // e. Updates logs. 128 | sgf_.Add(next_move); 129 | if (save_log_) sgf_.Write(sgf_path_); 130 | tree_.PrintBoardLog(b_); 131 | if (b_.double_pass()) PrintFinalResult(b_); 132 | 133 | // f. Sends response of the next move. 134 | if (resign) 135 | response = "resign"; 136 | else if (next_move == kPass) 137 | response = "pass"; 138 | else 139 | response = tree_.v2str(next_move); 140 | 141 | // g. Updates remaining time. 142 | if (Options["need_time_control"]) { 143 | double elapsed_time = tree_.ElapsedTime(t0); 144 | tree_.set_left_time(std::max(0.0, tree_.left_time() - elapsed_time)); 145 | } 146 | 147 | return response; 148 | } 149 | 150 | /** 151 | * Receives the opponent's move and reflect on the board. 152 | * e.g. "=play w D4", "play b pass", ... 153 | */ 154 | std::string GTPConnector::OnPlayCommand() { 155 | StopLizzieAnalysis(); 156 | std::string response(""); 157 | 158 | go_ponder_ = false; // Because 'genmove' command comes soon. 159 | 160 | // Returns error if side to move is not consistent. 161 | Color c_arg = FindString(args_[0], "B", "b") ? kBlack : kWhite; 162 | if (c_arg != b_.side_to_move()) { 163 | success_handle_ = false; 164 | response = "play command passed wrong color."; 165 | fprintf(stderr, "? %s\n", response.c_str()); 166 | if (tree_.log_file()) *tree_.log_file() << "? " << response << std::endl; 167 | 168 | return response; 169 | } 170 | 171 | // a. Analyzes received string. 172 | Vertex next_move; 173 | if (FindString(args_[1], "pass", "Pass", "PASS")) { 174 | next_move = kPass; 175 | } else if (FindString(args_[1], "resign", "Resign", "RESIGN")) { 176 | next_move = kPass; 177 | } else { 178 | std::string str_x = args_[1].substr(0, 1); 179 | std::string str_y = args_[1].substr(1); 180 | std::string x_list = "ABCDEFGHJKLMNOPQRSTabcdefghjklmnopqrst"; 181 | 182 | int x = static_cast(x_list.find(str_x)) % 19 + 1; 183 | int y = stoi(str_y); 184 | next_move = xy2v(x, y); 185 | } 186 | 187 | if (save_log_) { 188 | if (b_.game_ply() == 0) tree_.UpdateRoot(b_); 189 | std::stringstream ss; 190 | tree_.PrintCandidates(tree_.root_node(), next_move, ss, true); 191 | if (tree_.log_file()) *(tree_.log_file()) << ss.str(); 192 | std::cerr << ss.str(); 193 | } 194 | 195 | // c. Plays the move. 196 | b_.MakeMove(next_move); 197 | tree_.UpdateRoot(b_); 198 | 199 | // d. Updates logs. 200 | sgf_.Add(next_move); 201 | if (save_log_) sgf_.Write(sgf_path_); 202 | tree_.PrintBoardLog(b_); 203 | if (b_.double_pass()) PrintFinalResult(b_); 204 | 205 | return response; 206 | } 207 | 208 | /** 209 | * Undoes the previous move. 210 | */ 211 | std::string GTPConnector::OnUndoCommand() { 212 | StopLizzieAnalysis(); 213 | std::vector move_history = b_.move_history(); 214 | if (!move_history.empty()) move_history.pop_back(); 215 | double left_time = tree_.left_time(); 216 | 217 | int num_passes[2] = {b_.num_passes(kWhite), b_.num_passes(kBlack)}; 218 | if (b_.move_before() == kPass) --num_passes[~b_.side_to_move()]; 219 | 220 | // a. Initializes board. 221 | b_.Init(); 222 | tree_.InitRoot(); 223 | sgf_.Init(); 224 | 225 | // b. Advances the board to the previous state. 226 | for (auto v_hist : move_history) { 227 | b_.MakeMove(v_hist); 228 | sgf_.Add(v_hist); 229 | } 230 | tree_.UpdateRoot(b_); 231 | tree_.set_left_time(left_time); 232 | 233 | b_.set_num_passes(kWhite, num_passes[kWhite]); 234 | b_.set_num_passes(kBlack, num_passes[kBlack]); 235 | 236 | // c. Updates logs. 237 | if (save_log_) sgf_.Write(sgf_path_); 238 | tree_.PrintBoardLog(b_); 239 | 240 | return ""; 241 | } 242 | 243 | /** 244 | * Returns Lizzie information. 245 | */ 246 | std::string GTPConnector::OnLzAnalyzeCommand() { 247 | lizzie_interval_ = (args_.size() >= 1 ? stoi(args_[0]) * 10 : 100); // millisec 248 | if (!tree_.has_eval_worker()) { 249 | AllocateGPU(); // Allocates memory. 250 | b_.Init(); 251 | tree_.InitRoot(); 252 | tree_.UpdateRoot(b_); 253 | sgf_.Init(); 254 | c_engine_ = kEmpty; 255 | } 256 | go_ponder_ = true; 257 | tree_.PrepareToThink(); 258 | 259 | return ""; 260 | } 261 | 262 | /** 263 | * Sets main and byoyomi time. 264 | * e.g. "=kgs-time_settings byoyomi 30 60 3", ... 265 | */ 266 | std::string GTPConnector::OnKgsTimeSettingsCommand() { 267 | if (FindString(args_[0], "byoyomi") && args_.size() >= 3) { 268 | Options["main_time"] = args_[1]; 269 | tree_.set_main_time(stod(args_[1])); 270 | tree_.set_left_time(tree_.main_time()); 271 | Options["byoyomi"] = args_[2]; 272 | tree_.set_byoyomi(stod(args_[2])); 273 | } else { 274 | Options["main_time"] = args_[1]; 275 | tree_.set_main_time(stod(args_[1])); 276 | tree_.set_left_time(tree_.main_time()); 277 | } 278 | 279 | std::fprintf(stderr, "main time=%.1f[sec], byoyomi=%.1f[sec], extension=%d\n", 280 | Options["main_time"].get_double(), 281 | Options["byoyomi"].get_double(), 282 | Options["num_extensions"].get_int()); 283 | 284 | return ""; 285 | } 286 | 287 | /** 288 | * Sets main and byoyomi time. 289 | * e.g. "=time_settings 30 60 3", ... 290 | */ 291 | std::string GTPConnector::OnTimeSettingsCommand() { 292 | Options["main_time"] = args_[0]; 293 | tree_.set_main_time(stod(args_[0])); 294 | tree_.set_left_time(tree_.main_time()); 295 | Options["byoyomi"] = args_[1]; 296 | tree_.set_byoyomi(stod(args_[1])); 297 | // if (args_.size() >= 3) { 298 | // Options["num_extensions"] = 299 | // std::to_string(std::max(0, stoi(args_[2]) - 1)); 300 | // tree_.set_num_extensions(std::max(0, stoi(args_[2]) - 1)); 301 | // } 302 | 303 | std::fprintf(stderr, "main time=%.1f[sec], byoyomi=%.1f[sec], extension=%d\n", 304 | Options["main_time"].get_double(), 305 | Options["byoyomi"].get_double(), 306 | Options["num_extensions"].get_int()); 307 | 308 | return ""; 309 | } 310 | 311 | /** 312 | * Places handicap stones. 313 | * e.g. "=set_free_handicap D4 ..." 314 | */ 315 | std::string GTPConnector::OnSetFreeHandicapCommand() { 316 | if (args_.size() >= 1) { 317 | int i_max = args_.size(); 318 | for (int i = 0; i < i_max; ++i) { 319 | std::string str_x = args_[i].substr(0, 1); 320 | std::string str_y = args_[i].substr(1); 321 | 322 | std::string x_list = "ABCDEFGHJKLMNOPQRSTabcdefghjklmnopqrst"; 323 | 324 | int x = static_cast(x_list.find(str_x)) % 19 + 1; 325 | int y = stoi(str_y); 326 | 327 | Vertex next_move = xy2v(x, y); 328 | b_.MakeMove(next_move); 329 | sgf_.Add(next_move); 330 | 331 | // Add a white pass except at the end to adjust the turn. 332 | // In the case of handicapped games, start with white. 333 | if (i != i_max - 1) { 334 | b_.MakeMove(kPass); 335 | sgf_.Add(kPass); 336 | b_.decrement_passes(kWhite); 337 | } 338 | } 339 | } 340 | 341 | std::fprintf(stderr, "set free handicap.\n"); 342 | return ""; 343 | } 344 | 345 | /** 346 | * Places fixed handicap stones. 347 | * e.g. "=fixed_handicap 2" 348 | */ 349 | std::string GTPConnector::OnFixedHandicapCommand() { 350 | if (args_.size() >= 1) { 351 | int x_[9] = {4, 16, 4, 16, 4, 16, 10, 10, 10}; 352 | int y_[9] = {4, 16, 16, 4, 10, 10, 4, 16, 10}; 353 | int stones[8][9] = {{0, 1}, 354 | {0, 1, 2}, 355 | {0, 1, 2, 3}, 356 | {0, 1, 2, 3, 8}, 357 | {0, 1, 2, 3, 4, 5}, 358 | {0, 1, 2, 3, 4, 5, 8}, 359 | {0, 1, 2, 3, 4, 5, 6, 7}, 360 | {0, 1, 2, 3, 4, 5, 6, 7, 8}}; 361 | int num_handicaps = stoi(args_[0]); 362 | for (int i = 0; i < num_handicaps; ++i) { 363 | int stone_idx = stones[num_handicaps - 2][i]; 364 | Vertex v = xy2v(x_[stone_idx], y_[stone_idx]); 365 | b_.MakeMove(v); 366 | sgf_.Add(v); 367 | 368 | // Add a white pass except at the end to adjust the turn. 369 | // In the case of handicapped games, start with white. 370 | if (i != num_handicaps - 1) { 371 | b_.MakeMove(kPass); 372 | sgf_.Add(kPass); 373 | b_.decrement_passes(kWhite); 374 | } 375 | } 376 | } 377 | 378 | std::fprintf(stderr, "placed handicap stones.\n"); 379 | return ""; 380 | } 381 | 382 | /** 383 | * Receives all moves from start and reconstructs board. 384 | * e.g. "=gogui-play_sequence B R16 W D16 B Q3 W D3 ..." 385 | */ 386 | std::string GTPConnector::OnGoguiPlaySequenceCommand() { 387 | int i_max = args_.size(); 388 | for (int i = 1; i < i_max; i = i + 2) { 389 | Color c = (FindString(args_[i - 1], "B", "b")) ? kBlack : kWhite; 390 | Vertex next_move = kPass; 391 | if (b_.side_to_move() != c) { 392 | b_.MakeMove(kPass); 393 | sgf_.Add(kPass); 394 | b_.decrement_passes(c); 395 | } 396 | if (!FindString(args_[i], "PASS", "Pass", "pass")) { 397 | std::string str_x = args_[i].substr(0, 1); 398 | std::string str_y = args_[i].substr(1); 399 | 400 | std::string x_list = "ABCDEFGHJKLMNOPQRSTabcdefghjklmnopqrst"; 401 | 402 | int x = static_cast(x_list.find(str_x)) % 19 + 1; 403 | int y = stoi(str_y); 404 | 405 | next_move = xy2v(x, y); 406 | } 407 | 408 | // Plays the move. 409 | b_.MakeMove(next_move); 410 | // Updates logs. 411 | sgf_.Add(next_move); 412 | tree_.PrintBoardLog(b_); 413 | } 414 | 415 | tree_.UpdateRoot(b_); 416 | if (save_log_) sgf_.Write(sgf_path_); 417 | 418 | std::fprintf(stderr, "sequence loaded.\n"); 419 | return ""; 420 | } 421 | -------------------------------------------------------------------------------- /src/gtp.h: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #ifndef GTP_H_ 21 | #define GTP_H_ 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | #include "./board.h" 30 | #include "./search.h" 31 | #include "./sgf.h" 32 | 33 | extern const char kVersion[]; 34 | 35 | /** 36 | * List of supported commands. 37 | * (Just what match servers or GUI requires) 38 | */ 39 | extern const std::vector kListCommands; 40 | 41 | /** 42 | * @class GTPConnector 43 | * Manages GTP communication and loops until the quit command is sent. 44 | * 45 | * See the following link for detail of GTP (Go Text Protocol) communication. 46 | * https://www.lysator.liu.se/~gunnar/gtp/gtp2-spec-draft2/gtp2-spec.html 47 | * 48 | * It is implemented so that it receives a query from the server and returns 49 | * information such as a move. Since standard input/output is used, stdout 50 | * should not be used except for GTP. (Error in GTP communication.) 51 | */ 52 | class GTPConnector { 53 | public: 54 | // Constructor. 55 | GTPConnector() 56 | : c_engine_(kEmpty), 57 | go_ponder_(false), 58 | success_handle_(true), 59 | lizzie_interval_(-1) { 60 | // Log settings. 61 | if(Options["lizzie"]) Options["save_log"] = false; 62 | save_log_ = Options["save_log"].get_bool(); 63 | time_t t = time(NULL); 64 | char date[64]; 65 | std::strftime(date, sizeof(date), "%Y%m%d_%H%M%S", localtime(&t)); 66 | std::string date_str = date; 67 | log_path_ = JoinPath(Options["working_dir"], "log", date_str + ".txt"); 68 | sgf_path_ = JoinPath(Options["working_dir"], "log", date_str + ".sgf"); 69 | // Sets file path of log. 70 | if (save_log_) tree_.SetLogFile(log_path_); 71 | 72 | // Sends command list for a kind of matching server. 73 | if (Options["send_list"]) { 74 | std::string response(""); 75 | for (auto cmd : kListCommands) response += cmd + "\n"; 76 | response += "= "; 77 | SendGTPCommand("= %s\n\n", response); 78 | } 79 | 80 | // Allocates gpu in advance. 81 | if (Options["allocate_gpu"]) AllocateGPU(); 82 | } 83 | 84 | void Start() { 85 | // Starts communication with the GTP protocol. 86 | bool running = true; 87 | while (running) { 88 | std::string command(""); 89 | bool start_pondering = Options["use_ponder"].get_bool() && go_ponder_ && 90 | b_.move_before() != kPass && 91 | (tree_.left_time() > 10.0 || tree_.byoyomi() != 0); 92 | if (start_pondering) AllocateGPU(); 93 | 94 | // Thread that monitors GTP commands during pondering. 95 | std::thread read_th([this, &command, start_pondering]() { 96 | while (command == "") { 97 | ReceiveGTPCommand(&command); 98 | if (command != "" && start_pondering) { 99 | tree_.StopToThink(); 100 | break; 101 | } 102 | // Interval of checking command strings. 103 | std::this_thread::sleep_for(std::chrono::milliseconds(1)); // 1 msec 104 | } 105 | // Waits until SearchTree class stops thinking. 106 | std::this_thread::sleep_for(std::chrono::milliseconds(10)); // 10 msec 107 | }); 108 | 109 | // Goes pondering until the next command is received. 110 | if (start_pondering) { 111 | double winning_rate = 0.5; 112 | double time_limit = 100.0; 113 | if (Options["lizzie"]) 114 | time_limit = 86400.0; 115 | else if (tree_.byoyomi() > 0 && tree_.main_time() > 0 && 116 | tree_.left_time() < tree_.byoyomi() * 2) 117 | time_limit = tree_.byoyomi() * 2; 118 | 119 | tree_.Search(b_, time_limit, &winning_rate, false, true, lizzie_interval_); 120 | } 121 | 122 | read_th.join(); 123 | tree_.PrepareToThink(); 124 | 125 | // Processes GTP command. 126 | if (command == "" || command == "\n") continue; 127 | // Executes each command. 128 | // Stops when 'quit' command is send. 129 | running = ExecuteCommand(command); 130 | } 131 | } 132 | 133 | bool ExecuteCommand(std::string command) { 134 | // 1. Print command into the log file. 135 | if (tree_.log_file()) *tree_.log_file() << command << std::endl; 136 | 137 | // 2. Parses command. 138 | int command_id = -1; 139 | std::string type = ParseCommand(command, &command_id, &args_); 140 | std::string response = ""; 141 | success_handle_ = true; 142 | 143 | if (FindString(command, "protocol_version")) { 144 | response = "2"; 145 | } else if (type == "name") { 146 | StopLizzieAnalysis(); 147 | response = "AQ"; 148 | } else if (type == "version") { 149 | response = Options["lizzie"] ? "0.16" : std::string(kVersion); 150 | } else if (type == "known_command") { 151 | if (args_.size() >= 1 && 152 | std::find(kListCommands.begin(), kListCommands.end(), args_[0]) != 153 | kListCommands.end()) 154 | response = "true"; 155 | else 156 | response = "false"; 157 | } else if (type == "list_commands") { 158 | for (auto cmd : kListCommands) response += cmd + "\n"; 159 | response += "= "; 160 | } else if (type == "boardsize") { 161 | // Board size setting. (only corresponding to 19 size) 162 | // e.g. "=boardsize 19", "=boardsize 13", ... 163 | if (stoi(args_[0]) != kBSize) { 164 | success_handle_ = false; 165 | response = "This build is allowed to play in only " + 166 | std::to_string(int{kBSize}) + " board."; 167 | fprintf(stderr, "? %s\n", response.c_str()); 168 | } 169 | } else if (type == "clear_board") { 170 | response = OnClearBoardCommand(); 171 | } else if (type == "komi") { 172 | double komi = stod(args_[0]); 173 | Options["komi"] = komi; 174 | tree_.set_komi(komi); 175 | fprintf(stderr, "set komi=%.1f.\n", komi); 176 | } else if (type == "time_left") { 177 | // Sets remaining time. 178 | // e.g. "=time_left B 944", "=time_left white 300", ... 179 | Color c = FindString(args_[0], "B", "b") 180 | ? kBlack 181 | : FindString(args_[0], "W", "w") ? kWhite : kEmpty; 182 | double left_time = stod(args_[1]); 183 | if (c_engine_ == kEmpty || c_engine_ == c) tree_.set_left_time(left_time); 184 | Options["need_time_control"] = "false"; 185 | } else if (type == "genmove") { 186 | response = OnGenmoveCommand(); 187 | } else if (type == "play") { 188 | response = OnPlayCommand(); 189 | } else if (type == "undo") { 190 | response = OnUndoCommand(); 191 | } else if (type == "final_score") { 192 | response = PrintFinalResult(b_); 193 | } else if (type == "lz-analyze") { 194 | response = OnLzAnalyzeCommand(); 195 | } else if (type == "kgs-time_settings") { 196 | response = OnKgsTimeSettingsCommand(); 197 | } else if (type == "time_settings") { 198 | response = OnTimeSettingsCommand(); 199 | } else if (type == "set_free_handicap") { 200 | response = OnSetFreeHandicapCommand(); 201 | } else if (type == "fixed_handicap" || type == "place_free_handicap") { 202 | response = OnFixedHandicapCommand(); 203 | } else if (type == "gogui-play_sequence") { 204 | response = OnGoguiPlaySequenceCommand(); 205 | } else if (type == "kgs-game_over") { 206 | go_ponder_ = false; 207 | } else if (type == "quit") { 208 | StopLizzieAnalysis(); 209 | PrintFinalResult(b_); 210 | } else { 211 | success_handle_ = false; 212 | response = "unknown command."; 213 | fprintf(stderr, "? %s\n", response.c_str()); 214 | } 215 | 216 | std::string head_str = success_handle_ ? "=" : "?"; 217 | if (command_id >= 0) head_str += std::to_string(command_id); 218 | SendGTPCommand("%s %s\n\n", head_str.c_str(), response.c_str()); 219 | 220 | return (type != "quit"); 221 | } 222 | 223 | /** 224 | * Checks if an arbitrary string is included. 225 | */ 226 | bool FindString(std::string str, std::string s1, std::string s2 = "", 227 | std::string s3 = "") { 228 | bool found = false; 229 | found |= (s1 != "" && str.find(s1) != std::string::npos); 230 | found |= (s2 != "" && str.find(s2) != std::string::npos); 231 | found |= (s3 != "" && str.find(s3) != std::string::npos); 232 | return found; 233 | } 234 | 235 | /** 236 | * Parses GTP command to type and arguments. 237 | */ 238 | std::string ParseCommand(const std::string& command, int* command_id, 239 | std::vector* args) { 240 | *command_id = -1; 241 | args->clear(); 242 | std::string type = ""; 243 | 244 | std::istringstream iss(command); 245 | std::string s; 246 | 247 | while (iss >> s) { 248 | if (type == "") { 249 | if (s.substr(0, 1) == "=") s = s.substr(1); 250 | if (s == "") continue; 251 | 252 | if (std::all_of(s.cbegin(), s.cend(), isdigit)) { 253 | *command_id = std::stoi(s); 254 | } else { 255 | type = s; 256 | } 257 | } else { 258 | args->push_back(s); 259 | } 260 | } 261 | 262 | return type; 263 | } 264 | 265 | /** 266 | * Returns a GTP response using standard output. 267 | */ 268 | void SendGTPCommand(const char* output_str, ...) { 269 | va_list args; 270 | va_start(args, output_str); 271 | vfprintf(stdout, output_str, args); 272 | va_end(args); 273 | } 274 | 275 | /** 276 | * Reads a line of the input GTP command. 277 | */ 278 | void ReceiveGTPCommand(std::string* input_str) { 279 | std::getline(std::cin, *input_str); 280 | } 281 | 282 | /** 283 | * Allocates GPUs. 284 | * Allocating GPU memory in the constructor of the SearchTree class may take 285 | * several tens of seconds, so it is allocated at clear_board to avoid 286 | * timeouts in game servers and GUI. 287 | */ 288 | void AllocateGPU() { 289 | if (!tree_.has_eval_worker()) { 290 | std::cerr << "allocating memory...\n"; 291 | // Waits 5s when rating measurement. 292 | if (!save_log_ && !Options["use_ponder"]) 293 | std::this_thread::sleep_for(std::chrono::seconds(5)); // 5s 294 | tree_.SetGPUAndMemory(); 295 | } 296 | } 297 | 298 | /** 299 | * Return the final score. 300 | * If the log file is specified, the dead stone information is output. 301 | */ 302 | std::string PrintFinalResult(const Board& b_) { 303 | std::vector os_list; 304 | os_list.push_back(&std::cerr); 305 | if (tree_.log_file()) os_list.push_back(tree_.log_file()); 306 | 307 | Board::OwnerMap owner = {0}; 308 | double s = tree_.FinalScore(b_, kVtNull, -1, 1024, &owner); 309 | b_.PrintOwnerMap(s, 1024, owner, os_list); 310 | 311 | if (s == 0) return "0"; 312 | std::stringstream ss; 313 | ss << (s > 0 ? "B+" : "W+"); 314 | ss << std::fixed << std::setprecision(1) << std::abs(s); 315 | 316 | return ss.str(); 317 | } 318 | 319 | /** 320 | * Stops analysis for Lizzie. 321 | */ 322 | void StopLizzieAnalysis() { 323 | tree_.StopToThink(); 324 | lizzie_interval_ = -1; 325 | } 326 | 327 | std::string OnClearBoardCommand(); 328 | std::string OnGenmoveCommand(); 329 | std::string OnPlayCommand(); 330 | std::string OnUndoCommand(); 331 | std::string OnLzAnalyzeCommand(); 332 | std::string OnKgsTimeSettingsCommand(); 333 | std::string OnTimeSettingsCommand(); 334 | std::string OnSetFreeHandicapCommand(); 335 | std::string OnFixedHandicapCommand(); 336 | std::string OnGoguiPlaySequenceCommand(); 337 | 338 | private: 339 | Board b_; 340 | SearchTree tree_; 341 | Color c_engine_; 342 | bool go_ponder_; 343 | bool save_log_; 344 | SgfData sgf_; 345 | std::string log_path_; 346 | std::string sgf_path_; 347 | std::vector args_; 348 | bool success_handle_; 349 | int lizzie_interval_; 350 | }; 351 | 352 | #endif // GTP_H_ 353 | -------------------------------------------------------------------------------- /src/main.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #include "./board.h" 21 | #include "./gtp.h" 22 | #include "./option.h" 23 | #include "./test.h" 24 | 25 | int main(int argc, char **argv) { 26 | std::string mode = ReadConfiguration(argc, argv); 27 | 28 | if (mode == "--benchmark") { 29 | BenchMark(); 30 | NetworkBench(); 31 | } else if (mode == "--test") { 32 | TestBoard(); 33 | } else if (mode == "--self") { 34 | SelfMatch(); 35 | } else if (mode == "--policy_self") { 36 | PolicySelf(); 37 | } else { 38 | GTPConnector gtp_connector; 39 | gtp_connector.Start(); 40 | } 41 | 42 | return 0; 43 | } 44 | -------------------------------------------------------------------------------- /src/network.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #include "./network.h" 21 | 22 | void TensorEngine::Init(std::string model_path, bool use_full_features, 23 | bool value_from_black) { 24 | cudaSetDevice(gpu_id_); 25 | 26 | if (model_path == "") { 27 | if (Options["model_path"].get_string() == "default" || 28 | Options["model_path"].get_string() == "update") { 29 | std::string engine_name = Options["rule"].get_int() == kChinese 30 | ? "model_cn.engine" 31 | : "model_jp.engine"; 32 | model_path = JoinPath(Options["working_dir"], "engine", engine_name); 33 | } else { 34 | model_path = Options["model_path"].get_string(); 35 | } 36 | } 37 | 38 | LoadEngine(model_path); 39 | int max_batch_size = engine_->getMaxBatchSize(); 40 | for (int i = 0; i < engine_->getNbBindings(); ++i) { 41 | auto dim = engine_->getBindingDimensions(i); 42 | std::string dim_str = "("; 43 | int size = 1; 44 | for (int i = 0; i < dim.nbDims; ++i) { 45 | if (i) dim_str += ", "; 46 | dim_str += std::to_string(dim.d[i]); 47 | if (dim.d[i] > 0) size *= dim.d[i]; 48 | } 49 | dim_str += ")"; 50 | 51 | void *buf; 52 | cudaMalloc(&buf, max_batch_size * size * sizeof(float)); 53 | device_bufs_.push_back(buf); 54 | } 55 | 56 | use_full_features_ = use_full_features; 57 | value_from_black_ = value_from_black; 58 | feature_size_ = use_full_features_ ? kInputFeatures : 18; 59 | host_buf_.resize(max_batch_size_ * int{kNumRvts} * feature_size_); 60 | } 61 | 62 | void TensorEngine::BuildFromOnnx(std::string onnx_path, std::string save_path) { 63 | std::cerr << "building engine ... "; 64 | 65 | Logger g_logger; 66 | nvinfer1::IBuilder *builder = nvinfer1::createInferBuilder(g_logger); 67 | const auto explicit_batch = 68 | 1U << static_cast( 69 | nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); 70 | nvinfer1::INetworkDefinition *network = 71 | builder->createNetworkV2(explicit_batch); 72 | auto parser = nvonnxparser::createParser(*network, g_logger); 73 | bool parse_ok = parser->parseFromFile( 74 | onnx_path.c_str(), 75 | static_cast(nvinfer1::ILogger::Severity::kERROR) /* kWARNING */); 76 | if (!parse_ok) { 77 | std::cerr << "[ERROR] File not found: " << onnx_path << std::endl; 78 | if (Options["rule"].get_int() == kJapanese) { 79 | std::cerr << "Use '--model_path' and '--validate_model_path' option to " 80 | "specify a non-default engine file." 81 | << std::endl; 82 | } else { 83 | std::cerr 84 | << "Use '--model_path' option to specify a non-default engine file." 85 | << std::endl; 86 | } 87 | exit(1); 88 | } 89 | 90 | int batch_size = Options["batch_size"].get_int(); 91 | builder->setMaxBatchSize(batch_size); 92 | nvinfer1::IBuilderConfig *config = builder->createBuilderConfig(); 93 | config->setMaxWorkspaceSize(uint32_t{1} << 28); 94 | if (builder->platformHasFastFp16()) { 95 | config->setFlag(nvinfer1::BuilderFlag::kFP16); 96 | } 97 | 98 | nvinfer1::Dims input_dims = network->getInput(0)->getDimensions(); 99 | std::string inputs_name = network->getInput(0)->getName(); 100 | 101 | nvinfer1::IOptimizationProfile *profile = 102 | builder->createOptimizationProfile(); 103 | profile->setDimensions(inputs_name.c_str(), 104 | nvinfer1::OptProfileSelector::kMIN, 105 | nvinfer1::Dims3(batch_size, kInputFeatures, kNumRvts)); 106 | profile->setDimensions(inputs_name.c_str(), 107 | nvinfer1::OptProfileSelector::kMAX, 108 | nvinfer1::Dims3(batch_size, kInputFeatures, kNumRvts)); 109 | profile->setDimensions(inputs_name.c_str(), 110 | nvinfer1::OptProfileSelector::kOPT, 111 | nvinfer1::Dims3(batch_size, kInputFeatures, kNumRvts)); 112 | config->addOptimizationProfile(profile); 113 | engine_ = builder->buildEngineWithConfig(*network, *config); 114 | SetBufferIndex(); 115 | 116 | context_ = engine_->createExecutionContext(); 117 | if (input_dims.d[0] < 0) { 118 | input_dims.d[0] = batch_size; 119 | context_->setBindingDimensions(inputs_idx_, input_dims); 120 | } 121 | 122 | SaveSerializedEngine(save_path); 123 | 124 | builder->destroy(); 125 | network->destroy(); 126 | parser->destroy(); 127 | config->destroy(); 128 | 129 | std::cerr << "completed." << std::endl; 130 | } 131 | 132 | void TensorEngine::BuildFromUff(std::string uff_path, std::string save_path) { 133 | std::cerr << "building engine ... "; 134 | 135 | Logger g_logger; 136 | nvinfer1::IBuilder *builder = nvinfer1::createInferBuilder(g_logger); 137 | nvinfer1::INetworkDefinition *network = builder->createNetwork(); 138 | 139 | auto parser = nvuffparser::createUffParser(); 140 | parser->registerInput("inputs", 141 | nvinfer1::DimsCHW(kInputFeatures, kBSize, kBSize), 142 | nvuffparser::UffInputOrder::kNCHW); 143 | auto dtype = builder->platformHasFastFp16() ? nvinfer1::DataType::kHALF 144 | : nvinfer1::DataType::kFLOAT; 145 | 146 | bool parse_ok = parser->parse(uff_path.c_str(), *network, dtype); 147 | if (!parse_ok) { 148 | std::cerr << "[ERROR] File not found: " << uff_path << std::endl; 149 | if (Options["rule"].get_int() == kJapanese) { 150 | std::cerr << "Use '--model_path' and '--validate_model_path' option to " 151 | "specify a non-default engine file." 152 | << std::endl; 153 | } else { 154 | std::cerr 155 | << "Use '--model_path' option to specify a non-default engine file." 156 | << std::endl; 157 | } 158 | exit(1); 159 | } 160 | int batch_size = Options["batch_size"].get_int(); 161 | builder->setMaxBatchSize(std::max(batch_size, 8)); 162 | nvinfer1::IBuilderConfig *config = builder->createBuilderConfig(); 163 | config->setMaxWorkspaceSize(uint32_t{1} << 28); 164 | if (builder->platformHasFastFp16()) { 165 | config->setFlag(nvinfer1::BuilderFlag::kFP16); 166 | } 167 | 168 | nvinfer1::Dims input_dims = network->getInput(0)->getDimensions(); 169 | std::string inputs_name = network->getInput(0)->getName(); 170 | 171 | engine_ = builder->buildEngineWithConfig(*network, *config); 172 | SetBufferIndex(); 173 | SaveSerializedEngine(save_path); 174 | context_ = engine_->createExecutionContext(); 175 | 176 | builder->destroy(); 177 | network->destroy(); 178 | parser->destroy(); 179 | config->destroy(); 180 | 181 | std::cerr << "completed." << std::endl; 182 | } 183 | 184 | void TensorEngine::LoadEngine(std::string model_path) { 185 | std::string base_path = model_path; 186 | if (base_path.find(".engine") != std::string::npos) { 187 | base_path = base_path.substr(0, base_path.size() - 7); 188 | } else if (base_path.find(".onnx") != std::string::npos) { 189 | base_path = base_path.substr(0, base_path.size() - 5); 190 | } else if (base_path.find(".uff") != std::string::npos) { 191 | base_path = base_path.substr(0, base_path.size() - 4); 192 | } 193 | std::string onnx_path = base_path + ".onnx"; 194 | std::string uff_path = base_path + ".uff"; 195 | std::string engine_path = base_path + ".engine"; 196 | 197 | std::ifstream ifs(engine_path, std::ios::binary); 198 | 199 | if (ifs.is_open()) { 200 | std::ostringstream model_ss(std::ios::binary); 201 | model_ss << ifs.rdbuf(); 202 | std::string model_str = model_ss.str(); 203 | 204 | Logger g_logger; 205 | runtime_ = nvinfer1::createInferRuntime(g_logger); 206 | engine_ = runtime_->deserializeCudaEngine(model_str.c_str(), 207 | model_str.size(), nullptr); 208 | 209 | if (engine_->getMaxBatchSize() < Options["batch_size"].get_int()) { 210 | std::cerr << "Max batch size of the sirialized model" 211 | "is smaller than batch_size." 212 | << std::endl; 213 | if (use_uff_) 214 | BuildFromUff(uff_path, engine_path); 215 | else 216 | BuildFromOnnx(onnx_path, engine_path); 217 | } 218 | 219 | SetBufferIndex(); 220 | context_ = engine_->createExecutionContext(); 221 | 222 | auto input_dims = engine_->getBindingDimensions(inputs_idx_); 223 | if (input_dims.d[0] < 0) { 224 | input_dims.d[0] = Options["batch_size"].get_int(); 225 | context_->setBindingDimensions(0, input_dims); 226 | } 227 | } else { 228 | if (use_uff_) 229 | BuildFromUff(uff_path, engine_path); 230 | else 231 | BuildFromOnnx(onnx_path, engine_path); 232 | } 233 | } 234 | 235 | bool TensorEngine::Infer(const Feature &ft, ValueAndProb *vp, 236 | int symmetry_idx) { 237 | int inputs_size = kNumRvts * feature_size_; 238 | float *inputs_itr = host_buf_.data(); 239 | 240 | if (symmetry_idx == kNumSymmetry) symmetry_idx = RandSymmetry(); 241 | ft.Copy(inputs_itr, use_full_features_, symmetry_idx); 242 | 243 | cudaMemcpy(device_bufs_[inputs_idx_], host_buf_.data(), 244 | inputs_size * sizeof(float), cudaMemcpyHostToDevice); 245 | 246 | if (use_uff_) 247 | context_->execute(max_batch_size_, device_bufs_.data()); 248 | else 249 | context_->executeV2(device_bufs_.data()); 250 | 251 | std::vector policy(kNumRvts); 252 | std::vector value(1); 253 | cudaMemcpy(policy.data(), device_bufs_[policy_idx_], 254 | policy.size() * sizeof(float), cudaMemcpyDeviceToHost); 255 | cudaMemcpy(value.data(), device_bufs_[value_idx_], 256 | value.size() * sizeof(float), cudaMemcpyDeviceToHost); 257 | 258 | if (symmetry_idx == 0) { 259 | std::copy_n(policy.begin(), kNumRvts, vp->prob.begin()); 260 | } else { 261 | auto p_itr = policy.begin(); 262 | for (int j = 0; j < kNumRvts; ++j) { 263 | vp->prob[rv2sym(j, symmetry_idx)] = *p_itr; 264 | ++p_itr; 265 | } 266 | } 267 | if (value_from_black_ && ft.next_side() == kWhite) value[0] *= -1; 268 | vp->value = value[0]; 269 | 270 | return true; 271 | } 272 | 273 | bool TensorEngine::Infer(std::vector> *entries, 274 | int symmetry_idx) { 275 | int batch_size = entries->size(); 276 | 277 | if (batch_size <= 0 || batch_size > max_batch_size_) return false; 278 | 279 | int inputs_size = batch_size * kNumRvts * feature_size_; 280 | float *inputs_itr = host_buf_.data(); 281 | 282 | std::vector symmetries(batch_size); 283 | if (symmetry_idx != kNumSymmetry) { 284 | for (int i = 0; i < batch_size; ++i) symmetries[i] = symmetry_idx; 285 | } else { // symmetry_idx == kNumSymmetry 286 | for (int i = 0; i < batch_size; ++i) symmetries[i] = RandSymmetry(); 287 | } 288 | 289 | for (int i = 0; i < batch_size; ++i) 290 | inputs_itr = 291 | (*entries)[i]->ft.Copy(inputs_itr, use_full_features_, symmetries[i]); 292 | 293 | cudaMemcpy(device_bufs_[inputs_idx_], host_buf_.data(), 294 | inputs_size * sizeof(float), cudaMemcpyHostToDevice); 295 | 296 | if (use_uff_) 297 | context_->execute(max_batch_size_, device_bufs_.data()); 298 | else 299 | context_->executeV2(device_bufs_.data()); 300 | 301 | std::vector policy(batch_size * kNumRvts); 302 | std::vector value(batch_size); 303 | cudaMemcpy(policy.data(), device_bufs_[policy_idx_], 304 | policy.size() * sizeof(float), cudaMemcpyDeviceToHost); 305 | cudaMemcpy(value.data(), device_bufs_[value_idx_], 306 | value.size() * sizeof(float), cudaMemcpyDeviceToHost); 307 | 308 | auto p_itr = policy.begin(); 309 | 310 | for (int i = 0; i < batch_size; ++i) { 311 | if (symmetry_idx == 0) { 312 | std::copy_n(p_itr, kNumRvts, (*entries)[i]->vp.prob.begin()); 313 | std::advance(p_itr, kNumRvts); 314 | } else { 315 | for (int j = 0; j < kNumRvts; ++j) { 316 | (*entries)[i]->vp.prob[rv2sym(j, symmetries[i])] = *p_itr; 317 | ++p_itr; 318 | } 319 | } 320 | 321 | if (value_from_black_ && (*entries)[i]->ft.next_side() == kWhite) 322 | value[i] *= -1; 323 | (*entries)[i]->vp.value = value[i]; 324 | } 325 | 326 | return true; 327 | } 328 | 329 | bool TensorEngine::Infer(std::vector *entries, int symmetry_idx) { 330 | int batch_size = entries->size(); 331 | if (batch_size <= 0) return false; 332 | 333 | ASSERT_LV3(batch_size <= max_batch_size_); 334 | 335 | int inputs_size = batch_size * kNumRvts * feature_size_; 336 | float *inputs_itr = host_buf_.data(); 337 | 338 | std::vector symmetries(batch_size); 339 | if (symmetry_idx != kNumSymmetry) { 340 | for (int i = 0; i < batch_size; ++i) symmetries[i] = symmetry_idx; 341 | } else { // symmetry_idx == 8 342 | for (int i = 0; i < batch_size; ++i) symmetries[i] = RandSymmetry(); 343 | } 344 | 345 | for (int i = 0; i < batch_size; ++i) 346 | inputs_itr = 347 | (*entries)[i].ft.Copy(inputs_itr, use_full_features_, symmetries[i]); 348 | 349 | cudaMemcpy(device_bufs_[inputs_idx_], host_buf_.data(), 350 | inputs_size * sizeof(float), cudaMemcpyHostToDevice); 351 | 352 | if (use_uff_) 353 | context_->execute(max_batch_size_, device_bufs_.data()); 354 | else 355 | context_->executeV2(device_bufs_.data()); 356 | 357 | std::vector policy(batch_size * kNumRvts); 358 | std::vector value(batch_size); 359 | cudaMemcpy(policy.data(), device_bufs_[policy_idx_], 360 | policy.size() * sizeof(float), cudaMemcpyDeviceToHost); 361 | cudaMemcpy(value.data(), device_bufs_[value_idx_], 362 | value.size() * sizeof(float), cudaMemcpyDeviceToHost); 363 | 364 | auto p_itr = policy.begin(); 365 | 366 | for (int i = 0; i < batch_size; ++i) { 367 | if (symmetry_idx == 0) { 368 | std::copy_n(p_itr, kNumRvts, (*entries)[i].vp.prob.begin()); 369 | std::advance(p_itr, kNumRvts); 370 | } else { 371 | for (int j = 0; j < kNumRvts; ++j) { 372 | (*entries)[i].vp.prob[rv2sym(j, symmetries[i])] = *p_itr; 373 | ++p_itr; 374 | } 375 | } 376 | 377 | if (value_from_black_ && (*entries)[i].ft.next_side() == kWhite) 378 | value[i] *= -1; 379 | (*entries)[i].vp.value = value[i]; 380 | } 381 | 382 | return true; 383 | } 384 | -------------------------------------------------------------------------------- /src/network.h: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #ifndef NETWORK_H_ 21 | #define NETWORK_H_ 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | 34 | #include "./eval_cache.h" 35 | #include "./option.h" 36 | #include "./route_queue.h" 37 | 38 | constexpr int kInputFeatures = 52; 39 | constexpr int kNumSymmetry = 8; 40 | 41 | /** 42 | * @class Logger 43 | * Logger class of nvinfer. 44 | * Only outputs Error and ingnores Warning and Info messages. 45 | */ 46 | class Logger : public nvinfer1::ILogger { 47 | void log(Severity severity, const char *msg) override { 48 | switch (severity) { 49 | case Severity::kINTERNAL_ERROR: 50 | std::cerr << msg << std::endl; 51 | exit(1); 52 | case Severity::kERROR: 53 | std::cerr << msg << std::endl; 54 | exit(1); 55 | case Severity::kWARNING: 56 | break; 57 | case Severity::kINFO: 58 | break; 59 | } 60 | } 61 | }; 62 | 63 | /** 64 | * @class TensorEngine 65 | * TensorEngine class handles inference order by GPU via TensorRT APIs. 66 | * If you want to use the same network structure as AlphaGoZero, you can make 67 | * (18 * 19 * 19) inferences by setting use_full_feature = false. 68 | * NOTE: Don't call Init() in different programs or threads at the same time. 69 | */ 70 | class TensorEngine { 71 | public: 72 | TensorEngine(int gpu_id, int batch_size) 73 | : engine_(nullptr), 74 | runtime_(nullptr), 75 | context_(nullptr), 76 | gpu_id_(gpu_id), 77 | use_uff_(true), 78 | max_batch_size_(batch_size) {} 79 | 80 | ~TensorEngine() { 81 | if (context_) context_->destroy(); 82 | if (engine_) engine_->destroy(); 83 | if (runtime_) runtime_->destroy(); 84 | for (auto buf : device_bufs_) cudaFree(buf); 85 | } 86 | 87 | /** 88 | * Serializes the engine and save it to a file. 89 | * Serialized engines are incompatible across TensorRT versions and devices. 90 | */ 91 | void SaveSerializedEngine(std::string save_path) { 92 | auto serialized_engine = engine_->serialize(); 93 | std::ofstream ofs(save_path, std::ios::binary); 94 | ofs.write(reinterpret_cast(serialized_engine->data()), 95 | serialized_engine->size()); 96 | serialized_engine->destroy(); 97 | } 98 | 99 | /** 100 | * Assigns an index to the device buffer. 101 | */ 102 | void SetBufferIndex() { 103 | if (engine_) { 104 | std::string inputs_name = engine_->getBindingName(0); 105 | std::string policy_name = engine_->getBindingName(1); 106 | std::string value_name = engine_->getBindingName(2); 107 | if (policy_name.find("value") != std::string::npos) { 108 | std::string tmp = policy_name; 109 | policy_name = value_name; 110 | value_name = tmp; 111 | } 112 | 113 | inputs_idx_ = engine_->getBindingIndex(inputs_name.c_str()); 114 | policy_idx_ = engine_->getBindingIndex(policy_name.c_str()); 115 | value_idx_ = engine_->getBindingIndex(value_name.c_str()); 116 | } 117 | } 118 | 119 | /** 120 | * Builds the TensorRT engine from an ONNX file. 121 | * This method is called when a serialized engine is not found or does not 122 | * fit, and generates a serialized file. 123 | */ 124 | void BuildFromOnnx(std::string onnx_path, std::string save_path); 125 | 126 | /** 127 | * Builds the TensorRT engine from an UFF file. 128 | * This method is called when a serialized engine is not found or does not 129 | * fit, and generates a serialized file. 130 | */ 131 | void BuildFromUff(std::string uff_path, std::string save_path); 132 | 133 | /** 134 | * Loads a serialized engine file. 135 | * If the file is not found, build it from a UFF format file and save it. 136 | */ 137 | void LoadEngine(std::string model_path); 138 | 139 | void Init(std::string model_path = "", bool use_full_features = true, 140 | bool value_from_black = false); 141 | 142 | /** 143 | * Infers a single board. 144 | */ 145 | bool Infer(const Feature &ft, ValueAndProb *vp, int symmetry_idx = 0); 146 | 147 | /** 148 | * Infers boards from a list of SyncedEntries. 149 | */ 150 | bool Infer(std::vector> *entries, 151 | int symmetry_idx = 0); 152 | 153 | /** 154 | * Infers boards from a list of RouteEntires. 155 | */ 156 | bool Infer(std::vector *entries, int symmetry_idx = 0); 157 | 158 | private: 159 | nvinfer1::ICudaEngine *engine_; 160 | nvinfer1::IRuntime *runtime_; 161 | nvinfer1::IExecutionContext *context_; 162 | std::vector device_bufs_; 163 | std::vector host_buf_; 164 | int gpu_id_; 165 | int max_batch_size_; 166 | int feature_size_; 167 | bool use_full_features_; 168 | bool value_from_black_; 169 | int inputs_idx_; 170 | int policy_idx_; 171 | int value_idx_; 172 | bool use_uff_; 173 | }; 174 | 175 | #endif // NETWORK_H_ 176 | -------------------------------------------------------------------------------- /src/node.h: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #ifndef NODE_H_ 21 | #define NODE_H_ 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | 31 | #include "./board.h" 32 | 33 | /** 34 | * Function to perform arithmetic addition for atomic, atomic, 35 | * and other atomic variables for which there is no += operator available. 36 | */ 37 | template 38 | inline T FetchAdd(std::atomic* obj, T arg) { 39 | T expected = obj->load(); 40 | while (!atomic_compare_exchange_weak(obj, &expected, expected + arg)) { 41 | } 42 | return expected; 43 | } 44 | 45 | // -------------------- 46 | // RateStat 47 | // -------------------- 48 | 49 | /** 50 | * @class RateStat 51 | * Base class that holds the number of visits to a node and the cumulative value 52 | * of the results, which is inherited by the ChildNode and Node class. 53 | */ 54 | class RateStat { 55 | public: 56 | RateStat() { 57 | num_rollouts_.store(0, std::memory_order_relaxed); 58 | num_values_.store(0, std::memory_order_relaxed); 59 | win_rollouts_.store(0.0, std::memory_order_relaxed); 60 | win_values_.store(0.0, std::memory_order_relaxed); 61 | } 62 | RateStat(const RateStat& rhs) { *this = rhs; } 63 | 64 | void Init() { 65 | num_rollouts_.store(0, std::memory_order_relaxed); 66 | num_values_.store(0, std::memory_order_relaxed); 67 | win_rollouts_.store(0.0, std::memory_order_relaxed); 68 | win_values_.store(0.0, std::memory_order_relaxed); 69 | } 70 | 71 | RateStat& operator=(const RateStat& rhs) { 72 | num_rollouts_.store(rhs.num_rollouts_.load(std::memory_order_relaxed), 73 | std::memory_order_relaxed); 74 | num_values_.store(rhs.num_values_.load(std::memory_order_relaxed), 75 | std::memory_order_relaxed); 76 | win_rollouts_.store(rhs.win_rollouts_.load(std::memory_order_relaxed), 77 | std::memory_order_relaxed); 78 | win_values_.store(rhs.win_values_.load(std::memory_order_relaxed), 79 | std::memory_order_relaxed); 80 | 81 | return *this; 82 | } 83 | 84 | RateStat& operator+=(const RateStat& rhs) { 85 | num_rollouts_ += rhs.num_rollouts_; 86 | num_values_ += rhs.num_values_; 87 | FetchAdd(&win_rollouts_, rhs.win_rollouts_.load()); 88 | FetchAdd(&win_values_, rhs.win_values_.load()); 89 | 90 | return *this; 91 | } 92 | 93 | int num_rollouts() const { return num_rollouts_.load(); } 94 | int num_values() const { return num_values_.load(); } 95 | double win_rollouts() const { return win_rollouts_.load(); } 96 | double win_values() const { return win_values_.load(); } 97 | 98 | double rollout_rate() const { 99 | int rc = num_rollouts_.load(); 100 | return rc == 0 ? 0.0 : win_rollouts_.load() / rc; 101 | } 102 | 103 | double value_rate() const { 104 | int vc = num_values_.load(); 105 | return vc == 0 ? 0.0 : win_values_.load() / vc; 106 | } 107 | 108 | double winning_rate(double lambda_) const { 109 | return (1 - lambda_) * rollout_rate() + lambda_ * value_rate(); 110 | } 111 | 112 | void InitValueStat() { 113 | num_values_.store(0, std::memory_order_relaxed); 114 | win_values_.store(0.0, std::memory_order_relaxed); 115 | } 116 | 117 | void AddFlipedStat(const RateStat& rs) { 118 | num_rollouts_ += rs.num_rollouts_; 119 | num_values_ += rs.num_values_; 120 | FetchAdd(&win_rollouts_, rs.win_rollouts_.load()); 121 | FetchAdd(&win_values_, rs.win_values_.load()); 122 | } 123 | 124 | void AddValueOnce(float win) { 125 | num_values_ += 1; 126 | FetchAdd(&win_values_, win); 127 | } 128 | 129 | protected: 130 | std::atomic num_rollouts_; // The number of rollout execution. 131 | std::atomic num_values_; // The number of board evaluation. 132 | std::atomic win_rollouts_; // Sum of rollout results. 133 | std::atomic win_values_; // Sum of evaluation values. 134 | }; 135 | 136 | // -------------------- 137 | // ChildNode 138 | // -------------------- 139 | 140 | class Node; 141 | 142 | /** 143 | * @enum CreateState 144 | * Creation status of ChildNode. 145 | */ 146 | enum CreateState : uint8_t { 147 | kInitial = 0, 148 | kCreating, 149 | kComplete, 150 | }; 151 | 152 | /** 153 | * @class ChildNode 154 | * Class to store moves and probabilities on child nodes and pointers to the 155 | * next node. ChildNodes are expanded as Node::children when a Node is created. 156 | */ 157 | class ChildNode : public RateStat { 158 | public: 159 | ChildNode() { 160 | move_.store(kPass, std::memory_order_relaxed); 161 | prob_.store(0.0, std::memory_order_relaxed); 162 | next_ptr_.reset(); 163 | create_state_.store(kInitial, std::memory_order_relaxed); 164 | } 165 | 166 | ChildNode(const ChildNode& rhs) : RateStat(rhs) { 167 | move_.store(rhs.move_.load()); 168 | prob_.store(rhs.prob_.load()); 169 | // next_ptr_ = std::move(rhs.next_ptr_); 170 | next_ptr_.reset(); 171 | create_state_.store(rhs.create_state_.load()); 172 | } 173 | 174 | ChildNode(ChildNode&& rhs) : RateStat(rhs) { 175 | move_.store(rhs.move_.load(std::memory_order_relaxed), 176 | std::memory_order_relaxed); 177 | prob_.store(rhs.prob_.load(std::memory_order_relaxed), 178 | std::memory_order_relaxed); 179 | next_ptr_ = std::move(rhs.next_ptr_); 180 | create_state_.store(rhs.create_state_.load(std::memory_order_relaxed), 181 | std::memory_order_relaxed); 182 | } 183 | 184 | ChildNode(Vertex v, float p) { 185 | move_.store(v, std::memory_order_relaxed); 186 | prob_.store(p, std::memory_order_relaxed); 187 | next_ptr_.reset(); 188 | create_state_.store(kInitial, std::memory_order_relaxed); 189 | } 190 | 191 | ~ChildNode() { next_ptr_.reset(); } 192 | 193 | bool operator<(const ChildNode& rhs) const { 194 | return prob_.load() < rhs.prob_.load(); 195 | } 196 | 197 | bool operator>(const ChildNode& rhs) const { 198 | return prob_.load() > rhs.prob_.load(); 199 | } 200 | 201 | Vertex move() const { return move_.load(); } 202 | 203 | float prob() const { return prob_.load(); } 204 | 205 | Node* next_ptr() const { return next_ptr_.get(); } 206 | 207 | int num_entries() const; 208 | 209 | bool has_next() const { return static_cast(next_ptr_); } 210 | 211 | void set_move(Vertex v) { move_.store(v); } 212 | 213 | void set_prob(float val) { prob_.store(val); } 214 | 215 | void set_next_ptr(std::unique_ptr* pnd_) { 216 | next_ptr_ = std::move(*pnd_); 217 | } 218 | 219 | bool SetCreatingState() { 220 | uint8_t expected = kInitial; 221 | uint8_t desired = kCreating; 222 | return create_state_.compare_exchange_strong(expected, desired); 223 | } 224 | 225 | void SetCompleteState() { create_state_.exchange(kComplete); } 226 | 227 | void WaitForComplete() { 228 | while (create_state_.load() == kCreating) { 229 | } 230 | } 231 | 232 | friend class Node; 233 | friend class RootNode; 234 | 235 | private: 236 | std::atomic move_; // Move to the child board. 237 | std::atomic prob_; // Probability of the move. 238 | std::unique_ptr next_ptr_; // Pointer of the next node. 239 | std::atomic create_state_; 240 | }; 241 | 242 | // -------------------- 243 | // Node 244 | // -------------------- 245 | 246 | /** 247 | * @class Node 248 | * This node represents a board state, initialized in the Board class, and 249 | * expands the legal moves to an array of child nodes. It is necessary to update 250 | * the results of GPU evaluation to value_ and children prob_. 251 | */ 252 | class Node : public RateStat { 253 | public: 254 | std::vector children; 255 | 256 | // Constructor 257 | Node() 258 | : ply_(0), 259 | num_total_values_(1), 260 | num_total_rollouts_(1), 261 | value_(0.0), 262 | key_(UINT64_MAX), 263 | num_entries_(1) {} 264 | 265 | Node(const Node& rhs) : RateStat(rhs) { 266 | ply_.store(rhs.ply_.load()); 267 | num_total_values_.store(rhs.num_total_values_.load()); 268 | num_total_rollouts_.store(rhs.num_total_rollouts_.load()); 269 | value_.store(rhs.value_.load()); 270 | key_.store(rhs.key_.load()); 271 | children.clear(); 272 | for (auto& ch : rhs.children) children.emplace_back(ChildNode(ch)); 273 | num_entries_.store(rhs.num_entries_.load()); 274 | } 275 | 276 | explicit Node(const Board& b) { *this = b; } 277 | 278 | ~Node() { children.clear(); } 279 | 280 | /** 281 | * Generates node from board. 282 | */ 283 | Node& operator=(const Board& b) { 284 | ply_ = b.game_ply(); 285 | num_total_values_ = 1; 286 | num_total_rollouts_ = 1; 287 | key_ = b.key(); 288 | children.clear(); 289 | RateStat::Init(); 290 | value_ = 0.0; 291 | 292 | Board b_cpy = b; 293 | constexpr int num_escapes = kBSize == 9 ? 3 : 4; 294 | std::vector esc_list = b_cpy.LadderEscapes(num_escapes); 295 | 296 | std::vector legals; 297 | 298 | for (Vertex v : b.empties()) { 299 | bool is_sensible = 300 | (b.IsLegal(v) && !b.IsEyeShape(b.side_to_move(), v, true) && 301 | !b.IsSeki(v)); 302 | if (!is_sensible || b.CheckRepetition(v) == kRepetitionLose) continue; 303 | legals.push_back(v); 304 | } 305 | legals.push_back(kPass); 306 | 307 | int num_childern = legals.size(); 308 | children.resize(num_childern); 309 | for (int i = 0; i < num_childern; ++i) { 310 | Vertex v = legals[i]; 311 | children[i].set_move(v); 312 | if (!esc_list.empty() && 313 | std::find(esc_list.begin(), esc_list.end(), v) != esc_list.end()) 314 | children[i].set_prob(-0.1); 315 | } 316 | 317 | num_entries_ = 1; 318 | 319 | return *this; 320 | } 321 | 322 | int num_children() const { return children.size(); } 323 | 324 | int game_ply() const { return ply_.load(); } 325 | 326 | int num_total_values() const { return num_total_values_.load(); } 327 | 328 | int num_total_rollouts() const { return num_total_rollouts_.load(); } 329 | 330 | float value() const { return value_.load(); } 331 | 332 | Key key() const { return key_.load(); } 333 | 334 | int num_entries() const { return num_entries_.load(); } 335 | 336 | std::mutex& mutex() { return mx_; } 337 | 338 | void set_num_total_values(int val) { num_total_values_.store(val); } 339 | 340 | void set_num_total_rollouts(int val) { num_total_rollouts_.store(val); } 341 | 342 | void set_value(float val) { value_.store(val); } 343 | 344 | void increment_entries() { ++num_entries_; } 345 | 346 | template 347 | void VirtualLoss(int child_id, float virtual_loss) { 348 | if (NNSearch) { 349 | FetchAdd(&(children[child_id].win_values_), -virtual_loss); 350 | children[child_id].num_values_ += virtual_loss; 351 | FetchAdd(&(win_values_), -virtual_loss); 352 | num_values_ += virtual_loss; 353 | num_total_values_ += virtual_loss; 354 | } else { 355 | FetchAdd(&(children[child_id].win_rollouts_), -virtual_loss); 356 | children[child_id].num_rollouts_ += virtual_loss; 357 | FetchAdd(&(win_rollouts_), -virtual_loss); 358 | num_rollouts_ += virtual_loss; 359 | num_total_rollouts_ += virtual_loss; 360 | } 361 | } 362 | 363 | template 364 | void VirtualWin(int child_id, float virtual_loss, int num_requests, 365 | float win = 0.0) { 366 | if (NNSearch) { 367 | FetchAdd(&(children[child_id].win_values_), 368 | (win + virtual_loss) * num_requests); 369 | children[child_id].num_values_ -= (virtual_loss - 1) * num_requests; 370 | FetchAdd(&(win_values_), (win + virtual_loss) * num_requests); 371 | num_values_ -= (virtual_loss - 1) * num_requests; 372 | num_total_values_ -= (virtual_loss - 1) * num_requests; 373 | } else { 374 | FetchAdd(&(children[child_id].win_rollouts_), 375 | (win + virtual_loss) * num_requests); 376 | children[child_id].num_rollouts_ -= (virtual_loss - 1) * num_requests; 377 | FetchAdd(&(win_rollouts_), (win + virtual_loss) * num_requests); 378 | num_rollouts_ -= (virtual_loss - 1) * num_requests; 379 | num_total_rollouts_ -= (virtual_loss - 1) * num_requests; 380 | } 381 | } 382 | 383 | private: 384 | std::atomic ply_; // Number of moves in the game. 385 | // Sum of evaluation visits of all child nodes. 386 | std::atomic num_total_values_; 387 | // Sum of rollout visits of all child nodes. 388 | std::atomic num_total_rollouts_; 389 | std::atomic value_; // Evaluated value of the child board. 390 | std::atomic key_; // Board hash of the node. 391 | std::atomic num_entries_; // Total node number under this node. 392 | std::mutex mx_; // Mutex for lock of this node. 393 | }; 394 | 395 | inline int ChildNode::num_entries() const { 396 | return has_next() ? next_ptr_->num_entries() : 0; 397 | } 398 | 399 | // -------------------- 400 | // RootNode 401 | // -------------------- 402 | 403 | /** 404 | * @class RootNode 405 | * The RootNode class holds a pointer to the root node of the search tree. When 406 | * the board is advanced, it transitions to the corresponding child node. 407 | */ 408 | class RootNode { 409 | public: 410 | RootNode() : max_num_entries_(0), num_entries_(0), pnd_(nullptr) {} 411 | 412 | RootNode(const RootNode& rhs) = delete; 413 | 414 | int num_entries() const { return num_entries_.load(); } 415 | 416 | double entry_rate() const { 417 | return max_num_entries_ == 0 ? 0.0 : num_entries_.load() / max_num_entries_; 418 | } 419 | 420 | Node* node() const { return pnd_.get(); } 421 | 422 | void increment_entries() { ++num_entries_; } 423 | 424 | void set_node(std::unique_ptr* pnd) { pnd_ = std::move(*pnd); } 425 | 426 | void Init() { 427 | num_entries_ = 0; 428 | pnd_.reset(); 429 | } 430 | 431 | void Resize(int max_size) { 432 | max_num_entries_ = max_size; 433 | num_entries_ = 0; 434 | pnd_.reset(); 435 | } 436 | 437 | bool ShiftRootNode(Vertex v, const Board& b, bool create_if_not_found = true); 438 | 439 | private: 440 | int max_num_entries_; 441 | std::atomic num_entries_; 442 | std::unique_ptr pnd_; 443 | }; 444 | 445 | inline bool RootNode::ShiftRootNode(Vertex v, const Board& b, 446 | bool create_if_not_found) { 447 | // Already updated. 448 | if (static_cast(pnd_) && pnd_->game_ply() == b.game_ply() && 449 | pnd_->key() == b.key()) 450 | return true; 451 | 452 | bool found_next = false; 453 | 454 | if (static_cast(pnd_) && pnd_->game_ply() + 1 == b.game_ply()) { 455 | for (int i = 0, n = pnd_->num_children(); i < n; ++i) { 456 | ChildNode* cn = &pnd_->children[i]; 457 | if (cn->move() == v && cn->has_next()) { 458 | found_next = true; 459 | num_entries_ = std::max(1, cn->num_entries()); 460 | 461 | auto prev_pnd = std::move(pnd_); 462 | pnd_ = std::move(prev_pnd->children[i].next_ptr_); 463 | 464 | // Deletes previous pnd_ in another thread, which 465 | // takes 400,000 nodes per sec. 466 | auto p = prev_pnd.release(); 467 | auto th = std::thread([p]() { delete p; }); 468 | th.detach(); 469 | 470 | break; 471 | } 472 | } 473 | } else if (static_cast(pnd_) && pnd_->game_ply() + 2 == b.game_ply()) { 474 | Vertex v_prev = b.move_before_2(); 475 | 476 | for (size_t i = 0, n = pnd_->num_children(); i < n; ++i) { 477 | ChildNode* cn = &pnd_->children[i]; 478 | if (cn->move() == v_prev && cn->has_next()) { 479 | Node* nnd = cn->next_ptr(); 480 | for (size_t j = 0, n = nnd->num_children(); j < n; ++j) { 481 | ChildNode* gcn = &nnd->children[j]; 482 | if (gcn->move() == v && gcn->has_next()) { 483 | found_next = true; 484 | num_entries_ = std::max(1, gcn->num_entries()); 485 | 486 | auto prev_pnd = std::move(pnd_); 487 | gcn = &(prev_pnd->children[i].next_ptr()->children[j]); 488 | pnd_ = std::move(gcn->next_ptr_); 489 | 490 | // Deletes previous pnd_ in another thread, which 491 | // takes 400,000 nodes per sec. 492 | auto p = prev_pnd.release(); 493 | auto th = std::thread([p]() { delete p; }); 494 | th.detach(); 495 | 496 | break; 497 | } 498 | } 499 | } 500 | 501 | if (found_next) break; 502 | } 503 | } 504 | 505 | if (!found_next) { 506 | if (static_cast(pnd_)) { 507 | auto prev_pnd = std::move(pnd_); 508 | auto p = prev_pnd.release(); 509 | auto th = std::thread([p]() { delete p; }); 510 | th.detach(); 511 | } 512 | 513 | if (create_if_not_found) { 514 | pnd_ = std::move(std::unique_ptr(new Node(b))); 515 | num_entries_ = 1; 516 | } else { 517 | pnd_.reset(); 518 | num_entries_ = 0; 519 | } 520 | } 521 | 522 | return found_next; 523 | } 524 | 525 | #endif // NODE_H_ 526 | -------------------------------------------------------------------------------- /src/option.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #include "./option.h" 21 | #include "./pattern.h" 22 | 23 | Option::OptionsMap Options; 24 | 25 | std::string JoinPath(const std::string s1, const std::string s2, 26 | const std::string s3) { 27 | std::string s = s1; 28 | #ifdef _WIN32 29 | std::string split_str = "\\"; 30 | #else 31 | std::string split_str = "/"; 32 | #endif 33 | size_t n = split_str.size(); 34 | if (s.size() >= n && s.substr(s.size() - n) != split_str) s += split_str; 35 | s += s2; 36 | if (s3 != "") { 37 | if (s.substr(s.size() - n) != split_str) s += split_str; 38 | s += s3; 39 | } 40 | 41 | return s; 42 | } 43 | 44 | namespace { 45 | /** 46 | * Initialize an OptionMap to the default values. 47 | */ 48 | void InitOptions(Option::OptionsMap* o) { 49 | (*o)["num_threads"] << Option(16, 1, 512); 50 | (*o)["num_gpus"] << Option(1, 1, 32); 51 | 52 | #if BOARD_SIZE == 19 53 | (*o)["komi"] << Option(7.5); 54 | #else 55 | (*o)["komi"] << Option(7.0); 56 | #endif 57 | (*o)["rule"] << Option(0, 0, 2); 58 | (*o)["repetition_rule"] << Option(0, 0, 2); 59 | (*o)["resign_value"] << Option(0.1); 60 | (*o)["use_ponder"] << Option(true); 61 | (*o)["allocate_gpu"] << Option(false); 62 | 63 | (*o)["main_time"] << Option(0.0); 64 | (*o)["byoyomi"] << Option(3.0); 65 | (*o)["byoyomi_margin"] << Option(0.0); 66 | (*o)["num_extensions"] << Option(0, 0, 100); 67 | (*o)["emergency_time"] << Option(15.0); 68 | (*o)["need_time_control"] << Option(true); 69 | 70 | (*o)["working_dir"] << Option(""); 71 | (*o)["sgf_dir"] << Option(""); 72 | (*o)["model_path"] << Option("default"); 73 | (*o)["validate_model_path"] << Option("default"); 74 | (*o)["node_size"] << Option(65536, 4096, 67108864); 75 | 76 | (*o)["save_log"] << Option(true); 77 | (*o)["resume_file_name"] << Option(""); 78 | (*o)["send_list"] << Option(false); 79 | 80 | (*o)["lizzie"] << Option(false); 81 | 82 | // Seach parameters. 83 | (*o)["batch_size"] << Option(8, 1, 256); 84 | (*o)["lambda_init"] << Option(0.95); 85 | (*o)["lambda_delta"] << Option(0.2); 86 | 87 | #if BOARD_SIZE == 19 88 | (*o)["lambda_move_start"] << Option(240, 0, 400); 89 | (*o)["lambda_move_end"] << Option(360, 1, 401); 90 | #elif BOARD_SIZE == 13 91 | (*o)["lambda_move_start"] << Option(120, 0, 400); 92 | (*o)["lambda_move_end"] << Option(180, 1, 401); 93 | #else // BOARD_SIZE == 9 94 | (*o)["lambda_move_start"] << Option(60, 0, 400); 95 | (*o)["lambda_move_end"] << Option(90, 1, 401); 96 | #endif // BOARD_SIZE == 19 97 | 98 | (*o)["cp_init"] << Option(0.75); 99 | (*o)["cp_base"] << Option(20000.0); 100 | (*o)["use_dirichlet_noise"] << Option(false); 101 | (*o)["dirichlet_noise"] << Option(0.03); 102 | (*o)["search_limit"] << Option(-1, -1, 100000); 103 | (*o)["virtual_loss"] << Option(1, 0, 64); 104 | (*o)["ladder_reduction"] << Option(0.1); 105 | 106 | (*o)["num_games"] << Option(1, 1, 1000000); 107 | (*o)["use_full_features"] << Option(true); 108 | (*o)["value_from_black"] << Option(false); 109 | 110 | #if defined(LEARN) 111 | (*o)["opening_model_path"] << Option(""); 112 | (*o)["run_id"] << Option(-1, -1, 1000000); 113 | (*o)["param_id"] << Option(-1, -1, 1000000); 114 | (*o)["update_each"] << Option(false); 115 | (*o)["model_interval"] << Option(10, 1, 10000); 116 | (*o)["num_agents"] << Option(-1, 2, 40); 117 | 118 | (*o)["result_dir"] << Option(""); 119 | (*o)["stop_flag_dir"] << Option(""); 120 | (*o)["use_rating_model"] << Option(false); 121 | 122 | (*o)["db_host"] << Option("127.0.0.1"); 123 | (*o)["db_user"] << Option("user"); 124 | (*o)["db_pwd"] << Option("pwd"); 125 | (*o)["db_name"] << Option("learn"); 126 | (*o)["db_port"] << Option(3306, 0, 65535); 127 | #endif 128 | } 129 | } // namespace 130 | 131 | /** 132 | * Parse config.txt and command line arguments and reflect them in Options. 133 | */ 134 | std::string ReadConfiguration(int argc, char** argv) { 135 | InitOptions(&Options); 136 | 137 | // 1. Gets working directory path. 138 | char buf[1024] = {}; 139 | #ifdef _WIN32 140 | // GetModuleFileName needs a char* argument in MinGW-x64. 141 | GetModuleFileName(NULL, buf, sizeof(buf)); 142 | std::string split_str = "\\"; 143 | #else 144 | auto sz = readlink("/proc/self/exe", buf, sizeof(buf) - 1); 145 | std::string split_str = "/"; 146 | #endif 147 | std::string path_(buf); 148 | // Deletes file name. 149 | auto pos_filename = path_.rfind(split_str); 150 | if (pos_filename != std::string::npos) { 151 | path_ = path_.substr(0, pos_filename + 1); 152 | // Uses current directory if there is no configure file. 153 | std::ifstream ifs(path_ + "config.txt"); 154 | if (ifs.is_open()) Options["working_dir"] = path_; 155 | } 156 | 157 | // 2. Import prob table. 158 | std::string prob_dir = JoinPath(Options["working_dir"], "prob"); 159 | Pattern::Init(prob_dir); 160 | 161 | std::string config_path = JoinPath(Options["working_dir"], "config.txt"); 162 | 163 | // 3. Sets configuration file path. 164 | for (int i = 0; i < argc; ++i) { 165 | std::string arg_i = argv[i]; 166 | if (arg_i.find("--config=") != std::string::npos) { 167 | config_path = arg_i.substr(9); 168 | std::cerr << "Set configuration file: " << config_path << std::endl; 169 | } 170 | } 171 | 172 | // 4. Sets alies for old options. 173 | std::unordered_map alies_options; 174 | { 175 | alies_options["gpu_cnt"] = "num_gpus"; 176 | alies_options["thread_cnt"] = "num_threads"; 177 | alies_options["extension_cnt"] = "num_extensions"; 178 | alies_options["batch_cnt"] = "batch_size"; 179 | alies_options["game_cnt"] = "num_games"; 180 | alies_options["vloss_cnt"] = "virtual_loss"; 181 | alies_options["agent_cnt"] = "num_agents"; 182 | } 183 | 184 | std::unordered_set executable_modes{ 185 | "--benchmark", "--test", "--self", 186 | "--policy_self", "--learn", "--rating"}; 187 | 188 | auto trim_str = [](const std::string& str, 189 | const char* trim_chars = " \t\v\r\n") { 190 | std::string trimmed_str; 191 | auto left = str.find_first_not_of(trim_chars); 192 | 193 | if (left != std::string::npos) { 194 | auto right = str.find_last_not_of(trim_chars); 195 | trimmed_str = str.substr(left, right - left + 1); 196 | } 197 | return trimmed_str; 198 | }; 199 | 200 | auto flag_str = [](const std::string& str) { 201 | return (str == "on" || str == "On" || str == "ON") ? "true" : "false"; 202 | }; 203 | 204 | // 5. Open the configuration file. 205 | std::ifstream ifs(config_path); 206 | std::string str; 207 | 208 | // Read line by line. 209 | int num_lines = 0; 210 | while (ifs && getline(ifs, str)) { 211 | ++num_lines; 212 | // Exclude comments. 213 | auto cmt_pos = str.find("#"); 214 | if (cmt_pos != std::string::npos) { 215 | str = str.substr(0, cmt_pos); 216 | } 217 | str = trim_str(str, " \t\v\r\n"); 218 | if (str.length() == 0) continue; 219 | 220 | // Finds position after '='. 221 | auto eql_pos = str.find("="); 222 | if (eql_pos == std::string::npos) { 223 | std::cerr << "Failed to parse config:" << config_path << ":" << num_lines 224 | << " " << str << ". '=' not found.\n"; 225 | continue; 226 | } 227 | 228 | std::string key = trim_str(str.substr(0, eql_pos)); 229 | std::string val = trim_str(str.substr(eql_pos + 1)); 230 | 231 | if (key.find("--") == std::string::npos) { 232 | std::cerr << "Set '--' before option: [" << key << "]" << std::endl; 233 | } else { 234 | key = key.substr(2); // Removes '--' 235 | if (alies_options.count(key) > 0) { 236 | // Convert from options of old version. 237 | key = alies_options[key]; 238 | } 239 | 240 | if (Options.find(key) == Options.end()) { 241 | std::cerr << "Unknown option: [--" << key << "]" << std::endl; 242 | exit(1); 243 | } 244 | 245 | if (val == "on" || val == "off") 246 | Options[key] = flag_str(val); 247 | else 248 | Options[key] = val; 249 | } 250 | } 251 | ifs.close(); 252 | 253 | // 6. Reads command line options. 254 | // Overwrites options from configure file. 255 | std::string mode = ""; 256 | bool set_batch_size = Options["batch_size"].get_int() != 8; 257 | for (int i = 0; i < argc; ++i) { 258 | std::string arg_i = argv[i]; 259 | if (executable_modes.count(arg_i) > 0) { 260 | mode = arg_i; 261 | } else if (arg_i == "--lizzie") { 262 | Options["lizzie"] = "true"; 263 | } else if (arg_i.find("--config=") != std::string::npos) { 264 | continue; 265 | } else if (arg_i.find("--") != std::string::npos) { 266 | std::string str = arg_i.substr(2); 267 | auto eql_pos = str.find("="); 268 | if (eql_pos == std::string::npos) { 269 | std::cerr << "Failed to parse command line option: [--" << str 270 | << "]. '=' not found.\n"; 271 | exit(1); 272 | } 273 | std::string key = str.substr(0, eql_pos); 274 | std::string val = str.substr(eql_pos + 1); 275 | if (alies_options.count(key) > 0) { 276 | // Convert from options of old version. 277 | key = alies_options[key]; 278 | } 279 | 280 | if (Options.find(key) == Options.end()) { 281 | std::cerr << "Unknown option: [--" << key << "]" << std::endl; 282 | exit(1); 283 | } 284 | 285 | if (val == "on" || val == "off") 286 | Options[key] = flag_str(val); 287 | else 288 | Options[key] = val; 289 | 290 | if (key == "batch_size") set_batch_size = true; 291 | } 292 | } 293 | // Set 5 batch size when using search_limit option. 294 | if (!set_batch_size && Options["search_limit"].get_int() > 0) 295 | Options["batch_size"] = 5; 296 | 297 | std::cerr << "Configuration is loaded.\n"; 298 | 299 | return mode; 300 | } 301 | -------------------------------------------------------------------------------- /src/option.h: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #ifndef OPTION_H_ 21 | #define OPTION_H_ 22 | 23 | #include 24 | #include 25 | 26 | #include "./config.h" 27 | 28 | // -------------------- 29 | // Option 30 | // -------------------- 31 | 32 | /** 33 | * @enum OptionType 34 | * Data type to be stored in the Option class. 35 | */ 36 | enum OptionType { 37 | kOptionNone, 38 | kOptionString, 39 | kOptionBool, 40 | kOptionInt, 41 | kOptionDouble, 42 | }; 43 | 44 | /** 45 | * @class Option 46 | * Option class holds the options specified by the command line argument, which 47 | * are bool, int, double, and string for each option type. The data type is 48 | * determined by the constructor. 49 | */ 50 | class Option { 51 | public: 52 | typedef std::unordered_map OptionsMap; 53 | 54 | // Constructor 55 | Option() : type_(kOptionNone), min_(0), max_(0) {} 56 | 57 | explicit Option(const char* v) 58 | : type_(kOptionString), val_(v), min_(0), max_(0) {} 59 | 60 | explicit Option(bool v) 61 | : type_(kOptionBool), val_(v ? "true" : "false"), min_(0), max_(0) {} 62 | 63 | Option(int v, int min_v, int max_v) 64 | : type_(kOptionInt), val_(std::to_string(v)), min_(min_v), max_(max_v) {} 65 | 66 | explicit Option(double v) 67 | : type_(kOptionDouble), val_(std::to_string(v)), min_(0), max_(0) {} 68 | 69 | // Copy 70 | Option& operator=(std::string v) { 71 | ASSERT_LV1(type_ != kOptionNone); 72 | 73 | // Out of range or invalid argument. 74 | if (type_ != kOptionString && v.empty() || 75 | (type_ == kOptionBool && v != "true" && v != "false") || 76 | (type_ == kOptionInt && (stoll(v) < min_ || stoll(v) > max_))) 77 | return *this; 78 | 79 | val_ = v; 80 | return *this; 81 | } 82 | 83 | Option& operator=(const char* ptr) { return *this = std::string(ptr); } 84 | 85 | Option& operator=(int v) { 86 | ASSERT_LV1(type_ == kOptionInt || type_ == kOptionDouble); 87 | return *this = std::to_string(v); 88 | } 89 | 90 | Option& operator=(bool v) { 91 | ASSERT_LV1(type_ == kOptionBool); 92 | return *this = (v ? "true" : "false"); 93 | } 94 | 95 | Option& operator=(double v) { 96 | ASSERT_LV1(type_ == kOptionDouble); 97 | return *this = std::to_string(v); 98 | } 99 | 100 | void operator<<(const Option& o) { *this = o; } 101 | 102 | // Accessor for each data type. 103 | int get_int() const { 104 | ASSERT_LV1(type_ == kOptionInt); 105 | return std::stoi(val_); 106 | } 107 | 108 | bool get_bool() const { 109 | ASSERT_LV1(type_ == kOptionBool); 110 | return (val_ == "true"); 111 | } 112 | 113 | double get_double() const { 114 | ASSERT_LV1(type_ == kOptionDouble); 115 | return std::stod(val_); 116 | } 117 | 118 | std::string get_string() const { 119 | ASSERT_LV1(type_ != kOptionNone); 120 | return val_; 121 | } 122 | 123 | // Implicit type conversion. 124 | // Used for variable initialization and conditional branching by options of 125 | // kOptionBool. 126 | operator int() const { 127 | ASSERT_LV1(type_ == kOptionInt); 128 | return std::stoi(val_); 129 | } 130 | 131 | operator bool() const { 132 | ASSERT_LV1(type_ == kOptionBool); 133 | return (val_ == "true"); 134 | } 135 | 136 | operator double() const { 137 | ASSERT_LV1(type_ == kOptionDouble); 138 | return std::stod(val_); 139 | } 140 | 141 | operator std::string() const { 142 | ASSERT_LV1(type_ != kOptionNone); 143 | return val_; 144 | } 145 | 146 | private: 147 | OptionType type_; 148 | std::string val_; 149 | int min_; 150 | int max_; 151 | }; 152 | 153 | /** 154 | * Map to store options with command line arguments. 155 | */ 156 | extern Option::OptionsMap Options; 157 | 158 | /** 159 | * Combine directory and file paths. 160 | */ 161 | std::string JoinPath(const std::string s1, const std::string s2, 162 | const std::string s3 = ""); 163 | 164 | /** 165 | * Parse config.txt and command line arguments and reflect them in Options. 166 | */ 167 | std::string ReadConfiguration(int argc, char** argv); 168 | 169 | #endif // OPTION_H_ 170 | -------------------------------------------------------------------------------- /src/pattern.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #include "./pattern.h" 21 | #include "./option.h" 22 | 23 | double Pattern::prob_ptn3x3_[65536][256][2][2]; 24 | bool Pattern::legal_ptn_[256][256][2]; 25 | int Pattern::count_ptn_[256][4]; 26 | std::unordered_map> Pattern::prob_ptn_rsp_; 27 | 28 | void Pattern::Init(std::string prob_dir) { 29 | // 1. Initializes pattern tables. 30 | Pattern ptn; 31 | 32 | for (int j = 0; j < 256; ++j) { 33 | ptn.set_stones(0x00aaaa00 | j); 34 | for (Color c = kColorZero; c < kNumColors; ++c) { 35 | count_ptn_[j][c] = ptn.CountImpl(c); 36 | } 37 | } 38 | 39 | for (uint32_t j = 0; j < 65536; ++j) { 40 | for (uint32_t k = 0; k < 256; ++k) { 41 | ptn.set_stones(0x00aa0000 | j | (k << 24)); 42 | 43 | for (Color c = kColorZero; c < kNumPlayers; ++c) { 44 | if (ptn.LegalImpl(c)) { 45 | prob_ptn3x3_[j][k][c][0] = 1.0; 46 | prob_ptn3x3_[j][k][c][1] = 1.0; 47 | 48 | legal_ptn_[j & 0xff][k][c] = true; 49 | } else { 50 | prob_ptn3x3_[j][k][c][0] = 0.0; 51 | prob_ptn3x3_[j][k][c][1] = 1.0; 52 | 53 | legal_ptn_[j & 0xff][k][c] = false; 54 | } 55 | } 56 | } 57 | } 58 | 59 | prob_ptn_rsp_.clear(); 60 | 61 | // 2. Imports pattern probability from files. 62 | 63 | std::ifstream ifs; 64 | std::string str; 65 | 66 | // 2-1. 3x3 patterns. 67 | ifs.open(JoinPath(prob_dir, "prob_ptn3x3.txt")); 68 | if (ifs.fail()) 69 | std::cerr << "file could not be opened: prob_ptn3x3.txt" << std::endl; 70 | 71 | while (getline(ifs, str)) { 72 | std::string line_str; 73 | std::istringstream iss(str); 74 | 75 | getline(iss, line_str, ','); 76 | uint32_t stones = stoul(line_str); 77 | // Swaps color bits. 78 | stones = (stones & 0xff000000) | (stones ^ 0x00aaaaaa); 79 | 80 | std::array bf_prob; 81 | for (int i = 0; i < 4; ++i) { 82 | getline(iss, line_str, ','); 83 | bf_prob[i] = stod(line_str); 84 | } 85 | 86 | int stone_bf = stones & 0xffff; 87 | int atari_bf = stones >> 24; 88 | for (int j = 0; j < 4; ++j) { 89 | int color_id = j % 2; 90 | int restore_id = j < 2 ? 0 : 1; 91 | prob_ptn3x3_[stone_bf][atari_bf][color_id][restore_id] = bf_prob[j]; 92 | } 93 | } 94 | ifs.close(); 95 | 96 | // 2-2. Response patterns. 97 | ifs.open(JoinPath(prob_dir, "prob_ptn_rsp.txt")); 98 | if (ifs.fail()) 99 | std::cerr << "file could not be opened: prob_ptn_rsp.txt" << std::endl; 100 | 101 | while (getline(ifs, str)) { 102 | std::string line_str; 103 | std::istringstream iss(str); 104 | 105 | getline(iss, line_str, ','); 106 | uint32_t stones = stoul(line_str); 107 | // Swaps color bits. 108 | stones = (stones & 0xff000000) | (stones ^ 0x00aaaaaa); 109 | 110 | std::array bf_prob; 111 | for (int i = 0; i < 2; ++i) { 112 | getline(iss, line_str, ','); 113 | bf_prob[i] = stod(line_str); 114 | } 115 | 116 | prob_ptn_rsp_.insert(std::make_pair(stones, bf_prob)); 117 | } 118 | ifs.close(); 119 | } 120 | -------------------------------------------------------------------------------- /src/pattern.h: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #ifndef PATTERN_H_ 21 | #define PATTERN_H_ 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | 30 | #include "./types.h" 31 | 32 | /** 33 | * @class Pattern 34 | * A structure for recognizing a pattern of stone arrangement 35 | * at a certain coordinate and surrounding 8 points quickly. 36 | * For example, the judgment of a legal hand can be judged 37 | * simply by referring to the legal_ptn_ table. 38 | * 39 | * [bit field] 40 | * 0-15 : 3x3 colors. U/R/D/L/RU/RD/LD/LU (2 bits each) 41 | * 16-23 : extra 4 colors. UU/RR/DD/LL (2 bits each) 42 | * 24-31 : atari/pre-atari state. U/R/D/L (2 bits each) 43 | * 44 | * color types : kWhite(0b00) / kBlack(0b01) / kEmpty(0b10) / kWall(0b11) 45 | * atari types : atari(0b01) / pre-atari(0b10) / others(0b00) 46 | */ 47 | class Pattern { 48 | public: 49 | // Constructor 50 | Pattern() { stones_ = 0x00aaaaaa; } // all empty 51 | 52 | Pattern(const Pattern& rhs) : stones_(rhs.stones_) {} 53 | 54 | explicit Pattern(const uint32_t st_) { stones_ = st_; } 55 | 56 | Pattern& operator=(const Pattern& rhs) { 57 | stones_ = rhs.stones_; 58 | return *this; 59 | } 60 | 61 | bool operator==(const Pattern& rhs) const { return stones_ == rhs.stones_; } 62 | 63 | /** 64 | * Initialize the legal_ptn_ and count_ptn_ tables, and read out the 65 | * probability distribution for each pattern. 66 | */ 67 | static void Init(std::string prob_dir); 68 | 69 | /** 70 | * Initializes bits. 71 | */ 72 | void SetEmpty() { stones_ = 0x00aaaaaa; } 73 | 74 | /** 75 | * Sets null value (= UINT_MAX) 76 | */ 77 | void SetNull() { stones_ = 0xffffffff; } 78 | 79 | uint32_t stones() const { return stones_; } 80 | 81 | void set_stones(uint32_t val) { stones_ = val; } 82 | 83 | /** 84 | * Returns stone color in a direction. 85 | * 86 | * d: 0(U),1(R),2(D),3(L), 87 | * 4(LU),5(RU),6(RD),7(LD), 88 | * 8(UU),9(RR),10(DD),11(LL) 89 | */ 90 | Color color_at(Direction d) const { return Color((stones_ >> (2 * d)) & 3); } 91 | 92 | bool is_stone(Direction d) const { return ((stones_ >> (2 * d)) & 3) < 2; } 93 | 94 | /** 95 | * Updates stone color in a direction. 96 | * 97 | * c: 0(kWhite), 1(kBlack), 2(kEmpty), 3(kWall) 98 | */ 99 | void set_color(Direction d, Color c) { 100 | stones_ &= ~(3 << (2 * d)); 101 | stones_ |= c << (2 * d); 102 | } 103 | 104 | /** 105 | * Flips color of all black/white stones. 106 | */ 107 | void FlipColor() { 108 | // 0xa = 0b1010 109 | // 1. ~stones_ & 0x00aaaaaa: flag of black or white 110 | // 2. >> 1: make lower bit mask 111 | // 3. take xor with stone and convert 0b00(kWhite) <-> 0b01(kBlack) 112 | stones_ ^= (((~stones_) & 0x00aaaaaa) >> 1); 113 | } 114 | 115 | int CountImpl(Color c) const { 116 | return static_cast(Color((stones_)&3) == c) + 117 | static_cast(Color((stones_ >> 2) & 3) == c) + 118 | static_cast(Color((stones_ >> 4) & 3) == c) + 119 | static_cast(Color((stones_ >> 6) & 3) == c); 120 | } 121 | 122 | /** 123 | * Returns number of stones/empties/walls. 124 | */ 125 | int count(Color c) const { 126 | ASSERT_LV2(kColorZero <= c && c < kNumColors); 127 | return count_ptn_[stones_ & 0xff][c]; 128 | } 129 | 130 | /** 131 | * Returns whether it is surrounded by stones of c. 132 | */ 133 | bool enclosed_by(Color c) const { 134 | ASSERT_LV2(c < kNumPlayers); 135 | return (count(c) + count(kWall)) == 4; 136 | } 137 | 138 | /** 139 | * Sets atari in each direction (URDL). 140 | */ 141 | void set_atari(bool bn, bool be, bool bs, bool bw) { 142 | // 1. eliminate pre-atari bit (0b10) of the stone to be atari 143 | stones_ &= ~((bn << 25) | (be << 27) | (bs << 29) | (bw << 31)); 144 | // 2. add atari (0b01) 145 | stones_ |= (bn << 24) | (be << 26) | (bs << 28) | (bw << 30); 146 | } 147 | 148 | /** 149 | * Clear atari in a single direction (URDL). 150 | */ 151 | void cancel_atari(bool bn, bool be, bool bs, bool bw) { 152 | stones_ &= ~((bn << 24) | (be << 26) | (bs << 28) | (bw << 30)); 153 | } 154 | 155 | /** 156 | * Clear atari in all directions. 157 | */ 158 | void clear_atari() { 159 | // 0xa = 0b1010, 0xf = 0b1111 160 | stones_ &= 0xaaffffff; 161 | } 162 | 163 | /** 164 | * Returns whether the neighbor stone is in atari. 165 | */ 166 | bool atari_at(Direction d) const { return (stones_ >> (24 + 2 * d)) & 1; } 167 | 168 | /** 169 | * Returns whether any of neighbor stones is atari. 170 | */ 171 | bool atari() const { 172 | // 0x55 = 0b01010101 173 | return (stones_ >> 24) & 0x55; 174 | } 175 | 176 | /** 177 | * Sets pre-atari (liberty = 2) in each direction (URDL). 178 | */ 179 | void set_pre_atari(bool bn, bool be, bool bs, bool bw) { 180 | // 1. eliminate atari bit (01) of the stone to be pre-atari 181 | stones_ &= ~((bn << 24) | (be << 26) | (bs << 28) | (bw << 30)); 182 | // 2. add pre-atari (10) 183 | stones_ |= (bn << 25) | (be << 27) | (bs << 29) | (bw << 31); 184 | } 185 | 186 | /** 187 | * Clears pre-atari in a single direction (URDL). 188 | */ 189 | void cancel_pre_atari(bool bn, bool be, bool bs, bool bw) { 190 | stones_ &= ~((bn << 25) | (be << 27) | (bs << 29) | (bw << 31)); 191 | } 192 | 193 | /** 194 | * Clears pre-atari in all directions. 195 | */ 196 | void clear_pre_atari() { 197 | // 0x5 = 0b0101, 0xf = 0b1111 198 | stones_ &= 0x55ffffff; 199 | } 200 | 201 | /** 202 | * Returns whether the neighbor stone is pre-atari. 203 | */ 204 | bool pre_atari_at(Direction d) const { return (stones_ >> (24 + 2 * d)) & 2; } 205 | 206 | /** 207 | * Returns whether any of next stones is pre-atari. 208 | */ 209 | bool pre_atari() const { 210 | // 0xaa = 0b10101010 211 | return (stones_ >> 24) & 0xaa; 212 | } 213 | 214 | /** 215 | * Returns whether player's move into this is legal. 216 | */ 217 | bool LegalImpl(Color c) const { 218 | ASSERT_LV2(c < kNumPlayers); 219 | 220 | // 1. Legal if blank vertexes exist in neighbor. 221 | if (count(kEmpty) != 0) return true; 222 | 223 | int num_stones[2] = {0, 0}; // 0: white, 1: black 224 | int num_atari[2] = {0, 0}; 225 | 226 | // 2. Counts neighbor stones and atari. 227 | for (Direction d = kDirZero; d < kNumDir4; ++d) { 228 | Color ci = color_at(d); 229 | if (ci < kNumPlayers) { 230 | ++num_stones[ci]; 231 | if (atari_at(d)) ++num_atari[ci]; 232 | } 233 | } 234 | 235 | // 3. Legal if opponent's stone is atari, 236 | // or any of her stones_ is not atari. 237 | return (num_atari[~c] != 0 || num_atari[c] < num_stones[c]); 238 | } 239 | 240 | /** 241 | * Returns whether player's move into this is legal. 242 | */ 243 | bool legal(Color c) const { 244 | ASSERT_LV2(c < kNumPlayers); 245 | ASSERT_LV2((stones_ >> 24) == ((stones_ >> 24) & 0xff)); 246 | return legal_ptn_[stones_ & 0xff][stones_ >> 24][c]; 247 | } 248 | 249 | /** 250 | * Returns probability of this pettern. 251 | */ 252 | double prob(Color c, bool restore) const { 253 | ASSERT_LV2(c < kNumPlayers); 254 | ASSERT_LV2((stones_ >> 24) == ((stones_ >> 24) & 0xff)); 255 | return prob_ptn3x3_[stones_ & 0xffff][stones_ >> 24][c] 256 | [static_cast(restore)]; 257 | } 258 | 259 | /** 260 | * Returns response probability of this pettern. 261 | */ 262 | void ResponseProb(double* ptn_prob, double* inv_prob) const { 263 | if (prob_ptn_rsp_.find(stones_) != prob_ptn_rsp_.end()) { 264 | auto p = prob_ptn_rsp_.at(stones_); 265 | *ptn_prob = p[0]; 266 | *inv_prob = p[1]; 267 | } else { 268 | *ptn_prob = *inv_prob = -1; 269 | } 270 | } 271 | 272 | /** 273 | * Returns Pattern which is rotated clockwise by 90 degrees. 274 | */ 275 | Pattern Rotate() const { 276 | // 0x3 = 0b0011, 0xc = 0b1100, 0xf = 0b1111 277 | Pattern rot_ptn(((stones_ << 2) & 0xfcfcfcfc) | 278 | ((stones_ >> 6) & 0x03030303)); 279 | 280 | return rot_ptn; 281 | } 282 | 283 | /** 284 | * Returns Pattern which is horizontally inverted. 285 | */ 286 | Pattern Invert() const { 287 | // 0x3 = 0b0011, 0xc = 0b1100 288 | Pattern mir_ptn((stones_ & 0x33330033) | ((stones_ << 4) & 0xc0c000c0) | 289 | ((stones_ >> 4) & 0x0c0c000c) | 290 | ((stones_ << 2) & 0x0000cc00) | 291 | ((stones_ >> 2) & 0x00003300)); 292 | 293 | return mir_ptn; 294 | } 295 | 296 | /** 297 | * Returns Pattern which has the minimum number. 298 | */ 299 | Pattern MinimumSym() const { 300 | Pattern tmp_ptn(stones_); 301 | Pattern min_ptn = tmp_ptn; 302 | 303 | for (int i = 0; i < 2; ++i) { 304 | for (int j = 0; j < 4; ++j) { 305 | if (tmp_ptn.stones_ < min_ptn.stones_) min_ptn = tmp_ptn; 306 | 307 | tmp_ptn = tmp_ptn.Rotate(); 308 | } 309 | tmp_ptn = tmp_ptn.Invert(); 310 | } 311 | 312 | return min_ptn; 313 | } 314 | 315 | /** 316 | * Outputs Pattern information. (for debug) 317 | */ 318 | friend std::ostream& operator<<(std::ostream& os, const Pattern& ptn) { 319 | auto cl = [&ptn](Direction d) { 320 | std::string str_[] = {"O", "X", ".", "%"}; 321 | return str_[ptn.color_at(d)]; 322 | }; 323 | 324 | auto ap = [&ptn](Direction d) { 325 | return (ptn.atari_at(d) ? "a" : ptn.pre_atari_at(d) ? "p" : "."); 326 | }; 327 | 328 | os << " " << cl(kDirUU) << std::endl; 329 | os << " " << cl(kDirLU) << cl(kDirU) << cl(kDirRU) << " " << ap(kDirU) 330 | << std::endl; 331 | os << cl(kDirLL) << cl(kDirL) << "." << cl(kDirR) << cl(kDirRR) << " "; 332 | os << ap(kDirL) << " " << ap(kDirR) << std::endl; 333 | os << " " << cl(kDirLD) << cl(kDirD) << cl(kDirRD) << " " << ap(kDirD) 334 | << std::endl; 335 | os << " " << cl(kDirDD) << std::endl; 336 | 337 | return os; 338 | } 339 | 340 | friend void PrintPatternProb(); 341 | 342 | private: 343 | uint32_t stones_; 344 | // Static tables are initialized in Pattern::Init(). 345 | static double prob_ptn3x3_[65536][256][2][2]; 346 | static bool legal_ptn_[256][256][2]; 347 | static int count_ptn_[256][4]; 348 | static std::unordered_map> prob_ptn_rsp_; 349 | }; 350 | 351 | #endif // PATTERN_H_ 352 | -------------------------------------------------------------------------------- /src/route_queue.h: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #ifndef ROUTE_QUEUE_H_ 21 | #define ROUTE_QUEUE_H_ 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | #include "./eval_cache.h" 29 | #include "./node.h" 30 | 31 | /** 32 | * @enum LeafType 33 | * States at the end of a search. 34 | */ 35 | enum LeafType { kEvaluated, kWaitEval, kFailToPush, kReachEnd, kLeafNone }; 36 | 37 | /** 38 | * @struct SearchRoute 39 | * Record of the paths taken in the search tree. 40 | */ 41 | struct SearchRoute { 42 | int tree_id; 43 | int depth; 44 | int num_requests; 45 | LeafType leaf; 46 | std::vector moves; 47 | std::vector child_ids; 48 | 49 | SearchRoute() : depth(0), num_requests(1), tree_id(0), leaf(kLeafNone) {} 50 | 51 | SearchRoute(const SearchRoute& rhs) 52 | : depth(rhs.depth), 53 | num_requests(rhs.num_requests), 54 | tree_id(rhs.tree_id), 55 | leaf(rhs.leaf), 56 | moves(rhs.moves), 57 | child_ids(rhs.child_ids) {} 58 | 59 | bool operator==(const SearchRoute& rhs) const { 60 | return depth == rhs.depth && tree_id == rhs.tree_id && moves == rhs.moves && 61 | child_ids == rhs.child_ids; 62 | } 63 | 64 | void Add(Vertex v, int child_id) { 65 | moves.push_back(v); 66 | child_ids.push_back(child_id); 67 | ++depth; 68 | } 69 | }; 70 | 71 | /** 72 | * @struct RouteEntry 73 | * Structure that holds the path and terminal node information (pnd, ft, vp, 74 | * key) of the search. 75 | */ 76 | struct RouteEntry { 77 | std::vector routes; 78 | std::unique_ptr pnd; 79 | Feature ft; 80 | ValueAndProb vp; 81 | Key key; 82 | 83 | RouteEntry() : key(UINT64_MAX) {} 84 | 85 | RouteEntry(const Board& b, const SearchRoute& sr) 86 | : ft(b.get_feature()), key(b.key()) { 87 | pnd = std::move(std::unique_ptr(new Node(b))); 88 | routes.push_back(sr); 89 | } 90 | 91 | bool has_node_ptr() const { return static_cast(pnd); } 92 | 93 | void AddRoute(const SearchRoute& sr) { 94 | auto itr = find(routes.begin(), routes.end(), sr); 95 | if (itr != routes.end()) 96 | itr->num_requests++; 97 | else 98 | routes.push_back(sr); 99 | } 100 | }; 101 | 102 | /** 103 | * @class RouteQueue 104 | * Queue class that exclusively stores a RouteEntry. 105 | * Used for managing common node information while searching in multiple search 106 | * trees during training. 107 | */ 108 | class RouteQueue { 109 | public: 110 | RouteQueue() {} 111 | 112 | void clear() { entries_.clear(); } 113 | 114 | int size() const { return entries_.size(); } 115 | 116 | void push(const Board& b, const SearchRoute& route) { 117 | Key key = b.key(); 118 | 119 | std::lock_guard lock(mx_); 120 | auto itr = 121 | find_if(entries_.begin(), entries_.end(), [key](const RouteEntry& e) { 122 | return e.key == key && e.has_node_ptr(); 123 | }); 124 | 125 | if (itr != entries_.end()) { 126 | itr->AddRoute(route); 127 | } else { 128 | entries_.emplace_back(b, route); 129 | } 130 | } 131 | 132 | std::vector* get_entries() { return &entries_; } 133 | 134 | private: 135 | std::mutex mx_; 136 | std::vector entries_; 137 | }; 138 | 139 | #endif // ROUTE_QUEUE_H_ 140 | -------------------------------------------------------------------------------- /src/search.h: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #ifndef SEARCH_H_ 21 | #define SEARCH_H_ 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | 35 | #include "./board.h" 36 | #include "./eval_cache.h" 37 | #include "./eval_worker.h" 38 | #include "./network.h" 39 | #include "./node.h" 40 | #include "./option.h" 41 | #include "./timer.h" 42 | 43 | /** 44 | * @class SearchTree 45 | * SearchTree class manages search trees and performs parallel search and output 46 | * search information. 47 | * Need to call SetGPUAndMemory() to allocate cache memory and initialize the 48 | * GPU inference engine before you can search. 49 | */ 50 | class SearchTree : public RootNode, public Timer, public SearchParameter { 51 | public: 52 | // Constructor. GPU and node table are not yet set. 53 | SearchTree() : RootNode(), Timer(), SearchParameter() { Init(); } 54 | 55 | // Treats it the same as the default constructor, since the copy constructor 56 | // is called in std::vector::resize() of C++11. 57 | SearchTree(const SearchTree& rhs) : RootNode(), Timer(), SearchParameter() { 58 | Init(); 59 | } 60 | 61 | double lambda() const { return lambda_; } 62 | 63 | double num_virtual_loss() const { return virtual_loss_; } 64 | 65 | double ladder_reduction() const { return ladder_reduction_; } 66 | 67 | double komi() const { return komi_; } 68 | 69 | int num_reach_ends() const { return num_reach_ends_; } 70 | 71 | bool reflesh_root() const { return reflesh_root_; } 72 | 73 | bool consider_pass() const { return consider_pass_; } 74 | 75 | std::ofstream* log_file() const { return log_file_.get(); } 76 | 77 | bool has_eval_worker() const { return static_cast(eval_worker_); } 78 | 79 | Node* root_node() const { return RootNode::node(); } 80 | 81 | void set_komi(double val) { komi_ = val; } 82 | 83 | void set_num_reach_ends(int val) { num_reach_ends_.store(val); } 84 | 85 | void UpdateLambda(int ply) { 86 | lambda_ = 87 | lambda_init_ - 88 | lambda_delta_ * 89 | std::min( 90 | 1.0, 91 | std::max(0.0, static_cast(ply - lambda_move_start_) / 92 | (lambda_move_end_ - lambda_move_start_))); 93 | } 94 | 95 | void StopToThink() { stop_think_.store(true); } 96 | 97 | void PrepareToThink() { stop_think_.store(false); } 98 | 99 | void SetLogFile(std::string log_path) { 100 | if (log_file_) { 101 | log_file_->close(); 102 | log_file_.reset(); 103 | } 104 | log_file_ = std::move(std::unique_ptr( 105 | new std::ofstream(log_path, std::ofstream::out))); 106 | } 107 | 108 | void PrintBoardLog(const Board& b) { 109 | if (Options["save_log"]) std::cerr << b; 110 | if (log_file_) *(log_file_.get()) << b; 111 | } 112 | 113 | void InitEvalCache() { eval_cache_.Init(); } 114 | 115 | void ReplaceModel(std::vector gpu_ids, std::string model_path = "") { 116 | eval_worker_->ReplaceModel(gpu_ids, model_path); 117 | } 118 | 119 | void InitEvalWorker(std::vector list_gpus, std::string model_path = "") { 120 | if (list_gpus.empty()) { 121 | for (int i = 0; i < num_gpus_; ++i) { 122 | list_gpus.push_back(i); 123 | } 124 | } 125 | 126 | eval_worker_ = std::move(std::unique_ptr(new EvalWorker())); 127 | eval_worker_->Init(list_gpus, model_path); 128 | eval_cache_.Resize(300000); 129 | 130 | if (Options["rule"].get_int() == kJapanese && 131 | Options["validate_model_path"].get_string() != "") { 132 | std::lock_guard(*(eval_worker_->get_mutex())); 133 | std::string validate_model_path = 134 | Options["validate_model_path"].get_string(); 135 | if (validate_model_path == "default") { 136 | validate_model_path = 137 | JoinPath(Options["working_dir"], "engine", "model_cn.engine"); 138 | } 139 | validate_engine_ = std::move(std::unique_ptr( 140 | new TensorEngine(list_gpus[0], Options["batch_size"].get_int()))); 141 | validate_engine_->Init(validate_model_path); 142 | } 143 | 144 | stop_think_ = false; 145 | } 146 | 147 | void SetGPUAndMemory() { 148 | std::vector list_gpus; 149 | for (int i = 0; i < num_gpus_; ++i) list_gpus.push_back(i); 150 | 151 | InitEvalWorker(list_gpus); 152 | RootNode::Resize(Options["node_size"].get_int()); 153 | } 154 | 155 | void Init() { 156 | num_gpus_ = Options["num_gpus"].get_int(); 157 | num_threads_ = Options["num_threads"].get_int(); 158 | komi_ = Options["komi"].get_double(); 159 | use_dirichlet_noise_ = Options["use_dirichlet_noise"].get_bool(); 160 | reflesh_root_ = false; // (bool)Options["use_dirichlet_noise"]; 161 | consider_pass_ = Options["rule"].get_int() == kJapanese; 162 | log_file_.reset(); 163 | stop_think_ = false; 164 | InitRoot(); 165 | } 166 | 167 | void InitRoot() { 168 | lambda_ = lambda_init_; 169 | num_evaluated_ = 0; 170 | num_reach_ends_ = 0; 171 | RootNode::Init(); 172 | Timer::Init(); 173 | } 174 | 175 | void UpdateNodeVP(Node* nd, const ValueAndProb& vp); 176 | 177 | /** 178 | * Creates node in node table from board or another node. 179 | */ 180 | void CreateNode(Node* nd, const Board& b, const ValueAndProb& vp) { 181 | *nd = b; 182 | UpdateNodeVP(nd, vp); 183 | increment_entries(); 184 | } 185 | 186 | /** 187 | * Sets a pointer to the next node to a child node. 188 | */ 189 | void SetNextNode(Node* nd, int child_id, std::unique_ptr* pnd, 190 | const ValueAndProb& vp) { 191 | UpdateNodeVP(pnd->get(), vp); 192 | { 193 | std::lock_guard lock(nd->mutex()); 194 | if (!nd->children[child_id].has_next()) 195 | nd->children[child_id].set_next_ptr(pnd); 196 | } 197 | increment_entries(); 198 | } 199 | 200 | /** 201 | * Adds Dirichlet noise to a node. 202 | */ 203 | void AddDirichletNoise(Node* nd) { 204 | std::vector dirichlet_list; 205 | double sum_noises = 0.0; 206 | int imax = nd->num_children(); 207 | for (int i = 0; i < imax; ++i) { 208 | double noise = RandNoise(); 209 | sum_noises += noise; 210 | dirichlet_list.emplace_back(noise); 211 | } 212 | 213 | ChildNode* child; 214 | double sum_probs = 0.0; 215 | for (int i = 0; i < imax; ++i) { 216 | child = &nd->children[i]; 217 | child->set_prob(0.75 * child->prob() + 218 | 0.25 * dirichlet_list[i] / sum_noises); 219 | sum_probs += child->prob(); 220 | if (reflesh_root_) child->InitValueStat(); 221 | } 222 | 223 | double inv_sum = sum_probs > 0 ? 1.0 / sum_probs : 1.0; 224 | for (int i = 0; i < imax; ++i) 225 | nd->children[i].set_prob(nd->children[i].prob() * inv_sum); 226 | if (reflesh_root_) nd->set_num_total_values(1); 227 | } 228 | 229 | /** 230 | * Updates the root node. 231 | */ 232 | void UpdateRoot(const Board& b, TensorEngine* engine = nullptr) { 233 | Vertex v = b.move_before(); 234 | bool has_child = RootNode::ShiftRootNode(v, b); 235 | if (!has_child) { 236 | ValueAndProb vp; 237 | Feature ft(b.get_feature()); 238 | 239 | if (engine) { 240 | engine->Infer(ft, &vp); 241 | } else { 242 | eval_worker_->Evaluate(ft, &vp); 243 | } 244 | 245 | CreateNode(root_node(), b, vp); 246 | } 247 | 248 | if (use_dirichlet_noise_) AddDirichletNoise(root_node()); 249 | } 250 | 251 | /** 252 | * The search tree is searched once to the end according to the Q value, and 253 | * evaluated or rolled out. 254 | * 255 | * NNSearch == true : Evaluation with GPU 256 | * NNSearch == false: Rollout 257 | */ 258 | template 259 | double SearchBranch(Node* nd, Board* b, SearchRoute* route, 260 | RouteQueue* eq = nullptr, EvalCache* cache = nullptr); 261 | 262 | /** 263 | * Writes out text to the log file. 264 | * Outputs to standard error output when save_log mode. 265 | */ 266 | void PrintLog(const char* output_text, ...) { 267 | va_list args; 268 | char buf[1024]; 269 | 270 | va_start(args, output_text); 271 | vsprintf(buf, output_text, args); 272 | va_end(args); 273 | 274 | if (Options["save_log"]) std::cerr << buf; 275 | if (log_file_) *(log_file_.get()) << buf; 276 | } 277 | 278 | /** 279 | * Returns the maximum depth of the search. 280 | * Considering repetition, the seach is closed at 128 moves. 281 | */ 282 | int MaxDepth(const Node& nd, Vertex prev_move, int depth) const { 283 | if (nd.num_children() == 0 || depth >= 128) return depth; 284 | 285 | ChildNode* best_child = SortChildren(nd).front(); 286 | int max_depth = depth; 287 | 288 | if (best_child->has_next()) { 289 | if (prev_move == kPass && best_child->move() == kPass) return max_depth; 290 | max_depth = std::max(max_depth, MaxDepth(*best_child->next_ptr(), 291 | best_child->move(), depth + 1)); 292 | } 293 | 294 | return max_depth; 295 | } 296 | 297 | /** 298 | * Sorts the child nodes by the number of visits. 299 | */ 300 | std::vector SortChildren(const Node& nd) const { 301 | std::vector sorted; 302 | if (nd.num_children() > 0) { 303 | for (auto& ch : nd.children) 304 | sorted.emplace_back(const_cast(&ch)); 305 | std::stable_sort(sorted.begin(), sorted.end(), 306 | [](const ChildNode* lhs, const ChildNode* rhs) { 307 | if (lhs->num_values() == rhs->num_values()) 308 | return lhs->prob() > rhs->prob(); 309 | return lhs->num_values() > rhs->num_values(); 310 | }); 311 | } 312 | 313 | return std::move(sorted); 314 | } 315 | 316 | /** 317 | * Returns winning rate of child node scaled to [0, 1]. 318 | */ 319 | double WinningRate(const ChildNode& child) const { 320 | double winning_rate = child.num_values() == 0 321 | ? child.rollout_rate() 322 | : child.num_rollouts() == 0 323 | ? child.value_rate() 324 | : child.winning_rate(lambda_); 325 | 326 | return (winning_rate + 1) / 2; 327 | } 328 | 329 | /** 330 | * Converts a vertex to GTP-style string. e.g. (1,1) -> A1 331 | */ 332 | std::string v2str(Vertex v) const { 333 | return v == kPass 334 | ? "PASS" 335 | : v > kPass ? "NULL" 336 | : std::string("ABCDEFGHJKLMNOPQRST")[x_of(v) - 1] + 337 | std::to_string(y_of(v)); 338 | } 339 | 340 | /** 341 | * Returns elapsed time from t0 in seconds. 342 | */ 343 | double ElapsedTime(const std::chrono::system_clock::time_point& t0) { 344 | auto t1 = std::chrono::system_clock::now(); 345 | return std::chrono::duration_cast(t1 - t0) 346 | .count() / 347 | 1000.0; 348 | } 349 | 350 | /** 351 | * Performs a search with a time limit and return the best move and winning 352 | * rate. 353 | */ 354 | Vertex Search(const Board& b, double time_limit, double* winning_rate, 355 | bool is_errout, bool ponder, int lizzie_interval = -1); 356 | 357 | /** 358 | * Repeats searching with a single thread. 359 | */ 360 | void EvaluateWorker(const Board& b, double time_limit, bool ponder, 361 | int th_id); 362 | 363 | /** 364 | * Rollouts in a single thread. 365 | */ 366 | void RolloutWorker(const Board& b) { 367 | Board b_; 368 | while (!stop_think_) { 369 | b_ = b; 370 | SearchRoute route; 371 | SearchBranch(root_node(), &b_, &route); 372 | } 373 | } 374 | 375 | /** 376 | * Assigns rollout and evaluation workers to multiple threads. 377 | */ 378 | void AllocateThreads(const Board& b, double time_limit, bool ponder, 379 | int lizzie_interval = -1) { 380 | Node* nd = root_node(); 381 | if (nd->num_children() <= 1) { 382 | stop_think_ = true; 383 | return; 384 | } 385 | 386 | int num_rollout_threads = 1; 387 | int num_evaluate_threads = batch_size_ * num_gpus_ * 2; 388 | num_evaluate_threads = std::min(num_threads_, num_evaluate_threads); 389 | num_rollout_threads = 390 | std::max(num_rollout_threads, num_threads_ - num_evaluate_threads); 391 | int num_total_threads = num_evaluate_threads + num_rollout_threads; 392 | 393 | std::vector ths; 394 | for (int i = 0; i < num_total_threads; ++i) { 395 | if (i < num_evaluate_threads) 396 | ths.push_back(std::thread(&SearchTree::EvaluateWorker, this, b, 397 | time_limit, ponder, i)); 398 | else 399 | ths.push_back(std::thread(&SearchTree::RolloutWorker, this, b)); 400 | } 401 | 402 | if (ponder && lizzie_interval > 0) { 403 | do { 404 | LizzieInfo(nd, std::cout); 405 | std::this_thread::sleep_for(std::chrono::milliseconds(lizzie_interval)); 406 | } while (!stop_think_); 407 | } 408 | 409 | for (std::thread& th : ths) th.join(); 410 | } 411 | 412 | /** 413 | * Returns final score. 414 | */ 415 | double FinalScore(const Board& b, Vertex next_move, int num_policy_moves, 416 | int num_playouts, Board::OwnerMap* owner, 417 | TensorEngine* engine = nullptr, EvalCache* cache = nullptr); 418 | 419 | /** 420 | * Returns whether or not a pass should be made. 421 | * If it should not pass and there is a suitable move other than next_move, 422 | * returns it. 423 | */ 424 | Vertex ShouldPass(const Board& b, Vertex next_move, int num_policy_moves, 425 | int num_playouts, TensorEngine* engine = nullptr, 426 | EvalCache* cache = nullptr); 427 | 428 | /** 429 | * Returns string of best sequence. 430 | * e.g. D4 ->D16->Q16->Q4 ->... 431 | */ 432 | std::string PV(const Node* nd, Vertex head_move, int max_move = 6) const; 433 | 434 | /** 435 | * Outputs information on candidate moves. 436 | * The actual move is shown at the top. 437 | */ 438 | void PrintCandidates(const Node* nd, int next_move, std::ostream& ost, 439 | bool flip_value = false) const; 440 | 441 | /** 442 | * Outputs information on candidate moves for Lizzie. 443 | */ 444 | void LizzieInfo(const Node* nd, std::ostream& ost) const; 445 | 446 | private: 447 | double lambda_; 448 | int num_threads_; 449 | int num_gpus_; 450 | double komi_; 451 | bool use_dirichlet_noise_; 452 | bool reflesh_root_; 453 | bool consider_pass_; 454 | std::atomic stop_think_; 455 | std::atomic num_evaluated_; 456 | std::atomic num_reach_ends_; 457 | 458 | EvalCache validate_cache_; 459 | EvalCache eval_cache_; 460 | 461 | std::unique_ptr log_file_; 462 | std::unique_ptr validate_engine_; 463 | std::unique_ptr eval_worker_; 464 | }; 465 | 466 | #endif // SEARCH_H_ 467 | -------------------------------------------------------------------------------- /src/sgf.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #include "./sgf.h" 21 | 22 | void SgfData::Read(std::string file_path) { 23 | // Opens an sfg file. 24 | std::ifstream ifs(file_path); 25 | std::string buf; 26 | const auto npos = std::string::npos; 27 | auto is_numeric = [](const std::string str) { 28 | if (str.empty()) return false; 29 | 30 | std::string str_ = str; 31 | auto pos = str_.find("."); 32 | if (pos != std::string::npos) str_.replace(pos, 1, ""); 33 | 34 | return std::find_if(str_.begin(), str_.end(), 35 | [](char c) { return !std::isdigit(c); }) == str_.end(); 36 | }; 37 | 38 | // Reads lines until eof. 39 | while (ifs && std::getline(ifs, buf)) { 40 | // Moves to the next line when remaining letters are less than 4. 41 | while (buf.size() > 3) { 42 | // Header is typically written as '(;GM[1]FF[4]CA[UTF-8]...'. 43 | // Read two-character tags and contents in [] in order. 44 | std::string tag, in_br; 45 | 46 | // Goes to the next line if [] is not found. 47 | auto open_br = buf.find("["); 48 | auto close_br = buf.find("]"); 49 | 50 | if (open_br == npos || close_br == npos) break; 51 | 52 | if (close_br == 0) { 53 | buf = buf.substr(close_br + 1); 54 | close_br = buf.find("]"); 55 | open_br = buf.find("["); 56 | } 57 | 58 | tag = buf.substr(0, open_br); 59 | in_br = buf.substr(open_br + 1, close_br - open_br - 1); 60 | 61 | // Removes semicolon from the tag. 62 | auto semicolon = tag.find(";"); 63 | if (semicolon != npos) tag = tag.substr(semicolon + 1); 64 | 65 | if (tag == "SZ") { // board size 66 | if (std::stoi(in_br) != kBSize) { 67 | Init(); 68 | return; 69 | } 70 | } else if (tag == "KM") { // komi 71 | // Checks whether in_br is negative. 72 | // [-6.5], ... 73 | int sign_num = 1; 74 | if (in_br.substr(0, 1).find("-") != npos) { 75 | sign_num = -1; 76 | in_br = in_br.substr(1); 77 | } 78 | 79 | // Checks whether in_br is numeric. 80 | if (is_numeric(in_br)) { 81 | // Checks whether it becomes an integer when doubled. 82 | // e.g. [3.75] in Chinese rule 83 | double tmp_floor = 2 * std::stod(in_br) - floor(2 * std::stod(in_br)); 84 | 85 | // Doubles komi for Chinese rule. 86 | if (tmp_floor != 0) { // Chinese rule. 87 | komi_ = 2 * sign_num * std::stod(in_br); 88 | } else { // Japanese rule. 89 | komi_ = sign_num * std::stod(in_br); 90 | } 91 | } 92 | 93 | } else if (tag == "PW" || tag == "PB") { 94 | // Player name. 95 | Color c = tag == "PW" ? kWhite : kBlack; 96 | player_name_[c] = in_br; 97 | } else if (tag == "WR" || tag == "BR") { 98 | // Player rating. 99 | 100 | Color c = tag == "WR" ? kWhite : kBlack; 101 | 102 | // Excludes the trailing '?'. 103 | if (in_br.find("?") != npos) 104 | in_br = in_br.substr(0, in_br.length() - 1); 105 | 106 | // Inputs the rating if in_br is numeric. 107 | if (is_numeric(in_br)) { 108 | player_rating_[c] = std::stoi(in_br); 109 | } else if (in_br.length() == 2) { 110 | // Calculate player rating from the rank if the length of in_br 111 | // is 2. 112 | 113 | // 3000 if a professional player. 114 | if (in_br.find("p") != npos || in_br.find("P") != npos) 115 | player_rating_[c] = 3000; 116 | // 1d = 1580, 2d = 1760, ... 9d = 3020 117 | else if (in_br.find("d") != npos || in_br.find("D") != npos) 118 | player_rating_[c] = 1400 + std::stoi(in_br.substr(0, 1)) * 180; 119 | // 1k = 1450, 2k = 1350, ... 120 | else if (in_br.find("k") != npos || in_br.find("K") != npos) 121 | player_rating_[c] = 1550 - std::stoi(in_br.substr(0, 1)) * 100; 122 | } 123 | 124 | } else if (tag == "HA") { 125 | // Number of the handicap stones. 126 | // e.g. 2, 3, 4, ... 127 | 128 | // Checks whether in_br is numeric. 129 | if (is_numeric(in_br)) { 130 | handicap_ = std::stoi(in_br); 131 | } 132 | } else if (tag == "RE") { 133 | // Result. 134 | // e.g. W+R, B+Resign, W+6.5, B+Time, ... 135 | 136 | auto b = in_br.find("B+"); 137 | auto w = in_br.find("W+"); 138 | 139 | if (b == npos && w == npos) { 140 | score_ = 0.0; 141 | } else { 142 | if (is_numeric(in_br.substr(2))) { // Won by score 143 | score_ = std::stod(in_br.substr(2)); 144 | if (2 * score_ - floor(2 * score_) != 0) score_ *= 2; 145 | if (w != npos) score_ = -score_; 146 | } else if (in_br.find("R") != npos) { // Won by resign 147 | score_ = w != npos ? -512 : 512; 148 | } else { // Won by time or illegal 149 | score_ = w != npos ? -1024 : 1024; 150 | } 151 | } 152 | 153 | } else if (tag == "W" || tag == "B") { 154 | // Move 155 | 156 | // Checks whether the game starts from white (i.e. a handicap match) 157 | if (tag == "W" && game_ply() == 0) komi_ = 0; 158 | move_history_.push_back(sgf2v(in_br)); 159 | } else if (tag == "AB" || tag == "AW") { 160 | // Handicap stones. 161 | // e.g. AB[dd][qq], ... 162 | 163 | handicap_stones_[static_cast(tag == "AB")].push_back(sgf2v(in_br)); 164 | std::string::size_type next_open_br = 165 | buf.substr(close_br + 1).find("["); 166 | std::string::size_type next_close_br = 167 | buf.substr(close_br + 1).find("]"); 168 | 169 | // Reads continuous [] 170 | while (next_open_br == 0) { 171 | open_br = close_br + 1 + next_open_br; 172 | close_br = close_br + 1 + next_close_br; 173 | 174 | std::stringstream ss_in_br; 175 | in_br = ""; 176 | ss_in_br << buf.substr(open_br + 1, close_br - open_br - 1); 177 | ss_in_br >> in_br; 178 | handicap_stones_[static_cast(tag == "AB")].push_back( 179 | sgf2v(in_br)); 180 | 181 | next_open_br = buf.substr(close_br + 1).find("["); 182 | next_close_br = buf.substr(close_br + 1).find("]"); 183 | } 184 | } 185 | 186 | // Excludes tag[in_br] from buf. 187 | buf = buf.substr(close_br + 1); 188 | } 189 | } 190 | } 191 | 192 | void SgfData::Write(std::string file_path, 193 | std::vector* comments) const { 194 | // Opens file. 195 | std::stringstream ss; 196 | std::string rule_str = 197 | Options["rule"].get_int() == kJapanese ? "Japanese" : "Chinese"; 198 | 199 | // Uses fixed header. 200 | ss << "(;GM[1]FF[4]CA[UTF-8]" << std::endl; 201 | ss << "RU[" << rule_str << "]SZ[" << kBSize << "]KM[" << komi_ << "]"; 202 | if (player_name_[kWhite] != "" && player_name_[kBlack] != "") { 203 | ss << "PB[" << player_name_[kBlack] << "]" 204 | << "PW[" << player_name_[kWhite] << "]"; 205 | } 206 | if (score_ != 0) { 207 | std::string winner = (score_ > 0) ? "B+" : "W+"; 208 | char buf[16]; 209 | std::snprintf(buf, sizeof(buf), "%.1f", std::abs(score_)); 210 | std::string score_str(buf); 211 | if (std::abs(score_) == 512) score_str = "R"; 212 | ss << "RE[" << winner << score_str << "]"; 213 | } else { 214 | ss << "RE[0]"; 215 | } 216 | ss << std::endl; 217 | 218 | std::string str = "abcdefghijklmnopqrs"; 219 | for (int i = 0, n = move_history_.size(); i < n; ++i) { 220 | Vertex v = move_history_[i]; 221 | int x = x_of(v) - 1; 222 | int y = kBSize - y_of(v); 223 | 224 | ss << (i % 2 == 0 ? ";B[" : ";W["); 225 | ss << (v < kPass ? str.substr(x, 1) + str.substr(y, 1) : "") << "]"; 226 | if (comments != nullptr && static_cast(comments->size()) > i) 227 | ss << "C[" << (*comments)[i] << "]"; 228 | if ((i + 1) % 8 == 0) ss << std::endl; 229 | } 230 | ss << ")" << std::endl; 231 | 232 | std::ofstream ofs(file_path.c_str()); 233 | ofs << ss.str(); 234 | ofs.close(); 235 | } 236 | 237 | bool SgfData::ReconstructBoard(Board* b, int move_idx) const { 238 | if (move_idx > game_ply()) return false; 239 | b->Init(); 240 | 241 | // place handicap stones 242 | const int i_max = std::max(handicap_stones_[kWhite].size(), 243 | handicap_stones_[kBlack].size()); 244 | for (int i = 0; i < i_max; ++i) { 245 | for (Color c = kColorZero; c < kNumPlayers; ++c) { 246 | if (static_cast(handicap_stones_[c].size()) > i) { 247 | if (!b->IsLegal(handicap_stones_[c][i])) return false; 248 | b->MakeMove(handicap_stones_[c][i]); 249 | } else { 250 | b->MakeMove(kPass); 251 | } 252 | } 253 | } 254 | 255 | if (i_max == 0 && handicap_ > 0) { 256 | int x_[9] = {4, 16, 4, 16, 4, 16, 10, 10, 10}; 257 | int y_[9] = {4, 16, 16, 4, 10, 10, 4, 16, 10}; 258 | int stones[8][9] = {{0, 1}, 259 | {0, 1, 2}, 260 | {0, 1, 2, 3}, 261 | {0, 1, 2, 3, 8}, 262 | {0, 1, 2, 3, 4, 5}, 263 | {0, 1, 2, 3, 4, 5, 8}, 264 | {0, 1, 2, 3, 4, 5, 6, 7}, 265 | {0, 1, 2, 3, 4, 5, 6, 7, 8}}; 266 | int hc_idx = handicap_ - 2; 267 | for (int i = 0; i < handicap_; ++i) { 268 | int stone_idx = stones[hc_idx][i]; 269 | Vertex v = xy2v(x_[stone_idx], y_[stone_idx]); 270 | b->MakeMove(v); 271 | b->MakeMove(kPass); 272 | } 273 | } 274 | 275 | if ((handicap_stones_[kWhite].size() == 0 && 276 | handicap_stones_[kBlack].size() > 0) || 277 | handicap_ > 0) 278 | b->MakeMove(kPass); // kBlack kPass 279 | 280 | // Resets game_ply. 281 | b->set_game_ply(0); 282 | b->set_num_passes(kWhite, 0); 283 | b->set_num_passes(kBlack, 0); 284 | 285 | // Initial board. 286 | if (move_idx == 0) return true; 287 | 288 | // Plays until move_idx. 289 | for (int i = 0; i < move_idx; ++i) { 290 | if (!b->IsLegal(move_history_[i])) return false; 291 | b->MakeMove(move_history_[i]); 292 | } 293 | 294 | return true; 295 | } 296 | 297 | #ifdef _WIN32 298 | int SgfData::GetSgfFiles(std::string dir_path, 299 | std::vector* files) { 300 | int num_sgf_files = 0; 301 | HANDLE h_find; 302 | WIN32_FIND_DATA fd; 303 | 304 | if (dir_path.size() < 2) { 305 | dir_path = ".\\"; 306 | } else if (dir_path.substr(dir_path.size() - 2) != "\\") { 307 | dir_path += "\\"; 308 | } 309 | std::string file_path = dir_path + "*.sgf"; 310 | // FindFirstFile needs a char* argument in MinGW-x64. 311 | h_find = FindFirstFile(file_path.c_str(), &fd); 312 | 313 | // Fails to find. 314 | if (h_find == INVALID_HANDLE_VALUE) { 315 | return num_sgf_files; // 0 316 | } 317 | 318 | do { 319 | // Excludes directory. 320 | if (!(fd.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) && 321 | !(fd.dwFileAttributes & FILE_ATTRIBUTE_HIDDEN)) { 322 | // fd.cFileName is char* in MinGW-x64. 323 | std::string file_name = fd.cFileName; 324 | files->push_back(dir_path + file_name); 325 | ++num_sgf_files; 326 | } 327 | } while (FindNextFile(h_find, &fd)); // Next file. 328 | 329 | FindClose(h_find); 330 | return num_sgf_files; 331 | } 332 | #else // Linux 333 | #include 334 | int SgfData::GetSgfFiles(std::string dir_path, 335 | std::vector* files) { 336 | files->clear(); 337 | int num_sgf_files = 0; 338 | DIR* dr = opendir(dir_path.c_str()); 339 | if (dr == NULL) return num_sgf_files; 340 | dirent* entry; 341 | do { 342 | entry = readdir(dr); 343 | if (entry != NULL) { 344 | std::string file_name = entry->d_name; 345 | if (file_name.find(".sgf") == std::string::npos) continue; 346 | files->push_back(file_name); 347 | num_sgf_files++; 348 | } 349 | } while (entry != NULL); 350 | closedir(dr); 351 | return num_sgf_files; 352 | } 353 | #endif // _WIN32 354 | -------------------------------------------------------------------------------- /src/sgf.h: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #ifndef SGF_H_ 21 | #define SGF_H_ 22 | 23 | #include 24 | #include 25 | #include 26 | 27 | #include "./board.h" 28 | #include "./option.h" 29 | 30 | /** 31 | * @class SgfData 32 | * SgfData class has player information, game results, and moves. 33 | * It inputs and outputs them to the sgf format game record. 34 | */ 35 | class SgfData { 36 | public: 37 | // Constructor 38 | SgfData() { Init(); } 39 | 40 | std::string player_name(Color c) const { return player_name_[c]; } 41 | 42 | double score() const { return score_; } 43 | 44 | void set_score(double val) { score_ = val; } 45 | 46 | Vertex move_at(int t) const { return move_history_[t]; } 47 | 48 | int game_ply() const { return move_history_.size(); } 49 | 50 | Color winner() const { 51 | return score_ == 0 ? kEmpty : score_ > 0 ? kBlack : kWhite; 52 | } 53 | 54 | bool resign_or_score() const { 55 | return game_ply() >= 12 && std::abs(score_) < 1024; 56 | } 57 | 58 | void Init() { 59 | komi_ = kBSize < 19 ? 7.0 : Options["komi"].get_double(); 60 | handicap_ = 0; 61 | for (Color c = kColorZero; c < kNumPlayers; ++c) { 62 | player_name_[c] = ""; 63 | player_rating_[c] = 2800; 64 | handicap_stones_[c].clear(); 65 | } 66 | move_history_.clear(); 67 | score_ = 0.0; 68 | } 69 | 70 | void Add(Vertex v) { 71 | ASSERT_LV3(v == kPass || (kRvtZero <= v2rv(v) && v2rv(v) < kNumRvts)); 72 | move_history_.push_back(v); 73 | } 74 | 75 | /** 76 | * Returns Vertex converted from sgf-style string. 77 | * 78 | * aa -> Vertex(22), i.e. (x,y) = (1,1) 79 | */ 80 | Vertex sgf2v(std::string aa) const { 81 | // Returns kPass if input size is not 2. 82 | if (aa.size() != 2) return kPass; 83 | 84 | // Converts aa to (x,y) of RawVertex. 85 | char a0 = aa[0]; 86 | char a1 = aa[1]; 87 | int rx = isupper(a0) ? a0 - 'A' : a0 - 'a'; 88 | int ry = isupper(a1) ? a1 - 'A' : a1 - 'a'; 89 | 90 | if (rx < 0 || ry < 0 || rx >= kBSize || ry >= kBSize) return kPass; 91 | 92 | return rv2v(xy2rv(rx, ry)); 93 | } 94 | 95 | void Read(std::string file_path); 96 | 97 | /** 98 | * Outputs the match information to an SGF file with comments. 99 | */ 100 | void Write(std::string file_path, 101 | std::vector* comments = nullptr) const; 102 | 103 | /** 104 | * Constructs board from sgf information. 105 | */ 106 | bool ReconstructBoard(Board* b, int move_idx) const; 107 | 108 | /** 109 | * Imports all sgf files in the folder. 110 | */ 111 | static int GetSgfFiles(std::string dir_path, std::vector* files); 112 | 113 | private: 114 | double komi_; 115 | std::string player_name_[kNumPlayers]; 116 | int player_rating_[kNumPlayers]; 117 | int handicap_; 118 | std::vector handicap_stones_[kNumPlayers]; 119 | std::vector move_history_; 120 | double score_; 121 | }; // SgfData 122 | 123 | #endif // SGF_H_ 124 | -------------------------------------------------------------------------------- /src/test.h: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #ifndef TEST_H_ 21 | #define TEST_H_ 22 | 23 | #include "./board.h" 24 | #include "./network.h" 25 | #include "./option.h" 26 | #include "./search.h" 27 | #include "./sgf.h" 28 | 29 | /** 30 | * Tests structure and transitions of Board class. 31 | */ 32 | void TestBoard(); 33 | 34 | /** 35 | * Checks if the board with symmetric operation is registered in EvalCache. 36 | */ 37 | void TestSymmetry(); 38 | 39 | /** 40 | * Plays a self match with the hand with the maximum probability of Policy head. 41 | */ 42 | void PolicySelf(); 43 | 44 | /** 45 | * Plays a self match. 46 | */ 47 | void SelfMatch(); 48 | 49 | /** 50 | * Benchmark the inference speed of a neural network. 51 | */ 52 | void NetworkBench(); 53 | 54 | /** 55 | * Measures the execution speed of the rollout. 56 | */ 57 | void BenchMark(); 58 | 59 | /** 60 | * Measures the speed at which a tree node is freed from memory. 61 | */ 62 | void TestFreeMemory(); 63 | 64 | /** 65 | * Reads an SGF file and test the score of the last board. 66 | */ 67 | void ReadSgfFinalScore(int argc, char** argv); 68 | 69 | /** 70 | * Random rollouts are performed to display the seki. 71 | */ 72 | void TestSeki(); 73 | 74 | /** 75 | * Tests to see if you can pass properly under Japanese rules. 76 | */ 77 | void TestPassMove(); 78 | 79 | #endif // TEST_H_ 80 | -------------------------------------------------------------------------------- /src/timer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #ifndef TIMER_H_ 21 | #define TIMER_H_ 22 | 23 | #include 24 | #include "./option.h" 25 | 26 | /** 27 | * @class Timer 28 | * A class for control of holding time. 29 | * Adjusts the maximum consideration time according to the time remaining and 30 | * the degree of progress. 31 | */ 32 | class Timer { 33 | public: 34 | Timer() { 35 | main_time_ = Options["main_time"].get_double(); 36 | byoyomi_ = Options["byoyomi"].get_double(); 37 | byoyomi_margin_ = Options["byoyomi_margin"].get_double(); 38 | num_extensions_ = Options["num_extensions"].get_int(); 39 | left_time_ = main_time_; 40 | } 41 | 42 | double main_time() const { return main_time_; } 43 | double byoyomi() const { return byoyomi_; } 44 | int num_extensions() const { return num_extensions_; } 45 | double left_time() const { return left_time_; } 46 | 47 | void set_main_time(double val) { main_time_ = val; } 48 | void set_num_extensions(int val) { num_extensions_ = val; } 49 | void set_byoyomi(double val) { byoyomi_ = val; } 50 | void set_left_time(double val) { left_time_ = val; } 51 | 52 | void Init() { 53 | main_time_ = Options["main_time"].get_double(); 54 | byoyomi_ = Options["byoyomi"].get_double(); 55 | byoyomi_margin_ = Options["byoyomi_margin"].get_double(); 56 | num_extensions_ = Options["num_extensions"].get_int(); 57 | left_time_ = main_time_; 58 | } 59 | 60 | double ThinkingTime(int ply, bool* extendable, double lost_time = 0.0) { 61 | double t = 1.0; 62 | *extendable = false; 63 | 64 | if (main_time_ == 0.0) { // Byoyomi only. 65 | // Takes margin. 66 | if (byoyomi_ >= 10) 67 | t = byoyomi_ - byoyomi_margin_; 68 | else 69 | t = std::max(byoyomi_, 0.1); 70 | *extendable = (num_extensions_ > 0); 71 | } else { // Main time + byoyomi or sudden death. 72 | if (left_time_ < byoyomi_ * 2.0) { 73 | t = std::max(byoyomi_ - byoyomi_margin_, 1.0); // Takes margin. 74 | *extendable = (num_extensions_ > 0); 75 | } else { 76 | // Calculates from remaining time if sudden death 77 | // otherwise set that of 0.5-1 times of byoyomi 78 | // 79 | // 1-16 moves : 0.5 * byoyomi 80 | // 17-32 moves: 0.5-2.0 * byoyomi 81 | // > 32 moves : 2.0 * byoyomi 82 | t = std::max( 83 | left_time_ / (55.0 + std::max(50.0 - ply, 0.0)), 84 | byoyomi_ * 85 | (0.5 + 1.5 * std::min(1.0, std::max(0.0, (ply - 16) / 86 | (32.0 - 16.0))))); 87 | // Does not extend thinking time if the remaining time is 30% or less. 88 | *extendable = (left_time_ > main_time_ * 0.3) || (byoyomi_ >= 10); 89 | } 90 | } 91 | 92 | t = std::max(t - lost_time, 0.1); 93 | return t; 94 | } 95 | 96 | private: 97 | double main_time_; 98 | double byoyomi_; 99 | double byoyomi_margin_; 100 | int num_extensions_; 101 | double left_time_; 102 | }; 103 | 104 | /** 105 | * @class SearchParameter 106 | * Structure that stores hyperparameters for search. 107 | */ 108 | class SearchParameter { 109 | public: 110 | SearchParameter() { 111 | batch_size_ = Options["batch_size"].get_int(); 112 | lambda_init_ = Options["lambda_init"].get_double(); 113 | lambda_delta_ = Options["lambda_delta"].get_double(); 114 | lambda_move_start_ = Options["lambda_move_start"].get_int(); 115 | lambda_move_end_ = Options["lambda_move_end"].get_int(); 116 | cp_init_ = Options["cp_init"].get_double(); 117 | cp_base_ = Options["cp_base"].get_double(); 118 | virtual_loss_ = Options["virtual_loss"].get_int(); 119 | search_limit_ = Options["search_limit"].get_int(); 120 | ladder_reduction_ = Options["ladder_reduction"].get_double(); 121 | } 122 | 123 | SearchParameter& operator=(const SearchParameter& rhs) { 124 | batch_size_ = rhs.batch_size_; 125 | lambda_init_ = rhs.lambda_init_; 126 | lambda_delta_ = rhs.lambda_delta_; 127 | lambda_move_start_ = rhs.lambda_move_start_; 128 | lambda_move_end_ = rhs.lambda_move_end_; 129 | cp_init_ = rhs.cp_init_; 130 | cp_base_ = rhs.cp_base_; 131 | virtual_loss_ = rhs.virtual_loss_; 132 | search_limit_ = rhs.search_limit_; 133 | ladder_reduction_ = rhs.ladder_reduction_; 134 | 135 | return *this; 136 | } 137 | 138 | #if defined(LEARN) 139 | friend class MySQLConnector; 140 | #endif 141 | 142 | protected: 143 | int batch_size_; 144 | double lambda_init_; 145 | double lambda_delta_; 146 | int lambda_move_start_; 147 | int lambda_move_end_; 148 | double cp_init_; 149 | double cp_base_; 150 | int virtual_loss_; 151 | int search_limit_; 152 | double ladder_reduction_; 153 | }; 154 | 155 | #endif // TIMER_H_ 156 | -------------------------------------------------------------------------------- /src/types.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * AQ, a Go playing engine. 3 | * Copyright (C) 2017-2020 Yu Yamaguchi 4 | * except where otherwise indicated. 5 | * 6 | * This program is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | */ 19 | 20 | #include "./types.h" 21 | 22 | // --- Random generator 23 | 24 | std::random_device RandomGenerator::rd_; 25 | std::mt19937 RandomGenerator::mt_(RandomGenerator::rd_()); 26 | std::uniform_real_distribution RandomGenerator::uniform_double_(0.0, 27 | 1.0); 28 | std::uniform_int_distribution RandomGenerator::uniform_int_(0, 7); 29 | std::gamma_distribution RandomGenerator::gamma_double_(0.03, 1.0); 30 | 31 | /** 32 | * @namespace 33 | * Obscure namespace for defining auxiliary functions for initialization of 34 | * CoordinateTable. 35 | */ 36 | namespace { 37 | template 38 | void AddNakadeHash(const Vertex (&space)[N][M], const Vertex (&vital)[N], 39 | const uint64_t (&zobrist)[8][4][kNumVts], 40 | std::unordered_map* nakade); 41 | 42 | void InitNakade(const uint64_t (&zobrist)[8][4][kNumVts], 43 | std::unordered_map* nakade, 44 | std::unordered_set* bent4); 45 | } // namespace 46 | 47 | /** 48 | * Initializes tables for coordinate transformation. 49 | */ 50 | CoordinateTable::CoordinateTable() { 51 | // Vertex -> x,y,rv 52 | for (int i = 0; i < kNumVts; ++i) { 53 | v2x_table[i] = i % int{kEBSize}; 54 | v2y_table[i] = i / int{kEBSize}; 55 | in_wall_table[i] = 56 | (v2x_table[i] == 0 || v2y_table[i] == 0 || 57 | v2x_table[i] == int{kEBSize} - 1 || v2y_table[i] == int{kEBSize} - 1); 58 | 59 | if (in_wall_table[i]) 60 | v2rv_table[i] = kRvtNull; 61 | else 62 | v2rv_table[i] = 63 | RawVertex((v2x_table[i] - 1) + (v2y_table[i] - 1) * int{kBSize}); 64 | } 65 | 66 | v2x_table[kNumVts] = kEBSize - 1; 67 | v2y_table[kNumVts] = kEBSize - 1; 68 | in_wall_table[kNumVts] = true; 69 | v2rv_table[kNumVts] = kRvtNull; 70 | 71 | // x,y -> Vertex 72 | for (int i = 0; i < kEBSize; ++i) { 73 | for (int j = 0; j < kEBSize; ++j) { 74 | xy2v_table[i][j] = Vertex(i + j * int{kEBSize}); 75 | } 76 | } 77 | 78 | // RawVertex -> x,y 79 | for (int i = 0; i < kNumRvts; ++i) { 80 | rv2x_table[i] = i % int{kBSize}; 81 | rv2y_table[i] = i / int{kBSize}; 82 | rv2v_table[i] = 83 | Vertex((rv2x_table[i] + 1) + (rv2y_table[i] + 1) * int{kEBSize}); 84 | } 85 | 86 | // x,y -> RawVertex 87 | for (int i = 0; i < kBSize; ++i) { 88 | for (int j = 0; j < kBSize; ++j) { 89 | xy2rv_table[i][j] = RawVertex(i + j * int{kBSize}); 90 | } 91 | } 92 | 93 | // v,rv -> sym 94 | for (int i = 0; i < 8; ++i) { 95 | for (Vertex v = kVtZero; v < kNumVts; ++v) { 96 | int x = kEBSize - 1 - v2x_table[v]; 97 | int y = v2y_table[v]; 98 | 99 | if (i == 0) 100 | v2sym_table[i][v] = v; 101 | else if (i == 4) 102 | v2sym_table[i][v] = v2sym_table[0][xy2v_table[x][y]]; // Inverts. 103 | else 104 | v2sym_table[i][v] = v2sym_table[i - 1][xy2v_table[y][x]]; // Rotates. 105 | } 106 | for (RawVertex rv = kRvtZero; rv < kNumRvts; ++rv) { 107 | int x = kBSize - 1 - rv2x_table[rv]; 108 | int y = rv2y_table[rv]; 109 | 110 | if (i == 0) 111 | rv2sym_table[i][rv] = int{rv}; 112 | else if (i == 4) 113 | rv2sym_table[i][rv] = rv2sym_table[0][xy2rv_table[x][y]]; // Inverts. 114 | else 115 | rv2sym_table[i][rv] = 116 | rv2sym_table[i - 1][xy2rv_table[y][x]]; // Rotates. 117 | } 118 | v2sym_table[i][kNumVts] = kPass; 119 | } 120 | 121 | // Distance 122 | for (int i = 0; i < kNumVtsPlus1; ++i) { 123 | int dx_edge = 124 | (std::min)(v2x_table[i], std::abs(kEBSize - 1 - v2x_table[i])); 125 | int dy_edge = 126 | (std::min)(v2y_table[i], std::abs(kEBSize - 1 - v2y_table[i])); 127 | dist_edge_table[i] = (std::min)(dx_edge, dy_edge); 128 | 129 | for (int j = 0; j < kNumVtsPlus1; ++j) { 130 | if (i == kNumVts || j == kNumVts) { 131 | dist_table[i][j] = 3 * (kEBSize - 1); 132 | } else { 133 | int dx = std::abs(v2x_table[j] - v2x_table[i]); 134 | int dy = std::abs(v2y_table[j] - v2y_table[i]); 135 | dist_table[i][j] = dx + dy + (std::max)(dx, dy); 136 | } 137 | } 138 | } 139 | 140 | // Bitboard 141 | for (int i = 0; i < kNumVtsPlus1; ++i) { 142 | v2bb_idx_table[i] = v2rv_table[i] / 64; 143 | if (v2rv_table[i] == kRvtNull) 144 | v2bb_bit_table[i] = 0; 145 | else 146 | v2bb_bit_table[i] = 0x1ULL << (v2rv_table[i] % 64); 147 | } 148 | 149 | for (int i = 0; i < kNumBBs; ++i) { 150 | for (int j = 0; j < 64; ++j) { 151 | RawVertex rv = RawVertex(i * 64 + j); 152 | if (kRvtZero <= rv && rv < kNumRvts) 153 | bb2v_table[i][j] = rv2v_table[rv]; 154 | else 155 | bb2v_table[i][j] = kVtNull; 156 | } 157 | } 158 | 159 | // Zobrist hash 160 | std::mt19937_64 mt_64_(123); 161 | 162 | for (int i = 0; i < 8; ++i) { 163 | for (int j = 0; j < 4; ++j) { 164 | for (Vertex v = kVtZero; v < kNumVts; ++v) { 165 | int x = kEBSize - 1 - v2x_table[v]; 166 | int y = v2y_table[v]; 167 | 168 | if (i == 0) 169 | zobrist_table[i][j][v] = mt_64_(); 170 | else if (i == 4) 171 | zobrist_table[i][j][v] = 172 | zobrist_table[0][j][xy2v_table[x][y]]; // Inverts. 173 | else 174 | zobrist_table[i][j][v] = 175 | zobrist_table[i - 1][j][xy2v_table[y][x]]; // Rotates. 176 | } 177 | } 178 | } 179 | 180 | // Nakade 181 | InitNakade(zobrist_table, &nakade_map, &bent4_set); 182 | } 183 | 184 | namespace { 185 | 186 | template 187 | void AddNakadeHash(const Vertex (&space)[N][M], const Vertex (&vital)[N], 188 | const uint64_t (&zobrist)[8][4][kNumVts], 189 | std::unordered_map* nakade) { 190 | Vertex sym_pos[8][kNumVts]; 191 | 192 | for (int i = 0; i < 8; ++i) { 193 | for (Vertex v = kVtZero; v < kNumVts; ++v) { 194 | int x = kEBSize - 1 - int{v} % int{kEBSize}; 195 | int y = int{v} / int{kEBSize}; 196 | 197 | if (i == 0) 198 | sym_pos[i][v] = v; 199 | else if (i == 4) 200 | sym_pos[i][v] = sym_pos[0][x + y * int{kEBSize}]; // Inverts. 201 | else 202 | sym_pos[i][v] = sym_pos[i - 1][y + x * int{kEBSize}]; // Rotates. 203 | } 204 | } 205 | 206 | auto is_in_wall = [](Vertex v) { 207 | int x = int{v} % int{kEBSize}; 208 | int y = int{v} / int{kEBSize}; 209 | return (x == 0 || x == int{kEBSize} - 1 || y == 0 || y == int{kEBSize} - 1); 210 | }; 211 | 212 | for (Vertex v = kVtZero; v < kNumVts; ++v) { 213 | if (is_in_wall(v)) continue; 214 | 215 | for (int i = 0; i < 8; ++i) { 216 | for (int j = 0; j < N; ++j) { 217 | uint64_t space_hash = 0; 218 | bool inside = true; 219 | 220 | for (int k = 0; k < M; ++k) { 221 | Vertex v_k = v + space[j][k]; 222 | if (0 <= v_k && v_k < kNumVts && !is_in_wall(v_k)) 223 | space_hash ^= zobrist[i][2][v_k]; 224 | else 225 | inside = false; 226 | } 227 | 228 | if (inside) nakade->insert({space_hash, sym_pos[i][v + vital[j]]}); 229 | } 230 | } 231 | } // for v 232 | } 233 | 234 | /** 235 | * Initializes nakade tables. 236 | */ 237 | void InitNakade(const uint64_t (&zobrist)[8][4][kNumVts], 238 | std::unordered_map* nakade, 239 | std::unordered_set* bent4) { 240 | const Vertex space_3[4][3] = { 241 | {kVtL, kVtZero, kVtR}, // kVtZero <- vital position 242 | {kVtZero, kVtR, kVtU}, // kVtZero 243 | {kVtZero, kVtR, kVtRR}, // kVtR 244 | {kVtZero, kVtR, kVtRU} // kVtR 245 | }; 246 | 247 | const Vertex space_4[4][4] = { 248 | {kVtL, kVtZero, kVtR, kVtU}, // kVtZero 249 | {kVtZero, kVtR, kVtRD, kVtRU}, // kVtR 250 | {kVtZero, kVtR, kVtRR, kVtRU}, // kVtR 251 | {kVtZero, kVtR, kVtU, kVtRU} // kVtZero 252 | }; 253 | 254 | const Vertex space_5[7][5] = { 255 | {kVtL, kVtZero, kVtR, kVtLU, kVtU}, // kVtZero 256 | {kVtL, kVtZero, kVtR, kVtD, kVtU}, // kVtZero 257 | {kVtZero, kVtR, kVtRR, kVtU, kVtRU}, // kVtR 258 | {kVtZero, kVtR, kVtRR, kVtRU, kVtRU + kVtR}, // kVtR 259 | {kVtZero, kVtR, kVtRD, kVtU, kVtRU}, // kVtR 260 | {kVtZero, kVtR, kVtRR, kVtRD, kVtRU}, // kVtR 261 | {kVtZero, kVtR, kVtU, kVtRU, kVtRU + kVtR} // RU 262 | }; 263 | 264 | const Vertex space_6[4][6] = { 265 | {kVtLD, kVtD, kVtL, kVtZero, kVtR, kVtU}, // kVtZero 266 | {kVtZero, kVtR, kVtRR, kVtRD, kVtU, kVtRU}, // kVtR 267 | {kVtRD, kVtZero, kVtR, kVtRR, kVtRU, kVtRU + kVtR}, // kVtR 268 | {kVtZero, kVtR, kVtU, kVtRU, kVtRU + kVtR, kVtUU + kVtR} // RU 269 | }; 270 | 271 | const Vertex vital_3[4] = {kVtZero, kVtZero, kVtR, kVtR}; 272 | const Vertex vital_4[4] = {kVtZero, kVtR, kVtR, kVtZero}; 273 | const Vertex vital_5[7] = {kVtZero, kVtZero, kVtR, kVtR, kVtR, kVtR, kVtRU}; 274 | const Vertex vital_6[4] = {kVtZero, kVtR, kVtR, kVtRU}; 275 | 276 | AddNakadeHash<4, 3>(space_3, vital_3, zobrist, nakade); 277 | AddNakadeHash<4, 4>(space_4, vital_4, zobrist, nakade); 278 | AddNakadeHash<7, 5>(space_5, vital_5, zobrist, nakade); 279 | AddNakadeHash<4, 6>(space_6, vital_6, zobrist, nakade); 280 | 281 | Vertex v_corner = kVtZero + kVtRU; 282 | Vertex space_bend[2][3] = { 283 | {v_corner, v_corner + kVtR, v_corner + kVtU}, 284 | {v_corner, v_corner + kVtU, v_corner + kVtUU}, 285 | }; 286 | 287 | for (int i = 0; i < 8; ++i) { 288 | for (int j = 0; j < 2; ++j) { 289 | uint64_t space_hash = 0; 290 | for (int k = 0; k < 3; ++k) { 291 | space_hash ^= zobrist[i][2][space_bend[j][k]]; 292 | } 293 | bent4->insert(space_hash); 294 | } 295 | } 296 | } 297 | 298 | } // namespace 299 | --------------------------------------------------------------------------------