├── .gitignore ├── LICENSE ├── README.md ├── notebooks ├── Examples.ipynb ├── MatchgameRule50.ipynb ├── TrainingGameValidate_14801114.ipynb ├── lcz_training │ ├── TrainLeela.ipynb │ └── lcztools.ini └── lcztools.ini ├── scripts ├── download_match_games │ ├── download_match_games.py │ └── make_training_data.py ├── opening │ ├── data │ │ ├── bookfish_opening_seqs_10.txt │ │ ├── bookfish_opening_seqs_12.txt │ │ ├── bookfish_opening_seqs_2.txt │ │ ├── bookfish_opening_seqs_4.txt │ │ ├── bookfish_opening_seqs_6.txt │ │ ├── bookfish_opening_seqs_8.txt │ │ └── readme.MD │ └── polyglot_to_seqs.py ├── train_to_pgn.py └── wip │ └── analyse_games.py ├── setup.py ├── src └── lcztools │ ├── __init__.py │ ├── _leela_board.py │ ├── _old_leela_board.py │ ├── _uci_to_idx.py │ ├── backend │ ├── __init__.py │ ├── _leela_client_net.py │ ├── _leela_net.py │ ├── _leela_tf_net.py │ ├── _leela_torch_eval_net.py │ ├── _leela_torch_net.py │ └── net_server │ │ ├── __init__.py │ │ ├── client_example.py │ │ └── server.py │ ├── config.py │ ├── testing │ ├── __init__.py │ ├── _archive_unused │ │ ├── __init__.py │ │ └── leela_engine.py │ ├── lczero_web │ │ ├── __init__.py │ │ ├── networks.py │ │ └── web_game.py │ ├── leela_engine_lc0.py │ └── train_parser.py │ ├── util │ ├── __init__.py │ └── _shuffle_buffer.py │ └── weights │ ├── __init__.py │ └── _weights_file.py ├── tests ├── Untitled.ipynb ├── _archive_unused │ └── test_net_eq_engine.py ├── download_latest_network.py ├── lcztools.ini ├── test.py ├── test_net_eq_engine_lc0.py ├── train_to_pgn.py └── update_rule50_weights.py └── wip_archive └── leela_train_to_pgn.py /.gitignore: -------------------------------------------------------------------------------- 1 | .* 2 | !/.gitignore 3 | leelalogs 4 | __pycache__ 5 | *.pyc 6 | *.egg-info/ 7 | lczero.stderr.txt 8 | .ipynb_checkpoints 9 | lczero_log.txt 10 | lczero.stderr.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # lczero-tools 2 | Python utilities for experimenting with Leela Chess Zero a neural network based chess engine: https://github.com/glinscott/leela-chess/ 3 | 4 | ## IMPORTANT - DEFUNCT / NO LONGER BEING MAINTAINED 5 | ## Please go to https://github.com/LeelaChessZero/lc0/tree/master/scripts/pybind for the official python bindings! 6 | I have not actively worked on this in years. Many things have changed to Leela Chess Zero's network architecture in the meantime, including a different weights file format. I have intended to come back to this to update it, and still would like to, but for now this project is where it is. This will not work with the latest versions of Leela's network without significant changes. However, maybe somebody sees something useful here or wants to update this project. 7 | 8 | #### Note: This is primarily for looking at the Leela Chess neural network itself, outside of search/MCTS (although search may be added eventually). 9 | 10 | This makes heavy use of python-chess located at https://github.com/niklasf/python-chess 11 | 12 | The current implementation is primarily geared towards pytorch, but tensorflow is possible using the training/tf portion of leela-chess. 13 | 14 | Example usage (also see /tests/*.py and [Examples.ipynb](https://github.com/so-much-meta/lczero_tools/blob/master/notebooks/Examples.ipynb)): 15 | ```python 16 | >>> from lcztools import load_network, LeelaBoard 17 | >>> # Note: use pytorch_cuda for cuda support 18 | >>> net = load_network('pytorch_cuda', 'weights.txt.gz') 19 | >>> board = LeelaBoard() 20 | >>> # Many of Python-chess's methods are passed through, along with board representation 21 | >>> board.push_uci('e2e4') 22 | >>> print(board) 23 | r n b q k b n r 24 | p p p p p p p p 25 | . . . . . . . . 26 | . . . . . . . . 27 | . . . . P . . . 28 | . . . . . . . . 29 | P P P P . P P P 30 | R N B Q K B N R 31 | Turn: Black 32 | >>> policy, value = net.evaluate(board) 33 | >>> print(policy) 34 | OrderedDict([('c7c5', 0.5102739), ('e7e5', 0.16549255), ('e7e6', 0.11846365), ('c7c6', 0.034872748), 35 | ('d7d6', 0.025344277), ('a7a6', 0.02313047), ('g8f6', 0.021814445), ('g7g6', 0.01614216), ('b8c6', 0.013772337), 36 | ('h7h6', 0.013361011), ('b7b6', 0.01300134), ('d7d5', 0.010980369), ('a7a5', 0.008497312), ('b8a6', 0.0048270077), 37 | ('g8h6', 0.004309486), ('f7f6', 0.0040882644), ('h7h5', 0.003910391), ('b7b5', 0.0027878743), ('f7f5', 0.0025032777), 38 | ('g7g5', 0.0024271626)]) 39 | >>> print(value) 40 | 0.4715215042233467 41 | ``` 42 | 43 | ## Create network server 44 | 45 | It is possible to load the network (or multiple different networks) once in a network server, and access this by multiple clients. This does not add much overhead, and creates a significant speedup by batching GPU operations if multiple clients are simultaneously connected. 46 | 47 | IPC communication is via zeromq. 48 | 49 | ```bash 50 | python -m lcztools.backend.net_server.server weights_file1.txt.gz weights_file2.txt.gz 51 | ``` 52 | After the server starts, clients can access it like so using the load_network interface: 53 | 54 | ```python 55 | >>> from lcztools import load_network, LeelaBoard 56 | >>> net0 = load_network(backend='net_client', network_id=0) 57 | >>> net1 = load_network(backend='net_client', network_id=1) 58 | >>> board = LeelaBoard() 59 | >>> policy0, value0 = net0.evaluate(board) 60 | >>> policy1, value1 = net1.evaluate(board) 61 | ``` 62 | 63 | Max batch size can be configured by entering it after the weights file. Default is 32. Batch sizes are generally powers of 2 (starting at 1), and it can help to set this to the batch size that will actually be used if less than 32 clients are connected. When a new client connects, it sends a "hi" message to the server, which causes the server to reset the batch size to the max_batch_size (this message can also be sent throughout a program's executin via net.hi()). The network server will block up to 1 second if the batch is not filled, at which time the batch size is reset to the greatest power of 2 less than or equal to the number of items currently in the batch. 64 | 65 | ```bash 66 | python -m lcztools.backend.net_server.server weights_file1.txt.gz 8 weights_file2.txt.gz 8 67 | ``` 68 | 69 | ## INSTALL 70 | ``` 71 | # With both torch and util dependencies for NN evaluation 72 | pip install git+https://github.com/so-much-meta/lczero_tools.git#egg=lczero-tools[torch,util] 73 | # Or just util extras (parse training games, run lczero engine, etc) 74 | pip install git+https://github.com/so-much-meta/lczero_tools.git#egg=lczero-tools[util] 75 | 76 | # Or from source tree... 77 | git clone https://github.com/so-much-meta/lczero_tools 78 | cd lczero_tools 79 | # Note: Creating and using a virtualenv or Conda environment before install is suggested, as always 80 | pip install .[torch,util] 81 | # Or for developer/editable install, to make in place changes: 82 | # pip install -e .[torch,util] 83 | ``` 84 | 85 | ## TODO 86 | 1. [x] **DONE:** Implement testing to verify position evaluations match lczero engine. 87 | * [ ] Using /tests/test_net_eq_engine.py, results look good. But specific PGNs might be helpful too. 88 | 2. [x] **DONE:** Add config mechanism and Jupyter notebook examples 89 | 3. [x] **DONE:** Add training data parser module. Use cases are: 90 | * [x] **DONE:** Training data to PGN 91 | * [ ] Verification of training data correctness. 92 | * [ ] Loss calculation - allow comparison between networks on same data 93 | 4. [x] **DONE:** lczero web scraping *(NOT FOR HEAVY USE)* 94 | * [x] **DONE:** Convert individidual match and training games to PGN (URL => PGN) 95 | * [x] **DONE:** Download weights files 96 | 5. [ ] OpenCL support! This should be possible with https://github.com/plaidml/plaidml 97 | 6. [ ] Investigate optimizations (CUDA, multiprocessing, etc). Goal is to eventually have a fast enough python-based implementation to do MCTS and get decent nodes/second comparable to Leela's engine -- in cases where neural network eval speed is the bottleneck. 98 | * [ ] However, no optimizations should get (too much) in the way of clarity or ease of changing code to do experiments. 99 | 7. [ ] Possible MCTS implementation 100 | -------------------------------------------------------------------------------- /notebooks/lcz_training/TrainLeela.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "Loading network using backend=pytorch_train_cuda, policy_softmax_temp=2.2\n", 13 | "Channels 256\n", 14 | "Blocks 20\n", 15 | "Enabling CUDA!\n", 16 | "+ conv_val.conv1.weight\n", 17 | "+ conv_val.conv1_bn.bias\n", 18 | "+ affine_val_1.weight\n", 19 | "+ affine_val_1.bias\n", 20 | "+ affine_val_2.weight\n", 21 | "+ affine_val_2.bias\n", 22 | "0.33122138580307364\n", 23 | "0.32327708806842564\n", 24 | "0.342647999804467\n", 25 | "0.30650920746847987\n", 26 | "0.34121590862050655\n", 27 | "0.31692144978791476\n", 28 | "0.2926337251719087\n", 29 | "0.3612888185679913\n", 30 | "0.32078073009848596\n", 31 | "0.33477764517068864\n", 32 | "0.33724299166351557\n", 33 | "0.29047099124640224\n", 34 | "0.3376015657559037\n", 35 | "0.3247099100984633\n", 36 | "0.3170356016419828\n", 37 | "0.2968645826354623\n", 38 | "0.3251196685619652\n", 39 | "0.33669298788532614\n", 40 | "0.26937713043764233\n", 41 | "0.33238239696249366\n", 42 | "0.31816340910270813\n", 43 | "0.32765887467190624\n", 44 | "0.3277877380698919\n", 45 | "0.3299694628827274\n", 46 | "0.33974271538667383\n", 47 | "0.33874722172040495\n", 48 | "0.28303455530665816\n", 49 | "0.29582444167695937\n", 50 | "0.3109586057346314\n", 51 | "0.31915072560776026\n", 52 | "0.26685509014409037\n", 53 | "0.31726893690880387\n", 54 | "0.28230360097251833\n", 55 | "0.2911954136285931\n", 56 | "0.31076860865578054\n", 57 | "0.3014677645917982\n", 58 | "0.33730487290304156\n", 59 | "0.25808153552934526\n", 60 | "0.3813514027744532\n", 61 | "0.3000475941039622\n", 62 | "0.31898987393826245\n", 63 | "0.3520307368272915\n", 64 | "0.3076914825686254\n", 65 | "0.2620176869956776\n", 66 | "0.27852552672382447\n", 67 | "0.2695530009781942\n", 68 | "0.3187382895802148\n", 69 | "0.2566404694085941\n", 70 | "0.3254203283647075\n", 71 | "0.26879885351751\n", 72 | "0.31229888861533256\n", 73 | "0.34966941135236995\n", 74 | "0.3004613875132054\n", 75 | "0.2827500160690397\n", 76 | "0.36580940315965566\n", 77 | "0.27884005995234473\n", 78 | "0.3543519354728051\n", 79 | "0.28202596956864\n", 80 | "0.2855419478751719\n", 81 | "0.31479672213084997\n", 82 | "0.2737147694430314\n", 83 | "0.2906757100787945\n", 84 | "0.3016007874789648\n", 85 | "0.3095072785508819\n", 86 | "0.2670763988723047\n", 87 | "0.30047072959365323\n", 88 | "0.3007216412178241\n", 89 | "0.30039629626087844\n", 90 | "0.304792027217336\n", 91 | "0.366527035953477\n", 92 | "0.29760916242259555\n", 93 | "0.2903023031866178\n", 94 | "0.2535957191651687\n", 95 | "0.31904821934411304\n", 96 | "0.24747976335929706\n", 97 | "0.3230142588587478\n", 98 | "0.305334740083199\n", 99 | "0.311342480762396\n", 100 | "0.3861961679300293\n", 101 | "0.3430443228641525\n", 102 | "0.3525827567675151\n", 103 | "0.31915661280741914\n", 104 | "0.3353808896802366\n", 105 | "0.33310799511382355\n", 106 | "0.3366224261629395\n", 107 | "0.32511442466173324\n", 108 | "0.26830084920627995\n", 109 | "0.2803230711631477\n", 110 | "0.30956072024535386\n", 111 | "0.2910797866503708\n", 112 | "0.28283236584393306\n", 113 | "0.30750197074841706\n", 114 | "0.27632841889513654\n", 115 | "0.28175996818346905\n", 116 | "0.3250257240596693\n", 117 | "0.2843242620374076\n", 118 | "0.3407336264417972\n", 119 | "0.2996929039095994\n", 120 | "0.3189259368716739\n", 121 | "0.2754355598753318\n", 122 | "0.2811162417801097\n", 123 | "0.29241135413292796\n", 124 | "0.3055539311701432\n", 125 | "0.33056803280021996\n", 126 | "0.24958291918504982\n", 127 | "0.2659967349621002\n", 128 | "0.3105535579600837\n", 129 | "0.25698168139439076\n", 130 | "0.26725212325342\n", 131 | "0.25075584913254717\n", 132 | "0.281782717872411\n", 133 | "0.3161420700128656\n", 134 | "0.3248244007397443\n", 135 | "0.2912295384472236\n", 136 | "0.24813229275634513\n", 137 | "0.30048654023790733\n", 138 | "0.32964418082381597\n", 139 | "0.26272039765375665\n", 140 | "0.3506261654710397\n", 141 | "0.27671427248278635\n", 142 | "0.32100102566881106\n", 143 | "0.2975786504847929\n", 144 | "0.29398134926217606\n", 145 | "0.3583666497457307\n", 146 | "0.2859612186846789\n", 147 | "0.36839267740491777\n", 148 | "0.317303850708995\n", 149 | "0.27428580562234856\n", 150 | "0.28598992243176324\n", 151 | "0.3094337029586313\n", 152 | "0.24691568039241246\n", 153 | "0.28369945848709904\n", 154 | "0.2891576769796666\n", 155 | "0.3117331481282599\n", 156 | "0.2648585990839638\n", 157 | "0.27846844331012105\n", 158 | "0.29358619465725494\n", 159 | "0.32549676625989377\n", 160 | "0.3275492696545552\n", 161 | "0.2644423476688098\n", 162 | "0.31238163021218496\n", 163 | "0.2792712036613375\n", 164 | "0.3063756470987573\n", 165 | "0.33435537169571034\n", 166 | "0.29879326241556553\n", 167 | "0.2701051006105263\n", 168 | "0.2859174999187235\n", 169 | "0.3723205540073104\n", 170 | "0.323884713210864\n", 171 | "0.25719603803823704\n", 172 | "0.29744810216128825\n", 173 | "0.2432499035645742\n", 174 | "0.28229812500532714\n", 175 | "0.29820196301094254\n", 176 | "0.29949649508460424\n", 177 | "0.23422659648931585\n", 178 | "0.31855869347637056\n", 179 | "0.33534564868779854\n", 180 | "0.3102129570953548\n", 181 | "0.2732491894660052\n", 182 | "0.32253215591888873\n", 183 | "0.3147825984249357\n", 184 | "0.32301999544841237\n", 185 | "0.34437846184941007\n", 186 | "0.29896002348395995\n", 187 | "0.2868724218604621\n", 188 | "0.3278878434060607\n", 189 | "0.3374404431518633\n", 190 | "0.3371820159163326\n", 191 | "0.3147764347144403\n", 192 | "0.2855677987704985\n", 193 | "0.23246080566314048\n", 194 | "0.2747319133405108\n", 195 | "0.2856643143354449\n", 196 | "0.2863041346904356\n", 197 | "0.2976805143430829\n", 198 | "0.2875572463357821\n", 199 | "0.3259528050501831\n", 200 | "0.3117244970391039\n", 201 | "0.26975210695294666\n", 202 | "0.32246552520897237\n", 203 | "0.25196910928585564\n", 204 | "0.274266328853555\n", 205 | "0.3027299148240127\n", 206 | "0.2755828746047337\n", 207 | "0.34116057070204986\n", 208 | "0.26580794156179766\n", 209 | "0.3143021793360822\n", 210 | "0.2719938613264821\n", 211 | "0.2531212016241625\n", 212 | "0.22731876538135112\n", 213 | "0.40862841307534836\n", 214 | "0.2607629435451236\n", 215 | "0.2890829382557422\n", 216 | "0.3487825653760228\n", 217 | "0.3312167233973742\n", 218 | "0.28688744620536455\n", 219 | "0.3162757219537161\n", 220 | "0.29184330118587243\n", 221 | "0.3219680774072185\n", 222 | "0.2802078399434686\n", 223 | "0.30263635813142176\n", 224 | "0.28415583640453407\n", 225 | "0.29365012045716865\n", 226 | "0.36803921759943475\n", 227 | "0.28537550766719505\n", 228 | "0.2923389035940636\n", 229 | "0.2705648097989615\n", 230 | "0.28317774552619085\n", 231 | "0.26816349643515425\n", 232 | "0.362670331816189\n", 233 | "0.2975506695196964\n", 234 | "0.3028185239760205\n", 235 | "0.3283600237010978\n", 236 | "0.2777108939830214\n", 237 | "0.3025348131218925\n", 238 | "0.29859000536380337\n", 239 | "0.31122938438667913\n", 240 | "0.31550328350858764\n", 241 | "0.2333105406339746\n", 242 | "0.320980086941272\n", 243 | "0.2828723770368379\n", 244 | "0.25524520606035367\n", 245 | "Saving weights file:\n", 246 | ".................................................................................................................Done saving weights!\n" 247 | ] 248 | } 249 | ], 250 | "source": [ 251 | "# Training example\n", 252 | "# \n", 253 | "# Using the match game training_data.tgz as created by running scripts/download_match_games/download_match_games.py\n", 254 | "# followed by scripts/download_match_games/make_training_data.py\n", 255 | "#\n", 256 | "# TODO: Need to shuffle data, add a test/train split.\n", 257 | "\n", 258 | "import pickle\n", 259 | "import tarfile\n", 260 | "import chess.pgn\n", 261 | "import io\n", 262 | "from lcztools import LeelaBoard, load_network\n", 263 | "from torch import optim, nn\n", 264 | "import numpy as np\n", 265 | "import torch\n", 266 | "\n", 267 | "net = load_network()\n", 268 | "\n", 269 | "for name, param in net.model.named_parameters():\n", 270 | " # print(name)\n", 271 | " if ('_val' in name) and not ('conv1_bn.weight' in name):\n", 272 | " print('+', name)\n", 273 | " param.requires_grad = True\n", 274 | " else:\n", 275 | " # print('-', name)\n", 276 | " param.requires_grad = False\n", 277 | "\n", 278 | "optimizer = optim.Adam(net.model.parameters(), lr = 0.0002, weight_decay=1e-5)\n", 279 | "criterion = nn.MSELoss()\n", 280 | "\n", 281 | "losses = []\n", 282 | "\n", 283 | "def do_step(training_game):\n", 284 | " if training_game[1] == '1/2-1/2':\n", 285 | " result = 0\n", 286 | " elif training_game[1] == '1-0':\n", 287 | " result = 1\n", 288 | " elif training_game[1] == '0-1':\n", 289 | " result = -1\n", 290 | " elif training_game[1] == '*':\n", 291 | " return\n", 292 | " features_stack = []\n", 293 | " results = []\n", 294 | " for compressed_features in training_game[0]:\n", 295 | " features_stack.append(LeelaBoard.decompress_features(compressed_features))\n", 296 | " results.append(result)\n", 297 | " result *= -1\n", 298 | " features_stack = np.stack(features_stack)\n", 299 | " \n", 300 | " optimizer.zero_grad()\n", 301 | " pols, vals = net.model(features_stack)\n", 302 | " results = torch.Tensor(results).view(-1, 1).cuda()\n", 303 | " loss = criterion(vals, results)\n", 304 | " losses.append(loss.item())\n", 305 | " if len(losses)==100:\n", 306 | " print(sum(losses)/len(losses))\n", 307 | " losses.clear()\n", 308 | " loss.backward()\n", 309 | " optimizer.step() \n", 310 | "\n", 311 | "with tarfile.open('training_data.tgz') as f:\n", 312 | " for idx, member in enumerate(f, 1):\n", 313 | " if member.isfile():\n", 314 | " training_game = pickle.load(f.extractfile(member))\n", 315 | " do_step(training_game)\n", 316 | " if idx%1000 == 0:\n", 317 | " print(\"Training games done:\", idx)\n", 318 | "\n", 319 | "net.model.save_weights_file('/home/trevor/projects/lczero/weights/txt/test_weights_3.txt')" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 3, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "!gzip '/home/trevor/projects/lczero/weights/txt/test_weights_3.txt'" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [] 337 | } 338 | ], 339 | "metadata": { 340 | "kernelspec": { 341 | "display_name": "Python [conda env:luna]", 342 | "language": "python", 343 | "name": "conda-env-luna-py" 344 | }, 345 | "language_info": { 346 | "codemirror_mode": { 347 | "name": "ipython", 348 | "version": 3 349 | }, 350 | "file_extension": ".py", 351 | "mimetype": "text/x-python", 352 | "name": "python", 353 | "nbconvert_exporter": "python", 354 | "pygments_lexer": "ipython3", 355 | "version": "3.7.0" 356 | } 357 | }, 358 | "nbformat": 4, 359 | "nbformat_minor": 2 360 | } 361 | -------------------------------------------------------------------------------- /notebooks/lcz_training/lcztools.ini: -------------------------------------------------------------------------------- 1 | [default] 2 | # This is directory where the network weights are stored 3 | weights_dir = ~/projects/lczero/weights/txt 4 | 5 | # This is the default weights filename to use, if none provided 6 | weights_file = weights_run1_21754.txt.gz 7 | 8 | # This is where raw training files are stored 9 | training_raw_dir = /Volumes/SeagateExternal/leela_data/training_raw/ 10 | 11 | # This is the default backend to use, if none provided 12 | # Choices are: ['pytorch_eval_cpu', 'pytorch_eval_cuda', 'pytorch_cpu', 'pytorch_cuda', 'pytorch_train_cpu', 'pytorch_train_cuda', 'tensorflow'] 13 | backend = pytorch_train_cuda 14 | 15 | ### This is the lczero engine to use, only currently used for testing/validation 16 | ## NOTE: lczero no longer supported!!! 17 | ## lczero_engine = ~/git/leela-chess/release/lczero 18 | 19 | 20 | # This is the lc0 engine to use, only currently used for testing/validation 21 | lc0_engine = ~/.local/bin/lc0 22 | 23 | 24 | ### This is, e.g. ~/sompath/leela-chess/training/tf -- currently only used as hackish tensorflow 25 | ### mechanism. Not needed for pytorch backend 26 | ## leela_training_tf_dir = ~/git/leela-chess/training/tf/ 27 | 28 | # This is the policy softmax temp, like lc0's option 29 | policy_softmax_temp = 2.2 30 | -------------------------------------------------------------------------------- /notebooks/lcztools.ini: -------------------------------------------------------------------------------- 1 | [default] 2 | # This is directory where the network weights are stored 3 | weights_dir = /Volumes/SeagateExternal/leela_data/weights/ 4 | 5 | # This is the default weights filename to use, if none provided 6 | weights_file = weights.txt.gz 7 | 8 | # This is where raw training files are stored 9 | training_raw_dir = /Volumes/SeagateExternal/leela_data/training_raw/ 10 | 11 | # This is the default backend to use, if none provided 12 | # Choices are: ['pytorch', 'pytorch_cuda', 'pytorch_orig', 'tensorflow', 'pytorch_train_cpu', 'pytorch_train_cuda'] 13 | backend = pytorch_train_cuda 14 | 15 | # This is the lczero engine to use, only currently used for testing/validation 16 | lczero_engine = ~/git/leela-chess/release/lczero 17 | 18 | # This is, e.g. ~/sompath/leela-chess/training/tf -- currently only used as hackish tensorflow 19 | # mechanism. Not needed for pytorch backend 20 | leela_training_tf_dir = ~/git/leela-chess/training/tf/ 21 | -------------------------------------------------------------------------------- /scripts/download_match_games/download_match_games.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from bs4 import BeautifulSoup 3 | import pandas as pd 4 | from lcztools.testing import WebMatchGame 5 | from collections import OrderedDict 6 | import shelve 7 | import sys 8 | import pickle 9 | import os 10 | import shutil 11 | 12 | if os.path.exists('web_pgn.data'): 13 | with open('web_pgn.data', 'rb') as f: 14 | df_matches, match_dfs, df_pgn = pickle.load(f) 15 | 16 | if 'df_pgn' in dir(): 17 | print("Currently {} games in web_pgn.data pickle".format(len(df_pgn))) 18 | else: 19 | print("No games yet in shelf...") 20 | # Everytime I run this, I'll grab an extra GAMES_TO_GRAB games 21 | GAMES_TO_GRAB = 100 22 | 23 | 24 | def get_table(url): 25 | global soup 26 | soup = BeautifulSoup(requests.get(url).text, 'html.parser') 27 | head = soup.thead.tr 28 | head_cols = [c.text for c in head.find_all('th')] 29 | head_cols.append('hrefs') 30 | 31 | rows = [] 32 | for tr in soup.tbody.find_all('tr'): 33 | cells = [c.text for c in tr.find_all('td')] 34 | cells.append([c['href'] for c in tr.find_all('a')]) 35 | rows.append(cells) 36 | df = pd.DataFrame(rows, columns=head_cols) 37 | return df 38 | 39 | df_matches = get_table('http://www.lczero.org/matches/1') 40 | df_matches = df_matches.astype({'Id':int, 'Run':int, 'Candidate':int, 'Current':int}).set_index('Id') 41 | df_matches = df_matches.sort_index(ascending=False) 42 | 43 | 44 | if 'match_dfs' not in dir(): 45 | match_dfs = OrderedDict() 46 | 47 | if 'df_pgn' not in dir(): 48 | print("Creating df_pgn") 49 | df_pgn = pd.DataFrame(index=pd.Index([], name='match_game_id', dtype=int), columns=['match_id', 'pgn']) 50 | df_pgn = df_pgn.astype({'match_id':int}) 51 | 52 | all_games_received = set(df_pgn.index) 53 | num_pgns_grabbed = 0 54 | for match_id, row in df_matches.iterrows(): 55 | href = row.hrefs[0] 56 | if match_id in match_dfs: 57 | print("Already retrieved", href) 58 | else: 59 | print(href) 60 | match_dfs[match_id] = get_table(f'http://www.lczero.org/{href.lstrip("/")}') 61 | match_dfs[match_id] = match_dfs[match_id].astype({'Game Id':int}).set_index('Game Id') 62 | match_dfs[match_id] = match_dfs[match_id].sort_index(ascending=False) 63 | df = match_dfs[match_id] 64 | for game_id, row in df.iterrows(): 65 | if game_id in all_games_received: 66 | continue 67 | href = row.hrefs[0] 68 | print(game_id, end=', ') 69 | sys.stdout.flush() 70 | wmg = WebMatchGame(href) 71 | pgn = wmg.pgn 72 | df_pgn.loc[game_id] = match_id, pgn 73 | all_games_received.add(game_id) 74 | num_pgns_grabbed += 1 75 | if num_pgns_grabbed >= GAMES_TO_GRAB: 76 | break 77 | if num_pgns_grabbed >= GAMES_TO_GRAB: 78 | break 79 | 80 | print("Done... Downloaded {} files. Saving back to shelf".format(num_pgns_grabbed)) 81 | 82 | if os.path.exists('web_pgn.data'): 83 | shutil.copy('web_pgn.data', 'web_pgn.data.bup') 84 | 85 | print("Writing web_pgn.data") 86 | with open('web_pgn.data', 'wb') as f: 87 | pickle.dump((df_matches, match_dfs, df_pgn), f) 88 | 89 | print("Currently {} games in web_pgn.data pickle".format(len(df_pgn))) 90 | -------------------------------------------------------------------------------- /scripts/download_match_games/make_training_data.py: -------------------------------------------------------------------------------- 1 | import shelve 2 | from lcztools import LeelaBoard 3 | import chess.pgn 4 | import io 5 | import pickle 6 | import sys 7 | 8 | # with shelve.open('web_pgn_data.shelf') as s: 9 | # df_matches = s['df_matches'] 10 | # match_dfs = s['match_dfs'] 11 | # df_pgn = s['df_pgn'] 12 | 13 | with open('web_pgn.data', 'rb') as f: 14 | df_matches, match_dfs, df_pgn = pickle.load(f) 15 | 16 | 17 | for i, (idx, row) in enumerate(df_pgn.iterrows(), 1): 18 | lcz_board = LeelaBoard() 19 | pgn_game = chess.pgn.read_game(io.StringIO(row.pgn)) 20 | moves = [move for move in pgn_game.main_line()] 21 | compressed = [] 22 | features = lcz_board.lcz_features() 23 | compressed.append(lcz_board.compress_features(features)) 24 | for move in moves[:-1]: 25 | # print(move) 26 | lcz_board.push(move) 27 | features = lcz_board.lcz_features() 28 | compressed_features = lcz_board.compress_features(features) 29 | compressed.append(compressed_features) 30 | # assert(check_compressed_features(lcz_board, compressed_features)) 31 | training_game = (compressed, pgn_game.headers['Result'], row.pgn) 32 | with open(f'./training_data/match_game_{idx}.data', 'wb') as f: 33 | f.write(pickle.dumps(training_game)) 34 | print('.', end='') 35 | if i%100==0: 36 | print() 37 | if i%1000==0: 38 | print(i) 39 | sys.stdout.flush() 40 | 41 | -------------------------------------------------------------------------------- /scripts/opening/data/bookfish_opening_seqs_2.txt: -------------------------------------------------------------------------------- 1 | 2.704884289e-01 d2d4 g8f6 2 | 2.046996251e-01 d2d4 d7d5 3 | 1.630055324e-01 e2e4 c7c5 4 | 9.388722100e-02 e2e4 e7e5 5 | 6.376465784e-02 e2e4 e7e6 6 | 3.621977973e-02 e2e4 c7c6 7 | 2.397577573e-02 e2e4 d7d6 8 | 2.291826391e-02 e2e4 d7d5 9 | 1.967957731e-02 g1f3 d7d5 10 | 1.817544403e-02 d2d4 d7d6 11 | 1.596917626e-02 d2d4 f7f5 12 | 1.581310639e-02 e2e4 g8f6 13 | 1.507753078e-02 c2c4 e7e5 14 | 1.436950201e-02 c2c4 c7c5 15 | 1.064364567e-02 c2c4 g8f6 16 | 6.408682577e-03 d2d4 e7e6 17 | 2.522224307e-03 g1f3 g8f6 18 | 9.341571509e-04 g1f3 c7c5 19 | 8.357060941e-04 c2c4 b7b6 20 | 3.133897853e-04 c2c4 c7c6 21 | 2.521448883e-04 d2d4 c7c6 22 | 4.642811634e-05 c2c4 e7e6 23 | -------------------------------------------------------------------------------- /scripts/opening/data/bookfish_opening_seqs_4.txt: -------------------------------------------------------------------------------- 1 | 1.169119118e-01 d2d4 d7d5 c2c4 c7c6 2 | 1.016008808e-01 d2d4 g8f6 c2c4 g7g6 3 | 7.062450247e-02 d2d4 g8f6 c2c4 c7c5 4 | 6.927146411e-02 d2d4 g8f6 c2c4 e7e6 5 | 6.565051618e-02 e2e4 c7c5 g1f3 d7d6 6 | 6.376465784e-02 e2e4 e7e6 d2d4 d7d5 7 | 5.975635841e-02 e2e4 c7c5 g1f3 b8c6 8 | 5.605509705e-02 e2e4 e7e5 g1f3 b8c6 9 | 4.444657102e-02 d2d4 d7d5 c2c4 d5c4 10 | 3.621977973e-02 e2e4 c7c6 d2d4 d7d5 11 | 2.697296810e-02 d2d4 d7d5 c2c4 e7e6 12 | 2.397577573e-02 e2e4 d7d6 d2d4 g8f6 13 | 2.325591958e-02 e2e4 c7c5 g1f3 e7e6 14 | 2.151441453e-02 e2e4 d7d5 e4d5 d8d5 15 | 1.930936587e-02 e2e4 e7e5 g1f3 g8f6 16 | 1.721884171e-02 d2d4 d7d6 c2c4 e7e5 17 | 1.596917626e-02 d2d4 f7f5 g2g3 g8f6 18 | 1.581310639e-02 e2e4 g8f6 e4e5 f6d5 19 | 1.463078913e-02 c2c4 e7e5 g2g3 g8f6 20 | 1.152583993e-02 d2d4 g8f6 c1g5 f6e4 21 | 1.032729150e-02 d2d4 d7d5 c2c4 b8c6 22 | 9.095398394e-03 g1f3 d7d5 c2c4 c7c6 23 | 9.093200489e-03 c2c4 c7c5 g2g3 b8c6 24 | 7.074804187e-03 d2d4 g8f6 c1g5 e7e6 25 | 6.219827319e-03 e2e4 e7e5 b1c3 g8f6 26 | 5.781163918e-03 d2d4 g8f6 c2c4 e7e5 27 | 5.166624406e-03 e2e4 e7e5 g1f3 f7f5 28 | 5.108949922e-03 c2c4 g8f6 g2g3 c7c5 29 | 4.924513693e-03 d2d4 e7e6 c2c4 b7b6 30 | 4.912072808e-03 e2e4 c7c5 c2c3 d7d5 31 | 3.870507631e-03 e2e4 c7c5 c2c3 g8f6 32 | 3.654046516e-03 g1f3 d7d5 c2c4 e7e6 33 | 3.277241085e-03 g1f3 d7d5 b2b3 c8g4 34 | 3.031066830e-03 c2c4 c7c5 g2g3 g8f6 35 | 2.613439576e-03 e2e4 e7e5 f2f4 e5f4 36 | 2.554474961e-03 c2c4 g8f6 g2g3 c7c6 37 | 2.529234683e-03 e2e4 c7c5 b1c3 b8c6 38 | 2.494285796e-03 e2e4 e7e5 f2f4 b8c6 39 | 2.342650393e-03 d2d4 g8f6 c1g5 d7d5 40 | 2.291388622e-03 d2d4 d7d5 g1f3 g8f6 41 | 2.115826754e-03 d2d4 d7d5 c1g5 c7c6 42 | 1.818586767e-03 c2c4 c7c5 b1c3 b8c6 43 | 1.403849376e-03 e2e4 d7d5 e4d5 g8f6 44 | 1.358486536e-03 d2d4 e7e6 c2c4 f8b4 45 | 1.310690598e-03 g1f3 d7d5 c2c4 d5d4 46 | 1.271036039e-03 d2d4 g8f6 c2c4 d7d6 47 | 1.191536908e-03 g1f3 d7d5 c2c4 d5c4 48 | 1.068665571e-03 g1f3 d7d5 b2b3 g8f6 49 | 9.590032261e-04 c2c4 g8f6 b1c3 e7e6 50 | 9.566023173e-04 d2d4 d7d6 e2e4 g8f6 51 | 9.129768831e-04 e2e4 c7c5 c2c3 e7e6 52 | 9.038793078e-04 g1f3 g8f6 c2c4 b7b6 53 | 8.357060941e-04 c2c4 b7b6 d2d4 e7e6 54 | 8.015807451e-04 e2e4 e7e5 g1f3 d7d6 55 | 7.832070775e-04 c2c4 g8f6 g1f3 b7b6 56 | 7.583287003e-04 e2e4 e7e5 f1c4 g8f6 57 | 7.467623837e-04 d2d4 d7d5 c1g5 h7h6 58 | 6.815179550e-04 e2e4 c7c5 c2c3 d7d6 59 | 6.756421567e-04 d2d4 d7d5 c1g5 f7f6 60 | 6.558002963e-04 e2e4 c7c5 c2c3 d8a5 61 | 5.990361169e-04 d2d4 g8f6 g1f3 e7e6 62 | 5.213501809e-04 g1f3 g8f6 c2c4 e7e6 63 | 4.722611394e-04 g1f3 g8f6 d2d4 e7e6 64 | 4.517474270e-04 c2c4 g8f6 g1f3 e7e6 65 | 4.371516737e-04 e2e4 c7c5 b2b3 b8c6 66 | 4.203707179e-04 g1f3 c7c5 c2c4 g8f6 67 | 4.159817238e-04 c2c4 c7c5 b1c3 g8f6 68 | 3.655393463e-04 c2c4 g8f6 b1c3 d7d5 69 | 3.540815150e-04 c2c4 e7e5 b1c3 g8f6 70 | 3.434763150e-04 e2e4 c7c5 b2b3 b7b6 71 | 3.256869988e-04 e2e4 e7e5 f2f4 f8c5 72 | 3.023214067e-04 g1f3 g8f6 c2c4 c7c5 73 | 2.619600464e-04 c2c4 g8f6 g1f3 c7c5 74 | 2.112627808e-04 d2d4 d7d5 g1f3 b8c6 75 | 2.062651233e-04 g1f3 c7c5 e2e4 d7d6 76 | 1.994298634e-04 c2c4 c7c6 g2g3 d7d5 77 | 1.952414011e-04 d2d4 g8f6 g1f3 g7g6 78 | 1.877464695e-04 g1f3 c7c5 e2e4 b8c6 79 | 1.539221491e-04 g1f3 g8f6 d2d4 g7g6 80 | 1.429845361e-04 e2e4 e7e5 f2f4 d8h4 81 | 1.260724441e-04 d2d4 c7c6 c2c4 d7d5 82 | 1.260724441e-04 d2d4 c7c6 e2e4 d7d5 83 | 1.109326142e-04 d2d4 g8f6 g1f3 d7d5 84 | 8.745576655e-05 g1f3 g8f6 d2d4 d7d5 85 | 8.322964385e-05 d2d4 e7e6 e2e4 d7d5 86 | 8.200232508e-05 d2d4 g8f6 c2c4 c7c6 87 | 7.507630955e-05 g1f3 d7d5 d2d4 g8f6 88 | 7.306698256e-05 g1f3 c7c5 e2e4 e7e6 89 | 6.880740635e-05 c2c4 g8f6 b1c3 g7g6 90 | 6.271294495e-05 c2c4 e7e5 b1c3 b8c6 91 | 5.697996096e-05 c2c4 c7c6 d2d4 d7d5 92 | 5.697996096e-05 c2c4 c7c6 g1f3 d7d5 93 | 4.010386007e-05 g1f3 g8f6 c2c4 g7g6 94 | 3.772284452e-05 c2c4 e7e6 g2g3 d7d5 95 | 3.474980207e-05 c2c4 g8f6 g1f3 g7g6 96 | 2.776421082e-05 g1f3 g8f6 c2c4 c7c6 97 | 2.405755528e-05 c2c4 g8f6 g1f3 c7c6 98 | 2.122635212e-05 d2d4 e7e6 c2c4 g8f6 99 | 2.122635212e-05 d2d4 e7e6 c2c4 d7d5 100 | 1.937759535e-05 c2c4 e7e5 b1c3 f8b4 101 | 1.648512619e-05 g1f3 c7c5 c2c4 b8c6 102 | 1.648512619e-05 g1f3 c7c5 c2c4 g7g6 103 | 9.864957632e-06 c2c4 e7e5 b1c3 d7d6 104 | 8.874609139e-06 d2d4 g8f6 g1f3 c7c6 105 | 8.600925794e-06 c2c4 g8f6 b1c3 d7d6 106 | 8.600925794e-06 c2c4 g8f6 b1c3 c7c6 107 | 8.600925794e-06 c2c4 g8f6 b1c3 c7c5 108 | 6.996461324e-06 g1f3 g8f6 d2d4 c7c6 109 | 6.921929249e-06 g1f3 d7d5 d2d4 b8c6 110 | 6.169824627e-06 g1f3 g8f6 c2c4 d7d6 111 | 5.495042064e-06 g1f3 c7c5 c2c4 b7b6 112 | 5.495042064e-06 g1f3 c7c5 c2c4 e7e6 113 | 5.346123396e-06 c2c4 g8f6 g1f3 d7d6 114 | 5.333099023e-06 c2c4 c7c5 b1c3 g7g6 115 | 5.333099023e-06 c2c4 c7c5 b1c3 b7b6 116 | 3.627196589e-06 c2c4 e7e6 g1f3 g8f6 117 | 3.627196589e-06 c2c4 e7e6 g1f3 b7b6 118 | 2.747521032e-06 g1f3 c7c5 c2c4 d7d6 119 | 1.129543099e-06 c2c4 e7e6 d2d4 b7b6 120 | 7.046398309e-07 c2c4 e7e5 b1c3 f7f5 121 | 3.115980962e-07 c2c4 e7e6 d2d4 f8b4 122 | 4.868720254e-09 c2c4 e7e6 d2d4 g8f6 123 | 4.868720254e-09 c2c4 e7e6 d2d4 d7d5 124 | -------------------------------------------------------------------------------- /scripts/opening/data/readme.MD: -------------------------------------------------------------------------------- 1 | All of the text files here have been generated with polyglot_to_seqs.py using bookfish.bin 2 | 3 | bookfish.bin is available here: http://rebel13.nl/download/books.html 4 | -------------------------------------------------------------------------------- /scripts/opening/polyglot_to_seqs.py: -------------------------------------------------------------------------------- 1 | """This will convert a polyglot book to a list of move sequences and corresponding probabilities 2 | given that moves are selected in proportion to their polyglot weight (i.e., win-rate) 3 | 4 | E.g.: 5 | """ 6 | 7 | 8 | import chess 9 | import chess.polyglot 10 | from collections import namedtuple 11 | import fire 12 | import math 13 | import numpy as np 14 | 15 | Entry = namedtuple('Entry', 'move weight') 16 | 17 | def dfs(reader, board, depth, ll=0): 18 | """Yield a list of move-sequences and corresponding log-likelihoods""" 19 | total_weight = 0 20 | entries = [] 21 | for pentry in reader.find_all(board): 22 | total_weight += pentry.weight 23 | entries.append(Entry(pentry.move(), pentry.weight)) 24 | for entry in entries: 25 | board.push(entry.move) 26 | ll_cur = ll + math.log(entry.weight/total_weight) 27 | if depth<=1: 28 | moves = ' '.join(move.uci() for move in board.move_stack) 29 | yield((moves, ll_cur)) 30 | else: 31 | yield from dfs(reader, board, depth-1, ll_cur) 32 | board.pop() 33 | 34 | def main(book_filename, length): 35 | """Output a list of fixed-length move-sequences and corresponding probabilities 36 | in a polyglot opening book 37 | 38 | book_filename: Filename of polyglot opening book 39 | length: Length of move sequences to extract 40 | """ 41 | assert(length>=1) 42 | with chess.polyglot.open_reader("bookfish.bin") as reader: 43 | moves, lls = zip(*dfs(reader, chess.Board(), length)) 44 | lls = np.array(lls) # log likelihoods 45 | ls = np.exp(lls - max(lls)) # likelihoods 46 | ps = ls/sum(ls) # probabilities 47 | seqs = zip(ps, moves) 48 | seqs = sorted(seqs, key=lambda it: it[0], reverse=True) 49 | for item in seqs: 50 | print("{:.9e} {}".format(*item)) 51 | 52 | # main("bookfish.bin", 8) 53 | fire.Fire(main) -------------------------------------------------------------------------------- /scripts/train_to_pgn.py: -------------------------------------------------------------------------------- 1 | '''This will convert a Leela Chess tar training file to PGN. 2 | 3 | Usage example: 4 | python train_to_pgn.py games14620000.tar.gz 5 | 6 | -- This will output a file called games14620000.pgn 7 | ''' 8 | 9 | 10 | from lcztools.testing import TarTrainingFile 11 | from lcztools.util import tqdm 12 | import fire 13 | 14 | 15 | def train_to_pgn(train_filename): 16 | TarTrainingFile(filename).to_pgn() 17 | 18 | if __name__ == '__main__': 19 | fire.Fire(train_to_pgn) 20 | -------------------------------------------------------------------------------- /scripts/wip/analyse_games.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | import chess.pgn 3 | import io 4 | import tqdm 5 | import chess 6 | import chess.uci 7 | import re 8 | import numpy as np 9 | import os 10 | 11 | try: 12 | engine.quit() 13 | except: 14 | pass 15 | 16 | 17 | if True: 18 | command = [os.path.expanduser('~/Downloads/stockfish-9-mac/Mac/stockfish-9-popcnt')] 19 | # command = ['/Applications/chessx.app/Contents/MacOS/data/engines-mac/uci/stockfish-8-64'] 20 | engine = chess.uci.popen_engine(command) 21 | 22 | engine.uci() 23 | 24 | DEPTH = 9 25 | ZEROS = [0]*DEPTH 26 | 27 | def _info(self, arg): 28 | desc, score = regex.search(arg).groups() 29 | score = int(score) 30 | if desc == 'cp': 31 | self._v_cps[self._v_depth] = score 32 | elif desc == 'mate': 33 | self._v_mates[self._v_depth] = score 34 | else: 35 | raise Exception("oops") 36 | self._v_depth += 1 37 | 38 | stripfen = str.maketrans('', '', '0123456789/') 39 | 40 | engine.__class__._info = _info 41 | engine._v_cps = ZEROS.copy() 42 | engine._v_mates = ZEROS.copy() 43 | engine._v_depth = 0 44 | 45 | regex = re.compile('score ([^ ]*) ([^ ]*) ') 46 | 47 | 48 | columns = ['name', 'fen', 'turn', 'castling', 'en_passant', 'halfmove', 'fullmove', 'npieces'] 49 | columns += ['cp_{}'.format(n) for n in range(1, DEPTH+1)] 50 | columns += ['mate_{}'.format(n) for n in range(1, DEPTH+1)] 51 | columns += ['uci'] 52 | columns += ['finished', 'score', 'threefold', 'fifty', 'insufficient', 'stalemate'] 53 | all_records = [] 54 | def scoreit(game, name): 55 | global board 56 | board = chess.Board() 57 | records = [] 58 | for idx, move in enumerate(game.main_line()): 59 | engine.position(board) 60 | engine._v_cps[:] = ZEROS 61 | engine._v_mates[:] = ZEROS 62 | engine._v_depth = 0 63 | engine.go(depth=DEPTH) 64 | fensplit = board.fen().split() 65 | npieces = len(fensplit[0].translate(stripfen)) 66 | records.append([name] + fensplit + [npieces] + \ 67 | engine._v_cps + engine._v_mates + \ 68 | [move.uci()]) 69 | board.push(move) 70 | if board.can_claim_threefold_repetition(): 71 | finished = 1 72 | score = 0 73 | threefold = 1 74 | fifty = 0 75 | insufficient = 0 76 | stalemate = 0 77 | elif board.can_claim_fifty_moves(): 78 | finished = 1 79 | score = 0 80 | threefold = 0 81 | fifty = 1 82 | insufficient = 0 83 | stalemate = 0 84 | elif board.is_insufficient_material(): 85 | finished = 1 86 | score = 0 87 | threefold = 0 88 | fifty = 0 89 | insufficient = 1 90 | stalemate = 0 91 | elif board.is_stalemate(): 92 | finished = 1 93 | score = 0 94 | threefold = 0 95 | fifty = 0 96 | insufficient = 0 97 | stalemate = 1 98 | elif board.is_checkmate(): 99 | finished = 1 100 | score = [1,0][board.turn] 101 | threefold = 0 102 | fifty = 0 103 | insufficient = 0 104 | stalemate = 0 105 | elif board.can_claim_draw(): 106 | # Don't think this can happen... 107 | raise Exception("Why") 108 | finished = 1 109 | score = 0 110 | threefold = 0 111 | fifty = 0 112 | insufficient = 0 113 | stalemate = 0 114 | elif not board.is_game_over(): 115 | finished = 0 116 | score = 0 117 | threefold = 0 118 | fifty = 0 119 | insufficient = 0 120 | stalemate = 0 121 | else: 122 | raise Exception("I don't know") 123 | result = [finished, score, threefold, fifty, insufficient, stalemate] 124 | for record in records: 125 | all_records.append(record + result) 126 | 127 | 128 | # with open('/Volumes/SeagateExternal/leela_data/match_games/LeelaMatchGamesPgn_200000-300000.zip', mode='r') as f: 129 | with zipfile.ZipFile('/Volumes/SeagateExternal/leela_data/match_games/FixedLeelaMatchGamesPgn_200000-300000.zip') as z: 130 | namelist = z.namelist() 131 | for name in tqdm.tqdm(namelist[:20]): 132 | if name.endswith('.pgn'): 133 | game = chess.pgn.read_game(io.TextIOWrapper(z.open(name))) 134 | scoreit(game, name.split('/')[-1].split('.')[0]) 135 | 136 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Package setup script.""" 2 | import setuptools 3 | 4 | setuptools.setup( 5 | name='lczero-tools', 6 | version='0.2', 7 | packages=setuptools.find_packages('src'), 8 | package_dir={'': 'src'}, 9 | install_requires=[ 10 | 'numpy', 11 | 'python-chess', 12 | ], 13 | extras_require={ 14 | 'tf': ['tensorflow'], 15 | 'tf-gpu': ['tensorflow-gpu'], 16 | 'torch': ['torch'], 17 | 'util': ['tqdm', 'requests', 'BeautifulSoup4', 'fire'], 18 | }, 19 | setup_requires=[], 20 | tests_require=[], 21 | ) 22 | -------------------------------------------------------------------------------- /src/lcztools/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from lcztools._leela_board import LeelaBoard 4 | from lcztools.backend import load_network, LeelaNet, list_backends 5 | from . import testing, backend 6 | 7 | 8 | -------------------------------------------------------------------------------- /src/lcztools/_leela_board.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | import chess 4 | from chess import Move 5 | import struct 6 | from lcztools._uci_to_idx import uci_to_idx as _uci_to_idx 7 | import zlib 8 | 9 | flat_planes = [] 10 | for i in range(256): 11 | flat_planes.append(np.ones((8,8), dtype=np.uint8)*i) 12 | 13 | LeelaBoardData = collections.namedtuple('LeelaBoardData', 14 | 'plane_bytes repetition ' 15 | 'transposition_key us_ooo us_oo them_ooo them_oo ' 16 | 'side_to_move rule50_count') 17 | 18 | def pc_board_property(propertyname): 19 | '''Create a property based on self.pc_board''' 20 | def prop(self): 21 | return getattr(self.pc_board, propertyname) 22 | return property(prop) 23 | 24 | class LeelaBoard: 25 | turn = pc_board_property('turn') 26 | move_stack = pc_board_property('move_stack') 27 | _plane_bytes_struct = struct.Struct('>Q') 28 | 29 | def __init__(self, leela_board = None, *args, **kwargs): 30 | '''If leela_board is passed as an argument, return a copy''' 31 | self.pc_board = chess.Board(*args, **kwargs) 32 | self.lcz_stack = [] 33 | self._lcz_transposition_counter = collections.Counter() 34 | self._lcz_push() 35 | self.is_game_over = self.pc_method('is_game_over') 36 | self.can_claim_draw = self.pc_method('can_claim_draw') 37 | self.generate_legal_moves = self.pc_method('generate_legal_moves') 38 | 39 | def copy(self, history=7): 40 | """Note! Currently the copy constructor uses pc_board.copy(stack=False), which makes pops impossible""" 41 | cls = type(self) 42 | copied = cls.__new__(cls) 43 | copied.pc_board = self.pc_board.copy(stack=False) 44 | copied.pc_board.stack[:] = self.pc_board.stack[-history:] 45 | copied.pc_board.move_stack[:] = self.pc_board.move_stack[-history:] 46 | copied.lcz_stack = self.lcz_stack[-history:] 47 | copied._lcz_transposition_counter = self._lcz_transposition_counter.copy() 48 | copied.is_game_over = copied.pc_method('is_game_over') 49 | copied.can_claim_draw = copied.pc_method('can_claim_draw') 50 | copied.generate_legal_moves = copied.pc_method('generate_legal_moves') 51 | return copied 52 | 53 | def pc_method(self, methodname): 54 | '''Return attribute of self.pc_board, useful for copying method bindings''' 55 | return getattr(self.pc_board, methodname) 56 | 57 | def is_threefold(self): 58 | transposition_key = self.pc_board._transposition_key() 59 | return self._lcz_transposition_counter[transposition_key] >= 3 60 | 61 | def is_fifty_moves(self): 62 | return self.pc_board.halfmove_clock >= 100 63 | 64 | def is_draw(self): 65 | return self.is_threefold() or self.is_fifty_moves() 66 | 67 | def push(self, move): 68 | self.pc_board.push(move) 69 | self._lcz_push() 70 | 71 | def push_uci(self, uci): 72 | # don't check for legality - it takes much longer to run... 73 | # self.pc_board.push_uci(uci) 74 | self.pc_board.push(Move.from_uci(uci)) 75 | self._lcz_push() 76 | 77 | def push_san(self, san): 78 | self.pc_board.push_san(san) 79 | self._lcz_push() 80 | 81 | def pop(self): 82 | result = self.pc_board.pop() 83 | _lcz_data = self.lcz_stack.pop() 84 | self._lcz_transposition_counter.subtract((_lcz_data.transposition_key,)) 85 | return result 86 | 87 | def _plane_bytes_iter(self): 88 | """Get plane bytes... used for _lcz_push""" 89 | pack = self._plane_bytes_struct.pack 90 | pieces_mask = self.pc_board.pieces_mask 91 | for color in (True, False): 92 | for piece_type in range(1,7): 93 | byts = pack(pieces_mask(piece_type, color)) 94 | yield byts 95 | 96 | def _lcz_push(self): 97 | """Push data onto the lcz data stack after pushing board moves""" 98 | transposition_key = self.pc_board._transposition_key() 99 | self._lcz_transposition_counter.update((transposition_key,)) 100 | repetitions = self._lcz_transposition_counter[transposition_key] - 1 101 | # side_to_move = 0 if we're white, 1 if we're black 102 | side_to_move = 0 if self.pc_board.turn else 1 103 | rule50_count = self.pc_board.halfmove_clock 104 | # Figure out castling rights 105 | if not side_to_move: 106 | # we're white 107 | _c = self.pc_board.castling_rights 108 | us_ooo, us_oo = (_c>>chess.A1) & 1, (_c>>chess.H1) & 1 109 | them_ooo, them_oo = (_c>>chess.A8) & 1, (_c>>chess.H8) & 1 110 | else: 111 | # We're black 112 | _c = self.pc_board.castling_rights 113 | us_ooo, us_oo = (_c>>chess.A8) & 1, (_c>>chess.H8) & 1 114 | them_ooo, them_oo = (_c>>chess.A1) & 1, (_c>>chess.H1) & 1 115 | # Create 13 planes... 6 us, 6 them, repetitions>=1 116 | plane_bytes = b''.join(self._plane_bytes_iter()) 117 | repetition = (repetitions>=1) 118 | lcz_data = LeelaBoardData( 119 | plane_bytes, repetition=repetition, 120 | transposition_key=transposition_key, 121 | us_ooo=us_ooo, us_oo=us_oo, them_ooo=them_ooo, them_oo=them_oo, 122 | side_to_move=side_to_move, rule50_count=rule50_count 123 | ) 124 | self.lcz_stack.append(lcz_data) 125 | 126 | def serialize_features(self): 127 | '''Get compacted bytes representation of input planes''' 128 | planes = [] 129 | curdata = self.lcz_stack[-1] 130 | bytes_false_true = bytes([False]), bytes([True]) 131 | bytes_per_history = 97 132 | total_plane_bytes = bytes_per_history * 8 133 | def bytes_iter(): 134 | plane_bytes_yielded = 0 135 | for data in self.lcz_stack[-1:-9:-1]: 136 | yield data.plane_bytes 137 | yield bytes_false_true[data.repetition] 138 | plane_bytes_yielded += bytes_per_history 139 | # 104 total piece planes... fill in missing with 0s 140 | yield bytes(total_plane_bytes - plane_bytes_yielded) 141 | # Yield the rest of the constant planes 142 | yield np.packbits((curdata.us_ooo, 143 | curdata.us_oo, 144 | curdata.them_ooo, 145 | curdata.them_oo, 146 | curdata.side_to_move)).tobytes() 147 | yield chr(curdata.rule50_count).encode() 148 | return b''.join(bytes_iter()) 149 | 150 | @classmethod 151 | def deserialize_features(cls, serialized): 152 | planes_stack = [] 153 | rule50_count = serialized[-1] # last byte is rule 50 154 | board_attrs = np.unpackbits(memoryview(serialized[-2:-1])) # second to last byte 155 | us_ooo, us_oo, them_ooo, them_oo, side_to_move = board_attrs[:5] 156 | bytes_per_history = 97 157 | for history_idx in range(0, bytes_per_history*8, bytes_per_history): 158 | plane_bytes = serialized[history_idx:history_idx+96] 159 | repetition = serialized[history_idx+96] 160 | if not side_to_move: 161 | # we're white 162 | planes = (np.unpackbits(memoryview(plane_bytes))[::-1] 163 | .reshape(12, 8, 8)[::-1]) 164 | else: 165 | # We're black 166 | planes = (np.unpackbits(memoryview(plane_bytes))[::-1] 167 | .reshape(12, 8, 8)[::-1] 168 | .reshape(2,6,8,8)[::-1,:,::-1] 169 | .reshape(12, 8,8)) 170 | planes_stack.append(planes) 171 | planes_stack.append([flat_planes[repetition]]) 172 | planes_stack.append([flat_planes[us_ooo], 173 | flat_planes[us_oo], 174 | flat_planes[them_ooo], 175 | flat_planes[them_oo], 176 | flat_planes[side_to_move], 177 | flat_planes[rule50_count], 178 | flat_planes[0], 179 | flat_planes[1]]) 180 | planes = np.concatenate(planes_stack) 181 | return planes 182 | 183 | def lcz_features(self): 184 | '''Get neural network input planes as uint8''' 185 | # print(list(self._planes_iter())) 186 | planes_stack = [] 187 | curdata = self.lcz_stack[-1] 188 | planes_yielded = 0 189 | for data in self.lcz_stack[-1:-9:-1]: 190 | plane_bytes = data.plane_bytes 191 | if not curdata.side_to_move: 192 | # we're white 193 | planes = (np.unpackbits(memoryview(plane_bytes))[::-1] 194 | .reshape(12, 8, 8)[::-1]) 195 | else: 196 | # We're black 197 | planes = (np.unpackbits(memoryview(plane_bytes))[::-1] 198 | .reshape(12, 8, 8)[::-1] 199 | .reshape(2,6,8,8)[::-1,:,::-1] 200 | .reshape(12, 8,8)) 201 | planes_stack.append(planes) 202 | planes_stack.append([flat_planes[data.repetition]]) 203 | planes_yielded += 13 204 | empty_planes = [flat_planes[0] for _ in range(104-planes_yielded)] 205 | if empty_planes: 206 | planes_stack.append(empty_planes) 207 | # Yield the rest of the constant planes 208 | planes_stack.append([flat_planes[curdata.us_ooo], 209 | flat_planes[curdata.us_oo], 210 | flat_planes[curdata.them_ooo], 211 | flat_planes[curdata.them_oo], 212 | flat_planes[curdata.side_to_move], 213 | flat_planes[curdata.rule50_count], 214 | flat_planes[0], 215 | flat_planes[1]]) 216 | planes = np.concatenate(planes_stack) 217 | return planes 218 | 219 | def lcz_uci_to_idx(self, uci_list): 220 | # Return list of NN policy output indexes for this board position, given uci_list 221 | 222 | # TODO: Perhaps it's possible to just add the uci knight promotion move to the index dict 223 | # currently knight promotions are not in the dict 224 | uci_list = [uci.rstrip('n') for uci in uci_list] 225 | 226 | data = self.lcz_stack[-1] 227 | # uci_to_idx_index = 228 | # White, no-castling => 0 229 | # White, castling => 1 230 | # Black, no-castling => 2 231 | # Black, castling => 3 232 | uci_to_idx_index = (data.us_ooo | data.us_oo) + 2*data.side_to_move 233 | uci_idx_dct = _uci_to_idx[uci_to_idx_index] 234 | return [uci_idx_dct[m] for m in uci_list] 235 | 236 | @classmethod 237 | def compress_features(cls, features): 238 | """Compress a features array as returned from lcz_features method""" 239 | features_8 = features.astype(np.uint8) 240 | # Simple compression would do this... 241 | # return zlib.compress(features_8) 242 | piece_plane_bytes = np.packbits(features_8[:-8]).tobytes() 243 | scalar_bytes = features_8[-8:][:,0,0].tobytes() 244 | compressed = zlib.compress(piece_plane_bytes + scalar_bytes) 245 | return compressed 246 | 247 | @classmethod 248 | def decompress_features(cls, compressed_features): 249 | """Decompress a compressed features array from compress_features""" 250 | decompressed = zlib.decompress(compressed_features) 251 | # Simple decompression would do this 252 | # return np.frombuffer(decompressed, dtype=np.uint8).astype(np.float32).reshape(-1,8,8) 253 | piece_plane_bytes = decompressed[:-8] 254 | scalar_bytes = decompressed[-8:] 255 | piece_plane_arr = np.unpackbits(memoryview(piece_plane_bytes)) 256 | scalar_arr = np.frombuffer(scalar_bytes, dtype=np.uint8).repeat(64) 257 | result = np.concatenate((piece_plane_arr, scalar_arr)).astype(np.float32).reshape(-1,8,8) 258 | return result 259 | 260 | def unicode(self): 261 | if self.pc_board.is_game_over() or self.is_draw(): 262 | result = self.pc_board.result(claim_draw=True) 263 | turnstring = 'Result: {}'.format(result) 264 | else: 265 | turnstring = 'Turn: {}'.format('White' if self.pc_board.turn else 'Black') 266 | boardstr = self.pc_board.unicode() + "\n" + turnstring 267 | return boardstr 268 | 269 | def __repr__(self): 270 | return "LeelaBoard('{}')".format(self.pc_board.fen()) 271 | 272 | def _repr_svg_(self): 273 | return self.pc_board._repr_svg_() 274 | 275 | def __str__(self): 276 | if self.pc_board.is_game_over() or self.is_draw(): 277 | result = self.pc_board.result(claim_draw=True) 278 | turnstring = 'Result: {}'.format(result) 279 | else: 280 | turnstring = 'Turn: {}'.format('White' if self.pc_board.turn else 'Black') 281 | boardstr = self.pc_board.__str__() + "\n" + turnstring 282 | return boardstr 283 | 284 | def __eq__(self, other): 285 | return self.get_hash_key() == other.get_hash_key() 286 | 287 | def __hash__(self): 288 | return hash(self.get_hash_key()) 289 | 290 | def get_hash_key(self): 291 | transposition_key = self.pc_board._transposition_key() 292 | return (transposition_key + 293 | (self._lcz_transposition_counter[transposition_key], self.pc_board.halfmove_clock) + 294 | tuple(self.pc_board.move_stack[-7:]) 295 | ) 296 | 297 | # lb = LeelaBoard() 298 | # lb.push_uci('c2c4') 299 | #lb.push_uci('c7c5') 300 | #lb.push_uci('d2d3') 301 | #lb.push_uci('c2c4') 302 | #lb.push_uci('b8c6') 303 | # saved_planes = planes 304 | # planes = lb.features() 305 | # output = leela_net(torch.from_numpy(planes).unsqueeze(0)) 306 | # output 307 | -------------------------------------------------------------------------------- /src/lcztools/_old_leela_board.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | import chess 4 | import struct 5 | from lcztools._uci_to_idx import uci_to_idx as _uci_to_idx 6 | import zlib 7 | 8 | flat_planes = [] 9 | for i in range(256): 10 | flat_planes.append(np.ones((8,8), dtype=np.uint8)*i) 11 | 12 | OldLeelaBoardData = collections.namedtuple('OldLeelaBoardData', 13 | 'white_planes black_planes rep_planes ' 14 | 'transposition_key us_ooo us_oo them_ooo them_oo ' 15 | 'side_to_move rule50_count') 16 | 17 | def pc_board_property(propertyname): 18 | '''Create a property based on self.pc_board''' 19 | def prop(self): 20 | return getattr(self.pc_board, propertyname) 21 | return property(prop) 22 | 23 | class OldLeelaBoard: 24 | turn = pc_board_property('turn') 25 | move_stack = pc_board_property('move_stack') 26 | 27 | def __init__(self, leela_board = None, *args, **kwargs): 28 | '''If leela_board is passed as an argument, return a copy''' 29 | if leela_board: 30 | # Copy 31 | self.pc_board = leela_board.pc_board.copy(stack=False) 32 | self.lcz_stack = leela_board.lcz_stack[:] 33 | self._lcz_transposition_counter = leela_board._lcz_transposition_counter.copy() 34 | else: 35 | self.pc_board = chess.Board(*args, **kwargs) 36 | self.lcz_stack = [] 37 | self._lcz_transposition_counter = collections.Counter() 38 | self._lcz_push() 39 | self.is_game_over = self.pc_method('is_game_over') 40 | self.can_claim_draw = self.pc_method('can_claim_draw') 41 | self.generate_legal_moves = self.pc_method('generate_legal_moves') 42 | 43 | def copy(self): 44 | """Note! Currently the copy constructor uses pc_board.copy(stack=False), which makes pops impossible""" 45 | return self.__class__(leela_board=self) 46 | 47 | def pc_method(self, methodname): 48 | '''Return attribute of self.pc_board, useful for copying method bindings''' 49 | return getattr(self.pc_board, methodname) 50 | 51 | def is_threefold(self): 52 | transposition_key = self.pc_board._transposition_key() 53 | return self._lcz_transposition_counter[transposition_key] >= 3 54 | 55 | def is_fifty_moves(self): 56 | return self.pc_board.halfmove_clock >= 100 57 | 58 | def is_draw(self): 59 | return self.is_threefold() or self.is_fifty_moves() 60 | 61 | def _lcz_push(self): 62 | # print("_lcz_push") 63 | # Push data onto the lcz data stack after pushing board moves 64 | transposition_key = self.pc_board._transposition_key() 65 | self._lcz_transposition_counter.update((transposition_key,)) 66 | repetitions = self._lcz_transposition_counter[transposition_key] - 1 67 | # side_to_move = 0 if we're white, 1 if we're black 68 | side_to_move = 0 if self.pc_board.turn else 1 69 | rule50_count = self.pc_board.halfmove_clock 70 | # Figure out castling rights 71 | if not side_to_move: 72 | # we're white 73 | _c = self.pc_board.castling_rights 74 | us_ooo, us_oo = (_c>>chess.A1) & 1, (_c>>chess.H1) & 1 75 | them_ooo, them_oo = (_c>>chess.A8) & 1, (_c>>chess.H8) & 1 76 | else: 77 | # We're black 78 | _c = self.pc_board.castling_rights 79 | us_ooo, us_oo = (_c>>chess.A8) & 1, (_c>>chess.H8) & 1 80 | them_ooo, them_oo = (_c>>chess.A1) & 1, (_c>>chess.H1) & 1 81 | # Create 13 planes... 6 us, 6 them, repetitions>=1 82 | white_planes = [] 83 | black_planes = [] 84 | for color, planes in ((True, white_planes), (False, black_planes)): 85 | for piece_type in range(1,7): 86 | byts = struct.pack('>Q', self.pc_board.pieces_mask(piece_type, color)) 87 | arr = np.unpackbits(bytearray(byts))[::-1].reshape(8,8).astype(np.float32) 88 | planes.append(arr) 89 | # planes.append(flat_planes[repetitions>=1]) 90 | white_planes = np.stack(white_planes) 91 | black_planes = np.stack(black_planes) 92 | rep_planes = np.stack([flat_planes[repetitions>=1]]) 93 | lcz_data = OldLeelaBoardData( 94 | white_planes=white_planes, black_planes=black_planes, rep_planes=rep_planes, 95 | transposition_key=transposition_key, 96 | us_ooo=us_ooo, us_oo=us_oo, them_ooo=them_ooo, them_oo=them_oo, 97 | side_to_move=side_to_move, rule50_count=rule50_count 98 | ) 99 | self.lcz_stack.append(lcz_data) 100 | 101 | def push(self, move): 102 | self.pc_board.push(move) 103 | self._lcz_push() 104 | 105 | def push_uci(self, uci): 106 | self.pc_board.push_uci(uci) 107 | self._lcz_push() 108 | 109 | def push_san(self, san): 110 | self.pc_board.push_san(san) 111 | self._lcz_push() 112 | 113 | def pop(self): 114 | result = self.pc_board.pop() 115 | _lcz_data = self.lcz_stack.pop() 116 | self._lcz_transposition_counter.subtract((_lcz_data.transposition_key,)) 117 | return result 118 | 119 | def lcz_features(self): 120 | '''Get neural network input planes''' 121 | planes = [] 122 | curdata = self.lcz_stack[-1] 123 | for data in self.lcz_stack[-1:-9:-1]: 124 | if not curdata.side_to_move: 125 | # We're white 126 | planes.append(data.white_planes) 127 | planes.append(data.black_planes) 128 | planes.append(data.rep_planes) 129 | else: 130 | # We're black 131 | planes.append(data.black_planes[:,::-1]) 132 | planes.append(data.white_planes[:,::-1]) 133 | planes.append(data.rep_planes) 134 | planes = np.concatenate(planes) 135 | planes.resize((112,8,8), refcheck=False) 136 | planes[-8] = curdata.us_ooo 137 | planes[-7] = curdata.us_oo 138 | planes[-6] = curdata.them_ooo 139 | planes[-5] = curdata.them_oo 140 | planes[-4] = curdata.side_to_move 141 | planes[-3] = curdata.rule50_count 142 | planes[-2] = 0 143 | planes[-1] = 1 144 | return planes 145 | 146 | def lcz_features_debug(self, fake_history=False, no_history=False, real_history=7, rule50=None, allones=None): 147 | '''Get neural network input planes, with ability to modify based on parameters''' 148 | planes = [] 149 | curdata = self.lcz_stack[-1] 150 | if no_history: 151 | real_history = 0 152 | num_filled = 0 153 | for data in self.lcz_stack[-1:-9:-1]: 154 | if not curdata.side_to_move: 155 | # We're white 156 | planes.append(data.white_planes) 157 | planes.append(data.black_planes) 158 | planes.append(data.rep_planes) 159 | else: 160 | # We're black 161 | planes.append(data.black_planes[:,::-1]) 162 | planes.append(data.white_planes[:,::-1]) 163 | planes.append(data.rep_planes) 164 | num_filled +=1 165 | real_history -= 1 166 | if real_history<0: 167 | break 168 | # Augment with fake history, reusing last data 169 | if fake_history: 170 | for _ in range(8 - num_filled): 171 | if not curdata.side_to_move: 172 | # We're white 173 | planes.append(data.white_planes) 174 | planes.append(data.black_planes) 175 | planes.append(data.rep_planes) 176 | else: 177 | # We're black 178 | planes.append(data.black_planes[:,::-1]) 179 | planes.append(data.white_planes[:,::-1]) 180 | planes.append(data.rep_planes) 181 | planes = np.concatenate(planes) 182 | planes.resize(112,8,8) 183 | planes[-8] = curdata.us_ooo 184 | planes[-7] = curdata.us_oo 185 | planes[-6] = curdata.them_ooo 186 | planes[-5] = curdata.them_oo 187 | planes[-4] = curdata.side_to_move 188 | if rule50 is not None: 189 | planes[-3] = rule50 190 | else: 191 | planes[-3] = curdata.rule50_count 192 | planes[-2] = 0 193 | if allones is not None: 194 | planes[-1] = allones 195 | else: 196 | planes[-1] = 1 197 | return planes 198 | 199 | def lcz_uci_to_idx(self, uci_list): 200 | # Return list of NN policy output indexes for this board position, given uci_list 201 | 202 | # TODO: Perhaps it's possible to just add the uci knight promotion move to the index dict 203 | # currently knight promotions are not in the dict 204 | uci_list = [uci.rstrip('n') for uci in uci_list] 205 | 206 | data = self.lcz_stack[-1] 207 | # uci_to_idx_index = 208 | # White, no-castling => 0 209 | # White, castling => 1 210 | # Black, no-castling => 2 211 | # Black, castling => 3 212 | uci_to_idx_index = (data.us_ooo | data.us_oo) + 2*data.side_to_move 213 | uci_idx_dct = _uci_to_idx[uci_to_idx_index] 214 | return [uci_idx_dct[m] for m in uci_list] 215 | 216 | @classmethod 217 | def compress_features(cls, features): 218 | """Compress a features array as returned from lcz_features method""" 219 | features_8 = features.astype(np.uint8) 220 | # Simple compression would do this... 221 | # return zlib.compress(features_8) 222 | piece_plane_bytes = np.packbits(features_8[:-8]).tobytes() 223 | scalar_bytes = features_8[-8:][:,0,0].tobytes() 224 | compressed = zlib.compress(piece_plane_bytes + scalar_bytes) 225 | return compressed 226 | 227 | @classmethod 228 | def decompress_features(cls, compressed_features): 229 | """Decompress a compressed features array from compress_features""" 230 | decompressed = zlib.decompress(compressed_features) 231 | # Simple decompression would do this 232 | # return np.frombuffer(decompressed, dtype=np.uint8).astype(np.float32).reshape(-1,8,8) 233 | piece_plane_bytes = decompressed[:-8] 234 | scalar_bytes = decompressed[-8:] 235 | piece_plane_arr = np.unpackbits(bytearray(piece_plane_bytes)) 236 | scalar_arr = np.frombuffer(scalar_bytes, dtype=np.uint8).repeat(64) 237 | result = np.concatenate((piece_plane_arr, scalar_arr)).astype(np.float32).reshape(-1,8,8) 238 | return result 239 | 240 | 241 | def __repr__(self): 242 | return "OldLeelaBoard('{}')".format(self.pc_board.fen()) 243 | 244 | def _repr_svg_(self): 245 | return self.pc_board._repr_svg_() 246 | 247 | def __str__(self): 248 | boardstr = self.pc_board.__str__() + \ 249 | '\nTurn: {}'.format('White' if self.pc_board.turn else 'Black') 250 | return boardstr 251 | 252 | def __eq__(self, other): 253 | return self.get_hash_key() == other.get_hash_key() 254 | 255 | def __hash__(self): 256 | return hash(self.get_hash_key()) 257 | 258 | def get_hash_key(self): 259 | transposition_key = self.pc_board._transposition_key() 260 | return (transposition_key + 261 | (self._lcz_transposition_counter[transposition_key], self.pc_board.halfmove_clock) + 262 | tuple(self.pc_board.move_stack[-8:]) 263 | ) 264 | 265 | # lb = LeelaBoard() 266 | # lb.push_uci('c2c4') 267 | #lb.push_uci('c7c5') 268 | #lb.push_uci('d2d3') 269 | #lb.push_uci('c2c4') 270 | #lb.push_uci('b8c6') 271 | # saved_planes = planes 272 | # planes = lb.features() 273 | # output = leela_net(torch.from_numpy(planes).unsqueeze(0)) 274 | # output 275 | -------------------------------------------------------------------------------- /src/lcztools/_uci_to_idx.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # uci_to_idx is a list of four dicts {uci -> NN policy index} 4 | # 0 = white, no-castling 5 | # 1 = white, castling 6 | # 2 = black, no-castling 7 | # 3 = black, castling 8 | # Black moves are flipped, and castling moves are mapped to the e8a8, e8h8, e1a1, e1h1 indexes 9 | # from their respective UCI names 10 | 11 | uci_to_idx = [] 12 | 13 | # The index-to-uci list originates from here: 14 | # https://github.com/glinscott/leela-chess/blob/master/lc0/src/chess/bitboard.cc 15 | 16 | # White, no-castling 17 | _idx_to_move_wn = [ 18 | 'a1b1', 'a1c1', 'a1d1', 'a1e1', 'a1f1', 'a1g1', 'a1h1', 19 | 'a1a2', 'a1b2', 'a1c2', 'a1a3', 'a1b3', 'a1c3', 'a1a4', 20 | 'a1d4', 'a1a5', 'a1e5', 'a1a6', 'a1f6', 'a1a7', 'a1g7', 21 | 'a1a8', 'a1h8', 'b1a1', 'b1c1', 'b1d1', 'b1e1', 'b1f1', 22 | 'b1g1', 'b1h1', 'b1a2', 'b1b2', 'b1c2', 'b1d2', 'b1a3', 23 | 'b1b3', 'b1c3', 'b1d3', 'b1b4', 'b1e4', 'b1b5', 'b1f5', 24 | 'b1b6', 'b1g6', 'b1b7', 'b1h7', 'b1b8', 'c1a1', 'c1b1', 25 | 'c1d1', 'c1e1', 'c1f1', 'c1g1', 'c1h1', 'c1a2', 'c1b2', 26 | 'c1c2', 'c1d2', 'c1e2', 'c1a3', 'c1b3', 'c1c3', 'c1d3', 27 | 'c1e3', 'c1c4', 'c1f4', 'c1c5', 'c1g5', 'c1c6', 'c1h6', 28 | 'c1c7', 'c1c8', 'd1a1', 'd1b1', 'd1c1', 'd1e1', 'd1f1', 29 | 'd1g1', 'd1h1', 'd1b2', 'd1c2', 'd1d2', 'd1e2', 'd1f2', 30 | 'd1b3', 'd1c3', 'd1d3', 'd1e3', 'd1f3', 'd1a4', 'd1d4', 31 | 'd1g4', 'd1d5', 'd1h5', 'd1d6', 'd1d7', 'd1d8', 'e1a1', 32 | 'e1b1', 'e1c1', 'e1d1', 'e1f1', 'e1g1', 'e1h1', 'e1c2', 33 | 'e1d2', 'e1e2', 'e1f2', 'e1g2', 'e1c3', 'e1d3', 'e1e3', 34 | 'e1f3', 'e1g3', 'e1b4', 'e1e4', 'e1h4', 'e1a5', 'e1e5', 35 | 'e1e6', 'e1e7', 'e1e8', 'f1a1', 'f1b1', 'f1c1', 'f1d1', 36 | 'f1e1', 'f1g1', 'f1h1', 'f1d2', 'f1e2', 'f1f2', 'f1g2', 37 | 'f1h2', 'f1d3', 'f1e3', 'f1f3', 'f1g3', 'f1h3', 'f1c4', 38 | 'f1f4', 'f1b5', 'f1f5', 'f1a6', 'f1f6', 'f1f7', 'f1f8', 39 | 'g1a1', 'g1b1', 'g1c1', 'g1d1', 'g1e1', 'g1f1', 'g1h1', 40 | 'g1e2', 'g1f2', 'g1g2', 'g1h2', 'g1e3', 'g1f3', 'g1g3', 41 | 'g1h3', 'g1d4', 'g1g4', 'g1c5', 'g1g5', 'g1b6', 'g1g6', 42 | 'g1a7', 'g1g7', 'g1g8', 'h1a1', 'h1b1', 'h1c1', 'h1d1', 43 | 'h1e1', 'h1f1', 'h1g1', 'h1f2', 'h1g2', 'h1h2', 'h1f3', 44 | 'h1g3', 'h1h3', 'h1e4', 'h1h4', 'h1d5', 'h1h5', 'h1c6', 45 | 'h1h6', 'h1b7', 'h1h7', 'h1a8', 'h1h8', 'a2a1', 'a2b1', 46 | 'a2c1', 'a2b2', 'a2c2', 'a2d2', 'a2e2', 'a2f2', 'a2g2', 47 | 'a2h2', 'a2a3', 'a2b3', 'a2c3', 'a2a4', 'a2b4', 'a2c4', 48 | 'a2a5', 'a2d5', 'a2a6', 'a2e6', 'a2a7', 'a2f7', 'a2a8', 49 | 'a2g8', 'b2a1', 'b2b1', 'b2c1', 'b2d1', 'b2a2', 'b2c2', 50 | 'b2d2', 'b2e2', 'b2f2', 'b2g2', 'b2h2', 'b2a3', 'b2b3', 51 | 'b2c3', 'b2d3', 'b2a4', 'b2b4', 'b2c4', 'b2d4', 'b2b5', 52 | 'b2e5', 'b2b6', 'b2f6', 'b2b7', 'b2g7', 'b2b8', 'b2h8', 53 | 'c2a1', 'c2b1', 'c2c1', 'c2d1', 'c2e1', 'c2a2', 'c2b2', 54 | 'c2d2', 'c2e2', 'c2f2', 'c2g2', 'c2h2', 'c2a3', 'c2b3', 55 | 'c2c3', 'c2d3', 'c2e3', 'c2a4', 'c2b4', 'c2c4', 'c2d4', 56 | 'c2e4', 'c2c5', 'c2f5', 'c2c6', 'c2g6', 'c2c7', 'c2h7', 57 | 'c2c8', 'd2b1', 'd2c1', 'd2d1', 'd2e1', 'd2f1', 'd2a2', 58 | 'd2b2', 'd2c2', 'd2e2', 'd2f2', 'd2g2', 'd2h2', 'd2b3', 59 | 'd2c3', 'd2d3', 'd2e3', 'd2f3', 'd2b4', 'd2c4', 'd2d4', 60 | 'd2e4', 'd2f4', 'd2a5', 'd2d5', 'd2g5', 'd2d6', 'd2h6', 61 | 'd2d7', 'd2d8', 'e2c1', 'e2d1', 'e2e1', 'e2f1', 'e2g1', 62 | 'e2a2', 'e2b2', 'e2c2', 'e2d2', 'e2f2', 'e2g2', 'e2h2', 63 | 'e2c3', 'e2d3', 'e2e3', 'e2f3', 'e2g3', 'e2c4', 'e2d4', 64 | 'e2e4', 'e2f4', 'e2g4', 'e2b5', 'e2e5', 'e2h5', 'e2a6', 65 | 'e2e6', 'e2e7', 'e2e8', 'f2d1', 'f2e1', 'f2f1', 'f2g1', 66 | 'f2h1', 'f2a2', 'f2b2', 'f2c2', 'f2d2', 'f2e2', 'f2g2', 67 | 'f2h2', 'f2d3', 'f2e3', 'f2f3', 'f2g3', 'f2h3', 'f2d4', 68 | 'f2e4', 'f2f4', 'f2g4', 'f2h4', 'f2c5', 'f2f5', 'f2b6', 69 | 'f2f6', 'f2a7', 'f2f7', 'f2f8', 'g2e1', 'g2f1', 'g2g1', 70 | 'g2h1', 'g2a2', 'g2b2', 'g2c2', 'g2d2', 'g2e2', 'g2f2', 71 | 'g2h2', 'g2e3', 'g2f3', 'g2g3', 'g2h3', 'g2e4', 'g2f4', 72 | 'g2g4', 'g2h4', 'g2d5', 'g2g5', 'g2c6', 'g2g6', 'g2b7', 73 | 'g2g7', 'g2a8', 'g2g8', 'h2f1', 'h2g1', 'h2h1', 'h2a2', 74 | 'h2b2', 'h2c2', 'h2d2', 'h2e2', 'h2f2', 'h2g2', 'h2f3', 75 | 'h2g3', 'h2h3', 'h2f4', 'h2g4', 'h2h4', 'h2e5', 'h2h5', 76 | 'h2d6', 'h2h6', 'h2c7', 'h2h7', 'h2b8', 'h2h8', 'a3a1', 77 | 'a3b1', 'a3c1', 'a3a2', 'a3b2', 'a3c2', 'a3b3', 'a3c3', 78 | 'a3d3', 'a3e3', 'a3f3', 'a3g3', 'a3h3', 'a3a4', 'a3b4', 79 | 'a3c4', 'a3a5', 'a3b5', 'a3c5', 'a3a6', 'a3d6', 'a3a7', 80 | 'a3e7', 'a3a8', 'a3f8', 'b3a1', 'b3b1', 'b3c1', 'b3d1', 81 | 'b3a2', 'b3b2', 'b3c2', 'b3d2', 'b3a3', 'b3c3', 'b3d3', 82 | 'b3e3', 'b3f3', 'b3g3', 'b3h3', 'b3a4', 'b3b4', 'b3c4', 83 | 'b3d4', 'b3a5', 'b3b5', 'b3c5', 'b3d5', 'b3b6', 'b3e6', 84 | 'b3b7', 'b3f7', 'b3b8', 'b3g8', 'c3a1', 'c3b1', 'c3c1', 85 | 'c3d1', 'c3e1', 'c3a2', 'c3b2', 'c3c2', 'c3d2', 'c3e2', 86 | 'c3a3', 'c3b3', 'c3d3', 'c3e3', 'c3f3', 'c3g3', 'c3h3', 87 | 'c3a4', 'c3b4', 'c3c4', 'c3d4', 'c3e4', 'c3a5', 'c3b5', 88 | 'c3c5', 'c3d5', 'c3e5', 'c3c6', 'c3f6', 'c3c7', 'c3g7', 89 | 'c3c8', 'c3h8', 'd3b1', 'd3c1', 'd3d1', 'd3e1', 'd3f1', 90 | 'd3b2', 'd3c2', 'd3d2', 'd3e2', 'd3f2', 'd3a3', 'd3b3', 91 | 'd3c3', 'd3e3', 'd3f3', 'd3g3', 'd3h3', 'd3b4', 'd3c4', 92 | 'd3d4', 'd3e4', 'd3f4', 'd3b5', 'd3c5', 'd3d5', 'd3e5', 93 | 'd3f5', 'd3a6', 'd3d6', 'd3g6', 'd3d7', 'd3h7', 'd3d8', 94 | 'e3c1', 'e3d1', 'e3e1', 'e3f1', 'e3g1', 'e3c2', 'e3d2', 95 | 'e3e2', 'e3f2', 'e3g2', 'e3a3', 'e3b3', 'e3c3', 'e3d3', 96 | 'e3f3', 'e3g3', 'e3h3', 'e3c4', 'e3d4', 'e3e4', 'e3f4', 97 | 'e3g4', 'e3c5', 'e3d5', 'e3e5', 'e3f5', 'e3g5', 'e3b6', 98 | 'e3e6', 'e3h6', 'e3a7', 'e3e7', 'e3e8', 'f3d1', 'f3e1', 99 | 'f3f1', 'f3g1', 'f3h1', 'f3d2', 'f3e2', 'f3f2', 'f3g2', 100 | 'f3h2', 'f3a3', 'f3b3', 'f3c3', 'f3d3', 'f3e3', 'f3g3', 101 | 'f3h3', 'f3d4', 'f3e4', 'f3f4', 'f3g4', 'f3h4', 'f3d5', 102 | 'f3e5', 'f3f5', 'f3g5', 'f3h5', 'f3c6', 'f3f6', 'f3b7', 103 | 'f3f7', 'f3a8', 'f3f8', 'g3e1', 'g3f1', 'g3g1', 'g3h1', 104 | 'g3e2', 'g3f2', 'g3g2', 'g3h2', 'g3a3', 'g3b3', 'g3c3', 105 | 'g3d3', 'g3e3', 'g3f3', 'g3h3', 'g3e4', 'g3f4', 'g3g4', 106 | 'g3h4', 'g3e5', 'g3f5', 'g3g5', 'g3h5', 'g3d6', 'g3g6', 107 | 'g3c7', 'g3g7', 'g3b8', 'g3g8', 'h3f1', 'h3g1', 'h3h1', 108 | 'h3f2', 'h3g2', 'h3h2', 'h3a3', 'h3b3', 'h3c3', 'h3d3', 109 | 'h3e3', 'h3f3', 'h3g3', 'h3f4', 'h3g4', 'h3h4', 'h3f5', 110 | 'h3g5', 'h3h5', 'h3e6', 'h3h6', 'h3d7', 'h3h7', 'h3c8', 111 | 'h3h8', 'a4a1', 'a4d1', 'a4a2', 'a4b2', 'a4c2', 'a4a3', 112 | 'a4b3', 'a4c3', 'a4b4', 'a4c4', 'a4d4', 'a4e4', 'a4f4', 113 | 'a4g4', 'a4h4', 'a4a5', 'a4b5', 'a4c5', 'a4a6', 'a4b6', 114 | 'a4c6', 'a4a7', 'a4d7', 'a4a8', 'a4e8', 'b4b1', 'b4e1', 115 | 'b4a2', 'b4b2', 'b4c2', 'b4d2', 'b4a3', 'b4b3', 'b4c3', 116 | 'b4d3', 'b4a4', 'b4c4', 'b4d4', 'b4e4', 'b4f4', 'b4g4', 117 | 'b4h4', 'b4a5', 'b4b5', 'b4c5', 'b4d5', 'b4a6', 'b4b6', 118 | 'b4c6', 'b4d6', 'b4b7', 'b4e7', 'b4b8', 'b4f8', 'c4c1', 119 | 'c4f1', 'c4a2', 'c4b2', 'c4c2', 'c4d2', 'c4e2', 'c4a3', 120 | 'c4b3', 'c4c3', 'c4d3', 'c4e3', 'c4a4', 'c4b4', 'c4d4', 121 | 'c4e4', 'c4f4', 'c4g4', 'c4h4', 'c4a5', 'c4b5', 'c4c5', 122 | 'c4d5', 'c4e5', 'c4a6', 'c4b6', 'c4c6', 'c4d6', 'c4e6', 123 | 'c4c7', 'c4f7', 'c4c8', 'c4g8', 'd4a1', 'd4d1', 'd4g1', 124 | 'd4b2', 'd4c2', 'd4d2', 'd4e2', 'd4f2', 'd4b3', 'd4c3', 125 | 'd4d3', 'd4e3', 'd4f3', 'd4a4', 'd4b4', 'd4c4', 'd4e4', 126 | 'd4f4', 'd4g4', 'd4h4', 'd4b5', 'd4c5', 'd4d5', 'd4e5', 127 | 'd4f5', 'd4b6', 'd4c6', 'd4d6', 'd4e6', 'd4f6', 'd4a7', 128 | 'd4d7', 'd4g7', 'd4d8', 'd4h8', 'e4b1', 'e4e1', 'e4h1', 129 | 'e4c2', 'e4d2', 'e4e2', 'e4f2', 'e4g2', 'e4c3', 'e4d3', 130 | 'e4e3', 'e4f3', 'e4g3', 'e4a4', 'e4b4', 'e4c4', 'e4d4', 131 | 'e4f4', 'e4g4', 'e4h4', 'e4c5', 'e4d5', 'e4e5', 'e4f5', 132 | 'e4g5', 'e4c6', 'e4d6', 'e4e6', 'e4f6', 'e4g6', 'e4b7', 133 | 'e4e7', 'e4h7', 'e4a8', 'e4e8', 'f4c1', 'f4f1', 'f4d2', 134 | 'f4e2', 'f4f2', 'f4g2', 'f4h2', 'f4d3', 'f4e3', 'f4f3', 135 | 'f4g3', 'f4h3', 'f4a4', 'f4b4', 'f4c4', 'f4d4', 'f4e4', 136 | 'f4g4', 'f4h4', 'f4d5', 'f4e5', 'f4f5', 'f4g5', 'f4h5', 137 | 'f4d6', 'f4e6', 'f4f6', 'f4g6', 'f4h6', 'f4c7', 'f4f7', 138 | 'f4b8', 'f4f8', 'g4d1', 'g4g1', 'g4e2', 'g4f2', 'g4g2', 139 | 'g4h2', 'g4e3', 'g4f3', 'g4g3', 'g4h3', 'g4a4', 'g4b4', 140 | 'g4c4', 'g4d4', 'g4e4', 'g4f4', 'g4h4', 'g4e5', 'g4f5', 141 | 'g4g5', 'g4h5', 'g4e6', 'g4f6', 'g4g6', 'g4h6', 'g4d7', 142 | 'g4g7', 'g4c8', 'g4g8', 'h4e1', 'h4h1', 'h4f2', 'h4g2', 143 | 'h4h2', 'h4f3', 'h4g3', 'h4h3', 'h4a4', 'h4b4', 'h4c4', 144 | 'h4d4', 'h4e4', 'h4f4', 'h4g4', 'h4f5', 'h4g5', 'h4h5', 145 | 'h4f6', 'h4g6', 'h4h6', 'h4e7', 'h4h7', 'h4d8', 'h4h8', 146 | 'a5a1', 'a5e1', 'a5a2', 'a5d2', 'a5a3', 'a5b3', 'a5c3', 147 | 'a5a4', 'a5b4', 'a5c4', 'a5b5', 'a5c5', 'a5d5', 'a5e5', 148 | 'a5f5', 'a5g5', 'a5h5', 'a5a6', 'a5b6', 'a5c6', 'a5a7', 149 | 'a5b7', 'a5c7', 'a5a8', 'a5d8', 'b5b1', 'b5f1', 'b5b2', 150 | 'b5e2', 'b5a3', 'b5b3', 'b5c3', 'b5d3', 'b5a4', 'b5b4', 151 | 'b5c4', 'b5d4', 'b5a5', 'b5c5', 'b5d5', 'b5e5', 'b5f5', 152 | 'b5g5', 'b5h5', 'b5a6', 'b5b6', 'b5c6', 'b5d6', 'b5a7', 153 | 'b5b7', 'b5c7', 'b5d7', 'b5b8', 'b5e8', 'c5c1', 'c5g1', 154 | 'c5c2', 'c5f2', 'c5a3', 'c5b3', 'c5c3', 'c5d3', 'c5e3', 155 | 'c5a4', 'c5b4', 'c5c4', 'c5d4', 'c5e4', 'c5a5', 'c5b5', 156 | 'c5d5', 'c5e5', 'c5f5', 'c5g5', 'c5h5', 'c5a6', 'c5b6', 157 | 'c5c6', 'c5d6', 'c5e6', 'c5a7', 'c5b7', 'c5c7', 'c5d7', 158 | 'c5e7', 'c5c8', 'c5f8', 'd5d1', 'd5h1', 'd5a2', 'd5d2', 159 | 'd5g2', 'd5b3', 'd5c3', 'd5d3', 'd5e3', 'd5f3', 'd5b4', 160 | 'd5c4', 'd5d4', 'd5e4', 'd5f4', 'd5a5', 'd5b5', 'd5c5', 161 | 'd5e5', 'd5f5', 'd5g5', 'd5h5', 'd5b6', 'd5c6', 'd5d6', 162 | 'd5e6', 'd5f6', 'd5b7', 'd5c7', 'd5d7', 'd5e7', 'd5f7', 163 | 'd5a8', 'd5d8', 'd5g8', 'e5a1', 'e5e1', 'e5b2', 'e5e2', 164 | 'e5h2', 'e5c3', 'e5d3', 'e5e3', 'e5f3', 'e5g3', 'e5c4', 165 | 'e5d4', 'e5e4', 'e5f4', 'e5g4', 'e5a5', 'e5b5', 'e5c5', 166 | 'e5d5', 'e5f5', 'e5g5', 'e5h5', 'e5c6', 'e5d6', 'e5e6', 167 | 'e5f6', 'e5g6', 'e5c7', 'e5d7', 'e5e7', 'e5f7', 'e5g7', 168 | 'e5b8', 'e5e8', 'e5h8', 'f5b1', 'f5f1', 'f5c2', 'f5f2', 169 | 'f5d3', 'f5e3', 'f5f3', 'f5g3', 'f5h3', 'f5d4', 'f5e4', 170 | 'f5f4', 'f5g4', 'f5h4', 'f5a5', 'f5b5', 'f5c5', 'f5d5', 171 | 'f5e5', 'f5g5', 'f5h5', 'f5d6', 'f5e6', 'f5f6', 'f5g6', 172 | 'f5h6', 'f5d7', 'f5e7', 'f5f7', 'f5g7', 'f5h7', 'f5c8', 173 | 'f5f8', 'g5c1', 'g5g1', 'g5d2', 'g5g2', 'g5e3', 'g5f3', 174 | 'g5g3', 'g5h3', 'g5e4', 'g5f4', 'g5g4', 'g5h4', 'g5a5', 175 | 'g5b5', 'g5c5', 'g5d5', 'g5e5', 'g5f5', 'g5h5', 'g5e6', 176 | 'g5f6', 'g5g6', 'g5h6', 'g5e7', 'g5f7', 'g5g7', 'g5h7', 177 | 'g5d8', 'g5g8', 'h5d1', 'h5h1', 'h5e2', 'h5h2', 'h5f3', 178 | 'h5g3', 'h5h3', 'h5f4', 'h5g4', 'h5h4', 'h5a5', 'h5b5', 179 | 'h5c5', 'h5d5', 'h5e5', 'h5f5', 'h5g5', 'h5f6', 'h5g6', 180 | 'h5h6', 'h5f7', 'h5g7', 'h5h7', 'h5e8', 'h5h8', 'a6a1', 181 | 'a6f1', 'a6a2', 'a6e2', 'a6a3', 'a6d3', 'a6a4', 'a6b4', 182 | 'a6c4', 'a6a5', 'a6b5', 'a6c5', 'a6b6', 'a6c6', 'a6d6', 183 | 'a6e6', 'a6f6', 'a6g6', 'a6h6', 'a6a7', 'a6b7', 'a6c7', 184 | 'a6a8', 'a6b8', 'a6c8', 'b6b1', 'b6g1', 'b6b2', 'b6f2', 185 | 'b6b3', 'b6e3', 'b6a4', 'b6b4', 'b6c4', 'b6d4', 'b6a5', 186 | 'b6b5', 'b6c5', 'b6d5', 'b6a6', 'b6c6', 'b6d6', 'b6e6', 187 | 'b6f6', 'b6g6', 'b6h6', 'b6a7', 'b6b7', 'b6c7', 'b6d7', 188 | 'b6a8', 'b6b8', 'b6c8', 'b6d8', 'c6c1', 'c6h1', 'c6c2', 189 | 'c6g2', 'c6c3', 'c6f3', 'c6a4', 'c6b4', 'c6c4', 'c6d4', 190 | 'c6e4', 'c6a5', 'c6b5', 'c6c5', 'c6d5', 'c6e5', 'c6a6', 191 | 'c6b6', 'c6d6', 'c6e6', 'c6f6', 'c6g6', 'c6h6', 'c6a7', 192 | 'c6b7', 'c6c7', 'c6d7', 'c6e7', 'c6a8', 'c6b8', 'c6c8', 193 | 'c6d8', 'c6e8', 'd6d1', 'd6d2', 'd6h2', 'd6a3', 'd6d3', 194 | 'd6g3', 'd6b4', 'd6c4', 'd6d4', 'd6e4', 'd6f4', 'd6b5', 195 | 'd6c5', 'd6d5', 'd6e5', 'd6f5', 'd6a6', 'd6b6', 'd6c6', 196 | 'd6e6', 'd6f6', 'd6g6', 'd6h6', 'd6b7', 'd6c7', 'd6d7', 197 | 'd6e7', 'd6f7', 'd6b8', 'd6c8', 'd6d8', 'd6e8', 'd6f8', 198 | 'e6e1', 'e6a2', 'e6e2', 'e6b3', 'e6e3', 'e6h3', 'e6c4', 199 | 'e6d4', 'e6e4', 'e6f4', 'e6g4', 'e6c5', 'e6d5', 'e6e5', 200 | 'e6f5', 'e6g5', 'e6a6', 'e6b6', 'e6c6', 'e6d6', 'e6f6', 201 | 'e6g6', 'e6h6', 'e6c7', 'e6d7', 'e6e7', 'e6f7', 'e6g7', 202 | 'e6c8', 'e6d8', 'e6e8', 'e6f8', 'e6g8', 'f6a1', 'f6f1', 203 | 'f6b2', 'f6f2', 'f6c3', 'f6f3', 'f6d4', 'f6e4', 'f6f4', 204 | 'f6g4', 'f6h4', 'f6d5', 'f6e5', 'f6f5', 'f6g5', 'f6h5', 205 | 'f6a6', 'f6b6', 'f6c6', 'f6d6', 'f6e6', 'f6g6', 'f6h6', 206 | 'f6d7', 'f6e7', 'f6f7', 'f6g7', 'f6h7', 'f6d8', 'f6e8', 207 | 'f6f8', 'f6g8', 'f6h8', 'g6b1', 'g6g1', 'g6c2', 'g6g2', 208 | 'g6d3', 'g6g3', 'g6e4', 'g6f4', 'g6g4', 'g6h4', 'g6e5', 209 | 'g6f5', 'g6g5', 'g6h5', 'g6a6', 'g6b6', 'g6c6', 'g6d6', 210 | 'g6e6', 'g6f6', 'g6h6', 'g6e7', 'g6f7', 'g6g7', 'g6h7', 211 | 'g6e8', 'g6f8', 'g6g8', 'g6h8', 'h6c1', 'h6h1', 'h6d2', 212 | 'h6h2', 'h6e3', 'h6h3', 'h6f4', 'h6g4', 'h6h4', 'h6f5', 213 | 'h6g5', 'h6h5', 'h6a6', 'h6b6', 'h6c6', 'h6d6', 'h6e6', 214 | 'h6f6', 'h6g6', 'h6f7', 'h6g7', 'h6h7', 'h6f8', 'h6g8', 215 | 'h6h8', 'a7a1', 'a7g1', 'a7a2', 'a7f2', 'a7a3', 'a7e3', 216 | 'a7a4', 'a7d4', 'a7a5', 'a7b5', 'a7c5', 'a7a6', 'a7b6', 217 | 'a7c6', 'a7b7', 'a7c7', 'a7d7', 'a7e7', 'a7f7', 'a7g7', 218 | 'a7h7', 'a7a8', 'a7b8', 'a7c8', 'b7b1', 'b7h1', 'b7b2', 219 | 'b7g2', 'b7b3', 'b7f3', 'b7b4', 'b7e4', 'b7a5', 'b7b5', 220 | 'b7c5', 'b7d5', 'b7a6', 'b7b6', 'b7c6', 'b7d6', 'b7a7', 221 | 'b7c7', 'b7d7', 'b7e7', 'b7f7', 'b7g7', 'b7h7', 'b7a8', 222 | 'b7b8', 'b7c8', 'b7d8', 'c7c1', 'c7c2', 'c7h2', 'c7c3', 223 | 'c7g3', 'c7c4', 'c7f4', 'c7a5', 'c7b5', 'c7c5', 'c7d5', 224 | 'c7e5', 'c7a6', 'c7b6', 'c7c6', 'c7d6', 'c7e6', 'c7a7', 225 | 'c7b7', 'c7d7', 'c7e7', 'c7f7', 'c7g7', 'c7h7', 'c7a8', 226 | 'c7b8', 'c7c8', 'c7d8', 'c7e8', 'd7d1', 'd7d2', 'd7d3', 227 | 'd7h3', 'd7a4', 'd7d4', 'd7g4', 'd7b5', 'd7c5', 'd7d5', 228 | 'd7e5', 'd7f5', 'd7b6', 'd7c6', 'd7d6', 'd7e6', 'd7f6', 229 | 'd7a7', 'd7b7', 'd7c7', 'd7e7', 'd7f7', 'd7g7', 'd7h7', 230 | 'd7b8', 'd7c8', 'd7d8', 'd7e8', 'd7f8', 'e7e1', 'e7e2', 231 | 'e7a3', 'e7e3', 'e7b4', 'e7e4', 'e7h4', 'e7c5', 'e7d5', 232 | 'e7e5', 'e7f5', 'e7g5', 'e7c6', 'e7d6', 'e7e6', 'e7f6', 233 | 'e7g6', 'e7a7', 'e7b7', 'e7c7', 'e7d7', 'e7f7', 'e7g7', 234 | 'e7h7', 'e7c8', 'e7d8', 'e7e8', 'e7f8', 'e7g8', 'f7f1', 235 | 'f7a2', 'f7f2', 'f7b3', 'f7f3', 'f7c4', 'f7f4', 'f7d5', 236 | 'f7e5', 'f7f5', 'f7g5', 'f7h5', 'f7d6', 'f7e6', 'f7f6', 237 | 'f7g6', 'f7h6', 'f7a7', 'f7b7', 'f7c7', 'f7d7', 'f7e7', 238 | 'f7g7', 'f7h7', 'f7d8', 'f7e8', 'f7f8', 'f7g8', 'f7h8', 239 | 'g7a1', 'g7g1', 'g7b2', 'g7g2', 'g7c3', 'g7g3', 'g7d4', 240 | 'g7g4', 'g7e5', 'g7f5', 'g7g5', 'g7h5', 'g7e6', 'g7f6', 241 | 'g7g6', 'g7h6', 'g7a7', 'g7b7', 'g7c7', 'g7d7', 'g7e7', 242 | 'g7f7', 'g7h7', 'g7e8', 'g7f8', 'g7g8', 'g7h8', 'h7b1', 243 | 'h7h1', 'h7c2', 'h7h2', 'h7d3', 'h7h3', 'h7e4', 'h7h4', 244 | 'h7f5', 'h7g5', 'h7h5', 'h7f6', 'h7g6', 'h7h6', 'h7a7', 245 | 'h7b7', 'h7c7', 'h7d7', 'h7e7', 'h7f7', 'h7g7', 'h7f8', 246 | 'h7g8', 'h7h8', 'a8a1', 'a8h1', 'a8a2', 'a8g2', 'a8a3', 247 | 'a8f3', 'a8a4', 'a8e4', 'a8a5', 'a8d5', 'a8a6', 'a8b6', 248 | 'a8c6', 'a8a7', 'a8b7', 'a8c7', 'a8b8', 'a8c8', 'a8d8', 249 | 'a8e8', 'a8f8', 'a8g8', 'a8h8', 'b8b1', 'b8b2', 'b8h2', 250 | 'b8b3', 'b8g3', 'b8b4', 'b8f4', 'b8b5', 'b8e5', 'b8a6', 251 | 'b8b6', 'b8c6', 'b8d6', 'b8a7', 'b8b7', 'b8c7', 'b8d7', 252 | 'b8a8', 'b8c8', 'b8d8', 'b8e8', 'b8f8', 'b8g8', 'b8h8', 253 | 'c8c1', 'c8c2', 'c8c3', 'c8h3', 'c8c4', 'c8g4', 'c8c5', 254 | 'c8f5', 'c8a6', 'c8b6', 'c8c6', 'c8d6', 'c8e6', 'c8a7', 255 | 'c8b7', 'c8c7', 'c8d7', 'c8e7', 'c8a8', 'c8b8', 'c8d8', 256 | 'c8e8', 'c8f8', 'c8g8', 'c8h8', 'd8d1', 'd8d2', 'd8d3', 257 | 'd8d4', 'd8h4', 'd8a5', 'd8d5', 'd8g5', 'd8b6', 'd8c6', 258 | 'd8d6', 'd8e6', 'd8f6', 'd8b7', 'd8c7', 'd8d7', 'd8e7', 259 | 'd8f7', 'd8a8', 'd8b8', 'd8c8', 'd8e8', 'd8f8', 'd8g8', 260 | 'd8h8', 'e8e1', 'e8e2', 'e8e3', 'e8a4', 'e8e4', 'e8b5', 261 | 'e8e5', 'e8h5', 'e8c6', 'e8d6', 'e8e6', 'e8f6', 'e8g6', 262 | 'e8c7', 'e8d7', 'e8e7', 'e8f7', 'e8g7', 'e8a8', 'e8b8', 263 | 'e8c8', 'e8d8', 'e8f8', 'e8g8', 'e8h8', 'f8f1', 'f8f2', 264 | 'f8a3', 'f8f3', 'f8b4', 'f8f4', 'f8c5', 'f8f5', 'f8d6', 265 | 'f8e6', 'f8f6', 'f8g6', 'f8h6', 'f8d7', 'f8e7', 'f8f7', 266 | 'f8g7', 'f8h7', 'f8a8', 'f8b8', 'f8c8', 'f8d8', 'f8e8', 267 | 'f8g8', 'f8h8', 'g8g1', 'g8a2', 'g8g2', 'g8b3', 'g8g3', 268 | 'g8c4', 'g8g4', 'g8d5', 'g8g5', 'g8e6', 'g8f6', 'g8g6', 269 | 'g8h6', 'g8e7', 'g8f7', 'g8g7', 'g8h7', 'g8a8', 'g8b8', 270 | 'g8c8', 'g8d8', 'g8e8', 'g8f8', 'g8h8', 'h8a1', 'h8h1', 271 | 'h8b2', 'h8h2', 'h8c3', 'h8h3', 'h8d4', 'h8h4', 'h8e5', 272 | 'h8h5', 'h8f6', 'h8g6', 'h8h6', 'h8f7', 'h8g7', 'h8h7', 273 | 'h8a8', 'h8b8', 'h8c8', 'h8d8', 'h8e8', 'h8f8', 'h8g8', 274 | 'a7a8q', 'a7a8r', 'a7a8b', 'a7b8q', 'a7b8r', 'a7b8b', 'b7a8q', 275 | 'b7a8r', 'b7a8b', 'b7b8q', 'b7b8r', 'b7b8b', 'b7c8q', 'b7c8r', 276 | 'b7c8b', 'c7b8q', 'c7b8r', 'c7b8b', 'c7c8q', 'c7c8r', 'c7c8b', 277 | 'c7d8q', 'c7d8r', 'c7d8b', 'd7c8q', 'd7c8r', 'd7c8b', 'd7d8q', 278 | 'd7d8r', 'd7d8b', 'd7e8q', 'd7e8r', 'd7e8b', 'e7d8q', 'e7d8r', 279 | 'e7d8b', 'e7e8q', 'e7e8r', 'e7e8b', 'e7f8q', 'e7f8r', 'e7f8b', 280 | 'f7e8q', 'f7e8r', 'f7e8b', 'f7f8q', 'f7f8r', 'f7f8b', 'f7g8q', 281 | 'f7g8r', 'f7g8b', 'g7f8q', 'g7f8r', 'g7f8b', 'g7g8q', 'g7g8r', 282 | 'g7g8b', 'g7h8q', 'g7h8r', 'g7h8b', 'h7g8q', 'h7g8r', 'h7g8b', 283 | 'h7h8q', 'h7h8r', 'h7h8b' 284 | ] 285 | 286 | # White, no castling 287 | _uci_to_idx_wn = dict((uci, idx) for idx, uci in enumerate(_idx_to_move_wn)) 288 | 289 | # White, castling 290 | _uci_to_idx_wc = dict((uci, idx) for idx, uci in enumerate(_idx_to_move_wn)) 291 | _uci_to_idx_wc['e1g1'], _uci_to_idx_wc['e1h1'] = _uci_to_idx_wc['e1h1'], _uci_to_idx_wc['e1g1'] 292 | _uci_to_idx_wc['e1c1'], _uci_to_idx_wc['e1a1'] = _uci_to_idx_wc['e1a1'], _uci_to_idx_wc['e1c1'] 293 | 294 | 295 | # Black, no castling 296 | _idx_to_move_bn = [] 297 | for move in _idx_to_move_wn: 298 | c0,r0,c1,r1,p = move[0],int(move[1]),move[2],int(move[3]),move[4:] 299 | r0 = 9 - r0 300 | r1 = 9 - r1 301 | _idx_to_move_bn.append('{}{}{}{}{}'.format(c0,r0,c1,r1,p)) 302 | _uci_to_idx_bn = dict((uci, idx) for idx, uci in enumerate(_idx_to_move_bn)) 303 | 304 | # Black, castling 305 | _uci_to_idx_bc = dict((uci, idx) for idx, uci in enumerate(_idx_to_move_bn)) 306 | _uci_to_idx_bc['e8g8'], _uci_to_idx_bc['e8h8'] = _uci_to_idx_bc['e8h8'], _uci_to_idx_bc['e8g8'] 307 | _uci_to_idx_bc['e8c8'], _uci_to_idx_bc['e8a8'] = _uci_to_idx_bc['e8a8'], _uci_to_idx_bc['e8c8'] 308 | 309 | uci_to_idx = [_uci_to_idx_wn, _uci_to_idx_wc, _uci_to_idx_bn, _uci_to_idx_bc] 310 | -------------------------------------------------------------------------------- /src/lcztools/backend/__init__.py: -------------------------------------------------------------------------------- 1 | from lcztools.backend._leela_net import load_network 2 | from lcztools.backend._leela_net import LeelaNet, LeelaNetBase 3 | from lcztools.backend._leela_net import list_backends -------------------------------------------------------------------------------- /src/lcztools/backend/_leela_client_net.py: -------------------------------------------------------------------------------- 1 | """A client of net_server.server""" 2 | from lcztools import LeelaBoard 3 | import zmq 4 | import sys 5 | import threading 6 | import time 7 | from random import randint, random 8 | import numpy as np 9 | import os 10 | import pathlib 11 | from itertools import count 12 | 13 | from lcztools.backend import LeelaNetBase 14 | 15 | lcztools_tmp_path = pathlib.Path("/tmp/lcztools") 16 | 17 | class LeelaClientNet(LeelaNetBase): 18 | registered_clients = 0 19 | _lock = threading.Lock() 20 | 21 | @classmethod 22 | def register_client_id(cls): 23 | """Get a client ID based on PID and an incrementing value""" 24 | pid = os.getpid() 25 | with cls._lock: 26 | cls.registered_clients += 1 27 | return '{}-{}'.format(pid, cls.registered_clients) 28 | 29 | def __init__(self, policy_softmax_temp = 1.0, network_id=None, hi=True): 30 | super().__init__(policy_softmax_temp=policy_softmax_temp) 31 | network_id = 0 if network_id is None else network_id 32 | self.identity = self.register_client_id() 33 | self.context = zmq.Context() 34 | self.socket = self.context.socket(zmq.DEALER) 35 | self.socket.identity = self.identity.encode('ascii') 36 | socket_path = lcztools_tmp_path.joinpath('network_{}'.format(network_id)) 37 | if not pathlib.Path(socket_path).exists(): 38 | for cnt in count(): 39 | if cnt%30==0: 40 | print("Socket {} does not exist\nPlease start network server for network_id = {}".format(socket_path, network_id)) 41 | time.sleep(0.1) 42 | if pathlib.Path(socket_path).exists(): 43 | break 44 | self.socket.connect('ipc://{}'.format(socket_path)) 45 | if hi: 46 | self.hi() 47 | print("Connected to network server {}".format(network_id)) 48 | 49 | def call_model_eval(self, leela_board): 50 | message = leela_board.serialize_features() 51 | self.socket.send(message) 52 | response = self.socket.recv() 53 | # print("Got response!") 54 | response = memoryview(response) 55 | if len(response)==7436: # single precision 56 | value = np.frombuffer(response[:4], dtype=np.float32) 57 | policy = np.frombuffer(response[4:], dtype=np.float32) 58 | elif len(response)==3718: # half precision 59 | value = np.frombuffer(response[:2], dtype=np.float16) 60 | policy = np.frombuffer(response[2:], dtype=np.float16) 61 | return policy, value 62 | 63 | def hi(self): 64 | """Tell the server we're here, and it should be expecting some messages""" 65 | self.socket.send(bytes([1])) 66 | response = self.socket.recv() 67 | 68 | def bye(self): 69 | """Tell the server we're going away for a bit or forever, until we hi again""" 70 | self.socket.send(bytes([255])) 71 | response = self.socket.recv() 72 | 73 | def close(self): 74 | self.bye() 75 | self.socket.close() 76 | self.context.term() 77 | -------------------------------------------------------------------------------- /src/lcztools/backend/_leela_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | from lcztools.config import get_global_config 5 | from collections import OrderedDict 6 | 7 | def _softmax(x, softmax_temp): 8 | e_x = np.exp((x - np.max(x))/softmax_temp) 9 | return e_x / e_x.sum(axis=0) 10 | 11 | class LeelaNetBase: 12 | def __init__(self, policy_softmax_temp = 1.0): 13 | self.policy_softmax_temp = policy_softmax_temp 14 | 15 | def call_model_eval(self, leela_board): 16 | """Get policy and value from model - this needs to be implemented in subclasses""" 17 | raise NotImplementedError 18 | 19 | def evaluate(self, leela_board): 20 | policy, value = self.call_model_eval(leela_board) 21 | return self._evaluate(leela_board, policy, value) 22 | 23 | def _evaluate(self, leela_board, policy, value): 24 | """This is separated from evaluate so that subclasses can evaluate based on raw policy/value""" 25 | if not isinstance(policy, np.ndarray): 26 | # Assume it's a torch tensor 27 | policy = policy.cpu().numpy() 28 | value = value.cpu().numpy() 29 | # Results must be converted to float because operations 30 | # on numpy scalars can be very slow 31 | value = float(value[0]) 32 | # Knight promotions are represented without a suffix in leela-chess 33 | # ==> the transformation is done in lcz_uci_to_idx 34 | legal_uci = [m.uci() for m in leela_board.generate_legal_moves()] 35 | if legal_uci: 36 | legal_indexes = leela_board.lcz_uci_to_idx(legal_uci) 37 | softmaxed = _softmax(policy[legal_indexes], self.policy_softmax_temp) 38 | softmaxed_aspython = map(float, softmaxed) 39 | policy_legal = OrderedDict(sorted(zip(legal_uci, softmaxed_aspython), 40 | key = lambda mp: (mp[1], mp[0]), 41 | reverse=True)) 42 | else: 43 | policy_legal = OrderedDict() 44 | value = value/2 + 0.5 45 | return policy_legal, value 46 | 47 | 48 | class LeelaNet(LeelaNetBase): 49 | def __init__(self, model=None, policy_softmax_temp = 1.0, half=False): 50 | super().__init__(policy_softmax_temp=policy_softmax_temp) 51 | if half: 52 | self.dtype = np.float16 53 | else: 54 | self.dtype = np.float32 55 | self.model = model 56 | 57 | def evaluate_batch(self, leela_boards): 58 | # TODO/Not implemented 59 | raise NotImplementedError 60 | features = [] 61 | for board in leela_boards: 62 | features.append(board.features()) 63 | features = np.stack(features) 64 | policy, value = self.model(features) 65 | if not isinstance(policy[0], np.ndarray): 66 | # Assume it's a torch tensor 67 | policy = policy.numpy() 68 | value = value.numpy() 69 | policy, value = policy[0], value[0][0] 70 | legal_uci = [m.uci() for m in leela_board.generate_legal_moves()] 71 | if legal_uci: 72 | legal_indexes = leela_board.lcz_uci_to_idx(legal_uci) 73 | softmaxed = _softmax(policy[legal_indexes]) 74 | policy_legal = OrderedDict(sorted(zip(legal_uci, softmaxed), 75 | key = lambda mp: (mp[1], mp[0]), 76 | reverse=True)) 77 | else: 78 | policy_legal = OrderedDict() 79 | value = value/2 + 0.5 80 | return policy_legal, value 81 | 82 | 83 | def call_model_eval(self, leela_board): 84 | features = leela_board.lcz_features() 85 | features = features.astype(self.dtype) 86 | policy, value = self.model(features) 87 | return policy[0], value[0] 88 | 89 | 90 | 91 | def list_backends(): 92 | return ['pytorch_eval_cpu', 'pytorch_eval_cuda', 'pytorch_cpu', 'pytorch_cuda', 'tensorflow', 93 | 'pytorch_train_cpu', 'pytorch_train_cuda', 'net_client'] 94 | 95 | def load_network(filename=None, backend=None, policy_softmax_temp=None, network_id=None, half=None): 96 | # Config will handle filename in read_weights_file 97 | config = get_global_config() 98 | backend = backend or config.backend 99 | policy_softmax_temp = policy_softmax_temp or config.policy_softmax_temp 100 | backends = list_backends() 101 | 102 | print("Loading network using backend={}, policy_softmax_temp={}".format(backend, policy_softmax_temp)) 103 | if backend not in backends: 104 | raise Exception("Supported backends are {}".format(backends)) 105 | 106 | kwargs = {} 107 | if backend=='net_client': 108 | from lcztools.backend._leela_client_net import LeelaClientNet 109 | if filename != None: 110 | raise Exception('Weights file not allowed for net_client') 111 | if half is not None: 112 | print("Warning: half has no effect for LeelaClientNet -- this is done on server") 113 | return LeelaClientNet(policy_softmax_temp=policy_softmax_temp, network_id=network_id) 114 | 115 | half = half if half is not None else half 116 | if network_id != None: 117 | raise Exception("Network ID only for net_client backend") 118 | if backend == 'tensorflow': 119 | raise Exception("Tensorflow temporarily disabled, untested since latest changes") # Temporarily 120 | from lcztools.backend._leela_tf_net import LeelaLoader 121 | elif backend == 'pytorch_eval_cpu': 122 | from lcztools.backend._leela_torch_eval_net import LeelaLoader 123 | elif backend == 'pytorch_eval_cuda': 124 | from lcztools.backend._leela_torch_eval_net import LeelaLoader 125 | kwargs['cuda'] = True 126 | elif backend == 'pytorch_cpu': 127 | from lcztools.backend._leela_torch_net import LeelaLoader 128 | elif backend == 'pytorch_cuda': 129 | from lcztools.backend._leela_torch_net import LeelaLoader 130 | kwargs['cuda'] = True 131 | elif backend == 'pytorch_train_cpu': 132 | from lcztools.backend._leela_torch_net import LeelaLoader 133 | kwargs['train'] = True 134 | elif backend == 'pytorch_train_cuda': 135 | from lcztools.backend._leela_torch_net import LeelaLoader 136 | kwargs['cuda'] = True 137 | kwargs['train'] = True 138 | kwargs['half'] = half 139 | return LeelaNet(LeelaLoader.from_weights_file(filename, **kwargs), policy_softmax_temp=policy_softmax_temp, half=half) 140 | -------------------------------------------------------------------------------- /src/lcztools/backend/_leela_tf_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Note: yaml config and code to setup tfrprocess borrowed from here: 3 | # https://github.com/glinscott/leela-chess/tree/master/training/tf 4 | 5 | import tensorflow as tf 6 | import sys 7 | import os 8 | import sys 9 | import yaml 10 | import textwrap 11 | import gzip 12 | from lcztools.config import get_global_config 13 | 14 | # TODO: Hack!!! (This whole file is a hack) 15 | config = get_global_config() 16 | sys.path.append(os.path.expanduser(config.leela_training_tf_dir)) 17 | 18 | import tfprocess 19 | import tarfile 20 | import numpy as np 21 | 22 | from lcztools.weights import read_weights_file 23 | 24 | YAMLCFG = """ 25 | %YAML 1.2 26 | --- 27 | name: 'online-64x6' 28 | gpu: 0 29 | dataset: 30 | num_chunks: 200000 31 | train_ratio: 0.90 32 | training: 33 | batch_size: 2048 34 | total_steps: 60000 35 | shuffle_size: 1048576 36 | lr_values: 37 | - 0.04 38 | - 0.002 39 | lr_boundaries: 40 | - 35000 41 | policy_loss_weight: 1.0 42 | value_loss_weight: 1.0 43 | path: /dev/null 44 | model: 45 | filters: 64 46 | residual_blocks: 6 47 | ... 48 | """ 49 | YAMLCFG = textwrap.dedent(YAMLCFG).strip() 50 | 51 | 52 | 53 | class LeelaModel: 54 | def __init__(self, weights_filename): 55 | filters, blocks, weights = read_weights_file(weights_filename) 56 | cfg = yaml.safe_load(YAMLCFG) 57 | cfg['model']['filters'] = filters 58 | cfg['model']['residual_blocks'] = blocks 59 | cfg['name'] = 'online-{}x{}'.format(filters, blocks) 60 | print(yaml.dump(cfg, default_flow_style=False)) 61 | 62 | x = [ 63 | tf.placeholder(tf.float32, [None, 112, 8*8]), 64 | tf.placeholder(tf.float32, [None, 1858]), 65 | tf.placeholder(tf.float32, [None, 1]) 66 | ] 67 | 68 | self.tfp = tfprocess.TFProcess(cfg) 69 | self.tfp.init_net(x) 70 | self.tfp.replace_weights(weights) 71 | def __call__(self, input_planes): 72 | input_planes = input_planes.reshape(-1, 112, 8*8) 73 | policy, value = self.tfp.session.run([self.tfp.y_conv, self.tfp.z_conv], 74 | {self.tfp.x: input_planes, self.tfp.training:False}) 75 | # print("Policy:", policy) 76 | # print("Value:", value) 77 | return policy, value 78 | 79 | 80 | class LeelaLoader: 81 | @staticmethod 82 | def from_weights_file(filename, train=False): 83 | return LeelaModel(filename) 84 | 85 | -------------------------------------------------------------------------------- /src/lcztools/backend/_leela_torch_eval_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import gzip 5 | import zlib, base64 6 | import numpy as np 7 | import math 8 | 9 | from lcztools.weights import read_weights_file 10 | 11 | # This pytorch implementation slightly optimizes the original by using a simplified "Normalization" layer instead 12 | # of BatchNorm2d, with precalculated normalization/variance divisors: w = 1/torch.sqrt(w + 1e-5). 13 | # Without BatchNorm, this is only useful for eval, never training. 14 | 15 | class Normalization(nn.Module): 16 | r"""Applies per-channel transformation (x - mean)*stddiv 17 | """ 18 | 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.channels = channels 22 | self.mean = nn.Parameter(torch.Tensor(channels).unsqueeze(-1).unsqueeze(-1)) 23 | self.stddiv = nn.Parameter(torch.Tensor(channels).unsqueeze(-1).unsqueeze(-1)) 24 | 25 | def forward(self, x): 26 | return x.sub_(self.mean).mul_(self.stddiv) 27 | 28 | def extra_repr(self): 29 | return 'channels={}'.format( 30 | self.channels 31 | ) 32 | 33 | 34 | def prenormalize(conv, bn): 35 | new_weight = conv.weight * bn.stddiv.unsqueeze(-1) 36 | new_bias = ((conv.bias * bn.stddiv.flatten()) - 37 | (bn.mean.flatten() * bn.stddiv.flatten())) 38 | conv.weight.copy_(new_weight) 39 | conv.bias.copy_(new_bias) 40 | 41 | 42 | class ConvBlock(nn.Module): 43 | def __init__(self, kernel_size, input_channels, output_channels=None): 44 | super().__init__() 45 | if output_channels is None: 46 | output_channels = input_channels 47 | padding = kernel_size // 2 48 | self.conv1 = nn.Conv2d(input_channels, output_channels, kernel_size, stride=1, padding=padding, bias=True) 49 | self.conv1_bn = Normalization(output_channels) 50 | 51 | def prenormalize(self): 52 | prenormalize(self.conv1, self.conv1_bn) 53 | del self.conv1_bn 54 | 55 | def forward(self, x): 56 | out = self.conv1(x) 57 | out = F.relu(out, inplace=True) 58 | return out 59 | 60 | def _forward_(self, x): 61 | """Old forward without prenormaliztion""" 62 | out = self.conv1_bn(self.conv1(x)) 63 | out = F.relu(out, inplace=True) 64 | return out 65 | 66 | 67 | class ResidualBlock(nn.Module): 68 | def __init__(self, channels): 69 | super().__init__() 70 | self.conv1 = nn.Conv2d(channels, channels, 3, stride=1, padding=1, bias=True) 71 | self.conv1_bn = Normalization(channels) 72 | self.conv2 = nn.Conv2d(channels, channels, 3, stride=1, padding=1, bias=True) 73 | self.conv2_bn = Normalization(channels) 74 | 75 | def prenormalize(self): 76 | prenormalize(self.conv1, self.conv1_bn) 77 | prenormalize(self.conv2, self.conv2_bn) 78 | del self.conv1_bn 79 | del self.conv2_bn 80 | 81 | def forward(self, x): 82 | """Old forward without prenormaliztion""" 83 | out = self.conv1(x) 84 | out = F.relu(out, inplace=True) 85 | out = self.conv2(out) 86 | out += x 87 | out = F.relu(out, inplace=True) 88 | return out 89 | 90 | def _forward_(self, x): 91 | """Old forward without prenormaliztion""" 92 | out = self.conv1_bn(self.conv1(x)) 93 | out = F.relu(out, inplace=True) 94 | out = self.conv2_bn(self.conv2(out)) 95 | out += x 96 | out = F.relu(out, inplace=True) 97 | return out 98 | 99 | 100 | class LeelaModel(nn.Module): 101 | def __init__(self, channels, blocks): 102 | super().__init__() 103 | # 112 input channels 104 | self.conv_in = ConvBlock(kernel_size=3, 105 | input_channels=112, 106 | output_channels=channels) 107 | self.residual_blocks = [] 108 | for idx in range(blocks): 109 | block = ResidualBlock(channels) 110 | self.residual_blocks.append(block) 111 | self.add_module('residual_block{}'.format(idx+1), block) 112 | self.conv_pol = ConvBlock(kernel_size=1, 113 | input_channels=channels, 114 | output_channels=32) 115 | self.affine_pol = nn.Linear(32*8*8, 1858) 116 | self.conv_val = ConvBlock(kernel_size=1, 117 | input_channels=channels, 118 | output_channels=32) 119 | self.affine_val_1 = nn.Linear(32*8*8, 128) 120 | self.affine_val_2 = nn.Linear(128, 1) 121 | self.prenormalized = False 122 | 123 | def prenormalize(self): 124 | self.conv_in.prenormalize() 125 | self.conv_pol.prenormalize() 126 | self.conv_val.prenormalize() 127 | for block in self.residual_blocks: 128 | block.prenormalize() 129 | self.prenormalized = True 130 | 131 | def forward(self, x): 132 | if not self.prenormalized: 133 | raise Exception("Must call prenormalize first!") 134 | if isinstance(x, np.ndarray): 135 | x = torch.from_numpy(x) 136 | if next(self.parameters()).is_cuda: 137 | x = x.cuda() 138 | if x.ndimension() == 3: 139 | x = x.unsqueeze(0) 140 | out = self.conv_in(x) 141 | for block in self.residual_blocks: 142 | out = block(out) 143 | out_pol = self.conv_pol(out).view(-1, 32*8*8) 144 | out_pol = self.affine_pol(out_pol) 145 | out_val = self.conv_val(out).view(-1, 32*8*8) 146 | out_val = F.relu(self.affine_val_1(out_val), inplace=True) 147 | out_val = self.affine_val_2(out_val).tanh() 148 | return out_pol, out_val 149 | 150 | 151 | class LeelaLoader: 152 | @staticmethod 153 | def from_weights_file(filename, train=False, cuda=False, half=False): 154 | if cuda: 155 | torch.backends.cudnn.benchmark=True 156 | filters, blocks, weights = read_weights_file(filename) 157 | net = LeelaModel(filters, blocks) 158 | if not train: 159 | net.eval() 160 | for p in net.parameters(): 161 | p.requires_grad = False 162 | parameters = [] 163 | for module_name, module in net.named_modules(): 164 | class_name = module.__class__.__name__ 165 | for typ in ('weight', 'bias', 'mean', 'stddiv'): 166 | param = getattr(module, typ, None) 167 | if param is not None: 168 | parameters.append((module_name, class_name, typ, param)) 169 | for i, w in enumerate(weights): 170 | w = torch.Tensor(w) 171 | module_name, class_name, typ, param = parameters[i] 172 | # print(f"{tuple(w.size())} -- {module_name} - {class_name} - {typ}: {tuple(param.size())}") 173 | if class_name == 'Normalization' and typ=='stddiv': 174 | # print('NStddiv') 175 | w = 1/torch.sqrt(w + 1e-5) 176 | param.data.copy_(w.view_as(param)) 177 | if half: 178 | net.half() 179 | if cuda: 180 | print("Enabling CUDA!") 181 | net.cuda() 182 | net.prenormalize() 183 | return net 184 | -------------------------------------------------------------------------------- /src/lcztools/backend/_leela_torch_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import gzip 5 | import zlib, base64 6 | import numpy as np 7 | 8 | from lcztools.weights import read_weights_file 9 | 10 | 11 | class CenteredBatchNorm2d(nn.BatchNorm2d): 12 | """It appears the only way to get a trainable model with beta (bias) but not scale (weight 13 | is by keeping the weight data, even though it's not used""" 14 | 15 | def __init__(self, channels): 16 | super().__init__(channels, affine=True) 17 | self.weight.data.fill_(1) 18 | self.weight.requires_grad = False 19 | 20 | 21 | class ConvBlock(nn.Module): 22 | def __init__(self, kernel_size, input_channels, output_channels=None): 23 | super().__init__() 24 | if output_channels is None: 25 | output_channels = input_channels 26 | padding = kernel_size // 2 27 | self.conv1 = nn.Conv2d(input_channels, output_channels, kernel_size, stride=1, padding=padding, bias=False) 28 | self.conv1_bn = CenteredBatchNorm2d(output_channels) 29 | 30 | def forward(self, x): 31 | out = self.conv1_bn(self.conv1(x)) 32 | out = F.relu(out, inplace=True) 33 | return out 34 | 35 | 36 | class ResidualBlock(nn.Module): 37 | def __init__(self, channels): 38 | super().__init__() 39 | self.conv1 = nn.Conv2d(channels, channels, 3, stride=1, padding=1, bias=False) 40 | self.conv1_bn = CenteredBatchNorm2d(channels) 41 | self.conv2 = nn.Conv2d(channels, channels, 3, stride=1, padding=1, bias=False) 42 | self.conv2_bn = CenteredBatchNorm2d(channels) 43 | 44 | def forward(self, x): 45 | out = self.conv1_bn(self.conv1(x)) 46 | out = F.relu(out, inplace=True) 47 | out = self.conv2_bn(self.conv2(out)) 48 | out += x 49 | out = F.relu(out, inplace=True) 50 | return out 51 | 52 | 53 | class LeelaModel(nn.Module): 54 | def __init__(self, channels, blocks): 55 | super().__init__() 56 | # 112 input channels 57 | self.conv_in = ConvBlock(kernel_size=3, 58 | input_channels=112, 59 | output_channels=channels) 60 | self.residual_blocks = [] 61 | for idx in range(blocks): 62 | block = ResidualBlock(channels) 63 | self.residual_blocks.append(block) 64 | self.add_module('residual_block{}'.format(idx+1), block) 65 | self.conv_pol = ConvBlock(kernel_size=1, 66 | input_channels=channels, 67 | output_channels=32) 68 | self.affine_pol = nn.Linear(32*8*8, 1858) 69 | self.conv_val = ConvBlock(kernel_size=1, 70 | input_channels=channels, 71 | output_channels=32) 72 | self.affine_val_1 = nn.Linear(32*8*8, 128) 73 | self.affine_val_2 = nn.Linear(128, 1) 74 | 75 | def forward(self, x): 76 | if isinstance(x, np.ndarray): 77 | x = torch.from_numpy(x) 78 | if next(self.parameters()).is_cuda: 79 | x = x.cuda() 80 | x = x.view(-1, 112, 8, 8) 81 | out = self.conv_in(x) 82 | for block in self.residual_blocks: 83 | out = block(out) 84 | out_pol = self.conv_pol(out).view(-1, 32*8*8) 85 | out_pol = self.affine_pol(out_pol) 86 | out_val = self.conv_val(out).view(-1, 32*8*8) 87 | out_val = F.relu(self.affine_val_1(out_val), inplace=True) 88 | out_val = self.affine_val_2(out_val).tanh() 89 | return out_pol, out_val 90 | 91 | def save_weights_file(self, filename): 92 | LEELA_WEIGHTS_VERSION = '2' 93 | lines = [LEELA_WEIGHTS_VERSION] 94 | print("Saving weights file:") 95 | for module_name, module in self.named_modules(): 96 | print('.', end='') 97 | class_name = module.__class__.__name__ 98 | for typ in ('weight', 'bias', 'running_mean', 'running_var'): 99 | param = getattr(module, typ, None) 100 | if param is not None: 101 | # print(module_name, class_name, typ) 102 | if class_name == 'CenteredBatchNorm2d' and typ == 'weight': 103 | continue 104 | elif class_name == 'CenteredBatchNorm2d' and typ == 'bias': 105 | # print('-- updating') 106 | std = torch.sqrt(getattr(module, 'running_var').cpu().detach() + 1e-5) 107 | param_data = param.cpu().detach() * std 108 | else: 109 | param_data = param.cpu().detach() 110 | lines.append(' '.join(map(str, param_data.flatten().tolist()))) 111 | # lines.append('') 112 | with open(filename, 'w') as f: 113 | for line in lines: 114 | f.write(line) 115 | f.write('\n') 116 | print("Done saving weights!") 117 | 118 | 119 | class LeelaLoader: 120 | @staticmethod 121 | def from_weights_file(filename, train=False, cuda=False, half=False): 122 | if cuda: 123 | torch.backends.cudnn.benchmark=True 124 | filters, blocks, weights = read_weights_file(filename) 125 | net = LeelaModel(filters, blocks) 126 | if not train: 127 | net.eval() 128 | for p in net.parameters(): 129 | p.requires_grad = False 130 | parameters = [] 131 | for module_name, module in net.named_modules(): 132 | class_name = module.__class__.__name__ 133 | for typ in ('weight', 'bias', 'running_mean', 'running_var'): 134 | param = getattr(module, typ, None) 135 | if param is not None: 136 | if class_name == 'CenteredBatchNorm2d' and typ == 'weight': 137 | continue 138 | parameters.append((module_name, class_name, typ, param)) 139 | for i, w in enumerate(weights): 140 | w = torch.Tensor(w) 141 | module_name, class_name, typ, param = parameters[i] 142 | # print(param.shape) 143 | # print(f"{tuple(w.size())} -- {module_name} - {class_name} - {typ}: {tuple(param.size())}") 144 | if class_name == 'CenteredBatchNorm2d' and typ == 'bias': 145 | # Remember bias so it can be updated to a BatchNorm beta when the running_var is seen 146 | bn_bias_param = param 147 | if class_name == 'CenteredBatchNorm2d' and typ == 'running_var': 148 | # print("Updating bias") 149 | std = torch.sqrt(w + 1e-5) 150 | bn_bias_param.detach().div_(std.view_as(param)) 151 | param.data.copy_(w.view_as(param)) 152 | if half: 153 | net.half() 154 | if cuda: 155 | print("Enabling CUDA!") 156 | net.cuda() 157 | return net 158 | 159 | 160 | # Simple test to verify saving weights... 161 | # net = load_network() 162 | # 163 | # net.model.save_weights_file('test_weights_2.txt') 164 | # 165 | # import numpy as np 166 | # 167 | # with open('test_weights_2.txt') as f1: 168 | # with open('weights_run1_21754.txt') as f2: 169 | # for l1, l2 in zip(f1, f2): 170 | # l1 = np.array([float(s) for s in l1.strip().split()]) 171 | # l2 = np.array([float(s) for s in l2.strip().split()]) 172 | # s = sum((l1 - l2) ** 2) / len(l1) 173 | # print(s) -------------------------------------------------------------------------------- /src/lcztools/backend/net_server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trevor-graffa/lczero_tools/b02b2476bfdb8aba911bd83b7bba66424d31d016/src/lcztools/backend/net_server/__init__.py -------------------------------------------------------------------------------- /src/lcztools/backend/net_server/client_example.py: -------------------------------------------------------------------------------- 1 | from lcztools import LeelaBoard 2 | import zmq 3 | import sys 4 | import threading 5 | import time 6 | from random import randint, random 7 | import numpy as np 8 | 9 | 10 | class ClientTask(threading.Thread): 11 | """ClientTask""" 12 | def __init__(self, id): 13 | self.id = id 14 | threading.Thread.__init__ (self) 15 | self.board = LeelaBoard() 16 | self.board.push_uci('d2d4') 17 | 18 | def run(self): 19 | context = zmq.Context() 20 | socket = context.socket(zmq.DEALER) 21 | identity = u'client-%d' % self.id 22 | socket.identity = identity.encode('ascii') 23 | socket.connect('ipc:///tmp/lcztools/network_0') 24 | socket.send(bytes([1])) # Hi message 25 | socket.recv() 26 | message = self.board.serialize_features() 27 | for _ in range(20000): 28 | # print("Client {} sending message".format(identity)) 29 | # message = self.board.serialize_features() 30 | for _ in range(32): 31 | socket.send(message) 32 | # print("Send:", self.id) 33 | 34 | for _ in range(32): 35 | response = memoryview(socket.recv()) 36 | # print("Response:", self.id) 37 | # print ("Client {} received message of length {}".format(identity, len(response))) 38 | if len(response)==7436: # single precision 39 | value = np.frombuffer(response[:4], dtype=np.float32) 40 | policy = np.frombuffer(response[4:], dtype=np.float32) 41 | elif len(response)==3718: # half precision 42 | value = np.frombuffer(response[:2], dtype=np.float16) 43 | policy = np.frombuffer(response[2:], dtype=np.float16) 44 | # time.sleep(0.5) 45 | socket.close() 46 | context.term() 47 | 48 | def main(): 49 | """main function""" 50 | for i in range(32): 51 | client = ClientTask(i) 52 | client.start() 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /src/lcztools/backend/net_server/server.py: -------------------------------------------------------------------------------- 1 | import zmq 2 | import sys 3 | import threading 4 | import time 5 | from random import randint, random 6 | import numpy as np 7 | import os 8 | import contextlib 9 | from lcztools import load_network, LeelaBoard 10 | import math 11 | import pathlib 12 | import signal 13 | import queue 14 | 15 | 16 | 17 | FINISHED = False 18 | 19 | def signal_handler(signal, frame): 20 | """Signal handler started after threads start""" 21 | global FINISHED 22 | print('\nCtrl+C Pressed. Exiting') 23 | FINISHED = True 24 | 25 | lcztools_tmp_path = pathlib.Path("/tmp/lcztools") 26 | 27 | def clean_uds(): 28 | """Remove any network_* - unix domain sockets""" 29 | lcztools_tmp_path.mkdir(parents=True, exist_ok=True) 30 | for it in lcztools_tmp_path.glob('network_*'): 31 | it.unlink() 32 | 33 | 34 | class BatchProcessor(threading.Thread): 35 | """This is responsible for processing batches and sending responses to the clients""" 36 | _cuda_lock = threading.Lock() 37 | 38 | def __init__(self, zmq_context, model, network_id, queue_in, queue_out): 39 | super().__init__() 40 | self.model = model 41 | self.network_id = network_id 42 | socket_path = 'inproc://{}'.format(network_id) ## only for telling the receiver that a batch is ready 43 | self.context = zmq_context 44 | self.socket = self.context.socket(zmq.PUSH) 45 | self.socket.bind(socket_path) 46 | self.queue_in = queue_in 47 | self.queue_out = queue_out 48 | 49 | def run(self): 50 | while not FINISHED: 51 | item = self.queue_in.get() 52 | if item=='POISON': 53 | break 54 | features_stack, batch_ident = item 55 | with self._cuda_lock: 56 | pol, val = self.model(features_stack) 57 | pol = pol.cpu().numpy() 58 | val = val.cpu().numpy() 59 | self.queue_out.put((batch_ident, pol, val)) 60 | self.socket.send(b'') 61 | print("Network {} - Exiting batch processor".format(self.network_id)) 62 | self.socket.close() 63 | 64 | class DummyBatchProcessor(BatchProcessor): 65 | """This doesn't actually run the model. Just for testing speed""" 66 | 67 | def __init__(self, *args, **kwargs): 68 | super().__init__(*args, **kwargs) 69 | self.dummy_pol = np.random.rand(1858).astype(np.float32) 70 | self.dummy_val = np.random.rand(1).astype(np.float32) 71 | 72 | def run(self): 73 | while not FINISHED: 74 | item = self.queue_in.get() 75 | if item=='POISON': 76 | break 77 | features_stack, batch_ident = item 78 | with self._cuda_lock: 79 | pol, val = self.model(features_stack) 80 | length = len(batch_ident) 81 | pol = np.stack([self.dummy_pol]*length) 82 | val = np.stack([self.dummy_val]*length) 83 | self.queue_out.put((batch_ident, pol, val)) 84 | self.socket.send(b'') 85 | print("Network {} - Exiting batch processor".format(self.network_id)) 86 | self.socket.close() 87 | 88 | 89 | 90 | class Receiver(threading.Thread): 91 | """This is responsible for receiving messages and throwing them into a queue for the 92 | BatchProcessor""" 93 | def __init__(self, zmq_context, network_id, batch_queue, response_queue, max_batch_size, dtype): 94 | super().__init__() 95 | self.network_id = network_id 96 | self.batch_queue = batch_queue 97 | self.response_queue = response_queue 98 | self.dtype = dtype 99 | self.context = zmq_context 100 | self.socket = self.context.socket(zmq.ROUTER) 101 | socket_path = lcztools_tmp_path.joinpath('network_{}'.format(network_id)) 102 | self.socket.bind('ipc://{}'.format(socket_path)) 103 | batch_socket_path = 'inproc://{}'.format(network_id) 104 | self.batch_socket = self.context.socket(zmq.PULL) 105 | self.batch_socket.connect(batch_socket_path) 106 | self.poller = zmq.Poller() 107 | self.poller.register(self.socket, zmq.POLLIN) 108 | self.poller.register(self.batch_socket, zmq.POLLIN) 109 | self.batch_ident = [] 110 | self.batch_features = [] 111 | # self.batch_size = 1 112 | self.batches_processed = 0 113 | self.batches_processed_start = time.time() 114 | self.max_batch_size = 32 if not max_batch_size else max_batch_size 115 | assert(1 <= self.max_batch_size <= 2048) 116 | self.messages_received = 0 117 | self.messages_sent = 0 118 | self.elapsed_items_processed = 0 # Count of responses in current duration 119 | 120 | def queue_batch(self): 121 | """Put a batch into the queue for the BatchProcessor""" 122 | if not self.batch_ident: 123 | return 124 | # cur_features_stack = np.stack(self.batch_features).astype(self.dtype) 125 | # The following optimization is about twice as fast as np.stack 126 | item_shape = self.batch_features[0].shape 127 | cur_features_stack = np.frombuffer(b''.join(f for f in self.batch_features), dtype=np.uint8).reshape(-1,*item_shape) 128 | cur_features_stack = cur_features_stack.astype(self.dtype) 129 | cur_batch_ident = self.batch_ident[:] 130 | while True: 131 | try: 132 | self.batch_queue.put((cur_features_stack, cur_batch_ident), timeout=1) 133 | except queue.Full: 134 | if FINISHED: 135 | self.batch_ident.clear() 136 | self.batch_features.clear() 137 | return 138 | else: 139 | continue 140 | break 141 | self.batch_ident.clear() 142 | self.batch_features.clear() 143 | self.batches_processed += 1 144 | 145 | def run(self): 146 | blocked = False 147 | poll = self.poller.poll 148 | recv_multipart = self.socket.recv_multipart 149 | batch_ident_append = self.batch_ident.append 150 | deserialize_features = LeelaBoard.deserialize_features 151 | batch_features_append = self.batch_features.append 152 | batch_ident = self.batch_ident 153 | while not FINISHED: 154 | elapsed = time.time() - self.batches_processed_start 155 | if elapsed >= 5: # Print information about every 5 seconds 156 | # print('Network {} -- batch_size {}: {} bps, {} sps'.format(self.network_id, self.batch_size, 400/elapsed, 400*self.batch_size/elapsed)) 157 | if not blocked or self.elapsed_items_processed>0: 158 | print('Network {} -- {} evals/s'.format(self.network_id, self.elapsed_items_processed/elapsed)) 159 | sys.stdout.flush() 160 | self.batches_processed_start = time.time() 161 | self.elapsed_items_processed = 0 162 | socks = dict(poll(1000)) 163 | if not socks: 164 | # We've been blocked for 2 seconds. Let's set the batch size 165 | if not blocked or len(self.batch_ident): 166 | print("Network {} -- BLOCKED with {} items in batch".format(self.network_id, len(self.batch_ident))) 167 | if not len(self.batch_ident): 168 | blocked = True 169 | continue 170 | # Read any requests 171 | if self.socket in socks: 172 | try: 173 | while True: 174 | #check for a message, this will not block 175 | *ident, msg = recv_multipart(flags=zmq.NOBLOCK) 176 | self.messages_received += 1 177 | if len(msg)==1: 178 | if msg[0] == 1: # hi message 179 | # client_ident_set.add(ident) 180 | self.batch_size = self.max_batch_size 181 | elif msg[0] == 255: # bye message 182 | # client_ident_set.add(ident) 183 | pass 184 | self.socket.send_multipart([*ident, bytes([255])]) 185 | self.messages_sent += 1 186 | continue 187 | batch_ident_append(ident) 188 | batch_features_append(deserialize_features(msg)) 189 | if len(self.batch_ident)>=self.max_batch_size: 190 | self.queue_batch() 191 | except zmq.Again as e: 192 | pass 193 | if self.batch_queue.empty() and len(self.batch_ident): 194 | self.queue_batch() # Queue the rest of the items if the current batch queue is empty 195 | # Respond as needed 196 | if self.batch_socket in socks: 197 | message_in = self.batch_socket.recv() # throw away... 198 | response_batch_ident, pol, val = self.response_queue.get() 199 | for ident, policy, value in zip(response_batch_ident, pol, val): 200 | result = value.tobytes() + policy.tobytes() 201 | self.socket.send_multipart([*ident, result]) 202 | self.messages_sent += 1 203 | self.elapsed_items_processed += 1 204 | try: 205 | while True: # Empty batch queue just in case, so that batch processor doesn't hang 206 | self.batch_queue.get(False) 207 | except queue.Empty: 208 | pass 209 | self.batch_queue.put('POISON') # get the batch queue thread to stop 210 | try: 211 | while True: # Empty response queue just in case, so that batch processor doesn't hang 212 | self.response_queue.get(False) 213 | except queue.Empty: 214 | pass 215 | print("Network {} - Exiting receiver".format(self.network_id)) 216 | self.socket.close() 217 | self.batch_socket.close() 218 | 219 | 220 | 221 | class NetworkServer: 222 | """This class is responsible for binding and proxying the frontend, 223 | managing the network model, and managing the BatchProcessor and Receiver""" 224 | def __init__(self, context, weights_file, network_id, max_batch_size=None, half=False, dummy=False): 225 | super().__init__ () 226 | self.context = context 227 | self.network_id = network_id 228 | self.weights_file = weights_file 229 | self.model = None # won't load until run 230 | self.half = half 231 | if self.half: 232 | dtype = np.float16 233 | else: 234 | dtype = np.float32 235 | self.dummy = dummy # Use a stubbed out batch processor that doesn't actually compute 236 | self.batch_queue = queue.Queue(4) 237 | self.response_queue = queue.Queue(256) # Responses are in batches, so it's not important if this is big 238 | self.receiver = Receiver(self.context, network_id, self.batch_queue, self.response_queue, max_batch_size, dtype) 239 | self.batch_processor = None # won't load until run 240 | 241 | def load(self): 242 | if dummy: 243 | self.model = None 244 | self.batch_processor = DummyBatchProcessor(self.context, self.model, self.network_id, self.batch_queue, self.response_queue) 245 | else: 246 | with BatchProcessor._cuda_lock: 247 | net = load_network(backend='pytorch_cuda', filename=self.weights_file, half=self.half) 248 | self.model = net.model 249 | self.batch_processor = BatchProcessor(self.context, self.model, self.network_id, self.batch_queue, self.response_queue) 250 | 251 | def start(self): 252 | self.batch_processor.daemon = True # To prevent it from hanging after ctrl-c 253 | self.batch_processor.start() 254 | self.receiver.start() 255 | 256 | def join(self): 257 | self.receiver.join() 258 | time.sleep(1.5) # Give the batch processor a second to cleanup 259 | 260 | if __name__ == '__main__': 261 | if len(sys.argv) == 1: 262 | print("Usage: python -m lcztools.backend.net_server ARGS") 263 | print(" ARGS: [--half] [max_batch_size_0] [ [max_batch_size_1]...]") 264 | exit(1) 265 | clean_uds() 266 | context = zmq.Context() 267 | tasks = [] 268 | cur_weights_file = None 269 | args = sys.argv[1:] 270 | if '--half' in args: 271 | print("Using half precision") 272 | half = True 273 | else: 274 | print("Using single precision") 275 | half = False 276 | if '--dummy' in args: 277 | print("Using dummy processor") 278 | dummy = True 279 | else: 280 | dummy = False 281 | args = [arg for arg in args if arg not in ('--half', '--dummy')] 282 | max_batch_size = None 283 | for arg in args: 284 | if not cur_weights_file: 285 | if not pathlib.Path(arg).is_file(): 286 | raise("Bad filename: {}".format(arg)) 287 | cur_weights_file = arg 288 | continue 289 | max_batch_size = None 290 | try: 291 | max_batch_size = int(arg) 292 | except: 293 | pass 294 | tasks.append(NetworkServer(context, cur_weights_file, len(tasks), max_batch_size, half=half, dummy=dummy)) 295 | cur_weights_file = None 296 | if max_batch_size is None: 297 | if not pathlib.Path(arg).is_file(): 298 | raise("Bad filename: {}".format(arg)) 299 | cur_weights_file = arg 300 | continue 301 | if cur_weights_file: 302 | print(cur_weights_file) 303 | print(args) 304 | tasks.append(NetworkServer(context, cur_weights_file, len(tasks), max_batch_size, half=half, dummy=dummy)) 305 | print("Loading networks and starting server tasks") 306 | for task in tasks: 307 | task.load() 308 | signal.signal(signal.SIGINT, signal_handler) 309 | for task in tasks: 310 | task.start() 311 | print() 312 | print("STARTED: PRESS CTRL-C TO EXIT.") 313 | print() 314 | for task in tasks: 315 | task.join() 316 | context.term() 317 | -------------------------------------------------------------------------------- /src/lcztools/config.py: -------------------------------------------------------------------------------- 1 | '''Config handling 2 | 3 | Config is entered into an 'lcztools.ini' file, which is searched for in: 4 | 1. Script path (via sys.argv[0]) 5 | 2. Current-working-directory 6 | 3. Home directory 7 | ''' 8 | 9 | import configparser 10 | import sys 11 | import os 12 | 13 | class LCZToolsConfig: 14 | def __init__(self, dct={}): 15 | # This is directory where the network weights are stored 16 | self.weights_dir = dct.get('weights_dir') or '.' 17 | 18 | # This is the default weights filename to use, if none provided 19 | self.weights_file = dct.get('weights_file') or 'weights.txt.gz' 20 | 21 | # This is the default backend to use, if none provided 22 | self.backend = dct.get('backend') or 'pytorch' 23 | 24 | ## No longer supported 25 | ## This is the lczero engine to use, only currently used for testing/validation 26 | # self.lczero_engine = dct.get('lczero_engine') 27 | 28 | # This is the lc0 engine to use, only currently used for testing/validation 29 | self.lc0_engine = dct.get('lc0_engine') 30 | 31 | # This is, e.g. ~/sompath/leela-chess/training/tf -- currently only used as hackish tensorflow 32 | # mechanism, which is not used with pytorch backend 33 | self.leela_training_tf_dir = dct.get('leela_training_tf_dir') 34 | 35 | # Policy softmax temp 36 | self.policy_softmax_temp = dct.get('policy_softmax_temp') or 1.0 37 | self.policy_softmax_temp = float(self.policy_softmax_temp) 38 | 39 | def get_weights_filename(self, basename=None): 40 | '''Get weights filename given a base filename, or use default self.weights_file 41 | If filename provided is absolute, just return that. 42 | This way, lcztools.load_network('myweights.txt.gz') returns relative to weights_dir, 43 | but user may still use a full path name. 44 | ''' 45 | if basename: 46 | basename = os.path.expanduser(basename) 47 | if os.path.isabs(basename): 48 | return basename 49 | basename = basename or self.weights_file 50 | return os.path.join(os.path.expanduser(self.weights_dir), basename) 51 | 52 | 53 | def find_config_file(): 54 | '''Search for an lcztools.ini file, and return full pathname''' 55 | # 1. Search in main script directory 56 | filename = 'lcztools.ini' 57 | dirname = os.path.dirname(sys.argv[0]) 58 | dirname = os.path.abspath(dirname) 59 | fname = os.path.join(dirname, filename) 60 | if os.path.isfile(fname): 61 | return fname 62 | # 2. Search in cwd 63 | dirname = os.path.abspath('.') 64 | fname = os.path.join(dirname, filename) 65 | if os.path.isfile(fname): 66 | return fname 67 | # 3. Search in homedir 68 | dirname = os.path.expanduser('~') 69 | fname = os.path.join(dirname, filename) 70 | if os.path.isfile(fname): 71 | return fname 72 | # No config file found.. 73 | return None 74 | 75 | 76 | def set_global_config(filename=None, section='default'): 77 | '''Set global config. If filename is not provided, then attempt to find an lcztools.ini file''' 78 | global _global_config 79 | if filename is None: 80 | filename = find_config_file() 81 | if filename is None: 82 | # Empty config 83 | _global_config = LCZToolsConfig() 84 | return 85 | config = configparser.ConfigParser() 86 | config.read(filename) 87 | _global_config = LCZToolsConfig(config[section]) 88 | 89 | 90 | def get_global_config(): 91 | global _global_config 92 | if _global_config is None: 93 | set_global_config() 94 | return _global_config 95 | 96 | _global_config = None 97 | -------------------------------------------------------------------------------- /src/lcztools/testing/__init__.py: -------------------------------------------------------------------------------- 1 | from .leela_engine_lc0 import LC0Engine 2 | import importlib 3 | if importlib.find_loader('requests'): 4 | from .lczero_web import WebMatchGame 5 | from .lczero_web import WeightsDownloader 6 | from .train_parser import TarTrainingFile 7 | -------------------------------------------------------------------------------- /src/lcztools/testing/_archive_unused/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trevor-graffa/lczero_tools/b02b2476bfdb8aba911bd83b7bba66424d31d016/src/lcztools/testing/_archive_unused/__init__.py -------------------------------------------------------------------------------- /src/lcztools/testing/_archive_unused/leela_engine.py: -------------------------------------------------------------------------------- 1 | # This module will interact with the Leela chess engine to get position evaluations 2 | # Primary use case is threads=1 "go nodes 1" which is (usually) sufficient extract policy and value 3 | # from info string... In the case of only one possible move, a depth 1 search will be done at that node, 4 | # so it may not be possible to get the position's value (the root node's value is usually available 5 | # at child-nodes with 0 visits) 6 | 7 | 8 | import chess.uci 9 | from lcztools import LeelaBoard 10 | from collections import namedtuple, OrderedDict 11 | import re 12 | from operator import itemgetter 13 | from lcztools.config import get_global_config 14 | import os 15 | 16 | 17 | 18 | info_pattern = r' *(?P[\w\+\-\#\=]*) ->' \ 19 | r' *(?P\d*)' \ 20 | r' *\(V: *(?P[\d\.]*)\%\)' \ 21 | r' *\(N: *(?P[\d\.]*)\%\)' \ 22 | r' *PV: *(?P.*)$' 23 | 24 | info_regex = re.compile(info_pattern) 25 | 26 | InfoTuple = namedtuple('InfoTuple', 'san visits value policy pv') 27 | 28 | 29 | class LCZInfoHandler(chess.uci.InfoHandler): 30 | def __init__(self): 31 | self.lcz_strings = [] 32 | self.lcz_move_info = OrderedDict() 33 | super().__init__() 34 | def string(self, string): 35 | # Called whenever a complete info line has been processed. 36 | self.lcz_strings.append(string) 37 | match = info_regex.match(string) 38 | if match: 39 | info_tuple = InfoTuple(match.group('san'), 40 | int(match.group('visits')), 41 | float(match.group('value'))/100, 42 | float(match.group('policy'))/100, 43 | match.group('pv')) 44 | self.lcz_move_info[info_tuple.san] = info_tuple 45 | return super().string(string) 46 | def lcz_clear(self): 47 | self.lcz_strings.clear() 48 | self.lcz_move_info.clear() 49 | 50 | class LCZeroEngine: 51 | def __init__(self, engine_path=None, weights_file=None, 52 | threads=1, visits=1, nodes=1, start=True, 53 | logfile='lczero_log.txt', stderr='lczero.stderr.txt', 54 | ): 55 | config = get_global_config() 56 | engine_path = engine_path or config.lczero_engine 57 | engine_path = os.path.expanduser(engine_path) 58 | self.engine_path = engine_path 59 | self.weights_file = config.get_weights_filename(weights_file) 60 | self.threads = threads 61 | self.visits = visits 62 | self.nodes = nodes 63 | self.info_handler = LCZInfoHandler() 64 | self.engine = None 65 | self.logfile = logfile 66 | self.stderrfile = stderr 67 | self.stderr = None 68 | if start: 69 | self.start() 70 | def start(self): 71 | print("lczero outputting stderr to:", self.stderrfile) 72 | self.stderr = open(self.stderrfile, 'w') 73 | command = [self.engine_path] 74 | weights_file = os.path.expanduser(self.weights_file) 75 | command.extend(['-w', weights_file]) 76 | if self.threads is not None: 77 | command.extend(['-t', self.threads]) 78 | if self.visits is not None: 79 | command.extend(['-v', self.visits]) 80 | if self.logfile: 81 | print("lczero logging to: {}".format(self.logfile)) 82 | command.extend(['-l', self.logfile]) 83 | command = map(str, command) 84 | self.nodes = self.nodes 85 | self.engine = chess.uci.popen_engine(command, stderr=self.stderr) 86 | print("Leela engine started") 87 | self.engine.uci() 88 | self.engine.info_handlers.append(self.info_handler) 89 | def stop(self): 90 | self.engine.quit() 91 | try: 92 | self.stderr.close() 93 | except: 94 | pass 95 | def evaluate(self, board): 96 | '''returns a (bestmove, policy, value) tuple given a python-chess board where: 97 | policy is a mapping UCI=>value, sorted highest to lowest 98 | value is a float''' 99 | if isinstance(board, LeelaBoard): 100 | board = board.pc_board 101 | 102 | # Note/Important - this 'is_reversible' fix is needed to make sure that the engine 103 | # gets all history when calling self.engine.position(board) 104 | board = board.copy() 105 | board.is_irreversible = lambda move: False 106 | 107 | self.info_handler.lcz_clear() 108 | self.engine.position(board) 109 | bestmove = self.engine.go(nodes=self.nodes) 110 | value = None 111 | policy = {} 112 | for move in board.legal_moves: 113 | san = board.san(move) 114 | try: 115 | move_info = self.info_handler.lcz_move_info[san] 116 | except KeyError: 117 | san = san.replace('#', '+') 118 | move_info = self.info_handler.lcz_move_info[san] 119 | if value is None and move_info.visits==0: 120 | value = move_info.value 121 | # This is done so it matches leela-chess 122 | # Necessary??? uci = move.uci().rstrip('n') 123 | uci = move.uci() 124 | policy[uci] = move_info.policy 125 | return bestmove, OrderedDict(sorted(policy.items(), key=itemgetter(1), reverse=True)), value 126 | -------------------------------------------------------------------------------- /src/lcztools/testing/lczero_web/__init__.py: -------------------------------------------------------------------------------- 1 | from .web_game import WebMatchGame, WebTrainingGame 2 | from .networks import WeightsDownloader 3 | -------------------------------------------------------------------------------- /src/lcztools/testing/lczero_web/networks.py: -------------------------------------------------------------------------------- 1 | import bs4 2 | import requests 3 | import posixpath 4 | import os 5 | from collections import OrderedDict 6 | import shutil 7 | 8 | from lcztools.config import get_global_config 9 | from lcztools.util import lazy_property 10 | 11 | 12 | URL_BASE = 'http://lczero.org' 13 | 14 | class WeightsDownloader: 15 | def __init__(self, weights_dir = None, logging = True): 16 | '''Weights downloader 17 | 18 | weights_dir: leave empty to use default config (location for weights files)''' 19 | if weights_dir is None: 20 | cfg = get_global_config() 21 | weights_dir = cfg.weights_dir 22 | self.weights_dir = os.path.expanduser(weights_dir) 23 | if logging: 24 | self.log = lambda *args, **kwargs: print(*args, **kwargs) 25 | else: 26 | self.log = lambda *args, **kwargs: None 27 | 28 | @lazy_property 29 | def weights_urls(self): 30 | '''An OrderedDict, latest last''' 31 | self.log("Getting weights URLs") 32 | networks_page = requests.get(posixpath.join(URL_BASE, 'networks')).content 33 | soup = bs4.BeautifulSoup(networks_page, 'html.parser') 34 | pairs = [] 35 | for it in soup.tbody.find_all('a', download=True): 36 | url = posixpath.join(URL_BASE, it['href'].lstrip('/')) 37 | filename = it['download'] 38 | pairs.append((filename, url)) 39 | result = OrderedDict(reversed(pairs)) 40 | self.log("==> Success (Getting weights URLs)") 41 | return result 42 | 43 | @lazy_property 44 | def latest(self): 45 | '''Name of latest weights file''' 46 | return next(reversed(self.weights_urls)) 47 | 48 | def is_already_downloaded(self, filename): 49 | fullpath = os.path.join(self.weights_dir, filename) 50 | return os.path.isfile(os.path.expanduser(fullpath)) 51 | 52 | def download_latest_n(self, num, skip_already_downloaded = True): 53 | '''Download latest n weights file''' 54 | if num<1: 55 | raise Exception("Expected a positive number") 56 | for filename in list(self.weights_urls)[-num:]: 57 | self.download(filename, skip_already_downloaded) 58 | 59 | def download_latest(self, skip_already_downloaded = True): 60 | '''Download latest weights file''' 61 | filename = self.latest 62 | self.download(filename, skip_already_downloaded) 63 | 64 | def download_all(self, skip_already_downloaded = True): 65 | '''Download all weights files 66 | 67 | skip_already_downloaded: Skip files that have already been downloaded (default True)''' 68 | for filename in self.weights_urls: 69 | self.download(filename, skip_already_downloaded) 70 | 71 | def download(self, filename, skip_already_downloaded = True): 72 | '''Downlaod weights file''' 73 | self.log("Downloading weights file: {}".format(filename)) 74 | if filename not in self.weights_urls: 75 | raise Exception("Unknown file! {}".format(filename)) 76 | url = self.weights_urls[filename] 77 | if skip_already_downloaded: 78 | if self.is_already_downloaded(filename): 79 | self.log("==> Already downloaded (Downloading weights file: {})".format(filename)) 80 | return False 81 | fullpath = os.path.join(self.weights_dir, filename) 82 | fullpath_tmp = fullpath + '_download' 83 | with requests.get(url, stream=True) as r: 84 | with open(os.path.expanduser(fullpath_tmp), 'wb') as f: 85 | shutil.copyfileobj(r.raw, f) 86 | os.rename(fullpath_tmp, fullpath) 87 | self.log("==> Success (Downloading weights file: {})".format(filename)) 88 | -------------------------------------------------------------------------------- /src/lcztools/testing/lczero_web/web_game.py: -------------------------------------------------------------------------------- 1 | import posixpath 2 | import requests 3 | import bs4 4 | import re 5 | from lcztools import LeelaBoard 6 | from lcztools.util import lazy_property 7 | import chess 8 | import chess.pgn 9 | 10 | class WebGame: 11 | def __init__(self, url): 12 | '''Create a web match game object. 13 | 14 | URL may be the full URL, such as 'http://www.lczero.org/match_game/298660' 15 | or just a portion, like '298660'. Only the last portion is used''' 16 | self.url = str(url) 17 | 18 | @lazy_property 19 | def text(self): 20 | return requests.get(self.url).text 21 | 22 | @lazy_property 23 | def soup(self): 24 | return bs4.BeautifulSoup(self.text, 'html.parser') 25 | 26 | @lazy_property 27 | def movelist(self): 28 | movelist = re.search(r"pgnString: '(.*)'", self.text).group(1) \ 29 | .replace(r'\n', ' ') \ 30 | .replace(r'\x2b', '+') \ 31 | .replace(r'.', '. ') \ 32 | .strip() \ 33 | .split() 34 | return movelist 35 | 36 | @lazy_property 37 | def sans(self): 38 | '''This returns a list of san moves''' 39 | # Filter out move numbers and result 40 | sans = [m for m in self.movelist if re.match(r'^[^\d\*]', m)] 41 | return sans 42 | 43 | @lazy_property 44 | def result(self): 45 | return self.movelist[-1].replace('\\','') 46 | 47 | @lazy_property 48 | def board(self): 49 | board = chess.Board() 50 | for san in self.sans: 51 | board.push_san(san) 52 | return board 53 | 54 | @lazy_property 55 | def leela_board(self): 56 | board = LeelaBoard() 57 | for san in self.sans: 58 | board.push_san(san) 59 | return board 60 | 61 | @lazy_property 62 | def pgn_game(self): 63 | pgn_game = chess.pgn.Game.from_board(self.board) 64 | if pgn_game.headers['Result'] == '*': 65 | if self.board.can_claim_draw(): 66 | pgn_game.headers['Result'] = '1/2-1/2' 67 | elif len(self.board.move_stack) > 400: 68 | # 450 move rule.. We'll just adjudicate it as a draw if no result and over 400 moves 69 | pgn_game.headers['Result'] = '1/2-1/2' 70 | return pgn_game 71 | 72 | @lazy_property 73 | def pgn(self): 74 | return str(self.pgn_game) 75 | 76 | def get_leela_board_at(self, movenum=1, halfmoves=0): 77 | '''Get Leela board at given move number (*prior* to move) 78 | 79 | get_leela_board_at(12, 0): This will get the board on the 12th move, at white's turn 80 | get_leela_board_at(12, 1): This will get the board on the 12th move, at black's turn 81 | get_leela_board_at(halfmoves=3): This will return the 4th position (after 3 half-moves) 82 | ''' 83 | halfmoves = 2*(movenum-1) + halfmoves 84 | if halfmoves > len(self.sans): 85 | raise Exception('Not that many moves in game') 86 | board = LeelaBoard() 87 | for idx, san in enumerate(self.sans): 88 | if idx\w*) +' # The 19 | r'\((?P[\d]+) *\) +' # Move index 20 | r'N: *(?P\d+) +' # Visits 21 | r'\((?P[^\)]*)\) +' # Unknown/unused 22 | r'\(P: *(?P[\d\.]+)%\) +' # Policy percentage 23 | r'\(Q: *(?P[\d\.-]+)\) +' # Q value 24 | r'\(U: *(?P[\d\.-]+)\) +' # U value 25 | r'\(Q\+U: *(?P[\d\.-]+)\) +' # Q+U value 26 | r'\(V: *(?P[^ ]+)\)') # Value 27 | 28 | 29 | ## Example info strings/tests 30 | # t1 = 'info string g2g4 (378 ) N: 0 (+ 0) (P: 2.50%) (Q: -1.03320) (U: 0.85851) (Q+U: -0.17469) (V: -.----)' 31 | # t1 = t1.replace('info string ', '') 32 | # t2 = 'info string e2e4 (322 ) N: 2 (+ 8) (P: 14.04%) (Q: 0.04032) (U: 0.08678) (Q+U: 0.12710) (V: 0.0587)' 33 | # t2 = t2.replace('info string ', '') 34 | 35 | 36 | info_regex = re.compile(info_pattern) 37 | 38 | InfoTuple = namedtuple('InfoTuple', 'uci move_index visits policy q_value u_value q_u_value value_str') 39 | 40 | 41 | class LC0InfoHandler(chess.uci.InfoHandler): 42 | def __init__(self): 43 | self.lcz_strings = [] 44 | self.lcz_move_info = OrderedDict() 45 | super().__init__() 46 | 47 | def string(self, string): 48 | # Called whenever a complete info line has been processed. 49 | self.lcz_strings.append(string) 50 | match = info_regex.match(string) 51 | if match: 52 | info_tuple = InfoTuple(match.group('uci'), 53 | int(match.group('move_index')), 54 | int(match.group('visits')), 55 | float(match.group('policy'))/100, 56 | float(match.group('q_value')), 57 | float(match.group('u_value')), 58 | float(match.group('q_u_value')), 59 | match.group('value_str')) 60 | self.lcz_move_info[info_tuple.uci] = info_tuple 61 | return super().string(string) 62 | 63 | def lcz_clear(self): 64 | self.lcz_strings.clear() 65 | self.lcz_move_info.clear() 66 | 67 | 68 | class LC0Engine: 69 | def __init__(self, engine_path=None, weights_file=None, 70 | threads=1, nodes=1, backend='cudnn', start=True, 71 | policy_softmax_temp=None, 72 | logfile='lc0_log.txt', stderr='lc0.stderr.txt', 73 | ): 74 | config = get_global_config() 75 | engine_path = engine_path or config.lc0_engine 76 | engine_path = os.path.expanduser(engine_path) 77 | self.policy_softmax_temp = policy_softmax_temp or config.policy_softmax_temp 78 | self.engine_path = engine_path 79 | self.weights_file = config.get_weights_filename(weights_file) 80 | self.threads = threads 81 | self.nodes = nodes 82 | self.backend = backend 83 | self.info_handler = LC0InfoHandler() 84 | self.engine = None 85 | self.logfile = logfile 86 | self.stderrfile = stderr 87 | self.stderr = None 88 | if start: 89 | self.start() 90 | 91 | def start(self): 92 | print("lc0 outputting stderr to:", self.stderrfile) 93 | self.stderr = open(self.stderrfile, 'w') 94 | command = [self.engine_path, '--verbose-move-stats'] 95 | weights_file = os.path.expanduser(self.weights_file) 96 | command.extend(['-w', weights_file]) 97 | if self.threads is not None: 98 | command.extend(['-t', self.threads]) 99 | if self.policy_softmax_temp is not None: 100 | command.extend(['--policy-softmax-temp={}'.format(self.policy_softmax_temp)]) 101 | if self.logfile: 102 | print("lc0 logging to: {}".format(self.logfile)) 103 | command.extend(['-l', self.logfile]) 104 | command = list(map(str, command)) 105 | self.engine = chess.uci.popen_engine(command, stderr=self.stderr) 106 | print("Leela lc0 engine started with command: {}".format(' '.join(command))) 107 | self.engine.info_handlers.append(self.info_handler) 108 | self.engine.uci() 109 | 110 | 111 | def stop(self): 112 | self.engine.quit() 113 | try: 114 | self.stderr.close() 115 | except: 116 | pass 117 | 118 | def newgame(self): 119 | self.engine.ucinewgame() 120 | 121 | def evaluate(self, board): 122 | '''returns a (bestmove, policy, value) tuple given a python-chess board where: 123 | policy is a mapping UCI=>value, sorted highest to lowest 124 | value is a float''' 125 | if isinstance(board, LeelaBoard): 126 | board = board.pc_board 127 | 128 | # Note/Important - this 'is_reversible' fix is needed to make sure that the engine 129 | # gets all history when calling self.engine.position(board) 130 | board = board.copy() 131 | board.is_irreversible = lambda move: False 132 | 133 | self.info_handler.lcz_clear() 134 | self.engine.position(board) 135 | bestmove = self.engine.go(nodes=self.nodes) 136 | value = None 137 | policy = {} 138 | for move in board.legal_moves: 139 | uci = board.uci(move) 140 | move_info = self.info_handler.lcz_move_info[uci] 141 | if value is None and move_info.visits==0: 142 | value = (move_info.q_value + 1)/2 143 | policy[uci] = move_info.policy 144 | return bestmove, OrderedDict(sorted(policy.items(), key=itemgetter(1), reverse=True)), value 145 | -------------------------------------------------------------------------------- /src/lcztools/testing/train_parser.py: -------------------------------------------------------------------------------- 1 | '''Parse train data according to: 2 | https://github.com/glinscott/leela-chess/blob/master/training/tf/chunkparser.py#L115 3 | ''' 4 | import numpy as np 5 | import tarfile 6 | from collections import OrderedDict 7 | import struct 8 | import chess 9 | import chess.pgn 10 | import os 11 | from lcztools.util import tqdm, lazy_property 12 | 13 | try: 14 | from lcztools._uci_to_idx import uci_to_idx as _uci_to_idx 15 | _idx_to_uci_wn = {v: k for k, v in _uci_to_idx[0].items()} 16 | _idx_to_uci_wc = {v: k for k, v in _uci_to_idx[1].items()} 17 | _idx_to_uci_bn = {v: k for k, v in _uci_to_idx[2].items()} 18 | _idx_to_uci_bc = {v: k for k, v in _uci_to_idx[3].items()} 19 | IDX_TO_UCI = [ 20 | _idx_to_uci_wn, # White no castling 21 | _idx_to_uci_wc, # White castling 22 | _idx_to_uci_bn, # Black no castling 23 | _idx_to_uci_bc, # Black castling 24 | ] 25 | except: 26 | print("No lcztools") 27 | 28 | 29 | COLUMNS = 'abcdefgh' 30 | INDEXED_PIECES = list(enumerate(['P', 'N', 'B', 'R', 'Q', 'K'])) 31 | 32 | 33 | class TrainingRecord: 34 | SCALARS_STRUCT = struct.Struct('<7B1b') 35 | PROBS_STRUCT = struct.Struct('<1858f') 36 | PIECES_STRUCT = struct.Struct('<6Q') 37 | NP_FLOAT32 = np.dtype(' prob)''' 70 | # Global only for debug 71 | # global pmoves 72 | probs = TrainingRecord.PROBS_STRUCT.unpack_from(self.data[4:]) 73 | idx_to_uci_idx = 2*self.side_to_move + (self.us_ooo | self.us_oo) 74 | idx_to_uci = IDX_TO_UCI[idx_to_uci_idx] 75 | pmoves = [] 76 | for idx, prob in enumerate(probs): 77 | if prob > 0: 78 | pmoves.append((idx_to_uci[idx], prob)) 79 | pmoves = sorted(pmoves, key=lambda mp: (-mp[1], mp[0])) 80 | return OrderedDict(pmoves) 81 | 82 | def get_piece_plane(self, history_index, side_index, piece_index): 83 | '''Get a piece plane as an 8x8 numpy array, correctly flipped to normal (W-first) orientation 84 | 85 | history_index: range(0,8) 86 | side_index: 0 => our piece plane, 1 => their piece plane 87 | piece_index: index of piece''' 88 | # version_length + probs_length 89 | offset = 4 + 7432 90 | # + len(INDEXED_PIECES) * num_sides * history_index 91 | offset += 8 * (6 * 2 + 1) * history_index 92 | # + len(INDEXED_PIECES) * side_index 93 | offset += 8 * 6 * side_index 94 | # + piece_index 95 | offset += 8 * piece_index 96 | 97 | planebytes = self.data[offset:offset+8] 98 | if self.side_to_move == 0: 99 | return np.unpackbits(bytearray(planebytes)).reshape(8, 8) 100 | else: 101 | return np.unpackbits(bytearray(planebytes)).reshape(8, 8)[::-1] 102 | 103 | 104 | 105 | 106 | class TrainingGame: 107 | '''Parse training data bytes''' 108 | RECORD_SIZE = 8276 109 | 110 | def __init__(self, databytes, name): 111 | s = self.RECORD_SIZE 112 | self.records = [TrainingRecord(databytes[i:i+s]) for i in range(0, len(databytes),s)] 113 | self._cache = {} 114 | self.name = name 115 | 116 | def push_final_move(self, pc_board): 117 | '''Push the most likely final move, given python chess board and results in final record''' 118 | 119 | def test_final_move(): 120 | if (result==1) and pc_board.is_checkmate(): 121 | return True 122 | if (result==0) and (moveidx==449) and not pc_board.is_checkmate(): 123 | # 450 move rule... 124 | return True 125 | if (result==0) and (pc_board.is_insufficient_material() or 126 | pc_board.is_stalemate() or 127 | pc_board.can_claim_draw()): 128 | return True 129 | return False 130 | 131 | record = self.records[-1] 132 | result = record.result 133 | if result == -1: 134 | print("Error -- Last position has score of -1?") 135 | return 136 | uci_probs = record.get_probabilities() 137 | promotion_save = [] 138 | moveidx = len(pc_board.move_stack) 139 | 140 | for uci in uci_probs: 141 | try: 142 | pc_board.push_uci(uci) 143 | except ValueError: 144 | # This should be a knight promotion... 145 | # (But can be a queen promotion from older buggy engines) 146 | promotion_save.append(uci) 147 | uci = uci + 'n' 148 | pc_board.push_uci(uci) 149 | if test_final_move(): 150 | break 151 | pc_board.pop() 152 | else: 153 | # Just in case we can't find the final move... We should never get here on 154 | # new data 155 | for uci in promotion_save: 156 | uci = uci + 'q' 157 | if uci not in uci_probs: 158 | pc_board.push_uci(uci) 159 | if test_final_move(): 160 | break 161 | pc_board.pop() 162 | else: 163 | print("Error ({}) - no final move found!".format(self.name)) 164 | 165 | def get_move(self, move_index): 166 | '''Get UCI move, appropriately flipped from piece plane comparison 167 | 168 | This will not get the final move 169 | Returns: (piece, UCI) tuple''' 170 | record = self.records[move_index+1] 171 | piece_index = record._get_last_moved_piece_index() 172 | piece = INDEXED_PIECES[piece_index][1] 173 | 174 | arr1 = record.get_piece_plane(1, 1, piece_index) 175 | arr2 = record.get_piece_plane(0, 1, piece_index) 176 | 177 | rowfrom, colfrom = np.where(arr1 & ~arr2) 178 | rowto, colto = np.where(~arr1 & arr2) 179 | 180 | promotion = '' 181 | if not len(colfrom)==len(rowfrom)==len(colto)==len(rowto)==1: 182 | # This must be a pawn promotion... 183 | assert (len(colfrom)==len(rowfrom)==0) 184 | # Find where the pawn came from 185 | p_arr1 = record.get_piece_plane(1, 1, 0) 186 | p_arr2 = record.get_piece_plane(0, 1, 0) 187 | rowfrom, colfrom = np.where(p_arr1 & ~p_arr2) 188 | promotion = piece.lower() 189 | assert len(colfrom)==len(rowfrom)==len(colto)==len(rowto)==1 190 | rowfrom, colfrom = rowfrom[0], colfrom[0] 191 | rowto, colto = rowto[0], colto[0] 192 | uci = '{}{}{}{}{}'.format(COLUMNS[colfrom], rowfrom+1, 193 | COLUMNS[colto], rowto+1, promotion) 194 | return piece, uci 195 | 196 | 197 | def _get_move_orig(self, move_index): 198 | '''Get UCI move, appropriately flipped from piece plane comparison 199 | - This is a slower version of get_move() 200 | 201 | This will not get the final move 202 | Returns: (piece, UCI) tuple''' 203 | record1 = self.records[move_index] 204 | record2 = self.records[move_index+1] 205 | for piece_index, piece in reversed(INDEXED_PIECES): 206 | arr1 = record1.get_piece_plane(0, 0, piece_index) 207 | arr2 = record2.get_piece_plane(0, 1, piece_index) 208 | if not np.array_equal(arr1, arr2): 209 | arr1 = record1.get_piece_plane(0, 0, piece_index) 210 | arr2 = record2.get_piece_plane(0, 1, piece_index) 211 | rowfrom, colfrom = np.where(arr1 & ~arr2) 212 | rowto, colto = np.where(~arr1 & arr2) 213 | promotion = '' 214 | if not len(colfrom)==len(rowfrom)==len(colto)==len(rowto)==1: 215 | # This must be a pawn promotion... 216 | assert (len(colfrom)==len(rowfrom)==0) 217 | # Find where the pawn came from 218 | p_arr1 = record1.get_piece_plane(0, 0, 0) 219 | p_arr2 = record2.get_piece_plane(0, 1, 0) 220 | rowfrom, colfrom = np.where(p_arr1 & ~p_arr2) 221 | promotion = piece.lower() 222 | assert len(colfrom)==len(rowfrom)==len(colto)==len(rowto)==1 223 | rowfrom, colfrom = rowfrom[0], colfrom[0] 224 | rowto, colto = rowto[0], colto[0] 225 | uci = '{}{}{}{}{}'.format(COLUMNS[colfrom], rowfrom+1, 226 | COLUMNS[colto], rowto+1, promotion) 227 | return piece, uci 228 | else: 229 | raise Exception("I shouldn't be here") 230 | 231 | def get_all_moves(self): 232 | # TODO: Need to be able to get final move... 233 | all_moves = [self.get_move(move_index) for move_index in range(len(self.records)-1)] 234 | return all_moves 235 | 236 | def get_pc_board(self, with_final_move = True): 237 | cache_key = ('get_pc_board', with_final_move) 238 | if cache_key in self._cache: 239 | return self._cache[cache_key].copy() 240 | pc_board = chess.Board() 241 | for _piece, uci in self.get_all_moves(): 242 | pc_board.push_uci(uci) 243 | if with_final_move: 244 | self.push_final_move(pc_board) 245 | self._cache[cache_key] = pc_board 246 | return pc_board 247 | 248 | def get_pgn(self, with_final_move = True): 249 | cache_key = ('get_pgn', with_final_move) 250 | if cache_key in self._cache: 251 | return str(self._cache[cache_key]) 252 | pc_board = self.get_pc_board(with_final_move) 253 | pgn_game = chess.pgn.Game.from_board(pc_board) 254 | white_result = self.records[0].result 255 | pgn_game.headers["Event"] = self.name 256 | if white_result==1: 257 | pgn_game.headers["Result"] = "1-0" 258 | elif white_result==-1: 259 | pgn_game.headers["Result"] = "0-1" 260 | elif white_result==0: 261 | pgn_game.headers["Result"] = "1/2-1/2" 262 | else: 263 | print(white_result) 264 | raise Exception("Bad result") 265 | self._cache[cache_key] = pgn_game 266 | return str(pgn_game) 267 | 268 | 269 | class TarTrainingFile: 270 | '''Parse training data''' 271 | def __init__(self, filename): 272 | self.filename = filename 273 | 274 | def __iter__(self): 275 | '''Generator to iterate through data''' 276 | def generator(): 277 | with tarfile.open(self.filename) as f: 278 | for idx, member in enumerate(f): 279 | databytes = f.extractfile(member).read() 280 | yield TrainingGame(databytes, member.name) 281 | return generator() 282 | 283 | @lazy_property 284 | def archive_names(self): 285 | '''Read all the names from the training archive''' 286 | with tarfile.open(self.filename) as f: 287 | return f.getnames() 288 | 289 | def read_game(self, name): 290 | '''Read a single game from the archive''' 291 | name = str(name) 292 | names = self.archive_names 293 | with tarfile.open(self.filename) as f: 294 | # Search for any names that contain name 295 | names = [n for n in names if name in n] 296 | if len(names)==0: 297 | raise Exception("{} not found in {}".format(name, self.filename)) 298 | elif len(names)>1: 299 | raise Exception("Multiple occurrences of {} found in {}".format(name, self.filename)) 300 | databytes = f.extractfile(names[0]).read() 301 | return TrainingGame(databytes, names[0]) 302 | 303 | def to_pgn(self, filename=None, progress=True): 304 | if progress: 305 | progress = tqdm 306 | else: 307 | progress = lambda it: it 308 | if filename is None: 309 | dirname = os.path.dirname(self.filename) 310 | basename = os.path.basename(self.filename) 311 | basename = basename.split('.')[0] + '.pgn' 312 | filename = os.path.join(dirname, basename) 313 | assert(os.path.abspath(self.filename) != os.path.abspath(filename)) 314 | with open(filename, 'w') as pgn_file: 315 | for game in progress(self): 316 | pgn = game.get_pgn() 317 | pgn_file.write(pgn) 318 | pgn_file.write('\n\n\n') 319 | pgn_file.flush() 320 | -------------------------------------------------------------------------------- /src/lcztools/util/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from tqdm import tqdm 3 | except: 4 | def tqdm(iterator): 5 | print("Please install the real tqdm") 6 | yield from progress(iterator) 7 | 8 | def progress(iterator): 9 | for idx, stuff in enumerate(iterator): 10 | if idx%50==0: 11 | if idx>0: 12 | print() 13 | print('{:6}: '.format(idx), end='') 14 | yield stuff 15 | print('.', end='') 16 | if idx%50!=0: 17 | print() 18 | print('{:6}: '.format(idx+1), end='') 19 | print("DONE") 20 | 21 | 22 | def lazy_property(fn): 23 | '''Decorator that makes a property lazy-evaluated. 24 | From: https://stevenloria.com/lazy-properties/ 25 | ''' 26 | attr_name = '_lazy_' + fn.__name__ 27 | 28 | @property 29 | def _lazy_property(self): 30 | if not hasattr(self, attr_name): 31 | setattr(self, attr_name, fn(self)) 32 | return getattr(self, attr_name) 33 | return _lazy_property 34 | 35 | -------------------------------------------------------------------------------- /src/lcztools/util/_shuffle_buffer.py: -------------------------------------------------------------------------------- 1 | '''Maybe use this to help with training...''' 2 | 3 | import random as _random 4 | 5 | class ShuffleBufferEmptyException(Exception): 6 | pass 7 | 8 | 9 | class ShuffleBufferFullException(Exception): 10 | pass 11 | 12 | 13 | class ShuffleBuffer: 14 | '''A buffer that pops random items''' 15 | 16 | def __init__(self, length): 17 | self.buffer = [None]*length 18 | self.used = 0 19 | 20 | def full(self): 21 | return self.used == len(self.buffer) 22 | 23 | def push(self, item): 24 | if self.used == len(self.buffer): 25 | raise ShuffleBufferFullException 26 | self.buffer[self.used] = item 27 | self.used += 1 28 | 29 | def pop(self): 30 | if self.used == 0: 31 | raise ShuffleBufferEmptyException 32 | index = _random.randrange(self.used) 33 | item = self.buffer[index] 34 | self.used -= 1 35 | if self.used == index: 36 | # Just remove the last element 37 | self.buffer[self.used] = None 38 | else: 39 | # Move the last element to index 40 | self.buffer[index] = self.buffer[self.used] 41 | self.buffer[self.used] = None 42 | return item 43 | -------------------------------------------------------------------------------- /src/lcztools/weights/__init__.py: -------------------------------------------------------------------------------- 1 | from ._weights_file import read_weights_file 2 | -------------------------------------------------------------------------------- /src/lcztools/weights/_weights_file.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import os 3 | from lcztools.config import get_global_config 4 | 5 | # Note: Weight loading code taken from 6 | # https://github.com/glinscott/leela-chess/blob/master/training/tf/net_to_model.py 7 | 8 | 9 | LEELA_WEIGHTS_VERSION = '2' 10 | 11 | def read_weights_file(filename): 12 | config = get_global_config() 13 | filename = config.get_weights_filename(filename) 14 | filename = os.path.expanduser(filename) 15 | if '.gz' in filename: 16 | opener = gzip.open 17 | else: 18 | opener = open 19 | with opener(filename, 'r') as f: 20 | version = f.readline().decode('ascii') 21 | if version != '{}\n'.format(LEELA_WEIGHTS_VERSION): 22 | raise ValueError("Invalid version {}".format(version.strip())) 23 | weights = [] 24 | e = 0 25 | for line in f: 26 | line = line.decode('ascii').strip() 27 | if not line: 28 | continue 29 | e += 1 30 | weight = list(map(float, line.split(' '))) 31 | weights.append(weight) 32 | if e == 2: 33 | filters = len(line.split(' ')) 34 | print("Channels", filters) 35 | blocks = e - (4 + 14) 36 | if blocks % 8 != 0: 37 | raise ValueError("Inconsistent number of weights in the file - e = {}".format(e)) 38 | blocks //= 8 39 | print("Blocks", blocks) 40 | return (filters, blocks, weights) -------------------------------------------------------------------------------- /tests/Untitled.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "Test pytorch\n", 13 | "Loading network using backend=pytorch_cuda, policy_softmax_temp=2.2\n", 14 | "Channels 256\n", 15 | "Blocks 20\n", 16 | "Enabling CUDA!\n", 17 | "r n b q k b n r\n", 18 | "p p p p p p p p\n", 19 | ". . . . . . . .\n", 20 | ". . . . . . . .\n", 21 | ". . . . . . . .\n", 22 | ". . . . . . . .\n", 23 | "P P P P P P P P\n", 24 | "R N B Q K B N R\n", 25 | "Turn: White\n", 26 | "Policy: {\n", 27 | " \"d2d4\": 0.11119775474071503,\n", 28 | " \"c2c4\": 0.09062506258487701,\n", 29 | " \"e2e4\": 0.0897584781050682,\n", 30 | " \"g1f3\": 0.07733798772096634,\n", 31 | " \"e2e3\": 0.05697071924805641,\n", 32 | " \"g2g3\": 0.05233848839998245,\n", 33 | " \"c2c3\": 0.046211354434490204,\n", 34 | " \"a2a3\": 0.04585567116737366,\n", 35 | " \"d2d3\": 0.04334000498056412,\n", 36 | " \"h2h3\": 0.043053533881902695,\n", 37 | " \"b1c3\": 0.042753130197525024,\n", 38 | " \"a2a4\": 0.03928161412477493,\n", 39 | " \"b2b3\": 0.03883354365825653,\n", 40 | " \"f2f4\": 0.03694486990571022,\n", 41 | " \"h2h4\": 0.03529586270451546,\n", 42 | " \"b2b4\": 0.03448859602212906,\n", 43 | " \"b1a3\": 0.03385066241025925,\n", 44 | " \"g1h3\": 0.02947886288166046,\n", 45 | " \"f2f3\": 0.0273781456053257,\n", 46 | " \"g2g4\": 0.025005634874105453\n", 47 | "}\n", 48 | "Value: 0.5136711187660694\n", 49 | "r n b q k b n r\n", 50 | "p p p p p p p p\n", 51 | ". . . . . . . .\n", 52 | ". . . . . . . .\n", 53 | ". . . . P . . .\n", 54 | ". . . . . . . .\n", 55 | "P P P P . P P P\n", 56 | "R N B Q K B N R\n", 57 | "Turn: Black\n", 58 | "Policy: {\n", 59 | " \"c7c5\": 0.16294705867767334,\n", 60 | " \"e7e6\": 0.07480539381504059,\n", 61 | " \"d7d6\": 0.05852973461151123,\n", 62 | " \"e7e5\": 0.05767476186156273,\n", 63 | " \"g7g6\": 0.05732332542538643,\n", 64 | " \"b8c6\": 0.05544520914554596,\n", 65 | " \"h7h6\": 0.04756679758429527,\n", 66 | " \"a7a6\": 0.04726424440741539,\n", 67 | " \"g8f6\": 0.044885676354169846,\n", 68 | " \"d7d5\": 0.04398905485868454,\n", 69 | " \"c7c6\": 0.042789630591869354,\n", 70 | " \"a7a5\": 0.041947197169065475,\n", 71 | " \"b7b6\": 0.041355591267347336,\n", 72 | " \"b8a6\": 0.038210757076740265,\n", 73 | " \"h7h5\": 0.036929767578840256,\n", 74 | " \"g8h6\": 0.03256430849432945,\n", 75 | " \"f7f6\": 0.030394941568374634,\n", 76 | " \"g7g5\": 0.028698161244392395,\n", 77 | " \"f7f5\": 0.028635870665311813,\n", 78 | " \"b7b5\": 0.028042595833539963\n", 79 | "}\n", 80 | "Value: 0.47063818387687206\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "import json\n", 86 | "import sys\n", 87 | "import os\n", 88 | "import numpy as np\n", 89 | "\n", 90 | "\n", 91 | "\n", 92 | "import lcztools\n", 93 | "\n", 94 | "def json_default(obj):\n", 95 | " # Serialize numpy floats\n", 96 | " if isinstance(obj, np.floating):\n", 97 | " return float(obj)\n", 98 | " raise TypeError\n", 99 | "\n", 100 | "print(\"Test pytorch\")\n", 101 | "lcz_net = lcztools.load_network()\n", 102 | "lcz_board = lcztools.LeelaBoard()\n", 103 | "print(lcz_board)\n", 104 | "\n", 105 | "# import torch\n", 106 | "# features = lcz_board.lcz_features()\n", 107 | "# half_features = torch.HalfTensor(features)\n", 108 | "# half_model = lcz_net.model.half()\n", 109 | "# # lcz_net.model(half_features)\n", 110 | "# half_model(half_features)\n", 111 | "\n", 112 | "# blah\n", 113 | "\n", 114 | "policy, value = lcz_net.evaluate(lcz_board)\n", 115 | "print('Policy: {}'.format(json.dumps(policy, default=json_default, indent=3)))\n", 116 | "print('Value: {}'.format(value))\n", 117 | "\n", 118 | "lcz_board.push_uci('e2e4')\n", 119 | "print(lcz_board)\n", 120 | "policy, value = lcz_net.evaluate(lcz_board)\n", 121 | "print('Policy: {}'.format(json.dumps(policy, default=json_default, indent=3)))\n", 122 | "print('Value: {}'.format(value))\n" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "model = lcz_net.model\n", 132 | "model.is_half = True\n", 133 | "model.half()\n", 134 | "features = lcz_board.lcz_features()\n", 135 | "# policy, value = lcz_net.model(features)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "%%timeit\n", 145 | "policy, value = lcz_net.model(features)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "%%timeit\n", 155 | "policy, value = lcz_net.evaluate(lcz_board)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "lcz_net.model(np.stack([features]*10))" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "import torch\n", 174 | "features_stack = np.stack([features]*256)\n", 175 | "features_stack = torch.HalfTensor(features_stack).cuda()" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "%%timeit\n", 185 | "lcz_net.model(features_stack)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "features = lcz_board.lcz_features()" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "%%timeit\n", 204 | "lcz_board.push_uci('c7c5')\n", 205 | "lcz_board.pop()" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "256/37.9*1000" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "1000/4.63" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "stuff = half_features.flatten()" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "s = half_features.unsqueeze(0)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "s.numpy().shape" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "half_features.ndimension()" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "s.flatten().numpy().shape" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "import torch\n", 278 | "features = lcz_board.lcz_features()\n", 279 | "half_features = torch.HalfTensor(features)\n", 280 | "half_model = lcz_net.model.half()\n", 281 | "# lcz_net.model(half_features)\n", 282 | "half_model(half_features)\n" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "import torch\n", 292 | "t = torch.HalfTensor(features)" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "# t.view(-1,112,8,8).shape\n", 302 | "t.cuda().view(-1,112,8,8).half()" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "lcz_net.model.half??" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [] 320 | } 321 | ], 322 | "metadata": { 323 | "kernelspec": { 324 | "display_name": "Python [conda env:luna]", 325 | "language": "python", 326 | "name": "conda-env-luna-py" 327 | }, 328 | "language_info": { 329 | "codemirror_mode": { 330 | "name": "ipython", 331 | "version": 3 332 | }, 333 | "file_extension": ".py", 334 | "mimetype": "text/x-python", 335 | "name": "python", 336 | "nbconvert_exporter": "python", 337 | "pygments_lexer": "ipython3", 338 | "version": "3.7.0" 339 | } 340 | }, 341 | "nbformat": 4, 342 | "nbformat_minor": 2 343 | } 344 | -------------------------------------------------------------------------------- /tests/_archive_unused/test_net_eq_engine.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script just ensures that the engine is equal to the python implementation in evaluation 3 | If no exception is thrown, it has passed 4 | 5 | 6 | ''' 7 | 8 | import lcztools 9 | from lcztools.testing._archive_unused.leela_engine import LCZeroEngine 10 | import numpy as np 11 | import chess.pgn 12 | import time 13 | import json 14 | import collections 15 | 16 | 17 | engine = LCZeroEngine() 18 | board = lcztools.LeelaBoard() 19 | # engine.evaluate(board()) 20 | net = lcztools.load_network() 21 | 22 | def fix_policy_float(policy): 23 | '''Numpy to normal python float, for json dumps''' 24 | return collections.OrderedDict((k, float(v)) for k, v in policy.items()) 25 | 26 | def eval_equal(neteval, engineeval, tolerance=.00006): 27 | npol, nv = neteval 28 | epol, ev = engineeval 29 | for uci in npol: 30 | if abs(npol[uci] - epol[uci]) > tolerance: 31 | return False 32 | if (ev is not None) and (abs(nv - ev) > tolerance): 33 | return False 34 | return True 35 | 36 | net_eval_time = 0 37 | engine_eval_time = 0 38 | totalevals = 0 39 | numgames = 200 40 | for gamenum in range(numgames): 41 | print("Playing game {}/{}".format(gamenum+1,numgames)) 42 | board = lcztools.LeelaBoard() 43 | counter = 500 44 | while (counter>0) and (not board.is_game_over()): 45 | counter -= 1 46 | if board.can_claim_draw(): 47 | counter = max(counter, 10) 48 | 49 | clock = time.time() 50 | policy, value = net.evaluate(board) 51 | net_eval_time += time.time() - clock 52 | 53 | clock = time.time() 54 | bestmove, epolicy, evalue = engine.evaluate(board) 55 | engine_eval_time += time.time() - clock 56 | 57 | totalevals += 1 58 | 59 | if not eval_equal((policy, value), (epolicy, evalue)): 60 | print("Note equal...") 61 | print("Policy:", json.dumps(fix_policy_float(policy), indent=3)) 62 | print("Value:", value) 63 | print("Engine Bestmove:", bestmove) 64 | print("Engine Policy:", json.dumps(epolicy, indent=3)) 65 | print("Engine Value:", evalue) 66 | raise Exception("Not equal:", ' '.join(m.uci() for m in board.move_stack)) 67 | ucis = list(policy) 68 | pol_values = np.fromiter(policy.values(), dtype=np.float32) 69 | pol_values = pol_values/pol_values.sum() 70 | pol_index = np.random.choice(len(pol_values), p=pol_values) 71 | uci = ucis[pol_index] 72 | board.push_uci(uci) 73 | print() 74 | game = chess.pgn.Game.from_board(board.pc_board) 75 | print(game) 76 | print(board) 77 | print("Average net eval time : {:.6f}".format(net_eval_time/totalevals)) 78 | print("Average engine eval time: {:.6f}".format(engine_eval_time/totalevals)) 79 | print("All Evals Equal: PASS") 80 | print("Network looks good!") -------------------------------------------------------------------------------- /tests/download_latest_network.py: -------------------------------------------------------------------------------- 1 | from lcztools.testing import WeightsDownloader 2 | 3 | weights_downloader = WeightsDownloader() 4 | 5 | weights_downloader.download_latest() 6 | 7 | -------------------------------------------------------------------------------- /tests/lcztools.ini: -------------------------------------------------------------------------------- 1 | [default] 2 | # This is directory where the network weights are stored 3 | weights_dir = ~/projects/lczero/weights/txt 4 | 5 | # This is the default weights filename to use, if none provided 6 | weights_file = weights_run1_21754.txt.gz 7 | 8 | # This is where raw training files are stored 9 | training_raw_dir = /Volumes/SeagateExternal/leela_data/training_raw/ 10 | 11 | # This is the default backend to use, if none provided 12 | # Choices are: ['pytorch_eval_cpu', 'pytorch_eval_cuda', 'pytorch_cpu', 'pytorch_cuda', 'pytorch_train_cpu', 'pytorch_train_cuda', 'tensorflow'] 13 | backend = pytorch_cuda 14 | 15 | ### This is the lczero engine to use, only currently used for testing/validation 16 | ## NOTE: lczero no longer supported!!! 17 | ## lczero_engine = ~/git/leela-chess/release/lczero 18 | 19 | 20 | # This is the lc0 engine to use, only currently used for testing/validation 21 | lc0_engine = ~/.local/bin/lc0 22 | 23 | 24 | ### This is, e.g. ~/sompath/leela-chess/training/tf -- currently only used as hackish tensorflow 25 | ### mechanism. Not needed for pytorch backend 26 | ## leela_training_tf_dir = ~/git/leela-chess/training/tf/ 27 | 28 | # This is the policy softmax temp, like lc0's option 29 | policy_softmax_temp = 2.2 30 | -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import sys 4 | import os 5 | import numpy as np 6 | 7 | 8 | 9 | import lcztools 10 | 11 | def json_default(obj): 12 | # Serialize numpy floats 13 | if isinstance(obj, np.floating): 14 | return float(obj) 15 | raise TypeError 16 | 17 | print("Test pytorch") 18 | lcz_net = lcztools.load_network() 19 | lcz_board = lcztools.LeelaBoard() 20 | print(lcz_board) 21 | policy, value = lcz_net.evaluate(lcz_board) 22 | print('Policy: {}'.format(json.dumps(policy, default=json_default, indent=3))) 23 | print('Value: {}'.format(value)) 24 | 25 | lcz_board.push_uci('e2e4') 26 | print(lcz_board) 27 | policy, value = lcz_net.evaluate(lcz_board) 28 | print('Policy: {}'.format(json.dumps(policy, default=json_default, indent=3))) 29 | print('Value: {}'.format(value)) 30 | 31 | 32 | # print("Test tensorflow") 33 | # lcz_net = lcztools.load_network(backend='tensorflow') 34 | # lcz_board = lcztools.LeelaBoard() 35 | # print(lcz_board) 36 | # policy, value = lcz_net.evaluate(lcz_board) 37 | # print('Policy: {}'.format(json.dumps(policy, default=json_default, indent=3))) 38 | # print('Value: {}'.format(value)) 39 | # 40 | # lcz_board.push_uci('e2e4') 41 | # print(lcz_board) 42 | # policy, value = lcz_net.evaluate(lcz_board) 43 | # print('Policy: {}'.format(json.dumps(policy, default=json_default, indent=3))) 44 | # print('Value: {}'.format(value)) 45 | 46 | -------------------------------------------------------------------------------- /tests/test_net_eq_engine_lc0.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script just ensures that the engine is equal to the python implementation in evaluation 3 | If no exception is thrown, it has passed 4 | 5 | 6 | ''' 7 | 8 | import lcztools 9 | from lcztools.testing.leela_engine_lc0 import LC0Engine 10 | import numpy as np 11 | import chess.pgn 12 | import time 13 | import json 14 | import collections 15 | 16 | 17 | # NOTE: Policy values seem to be a tiny bit off from lc0... 18 | # The reason for this seems to be usage of src/mcts/node.cc::Edge::SetP in lc0 19 | # This function applies some rounding to policy values 20 | # 21 | # Changing tolerance from 0.00006 to 0.0005 22 | TOLERANCE =0.0005 23 | 24 | 25 | 26 | 27 | engine = LC0Engine() 28 | board = lcztools.LeelaBoard() 29 | # engine.evaluate(board()) 30 | net = lcztools.load_network() 31 | 32 | def fix_policy_float(policy): 33 | '''Numpy to normal python float, for json dumps''' 34 | return collections.OrderedDict((k, float(v)) for k, v in policy.items()) 35 | 36 | 37 | g_max_policy_error = 0 38 | g_max_value_error = 0 39 | g_mse_policy = 0 40 | g_mse_value = 0 41 | g_se_policy_sum = 0 42 | g_se_value_sum = 0 43 | g_policy_samples = 0 44 | g_value_samples = 0 45 | 46 | def eval_equal(neteval, engineeval, tolerance=TOLERANCE): 47 | global g_max_policy_error, g_max_value_error, g_mse_policy, g_mse_value, g_se_policy_sum, g_se_value_sum 48 | global g_policy_samples, g_value_samples 49 | npol, nv = neteval 50 | epol, ev = engineeval 51 | for uci in npol: 52 | policy_error = abs(npol[uci] - epol[uci]) 53 | g_max_policy_error = max(policy_error, g_max_policy_error) 54 | g_policy_samples += 1 55 | g_se_policy_sum += policy_error**2 56 | if policy_error > tolerance: 57 | print("Policy not equal: {}, {}, {}".format(uci, npol[uci], epol[uci])) 58 | return False 59 | g_mse_policy = g_se_policy_sum / g_policy_samples 60 | value_error = abs(nv - ev) 61 | g_max_value_error = max(value_error, g_max_value_error) 62 | g_value_samples += 1 63 | g_se_value_sum += value_error**2 64 | g_mse_value = g_se_value_sum / g_value_samples 65 | if value_error > tolerance: 66 | print("Value not equal: {}, {}".format(nv, ev)) 67 | return False 68 | return True 69 | 70 | net_eval_time = 0 71 | engine_eval_time = 0 72 | totalevals = 0 73 | numgames = 200 74 | for gamenum in range(numgames): 75 | print("Playing game {}/{}".format(gamenum+1,numgames)) 76 | board = lcztools.LeelaBoard() 77 | counter = 500 78 | engine.newgame() # TODO -- It looks like without this, lc0 gives bad values. 79 | while (counter>0) and (not board.is_game_over()): 80 | counter -= 1 81 | if board.can_claim_draw(): 82 | counter = max(counter, 10) 83 | 84 | clock = time.time() 85 | policy, value = net.evaluate(board) 86 | net_eval_time += time.time() - clock 87 | 88 | clock = time.time() 89 | bestmove, epolicy, evalue = engine.evaluate(board) 90 | engine_eval_time += time.time() - clock 91 | 92 | totalevals += 1 93 | 94 | if not eval_equal((policy, value), (epolicy, evalue)): 95 | print("Not equal...") 96 | print("Policy:", json.dumps(fix_policy_float(policy), indent=3)) 97 | print("Value:", value) 98 | print("Engine Bestmove:", bestmove) 99 | print("Engine Policy:", json.dumps(epolicy, indent=3)) 100 | print("Engine Value:", evalue) 101 | raise Exception("Not equal:", ' '.join(m.uci() for m in board.move_stack)) 102 | ucis = list(policy) 103 | pol_values = np.fromiter(policy.values(), dtype=np.float32) 104 | pol_values = pol_values/pol_values.sum() 105 | pol_index = np.random.choice(len(pol_values), p=pol_values) 106 | uci = ucis[pol_index] 107 | board.push_uci(uci) 108 | print() 109 | game = chess.pgn.Game.from_board(board.pc_board) 110 | print(game) 111 | print(board) 112 | print("Average net eval time : {:.6f}".format(net_eval_time/totalevals)) 113 | print("Average engine eval time: {:.6f}".format(engine_eval_time/totalevals)) 114 | print("Max policy error: {:.7f}".format(g_max_policy_error)) 115 | print("Max value error: {:.7f}".format(g_max_value_error)) 116 | print("Policy MSE: {}".format(g_mse_policy)) 117 | print("Value MSE: {}".format(g_mse_value)) 118 | print("All Evals Equal: PASS") 119 | print("Network looks good!") 120 | -------------------------------------------------------------------------------- /tests/train_to_pgn.py: -------------------------------------------------------------------------------- 1 | from lcztools.testing import TarTrainingFile 2 | from lcztools.util import tqdm 3 | import fire 4 | 5 | def train_to_pgn(filename): 6 | TarTrainingFile(filename).to_pgn() 7 | 8 | if __name__ == '__main__': 9 | fire.Fire(train_to_pgn) 10 | -------------------------------------------------------------------------------- /tests/update_rule50_weights.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Update Leela Chess's network's rule50 input weights by multiplying them by a 3 | constant coefficient 4 | 5 | Requires: 6 | pip install fire 7 | 8 | Usage: 9 | python update_rule50_weights.py filename rule50_multiplier 10 | 11 | Creates a new filename of the form 12 | basename__rule50__mult.ext 13 | 14 | E.g., 15 | python update_rule50_weights.py weights_345.txt.gz 0 16 | ''' 17 | 18 | 19 | import gzip 20 | import os 21 | from itertools import chain 22 | import fire 23 | 24 | 25 | LEELA_WEIGHTS_VERSION = '2' 26 | 27 | 28 | def _update_rule50_weights(input_weights_line, rule50_multiplier): 29 | '''Given input weights: 192x112x3x3, modify only the following... 30 | 31 | [range(112*9*i + 109*9, 112*9*i + 110*9) for i in range(output_channels)] 32 | ''' 33 | weights = input_weights_line.strip().split() 34 | output_channels = len(weights)//(112*9) 35 | assert(output_channels == len(weights)/(112*9)) 36 | update_indices = chain(*(range(112*9*i + 109*9, 112*9*i + 110*9) for i in range(output_channels))) 37 | for i in update_indices: 38 | weights[i] = str(float(weights[i])*rule50_multiplier) 39 | return ' '.join(weights) + '\n' 40 | 41 | 42 | 43 | def update_r50_weights(filename, rule50_multiplier): 44 | rule50_multiplier = float(rule50_multiplier) 45 | dirname = os.path.dirname(filename) 46 | basename, ext = os.path.basename(filename).split('.', 1) 47 | outbase = '{}__rule50__{}'.format(basename, str(rule50_multiplier).replace('.', '_')) 48 | out_filename = os.path.join(dirname, '{}.{}'.format(outbase, ext)) 49 | print(out_filename) 50 | if '.gz' in filename: 51 | opener = gzip.open 52 | else: 53 | opener = open 54 | output_lines = [] 55 | with opener(filename, 'r') as f: 56 | version = f.readline().decode('ascii') 57 | output_lines.append(version) 58 | version = version.strip() 59 | if version != '{}'.format(LEELA_WEIGHTS_VERSION): 60 | print(filename) 61 | raise ValueError("Invalid version {}".format(version.strip())) 62 | for idx, line in enumerate(f): 63 | line = line.decode('ascii') 64 | if idx==0: 65 | line = _update_rule50_weights(line, rule50_multiplier) 66 | output_lines.append(line) 67 | with opener(out_filename, 'w') as f: 68 | for line in output_lines: 69 | f.write(line.encode()) 70 | print("DONE!") 71 | 72 | 73 | if __name__ == '__main__': 74 | fire.Fire(update_r50_weights) 75 | -------------------------------------------------------------------------------- /wip_archive/leela_train_to_pgn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Get PGN data from Leela training data, according to format at: 3 | https://github.com/glinscott/leela-chess/blob/master/training/tf/chunkparser.py#L115 4 | ''' 5 | import tarfile 6 | import os 7 | import numpy as np 8 | import chess 9 | from collections import namedtuple, defaultdict 10 | from chess import pgn 11 | import zlib, base64 12 | import struct 13 | 14 | # filename = './games7580000.tar.gz' 15 | 16 | filename = './games9030000.tar.gz' 17 | outputdir = './pgn' 18 | ensure_legal = False # Check for move legality (requires generating moves - slow) 19 | include_header = True # include header in PGN output? 20 | single_file = True # output to one file 21 | global_stats = False # Create big dictionary of all possible moves, and how often they are legal/played 22 | add_final_move = True # Add final move, based on training move probabilities and W/L/D result 23 | 24 | 25 | 26 | # This is just a zlib-compressed -> base85 representation of: 27 | # https://github.com/glinscott/leela-chess/blob/master/lc0/src/chess/bitboard.cc#L26 28 | # To decode... 29 | # import zlib, base64 30 | # zlib.decompress(base64.b85decode(idx_to_move.strip().replace('\n', ''))).decode().split() 31 | idx_to_move = ''' 32 | c-ke|Np|bV4h6uiye)h+B)8*i08H)u4?ZGma*u4Pbi>RfROI{l-}v9ZpZ`s<#1ZnWTlrSL{rOw@R=$;Q 35 | y<6|ryY==_(Y;&m*1Prg0nz7vJNNwoqE~-g{fAi{<)J+EhThN{di#Lr-q0I*LvJ4tecHom4+|d_KJ2=%>pE`n 36 | {7dDjJe8;NRG$9)RGxZMZ|Y6GsWMfj%2b)ob2`syrPE5My_WV`+URMccRc?}Z|N<)rML8!-tw21%2HV>OIv9xZRNz46I)ibtZLa; 37 | WnYz@U3PZ4tjcBe^Xkfw6q$uG^ez}rguH=_1`8C88F6Y-0F6Y-4F6UPWm-DOt)S-9i9eT&hHT`n^Dj&*+@}Ybv 38 | AIgXFp?oOs_lXt^f1<4v$ue0c%Ve1>lV!3@mVTdT&;*)56Q~cIAHP0DxY%DQ 39 | TQ6bZE*~XKftM-3vRO9EX4x+T0B8ekpzSnmpbfNv 40 | Hp^z&ESsf|ineT)&9YhgfarDG)~!E4^d8#wP{%ELGOz?oumnrI%mbhR3ZMWArzwB}C|H6eSc0Vwh_(bvumnpV5WQq!$-+ho8!7DP 41 | u%A16TCgNbvLs8gBula+OR}Uhlq`vn7>SV>Nst6dkOXOY((9M4M<6t=}iw 44 | H~A*tf^TMV^n>Ns*h3i@jo^n^XfCttB+U0eadHx 48 | 5GhI+RUf13V^n>Ns?QKpm{*@A%&SlDm+>*GKJHzV59LGo_?S`zgJ2N!?}V|0F)#+kz!(?^HK``mr0N60SaO;s+C-aZ6K$eRw23y+CR)Ewv~TiFzRA}I&X2cG5$4sW_sg2A-~N>IinMpww#eHe?~j`#(oMQaH|hSESEQSClWx*JAey^zH|}!%QAbv>AFk2vOfq)FifDFj^7+R148IS?l2SiJQG)RLqoZTP|(tr%efDFh!AQ~Bv0U3~e 52 | K=c*}TOjPyuus$Kr`1osuY>{jjylq^&MggGwj=f 54 | @ioJ~9=&{xudmxQSVCmnTA6a?>&#xKw^*H8ha(&&V0ZpI@G=b*V(4v}D 55 | lWJ1+0bx`*O%rXRO|*$N(I(nNn`jfQ-zVBP`6l1w>jURUzTUv=YkYk@v9V;!lC7^>H^>FKAQ$BN8d{JGazQT0J|Nt+kuGO)kuK6j 56 | x=0t~f?SXbvfm;axgZzhg6sq0=O9II?`?-|J8WCzZISoaJss&L-K3jze+@0tO}a@pX&(@t@!)RU?d)#cjk|F-=_cKzo3xLLmTuBb 57 | x=H(h7?4}^1rhdS*pp!kge?$|0U3}18DB#SG9UvoAp3x5X^;kKkcP7xq(K^x0U3}1*#|@;12P~3vJZ&KXNK9QX`iOmPphAN$(MY| 58 | muQKWXo;3+iI!-ImT2irCR$P@RZ=BY0wquaB~V)Aw8&*omOWY4x2$im6icxbMqw02VH8GT6h>hbMmbGIPy|I#%1e1EFH2UI?B_DW 59 | E)R__GwkxH=`zDEH;R|>bs1lm@pbt>t;-C%=6Q8_CEOWao-$pY9A7C)7+;t1bs1lm@pTQ+>oG1f?DFLJGQ%!UnJ(k&a%ZS~C?Cql 60 | WpWV=f%5<4rms>%g2{eHw(7X&Ss!27eCRHC0MwQbv(I(nNn`jel 61 | qD{1kHqrWh;sje`^KHJ(xB2?0m_X~O#PRi&7dC&`{9*rv{TG%yEO+n)U+@KA 63 | Tm}?h@C9G+^#S22He^5sWH_Ax8IZvje8CrdeLyTA|0l>7|M~g~O&3YJNYb86donF|TJGdazT`{3L`$?pOSD8wv_wm^L`x?z(UK~u 64 | k}9bZD1j0vfyzZvE|Rhd%O)&KTb8z1iltZzqc94iFbbnE3ZpOzqnxH9D1ss=<)yro|J>%)?ccX<^Xm3g{5G#{&$Mpy>UN`e8&$Va 65 | bsJT;|Es#qt9zbTw^zcQ;q7tO?a}z{pS$i7M%8Up-NwUhRNel4>o%%x&%$q`>h}EVHl=Pig?fkHp?BN{l=7i`DDU416Nz9D41z&0 66 | 2nN9*7zE=#fdOzEP`BB1d-#2uO}8ghw~2JS?IV~3lVB3e+c3hI7!zY+^Z{W6IVF>2vP_oAGFc|eWSK1UKZ}d~HjHk==r)3G&!DWr 67 | vI^^VW5=>s7RzE;x52}*SQg7-=>x(&9nf-y7SIA(KnrNGESANx{u*7Ykumnr61WO+f9)1G_PymI~6hHwKEWr{i@t?&&aS5eMC~f7m 69 | mD5H_8!4?@TDJsBpae>wWJ#7}NtR?umSjnmWJzZzSrQ{L5+gCn6;-aNvY*R-E_RU59#5Gb6XbE5_81B6Ops@uS3Y!#CAJ8W!oQk7hR0)WJpTFYF*hEMydQJp@f_(f 71 | G9EW#+J?5FZ9Jxh%As3|0Nr4-lcavrhwk1cj;Yv`+#thqV+uT`z+(Vd)ooR`^Sqts{kY}PyY+6p 73 | Tkn2c@p`x3t#|9~1Huy<g6=ncK0w+{$Ula+_^P#*sLP#((jaberp 74 | X=kUsmiAg&__Xlp+^2J&%2RnNPvxmM^`_p`n|f1k>P@|=H~rOJ%7n 75 | m8G)IwbfrrzbF4rj`Mx}d-Ct(IA5Z_C;v{4^Nsp@^6%t0U$4I>|4xqcP4|29@8mdNd%q|DPLA^(_0] 120 | pmoves = sorted(pmoves, key=lambda mp: (-mp[1], mp[0])) 121 | 122 | if move_idx%2 == 1: # Flip black's move 123 | flipc = lambda nstr: str(ord('9') - ord(nstr)) 124 | flip = lambda uci: uci[0]+flipc(uci[1])+uci[2]+flipc(uci[3])+uci[4:] 125 | pmoves = [(flip(uci), p) for uci, p in pmoves] 126 | # print(pmoves) 127 | return pmoves 128 | 129 | def get_sorted_final_moves(data): 130 | '''For final move.. 131 | Return a list of (UCI move, probability) tuples, sorted descending based on probs 132 | ''' 133 | return get_training_probabilities(data, -1) 134 | 135 | 136 | def getbps_result(data): 137 | '''Get my-move bit-planes: 138 | returns an array (number of move elements) of arrays (each with 2 elements, 1 per side) 139 | of arrays (1 per piece, 6 total) of 8-bytes''' 140 | chunksize = 8276 141 | offset = 4+7432 142 | numplanes = 6 143 | bps = [] 144 | assert len(data)%chunksize == 0 145 | for i in range(len(data)//chunksize): 146 | sidebps = [] 147 | for side in range(2): 148 | planes = data[i*chunksize+offset + side*numplanes*8 149 | :i*chunksize + offset + numplanes*8 + side*numplanes*8] 150 | piecebbs = [] 151 | for plane in range(numplanes): 152 | piecebbs.append(planes[plane*8:(plane+1)*8]) 153 | sidebps.append(piecebbs) 154 | bps.append(sidebps) 155 | result_offset = chunksize - 1 156 | result = np.int8(data[i*chunksize + result_offset]) 157 | if i%2==0: 158 | # I am white 159 | white_result = result 160 | else: 161 | white_result = -result 162 | return bps, white_result 163 | 164 | def bp_to_array(bp, flip): 165 | '''Given an 8-byte bit-plane, convert to uint8 numpy array''' 166 | if not flip: 167 | return np.unpackbits(bytearray(bp)).reshape(8, 8) 168 | else: 169 | return np.unpackbits(bytearray(bp)).reshape(8, 8)[::-1] 170 | 171 | 172 | def convert_to_move(planes1, planes2, move_index): 173 | '''Given two arrays of 8-byte bit-planes, convert to a move 174 | Also need the index of the first move (0-indexed) to determine how to flip the board. 175 | ''' 176 | current_player = move_index % 2 177 | 178 | # Check for K moves first bc of castling, pawns last for promotions... 179 | for idx, piece in reversed(indexed_pieces): 180 | arr1 = bp_to_array(planes1[idx], current_player==1) 181 | arr2 = bp_to_array(planes2[idx], current_player==0) 182 | if not np.array_equal(arr1, arr2): 183 | rowfrom, colfrom = np.where(arr1 & ~arr2) 184 | rowto, colto = np.where(~arr1 & arr2) 185 | promotion = '' 186 | if not len(colfrom)==len(rowfrom)==len(colto)==len(rowto)==1: 187 | # This must be a pawn promotion... 188 | assert (len(colfrom)==len(rowfrom)==0) 189 | # Find where the pawn came from 190 | p_arr1 = bp_to_array(planes1[0], current_player==1) 191 | p_arr2 = bp_to_array(planes2[0], current_player==0) 192 | rowfrom, colfrom = np.where(p_arr1 & ~p_arr2) 193 | promotion = piece.lower() 194 | assert len(colfrom)==len(rowfrom)==len(colto)==len(rowto)==1 195 | rowfrom, colfrom = rowfrom[0], colfrom[0] 196 | rowto, colto = rowto[0], colto[0] 197 | uci = '{}{}{}{}{}'.format(columns[colfrom], rowfrom+1, 198 | columns[colto], rowto+1, promotion) 199 | return piece, uci 200 | else: 201 | raise Exception("I shouldn't be here") 202 | 203 | def getpgn(data, name): 204 | # These are global just for debug 205 | global game, node, white_result, legal_moves, board, uci, move, moveidx, final_moves, final_move_probs 206 | game = chess.pgn.Game() 207 | game.headers["Event"] = name 208 | node = game 209 | bps, white_result = getbps_result(data) 210 | if white_result==1: 211 | game.headers["Result"] = "1-0" 212 | elif white_result==-1: 213 | game.headers["Result"] = "0-1" 214 | elif white_result==0: 215 | game.headers["Result"] = "1/2-1/2" 216 | else: 217 | print(white_result) 218 | raise Exception("Bad result") 219 | for moveidx in range(len(bps)-1): 220 | piece, uci = convert_to_move(bps[moveidx][0], bps[moveidx+1][1], moveidx) 221 | move = chess.Move.from_uci(uci) 222 | if ensure_legal: 223 | legal_moves = list(node.board().generate_legal_moves()) 224 | assert move in legal_moves 225 | if global_stats: 226 | color = 'wb'[moveidx%2] 227 | move_stats[color + piece.lower() + uci[:4]].play_count += 1 228 | for lmove in legal_moves: # Legal moves 229 | lboard = node.board() 230 | lpiece = lboard.piece_at(lmove.from_square).symbol() 231 | luci = lmove.uci() 232 | move_stats[color + lpiece.lower() + luci[:4]].legal_count += 1 233 | node = node.add_variation(move) 234 | if add_final_move: 235 | promotion_debug = False 236 | board = node.board() 237 | final_move_probs = get_sorted_final_moves(data) 238 | final_moves = [mp[0] for mp in final_move_probs] 239 | # print(final_moves) 240 | if board.turn == chess.WHITE: 241 | result = white_result 242 | else: 243 | result = -white_result 244 | for idx, uci in enumerate(final_moves): 245 | try: 246 | board.push_uci(uci) 247 | except ValueError: 248 | # Assume this is a promotion 249 | if (uci+'q') not in final_moves: 250 | print() 251 | print(name, ": ", 'Queen promotion on final move not in training probabilities???') 252 | uci = uci + 'q' 253 | print("Trying move:", uci) 254 | promotion_debug = True 255 | else: 256 | # Must be a knight promotion... 257 | print() 258 | print(name, ": ", 'Knight promotion on final move') 259 | uci = uci + 'n' 260 | print("Trying move:", uci) 261 | promotion_debug = True 262 | board.push_uci(uci) 263 | if board.is_checkmate() and result==1: 264 | break 265 | if (board.is_insufficient_material() or 266 | board.is_stalemate() or 267 | board.can_claim_draw()) and result==0: 268 | break 269 | board.pop() 270 | else: 271 | print() 272 | print("Error: Can't find final move") 273 | print("White result is:", white_result) 274 | print("Choices for final move:") 275 | for it in final_move_probs: 276 | print(it) 277 | print(game) 278 | uci = None 279 | # raise Exception("Can't find final move!") 280 | if uci is not None: 281 | move = chess.Move.from_uci(uci) 282 | node = node.add_variation(move) 283 | if promotion_debug: 284 | print(game) 285 | print("Final move choices were:") 286 | for it in final_move_probs: 287 | print(it) 288 | print("Selected:", uci) 289 | return str(game) 290 | 291 | 292 | def write_pgn(pgn, name, pgnfile): 293 | if pgnfile: 294 | print(pgn, file=pgnfile) 295 | print('\n', file=pgnfile) 296 | else: 297 | pgnfilename = name + '.pgn' 298 | pgnfilename = os.path.join(outputdir, pgnfilename) 299 | with open(pgnfilename, 'w') as pgnfile: 300 | print(pgn, file=pgnfile) 301 | 302 | if __name__ == '__main__': 303 | allnames = set() 304 | games_in_file = 10000 305 | pgnfile = None 306 | try: 307 | if single_file: 308 | pgnfilename = os.path.basename(filename).split('.', 1)[0] + '.pgn' 309 | pgnfilename = os.path.join(outputdir, pgnfilename) 310 | pgnfile = open(pgnfilename, 'w') 311 | with tarfile.open(filename) as f: 312 | for idx, member in enumerate(f): 313 | if member.name in allnames: 314 | raise Exception("Duplicate name in training data???") 315 | allnames.add(member.name) 316 | if idx%50==0: 317 | print('\n{:5}/{} '.format(idx, games_in_file), end='') 318 | if single_file: 319 | pgnfile.flush() 320 | print('.', end='') 321 | data = f.extractfile(member).read() 322 | pgn = getpgn(data, member.name) 323 | if not include_header: 324 | # This chops off the header of the pgn string 325 | pgn = pgn.rsplit('\n', 1)[-1] 326 | write_pgn(pgn, member.name, pgnfile) 327 | finally: 328 | pgnfile.close() 329 | 330 | --------------------------------------------------------------------------------