├── .autopep8 ├── .clang-format ├── .githooks ├── autopep8.py ├── cpplint.py └── pre-commit ├── .gitignore ├── CMakeLists.txt ├── README.md ├── cfg ├── atari.cfg ├── gridworld.cfg └── gridworld_eval.cfg ├── docs ├── Console.md ├── Development.md ├── Evaluation.md ├── Training.md └── imgs │ ├── minizero-architecture.svg │ ├── minizero_atari.svg │ ├── minizero_go_9x9.svg │ ├── minizero_othello_8x8.svg │ └── optionzero_atari.svg ├── gridworld_maps └── gridworld_000.txt ├── minizero ├── CMakeLists.txt ├── actor │ ├── CMakeLists.txt │ ├── actor_group.cpp │ ├── actor_group.h │ ├── base_actor.cpp │ ├── base_actor.h │ ├── create_actor.h │ ├── gumbel_zero.cpp │ ├── gumbel_zero.h │ ├── mcts.cpp │ ├── mcts.h │ ├── search.h │ ├── tree.h │ ├── zero_actor.cpp │ └── zero_actor.h ├── config │ ├── CMakeLists.txt │ ├── configuration.cpp │ ├── configuration.h │ ├── configure_loader.cpp │ └── configure_loader.h ├── console │ ├── CMakeLists.txt │ ├── console.cpp │ ├── console.h │ ├── mode_handler.cpp │ └── mode_handler.h ├── environment │ ├── CMakeLists.txt │ ├── amazons │ │ ├── amazons.cpp │ │ └── amazons.h │ ├── atari │ │ ├── atari.cpp │ │ ├── atari.h │ │ ├── obs_recover.cpp │ │ ├── obs_recover.h │ │ ├── obs_remover.cpp │ │ └── obs_remover.h │ ├── base │ │ ├── base_env.cpp │ │ └── base_env.h │ ├── breakthrough │ │ ├── breakthrough.cpp │ │ └── breakthrough.h │ ├── clobber │ │ ├── clobber.cpp │ │ └── clobber.h │ ├── conhex │ │ ├── conhex.cpp │ │ ├── conhex.h │ │ ├── conhex_graph.cpp │ │ ├── conhex_graph.h │ │ ├── conhex_graph_cell.cpp │ │ ├── conhex_graph_cell.h │ │ ├── conhex_graph_flag.h │ │ ├── disjoint_set_union.cpp │ │ └── disjoint_set_union.h │ ├── connect6 │ │ ├── connect6.cpp │ │ └── connect6.h │ ├── dotsandboxes │ │ ├── dotsandboxes.cpp │ │ └── dotsandboxes.h │ ├── environment.h │ ├── go │ │ ├── go.cpp │ │ ├── go.h │ │ ├── go_area.h │ │ ├── go_block.h │ │ ├── go_data_structure_check.cpp │ │ ├── go_grid.h │ │ └── go_unit.h │ ├── gomoku │ │ ├── gomoku.cpp │ │ └── gomoku.h │ ├── gridworld │ │ ├── gridworld.cpp │ │ └── gridworld.h │ ├── havannah │ │ ├── havannah.cpp │ │ └── havannah.h │ ├── hex │ │ ├── hex.cpp │ │ └── hex.h │ ├── killallgo │ │ ├── killallgo.cpp │ │ ├── killallgo.h │ │ ├── killallgo_7x7_bitboard.h │ │ ├── killallgo_seki_7x7.cpp │ │ └── killallgo_seki_7x7.h │ ├── linesofaction │ │ ├── linesofaction.cpp │ │ └── linesofaction.h │ ├── nogo │ │ └── nogo.h │ ├── othello │ │ ├── othello.cpp │ │ └── othello.h │ ├── rubiks │ │ ├── rubiks.cpp │ │ └── rubiks.h │ ├── santorini │ │ ├── bitboard.h │ │ ├── board.cpp │ │ ├── board.h │ │ ├── santorini.cpp │ │ └── santorini.h │ ├── stochastic │ │ ├── puzzle2048 │ │ │ ├── bitboard.h │ │ │ ├── puzzle2048.cpp │ │ │ └── puzzle2048.h │ │ └── stochastic_env.h │ ├── surakarta │ │ ├── surakarta.cpp │ │ └── surakarta.h │ └── tictactoe │ │ ├── tictactoe.cpp │ │ └── tictactoe.h ├── learner │ ├── CMakeLists.txt │ ├── data_loader.cpp │ ├── data_loader.h │ ├── pybind.cpp │ └── train.py ├── minizero.cpp ├── network │ ├── CMakeLists.txt │ ├── alphazero_network.h │ ├── create_network.h │ ├── muzero_network.h │ ├── network.cpp │ ├── network.h │ └── py │ │ ├── alphazero_network.py │ │ ├── create_network.py │ │ ├── muzero_atari_network.py │ │ ├── muzero_gridworld_network.py │ │ ├── muzero_network.py │ │ └── network_unit.py ├── utils │ ├── CMakeLists.txt │ ├── base_server.h │ ├── color_message.h │ ├── ostream_redirector.h │ ├── paralleler.h │ ├── random.cpp │ ├── random.h │ ├── rotation.h │ ├── sgf_loader.cpp │ ├── sgf_loader.h │ ├── thread_pool.h │ ├── time_system.h │ ├── tqdm.h │ ├── utils.h │ └── vector_map.h └── zero │ ├── CMakeLists.txt │ ├── zero_server.cpp │ └── zero_server.h ├── scripts ├── build.sh ├── start-container.sh ├── zero-server.sh └── zero-worker.sh └── tools ├── analysis.py ├── count-moves.sh ├── count-options-depth-percentile.sh ├── count-options-in-tree.sh ├── dependency_graph_generator ├── README.md └── dependency_graph_generator.py ├── eval.py ├── extract-moves-stats.sh ├── extract-repeated-options.sh ├── fetch-complete-latest.sh ├── fight-eval.sh ├── handle_obs.sh ├── option_analysis.py ├── plot_board.py ├── quick-run.sh ├── self-eval.sh ├── sgf_analysis.py ├── to-sgf.py └── to-video.py /.autopep8: -------------------------------------------------------------------------------- 1 | [pycodestyle] 2 | max-line-length = 200 3 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: LLVM 2 | BreakBeforeBraces: WebKit 3 | IndentWidth: 4 4 | PointerAlignment: Left 5 | AccessModifierOffset: -4 6 | ColumnLimit: 0 7 | AllowShortBlocksOnASingleLine: true 8 | AllowShortCaseLabelsOnASingleLine: true 9 | AllowShortFunctionsOnASingleLine: true 10 | AllowShortIfStatementsOnASingleLine: true 11 | AllowShortLoopsOnASingleLine: true 12 | NamespaceIndentation: Inner 13 | IndentCaseLabels: true 14 | -------------------------------------------------------------------------------- /.githooks/pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # An example hook script to verify what is about to be committed. 4 | # Called by "git commit" with no arguments. The hook should 5 | # exit with non-zero status after issuing an appropriate message if 6 | # it wants to stop the commit. 7 | # 8 | # To enable this hook, rename this file to "pre-commit". 9 | 10 | if git rev-parse --verify HEAD >/dev/null 2>&1 11 | then 12 | against=HEAD 13 | else 14 | # Initial commit: diff against an empty tree object 15 | against=$(git hash-object -t tree /dev/null) 16 | fi 17 | 18 | # Redirect output to stderr. 19 | exec 1>&2 20 | 21 | # check cpp, reference: https://github.com/cpplint/cpplint & https://qiita.com/janus_wel/items/cfc6914d6b7b8bf185b6 22 | format_warning=0 23 | filters='-build/c++11,-build/include_subdir,-build/include_order,-build/namespaces,-legal/copyright,-readability/todo,-runtime/explicit,-runtime/references,-runtime/string,-whitespace/braces,-whitespace/comments,-whitespace/indent,-whitespace/line_length' 24 | for file in $(git diff --staged --name-only $against -- | grep -E '\.[ch](pp)?$'); do 25 | python3 .githooks/cpplint.py --filter=$filters $file > /dev/null 26 | format_warning=$(expr ${format_warning} + $?) 27 | done 28 | 29 | # check clang-format, reference: https://gist.github.com/alexeagle/c8ed91b14a407342d9a8e112b5ac7dab 30 | for file in $(git diff --staged --name-only $against -- | grep -E '\.[ch](pp)?$'); do 31 | clangformat_output=$(git-clang-format --diff -q $file) 32 | [[ "$clangformat_output" != *"no modified files to format"* ]] && [[ "$clangformat_output" != *"clang-format did not modify any files"* ]] && [[ ! -z "$clangformat_output" ]] && git-clang-format --diff -q $file && format_warning=$(expr ${format_warning} + 1) 33 | done 34 | 35 | # check python format, reference: https://github.com/hhatto/autopep8 36 | [ $(pip list | grep pycodestyle | wc -l) -eq 1 ] || pip install pycodestyle >/dev/null 2>/dev/null 37 | for file in $(git diff --staged --name-only $against -- | grep -E '\.py?$'); do 38 | autopep8_output=$(python3 .githooks/autopep8.py -d -a -a --global-config .autopep8 $file 2>&1) 39 | [[ ! -z "$autopep8_output" ]] && python3 .githooks/autopep8.py -d -a -a --global-config .autopep8 $file && format_warning=$(expr ${format_warning} + 1) 40 | done 41 | 42 | # verify if the modified code pass the check 43 | [[ ${format_warning} -ne 0 ]] && echo && echo "Found ${format_warning} format warning(s): check format again!" && exit 1 44 | 45 | # return successfully 46 | exit 0 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # execute and library file 2 | build/ 3 | __pycache__ 4 | 5 | # configuration file 6 | # *.cfg 7 | 8 | # training folder 9 | *_az_*/ 10 | *_gaz_*/ 11 | *_mz_*/ 12 | *_gmz_*/ 13 | 14 | # others 15 | .vscode 16 | .container_root 17 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.16) 2 | 3 | project(minizero) 4 | 5 | # specify the C++ standard 6 | set(CMAKE_CXX_STANDARD 17) 7 | set(CMAKE_CXX_STANDARD_REQUIRED True) 8 | 9 | find_package(Torch REQUIRED) 10 | find_package(Boost COMPONENTS system thread iostreams) 11 | find_package(ale REQUIRED) 12 | find_package(OpenCV REQUIRED) 13 | 14 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) 15 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) 16 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) 17 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -g -Wall -mpopcnt -O3 -pthread") 18 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -Wall -Wno-unused-function -O0 -pthread") 19 | 20 | # for git info 21 | include_directories(${PROJECT_BINARY_DIR}/git_info) 22 | 23 | add_subdirectory(minizero) 24 | add_subdirectory(minizero/actor) 25 | add_subdirectory(minizero/config) 26 | add_subdirectory(minizero/console) 27 | add_subdirectory(minizero/environment) 28 | add_subdirectory(minizero/learner) 29 | add_subdirectory(minizero/network) 30 | add_subdirectory(minizero/utils) 31 | add_subdirectory(minizero/zero) 32 | 33 | string(TOLOWER "${PROJECT_NAME}_${GAME_TYPE}" EXE_FILE_NAME) 34 | set_target_properties(${PROJECT_NAME} PROPERTIES OUTPUT_NAME ${EXE_FILE_NAME}) 35 | -------------------------------------------------------------------------------- /docs/Console.md: -------------------------------------------------------------------------------- 1 | # Console 2 | 3 | MiniZero supports the [Go Text Protocol (GTP)](http://www.lysator.liu.se/~gunnar/gtp/) and has a built-in console for easy communication with human operators or external programs. 4 | 5 | ```bash 6 | tools/quick-run.sh console GAME_TYPE FOLDER|MODEL_FILE [CONF_FILE] [OPTION]... 7 | ``` 8 | 9 | * `GAME_TYPE` sets the target game, e.g., `tictactoe`. 10 | * `FOLDER` or `MODEL_FILE` sets either the folder or the model file (`*.pt`). 11 | * `CONF_FILE` sets the config file for console. 12 | * `OPTION` sets optional arguments, e.g., `-conf_str` sets additional configurations. 13 | 14 | For detailed arguments, run `tools/quick-run.sh console -h`. 15 | 16 | Sample commands: 17 | 18 | ```bash 19 | # run a console with the latest model inside "tictactoe_az_1bx256_n50-cb69d4" using config "tictactoe_play.cfg" 20 | tools/quick-run.sh console tictactoe tictactoe_az_1bx256_n50-cb69d4 tictactoe_play.cfg 21 | 22 | # run a console with a specified model file using config "tictactoe_play.cfg" 23 | tools/quick-run.sh console tictactoe tictactoe_az_1bx256_n50-cb69d4/model/weight_iter_25000.pt tictactoe_play.cfg 24 | 25 | # run a console with the latest model inside "tictactoe_az_1bx256_n50-cb69d4" using its default config file, and overwrite several settings for console 26 | tools/quick-run.sh console tictactoe tictactoe_az_1bx256_n50-cb69d4 -conf_str actor_select_action_by_count=true:actor_use_dirichlet_noise=false:actor_num_simulation=200 27 | ``` 28 | 29 | Note that the console requires a trained network model. 30 | 31 | After the console starts successfully, a message "Successfully started console mode" will be displayed. 32 | Then, use [GTP commands](https://www.gnu.org/software/gnugo/gnugo_19.html) to interact with the program, e.g., `genmove`. 33 | 34 | ## Miscellaneous Console Tips 35 | 36 | ### Attach MiniZero to GoGui 37 | 38 | [GoGui](https://github.com/Remi-Coulom/gogui) provides a graphical interface for board game AI programs, it provides two tools, `gogui-server` and `gogui-client`, to attach programs that support GTP console. 39 | 40 | To attach MiniZero to GoGui, specify the `gogui-server` port via `-p`, which will automatically starts the MiniZero console with the `gogui-server`. 41 | 42 | ```bash 43 | # host the console at port 40000 using gogui-server 44 | tools/quick-run.sh console tictactoe tictactoe_az_1bx256_n50-cb69d4 tictactoe_play.cfg -p 40000 45 | ``` 46 | -------------------------------------------------------------------------------- /docs/Evaluation.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | 3 | MiniZero currently supports two evaluation methods to evaluate program strength: [self-evaluation](#Self-Evaluation), and [fight-evaluation](#Fight-Evaluation). 4 | 5 | ## Self-Evaluation 6 | 7 | Self-evaluation evaluates the relative strengths between different iterations in a training session, i.e., it evaluates whether a network model is continuously improving during traing. 8 | 9 | ```bash 10 | tools/quick-run.sh self-eval GAME_TYPE FOLDER [CONF_FILE] [INTERVAL] [GAMENUM] [OPTION]... 11 | ``` 12 | 13 | * `GAME_TYPE` sets the target game, e.g., `tictactoe`. 14 | * `FOLDER` sets the folder to be evaluated, which should contain the `model/` subfolder. 15 | * `CONF_FILE` sets the config file for evaluation. 16 | * `INTERVAL` sets the iteration interval between each model pair to be evaluated, e.g. `10` indicates to pair the 0th and the 10th models, then the 10th and 20th models, and so on. 17 | * `GAME_NUM` sets the number of games to play for each model pair, e.g., `100`. 18 | * `OPTION` sets optional arguments, e.g., `-conf_str` sets additional configurations. 19 | 20 | For detailed arguments, run `tools/quick-run.sh self-eval -h`. 21 | 22 | Sample commands: 23 | ```bash 24 | # evaluate a TicTacToe training session using "tictactoe_play.cfg", run 100 games for each model pair: 0th vs 10th, 10th vs 20th, ... 25 | tools/quick-run.sh self-eval tictactoe tictactoe_az_1bx256_n50-cb69d4 tictactoe_play.cfg 10 100 26 | 27 | # evaluate a TicTacToe training session using its training config, overwrite several settings for evaluation 28 | tools/quick-run.sh self-eval tictactoe tictactoe_az_1bx256_n50-cb69d4 tictactoe_az_1bx256_n50-cb69d4/*.cfg 10 100 -conf_str actor_select_action_by_count=true:actor_use_dirichlet_noise=false:actor_num_simulation=200 29 | 30 | # use more threads for faster evaluation 31 | tools/quick-run.sh self-eval tictactoe tictactoe_az_1bx256_n50-cb69d4 tictactoe_play.cfg 10 100 --num_threads 20 32 | ``` 33 | 34 | Note that evaluation is unnecessary for Atari games. 35 | 36 | The evaluation results are stored inside `FOLDER`, in a subfolder named `self_eval` by default, which contains the following records: 37 | * `elo.csv` saves the evaluated model strength in Elo rating. 38 | * `elo.png` plots the Elo rating of `elo.csv`. 39 | * `5000_vs_0`, `10000_vs_5000`, and other folders keep game trajectory records for each evaluated model pair. 40 | 41 | ## Fight-Evaluation 42 | 43 | Fight-evaluation evaluates the relative strengths between the same iterations of two training sessions, i.e., it compares the learning results of two network models. 44 | 45 | ```bash 46 | tools/quick-run.sh fight-eval GAME_TYPE FOLDER1 FOLDER2 [CONF_FILE1] [CONF_FILE2] [INTERVAL] [GAMENUM] [OPTION]... 47 | ``` 48 | 49 | * `GAME_TYPE` sets the target game, e.g., `tictactoe`. 50 | * `FOLDER1` and `FOLDER2` set the two folders to be evaluated. 51 | * `CONF_FILE1` and `CONF_FILE2` set the config files for both folders; if `CONF_FILE2` is unspecified, `FOLDER2` will uses `CONF_FILE1` for evaluation. 52 | * `INTERVAL` sets the iteration interval between each model pair to be evaluated, e.g. `10` indicates to match the ith models of both folders, then the i+10th models, and so on. 53 | * `GAME_NUM` sets the number of games to play for each model pair, e.g., `100`. 54 | * `OPTION` sets optional arguments, e.g., `-conf_str` sets additional configurations. 55 | 56 | For detailed arguments, run `tools/quick-run.sh fight-eval -h`. 57 | 58 | Sample commands: 59 | 60 | ```bash 61 | # evaluate two training results using "tictactoe_play.cfg" for both programs, run 100 games for each model pair 62 | tools/quick-run.sh fight-eval tictactoe tictactoe_az_1bx256_n50-cb69d4 tictactoe_az_1bx256_n50-731a0f tictactoe_play.cfg 10 100 63 | 64 | # evaluate two training results using "tictactoe_cb69d4.cfg" and "tictactoe_731a0f.cfg" for the former and the latter, respectively 65 | tools/quick-run.sh fight-eval tictactoe tictactoe_az_1bx256_n50-cb69d4 tictactoe_az_1bx256_n50-731a0f tictactoe_cb69d4.cfg tictactoe_731a0f.cfg 10 100 66 | ``` 67 | 68 | The evaluation results are stored inside `FOLDER1`, in a subfolder named `[FOLDER1]_vs_[FOLDER2]_eval` by default, which contains the following records: 69 | * `elo.csv` saves the evaluation statistics and strength comparisons of all evaluated model pairs. 70 | * `elo.png` plots the Elo rating comparisons reported in `elo.csv`. 71 | * `0`, `5000`, and other folders keep game trajectory records for each evaluated model pair. 72 | 73 | > **Note** 74 | > Before the fight-evaluation, it is suggested that a self-evaluation for `FOLDER1` be run first to generate a baseline strength, which is necessary for the strength comparison. 75 | 76 | ## Miscellaneous Evaluation Tips 77 | 78 | ### Configurations for evaluation 79 | 80 | Evaluation requires a different configuration from training, e.g., use more simulations and disable noise to always select the best action. 81 | ``` 82 | actor_num_simulation=400 83 | actor_select_action_by_count=true 84 | actor_select_action_by_softmax_count=false 85 | actor_use_dirichlet_noise=false 86 | actor_use_gumbel_noise=false 87 | ``` 88 | 89 | In addition, sometimes played games become too similar. 90 | To prevent this, use random rotation (for AlphaZero only) or even add softmax/noise back. 91 | ``` 92 | actor_use_random_rotation_features=true 93 | actor_select_action_by_count=false 94 | actor_select_action_by_softmax_count=true 95 | ``` 96 | -------------------------------------------------------------------------------- /gridworld_maps/gridworld_000.txt: -------------------------------------------------------------------------------- 1 | XXXXXXXXXXXXXXXXXXXX 2 | X XXXXXX X XXX 3 | X XXXXXX XX X XXX 4 | XXX XXXXXX XX X XXX 5 | X X XX X XXX 6 | X X XX XXX XX X XXX 7 | X XXX XX X XXX 8 | X XXXXXX XX X XXX 9 | X X X XXX 10 | XXXXXXXXXXXXXXXXXXXX 11 | -------------------------------------------------------------------------------- /minizero/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB SRCS *.cpp) 2 | 3 | add_executable(minizero ${SRCS}) 4 | target_link_libraries( 5 | minizero 6 | actor 7 | config 8 | console 9 | environment 10 | network 11 | utils 12 | zero 13 | ${TORCH_LIBRARIES} 14 | ) -------------------------------------------------------------------------------- /minizero/actor/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB_RECURSE SRCS *.cpp) 2 | 3 | add_library(actor ${SRCS}) 4 | target_include_directories(actor PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) 5 | target_link_libraries( 6 | actor 7 | config 8 | environment 9 | network 10 | utils 11 | ${Boost_LIBRARIES} 12 | ${TORCH_LIBRARIES} 13 | ) -------------------------------------------------------------------------------- /minizero/actor/actor_group.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_actor.h" 4 | #include "network.h" 5 | #include "paralleler.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | namespace minizero::actor { 15 | 16 | class ThreadSharedData : public utils::BaseSharedData { 17 | public: 18 | int getAvailableActorIndex(); 19 | void outputGame(const std::shared_ptr& actor); 20 | std::pair calculateTrainingDataRange(const std::shared_ptr& actor); 21 | 22 | bool do_cpu_job_; 23 | int actor_index_; 24 | std::mutex mutex_; 25 | std::vector> actors_; 26 | std::vector> networks_; 27 | std::vector>> network_outputs_; 28 | }; 29 | 30 | class SlaveThread : public utils::BaseSlaveThread { 31 | public: 32 | SlaveThread(int id, std::shared_ptr shared_data) 33 | : BaseSlaveThread(id, shared_data) {} 34 | 35 | void initialize() override; 36 | void runJob() override; 37 | bool isDone() override { return false; } 38 | 39 | protected: 40 | virtual bool doCPUJob(); 41 | virtual void doGPUJob(); 42 | virtual void handleSearchDone(int actor_id); 43 | inline std::shared_ptr getSharedData() { return std::static_pointer_cast(shared_data_); } 44 | }; 45 | 46 | class ActorGroup : public utils::BaseParalleler { 47 | public: 48 | ActorGroup() {} 49 | 50 | void run(); 51 | void initialize() override; 52 | void summarize() override {} 53 | 54 | protected: 55 | virtual void createNeuralNetworks(); 56 | virtual void createActors(); 57 | virtual void handleIO(); 58 | virtual void handleCommand(); 59 | virtual void handleCommand(const std::string& command_prefix, const std::string& command); 60 | 61 | void createSharedData() override { shared_data_ = std::make_shared(); } 62 | std::shared_ptr newSlaveThread(int id) override { return std::make_shared(id, shared_data_); } 63 | inline std::shared_ptr getSharedData() { return std::static_pointer_cast(shared_data_); } 64 | 65 | bool running_; 66 | std::deque commands_; 67 | std::unordered_set ignored_commands_; 68 | }; 69 | 70 | } // namespace minizero::actor 71 | -------------------------------------------------------------------------------- /minizero/actor/base_actor.cpp: -------------------------------------------------------------------------------- 1 | #include "base_actor.h" 2 | #include "configuration.h" 3 | #include 4 | #include 5 | 6 | namespace minizero::actor { 7 | 8 | void BaseActor::reset() 9 | { 10 | env_.reset(); 11 | action_info_history_.clear(); 12 | resetSearch(); 13 | } 14 | 15 | void BaseActor::resetSearch() 16 | { 17 | nn_evaluation_batch_id_ = -1; 18 | if (!search_) { search_ = createSearch(); } 19 | search_->reset(); 20 | } 21 | 22 | bool BaseActor::act(const Action& action) 23 | { 24 | bool can_act = env_.act(action); 25 | if (can_act) { 26 | action_info_history_.resize(env_.getActionHistory().size()); 27 | action_info_history_.back() = getActionInfo(); 28 | } 29 | return can_act; 30 | } 31 | 32 | bool BaseActor::act(const std::vector& action_string_args) 33 | { 34 | bool can_act = env_.act(action_string_args); 35 | if (can_act) { 36 | action_info_history_.resize(env_.getActionHistory().size()); 37 | action_info_history_.back() = getActionInfo(); 38 | } 39 | return can_act; 40 | } 41 | 42 | std::string BaseActor::getRecord(const std::unordered_map& tags /* = {} */) const 43 | { 44 | EnvironmentLoader env_loader; 45 | env_loader.loadFromEnvironment(env_, action_info_history_); 46 | env_loader.addTag("EV", config::nn_file_name.substr(config::nn_file_name.find_last_of('/') + 1)); 47 | 48 | // if the game is not ended, then treat the game as a resign game, where the next player is the lose side 49 | if (!isEnvTerminal()) { 50 | float result = env_.getEvalScore(true); 51 | std::ostringstream oss; 52 | oss << result; 53 | env_loader.addTag("RE", oss.str()); 54 | } 55 | for (auto tag : tags) { env_loader.addTag(tag.first, tag.second); } 56 | return env_loader.toString(); 57 | } 58 | 59 | std::vector> BaseActor::getActionInfo() const 60 | { 61 | std::vector> action_info; 62 | action_info.push_back({"P", getMCTSPolicy()}); 63 | action_info.push_back({"V", getMCTSValue()}); 64 | action_info.push_back({"R", getEnvReward()}); 65 | return action_info; 66 | } 67 | 68 | } // namespace minizero::actor 69 | -------------------------------------------------------------------------------- /minizero/actor/base_actor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "environment.h" 4 | #include "network.h" 5 | #include "search.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace minizero::actor { 13 | 14 | using namespace minizero; 15 | 16 | class BaseActor { 17 | public: 18 | BaseActor() {} 19 | virtual ~BaseActor() = default; 20 | 21 | virtual void reset(); 22 | virtual void resetSearch(); 23 | bool act(const Action& action); 24 | bool act(const std::vector& action_string_args); 25 | virtual std::string getRecord(const std::unordered_map& tags = {}) const; 26 | 27 | inline bool isEnvTerminal() const { return env_.isTerminal(); } 28 | inline const float getEvalScore() const { return env_.getEvalScore(); } 29 | inline Environment& getEnvironment() { return env_; } 30 | inline const Environment& getEnvironment() const { return env_; } 31 | inline const int getNNEvaluationBatchIndex() const { return nn_evaluation_batch_id_; } 32 | inline std::vector>>& getActionInfoHistory() { return action_info_history_; } 33 | inline const std::vector>>& getActionInfoHistory() const { return action_info_history_; } 34 | 35 | virtual Action think(bool with_play = false, bool display_board = false) = 0; 36 | virtual void beforeNNEvaluation() = 0; 37 | virtual void afterNNEvaluation(const std::shared_ptr& network_output) = 0; 38 | virtual bool isSearchDone() const = 0; 39 | virtual Action getSearchAction() const = 0; 40 | virtual bool isResign() const = 0; 41 | virtual std::string getSearchInfo() const = 0; 42 | virtual void setNetwork(const std::shared_ptr& network) = 0; 43 | virtual std::shared_ptr createSearch() = 0; 44 | 45 | protected: 46 | virtual std::vector> getActionInfo() const; 47 | virtual std::string getMCTSPolicy() const = 0; 48 | virtual std::string getMCTSValue() const = 0; 49 | virtual std::string getEnvReward() const = 0; 50 | 51 | int nn_evaluation_batch_id_; 52 | Environment env_; 53 | std::shared_ptr search_; 54 | std::vector>> action_info_history_; 55 | }; 56 | 57 | } // namespace minizero::actor 58 | -------------------------------------------------------------------------------- /minizero/actor/create_actor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_actor.h" 4 | #include "configuration.h" 5 | #include "zero_actor.h" 6 | #include 7 | 8 | namespace minizero::actor { 9 | 10 | inline std::shared_ptr createActor(uint64_t tree_node_size, const std::shared_ptr& network) 11 | { 12 | auto actor = std::make_shared(tree_node_size); 13 | actor->setNetwork(network); 14 | actor->reset(); 15 | return actor; 16 | 17 | assert(false); 18 | return nullptr; 19 | } 20 | 21 | } // namespace minizero::actor 22 | -------------------------------------------------------------------------------- /minizero/actor/gumbel_zero.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "mcts.h" 4 | #include 5 | #include 6 | #include 7 | 8 | namespace minizero::actor { 9 | 10 | class GumbelZero { 11 | public: 12 | std::string getMCTSPolicy(const std::shared_ptr& mcts) const; 13 | MCTSNode* decideActionNode(const std::shared_ptr& mcts); 14 | std::vector selection(const std::shared_ptr& mcts); 15 | void sequentialHalving(const std::shared_ptr& mcts); 16 | void sortCandidatesByScore(const std::shared_ptr& mcts); 17 | 18 | private: 19 | int sample_size_; 20 | int simulation_budget_; 21 | std::vector candidates_; 22 | }; 23 | 24 | } // namespace minizero::actor 25 | -------------------------------------------------------------------------------- /minizero/actor/search.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace minizero::actor { 4 | 5 | class Search { 6 | public: 7 | Search() {} 8 | virtual ~Search() = default; 9 | 10 | virtual void reset() = 0; 11 | }; 12 | 13 | } // namespace minizero::actor 14 | -------------------------------------------------------------------------------- /minizero/actor/tree.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace minizero::actor { 8 | 9 | template 10 | class TreeData { 11 | public: 12 | TreeData() { reset(); } 13 | 14 | inline void reset() { data_.clear(); } 15 | inline int store(const Data& data) 16 | { 17 | int index = data_.size(); 18 | data_.push_back(data); 19 | return index; 20 | } 21 | inline const Data& getData(int index) const 22 | { 23 | assert(index >= 0 && index < size()); 24 | return data_[index]; 25 | } 26 | inline int size() const { return data_.size(); } 27 | 28 | private: 29 | std::vector data_; 30 | }; 31 | 32 | class TreeNode { 33 | public: 34 | TreeNode() {} 35 | virtual ~TreeNode() = default; 36 | 37 | virtual void reset() = 0; 38 | virtual std::string toString() const = 0; 39 | virtual bool displayInTreeLog() const { return true; } 40 | 41 | inline bool isLeaf() const { return (num_children_ == 0); } 42 | inline void setAction(Action action) { action_ = action; } 43 | inline void setNumChildren(int num_children) { num_children_ = num_children; } 44 | inline void setFirstChild(TreeNode* first_child) { first_child_ = first_child; } 45 | inline void setOptionChild(TreeNode* option_child) { option_child_ = option_child; } 46 | inline void setParent(TreeNode* parent) { parent_ = parent; } 47 | inline Action getAction() const { return action_; } 48 | inline int getNumChildren() const { return num_children_; } 49 | inline virtual TreeNode* getChild(int index) const { return (index < num_children_ ? first_child_ + index : nullptr); } 50 | inline virtual TreeNode* getOptionChild() const { return option_child_; } 51 | inline virtual TreeNode* getParent() const { return parent_; } 52 | 53 | protected: 54 | Action action_; 55 | int num_children_; 56 | TreeNode* first_child_; 57 | TreeNode* option_child_; 58 | TreeNode* parent_; 59 | }; 60 | 61 | class Tree { 62 | public: 63 | Tree(uint64_t tree_node_size) 64 | : tree_node_size_(tree_node_size), 65 | nodes_(nullptr) 66 | { 67 | assert(tree_node_size >= 0); 68 | } 69 | 70 | inline void reset() 71 | { 72 | if (!nodes_) { nodes_ = createTreeNodes(1 + tree_node_size_); } 73 | current_node_size_ = 1; 74 | getRootNode()->reset(); 75 | } 76 | 77 | inline TreeNode* allocateNodes(int size) 78 | { 79 | assert(current_node_size_ + size <= 1 + tree_node_size_); 80 | TreeNode* node = getNodeIndex(current_node_size_); 81 | current_node_size_ += size; 82 | return node; 83 | } 84 | 85 | std::string toString(const std::string& env_string) 86 | { 87 | assert(!env_string.empty() && env_string.back() == ')'); 88 | std::ostringstream oss; 89 | TreeNode* pRoot = getRootNode(); 90 | std::string env_prefix = env_string.substr(0, env_string.size() - 1); 91 | oss << env_prefix << "C[" << pRoot->toString() << "]" << getTreeInfo_r(pRoot) << ")"; 92 | return oss.str(); 93 | } 94 | 95 | std::string getTreeInfo_r(const TreeNode* node) const 96 | { 97 | std::ostringstream oss; 98 | 99 | int numChildren = 0; 100 | for (int i = 0; i < node->getNumChildren(); ++i) { 101 | TreeNode* child = node->getChild(i); 102 | if (child->isLeaf()) { continue; } 103 | ++numChildren; 104 | } 105 | 106 | for (int i = 0; i < node->getNumChildren(); ++i) { 107 | TreeNode* child = node->getChild(i); 108 | if (!child->displayInTreeLog()) { continue; } 109 | if (numChildren > 1) { oss << "("; } 110 | oss << playerToChar(child->getAction().getPlayer()) 111 | << "[" << child->getAction().getActionID() << "]" 112 | << "C[" << child->toString() << "]" << getTreeInfo_r(child); 113 | if (numChildren > 1) { oss << ")"; } 114 | } 115 | return oss.str(); 116 | } 117 | 118 | inline TreeNode* getRootNode() { return &nodes_[0]; } 119 | inline const TreeNode* getRootNode() const { return &nodes_[0]; } 120 | 121 | protected: 122 | virtual TreeNode* createTreeNodes(uint64_t tree_node_size) = 0; 123 | virtual TreeNode* getNodeIndex(int index) = 0; 124 | 125 | uint64_t tree_node_size_; 126 | uint64_t current_node_size_; 127 | TreeNode* nodes_; 128 | }; 129 | 130 | } // namespace minizero::actor 131 | -------------------------------------------------------------------------------- /minizero/actor/zero_actor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "alphazero_network.h" 4 | #include "base_actor.h" 5 | #include "gumbel_zero.h" 6 | #include "mcts.h" 7 | #include "muzero_network.h" 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | namespace minizero::actor { 16 | 17 | class MCTSSearchData { 18 | public: 19 | std::string search_info_; 20 | std::string selection_path_; 21 | MCTSNode* selected_node_; 22 | std::vector node_path_; 23 | void clear(); 24 | }; 25 | 26 | class ZeroActor : public BaseActor { 27 | public: 28 | ZeroActor(uint64_t tree_node_size) 29 | : tree_node_size_(tree_node_size) 30 | { 31 | option_count_ = 0; 32 | alphazero_network_ = nullptr; 33 | muzero_network_ = nullptr; 34 | } 35 | 36 | void reset() override; 37 | void resetSearch() override; 38 | Action think(bool with_play = false, bool display_board = false) override; 39 | void beforeNNEvaluation() override; 40 | void afterNNEvaluation(const std::shared_ptr& network_output) override; 41 | bool isSearchDone() const override { return getMCTS()->reachMaximumSimulation(); } 42 | Action getSearchAction() const override { return mcts_search_data_.selected_node_->getAction(); } 43 | bool isResign() const override { return enable_resign_ && getMCTS()->isResign(mcts_search_data_.selected_node_); } 44 | std::string getSearchInfo() const override { return mcts_search_data_.search_info_; } 45 | void setNetwork(const std::shared_ptr& network) override; 46 | std::shared_ptr createSearch() override { return std::make_shared(tree_node_size_); } 47 | std::shared_ptr getMCTS() { return std::static_pointer_cast(search_); } 48 | const std::shared_ptr getMCTS() const { return std::static_pointer_cast(search_); } 49 | 50 | protected: 51 | std::vector> getActionInfo() const override; 52 | std::string getMCTSPolicy() const override { return (config::actor_use_gumbel ? gumbel_zero_.getMCTSPolicy(getMCTS()) : getMCTS()->getSearchDistributionString()); } 53 | std::string getMCTSValue() const override { return std::to_string(getMCTS()->getRootNode()->getMean()); } 54 | std::string getEnvReward() const override; 55 | 56 | virtual void step(); 57 | virtual void handleSearchDone(); 58 | virtual MCTSNode* decideActionNode(); 59 | virtual void addNoiseToNodeChildren(MCTSNode* node); 60 | virtual std::vector selection() { return (config::actor_use_gumbel ? gumbel_zero_.selection(getMCTS()) : getMCTS()->select()); } 61 | 62 | std::vector calculateAlphaZeroActionPolicy(const Environment& env_transition, const std::shared_ptr& alphazero_output, const utils::Rotation& rotation); 63 | std::vector calculateMuZeroActionPolicy(MCTSNode* leaf_node, const std::shared_ptr& muzero_output); 64 | virtual Environment getEnvironmentTransition(const std::vector& node_path); 65 | 66 | bool enable_resign_; 67 | bool enable_option_; 68 | int option_count_; 69 | int p_max_id_; 70 | std::string used_options_; 71 | std::string depth_; 72 | std::string option_str_; 73 | GumbelZero gumbel_zero_; 74 | uint64_t tree_node_size_; 75 | MCTSSearchData mcts_search_data_; 76 | utils::Rotation feature_rotation_; 77 | std::shared_ptr alphazero_network_; 78 | std::shared_ptr muzero_network_; 79 | 80 | private: 81 | void setAlphaZeroOptionInfo(MCTSNode* leaf_node, Environment& env_transition, const std::shared_ptr& alphazero_output, const utils::Rotation& rotation); 82 | void setMuZeroOptionInfo(MCTSNode* leaf_node, const std::shared_ptr& muzero_output); 83 | std::pair, std::vector>> calculateLegalOption(Environment& env_transition, env::Player turn, const std::vector> option, utils::Rotation rotation = utils::Rotation::kRotationNone); 84 | std::vector calculateOption(env::Player turn, const std::vector> option, utils::Rotation rotation = utils::Rotation::kRotationNone); 85 | void setOptionInfo(MCTSNode* leaf_node, const std::vector> option, const std::vector option_actions, const std::vector> option_legal_actions, utils::Rotation rotation = utils::Rotation::kRotationNone); 86 | }; 87 | 88 | } // namespace minizero::actor 89 | -------------------------------------------------------------------------------- /minizero/config/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB SRCS *.cpp) 2 | 3 | add_library(config ${SRCS}) 4 | target_include_directories(config PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) 5 | target_link_libraries(config) 6 | target_compile_definitions(config PUBLIC ${GAME_TYPE}) -------------------------------------------------------------------------------- /minizero/config/configuration.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "configure_loader.h" 4 | #include 5 | 6 | namespace minizero::config { 7 | 8 | // program parameters 9 | extern int program_seed; 10 | extern bool program_auto_seed; 11 | extern bool program_quiet; 12 | 13 | // actor parameters 14 | extern int actor_num_simulation; 15 | extern float actor_mcts_puct_base; 16 | extern float actor_mcts_puct_init; 17 | extern float actor_mcts_reward_discount; 18 | extern int actor_mcts_think_batch_size; 19 | extern float actor_mcts_think_time_limit; 20 | extern bool actor_mcts_value_rescale; 21 | extern char actor_mcts_value_flipping_player; 22 | extern bool actor_select_action_by_count; 23 | extern bool actor_select_action_by_softmax_count; 24 | extern float actor_select_action_softmax_temperature; 25 | extern bool actor_select_action_softmax_temperature_decay; 26 | extern bool actor_use_random_rotation_features; 27 | extern bool actor_use_dirichlet_noise; 28 | extern float actor_dirichlet_noise_alpha; 29 | extern float actor_dirichlet_noise_epsilon; 30 | extern bool actor_use_gumbel; 31 | extern bool actor_use_gumbel_noise; 32 | extern int actor_gumbel_sample_size; 33 | extern float actor_gumbel_sigma_visit_c; 34 | extern float actor_gumbel_sigma_scale_c; 35 | extern float actor_resign_threshold; 36 | 37 | // zero parameters 38 | extern int zero_num_threads; 39 | extern int zero_num_parallel_games; 40 | extern int zero_server_port; 41 | extern std::string zero_training_directory; 42 | extern int zero_num_games_per_iteration; 43 | extern int zero_start_iteration; 44 | extern int zero_end_iteration; 45 | extern int zero_replay_buffer; 46 | extern float zero_disable_resign_ratio; 47 | extern float zero_disable_option_ratio; 48 | extern int zero_actor_intermediate_sequence_length; 49 | extern std::string zero_actor_ignored_command; 50 | extern bool zero_server_accept_different_model_games; 51 | 52 | // learner parameters 53 | extern bool learner_use_per; 54 | extern float learner_per_alpha; 55 | extern float learner_per_init_beta; 56 | extern bool learner_per_beta_anneal; 57 | extern int learner_training_step; 58 | extern int learner_training_display_step; 59 | extern int learner_batch_size; 60 | extern int learner_muzero_unrolling_step; 61 | extern int learner_n_step_return; 62 | extern float learner_learning_rate; 63 | extern float learner_momentum; 64 | extern float learner_weight_decay; 65 | extern float learner_value_loss_scale; 66 | extern float learner_option_loss_scale; 67 | extern int learner_num_thread; 68 | 69 | // network parameters 70 | extern std::string nn_file_name; 71 | extern int nn_num_blocks; 72 | extern int nn_num_hidden_channels; 73 | extern int nn_num_value_hidden_channels; 74 | extern std::string nn_type_name; 75 | extern bool nn_use_consistency; 76 | 77 | // option parameters 78 | extern int option_seq_length; 79 | 80 | // environment parameters 81 | extern int env_board_size; 82 | 83 | // environment parameters for specific game 84 | extern std::string env_atari_rom_dir; 85 | extern std::string env_atari_name; 86 | extern bool env_conhex_use_swap_rule; 87 | extern float env_go_komi; 88 | extern std::string env_go_ko_rule; 89 | extern std::string env_gomoku_rule; 90 | extern bool env_gomoku_exactly_five_stones; 91 | extern bool env_havannah_use_swap_rule; 92 | extern bool env_hex_use_swap_rule; 93 | extern bool env_killallgo_use_seki; 94 | extern int env_rubiks_scramble_rotate; 95 | extern int env_surakarta_no_capture_plies; 96 | extern std::string env_gridworld_maps_dir; 97 | extern bool fix_player; 98 | extern bool fix_goal; 99 | 100 | void setConfiguration(ConfigureLoader& cl); 101 | 102 | } // namespace minizero::config 103 | -------------------------------------------------------------------------------- /minizero/config/configure_loader.cpp: -------------------------------------------------------------------------------- 1 | #include "configure_loader.h" 2 | #include 3 | #include 4 | #include 5 | 6 | namespace minizero::config { 7 | 8 | template <> 9 | bool setParameter(bool& ref, const std::string& value) 10 | { 11 | std::string tmp = value; 12 | transform(tmp.begin(), tmp.end(), tmp.begin(), ::toupper); 13 | if (tmp != "TRUE" && tmp != "1" && tmp != "FALSE" && tmp != "0") { return false; } 14 | 15 | ref = (tmp == "TRUE" || tmp == "1"); 16 | return true; 17 | } 18 | 19 | template <> 20 | bool setParameter(std::string& ref, const std::string& value) 21 | { 22 | ref = value; 23 | return true; 24 | } 25 | 26 | template <> 27 | std::string getParameter(bool& ref) 28 | { 29 | std::ostringstream oss; 30 | oss << (ref == true ? "true" : "false"); 31 | return oss.str(); 32 | } 33 | 34 | bool ConfigureLoader::loadFromFile(std::string conf_file) 35 | { 36 | if (conf_file.empty()) { return false; } 37 | 38 | std::string line; 39 | std::ifstream file(conf_file); 40 | if (file.fail()) { 41 | file.close(); 42 | return false; 43 | } 44 | while (std::getline(file, line)) { 45 | if (!setValue(line)) { return false; } 46 | } 47 | 48 | return true; 49 | } 50 | 51 | bool ConfigureLoader::loadFromString(std::string conf_string) 52 | { 53 | if (conf_string.empty()) { return false; } 54 | 55 | std::string line; 56 | std::istringstream iss(conf_string); 57 | while (std::getline(iss, line, ':')) { 58 | if (!setValue(line)) { return false; } 59 | } 60 | 61 | return true; 62 | } 63 | 64 | std::string ConfigureLoader::toString() const 65 | { 66 | std::ostringstream oss; 67 | for (const auto& group_name : group_name_order_) { 68 | oss << "# " << group_name << std::endl; 69 | for (auto parameter : parameter_groups_.at(group_name)) { oss << parameter->toString(); } 70 | oss << std::endl; 71 | } 72 | return oss.str(); 73 | } 74 | 75 | std::string ConfigureLoader::getConfig(std::string key) const 76 | { 77 | for (const auto& group_name : group_name_order_) { 78 | for (auto parameter : parameter_groups_.at(group_name)) { 79 | if (parameter->getKey() == key) { return parameter->toString(); } 80 | } 81 | } 82 | return ""; 83 | } 84 | 85 | void ConfigureLoader::trim(std::string& s) 86 | { 87 | if (s.empty()) { return; } 88 | s.erase(0, s.find_first_not_of(" \t")); 89 | s.erase(s.find_last_not_of(" \t") + 1); 90 | } 91 | 92 | bool ConfigureLoader::setValue(std::string line) 93 | { 94 | if (line.empty() || line[0] == '#') { return true; } 95 | 96 | std::string key = line.substr(0, line.find("=")); 97 | std::string value = line.substr(line.find("=") + 1); 98 | if (value.find("#") != std::string::npos) { value = value.substr(0, value.find("#")); } 99 | std::string group_name = line.substr(line.find("#") + 1); 100 | 101 | trim(key); 102 | trim(value); 103 | trim(group_name); 104 | 105 | if (parameters_.count(key) == 0) { 106 | std::cerr << "Invalid key \"" + key + "\" and value \"" << value << "\"" << std::endl; 107 | return false; 108 | } 109 | 110 | if (!(*parameters_[key])(value)) { 111 | std::cerr << "Unsatisfiable value \"" + value + "\" for option \"" + key + "\"" << std::endl; 112 | return false; 113 | } 114 | 115 | return true; 116 | } 117 | 118 | } // namespace minizero::config 119 | -------------------------------------------------------------------------------- /minizero/config/configure_loader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace minizero::config { 9 | 10 | // parameter setter 11 | template 12 | bool setParameter(T& ref, const std::string& value) 13 | { 14 | std::istringstream iss(value); 15 | iss >> ref; 16 | return (iss && iss.rdbuf()->in_avail() == 0); 17 | } 18 | template <> 19 | bool setParameter(bool& ref, const std::string& value); 20 | template <> 21 | bool setParameter(std::string& ref, const std::string& value); 22 | 23 | // parameter getter 24 | template 25 | std::string getParameter(T& ref) 26 | { 27 | std::ostringstream oss; 28 | oss << ref; 29 | return oss.str(); 30 | } 31 | template <> 32 | std::string getParameter(bool& ref); 33 | 34 | // parameter container 35 | class BaseParameter { 36 | public: 37 | virtual bool operator()(const std::string& value) = 0; 38 | virtual std::string toString() const = 0; 39 | virtual std::string getKey() const = 0; 40 | virtual ~BaseParameter() {} 41 | }; 42 | 43 | template 44 | class Parameter : public BaseParameter { 45 | public: 46 | Parameter(const std::string key, T& ref, const std::string description, Setter setter, Getter getter) 47 | : key_(key), description_(description), ref_(ref), setter_(setter), getter_(getter) 48 | { 49 | } 50 | 51 | std::string toString() const override 52 | { 53 | std::ostringstream oss; 54 | if (description_.size() > 150) { oss << "# " << description_ << std::endl; } 55 | oss << key_ << "=" << getter_(ref_); 56 | if (!description_.empty() && description_.size() <= 150) { oss << " # " << description_; } 57 | oss << std::endl; 58 | return oss.str(); 59 | } 60 | 61 | inline bool operator()(const std::string& value) override { return setter_(ref_, value); } 62 | inline std::string getKey() const override { return key_; } 63 | inline std::string getDescription() const { return description_; } 64 | 65 | private: 66 | std::string key_; 67 | std::string description_; 68 | 69 | T& ref_; 70 | Setter setter_; 71 | Getter getter_; 72 | }; 73 | 74 | class ConfigureLoader { 75 | public: 76 | ConfigureLoader() {} 77 | 78 | virtual ~ConfigureLoader() 79 | { 80 | for (auto& group_name : group_name_order_) { 81 | for (auto parameter : parameter_groups_[group_name]) { delete parameter; } 82 | parameter_groups_[group_name].clear(); 83 | } 84 | } 85 | 86 | template 87 | inline void addParameter(const std::string& key, T& value, const std::string description, const std::string& group_name, Setter setter, Getter getter) 88 | { 89 | if (parameter_groups_.count(group_name) == 0) { group_name_order_.push_back(group_name); } 90 | parameters_[key] = new Parameter(key, value, description, setter, getter); 91 | parameter_groups_[group_name].push_back(parameters_[key]); 92 | } 93 | 94 | template 95 | inline void addParameter(const std::string& key, T& value, const std::string description, const std::string& group_name) 96 | { 97 | addParameter(key, value, description, group_name, setParameter, getParameter); 98 | } 99 | 100 | bool loadFromFile(std::string conf_file); 101 | bool loadFromString(std::string conf_string); 102 | std::string toString() const; 103 | std::string getConfig(std::string key) const; 104 | 105 | private: 106 | void trim(std::string& s); 107 | bool setValue(std::string sLine); 108 | 109 | std::vector group_name_order_; 110 | std::map parameters_; 111 | std::map> parameter_groups_; 112 | }; 113 | 114 | } // namespace minizero::config 115 | -------------------------------------------------------------------------------- /minizero/console/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB SRCS *.cpp) 2 | 3 | add_library(console ${SRCS}) 4 | target_include_directories(console PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) 5 | target_link_libraries( 6 | console 7 | actor 8 | config 9 | environment 10 | network 11 | utils 12 | zero 13 | ${TORCH_LIBRARIES} 14 | ) -------------------------------------------------------------------------------- /minizero/console/console.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_actor.h" 4 | #include "network.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace minizero::console { 11 | 12 | using namespace minizero; 13 | 14 | enum class ConsoleResponse : char { 15 | kFail = '?', 16 | kSuccess = '=' 17 | }; 18 | 19 | class Console { 20 | public: 21 | Console(); 22 | virtual ~Console() = default; 23 | 24 | virtual void initialize(); 25 | virtual void executeCommand(std::string command); 26 | 27 | protected: 28 | class BaseFunction { 29 | public: 30 | virtual ~BaseFunction() = default; 31 | virtual void operator()(const std::vector& args) = 0; 32 | }; 33 | 34 | template 35 | class Function : public BaseFunction { 36 | public: 37 | Function(I* instance, F function) : instance_(instance), function_(function) {} 38 | void operator()(const std::vector& args) { (*instance_.*function_)(args); } 39 | 40 | I* instance_; 41 | F function_; 42 | }; 43 | 44 | template 45 | void RegisterFunction(const std::string& name, I* instance, F function) 46 | { 47 | function_map_[name] = std::make_shared>(instance, function); 48 | } 49 | 50 | void cmdGoguiAnalyzeCommands(const std::vector& args); 51 | void cmdListCommands(const std::vector& args); 52 | void cmdName(const std::vector& args); 53 | void cmdVersion(const std::vector& args); 54 | void cmdProtocalVersion(const std::vector& args); 55 | void cmdClearBoard(const std::vector& args); 56 | void cmdShowBoard(const std::vector& args); 57 | void cmdPlay(const std::vector& args); 58 | void cmdBoardSize(const std::vector& args); 59 | void cmdGenmove(const std::vector& args); 60 | void cmdFinalScore(const std::vector& args); 61 | void cmdPV(const std::vector& args); 62 | void cmdPVString(const std::vector& args); 63 | void cmdGameString(const std::vector& args); 64 | void cmdLoadModel(const std::vector& args); 65 | void cmdGetConfigString(const std::vector& args); 66 | 67 | virtual void calculatePolicyValue(std::vector& policy, float& value, utils::Rotation rotation = utils::Rotation::kRotationNone); 68 | bool checkArgument(const std::vector& args, int min_argc, int max_argc); 69 | void reply(ConsoleResponse response, const std::string& reply); 70 | 71 | std::string command_id_; 72 | std::shared_ptr network_; 73 | std::shared_ptr actor_; 74 | std::map> function_map_; 75 | }; 76 | 77 | } // namespace minizero::console 78 | -------------------------------------------------------------------------------- /minizero/console/mode_handler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "configuration.h" 4 | #include 5 | #include 6 | #include 7 | 8 | namespace minizero::console { 9 | 10 | class ModeHandler { 11 | public: 12 | ModeHandler(); 13 | virtual ~ModeHandler() = default; 14 | 15 | void run(int argc, char* argv[]); 16 | 17 | protected: 18 | class BaseFunction { 19 | public: 20 | virtual ~BaseFunction() = default; 21 | virtual void operator()() = 0; 22 | }; 23 | 24 | template 25 | class Function : public BaseFunction { 26 | public: 27 | Function(I* instance, F function) : instance_(instance), function_(function) {} 28 | void operator()() { (*instance_.*function_)(); } 29 | 30 | I* instance_; 31 | F function_; 32 | }; 33 | 34 | template 35 | void RegisterFunction(const std::string& name, I* instance, F function) 36 | { 37 | function_map_[name] = std::make_shared>(instance, function); 38 | } 39 | 40 | void usage(); 41 | std::string getAllModesString(); 42 | virtual void setDefaultConfiguration(config::ConfigureLoader& cl) { config::setConfiguration(cl); } 43 | void genConfiguration(config::ConfigureLoader& cl, const std::string& sConfigFile); 44 | bool readConfiguration(config::ConfigureLoader& cl, const std::string& sConfigFile, const std::string& sConfigString); 45 | virtual void runConsole(); 46 | virtual void runSelfPlay(); 47 | virtual void runZeroServer(); 48 | virtual void runZeroTrainingName(); 49 | virtual void runEnvTest(); 50 | virtual void runRemoveObs(); 51 | virtual void runRecoverObs(); 52 | virtual void runDecompressStr(); 53 | 54 | std::map> function_map_; 55 | }; 56 | 57 | } // namespace minizero::console 58 | -------------------------------------------------------------------------------- /minizero/environment/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB_RECURSE SRCS *.cpp) 2 | 3 | add_library(environment ${SRCS}) 4 | target_include_directories( 5 | environment PUBLIC 6 | ${CMAKE_CURRENT_SOURCE_DIR} 7 | base 8 | amazons 9 | atari 10 | breakthrough 11 | clobber 12 | conhex 13 | connect6 14 | dotsandboxes 15 | go 16 | gomoku 17 | gridworld 18 | havannah 19 | hex 20 | killallgo 21 | linesofaction 22 | nogo 23 | othello 24 | rubiks 25 | santorini 26 | surakarta 27 | tictactoe 28 | stochastic 29 | stochastic/puzzle2048 30 | ) 31 | target_link_libraries( 32 | environment 33 | config 34 | utils 35 | ale::ale-lib 36 | ${OpenCV_LIBS} 37 | ) 38 | -------------------------------------------------------------------------------- /minizero/environment/atari/obs_recover.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "atari.h" 4 | #include "configuration.h" 5 | #include "paralleler.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | namespace minizero::env::atari { 18 | 19 | struct EnvInfo { 20 | public: 21 | int seed_; 22 | std::vector sgf_ids_; 23 | AtariEnv env_; 24 | AtariEnvLoader env_loader_; // longest 25 | std::vector end_positions_; 26 | }; 27 | 28 | class ObsRecoverThreadSharedData : public utils::BaseSharedData { 29 | public: 30 | EnvInfo* getAvailableEnvInfoPtr(); 31 | int getEnvIndex(const AtariEnvLoader& env_loader); 32 | EnvInfo* getInitEnvInfoPtr(const int seed); 33 | int getInitEnvInfoPtrIndex(const int seed); 34 | void setSgf(std::string file_path); 35 | void addEnvInfo(AtariEnvLoader env_loader, int line_id); 36 | void addEnvInfoToRemove(EnvInfo* env_info_ptr); 37 | void resetMember(); 38 | std::pair getNextEnvPair(); 39 | 40 | std::ifstream original_file_; 41 | std::ofstream processed_file_; 42 | std::vector sgfs_; 43 | std::vector::iterator sgfs_it_; 44 | std::mutex mutex_; 45 | std::map> seed_env_info_; 46 | std::map>::iterator seed_env_info_it_; 47 | std::vector::iterator env_info_it_; 48 | std::vector env_info_to_remove_; 49 | }; 50 | 51 | class ObsRecoverSlaveThread : public utils::BaseSlaveThread { 52 | public: 53 | ObsRecoverSlaveThread(int id, std::shared_ptr shared_data) 54 | : BaseSlaveThread(id, shared_data) {} 55 | 56 | void initialize() override {} 57 | void runJob() override; 58 | bool addEnvironmentLoader(); 59 | void handleEndPosition(std::string& sgf, EnvInfo* env_info_ptr); 60 | void recover(); 61 | bool isDone() override { return false; } 62 | 63 | inline std::shared_ptr getSharedData() { return std::static_pointer_cast(shared_data_); } 64 | }; 65 | 66 | class ObsRecover : public utils::BaseParalleler { 67 | public: 68 | ObsRecover() {} 69 | 70 | void run(std::string& obs_file_path); 71 | void initialize() override; 72 | void summarize() override {} 73 | 74 | inline std::shared_ptr getSharedData() { return std::static_pointer_cast(shared_data_); } 75 | 76 | protected: 77 | void removeEnvInfo(EnvInfo* env_info_ptr); 78 | std::vector getAllSgfPath(const std::string& dir_path); 79 | void runSingleSgf(const std::string& path); 80 | 81 | void createSharedData() override { shared_data_ = std::make_shared(); } 82 | std::shared_ptr newSlaveThread(int id) override { return std::make_shared(id, shared_data_); } 83 | }; 84 | 85 | } // namespace minizero::env::atari 86 | -------------------------------------------------------------------------------- /minizero/environment/atari/obs_remover.cpp: -------------------------------------------------------------------------------- 1 | #include "obs_remover.h" 2 | 3 | namespace minizero::env::atari { 4 | 5 | std::string ObsRemoverThreadSharedData::getAvailableSgfPath() 6 | { 7 | std::lock_guard lock(mutex_); 8 | 9 | if (all_sgf_path_it_ == all_sgf_path_.end()) { return ""; } 10 | 11 | std::string sgf_path = *all_sgf_path_it_; 12 | all_sgf_path_it_++; 13 | 14 | return sgf_path; 15 | } 16 | 17 | void ObsRemoverSlaveThread::runJob() 18 | { 19 | while (removeSingleObs()) {} 20 | } 21 | 22 | bool ObsRemoverSlaveThread::removeSingleObs() 23 | { 24 | const std::string file_path = getSharedData()->getAvailableSgfPath(); 25 | if (file_path.empty()) { return false; } 26 | 27 | std::cout << "Removing obs: " << file_path << std::endl; 28 | std::string remove_obs_file_path = file_path.substr(0, file_path.size() - std::string(".sgf").length()) + "_remove_obs.sgf"; 29 | std::ifstream original_file(file_path); 30 | std::ofstream processed_file(remove_obs_file_path); 31 | 32 | if (!original_file.is_open()) { 33 | std::cerr << "Cannot open " << file_path << "!" << std::endl; 34 | exit(-1); 35 | } 36 | 37 | if (!processed_file.is_open()) { 38 | std::cerr << "Cannot create " << remove_obs_file_path << "!" << std::endl; 39 | exit(-1); 40 | } 41 | 42 | std::string line; 43 | while (getline(original_file, line)) { 44 | size_t start = line.find("OBS["); 45 | size_t end = line.find("]", start); 46 | 47 | if (start == std::string::npos || end == std::string::npos) { 48 | std::cerr << "Wrong file format in " << file_path << std::endl; 49 | exit(-1); 50 | } 51 | 52 | line.replace(start, end - start + 1, "OBS[]"); 53 | processed_file << line << std::endl; 54 | } 55 | 56 | original_file.close(); 57 | processed_file.close(); 58 | 59 | return true; 60 | } 61 | 62 | bool ObsRemover::isProperGame(const std::string& path) 63 | { 64 | std::ifstream file(path); 65 | std::string sgf; 66 | getline(file, sgf); 67 | 68 | AtariEnvLoader env_loader; 69 | env_loader.loadFromString(sgf); 70 | std::string game = env_loader.getTag("GM"); 71 | 72 | return game.find("atari") != std::string::npos; // currently, only support for atari games 73 | } 74 | 75 | void ObsRemover::run(std::string& path) 76 | { 77 | if (path.substr(path.size() - std::string(".sgf").length()) == ".sgf") { 78 | // path is a sgf file 79 | getSharedData()->all_sgf_path_.push_back(path); 80 | } else { 81 | // path is a directory 82 | for (const auto& entry : std::filesystem::directory_iterator(path)) { 83 | if (!entry.is_regular_file()) { continue; } 84 | const std::string path = entry.path().string(); 85 | if (path.find(".sgf") == std::string::npos) { continue; } // not a sgf file 86 | getSharedData()->all_sgf_path_.push_back(path); 87 | } 88 | } 89 | 90 | const std::string& file_path = getSharedData()->all_sgf_path_[0]; 91 | if (!isProperGame(file_path)) { 92 | std::cerr << "Currently, only support recover observation for atari games" << std::endl; 93 | exit(-1); 94 | } 95 | 96 | getSharedData()->all_sgf_path_it_ = getSharedData()->all_sgf_path_.begin(); 97 | for (auto& t : slave_threads_) { t->start(); } 98 | for (auto& t : slave_threads_) { t->finish(); } 99 | } 100 | 101 | void ObsRemover::initialize() 102 | { 103 | std::cout << "Using " << config::zero_num_threads << " threads to remove obs" << std::endl; 104 | createSlaveThreads(config::zero_num_threads); 105 | } 106 | 107 | } // namespace minizero::env::atari 108 | -------------------------------------------------------------------------------- /minizero/environment/atari/obs_remover.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "atari.h" 4 | #include "configuration.h" 5 | #include "paralleler.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | namespace minizero::env::atari { 15 | 16 | class ObsRemoverThreadSharedData : public utils::BaseSharedData { 17 | public: 18 | std::string getAvailableSgfPath(); 19 | 20 | std::vector all_sgf_path_; 21 | std::vector::iterator all_sgf_path_it_; 22 | std::mutex mutex_; 23 | }; 24 | 25 | class ObsRemoverSlaveThread : public utils::BaseSlaveThread { 26 | public: 27 | ObsRemoverSlaveThread(int id, std::shared_ptr shared_data) 28 | : BaseSlaveThread(id, shared_data) {} 29 | 30 | void initialize() override {} 31 | void runJob() override; 32 | bool removeSingleObs(); 33 | 34 | bool isDone() override { return false; } 35 | 36 | inline std::shared_ptr getSharedData() { return std::static_pointer_cast(shared_data_); } 37 | }; 38 | 39 | class ObsRemover : public utils::BaseParalleler { 40 | public: 41 | ObsRemover() {} 42 | 43 | bool isProperGame(const std::string& path); 44 | void run(std::string& dir_path); 45 | void initialize() override; 46 | void summarize() override {} 47 | 48 | inline std::shared_ptr getSharedData() { return std::static_pointer_cast(shared_data_); } 49 | 50 | protected: 51 | void createSharedData() override { shared_data_ = std::make_shared(); } 52 | std::shared_ptr newSlaveThread(int id) override { return std::make_shared(id, shared_data_); } 53 | }; 54 | 55 | } // namespace minizero::env::atari 56 | -------------------------------------------------------------------------------- /minizero/environment/base/base_env.cpp: -------------------------------------------------------------------------------- 1 | #include "base_env.h" 2 | 3 | namespace minizero::env { 4 | 5 | char playerToChar(Player p) 6 | { 7 | switch (p) { 8 | case Player::kPlayerNone: return 'N'; 9 | case Player::kPlayer1: return 'B'; 10 | case Player::kPlayer2: return 'W'; 11 | default: return '?'; 12 | } 13 | } 14 | 15 | Player charToPlayer(char c) 16 | { 17 | switch (c) { 18 | case 'N': return Player::kPlayerNone; 19 | case 'B': 20 | case 'b': return Player::kPlayer1; 21 | case 'W': 22 | case 'w': return Player::kPlayer2; 23 | default: return Player::kPlayerSize; 24 | } 25 | } 26 | 27 | Player getNextPlayer(Player player, int num_player) 28 | { 29 | if (num_player == 1) { 30 | return player; 31 | } else if (num_player == 2) { 32 | return (player == Player::kPlayer1 ? Player::kPlayer2 : Player::kPlayer1); 33 | } 34 | 35 | return Player::kPlayerNone; 36 | } 37 | 38 | Player getPreviousPlayer(Player player, int num_player) 39 | { 40 | if (num_player <= 2) { return getNextPlayer(player, num_player); } 41 | return Player::kPlayerNone; 42 | } 43 | 44 | } // namespace minizero::env 45 | -------------------------------------------------------------------------------- /minizero/environment/breakthrough/breakthrough.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_env.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace minizero::env::breakthrough { 10 | 11 | typedef std::uint64_t BreakthroughBitBoard; 12 | 13 | const std::string kBreakthroughName = "breakthrough"; 14 | const int kBreakthroughMaxBoardSize = 8; // fit 64bits int 15 | const int kBreakthroughMinBoardSize = 5; 16 | const int kBreakthroughNumPlayer = 2; 17 | 18 | extern std::vector kBreakthroughPolicySize; 19 | extern std::vector kBreakthroughIdxToStr; 20 | extern std::vector kBreakthroughIdxToFromIdx; 21 | extern std::vector kBreakthroughIdxToDestIdx; 22 | extern std::unordered_map kBreakthroughStrToIdx; 23 | 24 | void initialize(); 25 | 26 | inline int getMoveIndex(const std::string move) 27 | { 28 | auto it = kBreakthroughStrToIdx.find(move); 29 | if (it == kBreakthroughStrToIdx.end()) { return -1; } 30 | return it->second; 31 | } 32 | 33 | class BreakthroughAction : public BaseBoardAction { 34 | public: 35 | BreakthroughAction() : BaseBoardAction() {} 36 | BreakthroughAction(int action_id, Player player) : BaseBoardAction(action_id, player) {} 37 | BreakthroughAction(const std::vector& action_string_args, int board_size = minizero::config::env_board_size) 38 | { 39 | assert(action_string_args.size() == 2); 40 | assert(action_string_args[0].size() == 1); 41 | player_ = charToPlayer(action_string_args[0][0]); 42 | assert(static_cast(player_) > 0 && static_cast(player_) <= kBreakthroughNumPlayer); // assume kPlayer1 == 1, kPlayer2 == 2, ... 43 | action_id_ = getMoveIndex(action_string_args[1]); 44 | } 45 | 46 | std::string toConsoleString() const override { return kBreakthroughIdxToStr[getActionID()]; } 47 | int getFromID(int board_size = minizero::config::env_board_size) const; 48 | int getDestID(int board_size = minizero::config::env_board_size) const; 49 | }; 50 | 51 | class BreakthroughEnv : public BaseBoardEnv { 52 | public: 53 | BreakthroughEnv() { reset(); } 54 | void reset() override; 55 | bool act(const BreakthroughAction& action) override; 56 | bool act(const std::vector& action_string_args) override; 57 | std::vector getLegalActions() const override; 58 | bool isLegalAction(const BreakthroughAction& action) const override; 59 | bool isTerminal() const override; 60 | float getReward() const override { return 0.0f; } 61 | float getEvalScore(bool is_resign = false) const override; 62 | std::vector getFeatures(utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 63 | std::vector getActionFeatures(const BreakthroughAction& action, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 64 | inline int getNumInputChannels() const override { return 20; } 65 | inline int getPolicySize() const override { return kBreakthroughPolicySize[board_size_]; } 66 | std::string toString() const override; 67 | inline std::string name() const override { return kBreakthroughName + "_" + std::to_string(getBoardSize()) + "x" + std::to_string(getBoardSize()); } 68 | inline int getNumPlayer() const override { return kBreakthroughNumPlayer; } 69 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return position; } 70 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return action_id; } 71 | 72 | private: 73 | BreakthroughBitBoard bitboard_rank1_; 74 | BreakthroughBitBoard bitboard_rank2_; 75 | BreakthroughBitBoard bitboard_reverse_rank2_; 76 | BreakthroughBitBoard bitboard_reverse_rank1_; 77 | 78 | bool isThreatPosition(Player color, int position) const; 79 | BreakthroughBitBoard getThreatSpace(Player color) const; 80 | Player getPlayerAtBoardPos(int pos) const; 81 | 82 | GamePair bitboard_; 83 | std::vector> bitboard_history_; 84 | }; 85 | 86 | class BreakthroughEnvLoader : public BaseBoardEnvLoader { 87 | public: 88 | std::vector getActionFeatures(const int pos, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 89 | inline std::vector getValue(const int pos) const { return {getReturn()}; } 90 | inline std::string name() const override { return kBreakthroughName + "_" + std::to_string(getBoardSize()) + "x" + std::to_string(getBoardSize()); } 91 | inline int getPolicySize() const override { return kBreakthroughPolicySize[board_size_]; } 92 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return position; } 93 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return action_id; } 94 | }; 95 | 96 | } // namespace minizero::env::breakthrough 97 | -------------------------------------------------------------------------------- /minizero/environment/clobber/clobber.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_env.h" 4 | #include "configuration.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace minizero::env::clobber { 11 | 12 | const std::string kClobberName = "clobber"; 13 | const int kClobberNumPlayer = 2; 14 | const int kClobberMinBoardSize = 6; 15 | const int kClobberMaxBoardSize = 10; 16 | const int kNumDirections = 4; 17 | 18 | typedef std::bitset ClobberBitboard; 19 | 20 | class ClobberAction : public BaseAction { 21 | public: 22 | ClobberAction() : BaseAction() {} 23 | ClobberAction(int action_id, Player player) : BaseAction(action_id, player) {} 24 | ClobberAction(const std::vector& action_string_args) : BaseAction() 25 | { 26 | action_id_ = actionStringToID(action_string_args); 27 | player_ = charToPlayer(action_string_args[0][0]); 28 | } 29 | inline Player nextPlayer() const override { return getNextPlayer(getPlayer(), kClobberNumPlayer); } 30 | inline std::string toConsoleString() const override { return actionIDtoString(action_id_); } 31 | inline int getFromPos() const { return getFromPos(action_id_); } 32 | inline int getDestPos() const { return getDestPos(action_id_); } 33 | 34 | private: 35 | int board_size_ = minizero::config::env_board_size; 36 | 37 | int coordinateToID(int c1, int r1, int c2, int r2) const; 38 | int charToPos(char c) const; 39 | int getFromPos(int action_id) const; 40 | int getDestPos(int action_id) const; 41 | int actionStringToID(const std::vector& action_string_args) const; 42 | std::string actionIDtoString(int action_id) const; 43 | }; 44 | 45 | class ClobberEnv : public BaseBoardEnv { 46 | public: 47 | ClobberEnv() 48 | { 49 | assert(getBoardSize() <= kClobberMaxBoardSize && getBoardSize() >= kClobberMinBoardSize); 50 | reset(); 51 | } 52 | 53 | void reset() override; 54 | bool act(const ClobberAction& action) override; 55 | bool act(const std::vector& action_string_args) override; 56 | std::vector getLegalActions() const override; 57 | bool isLegalAction(const ClobberAction& action) const override; 58 | bool isTerminal() const override; 59 | float getReward() const override { return 0.0f; } 60 | 61 | float getEvalScore(bool is_resign = false) const override; 62 | std::vector getFeatures(utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 63 | std::vector getActionFeatures(const ClobberAction& action, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 64 | inline int getNumInputChannels() const override { return 18; } 65 | inline int getNumActionFeatureChannels() const override { return 0; } 66 | inline int getInputChannelHeight() const override { return getBoardSize(); } 67 | inline int getInputChannelWidth() const override { return getBoardSize(); } 68 | inline int getHiddenChannelHeight() const override { return getBoardSize(); } 69 | inline int getHiddenChannelWidth() const override { return getBoardSize(); } 70 | inline int getPolicySize() const override { return kNumDirections * getBoardSize() * getBoardSize(); } 71 | std::string toString() const override; 72 | inline std::string name() const override { return kClobberName; } 73 | inline int getNumPlayer() const override { return kClobberNumPlayer; } 74 | 75 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return position; }; 76 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return action_id; }; 77 | 78 | private: 79 | Player eval() const; 80 | std::string getCoordinateString() const; 81 | Player getPlayerAtBoardPos(int position) const; 82 | 83 | GamePair bitboard_; 84 | std::vector> bitboard_history_; 85 | }; 86 | 87 | class ClobberEnvLoader : public BaseBoardEnvLoader { 88 | public: 89 | std::vector getActionFeatures(const int pos, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 90 | inline std::vector getValue(const int pos) const { return {getReturn()}; } 91 | inline std::string name() const override { return kClobberName; } 92 | inline int getPolicySize() const override { return kNumDirections * getBoardSize() * getBoardSize(); } 93 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return position; }; 94 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return action_id; }; 95 | }; 96 | 97 | } // namespace minizero::env::clobber 98 | -------------------------------------------------------------------------------- /minizero/environment/conhex/conhex.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_env.h" 4 | #include "configuration.h" 5 | #include "conhex_graph.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace minizero::env::conhex { 12 | 13 | const std::string kConHexName = "conhex"; 14 | const int kConHexNumPlayer = 2; 15 | 16 | typedef BaseBoardAction ConHexAction; 17 | 18 | class ConHexEnv : public BaseBoardEnv { 19 | public: 20 | ConHexEnv(); 21 | 22 | void reset() override; 23 | bool act(const ConHexAction& action) override; 24 | bool act(const std::vector& action_string_args) override; 25 | std::vector getLegalActions() const override; 26 | bool isLegalAction(const ConHexAction& action) const override; 27 | bool isTerminal() const override; 28 | float getReward() const override { return 0.0f; } 29 | float getEvalScore(bool is_resign = false) const override; 30 | std::vector getFeatures(utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 31 | std::vector getActionFeatures(const ConHexAction& action, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 32 | inline int getNumInputChannels() const override { return 6; } 33 | inline int getPolicySize() const override { return getBoardSize() * getBoardSize(); } 34 | std::string toString() const override; 35 | inline std::string name() const override { return kConHexName; } 36 | inline int getNumPlayer() const override { return kConHexNumPlayer; } 37 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return position; } 38 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return action_id; } 39 | 40 | private: 41 | bool isPlaceable(int table_id) const; 42 | 43 | ConHexGraph conhex_graph_; 44 | ConHexBitboard invalid_actions_; 45 | }; 46 | 47 | class ConHexEnvLoader : public BaseBoardEnvLoader { 48 | public: 49 | std::vector getActionFeatures(const int pos, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 50 | inline std::vector getValue(const int pos) const { return {getReturn()}; } 51 | inline std::string name() const override { return kConHexName; } 52 | inline int getPolicySize() const override { return getBoardSize() * getBoardSize(); } 53 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return position; } 54 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return action_id; } 55 | }; 56 | 57 | } // namespace minizero::env::conhex 58 | -------------------------------------------------------------------------------- /minizero/environment/conhex/conhex_graph.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_env.h" 4 | #include "conhex_graph_cell.h" 5 | #include "conhex_graph_flag.h" 6 | #include "disjoint_set_union.h" 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace minizero::env::conhex { 13 | 14 | class ConHexGraph { 15 | public: 16 | ConHexGraph(); 17 | 18 | bool isCellCapturedByPlayer(int cell_id, Player player) const; 19 | void placeStone(int hole_idx, Player player); 20 | Player getPlayerAtPos(int hole_idx) const; 21 | inline Player checkWinner() const { return winner_; } 22 | void reset(); 23 | std::string toString() const; 24 | 25 | private: 26 | void initGraph(); 27 | void addCell(std::vector hole_indexes, ConHexGraphEdgeFlag cell_edge_flag); 28 | 29 | DisjointSetUnion graph_dsu_; 30 | std::vector> hole_to_cell_map_; // hole_idx* -> cell_id, on same hole id may have many cell 31 | std::vector> cell_adjacency_list_; // cell_id -> cell_id* , adj list 32 | 33 | std::vector cells_; 34 | std::vector holes_; 35 | Player winner_; 36 | 37 | static const int top_id_ = kConHexBoardSize * kConHexBoardSize; 38 | static const int left_id_ = kConHexBoardSize * kConHexBoardSize + 1; 39 | static const int right_id_ = kConHexBoardSize * kConHexBoardSize + 2; 40 | static const int bottom_id_ = kConHexBoardSize * kConHexBoardSize + 3; 41 | }; 42 | 43 | } // namespace minizero::env::conhex 44 | -------------------------------------------------------------------------------- /minizero/environment/conhex/conhex_graph_cell.cpp: -------------------------------------------------------------------------------- 1 | #include "conhex_graph_cell.h" 2 | 3 | namespace minizero::env::conhex { 4 | 5 | using namespace minizero::utils; 6 | 7 | ConHexGraphCell::ConHexGraphCell(int cell_id, ConHexGraphCellType cell_type) 8 | { 9 | cell_type_ = cell_type; 10 | cell_id_ = cell_id; 11 | capture_player_ = Player::kPlayerNone; // default player 12 | } 13 | 14 | void ConHexGraphCell::placeStone(int hole, Player player) 15 | { 16 | holes_.get(player).set(hole); 17 | ++captured_count_.get(player); 18 | 19 | if (capture_player_ != Player::kPlayerNone) { return; } // if already captured early return 20 | 21 | if ((cell_type_ == ConHexGraphCellType::OUTER && captured_count_.get(player) == 2) || 22 | (cell_type_ == ConHexGraphCellType::INNER && captured_count_.get(player) == 3) || 23 | (cell_type_ == ConHexGraphCellType::CENTER && captured_count_.get(player) == 3)) { 24 | capture_player_ = player; 25 | } 26 | } 27 | 28 | void ConHexGraphCell::setEdgeFlag(ConHexGraphEdgeFlag cell_edge_flag) 29 | { 30 | edge_flag_ = cell_edge_flag; 31 | } 32 | 33 | bool ConHexGraphCell::isEdgeFlag(ConHexGraphEdgeFlag edge_flag) 34 | { 35 | return static_cast(edge_flag_ & edge_flag); 36 | } 37 | 38 | int ConHexGraphCell::getCellId() 39 | { 40 | return cell_id_; 41 | } 42 | 43 | void ConHexGraphCell::reset() 44 | { 45 | holes_.get(Player::kPlayer1).reset(); 46 | holes_.get(Player::kPlayer2).reset(); 47 | captured_count_.get(Player::kPlayer1) = 0; 48 | captured_count_.get(Player::kPlayer2) = 0; 49 | capture_player_ = Player::kPlayerNone; 50 | } 51 | 52 | Player ConHexGraphCell::getCapturedPlayer() const 53 | { 54 | return capture_player_; 55 | } 56 | 57 | } // namespace minizero::env::conhex 58 | -------------------------------------------------------------------------------- /minizero/environment/conhex/conhex_graph_cell.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_env.h" 4 | #include "conhex_graph_flag.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace minizero::env::conhex { 12 | 13 | const int kConHexBoardSize = 9; 14 | typedef std::bitset ConHexBitboard; 15 | 16 | class ConHexGraphCell { 17 | public: 18 | ConHexGraphCell() = default; 19 | ConHexGraphCell(int cell_id, ConHexGraphCellType cell_type); 20 | 21 | int getCellId(); 22 | 23 | Player getCapturedPlayer() const; 24 | void placeStone(int hole_id, Player player); 25 | void setEdgeFlag(ConHexGraphEdgeFlag cell_edge_flag); 26 | bool isEdgeFlag(ConHexGraphEdgeFlag cell_edge_flag); 27 | void reset(); 28 | 29 | private: 30 | ConHexGraphEdgeFlag edge_flag_; 31 | ConHexGraphCellType cell_type_; 32 | Player capture_player_; 33 | GamePair holes_; 34 | GamePair captured_count_; 35 | int cell_id_; 36 | }; 37 | 38 | } // namespace minizero::env::conhex 39 | -------------------------------------------------------------------------------- /minizero/environment/conhex/conhex_graph_flag.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace minizero::env::conhex { 9 | 10 | enum class ConHexGraphEdgeFlag { 11 | NONE = 0x0, 12 | TOP = 0x1, 13 | RIGHT = 0x2, 14 | LEFT = 0x4, 15 | BOTTOM = 0x8, 16 | }; 17 | 18 | inline ConHexGraphEdgeFlag operator|(ConHexGraphEdgeFlag a, ConHexGraphEdgeFlag b) 19 | { 20 | return static_cast(static_cast(a) | static_cast(b)); 21 | } 22 | 23 | inline ConHexGraphEdgeFlag operator&(ConHexGraphEdgeFlag a, ConHexGraphEdgeFlag b) 24 | { 25 | return static_cast(static_cast(a) & static_cast(b)); 26 | } 27 | 28 | enum class ConHexGraphCellType { 29 | NONE = -1, 30 | OUTER = 3, 31 | INNER = 6, 32 | CENTER = 5, 33 | }; 34 | 35 | inline ConHexGraphCellType operator|(ConHexGraphCellType a, ConHexGraphCellType b) 36 | { 37 | return static_cast(static_cast(a) | static_cast(b)); 38 | } 39 | 40 | inline ConHexGraphCellType operator&(ConHexGraphCellType a, ConHexGraphCellType b) 41 | { 42 | return static_cast(static_cast(a) & static_cast(b)); 43 | } 44 | 45 | inline bool operator==(int a, ConHexGraphCellType b) 46 | { 47 | return a == static_cast(b); 48 | } 49 | 50 | inline bool operator==(ConHexGraphCellType a, int b) 51 | { 52 | return static_cast(a) == b; 53 | } 54 | 55 | } // namespace minizero::env::conhex 56 | -------------------------------------------------------------------------------- /minizero/environment/conhex/disjoint_set_union.cpp: -------------------------------------------------------------------------------- 1 | #include "disjoint_set_union.h" 2 | 3 | namespace minizero::env::conhex { 4 | 5 | using namespace minizero::utils; 6 | 7 | DisjointSetUnion::DisjointSetUnion(int size) 8 | { 9 | size_ = size; 10 | reset(); 11 | } 12 | 13 | void DisjointSetUnion::reset() 14 | { 15 | set_size_.resize(size_ + 4, 0); // +4 stands for top/left/right/bottom 16 | parent_.resize(size_ + 4); // +4 stands for top/left/right/bottom 17 | 18 | // DSU reset 19 | for (int i = 0; i < size_ + 4; ++i) { 20 | parent_[i] = i; 21 | set_size_[i] = 0; 22 | } 23 | } 24 | 25 | int DisjointSetUnion::find(int index) 26 | { 27 | if (parent_[index] == index) { return index; } 28 | return parent_[index] = find(parent_[index]); 29 | } 30 | 31 | void DisjointSetUnion::connect(int from_cell_id, int to_cell_id) 32 | { 33 | // same as Union in DSU 34 | int fa = find(from_cell_id), fb = find(to_cell_id); 35 | if (fa == fb) { return; } // already same 36 | if (set_size_[fa] > set_size_[fb]) { std::swap(fa, fb); } 37 | parent_[fb] = fa; 38 | set_size_[fa] = set_size_[fb]; 39 | } 40 | 41 | } // namespace minizero::env::conhex 42 | -------------------------------------------------------------------------------- /minizero/environment/conhex/disjoint_set_union.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_env.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace minizero::env::conhex { 10 | 11 | class DisjointSetUnion { 12 | public: 13 | DisjointSetUnion(int size); 14 | int find(int index); // DSU 15 | void connect(int from_cell_id, int to_cell_id); // DSU 16 | void reset(); 17 | 18 | private: 19 | std::vector parent_; 20 | std::vector set_size_; 21 | int size_; 22 | }; 23 | 24 | } // namespace minizero::env::conhex 25 | -------------------------------------------------------------------------------- /minizero/environment/connect6/connect6.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_env.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace minizero::env::connect6 { 10 | 11 | const std::string kConnect6Name = "connect6"; 12 | const int kConnect6NumPlayer = 2; 13 | const int kConnect6NumWinConnectStone = 6; 14 | const int kMaxConnect6BoardSize = 19; 15 | 16 | typedef std::bitset Connect6Bitboard; 17 | 18 | class Connect6Action : public BaseBoardAction { 19 | public: 20 | Connect6Action() : BaseBoardAction() {} 21 | Connect6Action(int action_id, Player player) : BaseBoardAction(action_id, player) {} 22 | Connect6Action(const std::vector& action_string_args) : BaseBoardAction(action_string_args) {} 23 | inline Player nextPlayer() const { throw std::runtime_error{"MuZero does not support this game"}; } 24 | inline Player nextPlayer(int move_id) const { return move_id > 0 && move_id % 2 == 0 ? getPlayer() : BaseBoardAction::nextPlayer(); } 25 | }; 26 | 27 | class Connect6Env : public BaseBoardEnv { 28 | public: 29 | Connect6Env() 30 | { 31 | assert(getBoardSize() <= kMaxConnect6BoardSize); 32 | reset(); 33 | } 34 | 35 | void reset() override; 36 | bool act(const Connect6Action& action) override; 37 | bool act(const std::vector& action_string_args) override; 38 | std::vector getLegalActions() const override; 39 | bool isLegalAction(const Connect6Action& action) const override; 40 | bool isTerminal() const override; 41 | float getReward() const override { return 0.0f; } 42 | float getEvalScore(bool is_resign = false) const override; 43 | std::vector getFeatures(utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 44 | std::vector getActionFeatures(const Connect6Action& action, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 45 | inline int getNumInputChannels() const override { return 24; } 46 | inline int getPolicySize() const override { return getBoardSize() * getBoardSize(); } 47 | std::string toString() const override; 48 | inline std::string name() const override { return kConnect6Name; } 49 | inline int getNumPlayer() const override { return kConnect6NumPlayer; } 50 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return utils::getPositionByRotating(rotation, position, getBoardSize()); }; 51 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return getRotatePosition(action_id, rotation); }; 52 | 53 | private: 54 | Connect6Bitboard scanThreadSpace(Player p, int target) const; 55 | 56 | Player updateWinner(const Connect6Action& action); 57 | int calculateNumberOfConnection(const Connect6Action& action, std::pair direction); 58 | std::string getCoordinateString() const; 59 | Player getPlayerAtBoardPos(int position) const; 60 | 61 | Player winner_; 62 | GamePair bitboard_; 63 | std::vector> bitboard_history_; 64 | }; 65 | 66 | class Connect6EnvLoader : public BaseBoardEnvLoader { 67 | public: 68 | std::vector getActionFeatures(const int pos, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 69 | inline std::vector getValue(const int pos) const { return {getReturn()}; } 70 | inline std::string name() const override { return kConnect6Name; } 71 | inline int getPolicySize() const override { return getBoardSize() * getBoardSize(); } 72 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return utils::getPositionByRotating(rotation, position, getBoardSize()); }; 73 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return getRotatePosition(action_id, rotation); }; 74 | }; 75 | 76 | } // namespace minizero::env::connect6 77 | -------------------------------------------------------------------------------- /minizero/environment/go/go_area.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "go_unit.h" 4 | 5 | namespace minizero::env::go { 6 | 7 | class GoArea { 8 | public: 9 | GoArea(int id) 10 | : id_(id) 11 | { 12 | reset(); 13 | } 14 | 15 | inline void reset() 16 | { 17 | num_grid_ = 0; 18 | player_ = Player::kPlayerNone; 19 | area_bitboard_.reset(); 20 | neighbor_block_id_bitboard_.reset(); 21 | } 22 | 23 | inline void combineWithArea(GoArea* area) 24 | { 25 | assert(area && player_ == area->getPlayer()); 26 | num_grid_ += area->getNumGrid(); 27 | area_bitboard_ |= area->getAreaBitboard(); 28 | neighbor_block_id_bitboard_ |= area->getNeighborBlockIDBitboard(); 29 | } 30 | 31 | // setter 32 | inline void setNumGrid(int num_grid) { num_grid_ = num_grid; } 33 | inline void setPlayer(Player p) { player_ = p; } 34 | inline void setAreaBitBoard(const GoBitboard& area_bitboard) { area_bitboard_ = area_bitboard; } 35 | inline void addNeighborBlockIDBitboard(int block_id) { neighbor_block_id_bitboard_.set(block_id); } 36 | inline void removeNeighborBlockIDBitboard(int block_id) { neighbor_block_id_bitboard_.reset(block_id); } 37 | 38 | // getter 39 | inline int getID() const { return id_; } 40 | inline int getNumGrid() const { return num_grid_; } 41 | inline Player getPlayer() const { return player_; } 42 | inline GoBitboard& getAreaBitboard() { return area_bitboard_; } 43 | inline const GoBitboard& getAreaBitboard() const { return area_bitboard_; } 44 | inline GoBitboard& getNeighborBlockIDBitboard() { return neighbor_block_id_bitboard_; } 45 | inline const GoBitboard& getNeighborBlockIDBitboard() const { return neighbor_block_id_bitboard_; } 46 | 47 | private: 48 | int id_; 49 | int num_grid_; 50 | Player player_; 51 | GoBitboard area_bitboard_; 52 | GoBitboard neighbor_block_id_bitboard_; 53 | }; 54 | 55 | } // namespace minizero::env::go 56 | -------------------------------------------------------------------------------- /minizero/environment/go/go_block.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "go_unit.h" 4 | 5 | namespace minizero::env::go { 6 | 7 | class GoBlock { 8 | public: 9 | GoBlock(int id) 10 | : id_(id) 11 | { 12 | reset(); 13 | } 14 | 15 | inline void reset() 16 | { 17 | num_grid_ = 0; 18 | num_liberty_ = 0; 19 | player_ = Player::kPlayerNone; 20 | hash_key_ = 0; 21 | grid_bitboard_.reset(); 22 | liberty_bitboard_.reset(); 23 | neighbor_area_id_bitboard_.reset(); 24 | } 25 | 26 | inline void combineWithBlock(GoBlock* block) 27 | { 28 | assert(block); 29 | addHashKey(block->getHashKey()); 30 | grid_bitboard_ |= block->getGridBitboard(); 31 | num_grid_ += block->getNumGrid(); 32 | liberty_bitboard_ |= block->getLibertyBitboard(); 33 | num_liberty_ = liberty_bitboard_.count(); 34 | neighbor_area_id_bitboard_ |= block->getNeighborAreaIDBitboard(); 35 | } 36 | 37 | // setter 38 | inline void setPlayer(Player p) { player_ = p; } 39 | inline void addHashKey(GoHashKey key) { hash_key_ ^= key; } 40 | inline void addGrid(int pos) 41 | { 42 | assert(!grid_bitboard_.test(pos)); 43 | ++num_grid_; 44 | grid_bitboard_.set(pos); 45 | } 46 | inline void addLiberty(int pos) 47 | { 48 | if (liberty_bitboard_.test(pos)) { return; } 49 | liberty_bitboard_.set(pos); 50 | ++num_liberty_; 51 | } 52 | inline void removeLiberty(int pos) 53 | { 54 | if (!liberty_bitboard_.test(pos)) { return; } 55 | liberty_bitboard_.reset(pos); 56 | --num_liberty_; 57 | } 58 | inline void addNeighborAreaIDBitboard(int area_id) { neighbor_area_id_bitboard_.set(area_id); } 59 | inline void removeNeighborAreaIDBitboard(int area_id) { neighbor_area_id_bitboard_.reset(area_id); } 60 | 61 | // getter 62 | inline int getID() const { return id_; } 63 | inline int getNumGrid() const { return num_grid_; } 64 | inline int getNumLiberty() const { return num_liberty_; } 65 | inline Player getPlayer() const { return player_; } 66 | inline GoHashKey getHashKey() const { return hash_key_; } 67 | inline GoBitboard& getGridBitboard() { return grid_bitboard_; } 68 | inline const GoBitboard& getGridBitboard() const { return grid_bitboard_; } 69 | inline GoBitboard& getLibertyBitboard() { return liberty_bitboard_; } 70 | inline const GoBitboard& getLibertyBitboard() const { return liberty_bitboard_; } 71 | inline GoBitboard& getNeighborAreaIDBitboard() { return neighbor_area_id_bitboard_; } 72 | inline const GoBitboard& getNeighborAreaIDBitboard() const { return neighbor_area_id_bitboard_; } 73 | 74 | private: 75 | int id_; 76 | int num_grid_; 77 | int num_liberty_; 78 | Player player_; 79 | GoHashKey hash_key_; 80 | GoBitboard grid_bitboard_; 81 | GoBitboard liberty_bitboard_; 82 | GoBitboard neighbor_area_id_bitboard_; 83 | }; 84 | 85 | } // namespace minizero::env::go 86 | -------------------------------------------------------------------------------- /minizero/environment/go/go_grid.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "go_area.h" 4 | #include "go_block.h" 5 | #include "go_unit.h" 6 | #include 7 | 8 | namespace minizero::env::go { 9 | 10 | class GoGrid { 11 | public: 12 | GoGrid(int position, int board_size) 13 | : position_(position) 14 | { 15 | reset(board_size); 16 | } 17 | 18 | inline void reset(int board_size) 19 | { 20 | player_ = Player::kPlayerNone; 21 | block_ = nullptr; 22 | area_pair_ = GamePair(nullptr, nullptr); 23 | initializeNeighbors(board_size); 24 | } 25 | 26 | // setter 27 | inline void setPlayer(Player p) { player_ = p; } 28 | inline void setArea(Player p, GoArea* a) { area_pair_.set(p, a); } 29 | inline void setBlock(GoBlock* b) { block_ = b; } 30 | 31 | // getter 32 | inline Player getPlayer() const { return player_; } 33 | inline int getPosition() const { return position_; } 34 | inline GoArea* getArea(Player p) { return area_pair_.get(p); } 35 | inline const GoArea* getArea(Player p) const { return area_pair_.get(p); } 36 | inline GamePair& getAreaPair() { return area_pair_; } 37 | inline const GamePair& getAreaPair() const { return area_pair_; } 38 | inline GoBlock* getBlock() { return block_; } 39 | inline const GoBlock* getBlock() const { return block_; } 40 | inline const std::vector& getNeighbors() const { return neighbors_; } 41 | 42 | private: 43 | void initializeNeighbors(int board_size) 44 | { 45 | const std::vector directions = {0, 1, 0, -1}; 46 | int x = position_ % board_size, y = position_ / board_size; 47 | neighbors_.clear(); 48 | for (size_t i = 0; i < directions.size(); ++i) { 49 | int new_x = x + directions[i]; 50 | int new_y = y + directions[(i + 1) % directions.size()]; 51 | if (!isInBoard(new_x, new_y, board_size)) { continue; } 52 | neighbors_.push_back(new_y * board_size + new_x); 53 | } 54 | } 55 | 56 | inline bool isInBoard(int x, int y, int board_size) 57 | { 58 | return (x >= 0 && x < board_size && y >= 0 && y < board_size); 59 | } 60 | 61 | int position_; 62 | Player player_; 63 | GoBlock* block_; 64 | GamePair area_pair_; 65 | std::vector neighbors_; 66 | }; 67 | 68 | } // namespace minizero::env::go 69 | -------------------------------------------------------------------------------- /minizero/environment/go/go_unit.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace minizero::env::go { 8 | 9 | const std::string kGoName = "go"; 10 | const int kGoNumPlayer = 2; 11 | const int kMaxGoBoardSize = 19; 12 | 13 | typedef uint64_t GoHashKey; 14 | typedef std::bitset GoBitboard; 15 | 16 | } // namespace minizero::env::go 17 | -------------------------------------------------------------------------------- /minizero/environment/gomoku/gomoku.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_env.h" 4 | #include "configuration.h" 5 | #include 6 | #include 7 | #include 8 | 9 | namespace minizero::env::gomoku { 10 | 11 | const std::string kGomokuName = "gomoku"; 12 | const int kGomokuNumPlayer = 2; 13 | const int kMaxGomokuBoardSize = 19; 14 | 15 | typedef BaseBoardAction GomokuAction; 16 | 17 | class GomokuEnv : public BaseBoardEnv { 18 | public: 19 | GomokuEnv() 20 | { 21 | assert(getBoardSize() <= kMaxGomokuBoardSize); 22 | reset(); 23 | } 24 | 25 | void reset() override; 26 | bool act(const GomokuAction& action) override; 27 | bool act(const std::vector& action_string_args) override; 28 | std::vector getLegalActions() const override; 29 | bool isLegalAction(const GomokuAction& action) const override; 30 | bool isTerminal() const override; 31 | float getReward() const override { return 0.0f; } 32 | float getEvalScore(bool is_resign = false) const override; 33 | std::vector getFeatures(utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 34 | std::vector getActionFeatures(const GomokuAction& action, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 35 | inline int getNumInputChannels() const override { return 4; } 36 | inline int getPolicySize() const override { return getBoardSize() * getBoardSize(); } 37 | std::string toString() const override; 38 | inline std::string name() const override { return kGomokuName + (config::env_gomoku_rule == "outer_open" ? "_oo_" : "_") + std::to_string(getBoardSize()) + "x" + std::to_string(getBoardSize()); } 39 | inline int getNumPlayer() const override { return kGomokuNumPlayer; } 40 | 41 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return utils::getPositionByRotating(rotation, position, getBoardSize()); }; 42 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return getRotatePosition(action_id, rotation); }; 43 | 44 | private: 45 | Player updateWinner(const GomokuAction& action); 46 | bool isNumberOfConnectionWins(int connection) { return config::env_gomoku_exactly_five_stones ? (connection == 5) : (connection >= 5); } 47 | int calculateNumberOfConnection(int start_pos, std::pair direction); 48 | std::string getCoordinateString() const; 49 | 50 | Player winner_; 51 | std::vector board_; 52 | }; 53 | 54 | class GomokuEnvLoader : public BaseBoardEnvLoader { 55 | public: 56 | std::vector getActionFeatures(const int pos, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 57 | inline std::vector getValue(const int pos) const { return {getReturn()}; } 58 | inline std::string name() const override { return kGomokuName + (config::env_gomoku_rule == "outer_open" ? "_oo_" : "_") + std::to_string(getBoardSize()) + "x" + std::to_string(getBoardSize()); } 59 | inline int getPolicySize() const override { return getBoardSize() * getBoardSize(); } 60 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return utils::getPositionByRotating(rotation, position, getBoardSize()); }; 61 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return getRotatePosition(action_id, rotation); }; 62 | }; 63 | 64 | } // namespace minizero::env::gomoku 65 | -------------------------------------------------------------------------------- /minizero/environment/hex/hex.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_env.h" 4 | #include "configuration.h" 5 | #include 6 | #include 7 | 8 | namespace minizero::env::hex { 9 | 10 | const std::string kHexName = "hex"; 11 | const int kHexNumPlayer = 2; 12 | const int kMaxHexBoardSize = 19; 13 | 14 | typedef BaseBoardAction HexAction; 15 | 16 | enum class Flag { 17 | NONE = 0x0, 18 | EDGE1_CONNECTION = 0x1, // edge1 represents left and bottom for Black and White players, respectively 19 | EDGE2_CONNECTION = 0x2, // edge2 represents right and top for Black and White players, respectively 20 | }; 21 | inline Flag operator|(Flag a, Flag b) 22 | { 23 | return static_cast(static_cast(a) | static_cast(b)); 24 | } 25 | inline Flag operator&(Flag a, Flag b) 26 | { 27 | return static_cast(static_cast(a) & static_cast(b)); 28 | } 29 | struct Cell { 30 | Player player{}; 31 | Flag flags; 32 | }; 33 | 34 | class HexEnv : public BaseBoardEnv { 35 | public: 36 | HexEnv() 37 | { 38 | assert(getBoardSize() <= kMaxHexBoardSize); 39 | reset(); 40 | } 41 | 42 | void reset() override; 43 | bool act(const HexAction& action) override; 44 | bool act(const std::vector& action_string_args) override; 45 | std::vector getLegalActions() const override; 46 | bool isLegalAction(const HexAction& action) const override; 47 | bool isTerminal() const override; 48 | float getReward() const override { return 0.0f; } 49 | float getEvalScore(bool is_resign = false) const override; 50 | std::vector getFeatures(utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 51 | std::vector getActionFeatures(const HexAction& action, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 52 | inline int getNumInputChannels() const override { return 4; } 53 | inline int getPolicySize() const override { return getBoardSize() * getBoardSize(); } 54 | std::string toString() const override; 55 | std::string toStringDebug() const; 56 | inline std::string name() const override { return kHexName + "_" + std::to_string(getBoardSize()) + "x" + std::to_string(getBoardSize()); } 57 | inline int getNumPlayer() const override { return kHexNumPlayer; } 58 | inline Player getWinner() const { return winner_; } 59 | inline const std::vector& getBoard() const { return board_; } 60 | std::vector getWinningStonesPosition() const; 61 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return position; } 62 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return action_id; } 63 | 64 | private: 65 | Player updateWinner(int actionID); 66 | 67 | Player winner_; 68 | std::vector board_; 69 | }; 70 | 71 | class HexEnvLoader : public BaseBoardEnvLoader { 72 | public: 73 | std::vector getActionFeatures(const int pos, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 74 | inline std::vector getValue(const int pos) const { return {getReturn()}; } 75 | inline std::string name() const override { return kHexName + "_" + std::to_string(getBoardSize()) + "x" + std::to_string(getBoardSize()); } 76 | inline int getPolicySize() const override { return getBoardSize() * getBoardSize(); } 77 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return position; } 78 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return action_id; } 79 | }; 80 | 81 | } // namespace minizero::env::hex 82 | -------------------------------------------------------------------------------- /minizero/environment/killallgo/killallgo.cpp: -------------------------------------------------------------------------------- 1 | #include "killallgo.h" 2 | #include "configuration.h" 3 | #include "killallgo_seki_7x7.h" 4 | #include 5 | 6 | namespace minizero::env::killallgo { 7 | 8 | Seki7x7Table g_seki_7x7_table; 9 | 10 | void initialize() 11 | { 12 | go::initialize(); 13 | 14 | constexpr int kSekiTableMinAreaSize = 5; 15 | constexpr int kSekiTableMaxAreaSize = 8; 16 | constexpr auto kSekiDBPath = "7x7_seki.db"; 17 | 18 | if (!g_seki_7x7_table.load(kSekiDBPath)) { 19 | SekiSearch::generateSekiTable(g_seki_7x7_table, kSekiTableMinAreaSize, kSekiTableMaxAreaSize); 20 | g_seki_7x7_table.save(kSekiDBPath); 21 | 22 | std::cerr << "Generate " << kSekiDBPath << " done!" << std::endl; 23 | std::cerr << "Size: " << g_seki_7x7_table.size() << std::endl; 24 | } 25 | } 26 | 27 | bool KillAllGoEnv::isLegalAction(const KillAllGoAction& action) const 28 | { 29 | if (actions_.size() == 1) { return isPassAction(action); } 30 | if (actions_.size() < 3) { return !isPassAction(action) && go::GoEnv::isLegalAction(action); } 31 | return go::GoEnv::isLegalAction(action); 32 | } 33 | 34 | bool KillAllGoEnv::isTerminal() const 35 | { 36 | if (board_size_ == 7 && config::env_killallgo_use_seki && SekiSearch::isSeki(g_seki_7x7_table, *this)) { return true; } 37 | // all black's benson or any white's benson 38 | if (benson_bitboard_.get(Player::kPlayer1).count() == board_size_ * board_size_ || benson_bitboard_.get(Player::kPlayer2).count() > 0) { return true; } 39 | return go::GoEnv::isTerminal(); 40 | } 41 | 42 | float KillAllGoEnv::getEvalScore(bool is_resign) const 43 | { 44 | if (stone_bitboard_.get(Player::kPlayer2).count() == 0 || benson_bitboard_.get(Player::kPlayer1).count() == board_size_ * board_size_) 45 | return 1.0f; // player1 wins 46 | else 47 | return -1.0f; // player2 wins 48 | } 49 | 50 | } // namespace minizero::env::killallgo 51 | -------------------------------------------------------------------------------- /minizero/environment/killallgo/killallgo.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "go.h" 4 | #include 5 | 6 | namespace minizero::env::killallgo { 7 | 8 | const std::string kKillAllGoName = "killallgo"; 9 | const int kKillAllGoNumPlayer = 2; 10 | const int kKillAllGoBoardSize = 7; 11 | 12 | class Seki7x7Table; 13 | extern Seki7x7Table g_seki_7x7_table; 14 | 15 | void initialize(); 16 | 17 | typedef go::GoAction KillAllGoAction; 18 | 19 | class KillAllGoEnv : public go::GoEnv { 20 | public: 21 | KillAllGoEnv(int board_size = minizero::config::env_board_size) 22 | : go::GoEnv(board_size) 23 | { 24 | assert(kKillAllGoBoardSize == minizero::config::env_board_size); 25 | } 26 | 27 | bool isLegalAction(const KillAllGoAction& action) const override; 28 | bool isTerminal() const override; 29 | float getEvalScore(bool is_resign = false) const override; 30 | 31 | inline std::string name() const override { return kKillAllGoName + "_" + std::to_string(board_size_) + "x" + std::to_string(board_size_); } 32 | inline int getNumPlayer() const override { return kKillAllGoNumPlayer; } 33 | }; 34 | 35 | class KillAllGoEnvLoader : public go::GoEnvLoader { 36 | public: 37 | inline std::string name() const override { return kKillAllGoName + "_" + std::to_string(getBoardSize()) + "x" + std::to_string(getBoardSize()); } 38 | }; 39 | 40 | } // namespace minizero::env::killallgo 41 | -------------------------------------------------------------------------------- /minizero/environment/linesofaction/linesofaction.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_env.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace minizero::env::linesofaction { 10 | 11 | typedef std::uint64_t LinesOfActionHashKey; 12 | typedef std::uint64_t LinesOfActionBitBoard; 13 | 14 | extern std::unordered_map kLinesOfActionStrToIdx; 15 | extern std::vector kLinesOfActionIdxToFromIdx; 16 | extern std::vector kLinesOfActionIdxToDestIdx; 17 | extern std::vector kLinesOfActionIdxToStr; 18 | extern std::vector kLinesOfActionSquareToStr; 19 | 20 | const std::string kLinesOfActionName = "linesofaction"; 21 | const int kLinesOfActionNumPlayer = 2; 22 | const int kLinesOfActionBoardSize = 8; // fit 64bits int 23 | 24 | void initialize(); 25 | 26 | inline int getMoveIndex(const std::string move) 27 | { 28 | auto it = kLinesOfActionStrToIdx.find(move); 29 | if (it == kLinesOfActionStrToIdx.end()) { return -1; } 30 | return it->second; 31 | } 32 | 33 | class LinesOfActionAction : public BaseBoardAction { 34 | public: 35 | LinesOfActionAction() : BaseBoardAction() {} 36 | LinesOfActionAction(int action_id, Player player) : BaseBoardAction(action_id, player) {} 37 | LinesOfActionAction(const std::vector& action_string_args, int board_size = minizero::config::env_board_size) 38 | { 39 | assert(action_string_args.size() == 2); 40 | assert(action_string_args[0].size() == 1); 41 | player_ = charToPlayer(action_string_args[0][0]); 42 | assert(static_cast(player_) > 0 && static_cast(player_) <= kLinesOfActionNumPlayer); // assume kPlayer1 == 1, kPlayer2 == 2, ... 43 | action_id_ = getMoveIndex(action_string_args[1]); 44 | } 45 | 46 | std::string toConsoleString() const override { return kLinesOfActionIdxToStr[action_id_]; } 47 | inline int getFromID() const { return kLinesOfActionIdxToFromIdx[action_id_]; } 48 | inline int getDestID() const { return kLinesOfActionIdxToDestIdx[action_id_]; } 49 | }; 50 | 51 | class LinesOfActionEnv : public BaseBoardEnv { 52 | public: 53 | LinesOfActionEnv() : BaseBoardEnv(kLinesOfActionBoardSize) { reset(); } 54 | 55 | void reset() override; 56 | bool act(const LinesOfActionAction& action) override; 57 | bool act(const std::vector& action_string_args) override; 58 | std::vector getLegalActions() const override; 59 | bool isLegalAction(const LinesOfActionAction& action) const override; 60 | bool isTerminal() const override; 61 | float getReward() const override { return 0.0f; } 62 | float getEvalScore(bool is_resign = false) const override; 63 | std::vector getFeatures(utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 64 | std::vector getActionFeatures(const LinesOfActionAction& action, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 65 | inline int getNumInputChannels() const override { return 22; } 66 | inline int getPolicySize() const override { return kLinesOfActionIdxToStr.size(); } 67 | std::string toString() const override; 68 | inline std::string name() const override { return kLinesOfActionName; } 69 | inline int getNumPlayer() const override { return kLinesOfActionNumPlayer; } 70 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return position; } 71 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return action_id; } 72 | LinesOfActionHashKey computeHashKey() const; 73 | LinesOfActionHashKey computeHashKey(const GamePair& bitboard, Player turn) const; 74 | 75 | private: 76 | bool isLegalActionInternal(const LinesOfActionAction& action, bool forbid_circular) const; 77 | int getNumPiecesOnLine(int pos, int k) const; 78 | int getNumPiecesOnLine(int x, int y, int dk, int dy) const; 79 | bool isOnBoard(int x, int y) const; 80 | bool searchConnection(Player p) const; 81 | Player whoConnectAll(bool& end) const; 82 | bool isCycleAction(const LinesOfActionAction& action) const; 83 | Player getPlayerAtBoardPos(int pos) const; 84 | 85 | GamePair bitboard_; 86 | LinesOfActionHashKey hash_key_; 87 | 88 | std::vector> bitboard_history_; 89 | std::vector> direction_; 90 | std::vector hashkey_history_; 91 | }; 92 | 93 | class LinesOfActionEnvLoader : public BaseBoardEnvLoader { 94 | public: 95 | std::vector getActionFeatures(const int pos, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 96 | inline std::vector getValue(const int pos) const { return {getReturn()}; } 97 | inline std::string name() const override { return kLinesOfActionName; } 98 | inline int getPolicySize() const override { return kLinesOfActionIdxToStr.size(); } 99 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return position; } 100 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return action_id; } 101 | }; 102 | 103 | } // namespace minizero::env::linesofaction 104 | -------------------------------------------------------------------------------- /minizero/environment/nogo/nogo.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "go.h" 4 | #include "go_block.h" 5 | #include "go_grid.h" 6 | #include 7 | 8 | namespace minizero::env::nogo { 9 | 10 | const std::string kNoGoName = "nogo"; 11 | const int kNoGoNumPlayer = 2; 12 | const int kNoGoBoardSize = 9; 13 | 14 | typedef go::GoAction NoGoAction; 15 | 16 | inline void initialize() { go::initialize(); } 17 | 18 | class NoGoEnv : public go::GoEnv { 19 | public: 20 | NoGoEnv() : go::GoEnv() 21 | { 22 | assert(kNoGoBoardSize == minizero::config::env_board_size); 23 | } 24 | 25 | bool isLegalAction(const NoGoAction& action) const override 26 | { 27 | assert(action.getActionID() >= 0 && action.getActionID() <= board_size_ * board_size_); 28 | assert(action.getPlayer() == Player::kPlayer1 || action.getPlayer() == Player::kPlayer2); 29 | 30 | if (isPassAction(action)) { return false; } 31 | 32 | const int position = action.getActionID(); 33 | const Player player = action.getPlayer(); 34 | const go::GoGrid& grid = grids_[position]; 35 | if (grid.getPlayer() != Player::kPlayerNone) { return false; } 36 | 37 | // illegal when suicide or capture opponent's stones 38 | bool is_legal = false; 39 | go::GoBitboard check_neighbor_block_bitboard; 40 | for (const auto& neighbor_pos : grid.getNeighbors()) { 41 | const go::GoGrid& neighbor_grid = grids_[neighbor_pos]; 42 | if (neighbor_grid.getPlayer() == Player::kPlayerNone) { 43 | is_legal = true; 44 | } else { 45 | const go::GoBlock* neighbor_block = neighbor_grid.getBlock(); 46 | if (check_neighbor_block_bitboard.test(neighbor_block->getID())) { continue; } 47 | 48 | check_neighbor_block_bitboard.set(neighbor_block->getID()); 49 | if (neighbor_block->getPlayer() == player) { 50 | if (neighbor_block->getNumLiberty() > 1) { is_legal = true; } 51 | } else { 52 | if (neighbor_block->getNumLiberty() == 1) { return false; } 53 | } 54 | } 55 | } 56 | return is_legal; 57 | } 58 | 59 | bool isTerminal() const override 60 | { 61 | for (int pos = 0; pos < board_size_ * board_size_; ++pos) { 62 | NoGoAction action(pos, turn_); 63 | if (isLegalAction(action)) { return false; } 64 | } 65 | return true; 66 | } 67 | 68 | float getEvalScore(bool is_resign = false) const override 69 | { 70 | Player eval = getNextPlayer(turn_, kNoGoNumPlayer); 71 | switch (eval) { 72 | case Player::kPlayer1: return 1.0f; 73 | case Player::kPlayer2: return -1.0f; 74 | default: return 0.0f; 75 | } 76 | } 77 | 78 | inline std::string name() const override { return kNoGoName + "_" + std::to_string(board_size_) + "x" + std::to_string(board_size_); } 79 | inline int getNumPlayer() const override { return kNoGoNumPlayer; } 80 | }; 81 | 82 | class NoGoEnvLoader : public go::GoEnvLoader { 83 | public: 84 | inline std::string name() const override { return kNoGoName + "_" + std::to_string(getBoardSize()) + "x" + std::to_string(getBoardSize()); } 85 | }; 86 | 87 | } // namespace minizero::env::nogo 88 | -------------------------------------------------------------------------------- /minizero/environment/othello/othello.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_env.h" 4 | #include "configuration.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace minizero::env::othello { 12 | using namespace minizero::utils; 13 | const std::string kOthelloName = "othello"; 14 | const int kOthelloNumPlayer = 2; 15 | const int kMaxOthelloBoardSize = 16; 16 | typedef std::bitset OthelloBitboard; 17 | 18 | typedef BaseBoardAction OthelloAction; 19 | 20 | class OthelloEnv : public BaseBoardEnv { 21 | public: 22 | OthelloEnv() 23 | { 24 | assert(getBoardSize() <= kMaxOthelloBoardSize); 25 | reset(); 26 | } 27 | 28 | void reset() override; 29 | bool act(const OthelloAction& action) override; 30 | bool act(const std::vector& action_string_args) override; 31 | std::vector getLegalActions() const override; 32 | bool isLegalAction(const OthelloAction& action) const override; 33 | bool isTerminal() const override; 34 | float getReward() const override { return 0.0f; } 35 | float getEvalScore(bool is_resign = false) const override; 36 | std::vector getFeatures(utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 37 | std::vector getActionFeatures(const OthelloAction& action, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 38 | inline int getNumInputChannels() const override { return 4; } 39 | inline int getPolicySize() const override { return getBoardSize() * getBoardSize() + 1; } 40 | std::string toString() const override; 41 | inline std::string name() const override { return kOthelloName + "_" + std::to_string(getBoardSize()) + "x" + std::to_string(getBoardSize()); } 42 | inline int getNumPlayer() const override { return kOthelloNumPlayer; } 43 | inline bool isPassAction(const OthelloAction& action) const { return (action.getActionID() == getBoardSize() * getBoardSize()); } 44 | 45 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return utils::getPositionByRotating(rotation, position, getBoardSize()); }; 46 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return getRotatePosition(action_id, rotation); }; 47 | 48 | private: 49 | Player eval() const; 50 | OthelloBitboard getCanPutPoint( 51 | int direction, 52 | OthelloBitboard mask, 53 | OthelloBitboard empty_board, 54 | OthelloBitboard opponent_board, 55 | OthelloBitboard player_board); 56 | OthelloBitboard getFlipPoint( 57 | int direction, 58 | OthelloBitboard mask, 59 | OthelloBitboard placed_pos, 60 | OthelloBitboard opponent_board, 61 | OthelloBitboard player_board); 62 | OthelloBitboard getCandidateAlongDirectionBoard(int direction, OthelloBitboard candidate); 63 | std::string getCoordinateString() const; 64 | 65 | int dir_step_[8]; // 8 directions 66 | OthelloBitboard one_board_; 67 | OthelloBitboard mask_[8]; // 8 directions 68 | GamePair legal_pass_; // store black/white legal pass 69 | GamePair legal_board_; // store black/white legal board 70 | GamePair board_; // store black/white board 71 | }; 72 | 73 | class OthelloEnvLoader : public BaseBoardEnvLoader { 74 | public: 75 | std::vector getActionFeatures(const int pos, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 76 | inline bool isPassAction(const OthelloAction& action) const { return (action.getActionID() == getBoardSize() * getBoardSize()); } 77 | inline std::vector getValue(const int pos) const { return {getReturn()}; } 78 | inline std::string name() const override { return kOthelloName + "_" + std::to_string(getBoardSize()) + "x" + std::to_string(getBoardSize()); } 79 | inline int getPolicySize() const override { return getBoardSize() * getBoardSize() + 1; } 80 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return utils::getPositionByRotating(rotation, position, getBoardSize()); }; 81 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return getRotatePosition(action_id, rotation); }; 82 | }; 83 | 84 | } // namespace minizero::env::othello 85 | -------------------------------------------------------------------------------- /minizero/environment/santorini/bitboard.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace minizero::env::santorini { 6 | 7 | class Bitboard { 8 | public: 9 | Bitboard() : raw_(0ULL) {} 10 | Bitboard(uint64_t v) : raw_(v) {} 11 | Bitboard(const Bitboard& b) = default; 12 | ~Bitboard() = default; 13 | Bitboard& operator=(const Bitboard& b) = default; 14 | operator uint64_t() const { return raw_; } 15 | 16 | Bitboard& operator|=(const Bitboard& b) 17 | { 18 | raw_ |= b.raw_; 19 | return (*this); 20 | } 21 | Bitboard operator|(const Bitboard& b) const 22 | { 23 | Bitboard ret(raw_); 24 | ret |= b; 25 | return ret; 26 | } 27 | Bitboard& operator&=(const Bitboard& b) 28 | { 29 | raw_ &= b.raw_; 30 | return (*this); 31 | } 32 | Bitboard operator&(const Bitboard& b) const 33 | { 34 | Bitboard ret(raw_); 35 | ret &= b; 36 | return ret; 37 | } 38 | Bitboard& operator^=(const Bitboard& b) 39 | { 40 | raw_ ^= b.raw_; 41 | return (*this); 42 | } 43 | Bitboard operator^(const Bitboard& b) const 44 | { 45 | Bitboard ret(raw_); 46 | ret ^= b; 47 | return ret; 48 | } 49 | Bitboard& operator<<=(int shift) 50 | { 51 | raw_ <<= shift; 52 | return (*this); 53 | } 54 | Bitboard operator<<(int shift) const 55 | { 56 | Bitboard ret(raw_); 57 | ret <<= shift; 58 | return ret; 59 | } 60 | Bitboard& operator>>=(int shift) 61 | { 62 | raw_ >>= shift; 63 | return (*this); 64 | } 65 | Bitboard operator>>(int shift) const 66 | { 67 | Bitboard ret(raw_); 68 | ret >>= shift; 69 | return ret; 70 | } 71 | Bitboard operator~() const 72 | { 73 | Bitboard ret(~raw_); 74 | return ret; 75 | } 76 | 77 | std::vector toList() const 78 | { 79 | std::vector ret; 80 | uint64_t b = 1ULL; 81 | for (int i = 0; i < kRowSize * kColSize; ++i) { 82 | if (b & raw_) { ret.push_back(i); } 83 | b <<= 1; 84 | } 85 | return ret; 86 | } 87 | 88 | static Bitboard getNeighbor(int idx) 89 | { 90 | Bitboard ret(kNeighbor); 91 | ret <<= (idx - 8); 92 | ret &= kValidSpace; 93 | return ret; 94 | } 95 | 96 | static constexpr int kRowSize = 7; 97 | static constexpr int kColSize = 7; 98 | 99 | // valid space: 100 | // ------- 101 | // -*****- 102 | // -*****- 103 | // -*****- 104 | // -*****- 105 | // -*****- 106 | // ------- 107 | static constexpr uint64_t kValidSpace = 2147077824256ULL; 108 | 109 | // neighbor: 110 | // ***---- 111 | // *-*---- 112 | // ***---- 113 | // ------- 114 | // ------- 115 | // ------- 116 | // ------- 117 | static constexpr uint64_t kNeighbor = 115335ULL; 118 | 119 | private: 120 | uint64_t raw_ = 0; 121 | }; 122 | 123 | } // namespace minizero::env::santorini 124 | -------------------------------------------------------------------------------- /minizero/environment/santorini/board.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "base_env.h" 3 | #include "bitboard.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace minizero::env::santorini { 10 | 11 | class Board { 12 | public: 13 | Board(); 14 | Board(const Board& board_); 15 | ~Board() = default; 16 | 17 | std::vector> getLegalMove(int p_id) const; 18 | std::vector getLegalBuild(int idx) const; 19 | std::pair getPlayerIdx(int p_id) const; 20 | const std::array& getPlanes() const; 21 | bool isTerminal(int p_id) const; 22 | bool checkWin(int p_id) const; 23 | 24 | std::string toConsole() const; 25 | 26 | bool setPlayer(int p_id, int idx_1 = 0, int idx_2 = 0); 27 | bool movePiece(int from, int to); 28 | bool movePiece(Bitboard from, Bitboard to); 29 | bool buildTower(int idx); 30 | 31 | private: 32 | // One bitboard for playes's pieces and four bitboards for building. 33 | std::array plane_; 34 | 35 | // Each player has two pieces. Totally use four bitboards. 36 | std::array, 2> player_; 37 | }; 38 | 39 | } // namespace minizero::env::santorini 40 | -------------------------------------------------------------------------------- /minizero/environment/santorini/santorini.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_env.h" 4 | #include "board.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace minizero::env::santorini { 11 | 12 | const std::string kSantoriniName = "santorini"; 13 | const int kSantoriniNumPlayer = 2; 14 | const int kSantoriniBoardSize = 5; 15 | const int kSantoriniLetterBoxSize = kSantoriniBoardSize + 2; 16 | const int kSantoriniPolicySize = 1900; 17 | const int kSantoriniHistoryLength = 8; 18 | const int kSantoriniHistorySize = 6; 19 | 20 | inline int positionToLetterBoxIdx(int pos) { return (pos / kSantoriniBoardSize + 1) * kSantoriniLetterBoxSize + pos % kSantoriniBoardSize + 1; } 21 | inline int letterBoaxIdxToposition(int idx) { return (idx / kSantoriniLetterBoxSize - 1) * kSantoriniBoardSize + idx % kSantoriniLetterBoxSize - 1; } 22 | 23 | class SantoriniAction : public BaseBoardAction { 24 | public: 25 | SantoriniAction() : BaseBoardAction() {} 26 | SantoriniAction(int action_id, Player player); 27 | SantoriniAction(const std::vector& action_string_args); 28 | std::string toConsoleString() const override; 29 | 30 | inline int getFrom() const { return from_; } 31 | inline int getTo() const { return to_; } 32 | inline int getBuild() const { return build_; } 33 | 34 | private: 35 | int encodePlaced(int x, int y) const; 36 | std::string getSquareString(int pos) const; 37 | std::pair decodePlaced(int z) const; 38 | void parseAction(int action_id); 39 | 40 | int from_; 41 | int to_; 42 | int build_; 43 | }; 44 | 45 | class SantoriniEnv : public BaseBoardEnv { 46 | public: 47 | SantoriniEnv() { reset(); } 48 | 49 | void reset() override; 50 | bool act(const SantoriniAction& action) override; 51 | bool act(const std::vector& action_string_args) override; 52 | std::vector getLegalActions() const override; 53 | bool isLegalAction(const SantoriniAction& action) const override; 54 | bool isTerminal() const override; 55 | float getReward() const override { return 0.0f; } 56 | float getEvalScore(bool is_resign = false) const override; 57 | std::vector getFeatures(utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 58 | std::vector getActionFeatures(const SantoriniAction& action, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 59 | inline int getNumInputChannels() const override { return 50; } 60 | inline int getPolicySize() const override { return kSantoriniPolicySize; } 61 | inline int getNumActionFeatureChannels() const override { return 0; } 62 | std::string toString() const override; 63 | inline std::string name() const override { return kSantoriniName; } 64 | inline int getNumPlayer() const override { return kSantoriniNumPlayer; } 65 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return position; }; 66 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return action_id; }; 67 | 68 | private: 69 | Board board_; 70 | std::vector history_; 71 | }; 72 | 73 | class SantoriniEnvLoader : public BaseBoardEnvLoader { 74 | public: 75 | std::vector getActionFeatures(const int pos, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 76 | inline std::vector getValue(const int pos) const { return {getReturn()}; } 77 | inline std::string name() const override { return kSantoriniName; } 78 | inline int getPolicySize() const override { return kSantoriniPolicySize; } 79 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return position; } 80 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return action_id; } 81 | }; 82 | 83 | } // namespace minizero::env::santorini 84 | -------------------------------------------------------------------------------- /minizero/environment/stochastic/stochastic_env.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_env.h" 4 | #include "random.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace minizero::env { 11 | 12 | template 13 | class StochasticEnv : public BaseEnv { 14 | public: 15 | StochasticEnv() : BaseEnv(), seed_(0) {} 16 | virtual ~StochasticEnv() = default; 17 | 18 | void reset() override { reset(utils::Random::randInt()); } 19 | bool act(const Action& action) override { return act(action, true); } 20 | bool act(const std::vector& action_string_args) override { return act(action_string_args, true); } 21 | 22 | virtual void reset(int seed) = 0; 23 | virtual bool act(const Action& action, bool with_chance = true) = 0; 24 | virtual bool act(const std::vector& action_string_args, bool with_chance = true) = 0; 25 | virtual bool actChanceEvent(const Action& action) = 0; 26 | virtual std::vector getLegalChanceEvents() const = 0; 27 | virtual bool isLegalChanceEvent(const Action& action) const = 0; 28 | virtual int getMaxChanceEventSize() const = 0; 29 | virtual float getChanceEventProbability(const Action& action) const = 0; 30 | 31 | inline int getSeed() const { return seed_; } 32 | 33 | protected: 34 | int seed_; 35 | std::mt19937 random_; 36 | }; 37 | 38 | template 39 | class StochasticEnvLoader : public BaseEnvLoader { 40 | public: 41 | StochasticEnvLoader() : BaseEnvLoader() {} 42 | virtual ~StochasticEnvLoader() = default; 43 | 44 | void loadFromEnvironment(const Env& env, const std::vector>>& action_info_history = {}) override 45 | { 46 | BaseEnvLoader::loadFromEnvironment(env, action_info_history); 47 | BaseEnvLoader::addTag("SD", std::to_string(env.getSeed())); 48 | } 49 | 50 | inline int getSeed() const { return std::stoi(BaseEnvLoader::getTag("SD")); } 51 | 52 | std::vector getFeatures(const int pos, utils::Rotation rotation = utils::Rotation::kRotationNone) const override 53 | { 54 | // a slow but naive method which simply replays the game again to get features 55 | Env env; 56 | env.reset(getSeed()); 57 | const auto& action_pairs_ = BaseEnvLoader::action_pairs_; 58 | for (int i = 0; i < std::min(pos, static_cast(action_pairs_.size())); ++i) { env.act(action_pairs_[i].first); } 59 | return env.getFeatures(rotation); 60 | } 61 | 62 | virtual std::vector getAfterstateFeatures(const int pos, utils::Rotation rotation) const 63 | { 64 | // a slow but naive method which simply replays the game again to get features 65 | Env env; 66 | env.reset(getSeed()); 67 | const auto& action_pairs_ = BaseEnvLoader::action_pairs_; 68 | for (int i = 0; i < std::min(pos, static_cast(action_pairs_.size())); ++i) { env.act(action_pairs_[i].first); } 69 | if (!env.isTerminal() && pos < static_cast(action_pairs_.size())) { env.act(action_pairs_[pos].first, false); } 70 | return env.getFeatures(rotation); 71 | } 72 | 73 | virtual std::vector getAfterstateValue(const int pos) const = 0; 74 | }; 75 | 76 | } // namespace minizero::env 77 | -------------------------------------------------------------------------------- /minizero/environment/tictactoe/tictactoe.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_env.h" 4 | #include 5 | #include 6 | 7 | namespace minizero::env::tictactoe { 8 | 9 | const std::string kTicTacToeName = "tictactoe"; 10 | const int kTicTacToeNumPlayer = 2; 11 | const int kTicTacToeBoardSize = 3; 12 | 13 | typedef BaseBoardAction TicTacToeAction; 14 | 15 | class TicTacToeEnv : public BaseBoardEnv { 16 | public: 17 | TicTacToeEnv() : BaseBoardEnv(kTicTacToeBoardSize) { reset(); } 18 | 19 | void reset() override; 20 | bool act(const TicTacToeAction& action) override; 21 | bool act(const std::vector& action_string_args) override; 22 | std::vector getLegalActions() const override; 23 | bool isLegalAction(const TicTacToeAction& action) const override; 24 | bool isTerminal() const override; 25 | float getReward() const override { return 0.0f; } 26 | float getEvalScore(bool is_resign = false) const override; 27 | std::vector getFeatures(utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 28 | std::vector getActionFeatures(const TicTacToeAction& action, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 29 | inline int getNumInputChannels() const override { return 4; } 30 | inline int getPolicySize() const override { return getBoardSize() * getBoardSize(); } 31 | std::string toString() const override; 32 | inline std::string name() const override { return kTicTacToeName; } 33 | inline int getNumPlayer() const override { return kTicTacToeNumPlayer; } 34 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return utils::getPositionByRotating(rotation, position, getBoardSize()); }; 35 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return getRotatePosition(action_id, rotation); }; 36 | 37 | private: 38 | Player eval() const; 39 | 40 | std::vector board_; 41 | }; 42 | 43 | class TicTacToeEnvLoader : public BaseBoardEnvLoader { 44 | public: 45 | std::vector getActionFeatures(const int pos, utils::Rotation rotation = utils::Rotation::kRotationNone) const override; 46 | inline std::vector getValue(const int pos) const { return {getReturn()}; } 47 | inline std::string name() const override { return kTicTacToeName; } 48 | inline int getPolicySize() const override { return getBoardSize() * getBoardSize(); } 49 | inline int getRotatePosition(int position, utils::Rotation rotation) const override { return utils::getPositionByRotating(rotation, position, getBoardSize()); }; 50 | inline int getRotateAction(int action_id, utils::Rotation rotation) const override { return getRotatePosition(action_id, rotation); }; 51 | }; 52 | 53 | } // namespace minizero::env::tictactoe 54 | -------------------------------------------------------------------------------- /minizero/learner/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB SRCS *.cpp) 2 | 3 | find_package(pybind11 REQUIRED) 4 | pybind11_add_module(minizero_py ${SRCS}) 5 | target_link_libraries( 6 | minizero_py 7 | PUBLIC 8 | config 9 | environment 10 | utils 11 | ) 12 | 13 | add_library(learner data_loader.cpp) 14 | target_include_directories( 15 | learner PUBLIC 16 | ${CMAKE_CURRENT_SOURCE_DIR} 17 | ) 18 | target_link_libraries( 19 | learner 20 | config 21 | environment 22 | utils 23 | ) -------------------------------------------------------------------------------- /minizero/learner/data_loader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "environment.h" 4 | #include "paralleler.h" 5 | #include "rotation.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace minizero::learner { 14 | 15 | class BaseBatchDataPtr { 16 | public: 17 | BaseBatchDataPtr() {} 18 | virtual ~BaseBatchDataPtr() = default; 19 | }; 20 | 21 | class BatchDataPtr : public BaseBatchDataPtr { 22 | public: 23 | BatchDataPtr() {} 24 | virtual ~BatchDataPtr() = default; 25 | 26 | float* features_; 27 | float* action_features_; 28 | float* policy_; 29 | float* value_; 30 | float* option_; 31 | float* reward_; 32 | float* loss_scale_; 33 | float* option_loss_scale_; 34 | int* step_option_length_; 35 | int* step_unroll_length_; 36 | int* sampled_index_; 37 | std::vector rotation_ = std::vector(config::learner_batch_size, utils::Rotation::kRotationNone); 38 | }; 39 | 40 | class ReplayBuffer { 41 | public: 42 | ReplayBuffer(); 43 | 44 | std::mutex mutex_; 45 | int num_data_; 46 | float game_priority_sum_; 47 | std::deque game_priorities_; 48 | std::deque> position_priorities_; 49 | std::deque env_loaders_; 50 | 51 | void addData(const EnvironmentLoader& env_loader); 52 | std::pair sampleEnvAndPos(); 53 | int sampleIndex(const std::deque& weight); 54 | float getLossScale(const std::pair& p); 55 | }; 56 | 57 | class DataLoaderSharedData : public utils::BaseSharedData { 58 | public: 59 | std::string getNextEnvString(); 60 | int getNextBatchIndex(); 61 | 62 | virtual void createDataPtr() { data_ptr_ = std::make_shared(); } 63 | inline std::shared_ptr getDataPtr() { return std::static_pointer_cast(data_ptr_); } 64 | 65 | int batch_index_; 66 | ReplayBuffer replay_buffer_; 67 | std::mutex mutex_; 68 | std::deque env_strings_; 69 | std::shared_ptr data_ptr_; 70 | }; 71 | 72 | class DataLoaderThread : public utils::BaseSlaveThread { 73 | public: 74 | DataLoaderThread(int id, std::shared_ptr shared_data) 75 | : BaseSlaveThread(id, shared_data) {} 76 | 77 | void initialize() override; 78 | void runJob() override; 79 | bool isDone() override { return false; } 80 | 81 | protected: 82 | virtual bool addEnvironmentLoader(); 83 | virtual bool sampleData(); 84 | 85 | virtual void setAlphaZeroTrainingData(int batch_index); 86 | virtual void setMuZeroTrainingData(int batch_index); 87 | 88 | inline std::shared_ptr getSharedData() { return std::static_pointer_cast(shared_data_); } 89 | }; 90 | 91 | class DataLoader : public utils::BaseParalleler { 92 | public: 93 | DataLoader(const std::string& conf_file_name); 94 | 95 | void initialize() override; 96 | void summarize() override {} 97 | virtual void loadDataFromFile(const std::string& file_name); 98 | virtual void sampleData(); 99 | virtual void updatePriority(int* sampled_index, float* batch_values); 100 | virtual void updateMax(int* sampled_index, int* batch_max_ids); 101 | 102 | void createSharedData() override { shared_data_ = std::make_shared(); } 103 | std::shared_ptr newSlaveThread(int id) override { return std::make_shared(id, shared_data_); } 104 | inline std::shared_ptr getSharedData() { return std::static_pointer_cast(shared_data_); } 105 | }; 106 | 107 | } // namespace minizero::learner 108 | -------------------------------------------------------------------------------- /minizero/minizero.cpp: -------------------------------------------------------------------------------- 1 | #include "mode_handler.h" 2 | 3 | int main(int argc, char* argv[]) 4 | { 5 | minizero::console::ModeHandler mode_handler; 6 | mode_handler.run(argc, argv); 7 | return 0; 8 | } 9 | -------------------------------------------------------------------------------- /minizero/network/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB_RECURSE SRCS *.cpp) 2 | 3 | add_library(network ${SRCS}) 4 | target_include_directories( 5 | network PUBLIC 6 | ${CMAKE_CURRENT_SOURCE_DIR} 7 | ) 8 | target_link_libraries( 9 | network 10 | utils 11 | ${TORCH_LIBRARIES} 12 | ) -------------------------------------------------------------------------------- /minizero/network/alphazero_network.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "network.h" 4 | #include "utils.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace minizero::network { 12 | 13 | class AlphaZeroNetworkOutput : public NetworkOutput { 14 | public: 15 | float value_; 16 | std::vector policy_; 17 | std::vector policy_logits_; 18 | 19 | AlphaZeroNetworkOutput(int policy_size) 20 | { 21 | value_ = 0.0f; 22 | policy_.resize(policy_size, 0.0f); 23 | policy_logits_.resize(policy_size, 0.0f); 24 | } 25 | }; 26 | 27 | class AlphaZeroNetwork : public Network { 28 | public: 29 | AlphaZeroNetwork() 30 | { 31 | clear(); 32 | } 33 | 34 | void loadModel(const std::string& nn_file_name, const int gpu_id) override 35 | { 36 | assert(batch_size_ == 0); // should avoid loading model when batch size is not 0 37 | Network::loadModel(nn_file_name, gpu_id); 38 | clear(); 39 | } 40 | 41 | std::string toString() const override 42 | { 43 | std::ostringstream oss; 44 | oss << Network::toString(); 45 | return oss.str(); 46 | } 47 | 48 | int pushBack(std::vector features) 49 | { 50 | assert(static_cast(features.size()) == getNumInputChannels() * getInputChannelHeight() * getInputChannelWidth()); 51 | assert(batch_size_ < kReserved_batch_size); 52 | 53 | int index; 54 | { 55 | std::lock_guard lock(mutex_); 56 | index = batch_size_++; 57 | tensor_input_.resize(batch_size_); 58 | } 59 | tensor_input_[index] = torch::from_blob(features.data(), {1, getNumInputChannels(), getInputChannelHeight(), getInputChannelWidth()}).clone(); 60 | return index; 61 | } 62 | 63 | std::vector> forward() 64 | { 65 | assert(batch_size_ > 0); 66 | auto forward_result = network_.forward(std::vector{torch::cat(tensor_input_).to(getDevice())}).toGenericDict(); 67 | 68 | auto policy_output = forward_result.at("policy").toTensor().to(at::kCPU); 69 | auto policy_logits_output = forward_result.at("policy_logit").toTensor().to(at::kCPU); 70 | auto value_output = forward_result.at("value").toTensor().to(at::kCPU); 71 | assert(policy_output.numel() == batch_size_ * getActionSize()); 72 | assert(policy_logits_output.numel() == batch_size_ * getActionSize()); 73 | assert(value_output.numel() == batch_size_ * getDiscreteValueSize()); 74 | 75 | const int policy_size = getActionSize(); 76 | std::vector> network_outputs; 77 | for (int i = 0; i < batch_size_; ++i) { 78 | network_outputs.emplace_back(std::make_shared(policy_size)); 79 | auto alphazero_network_output = std::static_pointer_cast(network_outputs.back()); 80 | 81 | // policy & policy logits 82 | std::copy(policy_output.data_ptr() + i * policy_size, 83 | policy_output.data_ptr() + (i + 1) * policy_size, 84 | alphazero_network_output->policy_.begin()); 85 | std::copy(policy_logits_output.data_ptr() + i * policy_size, 86 | policy_logits_output.data_ptr() + (i + 1) * policy_size, 87 | alphazero_network_output->policy_logits_.begin()); 88 | 89 | // value 90 | if (getDiscreteValueSize() == 1) { 91 | alphazero_network_output->value_ = value_output[i].item(); 92 | } else { 93 | int start_value = -getDiscreteValueSize() / 2; 94 | alphazero_network_output->value_ = std::accumulate(value_output.data_ptr() + i * getDiscreteValueSize(), 95 | value_output.data_ptr() + (i + 1) * getDiscreteValueSize(), 96 | 0.0f, 97 | [&start_value](const float& sum, const float& value) { return sum + value * start_value++; }); 98 | alphazero_network_output->value_ = utils::invertValue(alphazero_network_output->value_); 99 | } 100 | } 101 | 102 | clear(); 103 | return network_outputs; 104 | } 105 | 106 | inline int getBatchSize() const { return batch_size_; } 107 | 108 | protected: 109 | inline void clear() 110 | { 111 | batch_size_ = 0; 112 | tensor_input_.clear(); 113 | tensor_input_.reserve(kReserved_batch_size); 114 | } 115 | 116 | int batch_size_; 117 | std::mutex mutex_; 118 | std::vector tensor_input_; 119 | 120 | const int kReserved_batch_size = 4096; 121 | }; 122 | 123 | } // namespace minizero::network 124 | -------------------------------------------------------------------------------- /minizero/network/create_network.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "alphazero_network.h" 4 | #include "muzero_network.h" 5 | #include "network.h" 6 | #include 7 | #include 8 | 9 | namespace minizero::network { 10 | 11 | inline std::shared_ptr createNetwork(const std::string& nn_file_name, const int gpu_id) 12 | { 13 | // TODO: how to speed up? 14 | Network base_network; 15 | base_network.loadModel(nn_file_name, -1); 16 | 17 | std::shared_ptr network; 18 | if (base_network.getNetworkTypeName() == "alphazero") { 19 | network = std::make_shared(); 20 | std::dynamic_pointer_cast(network)->loadModel(nn_file_name, gpu_id); 21 | } else if (base_network.getNetworkTypeName() == "muzero" || base_network.getNetworkTypeName() == "muzero_atari" || base_network.getNetworkTypeName() == "muzero_gridworld") { 22 | network = std::make_shared(); 23 | std::dynamic_pointer_cast(network)->loadModel(nn_file_name, gpu_id); 24 | } else { 25 | // should not be here 26 | assert(false); 27 | } 28 | 29 | return network; 30 | } 31 | 32 | } // namespace minizero::network 33 | -------------------------------------------------------------------------------- /minizero/network/network.cpp: -------------------------------------------------------------------------------- 1 | #include "network.h" 2 | 3 | namespace minizero::network { 4 | 5 | Network::Network() 6 | { 7 | gpu_id_ = -1; 8 | num_input_channels_ = input_channel_height_ = input_channel_width_ = -1; 9 | num_hidden_channels_ = hidden_channel_height_ = hidden_channel_width_ = -1; 10 | num_blocks_ = action_size_ = num_value_hidden_channels_ = discrete_value_size_ = -1; 11 | game_name_ = network_type_name_ = network_file_name_ = ""; 12 | } 13 | 14 | void Network::loadModel(const std::string& nn_file_name, const int gpu_id) 15 | { 16 | gpu_id_ = gpu_id; 17 | network_file_name_ = nn_file_name; 18 | 19 | // load model weights 20 | try { 21 | network_ = torch::jit::load(network_file_name_, getDevice()); 22 | network_.eval(); 23 | } catch (const c10::Error& e) { 24 | std::cerr << e.msg() << std::endl; 25 | assert(false); 26 | } 27 | 28 | // network hyper-parameter 29 | std::vector dummy; 30 | num_input_channels_ = network_.get_method("get_num_input_channels")(dummy).toInt(); 31 | input_channel_height_ = network_.get_method("get_input_channel_height")(dummy).toInt(); 32 | input_channel_width_ = network_.get_method("get_input_channel_width")(dummy).toInt(); 33 | num_hidden_channels_ = network_.get_method("get_num_hidden_channels")(dummy).toInt(); 34 | hidden_channel_height_ = network_.get_method("get_hidden_channel_height")(dummy).toInt(); 35 | hidden_channel_width_ = network_.get_method("get_hidden_channel_width")(dummy).toInt(); 36 | num_blocks_ = network_.get_method("get_num_blocks")(dummy).toInt(); 37 | action_size_ = network_.get_method("get_action_size")(dummy).toInt(); 38 | num_value_hidden_channels_ = network_.get_method("get_num_value_hidden_channels")(dummy).toInt(); 39 | discrete_value_size_ = network_.get_method("get_discrete_value_size")(dummy).toInt(); 40 | game_name_ = network_.get_method("get_game_name")(dummy).toString()->string(); 41 | network_type_name_ = network_.get_method("get_type_name")(dummy).toString()->string(); 42 | } 43 | 44 | std::string Network::toString() const 45 | { 46 | std::ostringstream oss; 47 | oss << "GPU ID: " << gpu_id_ << std::endl; 48 | oss << "Number of input channels: " << num_input_channels_ << std::endl; 49 | oss << "Input channel height: " << input_channel_height_ << std::endl; 50 | oss << "Input channel width: " << input_channel_width_ << std::endl; 51 | oss << "Number of hidden channels: " << num_hidden_channels_ << std::endl; 52 | oss << "Hidden channel height: " << hidden_channel_height_ << std::endl; 53 | oss << "Hidden channel width: " << hidden_channel_width_ << std::endl; 54 | oss << "Number of blocks: " << num_blocks_ << std::endl; 55 | oss << "Action size: " << action_size_ << std::endl; 56 | oss << "Number of value hidden channels: " << num_value_hidden_channels_ << std::endl; 57 | oss << "Discrete value size: " << discrete_value_size_ << std::endl; 58 | oss << "Game name: " << game_name_ << std::endl; 59 | oss << "Network type name: " << network_type_name_ << std::endl; 60 | oss << "Network file name: " << network_file_name_ << std::endl; 61 | return oss.str(); 62 | } 63 | 64 | } // namespace minizero::network 65 | -------------------------------------------------------------------------------- /minizero/network/network.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace minizero::network { 9 | 10 | class NetworkOutput { 11 | public: 12 | virtual ~NetworkOutput() = default; 13 | }; 14 | 15 | class Network { 16 | public: 17 | Network(); 18 | virtual ~Network() = default; 19 | 20 | virtual void loadModel(const std::string& nn_file_name, const int gpu_id); 21 | virtual std::string toString() const; 22 | 23 | inline int getGPUID() const { return gpu_id_; } 24 | inline int getNumInputChannels() const { return num_input_channels_; } 25 | inline int getInputChannelHeight() const { return input_channel_height_; } 26 | inline int getInputChannelWidth() const { return input_channel_width_; } 27 | inline int getNumHiddenChannels() const { return num_hidden_channels_; } 28 | inline int getHiddenChannelHeight() const { return hidden_channel_height_; } 29 | inline int getHiddenChannelWidth() const { return hidden_channel_width_; } 30 | inline int getNumBlocks() const { return num_blocks_; } 31 | inline int getActionSize() const { return action_size_; } 32 | inline int getNumValueHiddenChannels() const { return num_value_hidden_channels_; } 33 | inline int getDiscreteValueSize() const { return discrete_value_size_; } 34 | inline std::string getGameName() const { return game_name_; } 35 | inline std::string getNetworkTypeName() const { return network_type_name_; } 36 | inline std::string getNetworkFileName() const { return network_file_name_; } 37 | 38 | protected: 39 | inline torch::Device getDevice() const { return (gpu_id_ == -1 ? torch::Device("cpu") : torch::Device(torch::kCUDA, gpu_id_)); } 40 | 41 | int gpu_id_; 42 | int num_input_channels_; 43 | int input_channel_height_; 44 | int input_channel_width_; 45 | int num_hidden_channels_; 46 | int hidden_channel_height_; 47 | int hidden_channel_width_; 48 | int num_blocks_; 49 | int action_size_; 50 | int num_value_hidden_channels_; 51 | int discrete_value_size_; 52 | std::string game_name_; 53 | std::string network_type_name_; 54 | std::string network_file_name_; 55 | torch::jit::script::Module network_; 56 | }; 57 | 58 | } // namespace minizero::network 59 | -------------------------------------------------------------------------------- /minizero/network/py/alphazero_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .network_unit import ResidualBlock, PolicyNetwork, ValueNetwork, DiscreteValueNetwork 5 | 6 | 7 | class AlphaZeroNetwork(nn.Module): 8 | def __init__(self, 9 | game_name, 10 | num_input_channels, 11 | input_channel_height, 12 | input_channel_width, 13 | num_hidden_channels, 14 | hidden_channel_height, 15 | hidden_channel_width, 16 | num_blocks, 17 | action_size, 18 | num_value_hidden_channels, 19 | discrete_value_size): 20 | super(AlphaZeroNetwork, self).__init__() 21 | self.game_name = game_name 22 | self.num_input_channels = num_input_channels 23 | self.input_channel_height = input_channel_height 24 | self.input_channel_width = input_channel_width 25 | self.num_hidden_channels = num_hidden_channels 26 | self.hidden_channel_height = hidden_channel_height 27 | self.hidden_channel_width = hidden_channel_width 28 | self.num_blocks = num_blocks 29 | self.action_size = action_size 30 | self.num_value_hidden_channels = num_value_hidden_channels 31 | self.discrete_value_size = discrete_value_size 32 | 33 | self.conv = nn.Conv2d(num_input_channels, num_hidden_channels, kernel_size=3, padding=1) 34 | self.bn = nn.BatchNorm2d(num_hidden_channels) 35 | self.residual_blocks = nn.ModuleList([ResidualBlock(num_hidden_channels) for _ in range(num_blocks)]) 36 | self.policy = PolicyNetwork(num_hidden_channels, hidden_channel_height, hidden_channel_width, action_size) 37 | if self.discrete_value_size == 1: 38 | self.value = ValueNetwork(num_hidden_channels, hidden_channel_height, hidden_channel_width, num_value_hidden_channels) 39 | else: 40 | self.value = DiscreteValueNetwork(num_hidden_channels, hidden_channel_height, hidden_channel_width, num_value_hidden_channels, discrete_value_size) 41 | 42 | @torch.jit.export 43 | def get_type_name(self): 44 | return "alphazero" 45 | 46 | @torch.jit.export 47 | def get_game_name(self): 48 | return self.game_name 49 | 50 | @torch.jit.export 51 | def get_num_input_channels(self): 52 | return self.num_input_channels 53 | 54 | @torch.jit.export 55 | def get_input_channel_height(self): 56 | return self.input_channel_height 57 | 58 | @torch.jit.export 59 | def get_input_channel_width(self): 60 | return self.input_channel_width 61 | 62 | @torch.jit.export 63 | def get_num_hidden_channels(self): 64 | return self.num_hidden_channels 65 | 66 | @torch.jit.export 67 | def get_hidden_channel_height(self): 68 | return self.hidden_channel_height 69 | 70 | @torch.jit.export 71 | def get_hidden_channel_width(self): 72 | return self.hidden_channel_width 73 | 74 | @torch.jit.export 75 | def get_num_blocks(self): 76 | return self.num_blocks 77 | 78 | @torch.jit.export 79 | def get_action_size(self): 80 | return self.action_size 81 | 82 | @torch.jit.export 83 | def get_num_value_hidden_channels(self): 84 | return self.num_value_hidden_channels 85 | 86 | @torch.jit.export 87 | def get_discrete_value_size(self): 88 | return self.discrete_value_size 89 | 90 | def forward(self, state): 91 | x = self.conv(state) 92 | x = self.bn(x) 93 | x = F.relu(x) 94 | for residual_block in self.residual_blocks: 95 | x = residual_block(x) 96 | 97 | # policy 98 | policy_logit = self.policy(x) 99 | policy = torch.softmax(policy_logit, dim=1) 100 | 101 | # value 102 | if self.discrete_value_size == 1: 103 | value = self.value(x) 104 | return {"policy_logit": policy_logit, 105 | "policy": policy, 106 | "value": value} 107 | else: 108 | value_logit = self.value(x) 109 | value = torch.softmax(value_logit, dim=1) 110 | return {"policy_logit": policy_logit, 111 | "policy": policy, 112 | "value_logit": value_logit, 113 | "value": value} 114 | -------------------------------------------------------------------------------- /minizero/network/py/create_network.py: -------------------------------------------------------------------------------- 1 | from .alphazero_network import AlphaZeroNetwork 2 | from .muzero_network import MuZeroNetwork 3 | from .muzero_atari_network import MuZeroAtariNetwork 4 | from .muzero_gridworld_network import MuZeroGridWorldNetwork 5 | 6 | 7 | def create_network(game_name="tietactoe", 8 | num_input_channels=4, 9 | input_channel_height=3, 10 | input_channel_width=3, 11 | num_hidden_channels=16, 12 | hidden_channel_height=3, 13 | hidden_channel_width=3, 14 | num_action_feature_channels=1, 15 | num_blocks=1, 16 | action_size=9, 17 | option_seq_length=1, 18 | option_action_size=19, 19 | num_value_hidden_channels=256, 20 | discrete_value_size=601, 21 | network_type_name="alphazero"): 22 | 23 | network = None 24 | if network_type_name == "alphazero": 25 | network = AlphaZeroNetwork(game_name, 26 | num_input_channels, 27 | input_channel_height, 28 | input_channel_width, 29 | num_hidden_channels, 30 | hidden_channel_height, 31 | hidden_channel_width, 32 | num_blocks, 33 | action_size, 34 | num_value_hidden_channels, 35 | discrete_value_size) 36 | elif network_type_name == "muzero": 37 | if "atari" in game_name: 38 | network = MuZeroAtariNetwork(game_name, 39 | num_input_channels, 40 | input_channel_height, 41 | input_channel_width, 42 | num_hidden_channels, 43 | hidden_channel_height, 44 | hidden_channel_width, 45 | num_action_feature_channels, 46 | num_blocks, 47 | action_size, 48 | option_seq_length, 49 | option_action_size, 50 | num_value_hidden_channels, 51 | discrete_value_size) 52 | elif "gridworld" in game_name: 53 | network = MuZeroGridWorldNetwork(game_name, 54 | num_input_channels, 55 | input_channel_height, 56 | input_channel_width, 57 | num_hidden_channels, 58 | hidden_channel_height, 59 | hidden_channel_width, 60 | num_action_feature_channels, 61 | num_blocks, 62 | action_size, 63 | option_seq_length, 64 | option_action_size, 65 | num_value_hidden_channels, 66 | discrete_value_size) 67 | else: 68 | network = MuZeroNetwork(game_name, 69 | num_input_channels, 70 | input_channel_height, 71 | input_channel_width, 72 | num_hidden_channels, 73 | hidden_channel_height, 74 | hidden_channel_width, 75 | num_action_feature_channels, 76 | num_blocks, 77 | action_size, 78 | option_seq_length, 79 | option_action_size, 80 | num_value_hidden_channels, 81 | discrete_value_size) 82 | else: 83 | assert False 84 | 85 | return network 86 | -------------------------------------------------------------------------------- /minizero/utils/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB SRCS *.cpp) 2 | 3 | add_library(utils ${SRCS}) 4 | target_include_directories(utils PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) 5 | target_link_libraries(utils ${Boost_LIBRARIES}) -------------------------------------------------------------------------------- /minizero/utils/color_message.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace minizero::utils { 5 | 6 | enum class TextType { 7 | kNormal, 8 | kBold, 9 | kUnderLine, 10 | kSize 11 | }; 12 | 13 | enum class TextColor { 14 | kBlack, 15 | kRed, 16 | kGreen, 17 | kYellow, 18 | kBlue, 19 | kPurple, 20 | kCyan, 21 | kWhite, 22 | kSize 23 | }; 24 | 25 | inline std::string getColorText(std::string text, TextType text_type, TextColor text_color, TextColor text_background) 26 | { 27 | const int text_type_number[static_cast(TextType::kSize)] = {0, 1, 4}; 28 | return "\33[" + std::to_string(text_type_number[static_cast(text_type)]) + 29 | ";3" + std::to_string(static_cast(text_color)) + 30 | ";4" + std::to_string(static_cast(text_background)) + 31 | "m" + text + "\33[0m"; 32 | } 33 | 34 | } // namespace minizero::utils 35 | -------------------------------------------------------------------------------- /minizero/utils/ostream_redirector.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | namespace minizero::utils { 8 | 9 | class OstreamRedirector { 10 | public: 11 | OstreamRedirector(std::ostream& src, std::ostream& dst) : src(src), sbuf(src.rdbuf(dst.rdbuf())) {} 12 | ~OstreamRedirector() { src.rdbuf(sbuf); } 13 | 14 | public: 15 | static bool silence(std::ostream& src, bool silence = true) 16 | { 17 | static std::map> redirects; 18 | auto it = redirects.find(&src); 19 | if (silence && it == redirects.end()) { 20 | std::ostream* nullout = new std::ofstream; 21 | OstreamRedirector* redirect = new OstreamRedirector(src, *nullout); 22 | redirects.insert({&src, std::make_pair(nullout, redirect)}); 23 | return true; 24 | } else if (!silence && it != redirects.end()) { 25 | std::ostream* nullout = it->second.first; 26 | OstreamRedirector* redirect = it->second.second; 27 | delete redirect; 28 | delete nullout; 29 | redirects.erase(it); 30 | return true; 31 | } 32 | return false; 33 | } 34 | 35 | private: 36 | std::ostream& src; 37 | std::streambuf* const sbuf; 38 | }; 39 | 40 | } // namespace minizero::utils 41 | -------------------------------------------------------------------------------- /minizero/utils/paralleler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace minizero::utils { 8 | 9 | class BaseSharedData { 10 | public: 11 | BaseSharedData() {} 12 | virtual ~BaseSharedData() = default; 13 | }; 14 | 15 | class BaseSlaveThread { 16 | public: 17 | BaseSlaveThread(int id, std::shared_ptr shared_data) 18 | : id_(id), 19 | shared_data_(shared_data), 20 | start_barrier_(2), 21 | finish_barrier_(2) {} 22 | virtual ~BaseSlaveThread() = default; 23 | 24 | void run() 25 | { 26 | initialize(); 27 | while (!isDone()) { 28 | start_barrier_.wait(); 29 | runJob(); 30 | finish_barrier_.wait(); 31 | } 32 | } 33 | 34 | virtual void initialize() = 0; 35 | virtual void runJob() = 0; 36 | virtual bool isDone() = 0; 37 | 38 | inline void start() { start_barrier_.wait(); } 39 | inline void finish() { finish_barrier_.wait(); } 40 | 41 | protected: 42 | int id_; 43 | std::shared_ptr shared_data_; 44 | boost::barrier start_barrier_; 45 | boost::barrier finish_barrier_; 46 | }; 47 | 48 | class BaseParalleler { 49 | public: 50 | BaseParalleler() {} 51 | 52 | virtual ~BaseParalleler() 53 | { 54 | thread_groups_.interrupt_all(); 55 | thread_groups_.join_all(); 56 | } 57 | 58 | void run() 59 | { 60 | initialize(); 61 | for (auto& t : slave_threads_) { t->start(); } 62 | for (auto& t : slave_threads_) { t->finish(); } 63 | summarize(); 64 | } 65 | 66 | virtual void initialize() = 0; 67 | virtual void summarize() = 0; 68 | 69 | protected: 70 | void createSlaveThreads(int num_threads) 71 | { 72 | createSharedData(); 73 | for (int id = 0; id < num_threads; ++id) { 74 | slave_threads_.emplace_back(newSlaveThread(id)); 75 | thread_groups_.create_thread(boost::bind(&BaseSlaveThread::run, slave_threads_.back())); 76 | } 77 | } 78 | 79 | virtual void createSharedData() = 0; 80 | virtual std::shared_ptr newSlaveThread(int id) = 0; 81 | 82 | boost::thread_group thread_groups_; 83 | std::shared_ptr shared_data_; 84 | std::vector> slave_threads_; 85 | }; 86 | 87 | } // namespace minizero::utils 88 | -------------------------------------------------------------------------------- /minizero/utils/random.cpp: -------------------------------------------------------------------------------- 1 | #include "random.h" 2 | 3 | namespace minizero::utils { 4 | 5 | thread_local std::mt19937 Random::generator_; 6 | thread_local std::uniform_int_distribution Random::int_distribution_; 7 | thread_local std::uniform_real_distribution Random::real_distribution_; 8 | 9 | } // namespace minizero::utils 10 | -------------------------------------------------------------------------------- /minizero/utils/random.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace minizero::utils { 8 | 9 | class Random { 10 | public: 11 | static inline void seed(int seed) { generator_.seed(seed); } 12 | static inline int randInt() { return int_distribution_(generator_); } 13 | static inline double randReal(double range = 1.0f) { return real_distribution_(generator_) * range; } 14 | 15 | static inline std::vector randDirichlet(float alpha, int size) 16 | { 17 | std::vector dirichlet; 18 | std::gamma_distribution gamma_distribution(alpha); 19 | for (int i = 0; i < size; ++i) { dirichlet.emplace_back(gamma_distribution(generator_)); } 20 | float sum = std::accumulate(dirichlet.begin(), dirichlet.end(), 0.0f); 21 | if (sum < std::numeric_limits::min()) { return dirichlet; } 22 | for (int i = 0; i < size; ++i) { dirichlet[i] /= sum; } 23 | return dirichlet; 24 | } 25 | 26 | static inline std::vector randGumbel(int size) 27 | { 28 | std::extreme_value_distribution gumbel_distribution(0.0, 1.0); 29 | std::vector gumbel; 30 | for (int i = 0; i < size; ++i) { 31 | float value = gumbel_distribution(generator_); 32 | while (std::isinf(value)) { value = gumbel_distribution(generator_); } 33 | gumbel.emplace_back(value); 34 | } 35 | return gumbel; 36 | } 37 | 38 | static thread_local std::mt19937 generator_; 39 | static thread_local std::uniform_int_distribution int_distribution_; 40 | static thread_local std::uniform_real_distribution real_distribution_; 41 | }; 42 | 43 | } // namespace minizero::utils 44 | -------------------------------------------------------------------------------- /minizero/utils/rotation.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace minizero::utils { 8 | 9 | enum class Rotation { 10 | kRotationNone, 11 | kRotation90, 12 | kRotation180, 13 | kRotation270, 14 | kHorizontalRotation, 15 | kHorizontalRotation90, 16 | kHorizontalRotation180, 17 | kHorizontalRotation270, 18 | kRotateSize 19 | }; 20 | 21 | const Rotation reversed_rotation[static_cast(Rotation::kRotateSize)] = { 22 | Rotation::kRotationNone, 23 | Rotation::kRotation270, 24 | Rotation::kRotation180, 25 | Rotation::kRotation90, 26 | Rotation::kHorizontalRotation, 27 | Rotation::kHorizontalRotation90, 28 | Rotation::kHorizontalRotation180, 29 | Rotation::kHorizontalRotation270}; 30 | 31 | const std::string rotation_string[static_cast(Rotation::kRotateSize)] = { 32 | "Rotation_None", 33 | "Rotation_90_Degree", 34 | "Rotation_180_Degree", 35 | "Rotation_270_Degree", 36 | "Horizontal_Rotation", 37 | "Horizontal_Rotation_90_Degree", 38 | "Horizontal_Rotation_180_Degree", 39 | "Horizontal_Rotation_270_Degree"}; 40 | 41 | inline std::string getRotationString(Rotation rotate) { return rotation_string[static_cast(rotate)]; } 42 | 43 | inline Rotation getRotationFromString(const std::string rotation_str) 44 | { 45 | for (int i = 0; i < static_cast(Rotation::kRotateSize); ++i) { 46 | if (rotation_str == rotation_string[i]) { return static_cast(i); } 47 | } 48 | return Rotation::kRotateSize; 49 | } 50 | 51 | inline int getPositionByRotating(Rotation rotation, int original_pos, int board_size) 52 | { 53 | assert(original_pos >= 0 && original_pos <= board_size * board_size); 54 | if (original_pos >= board_size * board_size) { return original_pos; } 55 | 56 | const float center = (board_size - 1) / 2.0; 57 | float x = original_pos % board_size - center; 58 | float y = original_pos / board_size - center; 59 | float rotation_x = x, rotation_y = y; 60 | switch (rotation) { 61 | case Rotation::kRotationNone: 62 | rotation_x = x, rotation_y = y; 63 | break; 64 | case Rotation::kRotation90: 65 | rotation_x = y, rotation_y = -x; 66 | break; 67 | case Rotation::kRotation180: 68 | rotation_x = -x, rotation_y = -y; 69 | break; 70 | case Rotation::kRotation270: 71 | rotation_x = -y, rotation_y = x; 72 | break; 73 | case Rotation::kHorizontalRotation: 74 | rotation_x = x, rotation_y = -y; 75 | break; 76 | case Rotation::kHorizontalRotation90: 77 | rotation_x = -y, rotation_y = -x; 78 | break; 79 | case Rotation::kHorizontalRotation180: 80 | rotation_x = -x, rotation_y = y; 81 | break; 82 | case Rotation::kHorizontalRotation270: 83 | rotation_x = y, rotation_y = x; 84 | break; 85 | default: 86 | assert(false); 87 | break; 88 | } 89 | 90 | int new_pos = (rotation_y + center) * board_size + (rotation_x + center); 91 | assert(new_pos >= 0 && new_pos < board_size * board_size); 92 | return new_pos; 93 | } 94 | 95 | } // namespace minizero::utils 96 | -------------------------------------------------------------------------------- /minizero/utils/sgf_loader.cpp: -------------------------------------------------------------------------------- 1 | #include "sgf_loader.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | namespace minizero::utils { 8 | 9 | bool SGFLoader::loadFromFile(const std::string& file_name) 10 | { 11 | std::ifstream fin(file_name.c_str()); 12 | if (!fin) { return false; } 13 | 14 | std::string line; 15 | std::string sgf_content; 16 | while (std::getline(fin, line)) { 17 | if (line.back() == '\r') { line.pop_back(); } 18 | if (line.empty()) { continue; } 19 | sgf_content += line; 20 | } 21 | return loadFromString(sgf_content); 22 | } 23 | 24 | bool SGFLoader::loadFromString(const std::string& content) 25 | { 26 | reset(); 27 | sgf_content_ = content; 28 | std::string key, value; 29 | int state = '('; 30 | bool accept_move = false; 31 | bool escape_next = false; 32 | int board_size = -1; 33 | for (char c : content) { 34 | switch (state) { 35 | case '(': // wait until record start 36 | if (!accept_move) { 37 | accept_move = (c == '('); 38 | } else { 39 | state = (c == ';') ? c : 'x'; 40 | accept_move = false; 41 | } 42 | break; 43 | case ';': // store key 44 | if (c == ';') { 45 | accept_move = true; 46 | } else if (c == '[' || c == ')') { 47 | state = c; 48 | } else if (std::isgraph(c)) { 49 | key += c; 50 | } 51 | break; 52 | case '[': // store value 53 | if (c == '\\' && !escape_next) { 54 | escape_next = true; 55 | } else if (c != ']' || escape_next) { 56 | value += c; 57 | escape_next = false; 58 | } else { // ready to store key-value pair 59 | if (accept_move) { 60 | if (board_size == -1) { return false; } 61 | actions_.emplace_back().first = SGFAction(key, actionIDToBoardCoordinateString(sgfStringToActionID(value, board_size), board_size)); 62 | accept_move = false; 63 | } else if (actions_.size()) { 64 | actions_.back().second[key] = std::move(value); 65 | } else { 66 | if (key == "SZ") { board_size = std::stoi(value); } 67 | tags_[key] = std::move(value); 68 | } 69 | key.clear(); 70 | value.clear(); 71 | state = ';'; 72 | } 73 | break; 74 | case ')': // end of record, do nothing 75 | break; 76 | } 77 | } 78 | return state == ')'; 79 | } 80 | 81 | void SGFLoader::reset() 82 | { 83 | file_name_.clear(); 84 | sgf_content_.clear(); 85 | tags_.clear(); 86 | actions_.clear(); 87 | } 88 | 89 | int SGFLoader::boardCoordinateStringToActionID(const std::string& board_coordinate_string, int board_size) 90 | { 91 | std::string tmp = board_coordinate_string; 92 | std::transform(tmp.begin(), tmp.end(), tmp.begin(), ::toupper); 93 | if (tmp == "PASS") { return board_size * board_size; } 94 | 95 | if (board_coordinate_string.size() < 2) { return -1; } 96 | int x = std::toupper(board_coordinate_string[0]) - 'A' + (std::toupper(board_coordinate_string[0]) > 'I' ? -1 : 0); 97 | int y = atoi(board_coordinate_string.substr(1).c_str()) - 1; 98 | return y * board_size + x; 99 | } 100 | 101 | std::string SGFLoader::actionIDToBoardCoordinateString(int action_id, int board_size) 102 | { 103 | assert(action_id >= 0 && action_id <= board_size * board_size); 104 | 105 | if (action_id == board_size * board_size) { return "PASS"; } 106 | int x = action_id % board_size; 107 | int y = action_id / board_size; 108 | std::ostringstream oss; 109 | oss << static_cast(x + 'A' + (x >= 8)) << y + 1; 110 | return oss.str(); 111 | } 112 | 113 | int SGFLoader::sgfStringToActionID(const std::string& sgf_string, int board_size) 114 | { 115 | if (sgf_string.size() != 2) { return board_size * board_size; } 116 | int x = std::toupper(sgf_string[0]) - 'A'; 117 | int y = (board_size - 1) - (std::toupper(sgf_string[1]) - 'A'); 118 | return y * board_size + x; 119 | } 120 | 121 | std::string SGFLoader::actionIDToSGFString(int action_id, int board_size) 122 | { 123 | assert(action_id >= 0 && action_id <= board_size * board_size); 124 | 125 | if (action_id == board_size * board_size) { return ""; } 126 | int x = action_id % board_size; 127 | int y = action_id / board_size; 128 | std::ostringstream oss; 129 | oss << static_cast(x + 'a') << static_cast(((board_size - 1) - y) + 'a'); 130 | return oss.str(); 131 | } 132 | 133 | std::string SGFLoader::trimSpace(const std::string& s) const 134 | { 135 | bool skip = false; 136 | std::string new_s; 137 | for (const auto& c : s) { 138 | skip = (c == '[' ? true : (c == ']' ? false : skip)); 139 | if (skip || c != ' ') { new_s += c; } 140 | } 141 | return new_s; 142 | } 143 | 144 | } // namespace minizero::utils 145 | -------------------------------------------------------------------------------- /minizero/utils/sgf_loader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "vector_map.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace minizero::utils { 10 | 11 | class SGFLoader { 12 | public: 13 | class SGFAction : public std::vector { 14 | public: 15 | SGFAction(const std::string& player, const std::string& move) 16 | { 17 | push_back(player); 18 | push_back(move); 19 | } 20 | SGFAction(const std::vector& action) : std::vector(action) {} 21 | SGFAction() {} 22 | }; 23 | 24 | typedef minizero::utils::VectorMap SGFTags; 25 | typedef minizero::utils::VectorMap SGFActionInfo; 26 | 27 | public: 28 | virtual bool loadFromFile(const std::string& file_name); 29 | virtual bool loadFromString(const std::string& sgf_content); 30 | 31 | inline const std::string& getFileName() const { return file_name_; } 32 | inline const std::string& getSGFContent() const { return sgf_content_; } 33 | inline const SGFTags& getTags() const { return tags_; } 34 | inline const std::vector>& getActions() const { return actions_; } 35 | 36 | static int boardCoordinateStringToActionID(const std::string& board_coordinate_string, int board_size); 37 | static std::string actionIDToBoardCoordinateString(int action_id, int board_size); 38 | static int sgfStringToActionID(const std::string& sgf_string, int board_size); 39 | static std::string actionIDToSGFString(int action_id, int board_size); 40 | 41 | protected: 42 | virtual void reset(); 43 | std::string trimSpace(const std::string& s) const; 44 | 45 | std::string file_name_; 46 | std::string sgf_content_; 47 | SGFTags tags_; 48 | std::vector> actions_; 49 | }; 50 | 51 | } // namespace minizero::utils 52 | -------------------------------------------------------------------------------- /minizero/utils/thread_pool.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | namespace minizero::utils { 19 | 20 | class ThreadPool { 21 | public: 22 | /** 23 | * Starting a muti-thread task. 24 | * 25 | * @param work Callback function with two integer parameters `wid` and `tid`. User defined task. 26 | * `wid`: unique work(task) id from 0 to `n_works` - 1. 27 | * `tid`: unique thread id from 0 to `n_workers` - 1. 28 | * @param n_works n independent tasks to do. 29 | * @param n_workers n threads to use. 30 | * @return No return. 31 | */ 32 | void start(std::function work, int n_works, int n_workers) 33 | { 34 | n_works_ = n_works; 35 | n_workers_ = n_workers; 36 | works_count_ = 0; 37 | workers_.clear(); 38 | for (int i = 0; i < n_workers_; i++) 39 | workers_.push_back(std::thread(&ThreadPool::start_worker, this, work, i)); 40 | for (auto& w : workers_) 41 | w.join(); 42 | } 43 | 44 | private: 45 | void start_worker(std::function work, int tid) 46 | { 47 | while (true) { 48 | int work_id = works_count_++; 49 | if (work_id >= n_works_) break; 50 | work(work_id, tid); 51 | } 52 | } 53 | 54 | private: 55 | int n_works_; 56 | int n_workers_; 57 | std::atomic works_count_; 58 | std::vector workers_; 59 | }; 60 | 61 | } // namespace minizero::utils -------------------------------------------------------------------------------- /minizero/utils/time_system.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace minizero::utils { 9 | 10 | class TimeSystem { 11 | public: 12 | static boost::posix_time::ptime getLocalTime() { return boost::posix_time::microsec_clock::local_time(); } 13 | 14 | static std::string getTimeString(std::string format = "Y/m/d H:i:s", boost::posix_time::ptime local_time = getLocalTime()) 15 | { 16 | std::string time_string; 17 | bool is_escape = false; 18 | 19 | for (size_t i = 0; i < format.length(); ++i) { 20 | if (is_escape) { 21 | time_string += format.at(i); 22 | is_escape = false; 23 | } else { 24 | switch (format.at(i)) { 25 | case 'Y': time_string += translateIntToString(local_time.date().year()); break; 26 | case 'y': time_string += translateIntToString(local_time.date().year() % 100, 2); break; 27 | case 'm': time_string += translateIntToString(local_time.date().month(), 2); break; 28 | case 'd': time_string += translateIntToString(local_time.date().day(), 2); break; 29 | case 'H': time_string += translateIntToString(local_time.time_of_day().hours(), 2); break; 30 | case 'i': time_string += translateIntToString(local_time.time_of_day().minutes(), 2); break; 31 | case 's': time_string += translateIntToString(local_time.time_of_day().seconds(), 2); break; 32 | case 'f': time_string += translateIntToString(local_time.time_of_day().total_milliseconds() % 1000, 3); break; 33 | case 'u': time_string += translateIntToString(local_time.time_of_day().total_microseconds() % 1000000, 6); break; 34 | case '\\': is_escape = true; break; 35 | default: time_string += format.at(i); break; 36 | } 37 | } 38 | } 39 | return time_string; 40 | } 41 | 42 | private: 43 | static std::string translateIntToString(int value, int width = 0) 44 | { 45 | char buf[16]; 46 | static char zero_fill_format[] = "%0*d", non_zero_fill_format[] = "%*d"; 47 | 48 | if (width > 15) width = 15; 49 | if (width < 0) width = 0; 50 | 51 | char* format = (width ? zero_fill_format : non_zero_fill_format); 52 | snprintf(buf, sizeof(buf), format, width, value); 53 | return buf; 54 | } 55 | }; 56 | 57 | } // namespace minizero::utils 58 | -------------------------------------------------------------------------------- /minizero/utils/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | namespace minizero::utils { 15 | 16 | inline std::vector stringToVector(const std::string& s, const std::string& delim = " ", bool compress = true) 17 | { 18 | std::vector args; 19 | if (delim.size()) { 20 | args.reserve(std::count(s.begin(), s.end(), delim.front())); 21 | std::string::size_type pos = 0, end; 22 | while ((end = s.find(delim, pos)) != std::string::npos) { 23 | if (end > pos || !compress) { args.emplace_back(s.substr(pos, end - pos)); } 24 | pos = end + delim.size(); 25 | } 26 | if (s.size() > pos || !compress) { args.emplace_back(s.substr(pos)); } 27 | } else { 28 | if (!compress) { args.emplace_back(); } 29 | for (char ch : s) { args.emplace_back(1, ch); } 30 | if (!compress) { args.emplace_back(); } 31 | } 32 | return args; 33 | } 34 | 35 | inline std::string compressToBinaryString(const std::string& s) 36 | { 37 | if (s.empty()) { return s; } 38 | 39 | // use gzip to compress string 40 | std::stringstream compressed; 41 | boost::iostreams::filtering_streambuf out; 42 | out.push(boost::iostreams::gzip_compressor()); 43 | out.push(compressed); 44 | boost::iostreams::copy(boost::iostreams::basic_array_source{s.data(), s.size()}, out); 45 | boost::iostreams::close(out); 46 | return compressed.str(); 47 | } 48 | 49 | inline std::string binaryToHexString(const std::string& s) 50 | { 51 | // encode binary string to hex string 52 | std::ostringstream oss; 53 | for (size_t i = 0; i < s.size(); ++i) { 54 | oss << std::setfill('0') << std::setw(2) << std::hex << static_cast(static_cast(s[i])); 55 | } 56 | return oss.str(); 57 | } 58 | 59 | inline std::string hexToBinaryString(const std::string& s) 60 | { 61 | assert(s.size() % 2 == 0); 62 | 63 | // decode hex string to binary string 64 | std::string decompressed_string; 65 | for (size_t i = 0; i < s.size(); i += 2) { decompressed_string += static_cast(std::stoi(s.substr(i, 2), 0, 16)); } 66 | return decompressed_string; 67 | } 68 | 69 | inline std::string decompressBinaryString(const std::string& s) 70 | { 71 | if (s.empty()) { return s; } 72 | 73 | // use gzip to decompress binary string 74 | boost::iostreams::filtering_streambuf in; 75 | in.push(boost::iostreams::gzip_decompressor()); 76 | in.push(boost::iostreams::basic_array_source(s.data(), s.size())); 77 | std::stringstream decompressed; 78 | boost::iostreams::copy(in, decompressed); 79 | boost::iostreams::close(in); 80 | return decompressed.str(); 81 | } 82 | 83 | inline std::string compressString(const std::string& s) 84 | { 85 | return binaryToHexString(compressToBinaryString(s)); 86 | } 87 | 88 | inline std::string decompressString(const std::string& s) 89 | { 90 | return decompressBinaryString(hexToBinaryString(s)); 91 | } 92 | 93 | inline float transformValue(float value) 94 | { 95 | // reference: Observe and Look Further: Achieving Consistent Performance on Atari, page 11 96 | const float epsilon = 0.001; 97 | const float sign_value = (value > 0.0f ? 1.0f : (value == 0.0f ? 0.0f : -1.0f)); 98 | value = sign_value * (sqrt(fabs(value) + 1) - 1) + epsilon * value; 99 | return value; 100 | } 101 | 102 | inline float invertValue(float value) 103 | { 104 | // reference: Observe and Look Further: Achieving Consistent Performance on Atari, page 11 105 | const float epsilon = 0.001; 106 | const float sign_value = (value > 0.0f ? 1.0f : (value == 0.0f ? 0.0f : -1.0f)); 107 | return sign_value * (powf((sqrt(1 + 4 * epsilon * (fabs(value) + 1 + epsilon)) - 1) / (2 * epsilon), 2.0f) - 1); 108 | } 109 | 110 | template 111 | float stddev(const std::vector& input) 112 | { 113 | if (input.size() <= 1) { return 0.0f; } 114 | 115 | double mean = std::accumulate(input.begin(), input.end(), 0.0) / input.size(); 116 | double variance = std::accumulate(input.begin(), input.end(), 0.0, [&](double sum, const T& value) { return sum + std::pow(static_cast(value) - mean, 2); }); 117 | return std::sqrt(variance / (input.size() - 1)); 118 | } 119 | 120 | } // namespace minizero::utils 121 | -------------------------------------------------------------------------------- /minizero/utils/vector_map.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace minizero::utils { 10 | 11 | template 12 | class VectorMap { 13 | public: 14 | typedef std::pair Item; 15 | 16 | VectorMap() {} 17 | VectorMap(const VectorMap& info) : info_(info.info_) {} 18 | VectorMap(VectorMap&& info) : info_(std::move(info.info_)) {} 19 | VectorMap(const std::vector& info) 20 | { 21 | for (const Item& item : info) { info_.emplace_back(item); } 22 | } 23 | VectorMap(std::vector&& info) 24 | { 25 | std::vector buff = std::move(info); 26 | for (Item& item : buff) { info_.emplace_back(std::move(item)); } 27 | } 28 | 29 | public: 30 | operator std::vector &() { return info_; } 31 | operator const std::vector &() const { return info_; } 32 | VectorMap& operator=(const VectorMap& info) 33 | { 34 | info_ = info.info_; 35 | return *this; 36 | } 37 | VectorMap& operator=(VectorMap&& info) 38 | { 39 | info_ = std::move(info.info_); 40 | return *this; 41 | } 42 | Value& operator[](const Key& key) 43 | { 44 | auto it = find(key); 45 | return (it != end()) ? it->second : info_.emplace_back(key, Value()).second; 46 | } 47 | const Value& operator[](const Key& key) const 48 | { 49 | static Value npos; 50 | auto it = find(key); 51 | return (it != end()) ? it->second : npos; 52 | } 53 | 54 | public: 55 | Value& at(const Key& key) 56 | { 57 | auto it = find(key); 58 | if (it != end()) { return it->second; } 59 | throw std::out_of_range("key not found"); 60 | } 61 | const Value& at(const Key& key) const 62 | { 63 | auto it = find(key); 64 | if (it != end()) { return it->second; } 65 | throw std::out_of_range("key not found"); 66 | } 67 | bool empty() const { return info_.empty(); } 68 | size_t size() const { return info_.size(); } 69 | void reserve(size_t n) { info_.reserve(n); } 70 | 71 | auto insert(const Item& p) 72 | { 73 | auto it = find(p.first); 74 | return std::pair(it == end() ? info_.insert(it, p) : it, it == end()); 75 | } 76 | auto insert(Item&& p) 77 | { 78 | auto it = find(p.first); 79 | return std::pair(it == end() ? info_.insert(it, std::move(p)) : it, it == end()); 80 | } 81 | auto erase(const Key& key) 82 | { 83 | auto it = find(key); 84 | return (it != end()) ? info_.erase(it) : end(); 85 | } 86 | void clear() { info_.clear(); } 87 | 88 | auto find(const Key& key) 89 | { 90 | return std::find_if(begin(), end(), [&](const Item& p) { return p.first == key; }); 91 | } 92 | auto find(const Key& key) const 93 | { 94 | return std::find_if(begin(), end(), [&](const Item& p) { return p.first == key; }); 95 | } 96 | size_t count(const Key& key) const { return find(key) != end() ? 1 : 0; } 97 | bool contains(const Key& key) const { return count(key) > 0; } 98 | 99 | public: 100 | auto begin() { return info_.begin(); } 101 | auto end() { return info_.end(); } 102 | auto begin() const { return info_.begin(); } 103 | auto end() const { return info_.end(); } 104 | 105 | private: 106 | std::deque info_; 107 | }; 108 | 109 | } // namespace minizero::utils 110 | -------------------------------------------------------------------------------- /minizero/zero/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB SRCS *.cpp) 2 | 3 | add_library(zero ${SRCS}) 4 | target_include_directories( 5 | zero PUBLIC 6 | ${CMAKE_CURRENT_SOURCE_DIR} 7 | ) 8 | target_link_libraries( 9 | zero 10 | config 11 | utils 12 | ${Boost_LIBRARIES} 13 | ) -------------------------------------------------------------------------------- /minizero/zero/zero_server.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base_server.h" 4 | #include "configuration.h" 5 | #include "time_system.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace minizero::zero { 14 | 15 | class ZeroLogger { 16 | public: 17 | ZeroLogger() {} 18 | void createLog(); 19 | 20 | inline void addWorkerLog(const std::string& log_str) { addLog(log_str, worker_log_); } 21 | inline void addTrainingLog(const std::string& log_str) { addLog(log_str, training_log_); } 22 | inline std::fstream& getSelfPlayFileStream() { return self_play_game_; } 23 | 24 | private: 25 | void addLog(const std::string& log_str, std::fstream& log_file); 26 | 27 | std::fstream worker_log_; 28 | std::fstream training_log_; 29 | std::fstream self_play_game_; 30 | }; 31 | 32 | class ZeroSelfPlayData { 33 | public: 34 | bool is_terminal_; 35 | int data_length_; 36 | int game_length_; 37 | float return_; 38 | std::string game_record_; 39 | 40 | ZeroSelfPlayData() {} 41 | ZeroSelfPlayData(std::string input_data); 42 | }; 43 | 44 | class ZeroWorkerSharedData { 45 | public: 46 | ZeroWorkerSharedData(boost::mutex& worker_mutex) 47 | : worker_mutex_(worker_mutex) 48 | { 49 | } 50 | 51 | bool getSelfPlayData(ZeroSelfPlayData& sp_data); 52 | bool isOptimizationPahse(); 53 | int getModelIetration(); 54 | 55 | bool is_optimization_phase_; 56 | int num_op_worker_; 57 | int total_games_; 58 | int model_iteration_; 59 | ZeroLogger logger_; 60 | std::string updated_conf_str_; 61 | std::queue sp_data_queue_; 62 | boost::mutex mutex_; 63 | boost::mutex& worker_mutex_; 64 | }; 65 | 66 | class ZeroWorkerHandler : public utils::ConnectionHandler { 67 | public: 68 | ZeroWorkerHandler(boost::asio::io_service& io_service, ZeroWorkerSharedData& shared_data) 69 | : ConnectionHandler(io_service), 70 | is_idle_(false), 71 | shared_data_(shared_data) 72 | { 73 | } 74 | 75 | void handleReceivedMessage(const std::string& message) override; 76 | void close() override; 77 | void syncConfig(); 78 | 79 | inline bool isIdle() const { return is_idle_; } 80 | inline std::string getName() const { return name_; } 81 | inline std::string getType() const { return type_; } 82 | inline void setIdle(bool is_idle) { is_idle_ = is_idle; } 83 | 84 | private: 85 | bool is_idle_; 86 | std::string name_; 87 | std::string type_; 88 | ZeroWorkerSharedData& shared_data_; 89 | }; 90 | 91 | class ZeroServer : public utils::BaseServer { 92 | public: 93 | ZeroServer() 94 | : BaseServer(minizero::config::zero_server_port), 95 | shared_data_(worker_mutex_), 96 | keep_alive_timer_(io_service_) 97 | { 98 | startKeepAlive(); 99 | } 100 | 101 | virtual void run(); 102 | boost::shared_ptr handleAcceptNewConnection() override { return boost::make_shared(io_service_, shared_data_); } 103 | void sendInitialMessage(boost::shared_ptr connection) override {} 104 | 105 | protected: 106 | virtual void initialize(); 107 | virtual void selfPlay(); 108 | virtual void broadcastSelfPlayJob(); 109 | virtual void optimization(); 110 | virtual std::string getUpdatedConfig(); 111 | void syncConfig(); 112 | void stopJob(const std::string& job_type); 113 | void close(); 114 | void keepAlive(); 115 | void startKeepAlive(); 116 | 117 | int iteration_; 118 | ZeroWorkerSharedData shared_data_; 119 | boost::asio::deadline_timer keep_alive_timer_; 120 | }; 121 | 122 | } // namespace minizero::zero 123 | -------------------------------------------------------------------------------- /scripts/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | support_games=("atari" "gridworld") 5 | 6 | usage() { 7 | echo "Usage: $0 GAME_TYPE BUILD_TYPE" 8 | echo "" 9 | echo "Required arguments:" 10 | echo " GAME_TYPE: $(echo ${support_games[@]} | sed 's/ /, /g')" 11 | echo " BUILD_TYPE: release(default), debug" 12 | exit 1 13 | } 14 | 15 | build_game() { 16 | # check arguments is vaild 17 | game_type=${1,,} 18 | build_type=$2 19 | [[ " ${support_games[*]} " == *" ${game_type} "* ]] || usage 20 | [ "${build_type}" == "Debug" ] || [ "${build_type}" == "Release" ] || usage 21 | 22 | # check whether the build type and cache are consistent 23 | if [ -f "build/${game_type}/CMakeCache.txt" ]; then 24 | cache_build_type=$(grep -oP "CMAKE_BUILD_TYPE:STRING=\K\w+" build/${game_type}/CMakeCache.txt) 25 | if [ "${cache_build_type}" != "${build_type}" ]; then 26 | rm -rf build/${game_type} 27 | fi 28 | fi 29 | 30 | # build 31 | echo "game type: ${game_type}" 32 | echo "build type: ${build_type}" 33 | if [ ! -f "build/${game_type}/Makefile" ]; then 34 | mkdir -p build/${game_type} 35 | cd build/${game_type} 36 | cmake ../../ -DCMAKE_BUILD_TYPE=${build_type} -DGAME_TYPE=${game_type^^} 37 | else 38 | cd build/${game_type} 39 | fi 40 | 41 | # create git info file 42 | git_hash=$(git log -1 --format=%H) 43 | git_short_hash=$(git describe --abbrev=6 --dirty --always --exclude '*') 44 | mkdir -p git_info 45 | git_info=$(echo -e "#pragma once\n\n#define GIT_HASH \"${git_hash}\"\n#define GIT_SHORT_HASH \"${git_short_hash}\"") 46 | if [ ! -f git_info/git_info.h ] || [ $(diff -q <(echo "${git_info}") <(cat git_info/git_info.h) | wc -l 2>/dev/null) -ne 0 ]; then 47 | echo "${git_info}" > git_info/git_info.h 48 | fi 49 | 50 | # make 51 | make -j$(nproc --all) 52 | cd ../.. 53 | } 54 | 55 | # add environment settings 56 | git config core.hooksPath .githooks 57 | 58 | game_type=${1:-all} 59 | build_type=${2:-release} 60 | build_type=$(echo ${build_type:0:1} | tr '[:lower:]' '[:upper:]')$(echo ${build_type:1} | tr '[:upper:]' '[:lower:]') 61 | [ "${game_type}" == "all" ] && [ ! -d "build" ] && usage 62 | 63 | if [ "${game_type}" == "all" ]; then 64 | for game in build/* 65 | do 66 | build_game $(basename ${game}) ${build_type} 67 | done 68 | else 69 | build_game ${game_type} ${build_type} 70 | fi 71 | -------------------------------------------------------------------------------- /scripts/start-container.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | usage() 5 | { 6 | echo "Usage: $0 [OPTION]..." 7 | echo "" 8 | echo "Optional arguments:" 9 | echo " -h, --help Give this help list" 10 | echo " --image Select the image name of the container" 11 | echo " -v, --volume Bind mount a volume into the container" 12 | echo " --name Assign a name to the container" 13 | echo " -d, --detach Run container in background and print container ID" 14 | echo " -H, --history Record the container bash history" 15 | exit 1 16 | } 17 | 18 | image_name=kds285/minizero:latest 19 | container_tool=$(basename $(which podman || which docker) 2>/dev/null) 20 | if [[ ! $container_tool ]]; then 21 | echo "Neither podman nor docker is installed." >&2 22 | exit 1 23 | fi 24 | container_volume="-v .:/workspace" 25 | container_argumenets="" 26 | record_history=false 27 | while :; do 28 | case $1 in 29 | -h|--help) shift; usage 30 | ;; 31 | --image) shift; image_name=${1} 32 | ;; 33 | -v|--volume) shift; container_volume="${container_volume} -v ${1}" 34 | ;; 35 | --name) shift; container_argumenets="${container_argumenets} --name ${1}" 36 | ;; 37 | -d|--detach) container_argumenets="${container_argumenets} -d" 38 | ;; 39 | -H|--history) record_history=true 40 | ;; 41 | "") break 42 | ;; 43 | *) echo "Unknown argument: $1"; usage 44 | ;; 45 | esac 46 | shift 47 | done 48 | 49 | if [ "$record_history" = true ]; then 50 | history_dir=".container_root" 51 | # if the history directory is not exist, create it and initialize the history directory 52 | if [ ! -d ${history_dir} ]; then 53 | mkdir -p ${history_dir} 54 | # start container with the history directory and copy the history to the history directory 55 | $container_tool run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --rm -it -v "${history_dir}:/container_root" ${image_name} /bin/bash -c "cp -r /root/. /container_root && touch /container_root/.bash_history && exit" 56 | fi 57 | # add the history directory to the container volume 58 | container_volume="${container_volume} -v ${history_dir}:/root" 59 | fi 60 | 61 | container_argumenets=$(echo ${container_argumenets} | xargs) 62 | echo "$container_tool run ${container_argumenets} --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --network=host --ipc=host --rm -it ${container_volume} ${image_name}" 63 | $container_tool run ${container_argumenets} --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --network=host --ipc=host --rm -it ${container_volume} ${image_name} 64 | -------------------------------------------------------------------------------- /tools/count-moves.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # ID=N:F:U:R:L:D:UR:UL:DR:DL:UF:RF:LF:DF:URF:ULF:DRF:DLF 3 | 4 | shopt -s lastpipe 5 | ID=($(tr ':' ' ' <<< ${ID:-N:F:U:R:L:D:UR:UL:DR:DL:UF:RF:LF:DF:URF:ULF:DRF:DLF})) 6 | NUM=${#ID[@]} 7 | moves=$(tr "(;)" '\n' | grep "^B" | sed "s/$/OP1[$NUM]/"; echo) 8 | paste <(grep -Eo "^B\[[0-9]+\]" <<< $moves | tr -d "B[]") <(grep -Eo "\]OP1\[[0-9-]+\].*" <<< $moves | cut -b6- | cut -d']' -f1) | tr '\t' ' ' | sed "s/^$NUM //;s/ $NUM$//" | \ 9 | while IFS= read -r res; do 10 | res=($res) 11 | echo +${res[0]} 12 | echo :${res[-1]} 13 | done | sort | uniq -c | while read -r count move; do 14 | echo $move $count 15 | done | sort -V | while read -r move count; do 16 | type=${move:0:1} 17 | move=${move:1} 18 | echo $type $(for a in $(tr '-' ' ' <<< $move); do 19 | echo ${ID[$a]} 20 | done | xargs | tr ' ' '-') $count 21 | done 22 | -------------------------------------------------------------------------------- /tools/count-options-depth-percentile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | which bc >/dev/null || { apt -qq update && apt -qq -y install bc || exit $?; } >/dev/null 2>&1 3 | shopt -s lastpipe 4 | 5 | sp_executable_file=$1; shift 6 | tmp1=$(mktemp) 7 | tmp2=$(mktemp) 8 | tmp3=$(mktemp) 9 | trap 'rm -f $tmp1 $tmp2 $tmp3' EXIT 10 | 11 | grep -Eo "\]SP\[[0-9a-f]*\]" | tr -d "SP[]" | ${sp_executable_file} -mode decompress_str -conf_str program_quiet=true | tr ':' ' ' | sed -E "s/[0-9]+/+/g;s/-//g;s/^ //" | awk '{ 12 | sum = 0 13 | max = 0 14 | for (i = 1; i <= NF; i++) { 15 | n = length($i) 16 | sum += n 17 | if (n > max) max = n 18 | } 19 | avg = sum / NF 20 | print avg, max 21 | }' > $tmp1 22 | cut -d' ' -f1 < $tmp1 | sort -n > $tmp2 # avg 23 | cut -d' ' -f2 < $tmp1 | sort -n > $tmp3 # max 24 | 25 | { 26 | for f in $tmp2 $tmp3; do # avg and max 27 | awk '{ sum += $1; count++ } END { print sum / count }' < $f 28 | ln=$(wc -l < $f) 29 | head -n1 $f # 0th percentile 30 | for idx in ${@:-"25 50 75 100"}; do 31 | ln_idx=$((idx*ln/100)) 32 | awk -v ln_idx=$ln_idx 'NR == ln_idx { print $0 }' < $f 33 | done 34 | done 35 | } | xargs 36 | -------------------------------------------------------------------------------- /tools/count-options-in-tree.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | which bc >/dev/null || { apt -qq update && apt -qq -y install bc || exit $?; } >/dev/null 2>&1 4 | shopt -s lastpipe 5 | 6 | sp_executable_file=$1 7 | Lsum=() 8 | Lmax=() 9 | N=() 10 | N_opt=() 11 | N_opt_sim=() 12 | N_type_opt=() 13 | while IFS= read -r sgf; do 14 | sgf=${sgf:1:-1} 15 | [[ $sgf =~ DLEN\[[0-9]+-([0-9]+)\] ]] || exit $? 16 | len=$((BASH_REMATCH[1]+1)) 17 | echo $((++ID)): $len moves >&2 18 | 19 | moves=$(tr ';' '\n' <<< $sgf | grep "^B") 20 | moves_nl_SP=$(grep -oEn "\]SP\[[0-9a-f]*\]" <<< $moves) 21 | moves_nl_C=$(grep -oEn "\]C\[[0-9]*\]" <<< $moves | tr -d "C[]") 22 | moves_nl_OP=$(grep -oEn "\]OP\[[0-9:-]*\]" <<< $moves | tr -d "OP[]") 23 | [[ $(wc -l <<< $moves_nl_SP ) == $(wc -l <<< $moves_nl_C ) ]] || { echo mismatch SP/C >&2; exit 100; } 24 | [[ $(cut -d: -f1 <<< $moves_nl_SP | tr '\n' -) == $(cut -d: -f1 <<< $moves_nl_C | tr '\n' - ) ]] || { echo mismatch C/OP >&2; exit 100; } 25 | [[ $(wc -l <<< $moves_nl_C ) == $(wc -l <<< $moves_nl_OP ) ]] || { echo mismatch C/OP >&2; exit 100; } 26 | [[ $(cut -d: -f1 <<< $moves_nl_C | tr '\n' -) == $(cut -d: -f1 <<< $moves_nl_OP | tr '\n' - ) ]] || { echo mismatch C/OP >&2; exit 100; } 27 | 28 | paste \ 29 | <(cut -d: -f1 <<< $moves_nl_SP | sed "s/$/-1/" | bc) \ 30 | <(cut -d: -f2 <<< $moves_nl_SP | tr -d "SP[]" | ${sp_executable_file} -mode decompress_str -conf_str program_quiet=true | tr ':' ' ' | sed -E "s/[0-9]+/+/g;s/-//g" | awk '{ 31 | max = 0; 32 | for (i = 1; i <= NF; i++) { if (length($i) > max) { max = length($i); } } 33 | print max; 34 | }') | while read -r idx l; do 35 | Lsum[$idx]=$((Lsum[$idx]+$l)) 36 | (( l > Lmax[$idx] )) && Lmax[$idx]=$l 37 | N[$idx]=$((N[$idx]+1)) 38 | done 39 | 40 | paste \ 41 | <(cut -d: -f1 <<< $moves_nl_C | sed "s/$/-1/" | bc) \ 42 | <(cut -d: -f2 <<< $moves_nl_C) \ 43 | <(cut -d: -f2- <<< $moves_nl_OP | awk -F: '{ delete u; for (i = 1; i <= NF; i++) { u[$i] = 1; }; print length(u); }') | \ 44 | while read -r idx OP_count uniq_OP_num; do 45 | N_opt[$idx]=$((N_opt[$idx]+(OP_count>0?1:0))) 46 | N_opt_sim[$idx]=$((N_opt_sim[$idx]+OP_count)) 47 | N_type_opt[$idx]=$((N_type_opt[$idx]+uniq_OP_num)) 48 | done 49 | done 50 | 51 | [[ ${#Lsum[@]} == ${#Lmax[@]} && ${#Lmax[@]} == ${#N[@]} ]] || exit 101 52 | 53 | N_total=$(bc <<< $(printf "%s+" ${N[@]})0) 54 | Lsum_total=$(bc <<< $(printf "%s+" ${Lsum[@]})0) 55 | Lavg=($(paste <(printf "%s\n" ${Lsum[@]}) <(printf "%s\n" ${N[@]}) | tr '\t' '/' | bc -l)) 56 | Lmax_max=$(printf "%s\n" ${Lmax[@]} | sort -n | tail -n1) 57 | N_opt_total=$(bc <<< $(printf "%s+" ${N_opt[@]})0 2>/dev/null) 58 | if (( N_opt_total )); then 59 | N_opt_sim_total=$(bc <<< $(printf "%s+" ${N_opt_sim[@]})0 2>/dev/null) 60 | N_type_opt_total=$(bc <<< $(printf "%s+" ${N_type_opt[@]})0 2>/dev/null) 61 | type_opt_avg=($(paste <(printf "%s\n" ${N_type_opt[@]}) <(printf "%s\n" ${N_opt[@]}) | tr '\t' '/' | sed -E "s/.+\/0$/0/" | bc -l 2>/dev/null)) 62 | paste \ 63 | <(printf "%s\n" idx "*" $(seq 0 $((${#N[@]}-1)))) \ 64 | <(printf "%s\n" num $N_total ${N[@]}) \ 65 | <(printf "%s\n" avg $(bc -l <<< $Lsum_total/$N_total) ${Lavg[@]}) \ 66 | <(printf "%s\n" max $Lmax_max ${Lmax[@]}) \ 67 | <(printf "%s\n" num_opt $N_opt_total ${N_opt[@]}) \ 68 | <(printf "%s\n" num_opt_sim $N_opt_sim_total ${N_opt_sim[@]}) \ 69 | <(printf "%s\n" avg_type_opt $(bc -l <<< $N_type_opt_total/$N_opt_total) ${type_opt_avg[@]}) 70 | else 71 | paste \ 72 | <(printf "%s\n" idx "*" $(seq 0 $((${#N[@]}-1)))) \ 73 | <(printf "%s\n" num $N_total ${N[@]}) \ 74 | <(printf "%s\n" avg $(bc -l <<< $Lsum_total/$N_total) ${Lavg[@]}) \ 75 | <(printf "%s\n" max $Lmax_max ${Lmax[@]}) 76 | fi -------------------------------------------------------------------------------- /tools/dependency_graph_generator/README.md: -------------------------------------------------------------------------------- 1 | 2 | What is the dependency_graph_generator? 3 | Accoring to good coding practise, one should not have cyclic dependencies in the code. 4 | A cylic dependency is when one class/module/service that should encapsulate some 5 | functionality is directly or indirectly dependent on itself. A service is dependent on 6 | another service when its .cpp or .h file has an #include of another .h file. You could 7 | also have a dependency by other types of couplings, but the dependency_graph_generator 8 | is not capable of detecting these. 9 | How to use the dependency_graph 10 | Step 1. 11 | Run in terminal: 12 | > python3 tools/dependency_graph_generator/dependency_graph_generator.py 13 | Step 2. 14 | A file "dependency_graph.graphml" has now been generated in 15 | tools/dependency_graph_generator/. Download this file by right clicking it. 16 | Step 3. 17 | Open website: https://www.yworks.com/yed-live/ 18 | Click "Open an Existing Document" 19 | Click "From Local File" 20 | Open "dependency_graph.graphml" that you downloaded. 21 | Alternatively you can also drag and drop the file into the browser window. 22 | Step 7. 23 | Auto format the graph by clicking the yellow round button in the bottom 24 | left courner. The boxes will show classes/services and arrows will show 25 | dependencies. -------------------------------------------------------------------------------- /tools/extract-repeated-options.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | which bc >/dev/null || { apt -qq update && apt -qq -y install bc || exit $?; } >/dev/null 2>&1 3 | shopt -s lastpipe 4 | 5 | total=0 6 | repeat=0 7 | nonrep=0 8 | grep -E "^\+.+-.+" | while read -r type option count; do 9 | if (( $(printf "%s\n" ${option//-/ } | sort -u | wc -l) > 1 )); then 10 | nonrep=$((nonrep+count)) 11 | else 12 | repeat=$((repeat+count)) 13 | fi 14 | total=$((total+count)) 15 | done 16 | { 17 | echo \#=$total 18 | echo \#repeat=$repeat 19 | echo \$nonrep=$nonrep 20 | echo %repeat=$(bc -l <<< $repeat/$total) 21 | echo %nonrep=$(bc -l <<< $nonrep/$total) 22 | } | stdbuf -o0 sed -E "s/=\./=0./g;s/[+-]?nan/0/" | while IFS= read -r stat; do 23 | item=${stat%=*} 24 | data=${stat#*=} 25 | [[ $data ]] && case $item in 26 | μ*|±*) data=$(printf "%.0f" $data); ;; 27 | %*) data=$(printf "%.2f%%" $(<<< $data*100 bc -l)); ;; 28 | esac 29 | [[ $item ]] && echo $item=$data || echo 30 | done -------------------------------------------------------------------------------- /tools/fetch-complete-latest.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | repo=${1?} 3 | limit=${2:-100} 4 | output_dir=${3} 5 | num_jobs=${4:-10} 6 | 7 | [ -d $repo/sgf ] || exit $? 8 | [[ $output_dir ]] && mkdir -p $output_dir 9 | shopt -s lastpipe 10 | tmp=$(mktemp -d -t fetch-complete-latest.XXXXXX) 11 | trap "cleanup" EXIT 12 | cleanup() { 13 | jobs -r -p | xargs -r kill 14 | rm -rf $tmp 15 | } 16 | sgf() { 17 | if [ ! -s $tmp/$1.sgf ]; then 18 | if [ ! -e $repo/sgf/$1.sgf ]; then 19 | echo /dev/null 20 | return 1 21 | fi 22 | flock -n $tmp/$1.sgf.tmp -c "sed 's/OBS\[[0-9a-f]*\]/OBS[]/g' $repo/sgf/$1.sgf > $tmp/$1.sgf.tmp && mv $tmp/$1.sgf.tmp $tmp/$1.sgf" 23 | fi 24 | while [ ! -s $tmp/$1.sgf ]; do sleep 0.1; done 25 | echo $tmp/$1.sgf 26 | return 0 27 | } 28 | sign() { 29 | grep -Eo ";[BW]\[[0-9]+\]" | xargs | tr -d ' ' 30 | } 31 | extract_part() { 32 | local part=$1 33 | local start=$2 34 | local end=$3 35 | #part=$(<<< $part sed -E "s/$(printf "%.0s;[^;]+" $(seq $((start+1))))//" | grep -Eo -m1 "$(printf "%.0s;[^;]+" $(seq $((end-start+1))))") 36 | part=${part:$(grep -o . <<< $part | grep -n -m$((start+1+1)) ";" | tail -n1 | cut -d: -f1)-1} 37 | part=${part:0:$(grep -o . <<< "$part;" | grep -n -m$((end-start+1+1)) ";" | tail -n1 | cut -d: -f1)-1} 38 | echo "$part" 39 | } 40 | log() { 41 | local nf= 42 | [[ $1 == -n ]] && shift && nf=-n 43 | log+="$@" 44 | [[ $nf && ${num_jobs:-1} -gt 1 ]] && return 45 | echo $nf "$log" >&2 46 | log= 47 | } 48 | 49 | latest_iter=${latest_iter:-$(ls -t $repo/sgf | sort -rn | head -n1 | sed 's/\.sgf//')} 50 | for iter in $(seq $latest_iter -1 1); do 51 | (( count >= limit )) && break 52 | sgf $iter >/dev/null || continue 53 | nl $(sgf $iter) | tac | grep "#$" | while IFS= read -r res && (( count < limit )); do 54 | count=$((count+1)) 55 | ID=$iter-$(<<< $res cut -f1 | xargs) 56 | complete=$(<<< $res cut -f2-) 57 | [[ $complete =~ DLEN\[([0-9]+)-([0-9]+)\] ]] 58 | start=${BASH_REMATCH[1]:-0} 59 | end=${BASH_REMATCH[2]:-0} 60 | header=$(grep -Eo "(GM|RE|SD|DLEN)\[[^]]+\]" <<< $complete | sed -E "s/DLEN\[[0-9]+-/DLEN[0-/" | xargs | tr -d ' ') 61 | if [[ $output_dir ]]; then 62 | output=$output_dir/$ID.sgf 63 | [ -e $output ] && log "#$ID $header: [-]" && continue 64 | touch $output 65 | fi 66 | ( # parallel execution block 67 | log -n "#$ID $header: ${start}-${end}" 68 | [[ $output ]] && trap "[ -s $output ] || rm -f $output" EXIT 69 | if (( start )); then 70 | [[ $complete =~ SD\[[0-9]+\] ]] 71 | SD=${BASH_REMATCH[0]} 72 | sign_complete=$(sign <<< $complete) 73 | complete_new=$(extract_part "$complete" $start $end) 74 | search_iter=$iter 75 | buffer=() 76 | while (( start )); do 77 | end_target=$((start-1)) 78 | part= 79 | while [[ ! $part ]]; do 80 | while [[ ! $buffer ]] && sgf $search_iter >/dev/null; do 81 | tac $(sgf $search_iter) | grep -v "#$" | grep -F "${SD:-;}" | while IFS= read -r buf; do 82 | buffer+=("$buf") 83 | done 84 | search_iter=$((search_iter-1)) 85 | done 86 | part=${buffer[0]:-$(sed -E "s/DLEN\[[^]]+\]/DLEN[0-${end_target}]/" <<< $complete)} 87 | buffer=("${buffer[@]:1}") 88 | [[ $part =~ DLEN\[([0-9]+)-([0-9]+)\] ]] 89 | start=${BASH_REMATCH[1]:-0} 90 | end=${BASH_REMATCH[2]:-0} 91 | [[ $SD && $end == $end_target ]] && break 92 | [[ "${sign_complete}" != "$(sign <<< $part)"* ]] && part= && continue 93 | [[ $end == $end_target ]] && break 94 | buffer=("$part" "${buffer[@]}") 95 | start=$((end+1)) 96 | end=${end_target} 97 | part=$(sed -E "s/DLEN\[[^]]+\]/DLEN[${start}-${end}]/" <<< $complete) 98 | done 99 | complete_new=$(extract_part "$part" $start $end)${complete_new} 100 | log -n " ${start}-${end}" 101 | done 102 | complete_new=$(<<< $complete grep -Eo "\(;[^;]+" | sed -E "s/DLEN\[[0-9]+-/DLEN[0-/")${complete_new} 103 | complete=$complete_new 104 | if [[ "${sign_complete}" != "$(sign <<< $complete)" ]]; then 105 | log -n " [x]" 106 | complete= 107 | fi 108 | fi 109 | if (( $(<<< $complete tr ';' '\n' | grep "^[BW]" | grep -Fv "]P[" | wc -l) )); then 110 | log -n " [!]" 111 | complete= 112 | fi 113 | log 114 | [[ $complete ]] && echo "$complete" >> ${output:-/dev/stdout} 115 | ) & 116 | while (( $(jobs -r -p | wc -l) >= ${num_jobs:-1}+1 )); do wait -n; done 117 | done 118 | done 119 | wait 120 | -------------------------------------------------------------------------------- /tools/fight-eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | usage() 4 | { 5 | echo "Usage: $0 GAME_TYPE FOLDER1 FOLDER2 CONF_FILE1 [CONF_FILE2] INTERVAL GAMENUM [OPTION]..." 6 | echo "Launch fight evaluation to evaluate the relative strengths between same iterations of two trained models." 7 | echo "" 8 | echo "Required arguments:" 9 | echo " GAME_TYPE: $(find ./ ../ -maxdepth 2 -name build.sh -exec grep -m1 support_games {} \; -quit | sed -E 's/.+\("|"\).*//g;s/" "/, /g')" 10 | echo " FOLDER1, FOLDER2: the two model folders to be evaluated" 11 | echo " CONF_FILE1, CONF_FILE2: the configure files (*.cfg) to use; if CONF_FILE2 is unspecified, CONF_FILE1 is used" 12 | echo " INTERVAL: the iteration interval between each evaluated model pair" 13 | echo " GAMENUM: the number of games to play for each model pair" 14 | echo "" 15 | echo "Optional arguments:" 16 | echo " -h, --help Give this help list" 17 | echo " -s Start from which file in the folder (default 0)" 18 | echo " -b Board size (default is env_board_size in CONF_FILE)" 19 | echo " -g, --gpu Assign available GPUs, e.g. 0123" 20 | echo " --num_threads Number of threads to play games" 21 | echo " -d Result Folder Name (default [Folder1]_vs_[Folder2]_eval)" 22 | echo " -conf_str Add additional configure string for programs" 23 | echo " --sp_executable_file Assign the path of executable file" 24 | exit 1 25 | } 26 | 27 | # check arguments 28 | if [ $# -lt 6 ]; 29 | then 30 | usage 31 | else 32 | GAME_TYPE=$1; shift 33 | FOLDER1=$1; shift 34 | FOLDER2=$1; shift 35 | CONF_FILE1=$1; shift 36 | [[ $1 == *.cfg ]] && { CONF_FILE2=$1; shift; } || CONF_FILE2=$CONF_FILE1 37 | INTERVAL=$1; shift 38 | GAMENUM=$1; shift 39 | fi 40 | 41 | # default arguments 42 | START=0 43 | NUM_GPU=$(nvidia-smi -L | wc -l) 44 | GPU_LIST=$(echo $NUM_GPU | awk '{for(i=0;i<$1;i++)printf i}') 45 | num_threads=2 46 | BOARD_SIZE=$({ grep env_board_size= $CONF_FILE1 || echo =9; } | sed -E "s/^[^=]*=| *[#].*$//g") 47 | NAME="$(basename ${FOLDER1})_vs_$(basename ${FOLDER2})_eval" 48 | sp_executable_file=build/${GAME_TYPE}/minizero_${GAME_TYPE} 49 | while :; do 50 | case $1 in 51 | -h|--help) shift; usage 52 | ;; 53 | -g|--gpu) shift; GPU_LIST=$1; NUM_GPU=${#GPU_LIST} 54 | ;; 55 | -f) shift; CONF_FILE2=$1 56 | ;; 57 | -b) shift; BOARD_SIZE=$1 58 | ;; 59 | -s) shift; START=$1 60 | ;; 61 | -d) shift; NAME=$1 62 | ;; 63 | --num_threads) shift; num_threads=$1 64 | ;; 65 | -conf_str) shift; conf_str=$1 66 | ;; 67 | --sp_executable_file) shift; sp_executable_file=$1 68 | ;; 69 | "") break 70 | ;; 71 | *) echo "Unknown argument: $1"; usage 72 | ;; 73 | esac 74 | shift 75 | done 76 | 77 | echo "$0 $GAME_TYPE $FOLDER1 $FOLDER2 $CONF_FILE1 $INTERVAL $GAMENUM -s $START -f $CONF_FILE2 -b $BOARD_SIZE -g $GPU_LIST -d $NAME --num_threads $num_threads ${conf_str:+-conf_str $conf_str} --sp_executable_file $sp_executable_file" 78 | 79 | if [ ! -d "${FOLDER1}" ] || [ ! -d "${FOLDER2}" ]; then 80 | echo "${FOLDER1} or ${FOLDER2} not exists!" 81 | exit 1 82 | fi 83 | 84 | if [ ! -d "${FOLDER1}/$NAME" ]; then 85 | mkdir "${FOLDER1}/$NAME" 86 | fi 87 | echo "FOLDERS: $FOLDER1 & $FOLDER2, CONF_FILES: $CONF_FILE1 & $CONF_FILE2 " 88 | function run_twogtp(){ 89 | BLACK="$sp_executable_file -conf_file $CONF_FILE1 -conf_str \"${conf_str:+$conf_str:}nn_file_name=$FOLDER1/model/$2\"" 90 | WHITE="$sp_executable_file -conf_file $CONF_FILE2 -conf_str \"${conf_str:+$conf_str:}nn_file_name=$FOLDER2/model/$2\"" 91 | EVAL_FOLDER="${FOLDER1}/$NAME/${2:12:-3}" 92 | SGFFILE="${EVAL_FOLDER}/${2:12:-3}" 93 | if [ -f "$SGFFILE.lock" ] || [ -f "${SGFFILE}-$((${GAMENUM}-1)).sgf" ] || [ ! -f "$FOLDER2/model/$2" ] || [ ! -f "$FOLDER1/model/$2" ] ; then 94 | return 95 | fi 96 | 97 | if [ ! -d "${EVAL_FOLDER}" ];then 98 | mkdir $EVAL_FOLDER 99 | fi 100 | KOMI=0 101 | if [[ $GAME_TYPE == go ]]; then 102 | KOMI=7 103 | fi 104 | echo "GPUID: $1, Current players: ${2:12:-3}, Game num $GAMENUM" 105 | CUDA_VISIBLE_DEVICES=$1 gogui-twogtp -black "$BLACK" -white "$WHITE" -games $GAMENUM -sgffile $SGFFILE -alternate -auto -size $BOARD_SIZE -komi $KOMI -threads $num_threads 106 | } 107 | function run_gpu(){ 108 | models=($(ls $FOLDER1/model | grep ".pt$" | sort -V)) 109 | for((i=$START;i<${#models[@]};i=i+$INTERVAL)) 110 | do 111 | run_twogtp $1 ${models[$i]} 112 | done 113 | echo "GPUID $1 done!" 114 | } 115 | for (( i=0; i < ${#GPU_LIST} ; i = i+1 )) 116 | do 117 | GPUID=${GPU_LIST:$i:1} 118 | run_gpu $GPUID & 119 | sleep 10 120 | done 121 | wait 122 | echo "All done!" 123 | -------------------------------------------------------------------------------- /tools/option_analysis.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import argparse 3 | import subprocess 4 | import os 5 | def execute_linux_cmd(cmd: str): 6 | ps = subprocess.Popen(cmd, executable='/bin/bash', shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 7 | output = ps.communicate()[0] 8 | ps.terminate() 9 | return output 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('-in_dir', dest='in_dir', type=str, help='dir to analysis option') 14 | args = parser.parse_args() 15 | if not args.in_dir: 16 | parser.print_help() 17 | exit(1) 18 | in_dir=args.in_dir 19 | in_option_dir = os.path.join(in_dir, 'option_analysis') 20 | 21 | # option in games 22 | result={} 23 | output = execute_linux_cmd(f'cat {in_option_dir}/stats/moves-stats.txt | grep \'%o\'').decode().split('\n')[:-1] 24 | avg_len=0 25 | for i, o in enumerate(output): 26 | ratio=float(o.split('=')[1].replace('%', 'e-2')) if o[-1] != '=' else 0 27 | if i == 0: 28 | result[f'% a'] = [round(1 - ratio, 4)] 29 | result[f'% o'] = [round(ratio, 4)] 30 | continue 31 | avg_len += ratio * i 32 | if i > 1: 33 | result[f'% {i}'] = [ratio] 34 | result['avg. l'] = [round(avg_len, 2)] 35 | output = execute_linux_cmd(f'cat {in_option_dir}/stats/repeated-options.txt | grep \'%repeat\'').decode().split('\n')[0] 36 | ratio=float(output.split('=')[1].replace('%', 'e-2')) if output[-1] != '=' else 1 37 | result['% Rpt.'] = [round(ratio, 4)] 38 | result['% NRpt.'] = [round(1 - ratio, 4)] 39 | result_df = pd.DataFrame(result) 40 | result_df.to_csv(os.path.join(in_option_dir, f'option_in_games.csv'), index=False) 41 | 42 | # option in trees 43 | result={} 44 | output = execute_linux_cmd(f'cat {in_option_dir}/stats/options-in-tree.txt | grep \'*\'').decode().split('\n')[0].split('\t') 45 | num = float(output[1]) 46 | num_opt = float(output[4]) if len(output) > 4 else 0 47 | num_opt_sim = float(output[5]) if len(output) > 5 else 0 48 | 49 | output = execute_linux_cmd(f'cat {in_dir}/$(basename \'{in_dir}\').cfg | grep -oE \'actor_num_simulation=[0-9]*\'') 50 | actor_num_simulation = int(output.decode().split('actor_num_simulation=')[1]) 51 | result['% in Tree'] = [round(num_opt / num if num > 0 else 0, 4)] 52 | result['% in Sim.'] = [round(num_opt_sim / (num * actor_num_simulation) if num > 0 else 0, 4)] 53 | output = execute_linux_cmd(f'cat {in_option_dir}/stats/options-depth-percentile.txt').decode().replace('\n','').split(' ')[4:] 54 | result['Avg. tree depth'] = [float(output[0])] 55 | result['Median tree depth'] = [float(output[2])] 56 | result['Max tree depth'] = [float(output[3])] 57 | result_df = pd.DataFrame(result) 58 | result_df.to_csv(os.path.join(in_option_dir, f'option_in_trees.csv'), index=False) 59 | -------------------------------------------------------------------------------- /tools/self-eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | usage() 4 | { 5 | echo "Usage: $0 GAME_TYPE FOLDER CONF_FILE INTERVAL GAMENUM [OPTION]..." 6 | echo "Launch self evalutation to evaluate the relative strengths between different iterations of trained model." 7 | echo "" 8 | echo "Required arguments:" 9 | echo " GAME_TYPE: $(find ./ ../ -maxdepth 2 -name build.sh -exec grep -m1 support_games {} \; -quit | sed -E 's/.+\("|"\).*//g;s/" "/, /g')" 10 | echo " FOLDER: the model folder, e.g., tictactoe_az_1bx256_n50-8c2433" 11 | echo " CONF_FILE: the configure file (*.cfg) to use" 12 | echo " INTERVAL: the iteration interval between each evaluated model pair" 13 | echo " GAMENUM: the number of games to play for each model pair" 14 | echo "" 15 | echo "Optional arguments:" 16 | echo " -h, --help Give this help list" 17 | echo " -s Start from which file in the folder (default 0)" 18 | echo " -b Board size (default is env_board_size in CONF_FILE)" 19 | echo " -g, --gpu Assign available GPUs, e.g. 0123" 20 | echo " --num_threads Number of threads to play games" 21 | echo " -d Result Folder Name (default self_eval)" 22 | echo " -conf_str Add additional configure string for program" 23 | echo " --sp_executable_file Assign the path of executable file" 24 | exit 1 25 | } 26 | 27 | # check arguments 28 | if [ $# -lt 5 ]; 29 | then 30 | usage 31 | else 32 | GAME_TYPE=$1; shift 33 | FOLDER=$1; shift 34 | CONF_FILE=$1; shift 35 | INTERVAL=$1; shift 36 | GAMENUM=$1; shift 37 | fi 38 | 39 | # default arguments 40 | START=0 41 | NUM_GPU=$(nvidia-smi -L | wc -l) 42 | GPU_LIST=$(echo $NUM_GPU | awk '{for(i=0;i<$1;i++)printf i}') 43 | num_threads=2 44 | BOARD_SIZE=$({ grep env_board_size= $CONF_FILE || echo =9; } | sed -E "s/^[^=]*=| *[#].*$//g") 45 | NAME="self_eval" 46 | sp_executable_file=build/${GAME_TYPE}/minizero_${GAME_TYPE} 47 | 48 | while :; do 49 | case $1 in 50 | -h|--help) shift; usage 51 | ;; 52 | -g|--gpu) shift; GPU_LIST=$1; NUM_GPU=${#GPU_LIST} 53 | ;; 54 | -b) shift; BOARD_SIZE=$1 55 | ;; 56 | -s) shift; START=$1 57 | ;; 58 | -d) shift; NAME=$1 59 | ;; 60 | --num_threads) shift; num_threads=$1 61 | ;; 62 | -conf_str) shift; conf_str=$1 63 | ;; 64 | --sp_executable_file) shift; sp_executable_file=$1 65 | ;; 66 | "") break 67 | ;; 68 | *) echo "Unknown argument: $1"; usage 69 | ;; 70 | esac 71 | shift 72 | done 73 | echo "$0 $GAME_TYPE $FOLDER $CONF_FILE $INTERVAL $GAMENUM -s $START -b $BOARD_SIZE -g $GPU_LIST -d $NAME --num_threads $num_threads ${conf_str:+-conf_str $conf_str} --sp_executable_file $sp_executable_file" 74 | if [ ! -d "${FOLDER}" ]; then 75 | echo "${FOLDER} not exists!" 76 | exit 1 77 | fi 78 | 79 | if [ ! -d "${FOLDER}/$NAME" ]; then 80 | mkdir "${FOLDER}/$NAME" 81 | fi 82 | 83 | function run_twogtp(){ 84 | BLACK="$sp_executable_file -conf_file $CONF_FILE -conf_str \"${conf_str:+$conf_str:}nn_file_name=$FOLDER/model/$3\"" 85 | WHITE="$sp_executable_file -conf_file $CONF_FILE -conf_str \"${conf_str:+$conf_str:}nn_file_name=$FOLDER/model/$2\"" 86 | EVAL_FOLDER="${FOLDER}/$NAME/${3:12:-3}_vs_${2:12:-3}" 87 | SGFFILE="${EVAL_FOLDER}/${3:12:-3}_vs_${2:12:-3}" 88 | if [ ! -d "${EVAL_FOLDER}" ];then 89 | mkdir $EVAL_FOLDER 90 | fi 91 | if [ -f "$SGFFILE.lock" ] || [ -f "${SGFFILE}-$((${GAMENUM}-1)).sgf" ] ; then 92 | return 93 | fi 94 | KOMI=0 95 | if [[ $GAME_TYPE == go ]]; then 96 | KOMI=7 97 | fi 98 | echo "GPUID: $1, Current players: ${3:12:-3} vs. ${2:12:-3}, Game num $GAMENUM" 99 | CUDA_VISIBLE_DEVICES=$1 gogui-twogtp -black "$BLACK" -white "$WHITE" -games $GAMENUM -sgffile $SGFFILE -alternate -auto -size $BOARD_SIZE -komi $KOMI -threads $num_threads 100 | } 101 | function run_gpu(){ 102 | models=($(ls $FOLDER/model | grep ".pt$" | sort -V)) 103 | for((i=$START;i<${#models[@]}-$INTERVAL;i=i+$INTERVAL)) 104 | do 105 | run_twogtp $1 ${models[$i]} ${models[$(($i+$INTERVAL))]} 106 | done 107 | echo "GPUID $1 done!" 108 | } 109 | for (( i=0; i < ${#GPU_LIST} ; i = i+1 )) 110 | do 111 | GPUID=${GPU_LIST:$i:1} 112 | run_gpu $GPUID & 113 | sleep 10 114 | done 115 | wait 116 | echo "All done!" 117 | -------------------------------------------------------------------------------- /tools/to-sgf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import sys 5 | import os 6 | 7 | gm_map = {'go': '1', 8 | 'hex': '11'} 9 | 10 | 11 | def getGame(content): 12 | if content in gm_map: 13 | return gm_map[content] 14 | print(f'Not supported game type: {content}', file=sys.stderr) 15 | print(f'Supported game types: {list(gm_map.keys())}', file=sys.stderr) 16 | exit(1) 17 | 18 | 19 | def tosgf(source): 20 | trans_infos = '' 21 | for info in source: 22 | info = info[info.find('(') + 1:info.find(')')] # remove ( ) 23 | trans_info = '(;FF[4]' 24 | l_bracket = info.find('[') 25 | r_bracket = info.find(']') 26 | board_sz = 0 27 | while l_bracket != -1 and r_bracket != -1: 28 | label = info[:l_bracket] 29 | content = info[l_bracket + 1:r_bracket] 30 | if label[0] == ';': # new format 31 | label = label[1:] 32 | if label == 'B' or label == 'W': 33 | if board_sz <= 0: 34 | raise ValueError('Invalid board size!') 35 | label = ';' + label 36 | content = content.split('|')[0].split(']')[0] 37 | position = int(content) 38 | if position == board_sz**2: 39 | content = '' 40 | else: 41 | content = chr(ord('a') + position % board_sz) + \ 42 | chr(ord('a') + (board_sz - 1 - position // board_sz)) 43 | elif label == 'SZ': 44 | board_sz = int(content) 45 | elif label == 'GM': 46 | content = getGame(content.split('_')[0]) 47 | trans_info += f'{label}[{content}]' 48 | info = info[r_bracket + 1:] 49 | l_bracket = info.find('[') 50 | r_bracket = info.find(']') 51 | trans_info += ')\n' 52 | trans_infos += trans_info 53 | return trans_infos 54 | 55 | 56 | if __name__ == '__main__': 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('-in_file', dest='fin_name', type=str, 59 | help='input flie') 60 | parser.add_argument('-out_file', dest='fout_name', type=str, 61 | help='output flie') 62 | parser.add_argument('--force', action='store_true', 63 | dest='force', help='overwrite files') 64 | args = parser.parse_args() 65 | if args.fin_name: 66 | if os.path.isfile(args.fin_name): 67 | with open(args.fin_name, 'r') as fin: 68 | trans_infos = tosgf(fin.readlines()) 69 | else: 70 | print(f'\"{args.fin_name}\" does not exist!', file=sys.stderr) 71 | exit(1) 72 | else: 73 | trans_infos = tosgf(sys.stdin) 74 | if args.fout_name: 75 | if not args.force and os.path.isfile(args.fout_name): 76 | print( 77 | f'*** {args.fout_name} exists! Use --force to overwrite it. ***') 78 | exit(1) 79 | with open(args.fout_name, 'w') as fout: 80 | fout.write(trans_infos) 81 | print(f'Writed to {args.fout_name}.') 82 | else: 83 | print(trans_infos) 84 | --------------------------------------------------------------------------------