├── .autopep8
├── .clang-format
├── .githooks
├── autopep8.py
├── cpplint.py
└── pre-commit
├── .gitignore
├── CMakeLists.txt
├── README.md
├── cfg
├── atari.cfg
├── gridworld.cfg
└── gridworld_eval.cfg
├── docs
├── Console.md
├── Development.md
├── Evaluation.md
├── Training.md
└── imgs
│ ├── minizero-architecture.svg
│ ├── minizero_atari.svg
│ ├── minizero_go_9x9.svg
│ ├── minizero_othello_8x8.svg
│ └── optionzero_atari.svg
├── gridworld_maps
└── gridworld_000.txt
├── minizero
├── CMakeLists.txt
├── actor
│ ├── CMakeLists.txt
│ ├── actor_group.cpp
│ ├── actor_group.h
│ ├── base_actor.cpp
│ ├── base_actor.h
│ ├── create_actor.h
│ ├── gumbel_zero.cpp
│ ├── gumbel_zero.h
│ ├── mcts.cpp
│ ├── mcts.h
│ ├── search.h
│ ├── tree.h
│ ├── zero_actor.cpp
│ └── zero_actor.h
├── config
│ ├── CMakeLists.txt
│ ├── configuration.cpp
│ ├── configuration.h
│ ├── configure_loader.cpp
│ └── configure_loader.h
├── console
│ ├── CMakeLists.txt
│ ├── console.cpp
│ ├── console.h
│ ├── mode_handler.cpp
│ └── mode_handler.h
├── environment
│ ├── CMakeLists.txt
│ ├── amazons
│ │ ├── amazons.cpp
│ │ └── amazons.h
│ ├── atari
│ │ ├── atari.cpp
│ │ ├── atari.h
│ │ ├── obs_recover.cpp
│ │ ├── obs_recover.h
│ │ ├── obs_remover.cpp
│ │ └── obs_remover.h
│ ├── base
│ │ ├── base_env.cpp
│ │ └── base_env.h
│ ├── breakthrough
│ │ ├── breakthrough.cpp
│ │ └── breakthrough.h
│ ├── clobber
│ │ ├── clobber.cpp
│ │ └── clobber.h
│ ├── conhex
│ │ ├── conhex.cpp
│ │ ├── conhex.h
│ │ ├── conhex_graph.cpp
│ │ ├── conhex_graph.h
│ │ ├── conhex_graph_cell.cpp
│ │ ├── conhex_graph_cell.h
│ │ ├── conhex_graph_flag.h
│ │ ├── disjoint_set_union.cpp
│ │ └── disjoint_set_union.h
│ ├── connect6
│ │ ├── connect6.cpp
│ │ └── connect6.h
│ ├── dotsandboxes
│ │ ├── dotsandboxes.cpp
│ │ └── dotsandboxes.h
│ ├── environment.h
│ ├── go
│ │ ├── go.cpp
│ │ ├── go.h
│ │ ├── go_area.h
│ │ ├── go_block.h
│ │ ├── go_data_structure_check.cpp
│ │ ├── go_grid.h
│ │ └── go_unit.h
│ ├── gomoku
│ │ ├── gomoku.cpp
│ │ └── gomoku.h
│ ├── gridworld
│ │ ├── gridworld.cpp
│ │ └── gridworld.h
│ ├── havannah
│ │ ├── havannah.cpp
│ │ └── havannah.h
│ ├── hex
│ │ ├── hex.cpp
│ │ └── hex.h
│ ├── killallgo
│ │ ├── killallgo.cpp
│ │ ├── killallgo.h
│ │ ├── killallgo_7x7_bitboard.h
│ │ ├── killallgo_seki_7x7.cpp
│ │ └── killallgo_seki_7x7.h
│ ├── linesofaction
│ │ ├── linesofaction.cpp
│ │ └── linesofaction.h
│ ├── nogo
│ │ └── nogo.h
│ ├── othello
│ │ ├── othello.cpp
│ │ └── othello.h
│ ├── rubiks
│ │ ├── rubiks.cpp
│ │ └── rubiks.h
│ ├── santorini
│ │ ├── bitboard.h
│ │ ├── board.cpp
│ │ ├── board.h
│ │ ├── santorini.cpp
│ │ └── santorini.h
│ ├── stochastic
│ │ ├── puzzle2048
│ │ │ ├── bitboard.h
│ │ │ ├── puzzle2048.cpp
│ │ │ └── puzzle2048.h
│ │ └── stochastic_env.h
│ ├── surakarta
│ │ ├── surakarta.cpp
│ │ └── surakarta.h
│ └── tictactoe
│ │ ├── tictactoe.cpp
│ │ └── tictactoe.h
├── learner
│ ├── CMakeLists.txt
│ ├── data_loader.cpp
│ ├── data_loader.h
│ ├── pybind.cpp
│ └── train.py
├── minizero.cpp
├── network
│ ├── CMakeLists.txt
│ ├── alphazero_network.h
│ ├── create_network.h
│ ├── muzero_network.h
│ ├── network.cpp
│ ├── network.h
│ └── py
│ │ ├── alphazero_network.py
│ │ ├── create_network.py
│ │ ├── muzero_atari_network.py
│ │ ├── muzero_gridworld_network.py
│ │ ├── muzero_network.py
│ │ └── network_unit.py
├── utils
│ ├── CMakeLists.txt
│ ├── base_server.h
│ ├── color_message.h
│ ├── ostream_redirector.h
│ ├── paralleler.h
│ ├── random.cpp
│ ├── random.h
│ ├── rotation.h
│ ├── sgf_loader.cpp
│ ├── sgf_loader.h
│ ├── thread_pool.h
│ ├── time_system.h
│ ├── tqdm.h
│ ├── utils.h
│ └── vector_map.h
└── zero
│ ├── CMakeLists.txt
│ ├── zero_server.cpp
│ └── zero_server.h
├── scripts
├── build.sh
├── start-container.sh
├── zero-server.sh
└── zero-worker.sh
└── tools
├── analysis.py
├── count-moves.sh
├── count-options-depth-percentile.sh
├── count-options-in-tree.sh
├── dependency_graph_generator
├── README.md
└── dependency_graph_generator.py
├── eval.py
├── extract-moves-stats.sh
├── extract-repeated-options.sh
├── fetch-complete-latest.sh
├── fight-eval.sh
├── handle_obs.sh
├── option_analysis.py
├── plot_board.py
├── quick-run.sh
├── self-eval.sh
├── sgf_analysis.py
├── to-sgf.py
└── to-video.py
/.autopep8:
--------------------------------------------------------------------------------
1 | [pycodestyle]
2 | max-line-length = 200
3 |
--------------------------------------------------------------------------------
/.clang-format:
--------------------------------------------------------------------------------
1 | BasedOnStyle: LLVM
2 | BreakBeforeBraces: WebKit
3 | IndentWidth: 4
4 | PointerAlignment: Left
5 | AccessModifierOffset: -4
6 | ColumnLimit: 0
7 | AllowShortBlocksOnASingleLine: true
8 | AllowShortCaseLabelsOnASingleLine: true
9 | AllowShortFunctionsOnASingleLine: true
10 | AllowShortIfStatementsOnASingleLine: true
11 | AllowShortLoopsOnASingleLine: true
12 | NamespaceIndentation: Inner
13 | IndentCaseLabels: true
14 |
--------------------------------------------------------------------------------
/.githooks/pre-commit:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # An example hook script to verify what is about to be committed.
4 | # Called by "git commit" with no arguments. The hook should
5 | # exit with non-zero status after issuing an appropriate message if
6 | # it wants to stop the commit.
7 | #
8 | # To enable this hook, rename this file to "pre-commit".
9 |
10 | if git rev-parse --verify HEAD >/dev/null 2>&1
11 | then
12 | against=HEAD
13 | else
14 | # Initial commit: diff against an empty tree object
15 | against=$(git hash-object -t tree /dev/null)
16 | fi
17 |
18 | # Redirect output to stderr.
19 | exec 1>&2
20 |
21 | # check cpp, reference: https://github.com/cpplint/cpplint & https://qiita.com/janus_wel/items/cfc6914d6b7b8bf185b6
22 | format_warning=0
23 | filters='-build/c++11,-build/include_subdir,-build/include_order,-build/namespaces,-legal/copyright,-readability/todo,-runtime/explicit,-runtime/references,-runtime/string,-whitespace/braces,-whitespace/comments,-whitespace/indent,-whitespace/line_length'
24 | for file in $(git diff --staged --name-only $against -- | grep -E '\.[ch](pp)?$'); do
25 | python3 .githooks/cpplint.py --filter=$filters $file > /dev/null
26 | format_warning=$(expr ${format_warning} + $?)
27 | done
28 |
29 | # check clang-format, reference: https://gist.github.com/alexeagle/c8ed91b14a407342d9a8e112b5ac7dab
30 | for file in $(git diff --staged --name-only $against -- | grep -E '\.[ch](pp)?$'); do
31 | clangformat_output=$(git-clang-format --diff -q $file)
32 | [[ "$clangformat_output" != *"no modified files to format"* ]] && [[ "$clangformat_output" != *"clang-format did not modify any files"* ]] && [[ ! -z "$clangformat_output" ]] && git-clang-format --diff -q $file && format_warning=$(expr ${format_warning} + 1)
33 | done
34 |
35 | # check python format, reference: https://github.com/hhatto/autopep8
36 | [ $(pip list | grep pycodestyle | wc -l) -eq 1 ] || pip install pycodestyle >/dev/null 2>/dev/null
37 | for file in $(git diff --staged --name-only $against -- | grep -E '\.py?$'); do
38 | autopep8_output=$(python3 .githooks/autopep8.py -d -a -a --global-config .autopep8 $file 2>&1)
39 | [[ ! -z "$autopep8_output" ]] && python3 .githooks/autopep8.py -d -a -a --global-config .autopep8 $file && format_warning=$(expr ${format_warning} + 1)
40 | done
41 |
42 | # verify if the modified code pass the check
43 | [[ ${format_warning} -ne 0 ]] && echo && echo "Found ${format_warning} format warning(s): check format again!" && exit 1
44 |
45 | # return successfully
46 | exit 0
47 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # execute and library file
2 | build/
3 | __pycache__
4 |
5 | # configuration file
6 | # *.cfg
7 |
8 | # training folder
9 | *_az_*/
10 | *_gaz_*/
11 | *_mz_*/
12 | *_gmz_*/
13 |
14 | # others
15 | .vscode
16 | .container_root
17 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.16)
2 |
3 | project(minizero)
4 |
5 | # specify the C++ standard
6 | set(CMAKE_CXX_STANDARD 17)
7 | set(CMAKE_CXX_STANDARD_REQUIRED True)
8 |
9 | find_package(Torch REQUIRED)
10 | find_package(Boost COMPONENTS system thread iostreams)
11 | find_package(ale REQUIRED)
12 | find_package(OpenCV REQUIRED)
13 |
14 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
15 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
16 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
17 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -g -Wall -mpopcnt -O3 -pthread")
18 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -Wall -Wno-unused-function -O0 -pthread")
19 |
20 | # for git info
21 | include_directories(${PROJECT_BINARY_DIR}/git_info)
22 |
23 | add_subdirectory(minizero)
24 | add_subdirectory(minizero/actor)
25 | add_subdirectory(minizero/config)
26 | add_subdirectory(minizero/console)
27 | add_subdirectory(minizero/environment)
28 | add_subdirectory(minizero/learner)
29 | add_subdirectory(minizero/network)
30 | add_subdirectory(minizero/utils)
31 | add_subdirectory(minizero/zero)
32 |
33 | string(TOLOWER "${PROJECT_NAME}_${GAME_TYPE}" EXE_FILE_NAME)
34 | set_target_properties(${PROJECT_NAME} PROPERTIES OUTPUT_NAME ${EXE_FILE_NAME})
35 |
--------------------------------------------------------------------------------
/docs/Console.md:
--------------------------------------------------------------------------------
1 | # Console
2 |
3 | MiniZero supports the [Go Text Protocol (GTP)](http://www.lysator.liu.se/~gunnar/gtp/) and has a built-in console for easy communication with human operators or external programs.
4 |
5 | ```bash
6 | tools/quick-run.sh console GAME_TYPE FOLDER|MODEL_FILE [CONF_FILE] [OPTION]...
7 | ```
8 |
9 | * `GAME_TYPE` sets the target game, e.g., `tictactoe`.
10 | * `FOLDER` or `MODEL_FILE` sets either the folder or the model file (`*.pt`).
11 | * `CONF_FILE` sets the config file for console.
12 | * `OPTION` sets optional arguments, e.g., `-conf_str` sets additional configurations.
13 |
14 | For detailed arguments, run `tools/quick-run.sh console -h`.
15 |
16 | Sample commands:
17 |
18 | ```bash
19 | # run a console with the latest model inside "tictactoe_az_1bx256_n50-cb69d4" using config "tictactoe_play.cfg"
20 | tools/quick-run.sh console tictactoe tictactoe_az_1bx256_n50-cb69d4 tictactoe_play.cfg
21 |
22 | # run a console with a specified model file using config "tictactoe_play.cfg"
23 | tools/quick-run.sh console tictactoe tictactoe_az_1bx256_n50-cb69d4/model/weight_iter_25000.pt tictactoe_play.cfg
24 |
25 | # run a console with the latest model inside "tictactoe_az_1bx256_n50-cb69d4" using its default config file, and overwrite several settings for console
26 | tools/quick-run.sh console tictactoe tictactoe_az_1bx256_n50-cb69d4 -conf_str actor_select_action_by_count=true:actor_use_dirichlet_noise=false:actor_num_simulation=200
27 | ```
28 |
29 | Note that the console requires a trained network model.
30 |
31 | After the console starts successfully, a message "Successfully started console mode" will be displayed.
32 | Then, use [GTP commands](https://www.gnu.org/software/gnugo/gnugo_19.html) to interact with the program, e.g., `genmove`.
33 |
34 | ## Miscellaneous Console Tips
35 |
36 | ### Attach MiniZero to GoGui
37 |
38 | [GoGui](https://github.com/Remi-Coulom/gogui) provides a graphical interface for board game AI programs, it provides two tools, `gogui-server` and `gogui-client`, to attach programs that support GTP console.
39 |
40 | To attach MiniZero to GoGui, specify the `gogui-server` port via `-p`, which will automatically starts the MiniZero console with the `gogui-server`.
41 |
42 | ```bash
43 | # host the console at port 40000 using gogui-server
44 | tools/quick-run.sh console tictactoe tictactoe_az_1bx256_n50-cb69d4 tictactoe_play.cfg -p 40000
45 | ```
46 |
--------------------------------------------------------------------------------
/docs/Evaluation.md:
--------------------------------------------------------------------------------
1 | # Evaluation
2 |
3 | MiniZero currently supports two evaluation methods to evaluate program strength: [self-evaluation](#Self-Evaluation), and [fight-evaluation](#Fight-Evaluation).
4 |
5 | ## Self-Evaluation
6 |
7 | Self-evaluation evaluates the relative strengths between different iterations in a training session, i.e., it evaluates whether a network model is continuously improving during traing.
8 |
9 | ```bash
10 | tools/quick-run.sh self-eval GAME_TYPE FOLDER [CONF_FILE] [INTERVAL] [GAMENUM] [OPTION]...
11 | ```
12 |
13 | * `GAME_TYPE` sets the target game, e.g., `tictactoe`.
14 | * `FOLDER` sets the folder to be evaluated, which should contain the `model/` subfolder.
15 | * `CONF_FILE` sets the config file for evaluation.
16 | * `INTERVAL` sets the iteration interval between each model pair to be evaluated, e.g. `10` indicates to pair the 0th and the 10th models, then the 10th and 20th models, and so on.
17 | * `GAME_NUM` sets the number of games to play for each model pair, e.g., `100`.
18 | * `OPTION` sets optional arguments, e.g., `-conf_str` sets additional configurations.
19 |
20 | For detailed arguments, run `tools/quick-run.sh self-eval -h`.
21 |
22 | Sample commands:
23 | ```bash
24 | # evaluate a TicTacToe training session using "tictactoe_play.cfg", run 100 games for each model pair: 0th vs 10th, 10th vs 20th, ...
25 | tools/quick-run.sh self-eval tictactoe tictactoe_az_1bx256_n50-cb69d4 tictactoe_play.cfg 10 100
26 |
27 | # evaluate a TicTacToe training session using its training config, overwrite several settings for evaluation
28 | tools/quick-run.sh self-eval tictactoe tictactoe_az_1bx256_n50-cb69d4 tictactoe_az_1bx256_n50-cb69d4/*.cfg 10 100 -conf_str actor_select_action_by_count=true:actor_use_dirichlet_noise=false:actor_num_simulation=200
29 |
30 | # use more threads for faster evaluation
31 | tools/quick-run.sh self-eval tictactoe tictactoe_az_1bx256_n50-cb69d4 tictactoe_play.cfg 10 100 --num_threads 20
32 | ```
33 |
34 | Note that evaluation is unnecessary for Atari games.
35 |
36 | The evaluation results are stored inside `FOLDER`, in a subfolder named `self_eval` by default, which contains the following records:
37 | * `elo.csv` saves the evaluated model strength in Elo rating.
38 | * `elo.png` plots the Elo rating of `elo.csv`.
39 | * `5000_vs_0`, `10000_vs_5000`, and other folders keep game trajectory records for each evaluated model pair.
40 |
41 | ## Fight-Evaluation
42 |
43 | Fight-evaluation evaluates the relative strengths between the same iterations of two training sessions, i.e., it compares the learning results of two network models.
44 |
45 | ```bash
46 | tools/quick-run.sh fight-eval GAME_TYPE FOLDER1 FOLDER2 [CONF_FILE1] [CONF_FILE2] [INTERVAL] [GAMENUM] [OPTION]...
47 | ```
48 |
49 | * `GAME_TYPE` sets the target game, e.g., `tictactoe`.
50 | * `FOLDER1` and `FOLDER2` set the two folders to be evaluated.
51 | * `CONF_FILE1` and `CONF_FILE2` set the config files for both folders; if `CONF_FILE2` is unspecified, `FOLDER2` will uses `CONF_FILE1` for evaluation.
52 | * `INTERVAL` sets the iteration interval between each model pair to be evaluated, e.g. `10` indicates to match the ith models of both folders, then the i+10th models, and so on.
53 | * `GAME_NUM` sets the number of games to play for each model pair, e.g., `100`.
54 | * `OPTION` sets optional arguments, e.g., `-conf_str` sets additional configurations.
55 |
56 | For detailed arguments, run `tools/quick-run.sh fight-eval -h`.
57 |
58 | Sample commands:
59 |
60 | ```bash
61 | # evaluate two training results using "tictactoe_play.cfg" for both programs, run 100 games for each model pair
62 | tools/quick-run.sh fight-eval tictactoe tictactoe_az_1bx256_n50-cb69d4 tictactoe_az_1bx256_n50-731a0f tictactoe_play.cfg 10 100
63 |
64 | # evaluate two training results using "tictactoe_cb69d4.cfg" and "tictactoe_731a0f.cfg" for the former and the latter, respectively
65 | tools/quick-run.sh fight-eval tictactoe tictactoe_az_1bx256_n50-cb69d4 tictactoe_az_1bx256_n50-731a0f tictactoe_cb69d4.cfg tictactoe_731a0f.cfg 10 100
66 | ```
67 |
68 | The evaluation results are stored inside `FOLDER1`, in a subfolder named `[FOLDER1]_vs_[FOLDER2]_eval` by default, which contains the following records:
69 | * `elo.csv` saves the evaluation statistics and strength comparisons of all evaluated model pairs.
70 | * `elo.png` plots the Elo rating comparisons reported in `elo.csv`.
71 | * `0`, `5000`, and other folders keep game trajectory records for each evaluated model pair.
72 |
73 | > **Note**
74 | > Before the fight-evaluation, it is suggested that a self-evaluation for `FOLDER1` be run first to generate a baseline strength, which is necessary for the strength comparison.
75 |
76 | ## Miscellaneous Evaluation Tips
77 |
78 | ### Configurations for evaluation
79 |
80 | Evaluation requires a different configuration from training, e.g., use more simulations and disable noise to always select the best action.
81 | ```
82 | actor_num_simulation=400
83 | actor_select_action_by_count=true
84 | actor_select_action_by_softmax_count=false
85 | actor_use_dirichlet_noise=false
86 | actor_use_gumbel_noise=false
87 | ```
88 |
89 | In addition, sometimes played games become too similar.
90 | To prevent this, use random rotation (for AlphaZero only) or even add softmax/noise back.
91 | ```
92 | actor_use_random_rotation_features=true
93 | actor_select_action_by_count=false
94 | actor_select_action_by_softmax_count=true
95 | ```
96 |
--------------------------------------------------------------------------------
/gridworld_maps/gridworld_000.txt:
--------------------------------------------------------------------------------
1 | XXXXXXXXXXXXXXXXXXXX
2 | X XXXXXX X XXX
3 | X XXXXXX XX X XXX
4 | XXX XXXXXX XX X XXX
5 | X X XX X XXX
6 | X X XX XXX XX X XXX
7 | X XXX XX X XXX
8 | X XXXXXX XX X XXX
9 | X X X XXX
10 | XXXXXXXXXXXXXXXXXXXX
11 |
--------------------------------------------------------------------------------
/minizero/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | file(GLOB SRCS *.cpp)
2 |
3 | add_executable(minizero ${SRCS})
4 | target_link_libraries(
5 | minizero
6 | actor
7 | config
8 | console
9 | environment
10 | network
11 | utils
12 | zero
13 | ${TORCH_LIBRARIES}
14 | )
--------------------------------------------------------------------------------
/minizero/actor/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | file(GLOB_RECURSE SRCS *.cpp)
2 |
3 | add_library(actor ${SRCS})
4 | target_include_directories(actor PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
5 | target_link_libraries(
6 | actor
7 | config
8 | environment
9 | network
10 | utils
11 | ${Boost_LIBRARIES}
12 | ${TORCH_LIBRARIES}
13 | )
--------------------------------------------------------------------------------
/minizero/actor/actor_group.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "base_actor.h"
4 | #include "network.h"
5 | #include "paralleler.h"
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | namespace minizero::actor {
15 |
16 | class ThreadSharedData : public utils::BaseSharedData {
17 | public:
18 | int getAvailableActorIndex();
19 | void outputGame(const std::shared_ptr& actor);
20 | std::pair calculateTrainingDataRange(const std::shared_ptr& actor);
21 |
22 | bool do_cpu_job_;
23 | int actor_index_;
24 | std::mutex mutex_;
25 | std::vector> actors_;
26 | std::vector> networks_;
27 | std::vector>> network_outputs_;
28 | };
29 |
30 | class SlaveThread : public utils::BaseSlaveThread {
31 | public:
32 | SlaveThread(int id, std::shared_ptr shared_data)
33 | : BaseSlaveThread(id, shared_data) {}
34 |
35 | void initialize() override;
36 | void runJob() override;
37 | bool isDone() override { return false; }
38 |
39 | protected:
40 | virtual bool doCPUJob();
41 | virtual void doGPUJob();
42 | virtual void handleSearchDone(int actor_id);
43 | inline std::shared_ptr getSharedData() { return std::static_pointer_cast(shared_data_); }
44 | };
45 |
46 | class ActorGroup : public utils::BaseParalleler {
47 | public:
48 | ActorGroup() {}
49 |
50 | void run();
51 | void initialize() override;
52 | void summarize() override {}
53 |
54 | protected:
55 | virtual void createNeuralNetworks();
56 | virtual void createActors();
57 | virtual void handleIO();
58 | virtual void handleCommand();
59 | virtual void handleCommand(const std::string& command_prefix, const std::string& command);
60 |
61 | void createSharedData() override { shared_data_ = std::make_shared(); }
62 | std::shared_ptr newSlaveThread(int id) override { return std::make_shared(id, shared_data_); }
63 | inline std::shared_ptr getSharedData() { return std::static_pointer_cast(shared_data_); }
64 |
65 | bool running_;
66 | std::deque commands_;
67 | std::unordered_set ignored_commands_;
68 | };
69 |
70 | } // namespace minizero::actor
71 |
--------------------------------------------------------------------------------
/minizero/actor/base_actor.cpp:
--------------------------------------------------------------------------------
1 | #include "base_actor.h"
2 | #include "configuration.h"
3 | #include
4 | #include
5 |
6 | namespace minizero::actor {
7 |
8 | void BaseActor::reset()
9 | {
10 | env_.reset();
11 | action_info_history_.clear();
12 | resetSearch();
13 | }
14 |
15 | void BaseActor::resetSearch()
16 | {
17 | nn_evaluation_batch_id_ = -1;
18 | if (!search_) { search_ = createSearch(); }
19 | search_->reset();
20 | }
21 |
22 | bool BaseActor::act(const Action& action)
23 | {
24 | bool can_act = env_.act(action);
25 | if (can_act) {
26 | action_info_history_.resize(env_.getActionHistory().size());
27 | action_info_history_.back() = getActionInfo();
28 | }
29 | return can_act;
30 | }
31 |
32 | bool BaseActor::act(const std::vector& action_string_args)
33 | {
34 | bool can_act = env_.act(action_string_args);
35 | if (can_act) {
36 | action_info_history_.resize(env_.getActionHistory().size());
37 | action_info_history_.back() = getActionInfo();
38 | }
39 | return can_act;
40 | }
41 |
42 | std::string BaseActor::getRecord(const std::unordered_map& tags /* = {} */) const
43 | {
44 | EnvironmentLoader env_loader;
45 | env_loader.loadFromEnvironment(env_, action_info_history_);
46 | env_loader.addTag("EV", config::nn_file_name.substr(config::nn_file_name.find_last_of('/') + 1));
47 |
48 | // if the game is not ended, then treat the game as a resign game, where the next player is the lose side
49 | if (!isEnvTerminal()) {
50 | float result = env_.getEvalScore(true);
51 | std::ostringstream oss;
52 | oss << result;
53 | env_loader.addTag("RE", oss.str());
54 | }
55 | for (auto tag : tags) { env_loader.addTag(tag.first, tag.second); }
56 | return env_loader.toString();
57 | }
58 |
59 | std::vector> BaseActor::getActionInfo() const
60 | {
61 | std::vector> action_info;
62 | action_info.push_back({"P", getMCTSPolicy()});
63 | action_info.push_back({"V", getMCTSValue()});
64 | action_info.push_back({"R", getEnvReward()});
65 | return action_info;
66 | }
67 |
68 | } // namespace minizero::actor
69 |
--------------------------------------------------------------------------------
/minizero/actor/base_actor.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "environment.h"
4 | #include "network.h"
5 | #include "search.h"
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 |
12 | namespace minizero::actor {
13 |
14 | using namespace minizero;
15 |
16 | class BaseActor {
17 | public:
18 | BaseActor() {}
19 | virtual ~BaseActor() = default;
20 |
21 | virtual void reset();
22 | virtual void resetSearch();
23 | bool act(const Action& action);
24 | bool act(const std::vector& action_string_args);
25 | virtual std::string getRecord(const std::unordered_map& tags = {}) const;
26 |
27 | inline bool isEnvTerminal() const { return env_.isTerminal(); }
28 | inline const float getEvalScore() const { return env_.getEvalScore(); }
29 | inline Environment& getEnvironment() { return env_; }
30 | inline const Environment& getEnvironment() const { return env_; }
31 | inline const int getNNEvaluationBatchIndex() const { return nn_evaluation_batch_id_; }
32 | inline std::vector>>& getActionInfoHistory() { return action_info_history_; }
33 | inline const std::vector>>& getActionInfoHistory() const { return action_info_history_; }
34 |
35 | virtual Action think(bool with_play = false, bool display_board = false) = 0;
36 | virtual void beforeNNEvaluation() = 0;
37 | virtual void afterNNEvaluation(const std::shared_ptr& network_output) = 0;
38 | virtual bool isSearchDone() const = 0;
39 | virtual Action getSearchAction() const = 0;
40 | virtual bool isResign() const = 0;
41 | virtual std::string getSearchInfo() const = 0;
42 | virtual void setNetwork(const std::shared_ptr& network) = 0;
43 | virtual std::shared_ptr createSearch() = 0;
44 |
45 | protected:
46 | virtual std::vector> getActionInfo() const;
47 | virtual std::string getMCTSPolicy() const = 0;
48 | virtual std::string getMCTSValue() const = 0;
49 | virtual std::string getEnvReward() const = 0;
50 |
51 | int nn_evaluation_batch_id_;
52 | Environment env_;
53 | std::shared_ptr search_;
54 | std::vector>> action_info_history_;
55 | };
56 |
57 | } // namespace minizero::actor
58 |
--------------------------------------------------------------------------------
/minizero/actor/create_actor.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "base_actor.h"
4 | #include "configuration.h"
5 | #include "zero_actor.h"
6 | #include
7 |
8 | namespace minizero::actor {
9 |
10 | inline std::shared_ptr createActor(uint64_t tree_node_size, const std::shared_ptr& network)
11 | {
12 | auto actor = std::make_shared(tree_node_size);
13 | actor->setNetwork(network);
14 | actor->reset();
15 | return actor;
16 |
17 | assert(false);
18 | return nullptr;
19 | }
20 |
21 | } // namespace minizero::actor
22 |
--------------------------------------------------------------------------------
/minizero/actor/gumbel_zero.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "mcts.h"
4 | #include
5 | #include
6 | #include
7 |
8 | namespace minizero::actor {
9 |
10 | class GumbelZero {
11 | public:
12 | std::string getMCTSPolicy(const std::shared_ptr& mcts) const;
13 | MCTSNode* decideActionNode(const std::shared_ptr& mcts);
14 | std::vector selection(const std::shared_ptr& mcts);
15 | void sequentialHalving(const std::shared_ptr& mcts);
16 | void sortCandidatesByScore(const std::shared_ptr& mcts);
17 |
18 | private:
19 | int sample_size_;
20 | int simulation_budget_;
21 | std::vector candidates_;
22 | };
23 |
24 | } // namespace minizero::actor
25 |
--------------------------------------------------------------------------------
/minizero/actor/search.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | namespace minizero::actor {
4 |
5 | class Search {
6 | public:
7 | Search() {}
8 | virtual ~Search() = default;
9 |
10 | virtual void reset() = 0;
11 | };
12 |
13 | } // namespace minizero::actor
14 |
--------------------------------------------------------------------------------
/minizero/actor/tree.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | namespace minizero::actor {
8 |
9 | template
10 | class TreeData {
11 | public:
12 | TreeData() { reset(); }
13 |
14 | inline void reset() { data_.clear(); }
15 | inline int store(const Data& data)
16 | {
17 | int index = data_.size();
18 | data_.push_back(data);
19 | return index;
20 | }
21 | inline const Data& getData(int index) const
22 | {
23 | assert(index >= 0 && index < size());
24 | return data_[index];
25 | }
26 | inline int size() const { return data_.size(); }
27 |
28 | private:
29 | std::vector data_;
30 | };
31 |
32 | class TreeNode {
33 | public:
34 | TreeNode() {}
35 | virtual ~TreeNode() = default;
36 |
37 | virtual void reset() = 0;
38 | virtual std::string toString() const = 0;
39 | virtual bool displayInTreeLog() const { return true; }
40 |
41 | inline bool isLeaf() const { return (num_children_ == 0); }
42 | inline void setAction(Action action) { action_ = action; }
43 | inline void setNumChildren(int num_children) { num_children_ = num_children; }
44 | inline void setFirstChild(TreeNode* first_child) { first_child_ = first_child; }
45 | inline void setOptionChild(TreeNode* option_child) { option_child_ = option_child; }
46 | inline void setParent(TreeNode* parent) { parent_ = parent; }
47 | inline Action getAction() const { return action_; }
48 | inline int getNumChildren() const { return num_children_; }
49 | inline virtual TreeNode* getChild(int index) const { return (index < num_children_ ? first_child_ + index : nullptr); }
50 | inline virtual TreeNode* getOptionChild() const { return option_child_; }
51 | inline virtual TreeNode* getParent() const { return parent_; }
52 |
53 | protected:
54 | Action action_;
55 | int num_children_;
56 | TreeNode* first_child_;
57 | TreeNode* option_child_;
58 | TreeNode* parent_;
59 | };
60 |
61 | class Tree {
62 | public:
63 | Tree(uint64_t tree_node_size)
64 | : tree_node_size_(tree_node_size),
65 | nodes_(nullptr)
66 | {
67 | assert(tree_node_size >= 0);
68 | }
69 |
70 | inline void reset()
71 | {
72 | if (!nodes_) { nodes_ = createTreeNodes(1 + tree_node_size_); }
73 | current_node_size_ = 1;
74 | getRootNode()->reset();
75 | }
76 |
77 | inline TreeNode* allocateNodes(int size)
78 | {
79 | assert(current_node_size_ + size <= 1 + tree_node_size_);
80 | TreeNode* node = getNodeIndex(current_node_size_);
81 | current_node_size_ += size;
82 | return node;
83 | }
84 |
85 | std::string toString(const std::string& env_string)
86 | {
87 | assert(!env_string.empty() && env_string.back() == ')');
88 | std::ostringstream oss;
89 | TreeNode* pRoot = getRootNode();
90 | std::string env_prefix = env_string.substr(0, env_string.size() - 1);
91 | oss << env_prefix << "C[" << pRoot->toString() << "]" << getTreeInfo_r(pRoot) << ")";
92 | return oss.str();
93 | }
94 |
95 | std::string getTreeInfo_r(const TreeNode* node) const
96 | {
97 | std::ostringstream oss;
98 |
99 | int numChildren = 0;
100 | for (int i = 0; i < node->getNumChildren(); ++i) {
101 | TreeNode* child = node->getChild(i);
102 | if (child->isLeaf()) { continue; }
103 | ++numChildren;
104 | }
105 |
106 | for (int i = 0; i < node->getNumChildren(); ++i) {
107 | TreeNode* child = node->getChild(i);
108 | if (!child->displayInTreeLog()) { continue; }
109 | if (numChildren > 1) { oss << "("; }
110 | oss << playerToChar(child->getAction().getPlayer())
111 | << "[" << child->getAction().getActionID() << "]"
112 | << "C[" << child->toString() << "]" << getTreeInfo_r(child);
113 | if (numChildren > 1) { oss << ")"; }
114 | }
115 | return oss.str();
116 | }
117 |
118 | inline TreeNode* getRootNode() { return &nodes_[0]; }
119 | inline const TreeNode* getRootNode() const { return &nodes_[0]; }
120 |
121 | protected:
122 | virtual TreeNode* createTreeNodes(uint64_t tree_node_size) = 0;
123 | virtual TreeNode* getNodeIndex(int index) = 0;
124 |
125 | uint64_t tree_node_size_;
126 | uint64_t current_node_size_;
127 | TreeNode* nodes_;
128 | };
129 |
130 | } // namespace minizero::actor
131 |
--------------------------------------------------------------------------------
/minizero/actor/zero_actor.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "alphazero_network.h"
4 | #include "base_actor.h"
5 | #include "gumbel_zero.h"
6 | #include "mcts.h"
7 | #include "muzero_network.h"
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 |
15 | namespace minizero::actor {
16 |
17 | class MCTSSearchData {
18 | public:
19 | std::string search_info_;
20 | std::string selection_path_;
21 | MCTSNode* selected_node_;
22 | std::vector node_path_;
23 | void clear();
24 | };
25 |
26 | class ZeroActor : public BaseActor {
27 | public:
28 | ZeroActor(uint64_t tree_node_size)
29 | : tree_node_size_(tree_node_size)
30 | {
31 | option_count_ = 0;
32 | alphazero_network_ = nullptr;
33 | muzero_network_ = nullptr;
34 | }
35 |
36 | void reset() override;
37 | void resetSearch() override;
38 | Action think(bool with_play = false, bool display_board = false) override;
39 | void beforeNNEvaluation() override;
40 | void afterNNEvaluation(const std::shared_ptr& network_output) override;
41 | bool isSearchDone() const override { return getMCTS()->reachMaximumSimulation(); }
42 | Action getSearchAction() const override { return mcts_search_data_.selected_node_->getAction(); }
43 | bool isResign() const override { return enable_resign_ && getMCTS()->isResign(mcts_search_data_.selected_node_); }
44 | std::string getSearchInfo() const override { return mcts_search_data_.search_info_; }
45 | void setNetwork(const std::shared_ptr& network) override;
46 | std::shared_ptr createSearch() override { return std::make_shared(tree_node_size_); }
47 | std::shared_ptr getMCTS() { return std::static_pointer_cast(search_); }
48 | const std::shared_ptr getMCTS() const { return std::static_pointer_cast(search_); }
49 |
50 | protected:
51 | std::vector> getActionInfo() const override;
52 | std::string getMCTSPolicy() const override { return (config::actor_use_gumbel ? gumbel_zero_.getMCTSPolicy(getMCTS()) : getMCTS()->getSearchDistributionString()); }
53 | std::string getMCTSValue() const override { return std::to_string(getMCTS()->getRootNode()->getMean()); }
54 | std::string getEnvReward() const override;
55 |
56 | virtual void step();
57 | virtual void handleSearchDone();
58 | virtual MCTSNode* decideActionNode();
59 | virtual void addNoiseToNodeChildren(MCTSNode* node);
60 | virtual std::vector selection() { return (config::actor_use_gumbel ? gumbel_zero_.selection(getMCTS()) : getMCTS()->select()); }
61 |
62 | std::vector calculateAlphaZeroActionPolicy(const Environment& env_transition, const std::shared_ptr& alphazero_output, const utils::Rotation& rotation);
63 | std::vector calculateMuZeroActionPolicy(MCTSNode* leaf_node, const std::shared_ptr& muzero_output);
64 | virtual Environment getEnvironmentTransition(const std::vector& node_path);
65 |
66 | bool enable_resign_;
67 | bool enable_option_;
68 | int option_count_;
69 | int p_max_id_;
70 | std::string used_options_;
71 | std::string depth_;
72 | std::string option_str_;
73 | GumbelZero gumbel_zero_;
74 | uint64_t tree_node_size_;
75 | MCTSSearchData mcts_search_data_;
76 | utils::Rotation feature_rotation_;
77 | std::shared_ptr alphazero_network_;
78 | std::shared_ptr muzero_network_;
79 |
80 | private:
81 | void setAlphaZeroOptionInfo(MCTSNode* leaf_node, Environment& env_transition, const std::shared_ptr& alphazero_output, const utils::Rotation& rotation);
82 | void setMuZeroOptionInfo(MCTSNode* leaf_node, const std::shared_ptr& muzero_output);
83 | std::pair, std::vector>> calculateLegalOption(Environment& env_transition, env::Player turn, const std::vector> option, utils::Rotation rotation = utils::Rotation::kRotationNone);
84 | std::vector calculateOption(env::Player turn, const std::vector> option, utils::Rotation rotation = utils::Rotation::kRotationNone);
85 | void setOptionInfo(MCTSNode* leaf_node, const std::vector> option, const std::vector option_actions, const std::vector> option_legal_actions, utils::Rotation rotation = utils::Rotation::kRotationNone);
86 | };
87 |
88 | } // namespace minizero::actor
89 |
--------------------------------------------------------------------------------
/minizero/config/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | file(GLOB SRCS *.cpp)
2 |
3 | add_library(config ${SRCS})
4 | target_include_directories(config PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
5 | target_link_libraries(config)
6 | target_compile_definitions(config PUBLIC ${GAME_TYPE})
--------------------------------------------------------------------------------
/minizero/config/configuration.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "configure_loader.h"
4 | #include
5 |
6 | namespace minizero::config {
7 |
8 | // program parameters
9 | extern int program_seed;
10 | extern bool program_auto_seed;
11 | extern bool program_quiet;
12 |
13 | // actor parameters
14 | extern int actor_num_simulation;
15 | extern float actor_mcts_puct_base;
16 | extern float actor_mcts_puct_init;
17 | extern float actor_mcts_reward_discount;
18 | extern int actor_mcts_think_batch_size;
19 | extern float actor_mcts_think_time_limit;
20 | extern bool actor_mcts_value_rescale;
21 | extern char actor_mcts_value_flipping_player;
22 | extern bool actor_select_action_by_count;
23 | extern bool actor_select_action_by_softmax_count;
24 | extern float actor_select_action_softmax_temperature;
25 | extern bool actor_select_action_softmax_temperature_decay;
26 | extern bool actor_use_random_rotation_features;
27 | extern bool actor_use_dirichlet_noise;
28 | extern float actor_dirichlet_noise_alpha;
29 | extern float actor_dirichlet_noise_epsilon;
30 | extern bool actor_use_gumbel;
31 | extern bool actor_use_gumbel_noise;
32 | extern int actor_gumbel_sample_size;
33 | extern float actor_gumbel_sigma_visit_c;
34 | extern float actor_gumbel_sigma_scale_c;
35 | extern float actor_resign_threshold;
36 |
37 | // zero parameters
38 | extern int zero_num_threads;
39 | extern int zero_num_parallel_games;
40 | extern int zero_server_port;
41 | extern std::string zero_training_directory;
42 | extern int zero_num_games_per_iteration;
43 | extern int zero_start_iteration;
44 | extern int zero_end_iteration;
45 | extern int zero_replay_buffer;
46 | extern float zero_disable_resign_ratio;
47 | extern float zero_disable_option_ratio;
48 | extern int zero_actor_intermediate_sequence_length;
49 | extern std::string zero_actor_ignored_command;
50 | extern bool zero_server_accept_different_model_games;
51 |
52 | // learner parameters
53 | extern bool learner_use_per;
54 | extern float learner_per_alpha;
55 | extern float learner_per_init_beta;
56 | extern bool learner_per_beta_anneal;
57 | extern int learner_training_step;
58 | extern int learner_training_display_step;
59 | extern int learner_batch_size;
60 | extern int learner_muzero_unrolling_step;
61 | extern int learner_n_step_return;
62 | extern float learner_learning_rate;
63 | extern float learner_momentum;
64 | extern float learner_weight_decay;
65 | extern float learner_value_loss_scale;
66 | extern float learner_option_loss_scale;
67 | extern int learner_num_thread;
68 |
69 | // network parameters
70 | extern std::string nn_file_name;
71 | extern int nn_num_blocks;
72 | extern int nn_num_hidden_channels;
73 | extern int nn_num_value_hidden_channels;
74 | extern std::string nn_type_name;
75 | extern bool nn_use_consistency;
76 |
77 | // option parameters
78 | extern int option_seq_length;
79 |
80 | // environment parameters
81 | extern int env_board_size;
82 |
83 | // environment parameters for specific game
84 | extern std::string env_atari_rom_dir;
85 | extern std::string env_atari_name;
86 | extern bool env_conhex_use_swap_rule;
87 | extern float env_go_komi;
88 | extern std::string env_go_ko_rule;
89 | extern std::string env_gomoku_rule;
90 | extern bool env_gomoku_exactly_five_stones;
91 | extern bool env_havannah_use_swap_rule;
92 | extern bool env_hex_use_swap_rule;
93 | extern bool env_killallgo_use_seki;
94 | extern int env_rubiks_scramble_rotate;
95 | extern int env_surakarta_no_capture_plies;
96 | extern std::string env_gridworld_maps_dir;
97 | extern bool fix_player;
98 | extern bool fix_goal;
99 |
100 | void setConfiguration(ConfigureLoader& cl);
101 |
102 | } // namespace minizero::config
103 |
--------------------------------------------------------------------------------
/minizero/config/configure_loader.cpp:
--------------------------------------------------------------------------------
1 | #include "configure_loader.h"
2 | #include
3 | #include
4 | #include
5 |
6 | namespace minizero::config {
7 |
8 | template <>
9 | bool setParameter(bool& ref, const std::string& value)
10 | {
11 | std::string tmp = value;
12 | transform(tmp.begin(), tmp.end(), tmp.begin(), ::toupper);
13 | if (tmp != "TRUE" && tmp != "1" && tmp != "FALSE" && tmp != "0") { return false; }
14 |
15 | ref = (tmp == "TRUE" || tmp == "1");
16 | return true;
17 | }
18 |
19 | template <>
20 | bool setParameter(std::string& ref, const std::string& value)
21 | {
22 | ref = value;
23 | return true;
24 | }
25 |
26 | template <>
27 | std::string getParameter(bool& ref)
28 | {
29 | std::ostringstream oss;
30 | oss << (ref == true ? "true" : "false");
31 | return oss.str();
32 | }
33 |
34 | bool ConfigureLoader::loadFromFile(std::string conf_file)
35 | {
36 | if (conf_file.empty()) { return false; }
37 |
38 | std::string line;
39 | std::ifstream file(conf_file);
40 | if (file.fail()) {
41 | file.close();
42 | return false;
43 | }
44 | while (std::getline(file, line)) {
45 | if (!setValue(line)) { return false; }
46 | }
47 |
48 | return true;
49 | }
50 |
51 | bool ConfigureLoader::loadFromString(std::string conf_string)
52 | {
53 | if (conf_string.empty()) { return false; }
54 |
55 | std::string line;
56 | std::istringstream iss(conf_string);
57 | while (std::getline(iss, line, ':')) {
58 | if (!setValue(line)) { return false; }
59 | }
60 |
61 | return true;
62 | }
63 |
64 | std::string ConfigureLoader::toString() const
65 | {
66 | std::ostringstream oss;
67 | for (const auto& group_name : group_name_order_) {
68 | oss << "# " << group_name << std::endl;
69 | for (auto parameter : parameter_groups_.at(group_name)) { oss << parameter->toString(); }
70 | oss << std::endl;
71 | }
72 | return oss.str();
73 | }
74 |
75 | std::string ConfigureLoader::getConfig(std::string key) const
76 | {
77 | for (const auto& group_name : group_name_order_) {
78 | for (auto parameter : parameter_groups_.at(group_name)) {
79 | if (parameter->getKey() == key) { return parameter->toString(); }
80 | }
81 | }
82 | return "";
83 | }
84 |
85 | void ConfigureLoader::trim(std::string& s)
86 | {
87 | if (s.empty()) { return; }
88 | s.erase(0, s.find_first_not_of(" \t"));
89 | s.erase(s.find_last_not_of(" \t") + 1);
90 | }
91 |
92 | bool ConfigureLoader::setValue(std::string line)
93 | {
94 | if (line.empty() || line[0] == '#') { return true; }
95 |
96 | std::string key = line.substr(0, line.find("="));
97 | std::string value = line.substr(line.find("=") + 1);
98 | if (value.find("#") != std::string::npos) { value = value.substr(0, value.find("#")); }
99 | std::string group_name = line.substr(line.find("#") + 1);
100 |
101 | trim(key);
102 | trim(value);
103 | trim(group_name);
104 |
105 | if (parameters_.count(key) == 0) {
106 | std::cerr << "Invalid key \"" + key + "\" and value \"" << value << "\"" << std::endl;
107 | return false;
108 | }
109 |
110 | if (!(*parameters_[key])(value)) {
111 | std::cerr << "Unsatisfiable value \"" + value + "\" for option \"" + key + "\"" << std::endl;
112 | return false;
113 | }
114 |
115 | return true;
116 | }
117 |
118 | } // namespace minizero::config
119 |
--------------------------------------------------------------------------------
/minizero/config/configure_loader.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include