├── .gitignore ├── LICENSE ├── bin ├── rsync-game ├── setup.sh └── tcpbridge.py ├── doc └── install.rst ├── readme.org ├── requirements.txt └── src ├── cpp ├── Makefile ├── events.h ├── gdltransformer.cpp ├── gdltransformer.h ├── ggpzero_interface.cpp ├── player.cpp ├── player.h ├── puct │ ├── config.h │ ├── evaluator.cpp │ ├── evaluator.h │ ├── minimax.cpp │ ├── minimax.h │ ├── node.cpp │ └── node.h ├── pyobjects │ ├── common.cpp │ ├── gdltransformer_impl.cpp │ ├── player_impl.cpp │ └── supervisor_impl.cpp ├── pyref.h ├── sample.h ├── scheduler.cpp ├── scheduler.h ├── selfplay.cpp ├── selfplay.h ├── selfplay_v2.cpp ├── selfplaymanager.cpp ├── selfplaymanager.h ├── supervisor.cpp ├── supervisor.h └── uniquestates.h ├── ggpzero ├── Makefile ├── __init__.py ├── battle │ ├── README.txt │ ├── __init__.py │ ├── amazons.py │ ├── bt.py │ ├── chess.py │ ├── common.py │ ├── connect6.py │ ├── draughts.py │ ├── hex.py │ ├── hex2.py │ └── reversi.py ├── defs │ ├── __init__.py │ ├── confs.py │ ├── datadesc.py │ ├── gamedesc.py │ ├── msgs.py │ └── templates.py ├── distributed │ ├── __init__.py │ ├── server.py │ └── worker.py ├── nn │ ├── __init__.py │ ├── bases.py │ ├── datacache.py │ ├── manager.py │ ├── model.py │ ├── network.py │ └── train.py ├── player │ ├── __init__.py │ └── puctplayer.py ├── scripts │ ├── __init__.py │ ├── cleanup_nnfiles.py │ ├── findbases.py │ ├── shownn.py │ └── supervised_train.py └── util │ ├── __init__.py │ ├── attrutil.py │ ├── broker.py │ ├── cppinterface.py │ ├── func.py │ ├── keras.py │ ├── main.py │ ├── runprocs.py │ ├── state.py │ └── symmetry.py └── test ├── cpp ├── __init__.py └── test_interface.py ├── nn ├── test_datacache.py ├── test_model.py ├── test_model_draws.py ├── test_model_external.py ├── test_new_transformer.py ├── test_speed.py └── test_templates.py ├── player └── test_player.py ├── test_state.py ├── test_symmetry.py └── test_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # temporary/backup files 2 | *~ 3 | 4 | # c++ 5 | *.d 6 | *.o 7 | *.so 8 | 9 | # python 10 | *.pyc 11 | *__pycache__/ 12 | *.cache/ 13 | *.pytest_cache/ 14 | 15 | # log files 16 | *.log 17 | 18 | # game files 19 | *.sgf 20 | 21 | bin/install 22 | 23 | # spurious stuff in my repo XXX 24 | src/scripts 25 | data 26 | confs/ 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | =============== 3 | Copyright © 2015-2019 Richard Emslie 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 6 | associated documentation files (the “Software”), to deal in the Software without restriction, 7 | including without limitation the rights to use, copy, modify, merge, publish, distribute, 8 | sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in all copies or 12 | substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 15 | NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 16 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 17 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT 18 | OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /bin/rsync-game: -------------------------------------------------------------------------------- 1 | GAME=$1 2 | 3 | DATAPATH_FROM=rxe@hz1:/home/rxe/working/ggpzero/data 4 | DATAPATH_TO=/home/rxe/working/ggpzero/data 5 | 6 | watch -n 15 "rsync -vaz -e ssh $DATAPATH_FROM/$GAME/generations/ $DATAPATH_TO/$GAME/generations && rsync -vaz -e ssh $DATAPATH_FROM/$GAME/models/ $DATAPATH_TO/$GAME/models && rsync -vaz -e ssh $DATAPATH_FROM/$GAME/weights/ $DATAPATH_TO/$GAME/weights" 7 | -------------------------------------------------------------------------------- /bin/setup.sh: -------------------------------------------------------------------------------- 1 | if [ -z "$GGPLIB_PATH" ]; then 2 | echo "Please set \$GGPLIB_PATH" 3 | 4 | else 5 | echo "Ensuring ggplib is set up..." 6 | . $GGPLIB_PATH/bin/setup.sh 7 | 8 | export GGPZERO_PATH=`python2 -c "import os.path as p; print p.dirname(p.dirname(p.abspath('$BASH_SOURCE')))"` 9 | echo "Automatically setting \$GGPZERO_PATH to $GGPZERO_PATH" 10 | 11 | export PYTHONPATH=$GGPZERO_PATH/src:$GGPZERO_PATH/src/cpp:$PYTHONPATH 12 | export LD_LIBRARY_PATH=$GGPZERO_PATH/src/cpp:$LD_LIBRARY_PATH 13 | export PATH=$GGPZERO_PATH/bin:$PATH 14 | 15 | cd $GGPZERO_PATH/src/ggpzero 16 | fi 17 | -------------------------------------------------------------------------------- /bin/tcpbridge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | ''' simple as it gets: bridge from stdio <---> tcp. ''' 3 | 4 | import os 5 | import sys 6 | 7 | from twisted.protocols import basic 8 | from twisted.internet import protocol 9 | 10 | from twisted.internet import stdio, reactor 11 | 12 | try: 13 | port = int(sys.argv[1]) 14 | except Exception as _: 15 | port = 2222 16 | 17 | 18 | class TCPBridge(basic.LineReceiver): 19 | def __init__(self, client): 20 | self.stdio_client = client 21 | self.stdio_client.tcp_bridge = self 22 | 23 | def connectionMade(self): 24 | for l in self.stdio_client.buf: 25 | self.sendLine(l) 26 | 27 | def lineReceived(self, line): 28 | self.stdio_client.requestSendLine(line) 29 | 30 | def requestSendLine(self, line): 31 | self.sendLine(line) 32 | 33 | 34 | class Factory(protocol.ClientFactory): 35 | def __init__(self, client): 36 | self.client = client 37 | 38 | def buildProtocol(self, addr): 39 | return TCPBridge(self.client) 40 | 41 | 42 | class StdioBridgeClient(basic.LineReceiver): 43 | delimiter = os.linesep.encode("ascii") 44 | buf = [] 45 | tcp_bridge = None 46 | 47 | def connectionMade(self): 48 | reactor.connectTCP("localhost", 49 | port, 50 | Factory(self)) 51 | 52 | def lineReceived(self, line): 53 | # print >>sys.stderr, "HERE/lineReceived()", self.tcp_bridge, line 54 | if self.tcp_bridge is None: 55 | self.buf.append(line) 56 | else: 57 | self.tcp_bridge.requestSendLine(line) 58 | 59 | def requestSendLine(self, line): 60 | self.sendLine(line) 61 | 62 | 63 | if __name__ == "__main__": 64 | stdio.StandardIO(StdioBridgeClient()) 65 | reactor.run() 66 | -------------------------------------------------------------------------------- /doc/install.rst: -------------------------------------------------------------------------------- 1 | There are no instructions - left this way since project is in fluid state. 2 | 3 | But here is quick sense of what to do. 4 | 5 | 1. First install ggplib (which also means installing k273). 6 | 2. init the environments (. bin/setup.sh) 7 | 3. cd src/cpp && make 8 | 4. create a virtual python environment 9 | 10 | 11 | pypy versus cpython 12 | ------------------- 13 | With ggplib it is recommended to use pypy - as provides fastest access to statemachine (but this 14 | only really 15 | matters for the simplest of games, like c4/ttt). For ggpzero we use python2.7. Some games will take 16 | ages to optimise the propnet in python2.7, but once created will be cached and then there is no speed 17 | difference. I'd recommend creating/caching the propnet of games you are interested in with pypy, 18 | then switch to python2.7. 19 | 20 | 21 | virtual environment 22 | ------------------- 23 | 24 | You'll need to do this twice if want to support CPU and GPU: 25 | 26 | 1. Follow [tensorflow instructions](https://www.tensorflow.org/install/install_linux) to install 27 | python2.7 in a virtual environment. 28 | 29 | 2. Activate virtualenv. Check tensorflow works (whether use CPU/GPU, is up to you - 30 | the whole self learning environment is optimised to do batching on GPU - so it might be a bit 31 | forlorn to use ggp-zero for training without a decent GPU). 32 | 33 | 3. install python packages. 34 | 35 | .. code-block:: shell 36 | 37 | pip install -r requirements.txt 38 | 39 | 40 | other 41 | ----- 42 | 43 | training uses client/server model. The server can run on install without GPU. clients can be 44 | remote or local to the machine. For examples of config see repo [gzero_models](https://github.com/richemslie/gzero_data/). 45 | 46 | .. code-block:: shell 47 | 48 | cd src/ggpzero/distributed 49 | python server.py 50 | python worker.py 51 | 52 | 53 | 5. Running a model : 54 | 55 | .. code-block:: shell 56 | 57 | cd src/ggpzero/player 58 | python puctplayer 59 | 60 | -------------------------------------------------------------------------------- /readme.org: -------------------------------------------------------------------------------- 1 | * What? 2 | galvanise is a [[https://en.wikipedia.org/wiki/General_game_playing][General Game Player]], where games are written in [[https://en.wikipedia.org/wiki/Game_Description_Language][GDL]]. The original galvanise code 3 | was converted to a library [[https://github.com/richemslie/ggplib][ggplib]] and galvanise_zero adds AlphaZero style learning. Much 4 | inspiration was from Deepmind's related papers, and the excellent Expert Iteration [[https://arxiv.org/abs/1705.08439][paper]]. A 5 | number of Alpha*Zero open source projects were also inspirational: LeelaZero and KataGo (XXX add 6 | links). 7 | 8 | * Features 9 | - there is *no* game specific code other than the GDL description of the games, a high level 10 | python configuration file describing GDL symbols to state mapping and symmetries (see 11 | [[https://github.com/richemslie/galvanise_zero/issues/1][here]] for more information). 12 | - multiple policies - train assymetric games 13 | - fully automated, put in oven and strong model is baked 14 | - network replaced during self play games 15 | - training is very fast using proper coroutines at the C level. 1000s of concurrent games are 16 | trained using large batch sizes on GPU (for small networks). 17 | - uses a post processed replay buffer, which uses the excellent bcolz (XXX link) project. Training 18 | can allow arbitrary sampling from the buffer (giving emphasis to most recent data). 19 | - initially project used expert iteration. This was deprecated in favour of oscillating sampling 20 | (similar to KataGo). 21 | - 3 value heads for games with draws 22 | 23 | * Training 24 | - used same setting for training all games types (cpuct 0.85, fpu 0.25). 25 | - uses smaller number of evaluations (200) than A0, oscillating sampling during training (75% of 26 | moves are skipped, using much less evals to do so). 27 | - policy squashing and extra noise to prevent overfitting 28 | - models use dropout, global average pooling and squeeze excite blocks (these are optional) 29 | - in general, takes 3-5 days in many of the trained game types below to become super human strength 30 | 31 | See [[http://littlegolem.net/jsp/info/player.jsp?plid=58835][gzero_bot]] for how to play on Little Golem. 32 | 33 | * Status 34 | Games with significant training, links to elo graphs and models: 35 | 36 | - [[https://github.com/richemslie/gzero_data/tree/master/data/chess][chess]] 37 | - [[https://github.com/richemslie/gzero_data/tree/master/data/connect6][connect6]] 38 | - [[https://github.com/richemslie/gzero_data/tree/master/data/hexLG13][hex13]] 39 | - [[https://github.com/richemslie/gzero_data/tree/master/data/reversi_10x10][reversi10]] 40 | - [[https://github.com/richemslie/gzero_data/tree/master/data/reversi_8x8][reversi8]] 41 | - [[https://github.com/richemslie/gzero_data/tree/master/data/amazons_10x10][amazons]] 42 | - [[https://github.com/richemslie/gzero_data/tree/master/data/breakthrough][breakthrough]] 43 | - [[https://github.com/richemslie/gzero_data/tree/master/data/hexLG11][hex11]] 44 | - [[https://github.com/richemslie/gzero_data/tree/master/data/draughts_killer][International Draughts (killer mode)]] 45 | - [[https://github.com/richemslie/gzero_data/tree/master/data/hex19][hex19]] 46 | 47 | Little Golem Champion in last attempts @ Connect6, Hex13, Amazons and Breakthrough, winning all 48 | matches. Retired from further Championships. Connect6 and Hex 13 are currently rated 1st and 49 | 2nd respectively on active users. 50 | 51 | Amazons and Breakthrough won gold medals at ICGA 2018 Computer Olympiad. :clap: :clap: 52 | 53 | Reversi is also strong relative to humans on LG, yet performs a bit worse than top AB programs 54 | (about ntest level 20 the last time I tested). 55 | 56 | Also trained Baduk 9x9, it had a rating ~2900 elo on CGOS after 2-3 week of training. 57 | 58 | * Running 59 | The code is in fairly good shape, but could do with some refactoring and documentation (especially 60 | a how to guide on how to train a game). It would definitely be good to have an extra pair of eyes 61 | on it. I'll welcome and support anyone willing to try training a game for themselves. Some notes: 62 | 63 | 1. python is 2.7 64 | 2. requires a GPU/tensorflow 65 | 3. good starting point is https://github.com/richemslie/ggp-zero/blob/dev/src/ggpzero/defs 66 | 67 | How to run and install instruction coming soon! 68 | 69 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrs==17.4.0 2 | bcolz==1.2.1 3 | cffi==1.11.5 4 | colorama==0.3.7 5 | docutils==0.14 6 | fire==0.1.3 7 | future==0.16.0 8 | h5py==2.8.0 9 | Keras==2.2.0 10 | numpy==1.14.5 11 | py==1.5.3 12 | pytest==3.6.1 13 | six==1.11.0 14 | tensorflow==1.12.0 15 | Twisted==19.7.0 16 | -------------------------------------------------------------------------------- /src/cpp/Makefile: -------------------------------------------------------------------------------- 1 | include $(K273_PATH)/src/cpp/Makefile.in 2 | 3 | LIBS = -L $(K273_PATH)/build/lib -lk273 -lk273_greenlet -L$(GGPLIB_PATH)/src/cpp -lggplib_cpp -lpython2.7 4 | 5 | # needs numpy installed in virtualenv, if not found set THE_PYTHONPATH manually 6 | THE_PYTHONPATH = $(shell python -c "import sys; print sys.exec_prefix") 7 | INCLUDE_PATHS += -I $(GGPLIB_PATH)/src/cpp -I. -I$(THE_PYTHONPATH)/lib/python2.7/site-packages/numpy/core/include 8 | 9 | # since dll 10 | CFLAGS += -fPIC 11 | 12 | # python specific compile flags 13 | CFLAGS += -Wno-register -Wno-strict-aliasing $(shell python2-config --includes) 14 | 15 | SRCS = puct/node.cpp puct/evaluator.cpp player.cpp puct/minimax.cpp 16 | 17 | SRCS += gdltransformer.cpp scheduler.cpp selfplay.cpp selfplaymanager.cpp 18 | SRCS += supervisor.cpp ggpzero_interface.cpp 19 | 20 | OBJS = $(patsubst %.cpp, %.o, $(SRCS)) 21 | 22 | DEPS = $(SRCS:.cpp=.d) 23 | 24 | # Top level 25 | all: $(OBJS) ggpzero_interface.so 26 | 27 | ggpzero_interface.so: $(OBJS) 28 | $(CPP) -shared $(LDFLAGS) $(OBJS) $(LIBS) -o $@ 29 | 30 | %.o : %.cpp 31 | $(CPP) $(INCLUDE_PATHS) $(CFLAGS) -c -o $@ $< 32 | 33 | # Cleans 34 | clean : 35 | $(RM) ggpzero_interface.so $(OBJS) $(DEPS) 36 | 37 | -include $(DEPS) 38 | .PHONY: all clean 39 | 40 | -------------------------------------------------------------------------------- /src/cpp/events.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace GGPZero { 6 | 7 | struct PredictDoneEvent { 8 | int pred_count; 9 | std::vector policies; 10 | float* final_scores; 11 | }; 12 | 13 | struct ReadyEvent { 14 | // how much of the buffer is used (must be an exact multiple of channels*channel_size) 15 | int buf_count; 16 | float* channel_buf; 17 | }; 18 | } 19 | -------------------------------------------------------------------------------- /src/cpp/gdltransformer.cpp: -------------------------------------------------------------------------------- 1 | #include "gdltransformer.h" 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | using namespace GGPZero; 12 | 13 | 14 | void GdlBasesTransformer::setForState(float* local_buf, const GGPLib::BaseState* bs) const { 15 | for (const auto& b : this->board_space) { 16 | if (b.check(bs)) { 17 | b.set(local_buf); 18 | } 19 | } 20 | } 21 | 22 | void GdlBasesTransformer::toChannels(const GGPLib::BaseState* the_base_state, 23 | const std::vector & prev_states, 24 | float* buf) const { 25 | 26 | // NOTE: only supports 'channel first'. If we really want to have 'channel last', 27 | // transformation will need to be via python code on numpy arrays. Keeping this code lean and 28 | // fast. 29 | 30 | // zero out channels for board states 31 | for (int ii=0; iitotalSize(); ii++) { 32 | *(buf + ii) = 0.0f; 33 | } 34 | 35 | // the_base_state 36 | this->setForState(buf, the_base_state); 37 | 38 | // prev_states 39 | int count = 1; 40 | for (const GGPLib::BaseState* b : prev_states) { 41 | float* local_buf = buf + (this->channels_per_state * this->channel_size * count); 42 | this->setForState(local_buf, b); 43 | } 44 | 45 | // set the control states 46 | float* control_buf_start = buf + this->controlStatesStart(); 47 | for (const auto& c : this->control_space) { 48 | if (c.check(the_base_state)) { 49 | c.floodFill(control_buf_start, this->channel_size); 50 | } 51 | } 52 | } 53 | 54 | GGPLib::BaseState::ArrayType* GdlBasesTransformer::createHashMask(GGPLib::BaseState* bs) const { 55 | // first we set bs true for everything we are interested in 56 | for (int ii=0; iisize; ii++) { 57 | bs->set(ii, this->interested_set.find(ii) != this->interested_set.end()); 58 | } 59 | 60 | GGPLib::BaseState::ArrayType* buf = (GGPLib::BaseState::ArrayType*) malloc(bs->byte_count); 61 | memcpy(buf, bs->data, bs->byte_count); 62 | return buf; 63 | } 64 | -------------------------------------------------------------------------------- /src/cpp/gdltransformer.h: -------------------------------------------------------------------------------- 1 | /* Note: there is no shape sizes here. That is because we just return a continuous numpy array to 2 | python and let the python code reshape it how it likes. */ 3 | 4 | #pragma once 5 | 6 | // ggplib includes 7 | #include 8 | 9 | // std includes 10 | #include 11 | #include 12 | 13 | namespace GGPZero { 14 | 15 | class BaseToBoardSpace { 16 | public: 17 | BaseToBoardSpace(int base_indx, int buf_incr) : 18 | base_indx(base_indx), 19 | buf_incr(buf_incr) { 20 | } 21 | 22 | public: 23 | bool check(const GGPLib::BaseState* bs) const { 24 | return bs->get(this->base_indx); 25 | } 26 | 27 | void set(float* buf) const { 28 | *(buf + this->buf_incr) = 1.0f; 29 | } 30 | 31 | private: 32 | int base_indx; 33 | 34 | // increment into buffer ... channels[b_info.channel][b_info.y_idx][b_info.x_idx] 35 | int buf_incr; 36 | }; 37 | 38 | class BaseToChannelSpace { 39 | public: 40 | BaseToChannelSpace(int base_indx, int channel_id, float value) : 41 | base_indx(base_indx), 42 | channel_id(channel_id), 43 | value(value) { 44 | } 45 | 46 | public: 47 | bool check(const GGPLib::BaseState* bs) const { 48 | return bs->get(this->base_indx); 49 | } 50 | 51 | void floodFill(float* buf, int channel_size) const { 52 | buf += channel_size * this->channel_id; 53 | for (int ii=0; iivalue; 55 | } 56 | } 57 | 58 | private: 59 | // which base it is 60 | int base_indx; 61 | 62 | // which channel it is (relative to end of states) 63 | int channel_id; 64 | 65 | // the value to set the entire channel (flood fill) 66 | float value; 67 | }; 68 | 69 | class GdlBasesTransformer { 70 | public: 71 | GdlBasesTransformer(int channel_size, 72 | int channels_per_state, 73 | int num_control_channels, 74 | int num_prev_states, 75 | int num_rewards, 76 | std::vector & expected_policy_sizes) : 77 | channel_size(channel_size), 78 | channels_per_state(channels_per_state), 79 | num_control_channels(num_control_channels), 80 | num_prev_states(num_prev_states), 81 | num_rewards(num_rewards), 82 | expected_policy_sizes(expected_policy_sizes) { 83 | } 84 | 85 | public: 86 | // builder methods 87 | void addBoardBase(int base_indx, int buf_incr) { 88 | this->board_space.emplace_back(base_indx, buf_incr); 89 | this->interested_set.emplace(base_indx); 90 | } 91 | 92 | void addControlBase(int base_indx, int channel_id, float value) { 93 | this->control_space.emplace_back(base_indx, channel_id, value); 94 | this->interested_set.emplace(base_indx); 95 | } 96 | 97 | private: 98 | void setForState(float* local_buf, const GGPLib::BaseState* bs) const; 99 | 100 | int controlStatesStart() const { 101 | return this->channel_size * (this->channels_per_state * (this->num_prev_states + 1)); 102 | } 103 | 104 | public: 105 | // client side methods 106 | int totalSize() const { 107 | return this->channel_size * (this->channels_per_state * (this->num_prev_states + 1) + 108 | this->num_control_channels); 109 | } 110 | 111 | void toChannels(const GGPLib::BaseState* bs, 112 | const std::vector & prev_states, 113 | float* buf) const; 114 | 115 | int getNumberPrevStates() const { 116 | return this->num_prev_states; 117 | } 118 | 119 | // whether policy info should be on this, i dunno. guess following python's lead. XXX 120 | int getNumberPolicies() const { 121 | return this->expected_policy_sizes.size(); 122 | } 123 | 124 | int getPolicySize(int i) const { 125 | return this->expected_policy_sizes[i]; 126 | } 127 | 128 | // XXX wip: 129 | int getNumberRewards() const { 130 | // XXX currently we have one reward/value head per role. Same as policy. So we can 131 | // abuse that for now. In the future, we want to be able to modify these 132 | // independently. 133 | return this->num_rewards; 134 | } 135 | 136 | GGPLib::BaseState::ArrayType* createHashMask(GGPLib::BaseState* bs) const; 137 | 138 | private: 139 | const int channel_size; 140 | const int channels_per_state; 141 | const int num_control_channels; 142 | const int num_prev_states; 143 | const int num_rewards; 144 | std::vector board_space; 145 | std::vector control_space; 146 | 147 | std::vector expected_policy_sizes; 148 | std::set interested_set; 149 | }; 150 | 151 | } 152 | -------------------------------------------------------------------------------- /src/cpp/ggpzero_interface.cpp: -------------------------------------------------------------------------------- 1 | // python includes 2 | #include 3 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION 4 | #include 5 | 6 | /////////////////////////////////////////////////////////////////////////////// 7 | // global variables 8 | 9 | PyObject* ggpzero_interface_error; 10 | 11 | /////////////////////////////////////////////////////////////////////////////// 12 | 13 | // these are python objects... we include cpp files since it is hard to do any 14 | // other way (please if you know a way, let me know) 15 | 16 | #include "pyobjects/common.cpp" 17 | #include "pyobjects/gdltransformer_impl.cpp" 18 | #include "pyobjects/player_impl.cpp" 19 | #include "pyobjects/supervisor_impl.cpp" 20 | 21 | 22 | static PyObject* gi_buf_to_tuple_reverse_bytes(PyObject* self, PyObject* args) { 23 | PyObject* buf_object = nullptr; 24 | if (! ::PyArg_ParseTuple(args, "O", &buf_object)) { 25 | return nullptr; 26 | } 27 | 28 | ASSERT(PyString_Check(buf_object)); 29 | const char* ptbuf = PyString_AsString(buf_object); 30 | 31 | const int buf_size = PyString_GET_SIZE(buf_object); 32 | PyObject* tup = PyTuple_New(buf_size * 8); 33 | 34 | for (int ii=0; ii 10 | #include 11 | 12 | #include 13 | 14 | using namespace GGPZero; 15 | 16 | Player::Player(GGPLib::StateMachineInterface* sm, 17 | const GdlBasesTransformer* transformer, 18 | PuctConfig* conf) : 19 | transformer(transformer), 20 | config(conf), 21 | evaluator(nullptr), 22 | scheduler(nullptr), 23 | first_play(false), 24 | on_next_move_choice(nullptr) { 25 | 26 | ASSERT(conf->batch_size >= 1); 27 | 28 | // first create a scheduler 29 | this->scheduler = new GGPZero::NetworkScheduler(transformer, conf->batch_size); 30 | 31 | // ... and then the evaluator... 32 | // dupe statemachine here, as the PuctEvaluator thinks it is sharing a statemachine (ie it 33 | // doesn't dupe the statemachine itself) 34 | this->evaluator = new PuctEvaluator(sm->dupe(), this->scheduler, transformer); 35 | this->evaluator->updateConf(conf); 36 | } 37 | 38 | 39 | Player::~Player() { 40 | delete this->evaluator; 41 | delete this->scheduler; 42 | } 43 | 44 | void Player::updateConfig(float think_time, int converged_visits, bool verbose) { 45 | this->config->think_time = think_time; 46 | this->config->converged_visits = converged_visits; 47 | this->config->verbose = verbose; 48 | 49 | this->evaluator->updateConf(this->config); 50 | } 51 | 52 | void Player::puctPlayerReset(int game_depth) { 53 | K273::l_verbose("V2 Player::puctPlayerReset()"); 54 | this->evaluator->reset(game_depth); 55 | this->first_play = true; 56 | } 57 | 58 | 59 | void Player::puctApplyMove(const GGPLib::JointMove* move) { 60 | this->scheduler->createMainLoop(); 61 | 62 | if (this->first_play) { 63 | this->first_play = false; 64 | auto f = [this, move]() { 65 | this->evaluator->establishRoot(nullptr); 66 | this->evaluator->applyMove(move); 67 | }; 68 | 69 | this->scheduler->addRunnable(f); 70 | 71 | } else { 72 | auto f = [this, move]() { 73 | this->evaluator->applyMove(move); 74 | }; 75 | 76 | this->scheduler->addRunnable(f); 77 | } 78 | } 79 | 80 | void Player::puctPlayerMove(const GGPLib::BaseState* state, int evaluations, double end_time) { 81 | this->on_next_move_choice = nullptr; 82 | this->scheduler->createMainLoop(); 83 | 84 | K273::l_verbose("V2 Player::puctPlayerMove() - %d", evaluations); 85 | 86 | // this should only happen as first move in the game 87 | if (this->first_play) { 88 | this->first_play = false; 89 | auto f = [this, state, evaluations, end_time]() { 90 | this->evaluator->establishRoot(state); 91 | this->on_next_move_choice = this->evaluator->onNextMove(evaluations, end_time); 92 | }; 93 | 94 | this->scheduler->addRunnable(f); 95 | 96 | } else { 97 | auto f = [this, evaluations, end_time]() { 98 | this->on_next_move_choice = this->evaluator->onNextMove(evaluations, end_time); 99 | }; 100 | 101 | this->scheduler->addRunnable(f); 102 | } 103 | } 104 | 105 | std::tuple Player::puctPlayerGetMove(int lead_role_index) { 106 | if (this->on_next_move_choice == nullptr) { 107 | return std::make_tuple(-1, -1.0f, -1); 108 | } 109 | 110 | float probability = -1; 111 | const PuctNode* node = this->on_next_move_choice->to_node; 112 | if (node != nullptr) { 113 | probability = node->getCurrentScore(lead_role_index); 114 | } 115 | 116 | return std::make_tuple(this->on_next_move_choice->move.get(lead_role_index), 117 | probability, 118 | this->evaluator->nodeCount()); 119 | } 120 | 121 | void Player::balanceNode(int max_count) { 122 | // ask the evaluator to balance the first 'max_count' moves 123 | K273::l_verbose("ask the evaluator to balance the first %d moves", max_count); 124 | this->scheduler->createMainLoop(); 125 | 126 | const PuctNode* root = this->evaluator->getRootNode(); 127 | if (root == nullptr) { 128 | return; 129 | } 130 | 131 | max_count = std::min((int) root->num_children, max_count); 132 | 133 | auto f = [this, max_count]() { this->evaluator->balanceFirstMoves(max_count); }; 134 | this->scheduler->addRunnable(f); 135 | } 136 | 137 | std::vector Player::treeDebugInfo(int max_count) { 138 | std::vector res; 139 | const PuctNode* root = this->evaluator->getRootNode(); 140 | 141 | if (root == nullptr) { 142 | return res; 143 | } 144 | 145 | max_count = std::min((int) root->num_children, max_count); 146 | for (int ii = 0; ii < max_count; ii++) { 147 | PuctNodeDebug info; 148 | PuctNode::debug(root, ii, 10, info); 149 | res.push_back(info); 150 | } 151 | 152 | return res; 153 | } 154 | 155 | const GGPZero::ReadyEvent* Player::poll(int predict_count, std::vector & data) { 156 | // when pred_count == 0, it is used to bootstrap the main loop in scheduler 157 | this->predict_done_event.pred_count = predict_count; 158 | 159 | // XXX holds pointers to data - maybe we should just copy it like in supervisor case. It isn't 160 | // like this is an optimisation, I am just being lazy. 161 | 162 | int index = 0; 163 | this->predict_done_event.policies.resize(this->transformer->getNumberPolicies()); 164 | for (int ii=0; iitransformer->getNumberPolicies(); ii++) { 165 | this->predict_done_event.policies[ii] = data[index++]; 166 | } 167 | 168 | this->predict_done_event.final_scores = data[index++]; 169 | 170 | this->scheduler->poll(&this->predict_done_event, &this->ready_event); 171 | 172 | return &this->ready_event; 173 | } 174 | -------------------------------------------------------------------------------- /src/cpp/player.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "events.h" 4 | #include "scheduler.h" 5 | #include "gdltransformer.h" 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | namespace GGPZero { 12 | 13 | // forwards 14 | class PuctEvaluator; 15 | struct PuctConfig; 16 | struct PuctNodeChild; 17 | struct PuctNodeDebug; 18 | 19 | // this is a bit of hack, wasnt really designed to actually play from c++ 20 | class Player { 21 | 22 | public: 23 | Player(GGPLib::StateMachineInterface* sm, 24 | const GGPZero::GdlBasesTransformer* transformer, 25 | PuctConfig* conf); 26 | ~Player(); 27 | 28 | public: 29 | // python side 30 | void updateConfig(float think_time, int converge_relaxed, bool verbose); 31 | 32 | void puctPlayerReset(int game_depth); 33 | void puctApplyMove(const GGPLib::JointMove* move); 34 | void puctPlayerMove(const GGPLib::BaseState* state, int iterations, double end_time); 35 | std::tuple puctPlayerGetMove(int lead_role_index); 36 | 37 | void balanceNode(int max_count); 38 | std::vector treeDebugInfo(int max_count); 39 | 40 | const ReadyEvent* poll(int predict_count, std::vector & data); 41 | 42 | private: 43 | const GdlBasesTransformer* transformer; 44 | PuctConfig* config; 45 | 46 | PuctEvaluator* evaluator; 47 | NetworkScheduler* scheduler; 48 | 49 | bool first_play; 50 | 51 | // store the choice of onNextMove()... 52 | const PuctNodeChild* on_next_move_choice; 53 | 54 | // Events 55 | ReadyEvent ready_event; 56 | PredictDoneEvent predict_done_event; 57 | }; 58 | 59 | } 60 | -------------------------------------------------------------------------------- /src/cpp/puct/config.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | 4 | namespace GGPZero { 5 | 6 | 7 | enum class ChooseFn { 8 | choose_top_visits, choose_temperature 9 | }; 10 | 11 | struct PuctConfig { 12 | bool verbose; 13 | 14 | float puct_constant; 15 | float puct_constant_root; 16 | 17 | float dirichlet_noise_pct; 18 | float noise_policy_squash_pct; 19 | float noise_policy_squash_prob; 20 | 21 | ChooseFn choose; 22 | int max_dump_depth; 23 | 24 | float random_scale; 25 | float temperature; 26 | int depth_temperature_start; 27 | float depth_temperature_increment; 28 | int depth_temperature_stop; 29 | float depth_temperature_max; 30 | 31 | float fpu_prior_discount; 32 | float fpu_prior_discount_root; 33 | 34 | // < 0, off 35 | float top_visits_best_guess_converge_ratio; 36 | 37 | float think_time; 38 | int converged_visits; 39 | 40 | int batch_size; 41 | 42 | // <= 0, off (XXX unused currently) 43 | int use_legals_count_draw; 44 | 45 | // MCTS prover 46 | bool backup_finalised; 47 | 48 | // turn on transposition 49 | bool lookup_transpositions; 50 | 51 | // when think time, multiples time. when iterations, multiples iterations. 52 | float evaluation_multiplier_to_convergence; 53 | }; 54 | 55 | } 56 | -------------------------------------------------------------------------------- /src/cpp/puct/evaluator.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "puct/node.h" 4 | #include "puct/config.h" 5 | 6 | #include "scheduler.h" 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | #include 15 | 16 | 17 | namespace GGPZero { 18 | 19 | struct PathElement { 20 | PathElement(PuctNode* node, PuctNodeChild* choice, PuctNodeChild* best); 21 | 22 | PuctNode* node; 23 | PuctNodeChild* choice; 24 | PuctNodeChild* best; 25 | }; 26 | 27 | using Path = std::vector ; 28 | 29 | /////////////////////////////////////////////////////////////////////////////// 30 | 31 | class PuctEvaluator { 32 | public: 33 | PuctEvaluator(GGPLib::StateMachineInterface* sm, NetworkScheduler* scheduler, 34 | const GGPZero::GdlBasesTransformer* transformer); 35 | virtual ~PuctEvaluator(); 36 | 37 | public: 38 | // called after creation 39 | void updateConf(const PuctConfig* conf); 40 | 41 | void setDirichletNoise(PuctNode* node); 42 | float priorScore(PuctNode* node, int depth) const; 43 | void setPuctConstant(PuctNode* node, int depth) const; 44 | float getTemperature(int depth) const; 45 | 46 | const PuctNodeChild* choose(const PuctNode* node); 47 | bool converged(int count) const; 48 | 49 | void checkDrawStates(const PuctNode* node, PuctNode* next); 50 | PuctNode* expandChild(PuctNode* parent, PuctNodeChild* child); 51 | 52 | void balanceFirstMoves(int max_moves); 53 | 54 | private: 55 | // tree manangement 56 | void removeNode(PuctNode*); 57 | void releaseNodes(PuctNode*); 58 | PuctNode* lookupNode(const GGPLib::BaseState* bs, int depth); 59 | PuctNode* createNode(PuctNode* parent, const GGPLib::BaseState* state); 60 | 61 | PuctNodeChild* selectChild(PuctNode* node, Path& path); 62 | 63 | void backUpMiniMax(float* new_scores, const PathElement& cur); 64 | void backup(float* new_scores, const Path& path); 65 | 66 | int treePlayout(PuctNode* current, std::vector &path); 67 | 68 | void playoutWorker(int worker_id); 69 | void playoutMain(int max_evaluations, double end_time); 70 | 71 | void logDebug(const PuctNodeChild* choice_root); 72 | 73 | public: 74 | void reset(int game_depth); 75 | PuctNode* fastApplyMove(const PuctNodeChild* next); 76 | PuctNode* establishRoot(const GGPLib::BaseState* current_state); 77 | 78 | void resetRootNode(); 79 | const PuctNodeChild* onNextMove(int max_evaluations, double end_time=-1); 80 | void applyMove(const GGPLib::JointMove* move); 81 | 82 | const PuctNodeChild* chooseTopVisits(const PuctNode* node) const; 83 | const PuctNodeChild* chooseTemperature(const PuctNode* node); 84 | 85 | Children getProbabilities(PuctNode* node, float temperature, bool use_linger=true); 86 | void dumpNode(const PuctNode* node, const PuctNodeChild* choice) const; 87 | 88 | int nodeCount() const { 89 | return this->number_of_nodes; 90 | } 91 | 92 | GGPLib::StateMachineInterface* getSM() const { 93 | return this->sm; 94 | } 95 | 96 | const PuctNode* getRootNode() const { 97 | return this->root; 98 | } 99 | 100 | private: 101 | struct PlayoutStats { 102 | PlayoutStats() { 103 | this->reset(); 104 | } 105 | 106 | void reset() { 107 | this->num_blocked = 0; 108 | this->num_tree_playouts = 0; 109 | this->num_evaluations = 0; 110 | this->num_transpositions_attached = 0; 111 | 112 | this->playouts_total_depth = 0; 113 | this->playouts_max_depth = 0; 114 | this->playouts_finals = 0; 115 | } 116 | 117 | int num_blocked; 118 | int num_tree_playouts; 119 | int num_evaluations; 120 | int num_transpositions_attached; 121 | 122 | int playouts_total_depth; 123 | int playouts_max_depth; 124 | int playouts_finals; 125 | }; 126 | 127 | private: 128 | const PuctConfig* conf; 129 | 130 | GGPLib::StateMachineInterface* sm; 131 | GGPLib::BaseState* basestate_expand_node; 132 | NetworkScheduler* scheduler; 133 | 134 | int game_depth; 135 | 136 | // tree for the entire game 137 | PuctNode* initial_root; 138 | 139 | // root of the tree 140 | PuctNode* root; 141 | 142 | // lookup table to tree 143 | GGPLib::BaseState::HashMapMasked * lookup; 144 | 145 | // when releasing nodes from tree, puts them to delete afterwards 146 | std::vector garbage; 147 | 148 | // tree info 149 | int number_of_nodes; 150 | long node_allocated_memory; 151 | 152 | // used by workers to indicate work to do 153 | bool do_playouts; 154 | 155 | // stats collected during playouts 156 | PlayoutStats stats; 157 | 158 | // random number generator 159 | K273::xoroshiro128plus32 rng; 160 | }; 161 | 162 | } 163 | -------------------------------------------------------------------------------- /src/cpp/puct/minimax.cpp: -------------------------------------------------------------------------------- 1 | #include "minimax.h" 2 | 3 | using namespace GGPZero; 4 | 5 | PuctNodeChild* MiniMaxer::minimaxExpanded(PuctNode* node) { 6 | const int role_count = this->sm->getRoleCount(); 7 | 8 | // returns the best child, or nullptr if there was no more minimax 9 | if (node->is_finalised) { 10 | return nullptr; 11 | } 12 | 13 | const int ri = node->lead_role_index; 14 | PuctNodeChild* best = nullptr; 15 | float best_score = -1; 16 | for (int ii=0; iinum_children; ii++) { 17 | PuctNodeChild* child = node->getNodeChild(role_count, ii); 18 | 19 | if (child->use_minimax) { 20 | ASSERT(child->to_node != nullptr); 21 | 22 | PuctNodeChild* mc = minimaxExpanded(child->to_node); 23 | float score = -1; 24 | if (mc == nullptr) { 25 | score = child->to_node->getFinalScore(ri, true); 26 | } else { 27 | score = child->to_node->getCurrentScore(ri); 28 | } 29 | 30 | if (score > best_score) { 31 | best = child; 32 | best_score = score; 33 | } 34 | } 35 | } 36 | 37 | if (best != nullptr) { 38 | node->setCurrentScore(ri, best_score); 39 | } 40 | 41 | return best; 42 | } 43 | 44 | typedef std::vector MiniChildren; 45 | 46 | static MiniChildren sortTreeByPolicy(PuctNode* node, int role_count) { 47 | MiniChildren children; 48 | for (int ii=0; iinum_children; ii++) { 49 | PuctNodeChild* child = node->getNodeChild(role_count, ii); 50 | children.push_back(child); 51 | } 52 | 53 | auto f = [](const PuctNodeChild* a, const PuctNodeChild* b) { 54 | return a->policy_prob > b->policy_prob; 55 | }; 56 | 57 | std::sort(children.begin(), children.end(), f); 58 | return children; 59 | } 60 | 61 | void MiniMaxer::expandTree(PuctNode* node, int depth) { 62 | const int role_count = this->sm->getRoleCount(); 63 | 64 | if (depth == this->conf->minimax_specifier.size()) { 65 | return; 66 | } 67 | 68 | if (node->is_finalised) { 69 | return; 70 | } 71 | 72 | // get the specifier 73 | int policy_count = this->conf->minimax_specifier[depth].add_policy_count; 74 | 75 | // sort tree by policy 76 | auto children = ::sortTreeByPolicy(node, role_count); 77 | 78 | auto expand = [node, this](PuctNodeChild* child, int next_depth) { 79 | if (child->to_node == nullptr) { 80 | this->evaluator->expandChild(node, child); 81 | } 82 | 83 | ASSERT(child->to_node != nullptr); 84 | this->expandTree(child->to_node, next_depth); 85 | }; 86 | 87 | int unexpanded_children = node->num_children; 88 | if (unexpanded_children <= this->conf->follow_max) { 89 | for (PuctNodeChild* c : children) { 90 | // note depth not decremented 91 | expand(c, depth); 92 | c->use_minimax = true; 93 | } 94 | 95 | return; 96 | } 97 | 98 | // add best from policy 99 | for (PuctNodeChild* c : children) { 100 | if (policy_count > 0) { 101 | expand(c, depth - 1); 102 | c->use_minimax = true; 103 | policy_count--; 104 | unexpanded_children--; 105 | } else { 106 | c->use_minimax = false; 107 | } 108 | } 109 | 110 | int random_count = std::min(unexpanded_children, 111 | this->conf->minimax_specifier[depth].add_random_count); 112 | 113 | float chance_to_random_play_move = 1.0f / (unexpanded_children + 0.001f); 114 | 115 | while (random_count > 0) { 116 | for (PuctNodeChild* c : children) { 117 | if (c->use_minimax) { 118 | continue; 119 | } 120 | 121 | if (this->rng.get() < chance_to_random_play_move) { 122 | expand(c, depth - 1); 123 | c->use_minimax = true; 124 | random_count--; 125 | } 126 | } 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /src/cpp/puct/minimax.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "puct/node.h" 4 | #include "puct/config.h" 5 | #include "puct/evaluator.h" 6 | 7 | #include 8 | 9 | #include 10 | 11 | #include 12 | 13 | 14 | namespace GGPZero { 15 | 16 | struct Specifier { 17 | int add_policy_count; 18 | int add_random_count; 19 | }; 20 | 21 | struct MiniMaxConfig { 22 | bool verbose; 23 | 24 | float random_scale; 25 | float temperature; 26 | int depth_temperature_stop; 27 | 28 | // if the legals <= follow_max, add them in without affecting the minimax. 29 | int follow_max; 30 | 31 | std::vector minimax_specifier; 32 | }; 33 | 34 | class MiniMaxer { 35 | public: 36 | MiniMaxer(const MiniMaxConfig* conf, 37 | PuctEvaluator* evaluator, 38 | GGPLib::StateMachineInterface* sm) : 39 | conf(conf), 40 | evaluator(evaluator), 41 | sm(sm) { 42 | } 43 | 44 | private: 45 | PuctNodeChild* minimaxExpanded(PuctNode* node); 46 | void expandTree(PuctNode* node, int depth); 47 | 48 | private: 49 | const MiniMaxConfig* conf; 50 | PuctEvaluator* evaluator; 51 | GGPLib::StateMachineInterface* sm; 52 | 53 | // random number generator 54 | K273::xoroshiro128plus32 rng; 55 | }; 56 | 57 | } // namespace GGPZero 58 | -------------------------------------------------------------------------------- /src/cpp/pyobjects/gdltransformer_impl.cpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "gdltransformer.h" 4 | 5 | // k273 includes 6 | #include 7 | #include 8 | 9 | // ggplib imports 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | using namespace GGPZero; 19 | 20 | /////////////////////////////////////////////////////////////////////////////// 21 | 22 | struct PyObject_GdlBasesTransformerWrapper { 23 | PyObject_HEAD 24 | GdlBasesTransformer* impl; 25 | }; 26 | 27 | static PyObject* GdlBasesTransformerWrapper_addBoardBase(PyObject_GdlBasesTransformerWrapper* self, PyObject* args) { 28 | int arg0, arg1; 29 | if (! ::PyArg_ParseTuple(args, "ii", &arg0, &arg1)) { 30 | return nullptr; 31 | } 32 | 33 | self->impl->addBoardBase(arg0, arg1); 34 | 35 | return Py_None; 36 | } 37 | 38 | static PyObject* GdlBasesTransformerWrapper_addControlBase(PyObject_GdlBasesTransformerWrapper* self, PyObject* args) { 39 | int arg0, arg1; 40 | float arg2; 41 | if (! ::PyArg_ParseTuple(args, "iif", &arg0, &arg1, &arg2)) { 42 | return nullptr; 43 | } 44 | 45 | self->impl->addControlBase(arg0, arg1, arg2); 46 | 47 | return Py_None; 48 | } 49 | 50 | static PyObject* GdlBasesTransformerWrapper_test(PyObject_GdlBasesTransformerWrapper* self, PyObject* args) { 51 | 52 | const int MOVES = 4096; 53 | static float* array_buf = nullptr; 54 | 55 | if (array_buf == nullptr) { 56 | array_buf = (float*) malloc(sizeof(float) * self->impl->totalSize() * MOVES); 57 | } 58 | 59 | ssize_t ptr = 0; 60 | if (! ::PyArg_ParseTuple(args, "n", &ptr)) { 61 | return nullptr; 62 | } 63 | 64 | GGPLib::StateMachine* sm = reinterpret_cast (ptr); 65 | 66 | GGPLib::JointMove* joint_move = sm->getJointMove(); 67 | GGPLib::BaseState* other = sm->newBaseState(); 68 | const GGPLib::BaseState* bs = sm->getInitialState(); 69 | other->assign(bs); 70 | sm->updateBases(bs); 71 | 72 | std::vector dummy; 73 | 74 | float* pt_array_buf = array_buf; 75 | 76 | // four random moves 77 | K273::xoroshiro128plus32 random; 78 | 79 | int total_depth = 0; 80 | for (int jj=0; jj<2; jj++) { 81 | sm->reset(); 82 | 83 | for (int kk=0; kkisTerminal()) { 86 | break; 87 | } 88 | 89 | // populate joint move 90 | for (int ii=0; iigetRoleCount(); ii++) { 91 | const GGPLib::LegalState* ls = sm->getLegalState(ii); 92 | int x = random.getWithMax(ls->getCount()); 93 | int choice = ls->getLegal(x); 94 | joint_move->set(ii, choice); 95 | } 96 | 97 | sm->nextState(joint_move, other); 98 | sm->updateBases(other); 99 | 100 | int prev_states = self->impl->getNumberPrevStates(); 101 | if (prev_states) { 102 | std::vector prevs; 103 | // not excactly valid states, but doesnt matter 104 | for (int ii=0; iiimpl->toChannels(other, prevs, pt_array_buf); 109 | 110 | } else { 111 | self->impl->toChannels(other, dummy, pt_array_buf); 112 | } 113 | 114 | pt_array_buf += self->impl->totalSize(); 115 | total_depth++; 116 | } 117 | } 118 | 119 | 120 | const int ND = 1; 121 | npy_intp dims[1]{self->impl->totalSize() * total_depth}; 122 | 123 | return PyArray_SimpleNewFromData(ND, dims, NPY_FLOAT, array_buf); 124 | } 125 | 126 | static struct PyMethodDef GdlBasesTransformerWrapper_methods[] = { 127 | {"add_board_base", (PyCFunction) GdlBasesTransformerWrapper_addBoardBase, METH_VARARGS, "addBoardBase"}, 128 | {"add_control_base", (PyCFunction) GdlBasesTransformerWrapper_addControlBase, METH_VARARGS, "addControlBase"}, 129 | {"test", (PyCFunction) GdlBasesTransformerWrapper_test, METH_VARARGS, "test"}, 130 | {nullptr, nullptr} /* Sentinel */ 131 | }; 132 | 133 | static void GdlBasesTransformerWrapper_dealloc(PyObject* ptr); 134 | 135 | 136 | static PyTypeObject PyType_GdlBasesTransformerWrapper = { 137 | PyVarObject_HEAD_INIT(nullptr, 0) 138 | "GdlBasesTransformerWrapper", /*tp_name*/ 139 | sizeof(PyObject_GdlBasesTransformerWrapper), /*tp_size*/ 140 | 0, /*tp_itemsize*/ 141 | 142 | /* methods */ 143 | GdlBasesTransformerWrapper_dealloc, /*tp_dealloc*/ 144 | 0, /*tp_print*/ 145 | 0, /*tp_getattr*/ 146 | 0, /*tp_setattr*/ 147 | 0, /*tp_compare*/ 148 | 0, /*tp_repr*/ 149 | 0, /*tp_as_number*/ 150 | 0, /*tp_as_sequence*/ 151 | 0, /*tp_as_mapping*/ 152 | 0, /*tp_hash*/ 153 | 0, /*tp_call*/ 154 | 0, /*tp_str*/ 155 | 0, /*tp_getattro*/ 156 | 0, /*tp_setattro*/ 157 | 0, /*tp_as_buffer*/ 158 | Py_TPFLAGS_DEFAULT, /*tp_flags*/ 159 | 0, /*tp_doc*/ 160 | 0, /*tp_traverse*/ 161 | 0, /*tp_clear*/ 162 | 0, /*tp_richcompare*/ 163 | 0, /*tp_weaklistoffset*/ 164 | 0, /*tp_iter*/ 165 | 0, /*tp_iternext*/ 166 | GdlBasesTransformerWrapper_methods, /* tp_methods */ 167 | 0, /* tp_members */ 168 | 0, /* tp_getset */ 169 | }; 170 | 171 | static PyObject_GdlBasesTransformerWrapper* PyType_GdlBasesTransformerWrapper_new(GdlBasesTransformer* impl) { 172 | PyObject_GdlBasesTransformerWrapper* res = PyObject_New(PyObject_GdlBasesTransformerWrapper, 173 | &PyType_GdlBasesTransformerWrapper); 174 | res->impl = impl; 175 | return res; 176 | } 177 | 178 | 179 | static void GdlBasesTransformerWrapper_dealloc(PyObject* ptr) { 180 | K273::l_debug("--> GdlBasesTransformerWrapper_dealloc"); 181 | ::PyObject_Del(ptr); 182 | } 183 | 184 | /////////////////////////////////////////////////////////////////////////////// 185 | 186 | static PyObject* gi_GdlBasesTransformer(PyObject* self, PyObject* args) { 187 | int channel_size, channels_per_state, num_control_channels; 188 | int num_prev_states, num_rewards; 189 | PyObject* expected_policy_sizes; 190 | 191 | if (! ::PyArg_ParseTuple(args, "iiiiiO!", 192 | &channel_size, 193 | &channels_per_state, 194 | &num_control_channels, 195 | &num_prev_states, 196 | &num_rewards, 197 | &PyList_Type, &expected_policy_sizes)) { 198 | return nullptr; 199 | } 200 | 201 | auto asInt = [expected_policy_sizes] (int index) { 202 | PyObject* borrowed = PyList_GET_ITEM(expected_policy_sizes, index); 203 | return PyInt_AsLong(borrowed); 204 | }; 205 | 206 | std::vector policy_sizes; 207 | for (int ii=0; iiimpl->puctPlayerReset(game_depth); 15 | Py_RETURN_NONE; 16 | } 17 | 18 | static PyObject* Player_apply_move(PyObject_Player* self, PyObject* args) { 19 | ssize_t ptr = 0; 20 | if (! ::PyArg_ParseTuple(args, "n", &ptr)) { 21 | return nullptr; 22 | } 23 | 24 | GGPLib::JointMove* move = reinterpret_cast (ptr); 25 | self->impl->puctApplyMove(move); 26 | 27 | Py_RETURN_NONE; 28 | } 29 | 30 | static PyObject* Player_move(PyObject_Player* self, PyObject* args) { 31 | ssize_t ptr = 0; 32 | int evaluations = 0; 33 | double end_time = 0.0; 34 | if (! ::PyArg_ParseTuple(args, "nid", &ptr, &evaluations, &end_time)) { 35 | return nullptr; 36 | } 37 | 38 | GGPLib::BaseState* basestate = reinterpret_cast (ptr); 39 | self->impl->puctPlayerMove(basestate, evaluations, end_time); 40 | Py_RETURN_NONE; 41 | } 42 | 43 | static PyObject* Player_get_move(PyObject_Player* self, PyObject* args) { 44 | int lead_role_index = 0; 45 | if (! ::PyArg_ParseTuple(args, "i", &lead_role_index)) { 46 | return nullptr; 47 | } 48 | 49 | int a, c; 50 | float b; 51 | std::tie(a, b, c) = self->impl->puctPlayerGetMove(lead_role_index); 52 | return ::Py_BuildValue("ifi", a, b, c); 53 | } 54 | 55 | static PyObject* Player_balance_moves(PyObject_Player* self, PyObject* args) { 56 | int max_count = 0; 57 | if (! ::PyArg_ParseTuple(args, "i", &max_count)) { 58 | return nullptr; 59 | } 60 | 61 | self->impl->balanceNode(max_count); 62 | 63 | Py_RETURN_NONE; 64 | } 65 | 66 | #include 67 | static PyObject* Player_tree_debug(PyObject_Player* self, PyObject* args) { 68 | int max_count = 0; 69 | if (! ::PyArg_ParseTuple(args, "i", &max_count)) { 70 | return nullptr; 71 | } 72 | 73 | std::vector debug_list = self->impl->treeDebugInfo(max_count); 74 | 75 | PyObject* result = PyTuple_New(debug_list.size()); 76 | 77 | for (int ii=0; iiimpl->updateConfig(think_time, converge_relaxed, verbose); 109 | Py_RETURN_NONE; 110 | } 111 | 112 | static PyObject* Player_poll(PyObject_Player* self, PyObject* args) { 113 | return doPoll(self->impl, args); 114 | } 115 | 116 | static struct PyMethodDef Player_methods[] = { 117 | {"player_reset", (PyCFunction) Player_reset, METH_VARARGS, "player_reset"}, 118 | {"player_update_config", (PyCFunction) Player_updateConfig, METH_VARARGS, "player_update_config"}, 119 | {"player_apply_move", (PyCFunction) Player_apply_move, METH_VARARGS, "player_apply_move"}, 120 | {"player_move", (PyCFunction) Player_move, METH_VARARGS, "player_move"}, 121 | {"player_get_move", (PyCFunction) Player_get_move, METH_VARARGS, "player_get_move"}, 122 | 123 | {"player_balance_moves", (PyCFunction) Player_balance_moves, METH_VARARGS, "player_balance_moves"}, 124 | {"player_tree_debug", (PyCFunction) Player_tree_debug, METH_VARARGS, "player_get_move"}, 125 | 126 | {"poll", (PyCFunction) Player_poll, METH_VARARGS, "poll"}, 127 | 128 | {nullptr, nullptr} /* Sentinel */ 129 | }; 130 | 131 | static void Player_dealloc(PyObject* ptr); 132 | 133 | static PyTypeObject PyType_Player = { 134 | PyVarObject_HEAD_INIT(nullptr, 0) 135 | "Player", /*tp_name*/ 136 | sizeof(PyObject_Player), /*tp_size*/ 137 | 0, /*tp_itemsize*/ 138 | 139 | /* methods */ 140 | Player_dealloc, /*tp_dealloc*/ 141 | 0, /*tp_print*/ 142 | 0, /*tp_getattr*/ 143 | 0, /*tp_setattr*/ 144 | 0, /*tp_compare*/ 145 | 0, /*tp_repr*/ 146 | 0, /*tp_as_number*/ 147 | 0, /*tp_as_sequence*/ 148 | 0, /*tp_as_mapping*/ 149 | 0, /*tp_hash*/ 150 | 0, /*tp_call*/ 151 | 0, /*tp_str*/ 152 | 0, /*tp_getattro*/ 153 | 0, /*tp_setattro*/ 154 | 0, /*tp_as_buffer*/ 155 | Py_TPFLAGS_DEFAULT, /*tp_flags*/ 156 | 0, /*tp_doc*/ 157 | 0, /*tp_traverse*/ 158 | 0, /*tp_clear*/ 159 | 0, /*tp_richcompare*/ 160 | 0, /*tp_weaklistoffset*/ 161 | 0, /*tp_iter*/ 162 | 0, /*tp_iternext*/ 163 | Player_methods, /* tp_methods */ 164 | 0, /* tp_members */ 165 | 0, /* tp_getset */ 166 | }; 167 | 168 | static PyObject_Player* PyType_Player_new(GGPZero::Player* impl) { 169 | PyObject_Player* res = PyObject_New(PyObject_Player, 170 | &PyType_Player); 171 | res->impl = impl; 172 | return res; 173 | } 174 | 175 | 176 | static void Player_dealloc(PyObject* ptr) { 177 | K273::l_debug("--> Player_dealloc"); 178 | ::PyObject_Del(ptr); 179 | } 180 | 181 | /////////////////////////////////////////////////////////////////////////////// 182 | 183 | static PyObject* gi_Player(PyObject* self, PyObject* args) { 184 | ssize_t ptr = 0; 185 | PyObject_GdlBasesTransformerWrapper* py_transformer = nullptr; 186 | PyObject* dict = nullptr; 187 | 188 | // sm, transformer, batch_size, expected_policy_size, role_1_index 189 | if (! ::PyArg_ParseTuple(args, "nO!O!", &ptr, 190 | &PyType_GdlBasesTransformerWrapper, &py_transformer, 191 | &PyDict_Type, &dict)) { 192 | return nullptr; 193 | } 194 | 195 | GGPLib::StateMachine* sm = reinterpret_cast (ptr); 196 | GGPZero::PuctConfig* conf = createPuctConfig(dict); 197 | 198 | // create the c++ object 199 | GGPZero::Player* player = new GGPZero::Player(sm, py_transformer->impl, conf); 200 | return (PyObject*) PyType_Player_new(player); 201 | } 202 | -------------------------------------------------------------------------------- /src/cpp/pyref.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // python includes 4 | #include 5 | 6 | struct PyCleanupRef { 7 | PyCleanupRef(PyObject* o) : 8 | o(o) { 9 | } 10 | 11 | ~PyCleanupRef() { 12 | Py_DECREF(o); 13 | } 14 | 15 | PyObject* o; 16 | }; 17 | 18 | #define PYCLEANUPREF(x) PyCleanupRef cleanup##x(x); 19 | -------------------------------------------------------------------------------- /src/cpp/sample.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | namespace GGPZero { 11 | 12 | struct Sample { 13 | typedef std::vector > Policy; 14 | 15 | GGPLib::BaseState* state; 16 | std::vector prev_states; 17 | std::vector policies; 18 | std::vector final_score; 19 | int depth; 20 | int game_length; 21 | std::string match_identifier; 22 | bool has_resigned; 23 | bool resign_false_positive; 24 | int starting_sample_depth; 25 | std::vector resultant_puct_score; 26 | int resultant_puct_visits; 27 | 28 | // keep the node around - just until we add the sample 29 | int lead_role_index; 30 | }; 31 | 32 | } 33 | -------------------------------------------------------------------------------- /src/cpp/scheduler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "sample.h" 4 | #include "events.h" 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | namespace GGPZero { 15 | 16 | // forwards 17 | class GdlBasesTransformer; 18 | 19 | /////////////////////////////////////////////////////////////////////////////// 20 | 21 | class ModelResult { 22 | public: 23 | ModelResult() : 24 | basestate(nullptr) { 25 | } 26 | 27 | void set(const GGPLib::BaseState* basestate, 28 | int idx, const PredictDoneEvent* evt, 29 | const GdlBasesTransformer* transformer); 30 | 31 | const float* getPolicy(int index) const { 32 | return this->policies[index]; 33 | } 34 | 35 | float getReward(int index) const { 36 | return this->rewards[index]; 37 | } 38 | 39 | private: 40 | // XXX better if all of this was one chunk of memory... but maybe later 41 | 42 | std::vector policies; 43 | std::vector rewards; 44 | 45 | // could follow leela here and just store a hash 46 | const GGPLib::BaseState* basestate; 47 | }; 48 | 49 | using ModelResultList = K273::InplaceList ; 50 | 51 | /////////////////////////////////////////////////////////////////////////////// 52 | // pure abstract interface 53 | 54 | class ModelRequestInterface { 55 | public: 56 | ModelRequestInterface() { 57 | } 58 | 59 | virtual ~ModelRequestInterface() { 60 | } 61 | 62 | public: 63 | // called to check if in NN cache 64 | virtual const GGPLib::BaseState* getBaseState() const = 0; 65 | 66 | // low level adds info to buffer 67 | virtual void add(float* buf, const GdlBasesTransformer* transformer) = 0; 68 | 69 | // given a result, populated 70 | virtual void reply(const ModelResult& result, 71 | const GdlBasesTransformer* transformer) = 0; 72 | }; 73 | 74 | /////////////////////////////////////////////////////////////////////////////// 75 | // XXX finish LRU NN cache 76 | 77 | class NetworkScheduler { 78 | public: 79 | NetworkScheduler(const GdlBasesTransformer* transformer, 80 | int batch_size, int lru_cache_size=1000); 81 | ~NetworkScheduler(); 82 | 83 | public: 84 | // called an evaluator engine 85 | void evaluate(ModelRequestInterface* request); 86 | void yield(); 87 | 88 | public: 89 | template 90 | void addRunnable(Callable& f) { 91 | ASSERT(this->main_loop != nullptr); 92 | greenlet_t* g = createGreenlet (f, this->main_loop); 93 | this->runnables.push_back(g); 94 | } 95 | 96 | void createMainLoop() { 97 | ASSERT(this->main_loop == nullptr); 98 | this->main_loop = createGreenlet([this]() { 99 | return this->mainLoop(); 100 | }); 101 | } 102 | 103 | // called directly/indirectly from python, sending events to/fro: 104 | void poll(const PredictDoneEvent* predict_done_event, 105 | ReadyEvent* ready_event); 106 | 107 | private: 108 | void mainLoop(); 109 | 110 | private: 111 | const GdlBasesTransformer* transformer; 112 | const unsigned int batch_size; 113 | 114 | std::vector requestors; 115 | std::vector yielders; 116 | std::deque runnables; 117 | 118 | // the main looper 119 | greenlet_t* main_loop; 120 | 121 | // exit in and of the main_loop (and is parent of main_loop) 122 | greenlet_t* top; 123 | 124 | // outbound predictions (we malloc/own this memory - although it will end up in 125 | // python/tensorflow for predictions, but that point we will be in a preserved state.) 126 | float* channel_buf; 127 | int channel_buf_indx; 128 | 129 | // size of the lru cachce 130 | int lru_cache_size; 131 | 132 | ModelResultList free_list; 133 | ModelResultList lru_list; 134 | GGPLib::BaseState::HashMap < ModelResultList::Node*> lru_lookup; 135 | 136 | // set via poll(). Don't own this memory. However, it won't change under feet. 137 | const PredictDoneEvent* predict_done_event; 138 | }; 139 | } 140 | -------------------------------------------------------------------------------- /src/cpp/selfplay.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | #include 9 | 10 | namespace GGPZero { 11 | // forwards 12 | struct Sample; 13 | class NetworkScheduler; 14 | class SelfPlayManager; 15 | struct PuctNode; 16 | struct PuctConfig; 17 | class PuctEvaluator; 18 | 19 | struct SelfPlayConfig { 20 | float oscillate_sampling_pct; 21 | 22 | float temperature_for_policy; 23 | 24 | PuctConfig* puct_config; 25 | int evals_per_move; 26 | 27 | float resign0_score_probability; 28 | float resign0_pct; 29 | float resign1_score_probability; 30 | float resign1_pct; 31 | 32 | int abort_max_length; 33 | int number_repeat_states_draw; 34 | float repeat_states_score; 35 | 36 | float run_to_end_pct; 37 | int run_to_end_evals; 38 | PuctConfig* run_to_end_puct_config; 39 | float run_to_end_early_score; 40 | int run_to_end_minimum_game_depth; 41 | }; 42 | 43 | class SelfPlay { 44 | public: 45 | SelfPlay(SelfPlayManager* manager, const SelfPlayConfig* conf, 46 | PuctEvaluator* pe, const GGPLib::BaseState* initial_state, 47 | int role_count, std::string identifier); 48 | ~SelfPlay(); 49 | 50 | private: 51 | bool resign(const PuctNode* node); 52 | PuctNode* collectSamples(PuctNode* node); 53 | int runToEnd(PuctNode* node, std::vector & final_scores); 54 | void addSamples(const std::vector & final_scores, 55 | int starting_sample_depth, int game_depth); 56 | 57 | bool checkFalsePositive(const std::vector & false_positive_check_scores, 58 | float resign_probability, float final_score, 59 | int role_index); 60 | 61 | public: 62 | void playOnce(); 63 | void playGamesForever(); 64 | 65 | private: 66 | SelfPlayManager* manager; 67 | const SelfPlayConfig* conf; 68 | 69 | // only one evaluator - allow to swap in/out config 70 | PuctEvaluator* pe; 71 | 72 | const GGPLib::BaseState* initial_state; 73 | const int role_count; 74 | const std::string identifier; 75 | 76 | int match_count; 77 | 78 | // collect samples per game - need to be scored at the end of game 79 | std::vector game_samples; 80 | 81 | // resignation during self play 82 | bool has_resigned; 83 | bool can_resign0; 84 | bool can_resign1; 85 | 86 | std::vector resign0_false_positive_check_scores; 87 | std::vector resign1_false_positive_check_scores; 88 | 89 | // random number generator 90 | K273::xoroshiro128plus32 rng; 91 | }; 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/cpp/selfplay_v2.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Idea for self play 2. 4 | 5 | First off - switch to puctplus. This gives us a shared MCTS tree, with transpositions, and 6 | traversals on children (need traversals if using transpositions - currently puct evaluator uses 7 | visits on nodes). 8 | 9 | roll: At the start of each roll will clear the tree. 10 | 11 | Every self play does 100 iterations per move, up to n moves. (n=10 in reversi?) 12 | 13 | Then go back to the root and select a starting point for this by walking the tree (using select 14 | parameters). And return the first node it finds where: 15 | 16 | * node depth > n OR has less "evaluations/iterations" than 800 17 | (or whatever it is configured to). 18 | 19 | Then evaluate as per normal and create samples, all the way to the end of game (or until it 20 | resigns). Perform backprop as per normal - no need to backprop all the way up the tree. The 21 | selection process next time around will take advantage of any nicely fleshed out downwind nodes, 22 | and update itself accordingly. 23 | 24 | At roll time, go through all nodes in tree from depth from 0-n, and emit a sample for any nodes 25 | over 800 (or what ever sample_iterations is set to). For the reward, return the mcts score (will 26 | be most accurate). Alternatively could clamp the value. 27 | 28 | Then we clear the tree for the next nueral network (goto roll). If feeling brave, we could 29 | possibly not clear the tree between every generation - however, would run into issues where the 30 | tree wouldn't reflect the current network's evaluations. 31 | 32 | ---- 33 | 34 | Conceptually, the idea is that we avoid massive amounts of dupes for the first n moves. The 35 | selection process should effectively choose an interesting starting point, whereas the current way 36 | is pretty much nonsense. 37 | 38 | Most importantly, it eliminates bad rewards due to random selection for a sample. Since samples 39 | will be for the most part taking top visits, then this will give the most accurate score. 40 | 41 | Currently, the number of players per self play manager is 1024, this would give 10k iterations on 42 | the first move. Since could do more than one self play per generation, this could grow well beyond 43 | 10k. 44 | 45 | */ 46 | -------------------------------------------------------------------------------- /src/cpp/selfplaymanager.cpp: -------------------------------------------------------------------------------- 1 | #include "supervisor.h" 2 | 3 | #include "sample.h" 4 | #include "selfplay.h" 5 | #include "scheduler.h" 6 | #include "selfplaymanager.h" 7 | 8 | #include "puct/config.h" 9 | #include "puct/evaluator.h" 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | #include 17 | 18 | #include 19 | 20 | using namespace GGPZero; 21 | 22 | SelfPlayManager::SelfPlayManager(GGPLib::StateMachineInterface* sm, 23 | const GdlBasesTransformer* transformer, 24 | int batch_size, 25 | UniqueStates* unique_states, 26 | std::string identifier) : 27 | sm(sm->dupe()), 28 | transformer(transformer), 29 | batch_size(batch_size), 30 | unique_states(unique_states), 31 | identifier(identifier), 32 | saw_dupes(0), 33 | no_samples_taken(0), 34 | false_positive_resigns0(0), 35 | false_positive_resigns1(0), 36 | number_early_run_to_ends(0), 37 | number_resigns(0), 38 | number_aborts_game_length(0) { 39 | 40 | this->scheduler = new NetworkScheduler(this->transformer, this->batch_size); 41 | 42 | // allocate buffers for predict_done_event 43 | for (int ii=0; iitransformer->getNumberPolicies(); ii++) { 44 | const int num_of_floats_policies = (this->transformer->getPolicySize(ii) * 45 | this->batch_size); 46 | float* mem = new float[num_of_floats_policies]; 47 | this->predict_done_event.policies.push_back(mem); 48 | } 49 | 50 | const int num_of_floats_final_scores = (this->transformer->getNumberRewards() * 51 | this->batch_size); 52 | 53 | this->predict_done_event.final_scores = new float[num_of_floats_final_scores]; 54 | this->predict_done_event.pred_count = 0; 55 | } 56 | 57 | SelfPlayManager::~SelfPlayManager() { 58 | delete this->sm; 59 | delete this->scheduler; 60 | 61 | for (float* mem : this->predict_done_event.policies) { 62 | delete[] mem; 63 | } 64 | 65 | delete[] this->predict_done_event.final_scores; 66 | } 67 | 68 | 69 | /////////////////////////////////////////////////////////////////////////////// 70 | 71 | // will create a new sample based on the root tree 72 | Sample* SelfPlayManager::createSample(const PuctEvaluator* pe, 73 | const PuctNode* node) { 74 | Sample* sample = new Sample; 75 | sample->state = this->sm->newBaseState(); 76 | sample->state->assign(node->getBaseState()); 77 | 78 | // Add previous states 79 | const PuctNode* cur = node->parent; 80 | for (int ii=0; iitransformer->getNumberPrevStates(); ii++) { 81 | if (cur == nullptr) { 82 | break; 83 | } 84 | 85 | GGPLib::BaseState* bs = this->sm->newBaseState(); 86 | bs->assign(cur->getBaseState()); 87 | sample->prev_states.push_back(bs); 88 | 89 | cur = cur->parent; 90 | } 91 | 92 | // create empty vectors 93 | sample->policies.resize(this->sm->getRoleCount()); 94 | 95 | for (int ri=0; rism->getRoleCount(); ri++) { 96 | Sample::Policy& policy = sample->policies[ri]; 97 | for (int ii=0; iinum_children; ii++) { 98 | const PuctNodeChild* child = node->getNodeChild(this->sm->getRoleCount(), ii); 99 | if (ri == node->lead_role_index) { 100 | policy.emplace_back(child->move.get(ri), 101 | child->next_prob); 102 | } else { 103 | // XXX huge hack to make it work (for now) 104 | policy.emplace_back(child->move.get(ri), 1.0); 105 | break; 106 | } 107 | } 108 | } 109 | 110 | sample->resultant_puct_visits = node->visits; 111 | for (int ii=0; iism->getRoleCount(); ii++) { 112 | sample->resultant_puct_score.push_back(node->getCurrentScore(ii)); 113 | } 114 | 115 | sample->depth = node->game_depth; 116 | sample->lead_role_index = node->lead_role_index; 117 | 118 | return sample; 119 | } 120 | 121 | void SelfPlayManager::addSample(Sample* sample) { 122 | this->samples.push_back(sample); 123 | } 124 | 125 | /////////////////////////////////////////////////////////////////////////////// 126 | 127 | void SelfPlayManager::startSelfPlayers(const SelfPlayConfig* config) { 128 | K273::l_info("SelfPlayManager::startSelfPlayers - starting %d players", this->batch_size); 129 | 130 | this->scheduler->createMainLoop(); 131 | 132 | // create a bunch of self plays 133 | for (int ii=0; iibatch_size; ii++) { 134 | // the statemachine is shared between all puctevaluators of this mananger. Just be careful. 135 | 136 | PuctEvaluator* pe = new PuctEvaluator(this->sm, this->scheduler, this->transformer); 137 | pe->updateConf(config->puct_config); 138 | 139 | std::string self_play_identifier = this->identifier + K273::fmtString("_%d", ii); 140 | SelfPlay* sp = new SelfPlay(this, config, pe, this->sm->getInitialState(), 141 | this->sm->getRoleCount(), self_play_identifier); 142 | this->self_plays.push_back(sp); 143 | 144 | auto f = [sp]() { 145 | sp->playGamesForever(); 146 | }; 147 | 148 | this->scheduler->addRunnable(f); 149 | } 150 | } 151 | 152 | void SelfPlayManager::poll() { 153 | // VERY IMPORTANT: This must be called in the thread that the scheduler resides (along with its 154 | // co-routines) 155 | // To make this super clear, the selfplaymanger should have a greenlet and we should assert it 156 | // is the correct one before continueing. 157 | 158 | this->scheduler->poll(&this->predict_done_event, &this->ready_event); 159 | } 160 | 161 | void SelfPlayManager::reportAndResetStats() { 162 | 163 | // XXX report every 5 minutes 164 | 165 | if (this->saw_dupes) { 166 | K273::l_info("Number of dupe states seen %d", this->saw_dupes); 167 | this->saw_dupes = 0; 168 | } 169 | 170 | if (this->no_samples_taken) { 171 | K273::l_info("Number of plays where no samples were taken %d", this->no_samples_taken); 172 | this->no_samples_taken = 0; 173 | } 174 | 175 | if (this->false_positive_resigns0) { 176 | K273::l_info("Number of false positive resigns (0) seen %d", this->false_positive_resigns0); 177 | this->false_positive_resigns0 = 0; 178 | } 179 | 180 | if (this->false_positive_resigns1) { 181 | K273::l_info("Number of false positive resigns (1) seen %d", this->false_positive_resigns1); 182 | this->false_positive_resigns1 = 0; 183 | } 184 | 185 | if (this->number_early_run_to_ends) { 186 | K273::l_info("Number of early run to ends %d", this->number_early_run_to_ends); 187 | this->number_early_run_to_ends = 0; 188 | } 189 | 190 | if (this->number_resigns) { 191 | K273::l_info("Number of resigns %d", this->number_resigns); 192 | this->number_resigns = 0; 193 | } 194 | 195 | if (this->number_aborts_game_length) { 196 | K273::l_info("Number of aborts (game length exceeded) %d", 197 | this->number_aborts_game_length); 198 | this->number_aborts_game_length = 0; 199 | } 200 | } 201 | -------------------------------------------------------------------------------- /src/cpp/selfplaymanager.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "events.h" 4 | #include "uniquestates.h" 5 | #include "gdltransformer.h" 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | namespace GGPZero { 14 | 15 | // forwards 16 | class SelfPlay; 17 | struct SelfPlayConfig; 18 | class PuctEvaluator; 19 | 20 | class SelfPlayManager { 21 | public: 22 | SelfPlayManager(GGPLib::StateMachineInterface* sm, 23 | const GdlBasesTransformer* transformer, 24 | int batch_size, 25 | UniqueStates* unique_states, 26 | std::string identifier); 27 | ~SelfPlayManager(); 28 | 29 | public: 30 | // the following are only called from self player 31 | Sample* createSample(const PuctEvaluator* pe, const PuctNode* node); 32 | 33 | void addSample(Sample* sample); 34 | 35 | UniqueStates* getUniqueStates() const { 36 | return this->unique_states; 37 | }; 38 | 39 | void incrDupes() { 40 | this->saw_dupes++; 41 | } 42 | 43 | void incrNoSamples() { 44 | this->no_samples_taken++; 45 | } 46 | 47 | void incrResign0FalsePositives() { 48 | this->false_positive_resigns0++; 49 | } 50 | 51 | void incrResign1FalsePositives() { 52 | this->false_positive_resigns1++; 53 | } 54 | 55 | void incrEarlyRunToEnds() { 56 | this->number_early_run_to_ends++; 57 | } 58 | 59 | void incrResigns() { 60 | this->number_resigns++; 61 | } 62 | 63 | void incrAbortsGameLength() { 64 | this->number_aborts_game_length++; 65 | } 66 | 67 | public: 68 | void startSelfPlayers(const SelfPlayConfig* config); 69 | 70 | void poll(); 71 | 72 | void reportAndResetStats(); 73 | 74 | std::vector & getSamples() { 75 | return this->samples; 76 | } 77 | 78 | ReadyEvent* getReadyEvent() { 79 | return &this->ready_event; 80 | } 81 | 82 | PredictDoneEvent* getPredictDoneEvent() { 83 | return &this->predict_done_event; 84 | } 85 | 86 | private: 87 | GGPLib::StateMachineInterface* sm; 88 | const GdlBasesTransformer* transformer; 89 | int batch_size; 90 | 91 | std::vector self_plays; 92 | 93 | // local scheduler 94 | NetworkScheduler* scheduler; 95 | 96 | std::vector samples; 97 | UniqueStates* unique_states; 98 | std::string identifier; 99 | 100 | std::vector states_allocated; 101 | 102 | // Events 103 | ReadyEvent ready_event; 104 | PredictDoneEvent predict_done_event; 105 | 106 | // stats 107 | int saw_dupes; 108 | int no_samples_taken; 109 | int false_positive_resigns0; 110 | int false_positive_resigns1; 111 | int number_early_run_to_ends; 112 | int number_resigns; 113 | int number_aborts_game_length; 114 | }; 115 | } 116 | -------------------------------------------------------------------------------- /src/cpp/supervisor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "uniquestates.h" 4 | 5 | #include "events.h" 6 | 7 | #include "puct/node.h" 8 | #include "puct/config.h" 9 | #include "puct/evaluator.h" 10 | #include "gdltransformer.h" 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | #include 18 | 19 | #include 20 | 21 | namespace GGPZero { 22 | 23 | // forwards 24 | class PuctEvaluator; 25 | class SelfPlay; 26 | class SelfPlayManager; 27 | struct SelfPlayConfig; 28 | 29 | typedef K273::LockedQueue ReadyQueue; 30 | typedef K273::LockedQueue PredictDoneQueue; 31 | 32 | class SelfPlayWorker : public K273::WorkerInterface { 33 | public: 34 | SelfPlayWorker(SelfPlayManager* man0, SelfPlayManager* man1, 35 | const SelfPlayConfig* config); 36 | virtual ~SelfPlayWorker(); 37 | 38 | private: 39 | void doWork(); 40 | 41 | public: 42 | // supervisor side: 43 | SelfPlayManager* pull() { 44 | if (!this->outbound_queue.empty()) { 45 | return this->outbound_queue.pop(); 46 | } 47 | 48 | return nullptr; 49 | } 50 | 51 | // supervisor side: 52 | void push(SelfPlayManager* manager) { 53 | this->inbound_queue.push(manager); 54 | this->getThread()->promptWorker(); 55 | } 56 | 57 | private: 58 | // worker pulls from here when no more workers available 59 | PredictDoneQueue inbound_queue; 60 | 61 | // worker pushes on here when done 62 | ReadyQueue outbound_queue; 63 | 64 | bool enter_first_time; 65 | const SelfPlayConfig* config; 66 | 67 | // will be two 68 | SelfPlayManager* man0; 69 | SelfPlayManager* man1; 70 | }; 71 | 72 | class Supervisor { 73 | public: 74 | Supervisor(GGPLib::StateMachineInterface* sm, 75 | const GdlBasesTransformer* transformer, 76 | int batch_size, 77 | std::string identifier); 78 | ~Supervisor(); 79 | 80 | private: 81 | void slowPoll(SelfPlayManager* manager); 82 | 83 | public: 84 | void createInline(const SelfPlayConfig* config); 85 | void createWorkers(const SelfPlayConfig* config); 86 | 87 | std::vector getSamples(); 88 | 89 | const ReadyEvent* poll(int predict_count, std::vector & data); 90 | 91 | void addUniqueState(const GGPLib::BaseState* bs); 92 | void clearUniqueStates(); 93 | 94 | private: 95 | GGPLib::StateMachineInterface* sm; 96 | const GdlBasesTransformer* transformer; 97 | const int batch_size; 98 | const std::string identifier; 99 | 100 | int slow_poll_counter; 101 | 102 | SelfPlayManager* inline_sp_manager; 103 | 104 | SelfPlayManager* in_progress_manager; 105 | SelfPlayWorker* in_progress_worker; 106 | std::vector self_play_workers; 107 | 108 | std::vector samples; 109 | UniqueStates unique_states; 110 | }; 111 | } 112 | -------------------------------------------------------------------------------- /src/cpp/uniquestates.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "gdltransformer.h" 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | namespace GGPZero { 12 | class UniqueStates { 13 | public: 14 | UniqueStates(const GGPLib::StateMachineInterface* sm, 15 | const GGPZero::GdlBasesTransformer* transformer, 16 | int max_num_dupes=1) : 17 | sm(sm->dupe()), 18 | max_num_dupes(max_num_dupes) { 19 | 20 | this->lookup = GGPLib::BaseState::makeMaskedMap (transformer->createHashMask(this->sm->newBaseState())); 21 | } 22 | 23 | ~UniqueStates() { 24 | delete this->sm; 25 | } 26 | 27 | public: 28 | void add(const GGPLib::BaseState* bs) { 29 | std::lock_guard lk(this->mut); 30 | 31 | auto it = this->lookup->find(bs); 32 | if (it != this->lookup->end()) { 33 | if (it->second < max_num_dupes) { 34 | it->second += 1; 35 | } 36 | 37 | return; 38 | } 39 | 40 | // create a new basestate... 41 | GGPLib::BaseState* new_bs = this->sm->newBaseState(); 42 | new_bs->assign(bs); 43 | this->states_allocated.push_back(new_bs); 44 | 45 | this->lookup->emplace(new_bs, 1); 46 | } 47 | 48 | bool isUnique(GGPLib::BaseState* bs, int depth) { 49 | std::lock_guard lk(this->mut); 50 | const auto it = this->lookup->find(bs); 51 | if (it != this->lookup->end()) { 52 | int allowed_dupes = std::max(2, (this->max_num_dupes - 5 * depth)); 53 | if (it->second >= allowed_dupes) { 54 | return false; 55 | } 56 | } 57 | 58 | return true; 59 | } 60 | 61 | void clear() { 62 | std::lock_guard lk(this->mut); 63 | 64 | for (GGPLib::BaseState* bs : this->states_allocated) { 65 | ::free(bs); 66 | } 67 | 68 | this->lookup->clear(); 69 | this->states_allocated.clear(); 70 | } 71 | 72 | private: 73 | const GGPLib::StateMachineInterface* sm; 74 | const int max_num_dupes; 75 | 76 | std::mutex mut; 77 | GGPLib::BaseState::HashMapMasked * lookup; 78 | std::vector states_allocated; 79 | }; 80 | } 81 | -------------------------------------------------------------------------------- /src/ggpzero/Makefile: -------------------------------------------------------------------------------- 1 | 2 | PY_FILES := $(shell find $(GGPZERO_PATH)/src/ggpzero -name '*.py') 3 | TEST_FILES := $(shell find $(GGPZERO_PATH)/src/test -name '*.py') 4 | 5 | all: 6 | flake8 $(PY_FILES) 7 | pylint $(PY_FILES) 8 | 9 | test: 10 | flake8 $(TEST_FILES) 11 | pylint $(TEST_FILES) 12 | -------------------------------------------------------------------------------- /src/ggpzero/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richemslie/galvanise_zero/52164bcd6f43d648736e1ae9e556a7f6412339d1/src/ggpzero/__init__.py -------------------------------------------------------------------------------- /src/ggpzero/battle/README.txt: -------------------------------------------------------------------------------- 1 | 2 | * Game specific code for testing outside GGP. 3 | * Ability to connect to 3rd party engines or run small touranments between games. 4 | * No learning code here. 5 | -------------------------------------------------------------------------------- /src/ggpzero/battle/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richemslie/galvanise_zero/52164bcd6f43d648736e1ae9e556a7f6412339d1/src/ggpzero/battle/__init__.py -------------------------------------------------------------------------------- /src/ggpzero/battle/amazons.py: -------------------------------------------------------------------------------- 1 | from builtins import super 2 | 3 | from ggplib.util.symbols import SymbolFactory 4 | from ggplib.db import lookup 5 | 6 | from ggpzero.battle.common import MatchGameInfo 7 | 8 | 9 | class MatchInfo(MatchGameInfo): 10 | def __init__(self, cross=False, match_cb=None): 11 | game = "amazonsLGcross" if cross else "amazons_10x10" 12 | self.cross = cross 13 | self.match_cb = match_cb 14 | 15 | game_info = lookup.by_name(game) 16 | super().__init__(game_info) 17 | 18 | def play_cb(self, players, match_depth): 19 | if self.match_cb: 20 | self.match_cb(players, match_depth) 21 | 22 | def convert_move_to_gdl(self, move): 23 | def lg_to_ggp(k): 24 | return ("jihgfedcba".index(k[0]) + 1), int(k[1:]) 25 | 26 | # actually 2 moves 27 | amazon, fire = move.split("/") 28 | from_pos, to_pos = amazon.split("-") 29 | from_pos, to_pos, fire = map(lg_to_ggp, (from_pos, to_pos, fire)) 30 | 31 | yield "(move %s %s %s %s)" % (from_pos + to_pos) 32 | yield "(fire %s %s)" % fire 33 | 34 | def ggp_to_sgf(self, move): 35 | amazon_move, fire_move = move 36 | amazon_move = amazon_move.replace("(move", "").replace(")", "") 37 | fire_move = fire_move.replace("(fire", "").replace(")", "") 38 | cords = map(int, amazon_move.split()) + map(int, fire_move.split()) 39 | 40 | def ggp_to_cord(x, y): 41 | return "%s%s" % ("abcdefghij"[10 - x], y) 42 | 43 | move = "%s-%s/%s" % (ggp_to_cord(cords[0], cords[1]), 44 | ggp_to_cord(cords[2], cords[3]), 45 | ggp_to_cord(cords[4], cords[5])) 46 | 47 | return move 48 | 49 | def gdl_to_lg(self, move): 50 | amazon_move, fire_move = move 51 | amazon_move = amazon_move.replace("(move", "").replace(")", "") 52 | fire_move = fire_move.replace("(fire", "").replace(")", "") 53 | cords = map(int, amazon_move.split()) + map(int, fire_move.split()) 54 | 55 | def ggp_to_lg(x, y): 56 | return "%s%s" % (10 - x, y - 1) 57 | 58 | move = "%s%s%s" % (ggp_to_lg(cords[0], cords[1]), 59 | ggp_to_lg(cords[2], cords[3]), 60 | ggp_to_lg(cords[4], cords[5])) 61 | 62 | return move 63 | 64 | def print_board(self, sm): 65 | as_str = self.game_info.model.basestate_to_str(sm.get_current_state()) 66 | print as_str 67 | 68 | sf = SymbolFactory() 69 | states = sf.to_symbols(as_str) 70 | 71 | control = None 72 | board_map = {} 73 | 74 | for s in list(states): 75 | base = s[1] 76 | if base[0] == "control": 77 | control = base[1] 78 | elif base[0] == "cell": 79 | key = int(base[1]), int(base[2]) 80 | board_map[key] = base[3] 81 | 82 | def row(i): 83 | yield ' ' 84 | for j in range(10, 0, -1): 85 | key = j, i 86 | if key in board_map: 87 | if board_map[key] == "arrow": 88 | yield " %s " % u"\u25C8" 89 | 90 | elif board_map[key] == "black": 91 | yield " B " 92 | 93 | else: 94 | assert board_map[key] == "white" 95 | yield " W " 96 | 97 | else: 98 | yield ' . ' 99 | 100 | def lines(): 101 | for i in range(10, 0, -1): 102 | yield "".join(row(i)) 103 | 104 | print 105 | print 106 | print "\n".join(lines()) 107 | print "Control:", control 108 | -------------------------------------------------------------------------------- /src/ggpzero/battle/bt.py: -------------------------------------------------------------------------------- 1 | 2 | from builtins import super 3 | 4 | from ggplib.db import lookup 5 | 6 | from ggpzero.battle.common import MatchGameInfo 7 | 8 | 9 | def get_game_info(board_size): 10 | # add players 11 | if board_size == 8: 12 | game = "breakthrough" 13 | elif board_size == 7: 14 | game = "bt_7" 15 | elif board_size == 6: 16 | game = "breakthroughSmall" 17 | else: 18 | assert "board_size not supported" 19 | 20 | return lookup.by_name(game) 21 | 22 | 23 | def pretty_board(board_size, sm): 24 | ' pretty print board current state of match ' 25 | 26 | from ggplib.util.symbols import SymbolFactory 27 | as_str = get_game_info(board_size).model.basestate_to_str(sm.get_current_state()) 28 | sf = SymbolFactory() 29 | states = sf.to_symbols(as_str) 30 | mapping = {} 31 | control = None 32 | for s in list(states): 33 | if s[1][0] == "control": 34 | control = s[1][1] 35 | else: 36 | if board_size != 6: 37 | assert s[1][0] == "cellHolds" 38 | else: 39 | assert s[1][0] == "cell" 40 | 41 | key = int(s[1][1]), int(s[1][2]) 42 | mapping[key] = s[1][3] 43 | 44 | lines = [] 45 | line_len = board_size * 4 + 1 46 | lines.append(" +" + "-" * (line_len - 2) + "+") 47 | for i in reversed(range(1, board_size + 1)): 48 | ll = [" %s |" % i] 49 | for j in reversed(range(1, board_size + 1)): 50 | key = j, i 51 | if key in mapping: 52 | if mapping[key] == "black": 53 | ll.append(" %s |" % u"\u2659") 54 | else: 55 | assert mapping[key] == "white" 56 | ll.append(" %s |" % u"\u265F") 57 | else: 58 | ll.append(" |") 59 | 60 | lines.append("".join(ll)) 61 | if i > 1: 62 | lines.append(" " + "-" * line_len) 63 | 64 | lines.append(" +" + "-" * (line_len - 2) + "+") 65 | if board_size == 8: 66 | lines.append(" " + ' '.join(' %s ' % c for c in 'abcdefgh')) 67 | else: 68 | lines.append(" " + ' '.join(' %s ' % c for c in 'abcdef')) 69 | 70 | print 71 | print 72 | print "\n".join(lines) 73 | print "Control:", control 74 | 75 | 76 | def parse_sgf(txt): 77 | ''' not actually sgf, but whatever ''' 78 | moves = [] 79 | for line in txt.splitlines(): 80 | line = line.strip() 81 | if not line: 82 | continue 83 | if line.startswith("1"): 84 | expect = 1 85 | moves = [] 86 | for count, token in enumerate(line.split()): 87 | if token == "*": 88 | break 89 | if count % 3 == 0: 90 | expect_str = "%s." % expect 91 | assert token == expect_str, "expected '%s', got '%s'" % (expect_str, token) 92 | expect += 1 93 | else: 94 | moves.append(token) 95 | return moves 96 | 97 | 98 | class MatchInfo(MatchGameInfo): 99 | def __init__(self, board_size): 100 | self.board_size = board_size 101 | game_info = get_game_info(board_size) 102 | super().__init__(game_info) 103 | 104 | def print_board(self, sm): 105 | return pretty_board(self.board_size, sm) 106 | 107 | def convert_move_to_gdl(self, move): 108 | def to_cords(s): 109 | if self.board_size == 8: 110 | mapping_x_cord = {x0 : x1 for x0, x1 in zip('abcdefgh', '87654321')} 111 | elif self.board_size == 7: 112 | mapping_x_cord = {x0 : x1 for x0, x1 in zip('abcdefg', '7654321')} 113 | else: 114 | mapping_x_cord = {x0 : x1 for x0, x1 in zip('abcdef', '654321')} 115 | return mapping_x_cord[s[0]], s[1] 116 | 117 | move = move.lower() 118 | split_chr = '-' if "-" in move else 'x' 119 | from_, to_ = map(to_cords, move.split(split_chr)) 120 | yield "(move %s %s %s %s)" % (from_[0], from_[1], to_[0], to_[1]) 121 | 122 | def gdl_to_sgf(self, move): 123 | # XXX captures How? 124 | # XXX move = move[lead_role_index] 125 | move = move.replace("(move", "").replace(")", "") 126 | a, b, c, d = move.split() 127 | if self.board_size == 8: 128 | mapping_x_cord = {x0 : x1 for x0, x1 in zip('87654321', 'abcdefgh')} 129 | elif self.board_size == 7: 130 | mapping_x_cord = {x0 : x1 for x0, x1 in zip('7654321', 'abcdefg')} 131 | else: 132 | mapping_x_cord = {x0 : x1 for x0, x1 in zip('654321', 'abcdef')} 133 | 134 | return "%s%s-%s%s" % (mapping_x_cord[a], b, mapping_x_cord[c], d) 135 | 136 | def gdl_to_lg(self, move): 137 | move = move.replace("(move", "").replace(")", "") 138 | a, b, c, d = move.split() 139 | a = self.board_size - int(a) 140 | b = int(b) - 1 141 | c = self.board_size - int(c) 142 | d = int(d) - 1 143 | return "%s%s%s%s" % (a, b, c, d) 144 | 145 | def parse_sgf(self, sgf): 146 | return parse_sgf(sgf) 147 | 148 | def print_board(self, sm): 149 | pretty_board(self.board_size, sm) 150 | -------------------------------------------------------------------------------- /src/ggpzero/battle/chess.py: -------------------------------------------------------------------------------- 1 | from builtins import super 2 | from pprint import pprint 3 | 4 | from ggplib.db import lookup 5 | 6 | from ggpzero.battle.common import MatchGameInfo 7 | 8 | 9 | def print_board(game_info, sm): 10 | basestate = sm.get_current_state() 11 | 12 | def valid_bases(): 13 | for ii in range(basestate.len()): 14 | if basestate.get(ii): 15 | base = game_info.model.bases[ii] 16 | base = base.replace("(true (", "").replace(")", "") 17 | yield base.split() 18 | 19 | # pprint(list(valid_bases())) 20 | 21 | board_map = {} 22 | 23 | to_unicode = {} 24 | to_unicode["black", "king"] = u"\u2654" 25 | to_unicode["black", "queen"] = u"\u2655" 26 | to_unicode["black", "rook"] = u"\u2656" 27 | to_unicode["black", "bishop"] = u"\u2657" 28 | to_unicode["black", "knight"] = u"\u2658" 29 | to_unicode["black", "pawn"] = u"\u2659" 30 | 31 | to_unicode["white", "king"] = u"\u265A" 32 | to_unicode["white", "queen"] = u"\u265B" 33 | to_unicode["white", "rook"] = u"\u265C" 34 | to_unicode["white", "bishop"] = u"\u265D" 35 | to_unicode["white", "knight"] = u"\u265E" 36 | to_unicode["white", "pawn"] = u"\u265F" 37 | 38 | for base in valid_bases(): 39 | if base[0] == "cell": 40 | key = "_abcdefgh".index(base[1]), int(base[2]) 41 | board_map[key] = base[3], base[4] 42 | elif base[0] in ("step", "kingHasMoved", "aRookHasMoved", "control", 43 | "hRookHasMoved", "canEnPassantCapture"): 44 | print "control", base 45 | continue 46 | else: 47 | assert False, "what is this?: %s" % base 48 | 49 | # pprint(board_map) 50 | 51 | board_size = 8 52 | lines = [] 53 | line_len = board_size * 2 + 1 54 | lines.append(" +" + "-" * line_len + "+") 55 | 56 | for i in range(board_size): 57 | y = board_size - i 58 | ll = [" %2d |" % y] 59 | for j in range(board_size): 60 | x = j + 1 61 | key = x, y 62 | if key in board_map: 63 | what = board_map[key] 64 | c = to_unicode[what] 65 | ll.append(" %s" % c) 66 | 67 | else: 68 | ll.append(" .") 69 | 70 | lines.append("".join(ll) + " |") 71 | 72 | lines.append(" +" + "-" * line_len + "+") 73 | lines.append(" " + ' '.join('%s' % c for c in 'abcdefgh')) 74 | 75 | print 76 | print 77 | print "\n".join(lines) 78 | print 79 | 80 | 81 | class MatchInfo(MatchGameInfo): 82 | def __init__(self, short_50=False): 83 | if short_50: 84 | game = "chess_15d" 85 | else: 86 | game = "chess_50d" 87 | 88 | self.game_info = lookup.by_name(game) 89 | super().__init__(self.game_info) 90 | 91 | def print_board(self, sm): 92 | print_board(self.game_info, sm) 93 | -------------------------------------------------------------------------------- /src/ggpzero/battle/connect6.py: -------------------------------------------------------------------------------- 1 | from builtins import super 2 | 3 | import re 4 | 5 | from ggplib.db import lookup 6 | 7 | from ggpzero.battle.common import MatchGameInfo 8 | 9 | 10 | class MatchInfo(MatchGameInfo): 11 | def __init__(self, match_cb=None): 12 | game_info = lookup.by_name("connect6") 13 | super().__init__(game_info) 14 | 15 | self.pattern = re.compile('[a-s]\d+') 16 | self.match_cb = match_cb 17 | 18 | def play_cb(self, players, match_depth): 19 | if self.match_cb: 20 | self.match_cb(players, match_depth) 21 | 22 | def convert_move_to_gdl(self, move): 23 | move = move.lower() 24 | if move == "j10": 25 | return 26 | 27 | def lg_to_ggp(k): 28 | return ("abcdefghijklmnopqrs".index(k[0]) + 1), int(k[1:]) 29 | 30 | # always 2 moves 31 | a, b = self.pattern.findall(move) 32 | 33 | yield "(place %s %s)" % lg_to_ggp(a) 34 | yield "(place %s %s)" % lg_to_ggp(b) 35 | 36 | def gdl_to_lg(self, move): 37 | move_a, move_b = move 38 | move_a = move_a.replace("(place", "").replace(")", "").split() 39 | move_b = move_b.replace("(place", "").replace(")", "").split() 40 | 41 | def to_cord(x): 42 | return "_abcdefghijklmnopqrs"[int(x)] 43 | 44 | return "%s%s%s%s" % tuple(to_cord(x) for x in move_a + move_b) 45 | 46 | def print_board(self, sm): 47 | from ggplib.util.symbols import SymbolFactory 48 | 49 | as_str = self.game_info.model.basestate_to_str(sm.get_current_state()) 50 | 51 | sf = SymbolFactory() 52 | states = sf.to_symbols(as_str) 53 | 54 | control = None 55 | board_map = {} 56 | 57 | for s in list(states): 58 | base = s[1] 59 | if base[0] == "control": 60 | control = base[1] 61 | elif base[0] == "cell": 62 | key = int(base[1]), int(base[2]) 63 | board_map[key] = base[3] 64 | 65 | board_size = 19 66 | lines = [] 67 | line_len = board_size * 2 + 1 68 | lines.append(" +" + "-" * line_len + "+") 69 | 70 | for y in range(board_size, 0, -1): 71 | ll = [" %2d |" % y] 72 | for j in range(board_size): 73 | x = j + 1 74 | key = x, y 75 | if key in board_map: 76 | if board_map[key] == "black": 77 | ll.append(" %s" % u"\u26C0") 78 | else: 79 | assert board_map[key] == "white" 80 | ll.append(" %s" % u"\u26C2") 81 | else: 82 | ll.append(" .") 83 | 84 | lines.append("".join(ll) + " |") 85 | 86 | lines.append(" +" + "-" * line_len + "+") 87 | lines.append(" " + ' '.join('%s' % c for c in 'abcdefghijklmnopqrs')) 88 | 89 | print 90 | print 91 | print "\n".join(lines) 92 | print "Control:", control 93 | -------------------------------------------------------------------------------- /src/ggpzero/battle/draughts.py: -------------------------------------------------------------------------------- 1 | from builtins import super 2 | 3 | from ggplib.db import lookup 4 | 5 | from ggpzero.battle.common import MatchGameInfo 6 | from ggplib.non_gdl_games.draughts import desc 7 | 8 | 9 | class Draughts_MatchInfo(MatchGameInfo): 10 | def __init__(self, killer=False): 11 | if killer: 12 | game_info = lookup.by_name("draughts_killer_10x10") 13 | else: 14 | game_info = lookup.by_name("draughts_10x10") 15 | 16 | super().__init__(game_info) 17 | 18 | self.board_desc = desc.BoardDesc(10) 19 | 20 | def print_board(self, sm): 21 | self.board_desc.print_board_sm(sm) 22 | -------------------------------------------------------------------------------- /src/ggpzero/battle/hex.py: -------------------------------------------------------------------------------- 1 | from builtins import super 2 | 3 | import time 4 | import hashlib 5 | 6 | from ggplib.db import lookup 7 | 8 | from ggpzero.battle.common import MatchGameInfo 9 | 10 | 11 | def write_hexgui_sgf(black_name, white_name, moves, game_size): 12 | game_id = hashlib.md5(hashlib.sha1("%.5f" % time.time()).hexdigest()).hexdigest()[:6] 13 | 14 | with open("hex%s_%s_%s_%s.sgf" % (game_size, 15 | black_name, 16 | white_name, 17 | game_id), "w") as f: 18 | f.write("(;FF[4]EV[null]PB[%s]PW[%s]SZ[%s]GC[game#%s];" % (black_name, 19 | white_name, 20 | game_size, 21 | game_id)) 22 | 23 | # Note: piece colours are swapped for hexgui from LG 24 | for ri, m in moves: 25 | f.write("%s[%s];" % ("B" if ri == 0 else "W", m)) 26 | f.write(")\n") 27 | 28 | 29 | def dump_trmph_url(moves, game_size): 30 | s = "http://www.trmph.com/hex/board#%d," % game_size 31 | for _, m in moves: 32 | if m == "swap": 33 | continue 34 | s += m 35 | 36 | print s 37 | 38 | 39 | def convert_trmph(ex): 40 | alpha = "_abcdefghijklmnop" 41 | moves = ex.split(",")[-1] 42 | first = None 43 | while moves: 44 | c0 = moves[0] 45 | assert c0 in alpha 46 | try: 47 | ii = int(moves[1] + moves[2]) 48 | c1 = alpha[ii] 49 | move = c0 + c1 50 | moves = moves[3:] 51 | except: 52 | ii = int(moves[1]) 53 | c1 = alpha[ii] 54 | move = c0 + c1 55 | moves = moves[2:] 56 | 57 | if first is None: 58 | first = move 59 | else: 60 | if first == move: 61 | move = "swap" 62 | first = -42 63 | 64 | yield move 65 | 66 | class MatchInfo(MatchGameInfo): 67 | def __init__(self, size=None): 68 | game = "hexLG%s" % size 69 | self.size = size 70 | 71 | game_info = lookup.by_name(game) 72 | super().__init__(game_info) 73 | 74 | def convert_move_to_gdl(self, move): 75 | if move == "resign": 76 | return 77 | 78 | if move == "swap": 79 | yield move 80 | else: 81 | gdl_role_move = "(place %s %s)" % (move[0], 82 | "abcdefghijklmnop".index(move[1]) + 1) 83 | yield gdl_role_move 84 | 85 | def gdl_to_lg(self, move): 86 | if move != "swap": 87 | move = move.replace("(place", "").replace(")", "") 88 | parts = move.split() 89 | move = parts[0] + "_abcdefghijklm"[int(parts[1])] 90 | return move 91 | 92 | def export(self, players, result): 93 | sgf_moves = [] 94 | 95 | def remove_gdl(m): 96 | return m.replace("(place ", "").replace(")", "").strip().replace(' ', '') 97 | 98 | def swapaxis(s): 99 | mapping_x = {x1 : x0 for x0, x1 in zip('abcdefghi', '123456789')} 100 | mapping_y = {x0 : x1 for x0, x1 in zip('abcdefghi', '123456789')} 101 | for x0, x1 in zip(('j', 'k', 'l', 'm'), ('10', '11', '12', '13')): 102 | mapping_x[x1] = x0 103 | mapping_y[x0] = x1 104 | 105 | return "%s%s" % (mapping_x[s[1]], mapping_y[s[0]]) 106 | 107 | def add_sgf_move(ri, m): 108 | sgf_moves.append((ri, m)) 109 | if m == "swap": 110 | assert ri == 1 111 | assert len(sgf_moves) == 2 112 | 113 | # hexgui does do swap like LG. This is a (double) hack. 114 | moved_move = swapaxis(sgf_moves[0][1]) 115 | sgf_moves[0] = (0, moved_move) 116 | sgf_moves.append((1, moved_move)) 117 | 118 | for match_depth, move, move_info in result: 119 | ri = 1 if move[0] == "noop" else 0 120 | str_move = remove_gdl(move[ri]) 121 | add_sgf_move(ri, str_move) 122 | 123 | player0, player1 = players 124 | dump_trmph_url(sgf_moves, self.size) 125 | write_hexgui_sgf(player0.get_name(), player1.get_name(), sgf_moves, self.size) 126 | 127 | def print_board(self, sm): 128 | from ggplib.util.symbols import SymbolFactory 129 | 130 | as_str = self.game_info.model.basestate_to_str(sm.get_current_state()) 131 | 132 | sf = SymbolFactory() 133 | states = sf.to_symbols(as_str) 134 | 135 | control = None 136 | board_map = {} 137 | 138 | for s in list(states): 139 | base = s[1] 140 | if base[0] == "control": 141 | control = base[1] 142 | elif base[0] == "cell": 143 | key = "abcdefghijklmnop".index(base[1]) + 1, int(base[2]), 144 | board_map[key] = base[3] 145 | 146 | board_size = self.size 147 | lines = [] 148 | line_len = board_size * 2 + 1 149 | 150 | def indent(y): 151 | return y * ' ' 152 | 153 | lines.append(" %s +" % indent(0) + "-" * line_len + "+") 154 | for i in range(board_size): 155 | y = i + 1 156 | ll = [" %2d %s \\" % (y, indent(y))] 157 | 158 | for j in range(board_size): 159 | x = j + 1 160 | key = x, y 161 | if key in board_map: 162 | if board_map[key] == "black": 163 | ll.append(" %s" % u"\u26C0") 164 | else: 165 | assert board_map[key] == "white" 166 | ll.append(" %s" % u"\u26C2") 167 | else: 168 | ll.append(" .") 169 | 170 | lines.append("".join(ll) + " \\") 171 | 172 | lines.append(" %s +" % indent(board_size + 1) + "-" * line_len + "+") 173 | lines.append(" %s " % indent(board_size + 1) + ' '.join('%s' % c for c in 'abcdefghijklmnopqrs'[:board_size])) 174 | 175 | print 176 | print 177 | print "\n".join(lines) 178 | print "Control:", control 179 | -------------------------------------------------------------------------------- /src/ggpzero/battle/hex2.py: -------------------------------------------------------------------------------- 1 | from builtins import super 2 | 3 | import time 4 | import hashlib 5 | 6 | import colorama 7 | 8 | from ggplib.db import lookup 9 | 10 | from ggpzero.battle.common import MatchGameInfo 11 | 12 | alphabet = "abcdefghijklmnopqrstuvwxyz" 13 | _alphabet = "_" + alphabet 14 | 15 | class MatchInfo(MatchGameInfo): 16 | def __init__(self, size=None): 17 | game = "hex_lg_%s" % size 18 | self.size = size 19 | 20 | game_info = lookup.by_name(game) 21 | super().__init__(game_info) 22 | 23 | def convert_move_to_gdl(self, move): 24 | if move == "resign": 25 | return 26 | 27 | if move == "swap": 28 | yield move 29 | else: 30 | gdl_role_move = "(place %s %s)" % (move[0], 31 | alphabet.index(move[1]) + 1) 32 | yield gdl_role_move 33 | 34 | def gdl_to_lg(self, move): 35 | if move != "swap": 36 | move = move.replace("(place", "").replace(")", "") 37 | parts = move.split() 38 | move = parts[0] + _alphabet[int(parts[1])] 39 | return move 40 | 41 | def print_board(self, sm): 42 | from ggplib.util.symbols import SymbolFactory 43 | 44 | as_str = self.game_info.model.basestate_to_str(sm.get_current_state()) 45 | 46 | sf = SymbolFactory() 47 | states = sf.to_symbols(as_str) 48 | 49 | control = None 50 | board_map = {} 51 | board_map_colours = {} 52 | 53 | for s in list(states): 54 | base = s[1] 55 | # print base 56 | if base[0] == "control": 57 | control = base[1] 58 | elif base[0] == "cell": 59 | key = alphabet.index(base[2]) + 1, int(base[3]), 60 | if base[1] == "white" or base[1] == "black": 61 | board_map[key] = base[1] 62 | else: 63 | board_map_colours[key] = base 64 | 65 | board_size = self.size 66 | lines = [] 67 | line_len = board_size * 2 + 1 68 | 69 | def indent(y): 70 | return y * ' ' 71 | 72 | lines.append(" %s +" % indent(0) + "-" * line_len + "+") 73 | for i in range(board_size): 74 | y = i + 1 75 | ll = [" %2d %s \\" % (y, indent(y))] 76 | 77 | for j in range(board_size): 78 | x = j + 1 79 | key = x, y 80 | if key in board_map: 81 | if key in board_map_colours: 82 | ll.append(colorama.Fore.GREEN) 83 | if board_map[key] == "black": 84 | ll.append(" %s" % u"\u26C0") 85 | else: 86 | assert board_map[key] == "white" 87 | ll.append(" %s" % u"\u26C2") 88 | 89 | if key in board_map_colours: 90 | ll.append(colorama.Style.RESET_ALL) 91 | 92 | else: 93 | ll.append(" .") 94 | 95 | lines.append("".join(ll) + " \\") 96 | 97 | lines.append(" %s +" % indent(board_size + 1) + "-" * line_len + "+") 98 | lines.append(" %s " % indent(board_size + 1) + ' '.join('%s' % c for c in alphabet[:board_size])) 99 | 100 | print 101 | print 102 | print "\n".join(lines) 103 | print "Control:", control 104 | -------------------------------------------------------------------------------- /src/ggpzero/battle/reversi.py: -------------------------------------------------------------------------------- 1 | from builtins import super 2 | 3 | from ggplib.db import lookup 4 | 5 | from ggpzero.battle.common import MatchGameInfo 6 | 7 | 8 | def pretty_board(board_size, sm): 9 | assert board_size == 8 or board_size == 10 10 | game_info = lookup.by_name("reversi") if board_size == 8 else lookup.by_name("reversi_10x10") 11 | 12 | from ggplib.util.symbols import SymbolFactory 13 | as_str = game_info.model.basestate_to_str(sm.get_current_state()) 14 | sf = SymbolFactory() 15 | print list(sf.to_symbols(as_str)) 16 | 17 | mapping = {} 18 | control = None 19 | for s in sf.to_symbols(as_str): 20 | if s[1][0] == "control": 21 | control = s[1][1] 22 | else: 23 | assert s[1][0] == "cell" 24 | 25 | key = int(s[1][1]), int(s[1][2]) 26 | mapping[key] = s[1][3] 27 | 28 | lines = [] 29 | line_len = board_size * 4 + 1 30 | lines.append(" +" + "-" * (line_len - 2) + "+") 31 | for i in reversed(range(1, board_size + 1)): 32 | ll = [" %2s |" % i] 33 | for j in reversed(range(1, board_size + 1)): 34 | key = j, i 35 | if key in mapping: 36 | if mapping[key] == "black": 37 | ll.append(" %s |" % u"\u2659") 38 | else: 39 | assert mapping[key] in ("red", "white") 40 | ll.append(" %s |" % u"\u265F") 41 | else: 42 | ll.append(" |") 43 | 44 | lines.append("".join(ll)) 45 | if i > 1: 46 | lines.append(" " + "-" * line_len) 47 | 48 | lines.append(" +" + "-" * (line_len - 2) + "+") 49 | if board_size == 8: 50 | lines.append(" " + ' '.join(' %s ' % c for c in 'abcdefgh')) 51 | else: 52 | lines.append(" " + ' '.join(' %s ' % c for c in 'abcdef')) 53 | 54 | print 55 | print 56 | print "\n".join(lines) 57 | print "Control:", control 58 | 59 | 60 | class MatchInfo8(MatchGameInfo): 61 | def __init__(self): 62 | game = "reversi" 63 | game_info = lookup.by_name(game) 64 | super().__init__(game_info) 65 | 66 | def convert_move_to_gdl(self, move): 67 | if move == "pass": 68 | yield "noop" 69 | return 70 | 71 | assert len(move) == 2 72 | 73 | def cord_x(c): 74 | return "hgfedcba".index(c) + 1 75 | 76 | def cord_y(c): 77 | return "abcdefgh".index(c) + 1 78 | 79 | yield "(move %s %s)" % (cord_x(move[0]), cord_y(move[1])) 80 | 81 | def gdl_to_lg(self, move): 82 | if move == "noop": 83 | return "pass" 84 | 85 | move = move.replace("(move", "").replace(")", "") 86 | 87 | def cord_x(c): 88 | return "hgfedcba"[int(c) - 1] 89 | 90 | def cord_y(c): 91 | return "abcdefgh"[int(c) - 1] 92 | 93 | x, y = move.split() 94 | return "%s%s" % (cord_x(x), cord_y(y)) 95 | 96 | def print_board(self, sm): 97 | pretty_board(8, sm) 98 | 99 | 100 | class MatchInfo10(MatchGameInfo): 101 | def __init__(self): 102 | game = "reversi_10x10" 103 | game_info = lookup.by_name(game) 104 | super().__init__(game_info) 105 | 106 | def convert_move_to_gdl(self, move): 107 | if move == "pass": 108 | yield "noop" 109 | return 110 | 111 | assert len(move) == 2 112 | 113 | def cord(c): 114 | return "abcdefghij".index(c) + 1 115 | 116 | yield "(move %s %s)" % tuple(cord(c) for c in move) 117 | 118 | def gdl_to_lg(self, move): 119 | if move == "noop": 120 | return "pass" 121 | 122 | move = move.replace("(move", "").replace(")", "") 123 | 124 | def cord(c): 125 | return "abcdefghij"[int(c) - 1] 126 | 127 | return "%s%s" % tuple(cord(c) for c in move.split()) 128 | 129 | def print_board(self, sm): 130 | pretty_board(10, sm) 131 | -------------------------------------------------------------------------------- /src/ggpzero/defs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richemslie/galvanise_zero/52164bcd6f43d648736e1ae9e556a7f6412339d1/src/ggpzero/defs/__init__.py -------------------------------------------------------------------------------- /src/ggpzero/defs/datadesc.py: -------------------------------------------------------------------------------- 1 | import attr 2 | 3 | from ggpzero.util.attrutil import register_attrs 4 | 5 | 6 | # XXX rename? 7 | @register_attrs 8 | class Sample(object): 9 | # state policy trained on. This is a tuple of 0/1s. Effectively a bit array. 10 | state = attr.ib([0, 0, 0, 1]) 11 | 12 | # list of previous state (first element is immediate parent of 'state') 13 | prev_states = attr.ib([1, 0, 0, 1]) 14 | 15 | # list of policy distributions - all should sum to 1. 16 | policies = attr.ib([[0, 0, 0.5, 0.5], [0, 0, 0.5, 0.5]]) 17 | 18 | # list of final scores for value head of network - list has same number as number of roles 19 | final_score = attr.ib([0, 1]) 20 | 21 | # game depth at which point sample is taken 22 | depth = attr.ib(42) 23 | 24 | # total length of game 25 | game_length = attr.ib(42) 26 | 27 | # these are for debug. The match_identifier can be used to extract contigous samples from the 28 | # same match. 29 | match_identifier = attr.ib("agame_421") 30 | has_resigned = attr.ib(False) 31 | resign_false_positive = attr.ib(False) 32 | starting_sample_depth = attr.ib(42) 33 | 34 | # the results after running the puct iterations 35 | resultant_puct_score = attr.ib(attr.Factory(list)) 36 | resultant_puct_visits = attr.ib(800) 37 | 38 | 39 | # XXX rename? 40 | @register_attrs 41 | class GenerationSamples(object): 42 | game = attr.ib("game") 43 | date_created = attr.ib('2018-01-24 22:28') 44 | 45 | # trained with this generation 46 | with_generation = attr.ib("v6_123") 47 | 48 | # number of samples in this generation 49 | num_samples = attr.ib(1024) 50 | 51 | # the samples (of type Sample) 52 | samples = attr.ib(attr.Factory(list)) 53 | 54 | 55 | @register_attrs 56 | class GenerationDescription(object): 57 | ''' this describes the inputs/output to the network, provide information how the gdl 58 | transformations to input/outputs. and other meta information. It does not describe the 59 | internals of the neural network, which is provided by the keras json model file. 60 | 61 | It will ripple through everything: 62 | * network model creation 63 | * reloading and using a trained network 64 | * the inputs/outputs from GdlTransformer 65 | * the channel ordering 66 | * network model creation 67 | * loading 68 | ''' 69 | 70 | game = attr.ib("breakthrough") 71 | name = attr.ib("v6_123") 72 | date_created = attr.ib('2018-01-24 22:28') 73 | 74 | # whether the network expects channel inputs to have channel last format 75 | channel_last = attr.ib(False) 76 | 77 | # whether the network uses multiple policy heads (False - there is one) 78 | multiple_policy_heads = attr.ib(False) 79 | 80 | # number of previous states expected (default is 0). 81 | num_previous_states = attr.ib(0) 82 | 83 | # XXX todo 84 | transformer_description = attr.ib(None) 85 | 86 | draw_head = attr.ib(False) 87 | 88 | # the training config attributes - for debugging, historical purposes 89 | # the number of samples trained on, etc 90 | # the number losses, validation losses, accurcacy 91 | trained_losses = attr.ib('not set') 92 | trained_validation_losses = attr.ib('not set') 93 | trained_policy_accuracy = attr.ib('not set') 94 | trained_value_accuracy = attr.ib('not set') 95 | 96 | 97 | @register_attrs 98 | class StepSummary(object): 99 | step = attr.ib(42) 100 | filename = attr.ib("gendata_hexLG11_6.json.gz") 101 | with_generation = attr.ib("h_5") 102 | num_samples = attr.ib(50000) 103 | 104 | md5sum = attr.ib("93d6ce4b812d353c73f4a8ca5b605d37") 105 | 106 | stats_unique_matches = attr.ib(2200) 107 | stats_draw_ratio = attr.ib(0.03) 108 | 109 | # when all policies lengths = 1 110 | stats_bare_policies_ratio = attr.ib(0.03) 111 | 112 | stats_av_starting_depth = attr.ib(5.5) 113 | stats_av_ending_depth = attr.ib(5.5) 114 | stats_av_resigns = attr.ib(0.05) 115 | stats_av_resign_false_positive = attr.ib(0.2) 116 | 117 | stats_av_puct_visits = attr.ib(2000) 118 | 119 | # if len(policy dist) > 1 and len(other_policy dist) == 1: +1 / #samples 120 | # [0.45, 0.55] 121 | stats_ratio_of_roles = attr.ib(attr.Factory(list)) 122 | 123 | # average score by role 124 | stats_av_final_scores = attr.ib(attr.Factory(list)) 125 | 126 | # score up to for lead_role_index (or role_index 0), number of samples 127 | # [(0.1, 200), (0.1, 200),... (0.9, 200)] 128 | stats_av_puct_score_dist = attr.ib(attr.Factory(list)) 129 | 130 | 131 | @register_attrs 132 | class GenDataSummary(object): 133 | game = attr.ib("game") 134 | gen_prefix = attr.ib("x1") 135 | last_updated = attr.ib('2018-01-24 22:28') 136 | total_samples = attr.ib(10**10) 137 | 138 | # isinstance StepSummary 139 | step_summaries = attr.ib(attr.Factory(list)) 140 | -------------------------------------------------------------------------------- /src/ggpzero/defs/msgs.py: -------------------------------------------------------------------------------- 1 | from ggpzero.util.attrutil import register_attrs, attribute, attr_factory 2 | 3 | from ggpzero.defs import confs, datadesc 4 | 5 | 6 | @register_attrs 7 | class Ping(object): 8 | pass 9 | 10 | 11 | @register_attrs 12 | class Pong(object): 13 | pass 14 | 15 | 16 | @register_attrs 17 | class Ok(object): 18 | message = attribute("ok") 19 | 20 | 21 | @register_attrs 22 | class RequestConfig(object): 23 | pass 24 | 25 | 26 | @register_attrs 27 | class WorkerConfigMsg(object): 28 | conf = attribute(default=attr_factory(confs.WorkerConfig)) 29 | 30 | 31 | @register_attrs 32 | class ConfigureSelfPlay(object): 33 | game = attribute("game") 34 | generation_name = attribute("gen0") 35 | self_play_conf = attribute(default=attr_factory(confs.SelfPlayConfig)) 36 | 37 | 38 | @register_attrs 39 | class RequestSamples(object): 40 | # list of states (0/1 tuples) - to reduce duplicates 41 | new_states = attribute(default=attr_factory(list)) 42 | 43 | 44 | @register_attrs 45 | class RequestSampleResponse(object): 46 | # list of def.confs.Sample 47 | samples = attribute(default=attr_factory(list)) 48 | duplicates_seen = attribute(0) 49 | 50 | 51 | @register_attrs 52 | class RequestNetworkTrain(object): 53 | game = attribute("game") 54 | train_conf = attribute(default=attr_factory(confs.TrainNNConfig)) 55 | network_model = attribute(default=attr_factory(confs.NNModelConfig)) 56 | generation_description = attribute(default=attr_factory(datadesc.GenerationDescription)) 57 | -------------------------------------------------------------------------------- /src/ggpzero/defs/templates.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from ggpzero.defs import confs, datadesc 4 | 5 | 6 | def default_generation_desc(game, name="default", **kwds): 7 | desc = datadesc.GenerationDescription(game) 8 | 9 | desc.name = name 10 | desc.date_created = datetime.now().strftime("%Y/%m/%d %H:%M") 11 | 12 | desc.channel_last = False 13 | desc.multiple_policy_heads = True 14 | desc.num_previous_states = 0 15 | for k, v in kwds.items(): 16 | setattr(desc, k, v) 17 | 18 | return desc 19 | 20 | 21 | def nn_model_config_template(game, network_size_hint, transformer, features=False): 22 | ' helper for creating NNModelConfig templates ' 23 | 24 | conf = confs.NNModelConfig() 25 | 26 | # from transformer 27 | conf.role_count = transformer.role_count 28 | 29 | conf.input_rows = transformer.num_rows 30 | conf.input_columns = transformer.num_cols 31 | conf.input_channels = transformer.num_channels 32 | 33 | # policy distribution head 34 | conf.policy_dist_count = transformer.policy_dist_count 35 | assert isinstance(conf.policy_dist_count, list) and len(conf.policy_dist_count) > 0 36 | 37 | # normal defaults 38 | conf.cnn_kernel_size = 3 39 | conf.dropout_rate_policy = 0.25 40 | conf.dropout_rate_value = 0.5 41 | 42 | if network_size_hint == "small": 43 | conf.cnn_filter_size = 64 44 | conf.residual_layers = 5 45 | conf.value_hidden_size = 256 46 | 47 | elif network_size_hint == "medium": 48 | conf.cnn_filter_size = 96 49 | conf.residual_layers = 5 50 | conf.value_hidden_size = 512 51 | 52 | elif network_size_hint == "large": 53 | conf.cnn_filter_size = 96 54 | conf.residual_layers = 10 55 | conf.value_hidden_size = 512 56 | 57 | else: 58 | assert False, "network_size_hint %s, not recognised" % network_size_hint 59 | 60 | conf.leaky_relu = False 61 | if features: 62 | conf.resnet_v2 = True 63 | conf.squeeze_excite_layers = True 64 | conf.global_pooling_value = True 65 | else: 66 | conf.resnet_v2 = False 67 | conf.squeeze_excite_layers = False 68 | conf.global_pooling_value = False 69 | 70 | return conf 71 | 72 | 73 | def base_puct_config(**kwds): 74 | config = confs.PUCTEvaluatorConfig(verbose=False, 75 | backup_finalised=False, 76 | batch_size=1, 77 | 78 | dirichlet_noise_pct=-1, 79 | 80 | puct_constant=0.85, 81 | puct_constant_root=0.85, 82 | 83 | fpu_prior_discount=0.25, 84 | fpu_prior_discount_root=0.25, 85 | 86 | choose="choose_temperature", 87 | temperature=1.0, 88 | depth_temperature_max=5.0, 89 | depth_temperature_start=2, 90 | depth_temperature_increment=0.2, 91 | depth_temperature_stop=6, 92 | random_scale=0.95, 93 | 94 | think_time=-1, 95 | max_dump_depth=0, 96 | top_visits_best_guess_converge_ratio=0.85, 97 | converged_visits=1, 98 | evaluation_multiplier_to_convergence=2.0) 99 | # ZZZ everyting else? 100 | 101 | for k, v in kwds.items(): 102 | setattr(config, k, v) 103 | 104 | return config 105 | 106 | 107 | def selfplay_config_template(): 108 | conf = confs.SelfPlayConfig() 109 | conf.oscillate_sampling_pct = 0.25 110 | conf.temperature_for_policy = 1.0 111 | 112 | conf.puct_config = base_puct_config(dirichlet_noise_pct=0.25) 113 | conf.evals_per_move = 100 114 | 115 | conf.resign0_score_probability = 0.1 116 | conf.resign0_pct = 0.99 117 | conf.resign1_score_probability = 0.025 118 | conf.resign1_pct = 0.95 119 | 120 | conf.run_to_end_pct = 0.01 121 | conf.run_to_end_evals = 32 122 | conf.run_to_end_puct_config = base_puct_config(dirichlet_noise_pct=0.15, 123 | random_scale=0.75) 124 | conf.run_to_end_early_score = 0.01 125 | conf.run_to_end_minimum_game_depth = 30 126 | 127 | conf.abort_max_length = -1 128 | 129 | return conf 130 | 131 | 132 | def train_config_template(game, gen_prefix): 133 | conf = confs.TrainNNConfig(game) 134 | 135 | conf.generation_prefix = gen_prefix 136 | 137 | conf.next_step = 0 138 | conf.starting_step = 0 139 | conf.use_previous = True 140 | conf.validation_split = 0.95 141 | conf.overwrite_existing = False 142 | 143 | conf.epochs = 1 144 | conf.batch_size = 256 145 | conf.compile_strategy = "SGD" 146 | conf.l2_regularisation = 0.00002 147 | conf.learning_rate = 0.01 148 | 149 | conf.initial_value_weight = 1.0 150 | conf.max_epoch_size = 1024 * 1024 151 | 152 | conf.resample_buckets = [ 153 | [ 154 | 100, 155 | 1.00000 156 | ]] 157 | 158 | return conf 159 | 160 | 161 | def server_config_template(game, generation_prefix, prev_states): 162 | conf = confs.ServerConfig() 163 | 164 | conf.game = game 165 | conf.generation_prefix = generation_prefix 166 | 167 | conf.port = 9000 168 | 169 | conf.current_step = 0 170 | 171 | conf.num_samples_to_train = 20000 172 | conf.max_samples_growth = 0.8 173 | 174 | conf.base_generation_description = default_generation_desc(game, 175 | generation_prefix, 176 | multiple_policy_heads=True, 177 | num_previous_states=prev_states) 178 | 179 | from ggpzero.nn.manager import get_manager 180 | man = get_manager() 181 | transformer = man.get_transformer(game, conf.base_generation_description) 182 | conf.base_network_model = nn_model_config_template(game, "small", transformer) 183 | 184 | conf.base_training_config = train_config_template(game, generation_prefix) 185 | 186 | conf.self_play_config = selfplay_config_template() 187 | return conf 188 | -------------------------------------------------------------------------------- /src/ggpzero/distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richemslie/galvanise_zero/52164bcd6f43d648736e1ae9e556a7f6412339d1/src/ggpzero/distributed/__init__.py -------------------------------------------------------------------------------- /src/ggpzero/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richemslie/galvanise_zero/52164bcd6f43d648736e1ae9e556a7f6412339d1/src/ggpzero/nn/__init__.py -------------------------------------------------------------------------------- /src/ggpzero/nn/manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | from ggplib.db import lookup 5 | 6 | from ggpzero.util import attrutil 7 | from ggpzero.util.keras import keras_models 8 | 9 | from ggpzero.defs import confs, datadesc 10 | 11 | from ggpzero.nn.network import NeuralNetwork 12 | from ggpzero.nn.model import get_network_model 13 | from ggpzero.defs import templates 14 | 15 | the_manager = None 16 | 17 | 18 | def ensure_directory_exists(path): 19 | if not os.path.exists(path): 20 | os.makedirs(path) 21 | 22 | 23 | class Manager(object): 24 | def __init__(self, data_path=None): 25 | 26 | if data_path is None: 27 | data_path = os.path.join(os.environ["GGPZERO_PATH"], "data") 28 | self.data_path = data_path 29 | 30 | # instantiated transformers, lazy constructed 31 | self.transformers = {} 32 | 33 | def samples_path(self, game, generation_prefix): 34 | p = os.path.join(self.data_path, game, generation_prefix) 35 | ensure_directory_exists(p) 36 | return p 37 | 38 | def generation_path(self, game, generation_name=None): 39 | p = os.path.join(self.data_path, game, "generations") 40 | ensure_directory_exists(p) 41 | if generation_name is not None: 42 | filename = "%s.json" % generation_name 43 | p = os.path.join(p, filename) 44 | return p 45 | 46 | def model_path(self, game, generation_name=None): 47 | p = os.path.join(self.data_path, game, "models") 48 | ensure_directory_exists(p) 49 | if generation_name is not None: 50 | filename = "%s.json" % generation_name 51 | p = os.path.join(p, filename) 52 | return p 53 | 54 | def weights_path(self, game, generation_name=None): 55 | p = os.path.join(self.data_path, game, "weights") 56 | ensure_directory_exists(p) 57 | if generation_name is not None: 58 | filename = "%s.h5" % generation_name 59 | p = os.path.join(p, filename) 60 | return p 61 | 62 | def get_transformer(self, game, generation_descr=None): 63 | from ggpzero.nn.bases import GdlBasesTransformer, GdlBasesTransformer_Draws 64 | 65 | if generation_descr is None: 66 | generation_descr = templates.default_generation_desc(game) 67 | 68 | assert isinstance(generation_descr, datadesc.GenerationDescription) 69 | 70 | desc = generation_descr 71 | key = (game, desc.channel_last, desc.multiple_policy_heads, desc.num_previous_states, desc.draw_head) 72 | 73 | transformer = self.transformers.get(key) 74 | 75 | if transformer is None: 76 | # looks up the game in the ggplib database 77 | game_info = lookup.by_name(game) 78 | transformer_clz = GdlBasesTransformer_Draws if generation_descr.draw_head else GdlBasesTransformer 79 | transformer = transformer_clz(game_info, generation_descr) 80 | self.transformers[key] = transformer 81 | 82 | return transformer 83 | 84 | def create_new_network(self, game, nn_model_conf=None, generation_descr=None): 85 | if generation_descr is None: 86 | generation_descr = templates.default_generation_desc(game) 87 | 88 | transformer = self.get_transformer(game, generation_descr) 89 | 90 | if isinstance(nn_model_conf, str): 91 | nn_model_conf = templates.nn_model_config_template(game, 92 | network_size_hint=nn_model_conf, 93 | transformer=transformer) 94 | 95 | elif nn_model_conf is None: 96 | nn_model_conf = templates.nn_model_config_template(game, 97 | network_size_hint="small", 98 | transformer=transformer) 99 | 100 | assert isinstance(nn_model_conf, confs.NNModelConfig) 101 | assert isinstance(generation_descr, datadesc.GenerationDescription) 102 | 103 | keras_model = get_network_model(nn_model_conf, generation_descr) 104 | return NeuralNetwork(transformer, keras_model, generation_descr) 105 | 106 | def save_network(self, nn, generation_name=None): 107 | game = nn.generation_descr.game 108 | if generation_name is None: 109 | generation_name = nn.generation_descr.name 110 | else: 111 | nn.generation_descr.name = generation_name 112 | 113 | # save model / weights 114 | with open(self.model_path(game, generation_name), "w") as f: 115 | f.write(nn.get_model().to_json()) 116 | 117 | nn.get_model().save_weights(self.weights_path(game, generation_name), 118 | overwrite=True) 119 | 120 | with open(self.generation_path(game, generation_name), "w") as f: 121 | f.write(attrutil.attr_to_json(nn.generation_descr, pretty=True)) 122 | 123 | def load_network(self, game, generation_name): 124 | json_str = open(self.generation_path(game, generation_name)).read() 125 | generation_descr = attrutil.json_to_attr(json_str) 126 | 127 | json_str = open(self.model_path(game, generation_name)).read() 128 | keras_model = keras_models.model_from_json(json_str) 129 | 130 | keras_model.load_weights(self.weights_path(game, generation_name)) 131 | transformer = self.get_transformer(game, generation_descr) 132 | return NeuralNetwork(transformer, keras_model, generation_descr) 133 | 134 | def can_load(self, game, generation_name): 135 | exists = os.path.exists 136 | return (exists(self.model_path(game, generation_name)) and 137 | exists(self.weights_path(game, generation_name)) and 138 | exists(self.generation_path(game, generation_name))) 139 | 140 | 141 | ############################################################################### 142 | 143 | def get_manager(): 144 | ' singleton for Manager ' 145 | global the_manager 146 | if the_manager is None: 147 | the_manager = Manager() 148 | 149 | return the_manager 150 | -------------------------------------------------------------------------------- /src/ggpzero/nn/network.py: -------------------------------------------------------------------------------- 1 | ''' mostly just an interface to keras... hoping to try other frameworks too. ''' 2 | 3 | import numpy as np 4 | 5 | from ggplib.util import log 6 | 7 | from ggpzero.util.keras import SGD, Adam, keras_metrics, keras_regularizers, keras_models 8 | 9 | 10 | class HeadResult(object): 11 | def __init__(self, transformer, policies, values): 12 | assert len(transformer.policy_dist_count) == len(policies) 13 | self.policies = policies 14 | self.scores = values 15 | 16 | def __repr__(self): 17 | return "HeadResult(policies=%s, scores=%s" % (self.policies, self.scores) 18 | 19 | 20 | class NeuralNetwork(object): 21 | ''' combines a keras model and gdl bases transformer to give a clean interface to use as a 22 | network. ''' 23 | 24 | def __init__(self, gdl_bases_transformer, keras_model, generation_descr): 25 | self.gdl_bases_transformer = gdl_bases_transformer 26 | self.keras_model = keras_model 27 | self.generation_descr = generation_descr 28 | 29 | def summary(self): 30 | ' log keras nn summary ' 31 | 32 | # one way to get print_summary to output string! 33 | lines = [] 34 | self.keras_model.summary(print_fn=lines.append) 35 | for l in lines: 36 | log.verbose(l) 37 | 38 | def predict_n(self, states, prev_states=None): 39 | ' this is for testing purposes. We use C++ normally to access network ' 40 | # prev_states -> list of list of states 41 | 42 | to_channels = self.gdl_bases_transformer.state_to_channels 43 | if prev_states: 44 | assert len(prev_states) == len(states) 45 | X = np.array([to_channels(s, prevs) 46 | for s, prevs in zip(states, prev_states)]) 47 | else: 48 | X = np.array([to_channels(s) for s in states]) 49 | 50 | Y = self.keras_model.predict(X, batch_size=len(states)) 51 | 52 | result = [] 53 | for i in range(len(states)): 54 | heads = HeadResult(self.gdl_bases_transformer, 55 | [Y[k][i] for k in range(len(Y) - 1)], 56 | Y[-1][i]) 57 | result.append(heads) 58 | 59 | return result 60 | 61 | def predict_1(self, state, prev_states=None): 62 | ' this is for testing purposes. We use C++ normally to access network ' 63 | if prev_states: 64 | return self.predict_n([state], [prev_states])[0] 65 | else: 66 | return self.predict_n([state])[0] 67 | 68 | def compile(self, compile_strategy, learning_rate=None, value_weight=1.0, 69 | l2_loss=None, l2_non_residual=True): 70 | # XXX allow l2_loss on final layers. 71 | 72 | value_objective = "mean_squared_error" 73 | policy_objective = 'categorical_crossentropy' 74 | if compile_strategy == "SGD": 75 | if learning_rate is None: 76 | learning_rate = 0.01 77 | optimizer = SGD(lr=learning_rate, momentum=0.9) 78 | 79 | elif compile_strategy == "adam": 80 | if learning_rate: 81 | optimizer = Adam(lr=learning_rate) 82 | else: 83 | optimizer = Adam() 84 | 85 | elif compile_strategy == "amsgrad": 86 | if learning_rate: 87 | optimizer = Adam(lr=learning_rate, amsgrad=True) 88 | else: 89 | optimizer = Adam(amsgrad=True) 90 | 91 | else: 92 | log.error("UNKNOWN compile strategy %s" % compile_strategy) 93 | raise Exception("UNKNOWN compile strategy %s" % compile_strategy) 94 | 95 | num_policies = len(self.gdl_bases_transformer.policy_dist_count) 96 | 97 | loss = [policy_objective] * num_policies 98 | loss.append(value_objective) 99 | loss_weights = [1.0] * num_policies 100 | loss_weights.append(value_weight) 101 | 102 | if learning_rate is not None: 103 | msg = "Compiling with %s (learning_rate=%.4f, value_weight=%.3f)" 104 | log.warning(msg % (optimizer, learning_rate, value_weight)) 105 | else: 106 | log.warning("Compiling with %s (value_weight=%.3f)" % (optimizer, value_weight)) 107 | 108 | def top_3_acc(y_true, y_pred): 109 | return keras_metrics.top_k_categorical_accuracy(y_true, y_pred, k=3) 110 | 111 | if l2_loss is not None: 112 | log.warning("Applying l2 loss (%.5f)" % l2_loss) 113 | l2_loss = keras_regularizers.l2(l2_loss) 114 | 115 | rebuild_model = False 116 | for layer in self.keras_model.layers: 117 | # To get global weight decay in keras regularizers have to be added to every layer 118 | # in the model. 119 | 120 | if hasattr(layer, 'kernel_regularizer'): 121 | 122 | ignore = False 123 | if l2_non_residual: 124 | ignore = True 125 | 126 | if "policy" in layer.name or "value" in layer.name: 127 | if "flatten" not in layer.name: 128 | ignore = False 129 | else: 130 | ignore = "_se_" in layer.name 131 | 132 | if ignore: 133 | if layer.kernel_regularizer is not None: 134 | log.warning("Ignoring but regularizer was set @ %s/%s. Unsetting." % (layer.name, layer)) 135 | layer.kernel_regularizer = None 136 | rebuild_model = True 137 | 138 | continue 139 | 140 | if l2_loss is not None and layer.kernel_regularizer is None: 141 | rebuild_model = True 142 | log.info("Applying l2 loss to %s/%s" % (layer.name, layer)) 143 | layer.kernel_regularizer = l2_loss 144 | 145 | if layer.kernel_regularizer is not None and l2_loss is None: 146 | log.info("Unsetting l2 loss on %s/%s" % (layer.name, layer)) 147 | rebuild_model = True 148 | layer.kernel_regularizer = l2_loss 149 | 150 | # This ensures a fresh build of the network (there is no API to do this in keras, hence 151 | # this hacky workaround). Furthermore, needing to rebuild the network here, before 152 | # compiling, is somewhat buggy/idiosyncrasy of keras. 153 | if rebuild_model: 154 | config = self.keras_model.get_config() 155 | weights = self.keras_model.get_weights() 156 | self.keras_model = keras_models.Model.from_config(config) 157 | self.keras_model.set_weights(weights) 158 | 159 | self.keras_model.compile(loss=loss, optimizer=optimizer, 160 | loss_weights=loss_weights, 161 | metrics=["acc", top_3_acc]) 162 | 163 | def get_model(self): 164 | assert self.keras_model is not None 165 | return self.keras_model 166 | -------------------------------------------------------------------------------- /src/ggpzero/player/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richemslie/galvanise_zero/52164bcd6f43d648736e1ae9e556a7f6412339d1/src/ggpzero/player/__init__.py -------------------------------------------------------------------------------- /src/ggpzero/player/puctplayer.py: -------------------------------------------------------------------------------- 1 | from builtins import super 2 | 3 | from ggplib.util import log 4 | from ggplib.player.base import MatchPlayer 5 | 6 | from ggpzero.defs import confs 7 | 8 | from ggpzero.util.cppinterface import joint_move_to_ptr, basestate_to_ptr, PlayPoller 9 | 10 | from ggpzero.nn.manager import get_manager 11 | 12 | 13 | class PUCTPlayer(MatchPlayer): 14 | poller = None 15 | last_probability = -1 16 | last_node_count = -1 17 | 18 | def __init__(self, conf): 19 | assert isinstance(conf, (confs.PUCTPlayerConfig, confs.PUCTEvaluatorConfig)) 20 | 21 | self.conf = conf 22 | if conf.playouts_per_iteration > 0: 23 | self.identifier = "%s_%s_%s" % (self.conf.name, conf.playouts_per_iteration, conf.generation) 24 | else: 25 | self.identifier = "%s_%s" % (self.conf.name, conf.generation) 26 | 27 | super().__init__(self.identifier) 28 | self.sm = None 29 | 30 | def cleanup(self): 31 | log.info("PUCTPlayer.cleanup() called") 32 | if self.poller is not None: 33 | self.poller.player_reset(0) 34 | 35 | def on_meta_gaming(self, finish_time): 36 | if self.conf.verbose: 37 | log.info("PUCTPlayer, match id: %s" % self.match.match_id) 38 | 39 | if self.sm is None or "*" in self.conf.generation: 40 | if "*" in self.conf.generation: 41 | log.warning("Using recent generation %s" % self.conf.generation) 42 | 43 | game_info = self.match.game_info 44 | self.sm = game_info.get_sm() 45 | 46 | man = get_manager() 47 | gen = self.conf.generation 48 | 49 | self.nn = man.load_network(game_info.game, gen) 50 | self.poller = PlayPoller(self.sm, self.nn, self.conf.evaluator_config) 51 | 52 | def get_noop_idx(actions): 53 | for idx, a in enumerate(actions): 54 | if "noop" in a: 55 | return idx 56 | assert False, "did not find noop" 57 | 58 | self.role0_noop_legal, self.role1_noop_legal = map(get_noop_idx, game_info.model.actions) 59 | 60 | self.poller.player_reset(self.match.game_depth) 61 | 62 | def on_apply_move(self, joint_move): 63 | self.poller.player_apply_move(joint_move_to_ptr(joint_move)) 64 | self.poller.poll_loop() 65 | 66 | def on_next_move(self, finish_time): 67 | log.info("PUCTPlayer.on_next_move(), %s" % self.get_name()) 68 | current_state = self.match.get_current_state() 69 | self.sm.update_bases(current_state) 70 | 71 | if (self.sm.get_legal_state(0).get_count() == 1 and 72 | self.sm.get_legal_state(0).get_legal(0) == self.role0_noop_legal): 73 | lead_role_index = 1 74 | 75 | else: 76 | assert (self.sm.get_legal_state(1).get_count() == 1 and 77 | self.sm.get_legal_state(1).get_legal(0) == self.role1_noop_legal) 78 | lead_role_index = 0 79 | 80 | if lead_role_index == self.match.our_role_index: 81 | max_iterations = self.conf.playouts_per_iteration 82 | else: 83 | max_iterations = self.conf.playouts_per_iteration_noop 84 | 85 | current_state = self.match.get_current_state() 86 | 87 | self.poller.player_move(basestate_to_ptr(current_state), max_iterations, finish_time) 88 | self.poller.poll_loop() 89 | 90 | move, prob, node_count = self.poller.player_get_move(self.match.our_role_index) 91 | self.last_probability = prob 92 | self.last_node_count = node_count 93 | return move 94 | 95 | def balance_moves(self, max_count): 96 | self.poller.player_balance_moves(max_count) 97 | self.poller.poll_loop() 98 | 99 | def tree_debug(self, max_count): 100 | return self.poller.player_tree_debug(max_count) 101 | 102 | def update_config(self, *args, **kwds): 103 | self.poller.player_update_config(*args, **kwds) 104 | 105 | def __repr__(self): 106 | return self.get_name() 107 | -------------------------------------------------------------------------------- /src/ggpzero/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richemslie/galvanise_zero/52164bcd6f43d648736e1ae9e556a7f6412339d1/src/ggpzero/scripts/__init__.py -------------------------------------------------------------------------------- /src/ggpzero/scripts/cleanup_nnfiles.py: -------------------------------------------------------------------------------- 1 | ''' go through all data for all games, and ask to remove spurious data files. We want to keep 2 | every 7th iteration for evaluation. ''' 3 | 4 | import os 5 | from pathlib import Path 6 | 7 | 8 | def go(game_path): 9 | game_name = game_path.name 10 | 11 | models_path = weights_path = None 12 | for p in game_path.iterdir(): 13 | if p.name == 'weights' and p.is_dir(): 14 | weights_path = p 15 | elif p.name == 'models' and p.is_dir(): 16 | models_path = p 17 | 18 | if models_path and weights_path: 19 | print game_path, "is valid game" 20 | 21 | # go through each of the models_path, weights_path 22 | marked_for_deletion = [] 23 | valid_generations_by_prefix = {} 24 | for root in (models_path, weights_path): 25 | for p in root.iterdir(): 26 | if game_name not in p.name: 27 | print "**NOT** a nn file", p 28 | continue 29 | 30 | if not p.name.startswith(game_name): 31 | print "**NOT** a valid nn filename", p 32 | continue 33 | 34 | name = p.name.replace(game_name + "_", "") 35 | parts = name.split(".") 36 | 37 | try: 38 | generation = parts[0] 39 | 40 | gen_split = generation.split("_") 41 | if gen_split[-1] == "prev": 42 | marked_for_deletion.append(p) 43 | continue 44 | 45 | step = int(gen_split[-1]) 46 | generation_prefix = "_".join(generation.split("_")[:-1]) 47 | 48 | valid_generations_by_prefix.setdefault(generation_prefix, []).append((p, step)) 49 | 50 | except Exception as exc: 51 | print exc 52 | print "**NOT** a valid nn filename", p 53 | continue 54 | 55 | print "_____________" 56 | 57 | for gp, gens in valid_generations_by_prefix.items(): 58 | max_step = max(s for _, s in gens) 59 | print "FOUND:", gp, "max_step", max_step 60 | 61 | keep_step_gt = (max_step / 5) * 5 62 | for p, s in gens: 63 | if s < keep_step_gt and s % 5 != 0: 64 | marked_for_deletion.append(p) 65 | else: 66 | print "WILL KEEP", p 67 | 68 | if marked_for_deletion: 69 | print "" 70 | print "game:", game_name 71 | print "marked_for_deletion", [m.name for m in marked_for_deletion] 72 | print "" 73 | print "" 74 | print "delete?" 75 | raw = raw_input() 76 | if raw == "Y": 77 | for p in marked_for_deletion: 78 | print 'bye', p 79 | p.unlink() 80 | 81 | print "======================" 82 | print 83 | 84 | 85 | def main(): 86 | data_path = Path(os.path.join(os.environ["GGPZERO_PATH"], "data")) 87 | game_paths = [] 88 | for p in data_path.iterdir(): 89 | if p.is_dir() and p.name not in "confs tournament": 90 | game_paths.append(p) 91 | 92 | # check we have models and weights in there 93 | for game_path in game_paths: 94 | go(game_path) 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /src/ggpzero/scripts/shownn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from ggpzero.nn.manager import get_manager 4 | 5 | if __name__ == "__main__": 6 | def main(args): 7 | game = args[0] 8 | gen = args[1] 9 | 10 | man = get_manager() 11 | nn = man.load_network(game, gen) 12 | nn.summary() 13 | 14 | from ggpzero.util.main import main_wrap 15 | main_wrap(main) 16 | -------------------------------------------------------------------------------- /src/ggpzero/scripts/supervised_train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from ggpzero.defs import confs, templates 4 | 5 | from ggpzero.nn.manager import get_manager 6 | from ggpzero.nn import train 7 | 8 | 9 | def get_train_config(game, gen_prefix, next_step, starting_step): 10 | config = confs.TrainNNConfig(game) 11 | 12 | config.next_step = next_step 13 | config.starting_step = starting_step 14 | 15 | config.generation_prefix = gen_prefix 16 | config.batch_size = 1024 17 | config.compile_strategy = "SGD" 18 | config.epochs = 10 19 | 20 | config.learning_rate = 0.01 21 | 22 | config.overwrite_existing = False 23 | config.use_previous = False 24 | config.validation_split = 0.95000 25 | config.resample_buckets = [[200, 1.0]] 26 | config.max_epoch_size = 1048576 * 2 27 | 28 | return config 29 | 30 | 31 | def get_nn_model(game, transformer, size="small"): 32 | config = templates.nn_model_config_template(game, size, transformer, features=True) 33 | 34 | config.cnn_filter_size = 96 35 | config.residual_layers = 6 36 | config.value_hidden_size = 512 37 | 38 | config.dropout_rate_policy = 0.25 39 | config.dropout_rate_value = 0.5 40 | 41 | # config.concat_all_layers = True 42 | # config.global_pooling_value = False 43 | 44 | config.concat_all_layers = False 45 | config.global_pooling_value = True 46 | 47 | return config 48 | 49 | 50 | def do_training(game, gen_prefix, next_step, starting_step, num_previous_states, 51 | gen_prefix_next, do_data_augmentation=False): 52 | 53 | man = get_manager() 54 | 55 | # create a transformer 56 | generation_descr = templates.default_generation_desc(game, 57 | multiple_policy_heads=True, 58 | num_previous_states=num_previous_states) 59 | transformer = man.get_transformer(game, generation_descr) 60 | 61 | # create train_config 62 | train_config = get_train_config(game, gen_prefix, next_step, starting_step) 63 | trainer = train.TrainManager(train_config, transformer, do_data_augmentation=do_data_augmentation) 64 | trainer.update_config(train_config, next_generation_prefix=gen_prefix_next) 65 | 66 | # get the nn model and set on trainer 67 | nn_model_config = get_nn_model(train_config.game, transformer) 68 | trainer.get_network(nn_model_config, generation_descr) 69 | 70 | trainer.do_epochs() 71 | trainer.save() 72 | 73 | 74 | if __name__ == "__main__": 75 | 76 | def main(args): 77 | gen_prefix_next = sys.argv[1] 78 | 79 | # modify these >>> 80 | game = "hex_lg_19" 81 | gen_prefix = "h2" 82 | 83 | next_step = 220 84 | starting_step = 0 85 | num_previous_states = 1 86 | 87 | do_training(game, gen_prefix, next_step, starting_step, 88 | num_previous_states, gen_prefix_next, do_data_augmentation=True) 89 | 90 | from ggpzero.util.main import main_wrap 91 | main_wrap(main) 92 | -------------------------------------------------------------------------------- /src/ggpzero/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richemslie/galvanise_zero/52164bcd6f43d648736e1ae9e556a7f6412339d1/src/ggpzero/util/__init__.py -------------------------------------------------------------------------------- /src/ggpzero/util/attrutil.py: -------------------------------------------------------------------------------- 1 | ''' XXX move out of this repo ''' 2 | 3 | import sys 4 | import json 5 | 6 | import attr 7 | 8 | # register classes is to add a tiny bit of security, otherwise could end up executing any old code 9 | _registered_clz = set() 10 | 11 | 12 | class SerialiseException(Exception): 13 | pass 14 | 15 | 16 | def register_clz(clz): 17 | reg = clz.__module__, clz.__name__ 18 | _registered_clz.add(reg) 19 | 20 | 21 | def get_clz(mod, name): 22 | # XXX these look terrible monkey patching hacks :( 23 | if mod == 'ggpzero.defs.confs' and name == 'Generation': 24 | mod = 'ggpzero.defs.datadesc' 25 | name = 'GenerationSamples' 26 | if mod == 'ggpzero.defs.confs' and name == 'Sample': 27 | mod = 'ggpzero.defs.datadesc' 28 | if (mod, name) not in _registered_clz: 29 | raise SerialiseException("Attempt to create an unregistered class: %s / %s" % (mod, name)) 30 | return getattr(sys.modules[mod], name) 31 | 32 | 33 | class AttrDict(dict): 34 | def __init__(self, *args, **kwds): 35 | dict.__init__(self, *args, **kwds) 36 | self._enabled = True 37 | 38 | def _add_clz_info(self, name, obj): 39 | clz = obj.__class__ 40 | key = "%s__clz__" % name 41 | value = clz.__module__, clz.__name__ 42 | 43 | if value not in _registered_clz: 44 | raise SerialiseException("Attempt to serialise unregistered class: %s / %s" % value) 45 | 46 | self[key] = value 47 | 48 | def _add_clz_info_list(self, name, obj): 49 | clz = obj.__class__ 50 | key = "%s__clzlist__" % name 51 | value = clz.__module__, clz.__name__ 52 | 53 | if value not in _registered_clz: 54 | raise SerialiseException("Attempt to serialise unregistered class: %s / %s" % value) 55 | 56 | self[key] = value 57 | 58 | def __setitem__(self, k, v): 59 | if self._enabled: 60 | if isinstance(v, (list, tuple)): 61 | self._do_list(k, v) 62 | return 63 | 64 | if attr.has(v): 65 | self._add_clz_info(k, v) 66 | 67 | # this recurses via AttrDict 68 | as_dict = attr.asdict(v, recurse=False, dict_factory=AttrDict) 69 | dict.__setitem__(self, k, as_dict) 70 | return 71 | 72 | dict.__setitem__(self, k, v) 73 | 74 | def _do_list(self, k, v): 75 | assert isinstance(v, (list, tuple)) 76 | 77 | # makes a shallow copy 78 | v = v.__class__(v) 79 | 80 | # anything to do? 81 | if not any(attr.has(i) for i in v): 82 | dict.__setitem__(self, k, v) 83 | return 84 | 85 | # check all the same type or not mixed 86 | if sum(issubclass(type(i), type(v[0])) for i in v) != len(v): 87 | raise Exception("Bad list %s" % v) 88 | 89 | self._add_clz_info_list(k, v[0]) 90 | as_list = [attr.asdict(i, recurse=False, dict_factory=AttrDict) for i in v] 91 | dict.__setitem__(self, k, as_list) 92 | 93 | 94 | def asdict_plus(obj): 95 | res = AttrDict() 96 | res['obj'] = obj 97 | return res 98 | 99 | 100 | def _fromdict_plus(d): 101 | # disable or we end up adding back in the ...__clz__ keys 102 | if isinstance(d, AttrDict): 103 | d._enabled = False 104 | 105 | for k in d.keys(): 106 | if "__clz__" in k: 107 | # get clz and remove ...__clz__ key from dict 108 | mod, clz_name = d.pop(k) 109 | clz = get_clz(mod, clz_name) 110 | 111 | # recurse 112 | k = k.replace('__clz__', '') 113 | new_v = _fromdict_plus(d[k]) 114 | 115 | # build object and replace in current dict 116 | d[k] = clz(**new_v) 117 | 118 | if "__clzlist__" in k: 119 | # get clz and remove ...__clz__ key from dict 120 | mod, clz_name = d.pop(k) 121 | clz = get_clz(mod, clz_name) 122 | 123 | # recurse 124 | k = k.replace('__clzlist__', '') 125 | value = d[k] 126 | 127 | assert isinstance(value, (list, tuple)) 128 | recurse_v = [_fromdict_plus(i) for i in value] 129 | 130 | # build object and replace in current dict 131 | d[k] = [clz(**i) for i in recurse_v] 132 | 133 | return d 134 | 135 | 136 | def fromdict_plus(d): 137 | res = _fromdict_plus(d) 138 | assert 'obj' in res 139 | assert len(res) == 1 140 | return res['obj'] 141 | 142 | 143 | def attr_to_json(obj, **kwds): 144 | assert attr.has(obj) 145 | 146 | if kwds.pop("pretty", False): 147 | kwds.update(sort_keys=True, 148 | separators=(',', ': '), 149 | indent=4) 150 | 151 | return json.dumps(asdict_plus(obj), **kwds) 152 | 153 | 154 | def json_to_attr(buf, **kwds): 155 | d = json.loads(buf, **kwds) 156 | return fromdict_plus(d) 157 | 158 | 159 | def pprint(obj): 160 | assert attr.has(obj) 161 | from pprint import pprint 162 | pprint(attr.asdict(obj)) 163 | 164 | 165 | def pformat(obj): 166 | assert attr.has(obj) 167 | from pprint import pformat 168 | return pformat(attr.asdict(obj)) 169 | 170 | 171 | def register_attrs(clz): 172 | clz = attr.s(clz, slots=True) 173 | register_clz(clz) 174 | return clz 175 | 176 | 177 | def clone(attr_object): 178 | # this is kind of horrible - but at least we are sure it works 179 | return fromdict_plus(asdict_plus(attr_object)) 180 | 181 | 182 | def has(inst, key): 183 | return key in attr.asdict(inst).keys() 184 | 185 | 186 | attribute = attr.ib 187 | attr_factory = attr.Factory 188 | -------------------------------------------------------------------------------- /src/ggpzero/util/broker.py: -------------------------------------------------------------------------------- 1 | ''' XXX move out of this repo ''' 2 | 3 | import zlib 4 | import codecs 5 | import random 6 | import string 7 | import struct 8 | import traceback 9 | 10 | import attr 11 | 12 | from twisted.internet import protocol, reactor 13 | 14 | from ggplib.util import log 15 | from ggpzero.util import attrutil, func 16 | 17 | 18 | @attrutil.register_attrs 19 | class Message(object): 20 | name = attr.ib() 21 | payload = attr.ib() 22 | 23 | 24 | def challenge(n): 25 | return "".join(random.choice(string.printable) for i in range(n)) 26 | 27 | 28 | def response(s): 29 | ' ok this arbitrary - just a lot of port scanning on my server, and this is a detterent ' 30 | buf = [] 31 | res = codecs.encode(s, 'rot_13') 32 | swap_me = True 33 | for c0, c1 in func.chunks(res, 2): 34 | x = (ord(c0) + ord(c1)) % 100 35 | while True: 36 | if chr(x) in string.printable: 37 | break 38 | x -= 1 39 | if x < 20: 40 | if swap_me: 41 | x = ord(c0) 42 | swap_me = False 43 | else: 44 | x = ord(c1) 45 | swap_me = True 46 | 47 | buf.append(chr(x)) 48 | res = "".join(buf) 49 | return res + res[::-1] 50 | 51 | 52 | def clz_to_name(clz): 53 | return "%s.%s" % (clz.__module__, clz.__name__) 54 | 55 | 56 | class Broker(object): 57 | def __init__(self): 58 | self.handlers = {} 59 | 60 | def register(self, attr_clz, cb): 61 | ' registers a attr class to a callback ' 62 | assert attr.has(attr_clz) 63 | self.handlers[clz_to_name(attr_clz)] = cb 64 | 65 | def onMessage(self, caller, msg): 66 | if msg.name not in self.handlers: 67 | log.error("%s : unknown msg %s" % (caller, str(msg.name))) 68 | caller.disconnect() 69 | return 70 | 71 | try: 72 | cb = self.handlers[msg.name] 73 | res = cb(caller, msg.payload) 74 | 75 | # doesn't necessarily need to have a response 76 | if res is not None: 77 | caller.send_msg(res) 78 | 79 | except Exception as e: 80 | log.error("%s : exception calling method %s. " % (caller, str(msg.name))) 81 | log.error("%s" % e) 82 | log.error(traceback.format_exc()) 83 | 84 | # do this last as might raise also... 85 | caller.disconnect() 86 | 87 | def start(self): 88 | reactor.run() 89 | 90 | 91 | class Client(protocol.Protocol): 92 | CHALLENGE_SIZE = 512 93 | 94 | def __init__(self, broker): 95 | self.broker = broker 96 | self.logical_connection = False 97 | 98 | # this buffer is only used until logical_connection made 99 | self.start_buf = "" 100 | 101 | self.rxd = [] 102 | self.header = struct.Struct("=i") 103 | 104 | def disconnect(self): 105 | self.transport.loseConnection() 106 | 107 | def connectionMade(self): 108 | self.logical_connection = False 109 | log.debug("Client::connectionMade()") 110 | 111 | def connectionLost(self, reason=""): 112 | self.logical_connection = False 113 | log.debug("Client::connectionLost() : %s" % reason) 114 | 115 | def unbuffer_data(self): 116 | # flatten 117 | buf = ''.join(self.rxd) 118 | 119 | while True: 120 | buf_len = len(buf) 121 | if buf_len < self.header.size: 122 | break 123 | 124 | payload_len = self.header.unpack_from(buf[:self.header.size])[0] 125 | if buf_len < payload_len + self.header.size: 126 | break 127 | 128 | # good, we have a message 129 | offset = self.header.size 130 | compressed_data = buf[offset:offset + payload_len] 131 | offset += payload_len 132 | 133 | data = zlib.decompress(compressed_data) 134 | msg = attrutil.json_to_attr(data) 135 | yield msg 136 | 137 | buf = buf[offset:] 138 | 139 | # compact 140 | self.rxd = [] 141 | if len(buf): 142 | self.rxd.append(buf) 143 | 144 | def init_data_rxd(self, data): 145 | raise NotImplementedError 146 | 147 | def dataReceived(self, data): 148 | if self.logical_connection: 149 | self.rxd.append(data) 150 | for msg in self.unbuffer_data(): 151 | self.broker.onMessage(self, msg) 152 | else: 153 | self.init_data_rxd(data) 154 | 155 | def format_msg(self, payload): 156 | assert attr.has(payload) 157 | name = clz_to_name(payload.__class__) 158 | 159 | msg = Message(name, payload) 160 | 161 | data = attrutil.attr_to_json(msg) 162 | compressed_data = zlib.compress(data) 163 | 164 | preamble = self.header.pack(len(compressed_data)) 165 | assert len(preamble) == self.header.size 166 | return preamble + compressed_data 167 | 168 | def send_msg(self, payload): 169 | self.transport.write(self.format_msg(payload)) 170 | 171 | 172 | class BrokerClient(Client): 173 | def init_data_rxd(self, data): 174 | self.start_buf += data 175 | if len(self.start_buf) == self.CHALLENGE_SIZE: 176 | self.transport.write(response(self.start_buf)) 177 | self.logical_connection = True 178 | log.info("Logical connection established") 179 | 180 | 181 | class BrokerClientFactory(protocol.ReconnectingClientFactory): 182 | ' client side factory, connects to server ' 183 | 184 | # maximum number of seconds between connection attempts 185 | maxDelay = 30 186 | 187 | # delay for the first reconnection attempt 188 | initialDelay = 2 189 | 190 | # a multiplicitive factor by which the delay grows 191 | factor = 1.5 192 | 193 | def __init__(self, broker): 194 | self.broker = broker 195 | 196 | def buildProtocol(self, addr): 197 | log.debug("Connection made to: %s" % addr) 198 | return BrokerClient(self.broker) 199 | 200 | 201 | class ServerClient(Client): 202 | 203 | def init_data_rxd(self, data): 204 | self.start_buf += data 205 | if len(self.start_buf) == self.CHALLENGE_SIZE: 206 | if self.expected_response == self.start_buf: 207 | self.logical_connection = True 208 | log.info("Logical connection made") 209 | self.broker.new_broker_client(self) 210 | else: 211 | self.logical_connection = True 212 | log.error("Logical connection failed") 213 | self.disconnect() 214 | 215 | def connectionMade(self): 216 | Client.connectionMade(self) 217 | msg = challenge(self.CHALLENGE_SIZE) 218 | self.transport.write(msg) 219 | self.expected_response = response(msg) 220 | 221 | def connectionLost(self, reason=""): 222 | if self.logical_connection: 223 | self.broker.remove_broker_client(self) 224 | Client.connectionLost(self, reason) 225 | 226 | 227 | class ServerFactory(protocol.Factory): 228 | def __init__(self, broker): 229 | self.broker = broker 230 | 231 | def buildProtocol(self, addr): 232 | log.debug("Connection made from: %s" % addr) 233 | return ServerClient(self.broker) 234 | -------------------------------------------------------------------------------- /src/ggpzero/util/cppinterface.py: -------------------------------------------------------------------------------- 1 | from builtins import super 2 | 3 | import time 4 | 5 | import attr 6 | import numpy as np 7 | 8 | import ggpzero_interface 9 | from ggpzero.defs import confs, datadesc 10 | 11 | 12 | def sm_to_ptr(sm): 13 | from ggplib.interface import ffi 14 | cffi_ptr = sm.c_statemachine 15 | ptr_as_long = int(ffi.cast("intptr_t", cffi_ptr)) 16 | return ptr_as_long 17 | 18 | 19 | def joint_move_to_ptr(joint_move): 20 | from ggplib.interface import ffi 21 | cffi_ptr = joint_move.c_joint_move 22 | ptr_as_long = int(ffi.cast("intptr_t", cffi_ptr)) 23 | return ptr_as_long 24 | 25 | 26 | def basestate_to_ptr(basestate): 27 | from ggplib.interface import ffi 28 | cffi_ptr = basestate.c_base_state 29 | ptr_as_long = int(ffi.cast("intptr_t", cffi_ptr)) 30 | return ptr_as_long 31 | 32 | 33 | def create_c_transformer(transformer): 34 | TransformerClz = ggpzero_interface.GdlBasesTransformer 35 | c_transformer = TransformerClz(transformer.channel_size, 36 | transformer.raw_channels_per_state, 37 | transformer.num_of_controls_channels, 38 | transformer.num_previous_states, 39 | transformer.num_rewards, 40 | transformer.policy_dist_count) 41 | 42 | # build it up 43 | for b in transformer.board_space: 44 | index = transformer.channel_size * b.channel_id + b.y_idx * transformer.num_rows + b.x_idx 45 | c_transformer.add_board_base(b.base_indx, index) 46 | 47 | for c in transformer.control_space: 48 | c_transformer.add_control_base(c.base_indx, c.channel_id, c.value) 49 | 50 | return c_transformer 51 | 52 | 53 | class PollerBase(object): 54 | POLL_AGAIN = "poll_again" 55 | 56 | def __init__(self, sm, nn, batch_size=1024, sleep_between_poll=-1): 57 | self.sm = sm 58 | self.nn = nn 59 | 60 | # maximum number of batches we can do 61 | self.batch_size = batch_size 62 | 63 | # if defined, will sleep between polls 64 | self.sleep_between_poll = sleep_between_poll 65 | 66 | self.poll_last = None 67 | self.reset_stats() 68 | 69 | def _get_poller(self): 70 | raise NotImplemented 71 | 72 | def reset_stats(self): 73 | self.num_predictions_calls = 0 74 | self.total_predictions = 0 75 | self.acc_time_polling = 0 76 | self.acc_time_prediction = 0 77 | 78 | def poll(self, do_stats=False): 79 | ''' POLL_AGAIN is returned, to indicate we need to call poll() again. ''' 80 | 81 | transformer = self.nn.gdl_bases_transformer 82 | expect_num_arrays = len(transformer.policy_dist_count) + 1 83 | 84 | if self.poll_last is None: 85 | dummy = np.zeros(0) 86 | arrays = [dummy for _ in range(expect_num_arrays)] 87 | else: 88 | arrays = list(self.poll_last) 89 | 90 | assert len(arrays) == expect_num_arrays 91 | 92 | if do_stats: 93 | s0 = time.time() 94 | 95 | pred_array = self._get_poller().poll(len(arrays[0]), arrays) 96 | if pred_array is None: 97 | self.poll_last = None 98 | return 99 | 100 | if do_stats: 101 | s1 = time.time() 102 | 103 | t = self.nn.gdl_bases_transformer 104 | num_predictions = len(pred_array) / (t.num_channels * t.channel_size) 105 | assert num_predictions <= self.batch_size 106 | 107 | # make sure array is correct shape for keras/tensorflow (no memory is allocated) 108 | pred_array = pred_array.reshape(num_predictions, t.num_channels, t.num_cols, t.num_rows) 109 | self.poll_last = self.nn.get_model().predict_on_batch(pred_array) 110 | 111 | if do_stats: 112 | s2 = time.time() 113 | 114 | self.num_predictions_calls += 1 115 | self.total_predictions += num_predictions 116 | self.acc_time_polling += s1 - s0 117 | self.acc_time_prediction += s2 - s1 118 | 119 | return self.POLL_AGAIN 120 | 121 | def poll_loop(self, cb=None, do_stats=False): 122 | ''' will poll until we are done ''' 123 | 124 | # calls back every n times 125 | cb_every_n = 100 126 | count = 1 127 | while self.poll(do_stats=do_stats) == self.POLL_AGAIN: 128 | if count % cb_every_n == 0: 129 | if cb is not None: 130 | if cb(): 131 | break 132 | count += 1 133 | if self.sleep_between_poll > 0: 134 | time.sleep(self.sleep_between_poll) 135 | 136 | def update_nn(self, nn): 137 | self.nn = nn 138 | 139 | def dump_stats(self): 140 | print "num of prediction calls", self.num_predictions_calls 141 | print "predictions", self.total_predictions 142 | print "acc_time_polling", self.acc_time_polling 143 | print "acc_time_prediction", self.acc_time_prediction 144 | 145 | 146 | class PlayPoller(PollerBase): 147 | def __init__(self, sm, nn, conf): 148 | assert isinstance(conf, confs.PUCTEvaluatorConfig) 149 | super().__init__(sm, nn, batch_size=conf.batch_size) 150 | transformer = nn.gdl_bases_transformer 151 | self.c_transformer = create_c_transformer(transformer) 152 | self.c_player = ggpzero_interface.Player(sm_to_ptr(sm), 153 | self.c_transformer, 154 | attr.asdict(conf)) 155 | 156 | for name in "reset apply_move move get_move update_config balance_moves tree_debug".split(): 157 | name = "player_" + name 158 | setattr(self, name, getattr(self.c_player, name)) 159 | 160 | def _get_poller(self): 161 | return self.c_player 162 | 163 | 164 | class Supervisor(PollerBase): 165 | def __init__(self, sm, nn, batch_size=1024, 166 | sleep_between_poll=-1, workers=None, 167 | identifier=""): 168 | 169 | transformer = nn.gdl_bases_transformer 170 | self.c_transformer = create_c_transformer(transformer) 171 | 172 | self.c_supervisor = ggpzero_interface.Supervisor(sm_to_ptr(sm), 173 | self.c_transformer, 174 | batch_size, 175 | identifier) 176 | if workers: 177 | self.c_supervisor.set_num_workers(workers) 178 | 179 | self.bs_for_unique_states = sm.new_base_state() 180 | 181 | super().__init__(sm, nn, batch_size=batch_size, sleep_between_poll=sleep_between_poll) 182 | 183 | def _get_poller(self): 184 | return self.c_supervisor 185 | 186 | def start_self_play(self, conf, num_workers): 187 | assert isinstance(conf, confs.SelfPlayConfig) 188 | return self.c_supervisor.start_self_play(num_workers, attr.asdict(conf)) 189 | 190 | def fetch_samples(self): 191 | res = self.c_supervisor.fetch_samples() 192 | if res: 193 | return [datadesc.Sample(**d) for d in res] 194 | else: 195 | return [] 196 | 197 | def add_unique_state(self, s): 198 | return # ZZZ remove this 199 | assert isinstance(s, str) 200 | assert len(s) == self.bs_for_unique_states.num_bytes 201 | self.bs_for_unique_states.from_string(s) 202 | self.c_supervisor.add_unique_state(basestate_to_ptr(self.bs_for_unique_states)) 203 | 204 | def clear_unique_states(self): 205 | self.c_supervisor.clear_unique_states() 206 | -------------------------------------------------------------------------------- /src/ggpzero/util/func.py: -------------------------------------------------------------------------------- 1 | ''' XXX move out of this repo ''' 2 | 3 | import os 4 | import json 5 | 6 | 7 | def chunks(l, n): 8 | for i in range(0, len(l), n): 9 | yield l[i:i + n] 10 | 11 | 12 | def get_from_json(path, includes=None, excludes=None): 13 | includes = includes or [] 14 | excludes = excludes or [] 15 | files = os.listdir(path) 16 | for the_file in files: 17 | if not the_file.endswith(".json"): 18 | continue 19 | 20 | if len([ii for ii in includes if ii not in the_file]): 21 | continue 22 | 23 | if len([ii for ii in excludes if ii in the_file]): 24 | continue 25 | 26 | for ii in excludes: 27 | if ii in the_file: 28 | continue 29 | 30 | filename = os.path.join(path, the_file) 31 | buf = open(filename).read() 32 | yield json.loads(buf), filename 33 | -------------------------------------------------------------------------------- /src/ggpzero/util/keras.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from ggplib.util import log 3 | 4 | 5 | from keras.optimizers import SGD, Adam 6 | from keras.utils.generic_utils import Progbar 7 | import keras.callbacks as keras_callbacks 8 | from keras import metrics as keras_metrics 9 | import keras.backend as K 10 | 11 | from keras import models as keras_models 12 | from keras import layers as keras_layers 13 | from keras import regularizers as keras_regularizers 14 | 15 | 16 | def _bla(): 17 | ' i am here to confuse flake8 ' 18 | print SGD, Adam, Progbar, keras_callbacks, keras_metrics 19 | print keras_models, keras_layers, keras_regularizers 20 | 21 | 22 | def is_channels_first(): 23 | ' NCHW is cuDNN default, and what tf wants for GPU. ' 24 | return K.image_data_format() == "channels_first" 25 | 26 | 27 | def antirectifier(inputs): 28 | inputs -= K.mean(inputs, axis=1, keepdims=True) 29 | inputs = K.l2_normalize(inputs, axis=1) 30 | pos = K.relu(inputs) 31 | neg = K.relu(-inputs) 32 | return K.concatenate([pos, neg], axis=1) 33 | 34 | 35 | def antirectifier_output_shape(input_shape): 36 | shape = list(input_shape) 37 | assert len(shape) == 2 # only valid for 2D tensors 38 | shape[-1] *= 2 39 | return tuple(shape) 40 | 41 | 42 | def get_antirectifier(name): 43 | # output_shape=antirectifier_output_shape 44 | return keras_layers.Lambda(antirectifier, name=name) 45 | 46 | 47 | def constrain_resources_tf(): 48 | ' constrain resource as tensorflow likes to assimilate your machine rendering it useless ' 49 | 50 | import tensorflow as tf 51 | from tensorflow.python.client import device_lib 52 | 53 | local_device_protos = device_lib.list_local_devices() 54 | gpu_available = [x.name for x in local_device_protos if x.device_type == 'GPU'] 55 | 56 | if not gpu_available: 57 | # this doesn't strictly use just one cpu... but seems it is the best one can do 58 | config = tf.ConfigProto(device_count=dict(CPU=1), 59 | allow_soft_placement=False, 60 | log_device_placement=False, 61 | intra_op_parallelism_threads=1, 62 | inter_op_parallelism_threads=1) 63 | else: 64 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.25, 65 | allow_growth=True) 66 | 67 | config = tf.ConfigProto(gpu_options=gpu_options) 68 | 69 | sess = tf.Session(config=config) 70 | 71 | K.set_session(sess) 72 | 73 | 74 | def init(data_format='channels_first'): 75 | assert K.backend() == "tensorflow" 76 | 77 | if K.image_data_format() != data_format: 78 | was = K.image_data_format() 79 | K.set_image_data_format(data_format) 80 | log.warning("Changing image_data_format: %s -> %s" % (was, K.image_data_format())) 81 | 82 | constrain_resources_tf() 83 | -------------------------------------------------------------------------------- /src/ggpzero/util/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import sys 4 | import traceback 5 | 6 | from ggplib.util.init import setup_once 7 | from ggplib.util import log 8 | from ggpzero.util.keras import init 9 | 10 | 11 | def main_wrap(main_fn, logfile_name=None, **kwds): 12 | if logfile_name is None: 13 | # if logfile_name not set, derive it from main_fn 14 | fn = main_fn.func_code.co_filename 15 | logfile_name = os.path.splitext(os.path.basename(fn))[0] 16 | 17 | setup_once(logfile_name) 18 | 19 | try: 20 | # we might be running under python with no keras/numpy support 21 | init(**kwds) 22 | 23 | except ImportError as exc: 24 | log.warning("ImportError: %s" % exc) 25 | 26 | try: 27 | if main_fn.func_code.co_argcount == 0: 28 | return main_fn() 29 | else: 30 | return main_fn(sys.argv[1:]) 31 | 32 | except Exception as exc: 33 | print exc 34 | _, _, tb = sys.exc_info() 35 | traceback.print_exc() 36 | pdb.post_mortem(tb) 37 | -------------------------------------------------------------------------------- /src/ggpzero/util/runprocs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import shlex 4 | from signal import SIGKILL, SIGTERM 5 | from subprocess import PIPE, Popen 6 | 7 | from twisted.internet import reactor 8 | 9 | from ggplib.util import log 10 | 11 | 12 | class RunCmds(object): 13 | def __init__(self, cmds, cb_on_completion=None, max_time=2.0): 14 | assert len(cmds) == len(set(cmds)), "cmds not unique: %s" % cmds 15 | self.cmds = cmds 16 | self.cb_on_completion = cb_on_completion 17 | self.max_time = max_time 18 | 19 | self.timeout_time = None 20 | self.killing = set() 21 | self.terminating = set() 22 | 23 | def spawn(self): 24 | self.procs = [(cmd, Popen(shlex.split(cmd), 25 | shell=False, stdout=PIPE, stderr=PIPE)) for cmd in self.cmds] 26 | self.timeout_time = time.time() + self.max_time 27 | reactor.callLater(0.1, self.check_running_processes) 28 | 29 | def check_running_processes(self): 30 | procs, self.procs = self.procs, [] 31 | for cmd, proc in procs: 32 | retcode = proc.poll() 33 | if retcode is not None: 34 | log.debug("cmd '%s' exited with return code: %s" % (cmd, retcode)) 35 | stdout, stderr = proc.stdout.read().strip(), proc.stderr.read().strip() 36 | if stdout: 37 | log.verbose("stdout:%s" % stdout) 38 | if stderr: 39 | log.warning("stderr:%s" % stderr) 40 | continue 41 | 42 | self.procs.append((cmd, proc)) 43 | 44 | if time.time() > self.timeout_time: 45 | for cmd, proc in self.procs: 46 | if cmd not in self.killing: 47 | self.killing.add(cmd) 48 | log.warning("cmd '%s' taking too long, terminating" % cmd) 49 | os.kill(proc.pid, SIGTERM) 50 | 51 | if time.time() > self.timeout_time + 1: 52 | for cmd, proc in self.procs: 53 | if cmd not in self.terminating: 54 | self.terminating.add(cmd) 55 | log.warning("cmd '%s' didn't terminate gracefully, killing" % cmd) 56 | os.kill(proc.pid, SIGKILL) 57 | 58 | if self.procs: 59 | reactor.callLater(0.1, self.check_running_processes) 60 | else: 61 | self.cb_on_completion() 62 | -------------------------------------------------------------------------------- /src/ggpzero/util/state.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import numpy as np 3 | 4 | import ggpzero_interface 5 | 6 | 7 | def encode_state(s): 8 | assert isinstance(s, (list, tuple)) 9 | a = np.array(s) 10 | aa = np.packbits(a) 11 | s = aa.tostring() 12 | return base64.encodestring(s) 13 | 14 | 15 | def decode_state(s): 16 | if isinstance(s, tuple): 17 | return s 18 | elif isinstance(s, list): 19 | return tuple(s) 20 | 21 | s = base64.decodestring(s) 22 | aa = np.fromstring(s, dtype=np.uint8) 23 | 24 | # if the state is not a multiple of 8, will grow by that 25 | # XXX horrible. We really should have these functions as methods to do encode/decode on some 26 | # smart Basestate object... ) 27 | a = np.unpackbits(aa) 28 | return tuple(a) 29 | 30 | 31 | def fast_decode_state(s): 32 | if isinstance(s, tuple): 33 | return s 34 | elif isinstance(s, list): 35 | return tuple(s) 36 | 37 | return ggpzero_interface.buf_to_tuple_reverse_bytes(base64.decodestring(s)) 38 | -------------------------------------------------------------------------------- /src/test/cpp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richemslie/galvanise_zero/52164bcd6f43d648736e1ae9e556a7f6412339d1/src/test/cpp/__init__.py -------------------------------------------------------------------------------- /src/test/cpp/test_interface.py: -------------------------------------------------------------------------------- 1 | ''' note these tests really need a GPU. XXX add a skip, or CPU versions of test. ''' 2 | 3 | import os 4 | import time 5 | 6 | import numpy as np 7 | import py.test 8 | 9 | import tensorflow as tf 10 | 11 | from ggplib.db import lookup 12 | 13 | from ggpzero.nn.manager import get_manager 14 | from ggpzero.util import cppinterface 15 | from ggpzero.defs import confs, templates 16 | 17 | 18 | def float_formatter0(x): 19 | return "%.0f" % x 20 | 21 | 22 | def float_formatter1(x): 23 | return "%.2f" % x 24 | 25 | 26 | def setup(): 27 | from ggplib.util.init import setup_once 28 | setup_once() 29 | 30 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 31 | tf.logging.set_verbosity(tf.logging.ERROR) 32 | 33 | np.set_printoptions(threshold=100000) 34 | 35 | 36 | def do_transformer(num_previous_states): 37 | game = "breakthrough" 38 | 39 | game_info = lookup.by_name(game) 40 | sm = game_info.get_sm() 41 | 42 | man = get_manager() 43 | 44 | # only multiple_policy_heads supported in c++ 45 | generation_descr = templates.default_generation_desc(game, 46 | multiple_policy_heads=True, 47 | num_previous_states=num_previous_states) 48 | 49 | t = man.get_transformer(game, generation_descr) 50 | 51 | # create transformer wrapper object 52 | c_transformer = cppinterface.create_c_transformer(t) 53 | 54 | nn = man.create_new_network(game, "small", generation_descr) 55 | verbose = True 56 | 57 | total_predictions = 0 58 | total_s0 = 0 59 | total_s1 = 0 60 | total_s2 = 0 61 | for ii in range(10): 62 | print ii 63 | start = time.time() 64 | array = c_transformer.test(cppinterface.sm_to_ptr(sm)) 65 | total_s0 += time.time() - start 66 | 67 | sz = len(array) / (t.num_channels * t.channel_size) 68 | 69 | total_predictions += sz 70 | 71 | array = array.reshape(sz, t.num_channels, t.num_cols, t.num_rows) 72 | total_s1 += time.time() - start 73 | 74 | if verbose: 75 | np.set_printoptions(threshold=np.inf, formatter={'float_kind' : float_formatter0}) 76 | print array 77 | 78 | # test we can actually predict 79 | res = nn.get_model().predict(array, batch_size=sz) 80 | # print res[0].shape 81 | # print res[1].shape 82 | 83 | total_s2 += time.time() - start 84 | 85 | if verbose: 86 | np.set_printoptions(threshold=np.inf, formatter={'float_kind' : float_formatter1}) 87 | print res 88 | 89 | print total_predictions, "time taken", [s * 1000 for s in (total_s0, total_s1, total_s2)] 90 | 91 | 92 | def test_transformer(): 93 | do_transformer(0) 94 | 95 | 96 | def test_transformer_with_prev_states(): 97 | do_transformer(3) 98 | 99 | 100 | def test_inline_supervisor_creation(): 101 | games = "breakthrough reversi breakthroughSmall connectFour".split() 102 | 103 | man = get_manager() 104 | 105 | for game in games: 106 | game_info = lookup.by_name(game) 107 | 108 | # get statemachine 109 | sm = game_info.get_sm() 110 | 111 | # only multiple_policy_heads supported in c++ 112 | generation_descr = templates.default_generation_desc(game, 113 | multiple_policy_heads=True) 114 | 115 | nn = man.create_new_network(game, "small", generation_descr) 116 | 117 | for batch_size in (1, 128, 1024): 118 | supervisor = cppinterface.Supervisor(sm, nn, batch_size=batch_size) 119 | supervisor = supervisor 120 | continue 121 | 122 | 123 | def setup_c4(batch_size=1024): 124 | game = "connectFour" 125 | man = get_manager() 126 | 127 | # only multiple_policy_heads supported in c++ 128 | generation_descr = templates.default_generation_desc(game, 129 | multiple_policy_heads=True) 130 | 131 | nn = man.create_new_network(game, "small", generation_descr) 132 | 133 | game_info = lookup.by_name(game) 134 | 135 | supervisor = cppinterface.Supervisor(game_info.get_sm(), nn, batch_size=batch_size) 136 | 137 | conf = templates.selfplay_config_template() 138 | return supervisor, conf 139 | 140 | 141 | def get_ctx(): 142 | class X: 143 | pass 144 | return X() 145 | 146 | 147 | def do_test(batch_size, get_sample_count, num_workers=0): 148 | supervisor, conf = setup_c4(batch_size=batch_size) 149 | supervisor.start_self_play(conf, num_workers) 150 | 151 | nonlocal = get_ctx 152 | nonlocal.samples = [] 153 | 154 | def cb(): 155 | new_samples = supervisor.fetch_samples() 156 | if new_samples: 157 | nonlocal.samples += new_samples 158 | print "Total rxd", len(nonlocal.samples) 159 | 160 | if len(nonlocal.samples) > get_sample_count: 161 | return True 162 | 163 | supervisor.poll_loop(do_stats=True, cb=cb) 164 | supervisor.dump_stats() 165 | 166 | # should be resumable 167 | nonlocal.samples = [] 168 | supervisor.reset_stats() 169 | supervisor.poll_loop(do_stats=True, cb=cb) 170 | supervisor.dump_stats() 171 | 172 | 173 | def test_inline_one(): 174 | do_test(batch_size=1, get_sample_count=42, num_workers=0) 175 | 176 | 177 | def test_inline_batched(): 178 | do_test(batch_size=1024, get_sample_count=5000, num_workers=1) 179 | 180 | 181 | def test_workers_batched(): 182 | do_test(batch_size=1024, get_sample_count=5000, num_workers=2) 183 | 184 | 185 | def test_inline_unique_states(): 186 | py.test.skip("too slow without GPU") 187 | get_sample_count = 250 188 | supervisor, conf = setup_c4(batch_size=1) 189 | supervisor.start_self_play(conf, num_workers=0) 190 | 191 | nonlocal = get_ctx 192 | nonlocal.samples = [] 193 | nonlocal.added_unique_states = False 194 | 195 | def cb(): 196 | new_samples = supervisor.fetch_samples() 197 | if new_samples: 198 | nonlocal.samples += new_samples 199 | print "Total rxd", len(nonlocal.samples) 200 | 201 | if len(nonlocal.samples) > get_sample_count: 202 | return True 203 | 204 | if not nonlocal.added_unique_states: 205 | if len(nonlocal.samples) > 100: 206 | nonlocal.added_unique_states = True 207 | for s in nonlocal.samples: 208 | supervisor.add_unique_state(s.state) 209 | 210 | elif len(nonlocal.samples) > 20: 211 | print 'clearing unique states' 212 | supervisor.clear_unique_states() 213 | 214 | supervisor.poll_loop(do_stats=True, cb=cb) 215 | supervisor.dump_stats() 216 | 217 | supervisor.clear_unique_states() 218 | nonlocal.samples = [] 219 | nonlocal.added_unique_states = False 220 | 221 | supervisor.poll_loop(do_stats=True, cb=cb) 222 | supervisor.dump_stats() 223 | -------------------------------------------------------------------------------- /src/test/nn/test_datacache.py: -------------------------------------------------------------------------------- 1 | 2 | # ggplib imports 3 | from ggplib.util import log 4 | 5 | # ggpzero imports 6 | from ggplib.db import lookup 7 | from ggpzero.nn import datacache 8 | from ggpzero.defs import templates 9 | 10 | from ggpzero.nn.manager import get_manager 11 | 12 | 13 | def setup_and_get_cache(game, prev_states, gen): 14 | # cp some files to a test area 15 | 16 | lookup.get_database() 17 | 18 | generation_descr = templates.default_generation_desc(game, 19 | multiple_policy_heads=True, 20 | num_previous_states=prev_states) 21 | man = get_manager() 22 | transformer = man.get_transformer(game, generation_descr) 23 | return datacache.DataCache(transformer, gen) 24 | 25 | 26 | def test_summary(): 27 | from ggplib.util.init import setup_once 28 | setup_once() 29 | 30 | game = "amazons_10x10" 31 | # game = "hexLG13" 32 | 33 | cache = setup_and_get_cache(game, 1, "h3") 34 | 35 | for x in cache.list_files(): 36 | print x 37 | 38 | cache.sync() 39 | 40 | # good to see some outputs 41 | for index in (10, 420, 42): 42 | channels = cache.db[index]["channels"] 43 | log.info('train input, shape: %s. Example: %s' % (channels.shape, channels)) 44 | 45 | for name in cache.db.names[1:]: 46 | log.info("Outputs: %s" % name) 47 | output = cache.db[index]["channels"] 48 | log.info('train output, shape: %s. Example: %s' % (output.shape, output)) 49 | 50 | 51 | def test_chunking(): 52 | game = "breakthroughSmall" 53 | cache = setup_and_get_cache(game, 1, "t1") 54 | cache.sync() 55 | 56 | buckets_def = [(1, 1.0), (3, 0.75), (6, 0.5), (-1, 0.1)] 57 | buckets = datacache.Buckets(buckets_def) 58 | 59 | # max_training_count=None, max_validation_count=None 60 | indexer = cache.create_chunk_indexer(buckets) 61 | 62 | # gen0 63 | first = len(cache.summary.step_summaries) - 1 64 | # gen most recent 65 | last = 0 66 | 67 | print 'here!' 68 | print indexer.create_indices_for_level(first, validation=False, max_size=42) 69 | print indexer.create_indices_for_level(first, validation=True, max_size=42) 70 | 71 | z = indexer.get_indices(max_size=100000) 72 | #z.sort() 73 | #print z 74 | 75 | 76 | def test_include_size(): 77 | game = "breakthroughSmall" 78 | cache = setup_and_get_cache(game, 1, "t1") 79 | cache.sync() 80 | 81 | buckets_def = [(1, 1.00), (3, 0.75), (6, 0.5), (-1, 0.1)] 82 | buckets = datacache.Buckets(buckets_def) 83 | 84 | # max_training_count=None, max_validation_count=None 85 | indexer = cache.create_chunk_indexer(buckets) 86 | 87 | z = indexer.get_indices(max_size=40000) 88 | z = indexer.get_indices(max_size=40000, include_all=2) 89 | #z.sort() 90 | #print z 91 | -------------------------------------------------------------------------------- /src/test/nn/test_model_draws.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from ggplib.util.init import setup_once 8 | from ggplib.db import lookup 9 | from ggpzero.util import keras 10 | 11 | from ggpzero.nn.manager import get_manager 12 | 13 | 14 | def setup(): 15 | # set up ggplib 16 | setup_once() 17 | 18 | # ensure we have database with ggplib 19 | from gzero_games.ggphack import addgame 20 | lookup.get_database() 21 | 22 | # initialise keras/tf 23 | keras.init() 24 | 25 | # just ensures we have the manager ready 26 | get_manager() 27 | 28 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 29 | tf.logging.set_verbosity(tf.logging.ERROR) 30 | 31 | np.set_printoptions(threshold=100000) 32 | 33 | 34 | def advance_state(sm, basestate): 35 | sm.update_bases(basestate) 36 | 37 | # leaks, but who cares, it is a test 38 | joint_move = sm.get_joint_move() 39 | base_state = sm.new_base_state() 40 | 41 | for role_index in range(len(sm.get_roles())): 42 | ls = sm.get_legal_state(role_index) 43 | choice = ls.get_legal(random.randrange(0, ls.get_count())) 44 | joint_move.set(role_index, choice) 45 | 46 | # play move, the base_state will be new state 47 | sm.next_state(joint_move, base_state) 48 | return base_state 49 | 50 | 51 | def test_baduk(): 52 | def show(pred): 53 | win_0, win_1, draw = list(pred.scores) 54 | print "wins/draw", win_0, win_1, draw 55 | win_0 += draw / 2 56 | win_1 += draw / 2 57 | print "wins only", win_0, win_1, win_0 + win_1 58 | 59 | man = get_manager() 60 | 61 | game = "baduk_9x9" 62 | 63 | # create a nn 64 | game_info = lookup.by_name(game) 65 | sm = game_info.get_sm() 66 | 67 | nn = man.load_network(game, "h3_11") 68 | nn.summary() 69 | 70 | basestate = sm.get_initial_state() 71 | 72 | predictions = nn.predict_1(basestate.to_list()) 73 | print predictions.policies, predictions.scores 74 | 75 | predictions = nn.predict_n([basestate.to_list(), basestate.to_list()]) 76 | assert len(predictions) == 2 and len(predictions[0].policies) == 2 and len(predictions[0].scores) == 3 77 | show(predictions[0]) 78 | 79 | prevs = [] 80 | for i in range(4): 81 | prevs.append(basestate) 82 | basestate = advance_state(game_info.get_sm(), basestate) 83 | 84 | prediction = nn.predict_1(basestate.to_list(), [p.to_list() for p in prevs]) 85 | show(prediction) 86 | -------------------------------------------------------------------------------- /src/test/nn/test_model_external.py: -------------------------------------------------------------------------------- 1 | ''' 2 | testing with games not defined in GDL 3 | ''' 4 | 5 | import os 6 | import random 7 | 8 | import numpy as np 9 | 10 | import tensorflow as tf 11 | 12 | from ggplib.util.init import setup_once 13 | from ggplib.db import lookup 14 | 15 | from ggpzero.util import keras 16 | from ggpzero.defs import templates 17 | 18 | from ggpzero.nn.manager import get_manager 19 | 20 | def setup(): 21 | # set up ggplib 22 | setup_once() 23 | 24 | # ensure we have database with ggplib 25 | lookup.get_database() 26 | 27 | # initialise keras/tf 28 | keras.init() 29 | 30 | # just ensures we have the manager ready 31 | get_manager() 32 | 33 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 34 | tf.logging.set_verbosity(tf.logging.ERROR) 35 | 36 | np.set_printoptions(threshold=100000) 37 | 38 | 39 | #games = ["draughts_10x10", "draughts_killer_10x10"] 40 | games = ["hex_lg_19"] 41 | 42 | 43 | def advance_state(sm, basestate, do_swap=False): 44 | sm.update_bases(basestate) 45 | 46 | # leaks, but who cares, it is a test 47 | joint_move = sm.get_joint_move() 48 | base_state = sm.new_base_state() 49 | 50 | for role_index in range(len(sm.get_roles())): 51 | ls = sm.get_legal_state(role_index) 52 | choice = ls.get_legal(random.randrange(0, ls.get_count())) 53 | 54 | if do_swap and role_index == 1: 55 | joint_move.set(role_index, 1) 56 | else: 57 | joint_move.set(role_index, choice) 58 | 59 | # play move, the base_state will be new state 60 | sm.next_state(joint_move, base_state) 61 | return base_state 62 | 63 | 64 | def test_basic_config(): 65 | man = get_manager() 66 | 67 | for game in games: 68 | # look game from database 69 | game_info = lookup.by_name(game) 70 | assert game == game_info.game 71 | 72 | sm = game_info.get_sm() 73 | basestate = sm.get_initial_state() 74 | 75 | # lookup game in manager 76 | transformer = man.get_transformer(game) 77 | 78 | print "rows x cols", transformer.num_rows, transformer.num_cols 79 | print transformer.x_cords 80 | print transformer.y_cords 81 | 82 | basestate = advance_state(game_info.get_sm(), basestate) 83 | print "1" 84 | print "=" * 50 85 | 86 | print transformer.state_to_channels(basestate.to_list()) 87 | 88 | print "2" 89 | print "=" * 50 90 | basestate = advance_state(game_info.get_sm(), basestate, do_swap=True) 91 | print transformer.state_to_channels(basestate.to_list()) 92 | 93 | for ii in range(20): 94 | basestate = advance_state(game_info.get_sm(), basestate) 95 | 96 | print "3" 97 | print "=" * 50 98 | print transformer.state_to_channels(basestate.to_list()) 99 | -------------------------------------------------------------------------------- /src/test/nn/test_new_transformer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | import tensorflow as tf 6 | 7 | from ggplib.util.init import setup_once 8 | 9 | from ggplib.db import lookup 10 | 11 | from ggpzero.util import keras 12 | from ggpzero.defs import gamedesc, templates 13 | from ggpzero.nn.bases import GdlBasesTransformer 14 | 15 | from ggpzero.nn.manager import get_manager 16 | 17 | 18 | def setup(): 19 | # set up ggplib 20 | setup_once() 21 | 22 | # ensure we have database with ggplib 23 | lookup.get_database() 24 | 25 | # initialise keras/tf 26 | keras.init() 27 | 28 | # just ensures we have the manager ready 29 | get_manager() 30 | 31 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 32 | tf.logging.set_verbosity(tf.logging.ERROR) 33 | 34 | np.set_printoptions(threshold=100000, precision=2) 35 | 36 | 37 | def test_game_descriptions(): 38 | game_descs = gamedesc.Games() 39 | names = [name for name in dir(game_descs) if name[0] != "_"] 40 | 41 | names = ["breakthroughSmall", "breakthrough", "englishDraughts"] 42 | names = ["connect6"] 43 | 44 | for name in names: 45 | print 46 | print "=" * 80 47 | print name 48 | print "=" * 80 49 | 50 | meth = getattr(game_descs, name) 51 | game_description = meth() 52 | 53 | print name, game_description.game 54 | print game_description 55 | print "-" * 80 56 | 57 | game_info = lookup.by_name(game_description.game) 58 | 59 | # create GenerationDescription 60 | generation_descr = templates.default_generation_desc(game_description.game) 61 | 62 | transformer = GdlBasesTransformer(game_info, generation_descr, game_description) 63 | transformer = transformer 64 | 65 | sm = game_info.get_sm() 66 | basestate = sm.get_initial_state() 67 | 68 | from test_model import advance_state 69 | 70 | for i in range(25): 71 | print "move made", i 72 | print game_info.model.basestate_to_str(basestate) 73 | print transformer.state_to_channels(basestate.to_list()) 74 | 75 | sm.update_bases(basestate) 76 | if sm.is_terminal(): 77 | break 78 | 79 | basestate = advance_state(sm, basestate) 80 | -------------------------------------------------------------------------------- /src/test/nn/test_speed.py: -------------------------------------------------------------------------------- 1 | ''' tests raw speed of predictions on GPU. ''' 2 | 3 | import gc 4 | import time 5 | 6 | import numpy as np 7 | 8 | from ggpzero.defs import confs, templates 9 | 10 | from ggpzero.nn.manager import get_manager 11 | from ggpzero.nn import train 12 | 13 | 14 | def config(): 15 | conf = confs.TrainNNConfig("reversi") 16 | conf.generation_prefix = "x4" 17 | conf.overwrite_existing = True 18 | conf.next_step = 20 19 | conf.validation_split = 1.0 20 | conf.starting_step = 10 21 | return conf 22 | 23 | 24 | class Runner: 25 | def __init__(self, data, keras_model): 26 | self.data = data 27 | self.keras_model = keras_model 28 | self.sample_count = len(data.inputs) 29 | 30 | def warmup(self): 31 | res = [] 32 | for i in range(2): 33 | idx, end_idx = i * batch_size, (i + 1) * batch_size 34 | print i, idx, end_idx 35 | inputs = np.array(data.inputs[idx:end_idx]) 36 | res.append(keras_model.predict(inputs, batch_size=batch_size)) 37 | print res 38 | 39 | 40 | def perf_test(batch_size): 41 | gc.collect() 42 | print 'Starting speed run' 43 | num_batches = sample_count / batch_size + 1 44 | print "batches %s, batch_size %s, inputs: %s" % (num_batches, 45 | batch_size, 46 | len(data.inputs)) 47 | 48 | def go(): 49 | ITERATIONS = 3 50 | 51 | man = get_manager() 52 | 53 | # get data 54 | train_config = config() 55 | 56 | # get nn to test speed on 57 | transformer = man.get_transformer(train_config.game) 58 | trainer = train.TrainManager(train_config, transformer) 59 | 60 | nn_model_config = templates.nn_model_config_template(train_config.game, "small", transformer) 61 | generation_descr = templates.default_generation_desc(train_config.game) 62 | trainer.get_network(nn_model_config, generation_descr) 63 | 64 | data = trainer.gather_data() 65 | 66 | r = Runner(trainer.gather_data(), trainer.nn.get_model()) 67 | r.warmup() 68 | 69 | 70 | def speed_test(): 71 | ITERATIONS = 3 72 | 73 | man = get_manager() 74 | 75 | # get data 76 | train_config = config() 77 | 78 | # get nn to test speed on 79 | transformer = man.get_transformer(train_config.game) 80 | trainer = train.TrainManager(train_config, transformer) 81 | 82 | nn_model_config = templates.nn_model_config_template(train_config.game, "small", transformer) 83 | generation_descr = templates.default_generation_desc(train_config.game) 84 | trainer.get_network(nn_model_config, generation_descr) 85 | 86 | data = trainer.gather_data() 87 | 88 | res = [] 89 | 90 | batch_size = 4096 91 | sample_count = len(data.inputs) 92 | keras_model = trainer.nn.get_model() 93 | 94 | # warm up 95 | for i in range(2): 96 | idx, end_idx = i * batch_size, (i + 1) * batch_size 97 | print i, idx, end_idx 98 | inputs = np.array(data.inputs[idx:end_idx]) 99 | res.append(keras_model.predict(inputs, batch_size=batch_size)) 100 | print res[0] 101 | 102 | for _ in range(ITERATIONS): 103 | res = [] 104 | times = [] 105 | gc.collect() 106 | 107 | print 'Starting speed run' 108 | num_batches = sample_count / batch_size + 1 109 | print "batches %s, batch_size %s, inputs: %s" % (num_batches, 110 | batch_size, 111 | len(data.inputs)) 112 | for i in range(num_batches): 113 | idx, end_idx = i * batch_size, (i + 1) * batch_size 114 | inputs = np.array(data.inputs[idx:end_idx]) 115 | print "inputs", len(inputs) 116 | s = time.time() 117 | Y = keras_model.predict(inputs, batch_size=batch_size) 118 | times.append(time.time() - s) 119 | print "outputs", len(Y[0]) 120 | 121 | print "times taken", times 122 | print "total_time taken", sum(times) 123 | print "predictions per second", sample_count / float(sum(times)) 124 | 125 | 126 | if __name__ == "__main__": 127 | from ggpzero.util.main import main_wrap 128 | main_wrap(go) 129 | -------------------------------------------------------------------------------- /src/test/nn/test_templates.py: -------------------------------------------------------------------------------- 1 | from ggplib.util.init import setup_once 2 | from ggplib.db import lookup 3 | 4 | from ggpzero.util import attrutil 5 | 6 | from ggpzero.defs import confs, templates 7 | from ggpzero.nn import train 8 | from ggpzero.nn.manager import get_manager 9 | from ggpzero.nn.model import get_network_model 10 | from ggpzero.nn.network import NeuralNetwork 11 | 12 | man = get_manager() 13 | 14 | 15 | def setup(): 16 | # set up ggplib 17 | setup_once() 18 | 19 | # ensure we have database with ggplib 20 | lookup.get_database() 21 | 22 | 23 | def test_generation_desc(): 24 | game = "breakthrough" 25 | gen_prefix = "x1" 26 | prev_states = 1 27 | gen_desc = templates.default_generation_desc(game, 28 | gen_prefix, 29 | multiple_policy_heads=True, 30 | num_previous_states=prev_states) 31 | attrutil.pprint(gen_desc) 32 | 33 | 34 | def test_nn_model_config_template(): 35 | game = "breakthrough" 36 | gen_prefix = "x1" 37 | prev_states = 1 38 | 39 | gen_desc = templates.default_generation_desc(game, 40 | gen_prefix, 41 | multiple_policy_heads=True, 42 | num_previous_states=prev_states) 43 | transformer = man.get_transformer(game, gen_desc) 44 | 45 | model = templates.nn_model_config_template("breakthrough", "small", transformer) 46 | attrutil.pprint(model) 47 | 48 | keras_model = get_network_model(model, gen_desc) 49 | network = NeuralNetwork(transformer, keras_model, gen_desc) 50 | print network 51 | network.summary() 52 | 53 | 54 | def test_nn_model_config_template2(): 55 | game = "breakthrough" 56 | gen_prefix = "x1" 57 | prev_states = 1 58 | 59 | gen_desc = templates.default_generation_desc(game, 60 | gen_prefix, 61 | multiple_policy_heads=True, 62 | num_previous_states=prev_states) 63 | transformer = man.get_transformer(game, gen_desc) 64 | 65 | model = templates.nn_model_config_template("breakthrough", "small", 66 | transformer, features=True) 67 | attrutil.pprint(model) 68 | 69 | keras_model = get_network_model(model, gen_desc) 70 | network = NeuralNetwork(transformer, keras_model, gen_desc) 71 | print network 72 | network.summary() 73 | 74 | 75 | def test_train_config_template(): 76 | game = "breakthrough" 77 | gen_prefix = "x1" 78 | 79 | train_config = templates.train_config_template(game, gen_prefix) 80 | attrutil.pprint(train_config) 81 | 82 | 83 | def test_base_puct_config(): 84 | config = templates.base_puct_config(dirichlet_noise_pct=0.5) 85 | attrutil.pprint(config) 86 | 87 | 88 | def test_selfplay_config(): 89 | config = templates.selfplay_config_template() 90 | attrutil.pprint(config) 91 | 92 | 93 | def test_server_config_template(): 94 | game = "breakthrough" 95 | gen_prefix = "x1" 96 | prev_states = 1 97 | config = templates.server_config_template(game, gen_prefix, prev_states) 98 | attrutil.pprint(config) 99 | -------------------------------------------------------------------------------- /src/test/player/test_player.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from ggplib.player import get 4 | from ggplib.player.gamemaster import GameMaster 5 | from ggplib.db import lookup 6 | 7 | from ggpzero.util import attrutil 8 | from ggpzero.defs import confs, templates 9 | from ggpzero.nn.manager import get_manager 10 | 11 | from ggpzero.player.puctplayer import PUCTPlayer 12 | from ggpzero.battle.bt import pretty_board 13 | 14 | 15 | BOARD_SIZE = 7 16 | GAME = "bt_7" 17 | RANDOM_GEN = "rand_0" 18 | 19 | GOOD_GEN1 = "x1_132" 20 | 21 | 22 | def setup(): 23 | import tensorflow as tf 24 | 25 | from ggplib.util.init import setup_once 26 | setup_once() 27 | 28 | from ggpzero.util.keras import init 29 | init() 30 | 31 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 32 | tf.logging.set_verbosity(tf.logging.ERROR) 33 | 34 | import numpy as np 35 | np.set_printoptions(threshold=100000) 36 | 37 | man = get_manager() 38 | if not man.can_load(GAME, RANDOM_GEN): 39 | network = man.create_new_network(GAME) 40 | man.save_network(network, RANDOM_GEN) 41 | 42 | 43 | def play(player_white, player_black, move_time=0.5): 44 | gm = GameMaster(lookup.by_name(GAME), verbose=True) 45 | gm.add_player(player_white, "white") 46 | gm.add_player(player_black, "black") 47 | 48 | gm.start(meta_time=15, move_time=move_time) 49 | 50 | move = None 51 | while not gm.finished(): 52 | 53 | # print out the board 54 | pretty_board(BOARD_SIZE, gm.sm) 55 | 56 | move = gm.play_single_move(last_move=move) 57 | 58 | gm.finalise_match(move) 59 | 60 | 61 | def test_random(): 62 | # add two players 63 | # simplemcts vs RANDOM_GEN 64 | pymcs = get.get_player("simplemcts") 65 | pymcs.max_run_time = 0.25 66 | 67 | eval_config = templates.base_puct_config(verbose=True, 68 | max_dump_depth=1) 69 | puct_config = confs.PUCTPlayerConfig("gzero", 70 | True, 71 | 100, 72 | 0, 73 | RANDOM_GEN, 74 | eval_config) 75 | 76 | attrutil.pprint(puct_config) 77 | 78 | puct_player = PUCTPlayer(puct_config) 79 | 80 | play(pymcs, puct_player) 81 | 82 | 83 | def test_trained(): 84 | # simplemcts vs GOOD_GEN 85 | simple = get.get_player("simplemcts") 86 | simple.max_run_time = 0.5 87 | 88 | eval_config = confs.PUCTEvaluatorConfig(verbose=True, 89 | puct_constant=0.85, 90 | puct_constant_root=3.0, 91 | 92 | dirichlet_noise_pct=-1, 93 | 94 | fpu_prior_discount=0.25, 95 | fpu_prior_discount_root=0.15, 96 | 97 | choose="choose_temperature", 98 | temperature=2.0, 99 | depth_temperature_max=10.0, 100 | depth_temperature_start=0, 101 | depth_temperature_increment=0.75, 102 | depth_temperature_stop=1, 103 | random_scale=1.0, 104 | batch_size=1, 105 | max_dump_depth=1) 106 | 107 | puct_config = confs.PUCTPlayerConfig("gzero", 108 | True, 109 | 200, 110 | 0, 111 | GOOD_GEN1, 112 | eval_config) 113 | attrutil.pprint(puct_config) 114 | 115 | puct_player = PUCTPlayer(puct_config) 116 | 117 | play(simple, puct_player) 118 | #play(puct_player, simple) 119 | -------------------------------------------------------------------------------- /src/test/test_state.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import random 3 | 4 | from ggplib.db import lookup 5 | 6 | from ggpzero.util.state import encode_state, decode_state, fast_decode_state 7 | 8 | games = ["breakthrough", "hex", "hexLG13"] 9 | 10 | def advance_state(sm, basestate): 11 | sm.update_bases(basestate) 12 | 13 | # leaks, but who cares, it is a test 14 | joint_move = sm.get_joint_move() 15 | base_state = sm.new_base_state() 16 | 17 | for role_index in range(len(sm.get_roles())): 18 | ls = sm.get_legal_state(role_index) 19 | choice = ls.get_legal(random.randrange(0, ls.get_count())) 20 | joint_move.set(role_index, choice) 21 | 22 | # play move, the base_state will be new state 23 | sm.next_state(joint_move, base_state) 24 | return base_state 25 | 26 | 27 | def test_simple(): 28 | for game in games: 29 | sm = lookup.by_name(game).get_sm() 30 | bs_0 = sm.get_initial_state() 31 | 32 | bs_1 = sm.new_base_state() 33 | bs_1.assign(bs_0) 34 | for i in range(3): 35 | advance_state(sm, bs_1) 36 | 37 | assert bs_0 != bs_1 38 | 39 | l0 = decode_state(encode_state(bs_0.to_list())) 40 | l1 = decode_state(encode_state(bs_1.to_list())) 41 | 42 | decode_bs_0 = sm.new_base_state() 43 | decode_bs_1 = sm.new_base_state() 44 | decode_bs_0.from_list(l0) 45 | decode_bs_1.from_list(l1) 46 | 47 | assert bs_0.to_string() == bs_0.to_string() 48 | 49 | assert decode_bs_0 == bs_0 50 | assert decode_bs_0.hash_code() == bs_0.hash_code() 51 | 52 | print len(decode_bs_0.to_string()) 53 | print len(bs_0.to_string()) 54 | 55 | #assert decode_bs_0.to_string() == bs_0.to_string() 56 | 57 | assert decode_bs_1 == bs_1 58 | assert decode_bs_1.hash_code() == bs_1.hash_code() 59 | assert decode_bs_1.to_string() == bs_1.to_string() 60 | 61 | 62 | def test_more(): 63 | for game in games: 64 | print "doing", game 65 | sm = lookup.by_name(game).get_sm() 66 | bs_0 = sm.get_initial_state() 67 | 68 | bs_1 = sm.new_base_state() 69 | bs_1.assign(bs_0) 70 | for i in range(5): 71 | advance_state(sm, bs_1) 72 | 73 | assert bs_0 != bs_1 74 | 75 | # states to compare 76 | decode_bs_0 = sm.new_base_state() 77 | decode_bs_1 = sm.new_base_state() 78 | decode_direct_bs_0 = sm.new_base_state() 79 | decode_direct_bs_1 = sm.new_base_state() 80 | 81 | # encode as before 82 | en_0 = encode_state(bs_0.to_list()) 83 | en_1 = encode_state(bs_1.to_list()) 84 | 85 | # decode as before 86 | l0 = decode_state(en_0) 87 | l1 = decode_state(en_1) 88 | decode_bs_0.from_list(l0) 89 | decode_bs_1.from_list(l1) 90 | 91 | # decode directly 92 | decode_direct_bs_0.from_string(base64.decodestring(en_0)) 93 | decode_direct_bs_1.from_string(base64.decodestring(en_1)) 94 | 95 | # all checks 96 | assert decode_bs_0 == bs_0 97 | assert decode_bs_0.hash_code() == bs_0.hash_code() 98 | assert decode_bs_0.to_string() == bs_0.to_string() 99 | 100 | assert decode_direct_bs_0 == bs_0 101 | assert decode_direct_bs_0.hash_code() == bs_0.hash_code() 102 | assert decode_direct_bs_0.to_string() == bs_0.to_string() 103 | 104 | assert decode_bs_1 == bs_1 105 | assert decode_bs_1.hash_code() == bs_1.hash_code() 106 | assert decode_bs_1.to_string() == bs_1.to_string() 107 | 108 | assert decode_direct_bs_1 == bs_1 109 | assert decode_direct_bs_1.hash_code() == bs_1.hash_code() 110 | assert decode_direct_bs_1.to_string() == bs_1.to_string() 111 | 112 | print "good", game 113 | 114 | 115 | def test_speed(): 116 | import time 117 | 118 | for game in games: 119 | print "doing", game 120 | sm = lookup.by_name(game).get_sm() 121 | 122 | # a couple of states 123 | bs_0 = sm.get_initial_state() 124 | 125 | bs_1 = sm.new_base_state() 126 | bs_1.assign(bs_0) 127 | for i in range(5): 128 | advance_state(sm, bs_1) 129 | 130 | # encode states 131 | encoded_0 = encode_state(bs_0.to_list()) 132 | encoded_1 = encode_state(bs_1.to_list()) 133 | 134 | assert decode_state(encoded_0) == fast_decode_state(encoded_0) 135 | assert decode_state(encoded_1) == fast_decode_state(encoded_1) 136 | 137 | s = time.time() 138 | for i in range(10000): 139 | l0 = decode_state(encoded_0) 140 | l1 = decode_state(encoded_1) 141 | 142 | print "time taken %.3f msecs" % ((time.time() - s) * 1000) 143 | 144 | s = time.time() 145 | for i in range(10000): 146 | l0 = fast_decode_state(encoded_0) 147 | l1 = fast_decode_state(encoded_1) 148 | 149 | print "time taken %.3f msecs" % ((time.time() - s) * 1000) 150 | -------------------------------------------------------------------------------- /src/test/test_util.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | 3 | import attr 4 | 5 | from ggpzero.util import attrutil, func, broker, runprocs 6 | 7 | 8 | def setup(): 9 | from ggplib.util.init import setup_once 10 | setup_once() 11 | 12 | 13 | def test_chunks(): 14 | res = list(func.chunks(range(10), 2)) 15 | assert res[0] == [0, 1] 16 | assert res[4] == [8, 9] 17 | 18 | 19 | def test_challenge(): 20 | m = broker.challenge(128) 21 | 22 | print m 23 | assert len(m) == 128 24 | 25 | m = broker.response(m) 26 | 27 | print m 28 | assert len(m) == 128 29 | 30 | 31 | @attrutil.register_attrs 32 | class DummyMsg(object): 33 | what = attr.ib() 34 | 35 | 36 | class BrokerX(broker.Broker): 37 | the_client = None 38 | got = None 39 | 40 | def on_call_me(self, client, msg): 41 | self.the_client = client 42 | self.got = msg.what 43 | 44 | 45 | def test_broker(): 46 | b = BrokerX() 47 | 48 | b.register(DummyMsg, b.on_call_me) 49 | m = DummyMsg("hello") 50 | b.onMessage(42, broker.Message(broker.clz_to_name(DummyMsg), payload=m)) 51 | 52 | assert b.the_client == 42 53 | assert b.got == "hello" 54 | 55 | 56 | def test_client(): 57 | client = broker.Client(None) 58 | client.logical_connection = True 59 | 60 | data = client.format_msg(DummyMsg("hello world!")) 61 | client.rxd.append(data) 62 | 63 | msg = list(client.unbuffer_data())[0] 64 | 65 | assert msg.name == "test_util.DummyMsg" 66 | assert msg.payload.what == "hello world!" 67 | 68 | 69 | def test_broker_frag_msg(): 70 | client = broker.Client(None) 71 | client.logical_connection = True 72 | 73 | data = client.format_msg(DummyMsg("hello world2!")) 74 | assert len(data) > 14 75 | client.rxd.append(data[:10]) 76 | 77 | empty = list(client.unbuffer_data()) 78 | assert empty == [] 79 | 80 | client.rxd.append(data[10:]) 81 | buflens = [len(x) for x in client.rxd] 82 | assert buflens[0] == 10 83 | assert buflens[1] > 10 84 | 85 | msg = list(client.unbuffer_data())[0] 86 | assert msg is not None 87 | 88 | assert msg.name == "test_util.DummyMsg" 89 | assert msg.payload.what == "hello world2!" 90 | 91 | 92 | @attrutil.register_attrs 93 | class Container(object): 94 | x = attr.ib() 95 | y = attr.ib() 96 | z = attr.ib() 97 | 98 | 99 | def test_attrs_recursive(): 100 | print 'test_attrs_recursive.1' 101 | 102 | c = Container(DummyMsg('a'), 103 | DummyMsg('b'), 104 | DummyMsg('c')) 105 | 106 | m = Container(DummyMsg('o'), 107 | DummyMsg('p'), 108 | DummyMsg(c)) 109 | 110 | d = attrutil.asdict_plus(m) 111 | pprint(d) 112 | 113 | r = attrutil.fromdict_plus(d) 114 | assert isinstance(r, Container) 115 | 116 | assert r.x.what == 'o' 117 | assert r.z.what.x.what == 'a' 118 | 119 | json_str = attrutil.attr_to_json(m, indent=4) 120 | print json_str 121 | 122 | k = attrutil.json_to_attr(json_str) 123 | assert k.x.what == 'o' 124 | assert k.z.what.x.what == 'a' 125 | 126 | 127 | @attrutil.register_attrs 128 | class Sample(object): 129 | name = attr.ib() 130 | data = attr.ib(attr.Factory(list)) 131 | 132 | 133 | @attrutil.register_attrs 134 | class Samples(object): 135 | k = attr.ib() 136 | samples = attr.ib(attr.Factory(list)) 137 | 138 | 139 | def test_attrs_listof(): 140 | s0 = Sample('s0', [1, 2, 3, 4, 5]) 141 | s1 = Sample('s1', [5, 4, 3, 2, 1]) 142 | 143 | samples = Samples(42, [s0, s1]) 144 | 145 | d = attrutil.asdict_plus(samples) 146 | pprint(d) 147 | 148 | r = attrutil.fromdict_plus(d) 149 | 150 | pprint(r) 151 | 152 | assert isinstance(r, Samples) 153 | assert len(r.samples) == 2 154 | assert r.samples[0].name == "s0" 155 | assert r.samples[1].data[1] == 4 156 | 157 | s0_clone = attrutil.clone(s0) 158 | print s0_clone 159 | 160 | 161 | def test_attrs_clone(): 162 | s0 = Sample('s0', [1, 2, 3, 4, 5]) 163 | s0_clone = attrutil.clone(s0) 164 | s0_clone.name = "asd" 165 | s0_clone.data.append(6) 166 | 167 | assert s0_clone.name == "asd" and s0.name == "s0" 168 | assert s0_clone.data[-1] == 6 169 | assert s0.data[-1] == 5 170 | 171 | s1 = Sample('s0', (1, 2, 3, 4, 5)) 172 | s1.data = list(s1.data) 173 | s1_clone = attrutil.clone(s1) 174 | s1_clone.data[2] = 42 175 | assert s1_clone.data[2] == 42 176 | assert s1.data[2] == 3 177 | 178 | 179 | def test_runcmds(): 180 | from twisted.internet import reactor 181 | 182 | def done(): 183 | print "SUCCESS" 184 | reactor.crash() 185 | 186 | cmds = ["ls -l", "sleep 3", "python2 -c 'import sys; print >>sys.stderr, 123'"] 187 | run_cmds = runprocs.RunCmds(cmds, cb_on_completion=done) 188 | 189 | reactor.callLater(0.1, run_cmds.spawn) 190 | reactor.run() 191 | 192 | 193 | def test_runcmds2(): 194 | from twisted.internet import reactor 195 | 196 | def done(): 197 | print "SUCCESS" 198 | reactor.crash() 199 | 200 | cmds = ["ls -l", "ls -l"] 201 | 202 | try: 203 | run_cmds = runprocs.RunCmds(cmds, cb_on_completion=done) 204 | raise Exception("Should not get here") 205 | except AssertionError: 206 | pass 207 | 208 | cmds = ['python -c "import time, signal; signal.signal(signal.SIGTERM, lambda a,b: time.sleep(5)); time.sleep(5)"'] 209 | run_cmds = runprocs.RunCmds(cmds, cb_on_completion=done) 210 | 211 | reactor.callLater(0.1, run_cmds.spawn) 212 | reactor.run() 213 | 214 | 215 | @attr.s 216 | class BadClass(object): 217 | k = attr.ib() 218 | z = attr.ib() 219 | 220 | 221 | def test_attrs_bad(): 222 | bc = BadClass(1, 2) 223 | 224 | try: 225 | res = attrutil.asdict_plus(bc) 226 | assert False, "Do not get here %s" % res 227 | 228 | except attrutil.SerialiseException as exc: 229 | # this is what we want 230 | print exc 231 | pass 232 | --------------------------------------------------------------------------------