├── waf ├── .gitignore ├── README.md ├── tbb.py ├── include └── mcts │ ├── macros.hpp │ ├── parallel.hpp │ ├── defaults.hpp │ └── uct.hpp ├── wscript └── src ├── benchmarks └── trap.cpp ├── toy_sim.cpp └── uct.cpp /waf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/resibots/mcts/HEAD/waf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *# 3 | .DS_Store 4 | .project 5 | .settings 6 | .classpath 7 | .metadata 8 | *.o 9 | *.a 10 | *.pyc 11 | *.err 12 | *.log 13 | .lock-waf* 14 | .waf-* 15 | .waf3-* 16 | build 17 | waf_xcode.sh 18 | exp 19 | src/tests/combinations 20 | 21 | # Ignored folders for the documentation 22 | _build 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Monte Carlo Tree Search 2 | ======================== 3 | 4 | A lightweight and generic C++14 implementation for Monte Carlo Tree Search algorithm. 5 | 6 | Authors 7 | ------ 8 | - Konstantinos Chatzilygeroudis (Inria) 9 | 10 | Main references 11 | --------------- 12 | 13 | - **UCT**: Levente Kocsis and Csaba Szepesvari (2006). Bandit based Monte-Carlo Planning. *Machine Learning: ECML* 14 | - **Continuous-MCTS**: Adrien Couetoux(2013). Monte Carlo Tree Search for Continuous and Stochastic Sequential Decision Making Problems. *Ph.D. dissertation - Universite Paris Sud - Paris XI* 15 | - **Survey**: Cameron Browne, Edward Powley, Daniel Whitehouse, Simon Lucas, Peter I. Cowling, Philipp Rohlfshagen, Stephen Tavener, Diego Perez, Spyridon Samothrakis and Simon Colton (2012). A Survey of Monte Carlo Tree Search Methods. *Transactions on Computational Intelligence and AI in Games, IEEE* 16 | -------------------------------------------------------------------------------- /tbb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | #| Konstantinos Chatzilygeroudis 2018-2023 4 | 5 | """ 6 | Quick n dirty tbb detection 7 | """ 8 | 9 | from waflib.Configure import conf 10 | 11 | def options(opt): 12 | opt.add_option('--tbb', type='string', help='path to Intel TBB', dest='tbb') 13 | 14 | # check if a lib exists for both osx (darwin) and GNU/linux 15 | def check_lib(self, name, path): 16 | if self.env['DEST_OS'] == 'darwin': 17 | libname = name + '.dylib' 18 | else: 19 | libname = name + '.so' 20 | res = self.find_file(libname, path) 21 | lib = res[:-len(libname)-1] 22 | return lib 23 | 24 | @conf 25 | def check_tbb(self, *k, **kw): 26 | def get_directory(filename, dirs): 27 | res = self.find_file(filename, dirs) 28 | return res[:-len(filename)-1] 29 | 30 | required = kw.get('required', False) 31 | 32 | if self.options.tbb: 33 | includes_tbb = [self.options.tbb + '/include'] 34 | libpath_tbb = [self.options.tbb + '/lib'] 35 | else: 36 | includes_tbb = ['/usr/local/include/oneapi', '/usr/include/oneapi', '/usr/local/include', '/usr/include', '/opt/intel/tbb/include'] 37 | libpath_tbb = ['/usr/local/lib/', '/usr/lib', '/opt/intel/tbb/lib', '/usr/lib/x86_64-linux-gnu/'] 38 | 39 | self.start_msg('Checking Intel TBB includes') 40 | incl = '' 41 | lib = '' 42 | try: 43 | incl = get_directory('tbb/parallel_for.h', includes_tbb) 44 | self.end_msg(incl) 45 | except: 46 | if required: 47 | self.fatal('Not found in %s' % str(includes_tbb)) 48 | self.end_msg('Not found in %s' % str(includes_tbb), 'YELLOW') 49 | return 50 | 51 | check_oneapi = False 52 | try: 53 | incl = get_directory('tbb/version.h', includes_tbb) 54 | check_oneapi = True 55 | except: 56 | pass 57 | 58 | self.start_msg('Checking Intel TBB libs') 59 | try: 60 | lib = check_lib(self, 'libtbb', libpath_tbb) 61 | self.end_msg(lib) 62 | except: 63 | if required: 64 | self.fatal('Not found in %s' % str(libpath_tbb)) 65 | self.end_msg('Not found in %s' % str(libpath_tbb), 'YELLOW') 66 | return 67 | 68 | self.env.LIBPATH_TBB = [lib] 69 | self.env.LIB_TBB = ['tbb'] 70 | self.env.INCLUDES_TBB = [incl] 71 | self.env.DEFINES_TBB = ['USE_TBB'] 72 | if check_oneapi: 73 | self.env.DEFINES_TBB.append('USE_TBB_ONEAPI') 74 | -------------------------------------------------------------------------------- /include/mcts/macros.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MCTS_MACROS_HPP 2 | #define MCTS_MACROS_HPP 3 | 4 | #define MCTS_PARAM(Type, Name, Value) \ 5 | static constexpr Type Name() { return Value; } 6 | 7 | #define MCTS_REQUIRED_PARAM(Type, Name) \ 8 | static const Type Name() \ 9 | { \ 10 | static_assert(false, "You need to define the parameter:" #Name " !"); \ 11 | return Type(); \ 12 | } 13 | 14 | #define MCTS_DYN_PARAM(Type, Name) \ 15 | static Type _##Name; \ 16 | static Type Name() { return _##Name; } \ 17 | static void set_##Name(const Type& v) { _##Name = v; } 18 | 19 | #define MCTS_DECLARE_DYN_PARAM(Type, Namespace, Name) Type Namespace::_##Name; 20 | 21 | #define __VA_NARG__(...) (__VA_NARG_(_0, ##__VA_ARGS__, __RSEQ_N()) - 1) 22 | #define __VA_NARG_(...) __VA_ARG_N(__VA_ARGS__) 23 | #define __VA_ARG_N( \ 24 | _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, \ 25 | _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, \ 26 | _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, \ 27 | _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, \ 28 | _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, \ 29 | _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, \ 30 | _61, _62, _63, N, ...) N 31 | #define __RSEQ_N() \ 32 | 63, 62, 61, 60, \ 33 | 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, \ 34 | 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, \ 35 | 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, \ 36 | 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, \ 37 | 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, \ 38 | 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 39 | 40 | #define MCTS_PARAM_ARRAY(Type, Name, ...) \ 41 | static Type Name(size_t i) \ 42 | { \ 43 | assert(i < __VA_NARG__(__VA_ARGS__)); \ 44 | static constexpr Type _##Name[] = {__VA_ARGS__}; \ 45 | return _##Name[i]; \ 46 | } \ 47 | static constexpr size_t Name##_size() \ 48 | { \ 49 | return __VA_NARG__(__VA_ARGS__); \ 50 | } \ 51 | typedef Type Name##_t; 52 | 53 | #define MCTS_PARAM_STRING(Name, Value) \ 54 | static constexpr const char* Name() { return Value; } 55 | 56 | #endif 57 | -------------------------------------------------------------------------------- /wscript: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | VERSION = '0.0.1' 5 | APPNAME = 'mcts' 6 | 7 | srcdir = '.' 8 | blddir = 'build' 9 | 10 | from waflib.Build import BuildContext 11 | import tbb 12 | 13 | 14 | def options(opt): 15 | opt.load('compiler_cxx') 16 | opt.load('compiler_c') 17 | opt.load('tbb') 18 | 19 | 20 | def configure(conf): 21 | conf.load('compiler_cxx') 22 | conf.load('compiler_c') 23 | conf.load('tbb') 24 | 25 | conf.check_tbb() 26 | 27 | if conf.env.CXX_NAME in ["icc", "icpc"]: 28 | common_flags = "-Wall -std=c++14" 29 | opt_flags = " -O3 -xHost -march=native -mtune=native -unroll -fma -g" 30 | elif conf.env.CXX_NAME in ["clang"]: 31 | common_flags = "-Wall -std=c++14" 32 | opt_flags = " -O3 -march=native -g" 33 | else: 34 | common_flags = "-Wall -std=c++14" 35 | opt_flags = " -O3 -march=native -g" 36 | 37 | all_flags = common_flags + opt_flags 38 | conf.env['CXXFLAGS'] = conf.env['CXXFLAGS'] + all_flags.split(' ') 39 | print(conf.env['CXXFLAGS']) 40 | 41 | 42 | def build(bld): 43 | bld.program(features = 'cxx', 44 | uselib = "TBB", 45 | install_path = None, 46 | source='src/uct.cpp', 47 | includes = './include', 48 | target='uct') 49 | 50 | bld.program(features = 'cxx', 51 | uselib = "TBB", 52 | install_path = None, 53 | source='src/benchmarks/trap.cpp', 54 | includes = './include', 55 | defines = ['SINGLE'], 56 | target='src/benchmarks/trap') 57 | 58 | bld.program(features = 'cxx', 59 | uselib = "TBB", 60 | install_path = None, 61 | source='src/benchmarks/trap.cpp', 62 | includes = './include', 63 | target='src/benchmarks/trap_parallel') 64 | 65 | bld.program(features = 'cxx', 66 | uselib = "TBB", 67 | install_path = None, 68 | source='src/benchmarks/trap.cpp', 69 | includes = './include', 70 | defines = ['SIMPLE', 'SINGLE'], 71 | target='src/benchmarks/trap_simple') 72 | 73 | bld.program(features = 'cxx', 74 | uselib = "TBB", 75 | install_path = None, 76 | source='src/benchmarks/trap.cpp', 77 | includes = './include', 78 | defines = ['SIMPLE'], 79 | target='src/benchmarks/trap_simple_parallel') 80 | 81 | bld.program(features = 'cxx', 82 | uselib = "TBB", 83 | install_path = None, 84 | source='src/toy_sim.cpp', 85 | includes = './include', 86 | target='toy_sim') 87 | 88 | bld.program(features = 'cxx', 89 | uselib = "TBB", 90 | install_path = None, 91 | source='src/toy_sim.cpp', 92 | includes = './include', 93 | defines = 'SINGLE', 94 | target='toy_sim_single') 95 | 96 | bld.install_files('${PREFIX}/include/mcts', 'include/mcts/uct.hpp') 97 | bld.install_files('${PREFIX}/include/mcts', 'include/mcts/defaults.hpp') 98 | bld.install_files('${PREFIX}/include/mcts', 'include/mcts/macros.hpp') 99 | bld.install_files('${PREFIX}/include/mcts', 'include/mcts/parallel.hpp') 100 | -------------------------------------------------------------------------------- /src/benchmarks/trap.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | struct Params { 6 | struct uct { 7 | MCTS_PARAM(double, c, 50.0); 8 | }; 9 | 10 | struct spw { 11 | MCTS_PARAM(double, a, 0.5); 12 | }; 13 | 14 | struct cont_outcome { 15 | MCTS_PARAM(double, b, 0.6); 16 | }; 17 | 18 | struct mcts_node { 19 | #ifdef SINGLE 20 | MCTS_PARAM(size_t, parallel_roots, 1); 21 | #else 22 | MCTS_PARAM(size_t, parallel_roots, 4); 23 | #endif 24 | }; 25 | }; 26 | 27 | namespace global { 28 | double a = 70; 29 | double h = 100; 30 | double l = 1; 31 | double w = 0.7; 32 | } 33 | 34 | struct SimpleState { 35 | double _x, _R; 36 | int _time; 37 | const double _epsilon = 1e-6; 38 | 39 | SimpleState() 40 | { 41 | _x = 0; 42 | _R = 0.01; 43 | _time = 0; 44 | } 45 | 46 | SimpleState(double x, int t = 0, double R = 0.01) 47 | { 48 | _x = x; 49 | _R = R; 50 | _time = t; 51 | } 52 | 53 | double next_action() const 54 | { 55 | return random_action(); 56 | } 57 | 58 | double random_action() const 59 | { 60 | return (std::rand() / double(RAND_MAX)); 61 | } 62 | 63 | SimpleState move(double d) const 64 | { 65 | double x_new = _x + d + _R * (std::rand() / double(RAND_MAX)); 66 | return SimpleState(x_new, _time + 1, _R); 67 | } 68 | 69 | bool terminal() const 70 | { 71 | return (_time >= 2); 72 | } 73 | 74 | bool operator==(const SimpleState& other) const 75 | { 76 | double dx = _x - other._x; 77 | return ((dx * dx) < _epsilon); 78 | } 79 | }; 80 | 81 | struct RewardFunction { 82 | template 83 | double operator()(std::shared_ptr from_state, double action, std::shared_ptr to_state) 84 | { 85 | if (to_state->_x < global::l) 86 | return global::a; 87 | else if (to_state->_x < (global::l + global::w)) 88 | return 0.0; 89 | else if (to_state->_x > (global::l + global::w)) 90 | return global::h; 91 | assert(false); 92 | return 0.0; 93 | } 94 | }; 95 | 96 | int main() 97 | { 98 | std::srand(std::time(0)); 99 | mcts::par::init(); 100 | 101 | RewardFunction world; 102 | SimpleState init; 103 | 104 | #ifdef SIMPLE 105 | auto tree = std::make_shared, mcts::SimpleValueInit, mcts::UCTValue, mcts::UniformRandomPolicy, double, mcts::SPWSelectPolicy, mcts::SimpleOutcomeSelect>>(init, 2, 1.0); 106 | #else 107 | auto tree = std::make_shared, mcts::SimpleValueInit, mcts::UCTValue, mcts::UniformRandomPolicy, double, mcts::SPWSelectPolicy, mcts::ContinuousOutcomeSelect>>(init, 2, 1.0); 108 | #endif 109 | 110 | #ifdef SINGLE 111 | const int n_iter = 50000; 112 | #else 113 | const int n_iter = 18000; 114 | #endif 115 | 116 | auto t1 = std::chrono::steady_clock::now(); 117 | 118 | tree->compute(world, n_iter); 119 | 120 | auto time_running = std::chrono::duration_cast(std::chrono::steady_clock::now() - t1).count(); 121 | std::cout << "Time in sec: " << time_running / 1000.0 << std::endl; 122 | 123 | auto best = tree->best_action(); 124 | if (best != nullptr) { 125 | std::cout << best->action() << std::endl; 126 | std::cout << best->value() / best->visits() << std::endl; 127 | auto new_state = init.move(best->action()); 128 | std::cout << "Moving to: " << new_state._x << std::endl; 129 | 130 | // tree = std::make_shared, mcts::SimpleValueInit, mcts::UCTValue, mcts::UniformRandomPolicy, double, mcts::SPWSelectPolicy, mcts::ContinuousOutcomeSelect>>(new_state, 2); 131 | // 132 | // best = tree->best_action(); 133 | // if (best != nullptr) 134 | // std::cout << best->action() << std::endl; 135 | } 136 | 137 | return 0; 138 | } 139 | -------------------------------------------------------------------------------- /include/mcts/parallel.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MCTS_PARALLEL_HPP 2 | #define MCTS_PARALLEL_HPP 3 | 4 | #include 5 | #include 6 | 7 | #ifdef USE_TBB 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #ifndef USE_TBB_ONEAPI 16 | #include 17 | #endif 18 | #endif 19 | 20 | namespace mcts { 21 | namespace par { 22 | #ifdef USE_TBB 23 | template 24 | using vector = tbb::concurrent_vector; // Template alias (for GCC 4.7 and later) 25 | 26 | /// @ingroup par_tools 27 | /// convert a std::vector to something else (e.g. a std::list) 28 | template 29 | std::vector convert_vector(const V& v) 30 | { 31 | std::vector v2(v.size()); 32 | std::copy(v.begin(), v.end(), v2.begin()); 33 | return v2; 34 | } 35 | #else 36 | template 37 | using vector = std::vector; // Template alias (for GCC 4.7 and later) 38 | 39 | template 40 | V convert_vector(const V& v) 41 | { 42 | return v; 43 | } 44 | 45 | #endif 46 | 47 | #if (defined USE_TBB) && !(defined USE_TBB_ONEAPI) 48 | inline void init() 49 | { 50 | static tbb::task_scheduler_init init; 51 | } 52 | #else 53 | /// @ingroup par_tools 54 | /// init TBB (if activated) for multi-core computing 55 | void init() 56 | { 57 | } 58 | #endif 59 | 60 | ///@ingroup par_tools 61 | /// parallel for 62 | template 63 | inline void loop(size_t begin, size_t end, const F& f) 64 | { 65 | #ifdef USE_TBB 66 | tbb::parallel_for(size_t(begin), end, size_t(1), [&](size_t i) { 67 | // clang-format off 68 | f(i); 69 | // clang-format on 70 | }); 71 | #else 72 | for (size_t i = begin; i < end; ++i) 73 | f(i); 74 | #endif 75 | } 76 | 77 | /// @ingroup par_tools 78 | /// parallel for_each 79 | template 80 | inline void for_each(Iterator begin, Iterator end, const F& f) 81 | { 82 | #ifdef USE_TBB 83 | tbb::parallel_for_each(begin, end, f); 84 | #else 85 | for (Iterator i = begin; i != end; ++i) 86 | f(*i); 87 | #endif 88 | } 89 | 90 | /// @ingroup par_tools 91 | /// parallel max 92 | template 93 | T max(const T& init, int num_steps, const F& f, const C& comp) 94 | { 95 | #ifdef USE_TBB 96 | auto body = [&](const tbb::blocked_range& r, T current_max) -> T { 97 | // clang-format off 98 | for (size_t i = r.begin(); i != r.end(); ++i) 99 | { 100 | T v = f(i); 101 | if (comp(v, current_max)) 102 | current_max = v; 103 | } 104 | return current_max; 105 | // clang-format on 106 | }; 107 | auto joint = [&](const T& p1, const T& p2) -> T { 108 | // clang-format off 109 | if (comp(p1, p2)) 110 | return p1; 111 | return p2; 112 | // clang-format on 113 | }; 114 | return tbb::parallel_reduce(tbb::blocked_range(0, num_steps), init, 115 | body, joint); 116 | #else 117 | T current_max = init; 118 | for (size_t i = 0; i < num_steps; ++i) { 119 | T v = f(i); 120 | if (comp(v, current_max)) 121 | current_max = v; 122 | } 123 | return current_max; 124 | #endif 125 | } 126 | /// @ingroup par_tools 127 | /// parallel sort 128 | template 129 | inline void sort(T1 i1, T2 i2, T3 comp) 130 | { 131 | #ifdef USE_TBB 132 | tbb::parallel_sort(i1, i2, comp); 133 | #else 134 | std::sort(i1, i2, comp); 135 | #endif 136 | } 137 | 138 | /// @ingroup par_tools 139 | /// replicate a function nb times 140 | template 141 | inline void replicate(size_t nb, const F& f) 142 | { 143 | #ifdef USE_TBB 144 | tbb::parallel_for(size_t(0), nb, size_t(1), [&](size_t i) { 145 | // clang-format off 146 | f(); 147 | // clang-format on 148 | }); 149 | #else 150 | for (size_t i = 0; i < nb; ++i) 151 | f(); 152 | #endif 153 | } 154 | } // namespace par 155 | } // namespace mcts 156 | 157 | #endif 158 | -------------------------------------------------------------------------------- /src/toy_sim.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | 7 | template 8 | inline T gaussian_rand(T m = 0.0, T v = 1.0) 9 | { 10 | std::random_device rd; 11 | std::mt19937 gen(rd()); 12 | 13 | std::normal_distribution gaussian(m, v); 14 | 15 | return gaussian(gen); 16 | } 17 | 18 | struct Params { 19 | struct uct { 20 | MCTS_PARAM(double, c, 50.0); 21 | }; 22 | 23 | struct spw { 24 | MCTS_PARAM(double, a, 0.5); 25 | }; 26 | 27 | struct cont_outcome { 28 | MCTS_PARAM(double, b, 0.6); 29 | }; 30 | 31 | struct mcts_node { 32 | #ifdef SINGLE 33 | MCTS_PARAM(size_t, parallel_roots, 1); 34 | #else 35 | MCTS_PARAM(size_t, parallel_roots, 4); 36 | #endif 37 | }; 38 | }; 39 | 40 | namespace global { 41 | double goal_x, goal_y; 42 | } 43 | 44 | struct SimpleState { 45 | double _x, _y; 46 | const double _epsilon = 1e-6; 47 | 48 | SimpleState() 49 | { 50 | _x = _y = 0; 51 | } 52 | 53 | SimpleState(double x, double y) 54 | { 55 | _x = x; 56 | _y = y; 57 | } 58 | 59 | double next_action() const 60 | { 61 | // using domain knowledge - have to check literature 62 | double th = gaussian_rand(best_action(), 0.3); 63 | if (th > M_PI) 64 | th -= 2 * M_PI; 65 | if (th < -M_PI) 66 | th += 2 * M_PI; 67 | return th; 68 | } 69 | 70 | double random_action() const 71 | { 72 | return (std::rand() * 2.0 * M_PI / double(RAND_MAX) - M_PI); 73 | } 74 | 75 | double best_action() const 76 | { 77 | double th = std::atan2(global::goal_y - _y, global::goal_x - _x); 78 | if (th > M_PI) 79 | th -= 2 * M_PI; 80 | if (th < -M_PI) 81 | th += 2 * M_PI; 82 | return th; 83 | } 84 | 85 | SimpleState move(double theta, bool prob = true) const 86 | { 87 | double r = 0.1; 88 | double th = theta; 89 | if (prob) { 90 | double p = std::rand() / double(RAND_MAX); 91 | if (p < 0.2) { 92 | th += 0.1; 93 | if (th > M_PI) 94 | th -= 2 * M_PI; 95 | if (th < -M_PI) 96 | th += 2 * M_PI; 97 | } 98 | } 99 | double s = std::sin(th), c = std::cos(th); 100 | double x_new = r * c + _x, y_new = r * s + _y; 101 | 102 | return SimpleState(x_new, y_new); 103 | } 104 | 105 | bool terminal() const 106 | { 107 | double dx = _x - global::goal_x; 108 | double dy = _y - global::goal_y; 109 | 110 | if ((dx * dx + dy * dy) < 0.01) 111 | return true; 112 | return false; 113 | } 114 | 115 | bool operator==(const SimpleState& other) const 116 | { 117 | double dx = _x - other._x; 118 | double dy = _y - other._y; 119 | return ((dx * dx + dy * dy) < _epsilon); 120 | } 121 | }; 122 | 123 | struct RewardFunction { 124 | template 125 | double operator()(std::shared_ptr from_state, double action, std::shared_ptr to_state) 126 | { 127 | if (to_state->terminal()) 128 | return 10.0; 129 | return -1.0; 130 | } 131 | }; 132 | 133 | namespace mcts { 134 | template 135 | struct BestHeuristicPolicy { 136 | Action operator()(const std::shared_ptr& state) 137 | { 138 | return state->best_action(); 139 | } 140 | }; 141 | } // namespace mcts 142 | 143 | int main() 144 | { 145 | std::srand(std::time(0)); 146 | mcts::par::init(); 147 | 148 | global::goal_x = 2.0; 149 | global::goal_y = 2.0; 150 | 151 | RewardFunction world; 152 | SimpleState init(0.0, 0.0); 153 | 154 | auto tree = std::make_shared, mcts::SimpleValueInit, mcts::UCTValue, mcts::BestHeuristicPolicy, double, mcts::SPWSelectPolicy, mcts::ContinuousOutcomeSelect>>(init, 2000); 155 | #ifdef SINGLE 156 | const int n_iter = 400000; 157 | #else 158 | const int n_iter = 200000; 159 | #endif 160 | 161 | auto t1 = std::chrono::steady_clock::now(); 162 | 163 | tree->compute(world, n_iter); 164 | 165 | auto time_running = std::chrono::duration_cast(std::chrono::steady_clock::now() - t1).count(); 166 | std::cout << "Time in sec: " << time_running / 1000.0 << std::endl; 167 | 168 | auto best = tree->best_action(); 169 | if (best == nullptr) 170 | std::cout << init._x << " " << init._y << ": Terminal!" << std::endl; 171 | else 172 | std::cout << init._x << " " << init._y << ": " << best->action() << " -> " << init.move(best->action(), false)._x << " " << init.move(best->action(), false)._y << std::endl; 173 | 174 | return 0; 175 | } 176 | -------------------------------------------------------------------------------- /include/mcts/defaults.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MCTS_DEFAULTS_HPP 2 | #define MCTS_DEFAULTS_HPP 3 | 4 | #include 5 | 6 | namespace mcts { 7 | 8 | template 9 | struct SimpleStateInit { 10 | std::shared_ptr operator()() 11 | { 12 | // assumes the default constructor of State is the init state 13 | return std::make_shared(); 14 | } 15 | }; 16 | 17 | struct SimpleValueInit { 18 | template 19 | double operator()(const std::shared_ptr& state) 20 | { 21 | return 0.0; 22 | } 23 | }; 24 | 25 | struct SimpleSelectPolicy { 26 | template 27 | bool operator()(const std::shared_ptr& node) 28 | { 29 | return true; 30 | } 31 | }; 32 | 33 | struct SimpleOutcomeSelect { 34 | template 35 | auto operator()(const std::shared_ptr& action) -> std::shared_ptrparent()))>::type> 36 | { 37 | using NodeType = typename std::remove_referenceparent()))>::type; 38 | auto st = action->parent()->state()->move(action->action()); 39 | auto to_add = std::make_shared(st, action->parent()->rollout_depth(), action->parent()->gamma()); 40 | auto it = std::find_if(action->children().begin(), action->children().end(), [&](std::shared_ptr const& p) { return *(p->state()) == *(to_add->state()); }); 41 | if (action->children().size() == 0 || it == action->children().end()) { 42 | to_add->parent() = action; 43 | action->children().push_back(to_add); 44 | return to_add; 45 | } 46 | 47 | return (*it); 48 | } 49 | }; 50 | 51 | template 52 | struct UCTValue { 53 | // c parameter in Params struct 54 | const double _epsilon = 1e-6; 55 | 56 | template 57 | double operator()(const std::shared_ptr& action) 58 | { 59 | // return action->value() / (double(action->visits()) + _epsilon) + _c * std::sqrt(2.0 * std::log(action->parent()->visits() + 1.0) / (double(action->visits()) + _epsilon)); 60 | return action->value() / (double(action->visits()) + _epsilon) + 2.0 * Params::uct::c() * std::sqrt(std::log(action->parent()->visits() + 1.0) / (double(action->visits()) + _epsilon)); 61 | } 62 | }; 63 | 64 | struct GreedyValue { 65 | const double _epsilon = 1e-6; 66 | 67 | template 68 | double operator()(const std::shared_ptr& action) 69 | { 70 | return action->value() / (double(action->visits()) + _epsilon); 71 | } 72 | }; 73 | 74 | template 75 | struct UniformRandomPolicy { 76 | Action operator()(const std::shared_ptr& state) 77 | { 78 | return state->random_action(); 79 | } 80 | }; 81 | 82 | template 83 | struct SPWSelectPolicy { 84 | // a parameter in Params struct 85 | 86 | template 87 | bool operator()(const std::shared_ptr& node) 88 | { 89 | if (node->visits() == 0 || std::pow((double)node->visits(), Params::spw::a()) > node->children().size()) 90 | return true; 91 | return false; 92 | } 93 | }; 94 | 95 | template 96 | struct ContinuousOutcomeSelect { 97 | // b parameter in Params struct 98 | 99 | template 100 | auto operator()(const std::shared_ptr& action) -> std::shared_ptrparent()))>::type> 101 | { 102 | using NodeType = typename std::remove_referenceparent()))>::type; 103 | 104 | if (action->visits() == 0 || std::pow((double)action->visits(), Params::cont_outcome::b()) > action->children().size()) { 105 | auto st = action->parent()->state()->move(action->action()); 106 | auto to_add = std::make_shared(st, action->parent()->rollout_depth(), action->parent()->gamma()); 107 | auto it = std::find_if(action->children().begin(), action->children().end(), [&](std::shared_ptr const& p) { return *(p->state()) == *(to_add->state()); }); 108 | if (action->children().size() == 0 || it == action->children().end()) { 109 | to_add->parent() = action; 110 | action->children().push_back(to_add); 111 | return to_add; 112 | } 113 | 114 | return (*it); 115 | } 116 | 117 | // Choose child with probability: n(c)/Sum(n(c')) 118 | size_t sum = 0; 119 | for (size_t i = 0; i < action->children().size(); i++) { 120 | sum += action->children()[i]->visits(); 121 | } 122 | size_t r = static_cast(std::rand() * double(sum) / double(RAND_MAX)); 123 | size_t p = 0; 124 | for (auto child : action->children()) { 125 | p += child->visits(); 126 | if (r <= p) 127 | return child; 128 | } 129 | 130 | // we should never reach here 131 | assert(false); 132 | return nullptr; 133 | } 134 | }; 135 | } // namespace mcts 136 | 137 | #endif 138 | -------------------------------------------------------------------------------- /src/uct.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | size_t GOAL; 8 | 9 | struct Params { 10 | struct uct { 11 | MCTS_PARAM(double, c, 10.0); 12 | }; 13 | 14 | struct mcts_node { 15 | MCTS_PARAM(size_t, parallel_roots, 1); 16 | }; 17 | }; 18 | 19 | struct GridState { 20 | size_t _x, _y, _N; 21 | double _prob; 22 | std::vector _used_actions; 23 | 24 | GridState() 25 | { 26 | _x = _y = 0; 27 | _N = 10; 28 | _prob = 0.0; 29 | } 30 | 31 | GridState(size_t x, size_t y, size_t N, double prob) 32 | { 33 | _x = x; 34 | _y = y; 35 | _N = N; 36 | _prob = prob; 37 | } 38 | 39 | bool valid(size_t action) const 40 | { 41 | int x_new = _x, y_new = _y; 42 | if (action == 0) // up 43 | { 44 | y_new++; 45 | if (y_new >= (int)_N) 46 | return false; 47 | } 48 | else if (action == 1) // down 49 | { 50 | y_new--; 51 | if (y_new < 0) 52 | return false; 53 | } 54 | else if (action == 2) // right 55 | { 56 | x_new++; 57 | if (x_new >= (int)_N) 58 | return false; 59 | } 60 | else if (action == 3) // left 61 | { 62 | x_new--; 63 | if (x_new < 0) 64 | return false; 65 | } 66 | return true; 67 | } 68 | 69 | size_t next_action() 70 | { 71 | return random_action(); 72 | } 73 | 74 | GridState move(size_t action, bool prob = true) const 75 | { 76 | int x_new = _x, y_new = _y; 77 | 78 | double r = std::rand() / (double)RAND_MAX; 79 | if ((r - _prob) < 0 && prob) 80 | action = (action + 1) % 4; 81 | 82 | if (action == 0) // up 83 | { 84 | y_new++; 85 | if (y_new >= (int)_N) 86 | y_new--; 87 | } 88 | else if (action == 1) // down 89 | { 90 | y_new--; 91 | if (y_new < 0) 92 | y_new++; 93 | } 94 | else if (action == 2) // right 95 | { 96 | x_new++; 97 | if (x_new >= (int)_N) 98 | x_new--; 99 | } 100 | else if (action == 3) // left 101 | { 102 | x_new--; 103 | if (x_new < 0) 104 | x_new++; 105 | } 106 | // if (r < PROB) 107 | // return GridState(_x, _y, _N); 108 | 109 | return GridState(x_new, y_new, _N, _prob); 110 | } 111 | 112 | size_t random_action() const 113 | { 114 | size_t act; 115 | do { 116 | act = static_cast(std::rand() * 4.0 / (double)RAND_MAX); 117 | } while (!valid(act)); 118 | 119 | return act; 120 | } 121 | 122 | size_t best_action() const 123 | { 124 | size_t act = 0; 125 | double v = std::numeric_limits::max(); 126 | for (size_t i = 0; i < 4; i++) { 127 | if (!valid(i)) 128 | continue; 129 | GridState tmp = move(i, false); 130 | double dx = tmp._x - GOAL + 1; 131 | double dy = tmp._y - GOAL + 1; 132 | double d = dx * dx + dy * dy; 133 | if (d < v) { 134 | act = i; 135 | v = d; 136 | } 137 | } 138 | 139 | return act; 140 | } 141 | 142 | bool terminal() const 143 | { 144 | if (_x == (GOAL - 1) && _y == (GOAL - 1)) 145 | return true; 146 | return false; 147 | } 148 | 149 | bool operator==(const GridState& other) const 150 | { 151 | assert(_N == other._N); 152 | return (_x == other._x && _y == other._y); 153 | } 154 | }; 155 | 156 | struct GridWorld { 157 | template 158 | double operator()(std::shared_ptr from_state, size_t action, std::shared_ptr to_state) 159 | { 160 | if (to_state->_x == (GOAL - 1) && to_state->_y == (GOAL - 1)) 161 | return max_reward(); 162 | 163 | return min_reward(); 164 | } 165 | 166 | double max_reward() 167 | { 168 | return 1.0; 169 | } 170 | 171 | double min_reward() 172 | { 173 | return 0.0; 174 | } 175 | }; 176 | 177 | template 178 | struct BestHeuristicPolicy { 179 | Action operator()(const std::shared_ptr& state) 180 | { 181 | return state->best_action(); 182 | } 183 | }; 184 | 185 | int main() 186 | { 187 | std::srand(std::time(0)); 188 | mcts::par::init(); 189 | 190 | GridWorld world; 191 | 192 | for (size_t s = 5; s <= 40; s += 5) { 193 | 194 | GOAL = s; 195 | 196 | std::ofstream file("results_" + std::to_string(s) + ".txt"); 197 | 198 | for (double p = 0.0; p <= 0.4; p += 0.1) { 199 | 200 | size_t c = 0; 201 | size_t avg = 0; 202 | double avg_time = 0.0; 203 | 204 | for (size_t i = 0; i < s; i++) { 205 | for (size_t j = 0; j < s; j++) { 206 | auto t1 = std::chrono::steady_clock::now(); 207 | GridState init(i, j, s, p); 208 | auto tree = std::make_shared, mcts::SimpleValueInit, mcts::UCTValue, BestHeuristicPolicy, size_t, mcts::SimpleSelectPolicy, mcts::SimpleOutcomeSelect>>(init, 10000); 209 | const int N_ITERATIONS = 10000; 210 | const int MIN_ITERATIONS = 1000; 211 | int k; 212 | for (k = 0; k < N_ITERATIONS; ++k) { 213 | tree->iterate(world); 214 | if (k >= MIN_ITERATIONS) { 215 | auto best = tree->best_action(); 216 | if (best != nullptr && (best->action() == 0 || best->action() == 2)) { 217 | if (!(init._x == (s - 1) && best->action() != 0) && !(init._y == (s - 1) && best->action() != 2)) 218 | break; 219 | } 220 | } 221 | } 222 | auto time_running = std::chrono::duration_cast(std::chrono::steady_clock::now() - t1).count(); 223 | avg_time += time_running / 1000.0; 224 | avg += k; 225 | // tree->print(); 226 | // std::cout << "------------------------" << std::endl; 227 | auto best = tree->best_action(); 228 | if (best == nullptr && !init.terminal()) 229 | c++; 230 | if (best != nullptr && best->action() != 0 && best->action() != 2) 231 | c++; 232 | // if (best == nullptr) 233 | // std::cout << init._x << " " << init._y << ": Terminal!" << std::endl; 234 | // else { 235 | // if (best->action() != 0 && best->action() != 2) 236 | // c++; 237 | // std::cout << init._x << " " << init._y << ": " << best->action() << std::endl; 238 | // } 239 | // std::cin.get(); 240 | } 241 | } 242 | 243 | file << c << " " << double(avg) / double(s * s) << " " << avg_time / double(s * s) << std::endl; 244 | // std::cout << "Errors: " << c << std::endl; 245 | } 246 | file.close(); 247 | } 248 | return 0; 249 | } 250 | -------------------------------------------------------------------------------- /include/mcts/uct.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MCTS_UCT_HPP 2 | #define MCTS_UCT_HPP 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | namespace mcts { 16 | 17 | template 18 | class MCTSAction : public std::enable_shared_from_this> { 19 | public: 20 | using action_type = MCTSAction; 21 | using node_ptr = std::shared_ptr; 22 | 23 | MCTSAction(const ActionType& action, const node_ptr& parent, double value) : _parent(parent), _action(action), _value(value), _visits(0) {} 24 | 25 | node_ptr parent() 26 | { 27 | return _parent; 28 | } 29 | 30 | std::vector children() const 31 | { 32 | return _children; 33 | } 34 | 35 | std::vector& children() 36 | { 37 | return _children; 38 | } 39 | 40 | ActionType action() const 41 | { 42 | return _action; 43 | } 44 | 45 | size_t visits() const 46 | { 47 | return _visits; 48 | } 49 | 50 | size_t& visits() 51 | { 52 | return _visits; 53 | } 54 | 55 | double value() const 56 | { 57 | return _value; 58 | } 59 | 60 | double& value() 61 | { 62 | return _value; 63 | } 64 | 65 | bool operator==(const MCTSAction& other) const 66 | { 67 | return _action == other._action; 68 | } 69 | 70 | node_ptr node() 71 | { 72 | return OutcomeSelection()(this->shared_from_this()); 73 | } 74 | 75 | void update_stats(double value) 76 | { 77 | _value += value; 78 | _visits++; 79 | } 80 | 81 | protected: 82 | node_ptr _parent; 83 | std::vector _children; 84 | ActionType _action; 85 | double _value; 86 | size_t _visits; 87 | }; 88 | 89 | template 90 | class MCTSNode : public std::enable_shared_from_this> { 91 | public: 92 | using node_type = MCTSNode; 93 | using action_type = MCTSAction; 94 | using action_ptr = std::shared_ptr; 95 | using node_ptr = std::shared_ptr; 96 | using state_ptr = std::shared_ptr; 97 | 98 | MCTSNode(size_t rollout_depth = 1000, double gamma = 0.9) : _gamma(gamma), _visits(0), _rollout_depth(rollout_depth) 99 | { 100 | _state = StateInit()(); 101 | } 102 | 103 | MCTSNode(State state, size_t rollout_depth = 1000, double gamma = 0.9) : _gamma(gamma), _visits(0), _rollout_depth(rollout_depth) 104 | { 105 | _state = std::make_shared(state); 106 | } 107 | 108 | action_ptr parent() const 109 | { 110 | return _parent; 111 | } 112 | 113 | action_ptr& parent() 114 | { 115 | return _parent; 116 | } 117 | 118 | std::vector children() const 119 | { 120 | return _children; 121 | } 122 | 123 | state_ptr state() const 124 | { 125 | return _state; 126 | } 127 | 128 | size_t visits() const 129 | { 130 | return _visits; 131 | } 132 | 133 | size_t& visits() 134 | { 135 | return _visits; 136 | } 137 | 138 | size_t rollout_depth() const 139 | { 140 | return _rollout_depth; 141 | } 142 | 143 | double gamma() const 144 | { 145 | return _gamma; 146 | } 147 | 148 | template 149 | void compute(RewardFunc rfun, size_t iterations) 150 | { 151 | if (Params::mcts_node::parallel_roots() > 1) { 152 | par::vector roots; 153 | par::replicate(Params::mcts_node::parallel_roots(), [&]() { 154 | node_ptr to_ret = std::make_shared(*this->_state, this->_rollout_depth, this->_gamma); 155 | for (size_t k = 0; k < iterations; ++k) { 156 | to_ret->iterate(rfun); 157 | } 158 | 159 | roots.push_back(to_ret); 160 | }); 161 | 162 | node_ptr cur_node = this->shared_from_this(); 163 | for (size_t i = 0; i < roots.size(); i++) { 164 | cur_node->merge_inplace(roots[i]); 165 | } 166 | } 167 | else { 168 | for (size_t k = 0; k < iterations; ++k) { 169 | this->iterate(rfun); 170 | } 171 | } 172 | } 173 | 174 | template 175 | void iterate(RewardFunc rfun) 176 | { 177 | std::vector visited; 178 | std::vector rewards; 179 | 180 | node_ptr cur_node = this->shared_from_this(); 181 | visited.push_back(cur_node); 182 | rewards.push_back(0.0); 183 | // std::cout << "Iterate!" << std::endl; 184 | 185 | do { 186 | node_ptr prev_node = cur_node; 187 | // std::cout << "(" << cur_node->_state->_x << ", " << cur_node->_state->_y << ")" << std::endl; 188 | action_ptr next_action = cur_node->_expand(); 189 | if (!next_action) 190 | break; 191 | // std::cout << "Selected action: " << next_action->action() << std::endl; 192 | cur_node = next_action->node(); 193 | rewards.push_back(rfun(prev_node->_state, next_action->action(), cur_node->_state)); 194 | // std::cout << "TO: (" << cur_node->_state->_x << ", " << cur_node->_state->_y << ")" << std::endl; 195 | visited.push_back(cur_node); 196 | } while (!cur_node->_state->terminal() && cur_node->visits() > 0); 197 | 198 | double value; 199 | if (cur_node->_state->terminal()) { 200 | value = 0.0; 201 | } 202 | else { 203 | // std::cout << "Simulating: (" << cur_node->_state->_x << ", " << cur_node->_state->_y << ")" << std::endl; 204 | value = cur_node->_simulate(rfun); 205 | } 206 | 207 | for (int i = visited.size() - 1; i >= 0; i--) { 208 | value = rewards[i] + _gamma * value; 209 | visited[i]->_visits++; 210 | if (visited[i]->_parent != nullptr) 211 | visited[i]->_parent->update_stats(value); 212 | } 213 | } 214 | 215 | size_t max_depth(size_t parent_depth = 0) 216 | { 217 | if (this->_children.size() == 0) { 218 | return parent_depth + 1; 219 | } 220 | 221 | size_t maxDepth = 0; 222 | for (size_t k = 0; k < this->_children.size(); ++k) { 223 | for (size_t j = 0; j < this->_children[k]->children().size(); j++) { 224 | size_t curDepth = this->_children[k]->children()[j]->max_depth(parent_depth + 1); 225 | if (maxDepth < curDepth) { 226 | maxDepth = curDepth; 227 | } 228 | } 229 | } 230 | 231 | return maxDepth; 232 | } 233 | 234 | template 235 | action_ptr best_action() 236 | { 237 | if (_state->terminal()) 238 | return nullptr; 239 | double v = -std::numeric_limits::max(); 240 | action_ptr best_action = nullptr; 241 | 242 | for (auto child : _children) { 243 | double d = Value()(child); 244 | 245 | if (d > v) { 246 | v = d; 247 | best_action = child; 248 | } 249 | } 250 | 251 | return best_action; 252 | } 253 | 254 | node_ptr merge_with(const node_ptr& other) 255 | { 256 | node_ptr to_ret = std::make_shared(*this->_state, this->_rollout_depth, this->_gamma); 257 | to_ret->merge_inplace(other); 258 | 259 | return to_ret; 260 | } 261 | 262 | void merge_inplace(const node_ptr& other) 263 | { 264 | node_ptr to_ret = this->shared_from_this(); 265 | 266 | for (auto child : other->_children) { 267 | auto it = std::find_if(to_ret->_children.begin(), to_ret->_children.end(), [&](action_ptr const& p) { return *p == *child; }); 268 | if (it == to_ret->_children.end()) 269 | to_ret->_children.push_back(child); 270 | else { 271 | (*it)->value() += child->value(); 272 | (*it)->visits() += child->visits(); 273 | } 274 | } 275 | } 276 | 277 | // void print(size_t d = 0) const 278 | // { 279 | // std::cout << d << ": " << _state->_x << " " << _state->_y << " -> " << _value << ", " << _visits; // << std::endl; 280 | // if (_parent != nullptr) 281 | // std::cout << " act: " << _parent->action(); 282 | // std::cout << std::endl; 283 | // for (size_t i = 0; i < _children.size(); i++) { 284 | // for (size_t k = 0; k < _children[i]->children().size(); k++) { 285 | // _children[i]->children()[k]->print(d + 1); 286 | // } 287 | // } 288 | // } 289 | 290 | protected: 291 | action_ptr _parent; 292 | std::vector _children; 293 | state_ptr _state; 294 | double _gamma; 295 | size_t _visits, _rollout_depth; 296 | 297 | action_ptr _expand() 298 | { 299 | if (SelectionPolicy()(this->shared_from_this())) { 300 | Action act = _state->next_action(); 301 | action_ptr next_action = std::make_shared(act, this->shared_from_this(), ValueInit()(_state)); 302 | auto it = std::find_if(_children.begin(), _children.end(), [&](action_ptr const& p) { return *p == *next_action; }); 303 | if (_children.size() == 0 || it == _children.end()) { 304 | _children.push_back(next_action); 305 | return next_action; 306 | } 307 | 308 | return (*it); 309 | } 310 | 311 | return _select_action(); 312 | } 313 | 314 | action_ptr _select_action() 315 | { 316 | if (_state->terminal()) 317 | return nullptr; 318 | double v = -std::numeric_limits::max(); 319 | action_ptr best_action = nullptr; 320 | 321 | for (auto child : _children) { 322 | double d = ActionValue()(child); 323 | 324 | if (d > v) { 325 | v = d; 326 | best_action = child; 327 | } 328 | } 329 | 330 | return best_action; 331 | } 332 | 333 | template 334 | double _simulate(RewardFunc rfun) 335 | { 336 | double discount = 1.0; 337 | double reward = 0.0; 338 | 339 | state_ptr cur_state = _state; 340 | 341 | for (size_t k = 0; k < _rollout_depth; ++k) { 342 | // Choose action according to default policy 343 | Action action = DefaultPolicy()(cur_state); 344 | state_ptr prev_state = cur_state; 345 | 346 | // Update state 347 | cur_state = std::make_shared(cur_state->move(action)); 348 | 349 | // Get value from (PO)MDP 350 | reward += discount * rfun(prev_state, action, cur_state); 351 | 352 | // Check if terminal state 353 | if (cur_state->terminal()) 354 | break; 355 | discount *= _gamma; 356 | } 357 | 358 | return reward; 359 | } 360 | }; 361 | } // namespace mcts 362 | 363 | #endif 364 | --------------------------------------------------------------------------------