├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── blunder_prediction ├── configs │ ├── leela_blunder_00001.yaml │ ├── leela_extra_blunder_00001.yaml │ ├── leela_grouped_0001.yaml │ └── leela_winrate_loss_00001.yaml ├── cp_to_winrate.py ├── maia_chess_backend │ ├── __init__.py │ ├── bat_files.py │ ├── boardTrees.py │ ├── chunkparser.py │ ├── data_prep.py │ ├── fen_to_vec.py │ ├── games.py │ ├── loaders.py │ ├── logging.py │ ├── maia │ │ ├── __init__.py │ │ ├── chunkparser.py │ │ ├── decode_training.py │ │ ├── lc0_az_policy_map.py │ │ ├── net.py │ │ ├── net_to_model.py │ │ ├── policy_index.py │ │ ├── proto │ │ │ ├── __init__.py │ │ │ ├── chunk_pb2.py │ │ │ └── net_pb2.py │ │ ├── shufflebuffer.py │ │ ├── tfprocess.py │ │ └── update_steps.py │ ├── model_loader.py │ ├── models_loader.py │ ├── new_model.py │ ├── plt_utils.py │ ├── torch │ │ ├── __init__.py │ │ ├── blocks.py │ │ ├── data_utils.py │ │ ├── dataset_loader.py │ │ ├── dataset_loader_old.py │ │ ├── model_loader.py │ │ ├── models.py │ │ ├── new_blocks.py │ │ ├── new_model.py │ │ ├── resnet.py │ │ ├── tensorboard_wrapper.py │ │ ├── utils.py │ │ └── wrapper.py │ ├── tourney.py │ ├── uci.py │ └── utils.py ├── make_csvs.sh ├── make_month_csv.py ├── mmap_csv.py ├── mmap_grouped_csv.py ├── model_weights │ ├── README.md │ ├── leela_blunder_00001-266000.pt │ ├── leela_extra_blunder_00001-234000.pt │ ├── leela_grouped_0001-1372000.pt │ └── leela_winrate_loss_00001-200000.pt └── train_model.py ├── data_generators ├── csv_grouper.py ├── extractELOrange.py ├── filter_csv.py ├── grouped_boards.py ├── grouped_train_test.py ├── make_batch_csv.py ├── make_batch_files.py ├── make_combined_csvs.py ├── make_csvs.sh ├── make_month_csv.py ├── pgnCPsToCSV_multi.py ├── pgnCPsToCSV_single.py ├── run_alternates.sh ├── run_batch_csv_makers.sh ├── run_batch_makers.sh └── run_singles.sh ├── images ├── CP_v_winrate_ELO.png ├── all_lineplot.png ├── delta_human_wr.png ├── delta_top2.png ├── delta_wr2.png ├── leela_lineplot.png ├── maia_lineplot.png ├── models_agreement.png ├── other_effects_lineplot.png └── sf_lineplot.png ├── maia_env.yml ├── maia_weights ├── maia-1100.pb.gz ├── maia-1200.pb.gz ├── maia-1300.pb.gz ├── maia-1400.pb.gz ├── maia-1500.pb.gz ├── maia-1600.pb.gz ├── maia-1700.pb.gz ├── maia-1800.pb.gz └── maia-1900.pb.gz ├── move_prediction ├── lczero-common │ ├── gen_proto_files.sh │ └── proto │ │ ├── chunk.proto │ │ └── net.proto ├── maia_chess_backend │ ├── __init__.py │ ├── bat_files.py │ ├── boardTrees.py │ ├── chunkparser.py │ ├── data_prep.py │ ├── fen_to_vec.py │ ├── games.py │ ├── loaders.py │ ├── logging.py │ ├── maia │ │ ├── __init__.py │ │ ├── chunkparser.py │ │ ├── decode_training.py │ │ ├── lc0_az_policy_map.py │ │ ├── net.py │ │ ├── net_to_model.py │ │ ├── policy_index.py │ │ ├── proto │ │ │ ├── __init__.py │ │ │ ├── chunk_pb2.py │ │ │ └── net_pb2.py │ │ ├── shufflebuffer.py │ │ ├── tfprocess.py │ │ └── update_steps.py │ ├── model_loader.py │ ├── models_loader.py │ ├── new_model.py │ ├── plt_utils.py │ ├── tourney.py │ ├── uci.py │ └── utils.py ├── maia_config.yaml ├── model_files │ ├── 1100 │ │ ├── config.yaml │ │ └── final_1100-40.pb.gz │ ├── 1200 │ │ ├── config.yaml │ │ └── final_1200-40.pb.gz │ ├── 1300 │ │ ├── config.yaml │ │ └── final_1300-40.pb.gz │ ├── 1400 │ │ ├── config.yaml │ │ └── final_1400-40.pb.gz │ ├── 1500 │ │ ├── config.yaml │ │ └── final_1500-40.pb.gz │ ├── 1600 │ │ ├── config.yaml │ │ └── final_1600-40.pb.gz │ ├── 1700 │ │ ├── config.yaml │ │ └── final_1700-40.pb.gz │ ├── 1800 │ │ ├── config.yaml │ │ └── final_1800-40.pb.gz │ └── 1900 │ │ ├── config.yaml │ │ └── final_1900-40.pb.gz ├── pgn_to_trainingdata.sh ├── replication-configs │ ├── defaults │ │ ├── 1200.yaml │ │ ├── 1500.yaml │ │ └── 1800.yaml │ ├── extras │ │ ├── 1000.yaml │ │ ├── 2000.yaml │ │ ├── 2300.yaml │ │ ├── all.yaml │ │ └── double.yaml │ ├── final │ │ ├── 1100_final.yaml │ │ ├── 1200_final.yaml │ │ ├── 1300_final.yaml │ │ ├── 1400_final.yaml │ │ ├── 1500_final.yaml │ │ ├── 1600_final.yaml │ │ ├── 1700_final.yaml │ │ ├── 1800_final.yaml │ │ └── 1900_final.yaml │ ├── final_unfiltered │ │ ├── 1000_final.yaml │ │ ├── 1100_final.yaml │ │ ├── 1200_final.yaml │ │ ├── 1300_final.yaml │ │ ├── 1400_final.yaml │ │ ├── 1500_final.yaml │ │ ├── 1600_final.yaml │ │ ├── 1700_final.yaml │ │ ├── 1800_final.yaml │ │ ├── 1900_final.yaml │ │ └── 2000_final.yaml │ ├── leela_best │ │ ├── 1200.yaml │ │ ├── 1500.yaml │ │ └── 1800.yaml │ ├── new_LR │ │ ├── 1200_LR.yaml │ │ ├── 1200_LR_big_batch.yaml │ │ ├── 1500_LR.yaml │ │ ├── 1500_LR_big_batch.yaml │ │ ├── 1800_LR.yaml │ │ └── 1800_LR_big_batch.yaml │ ├── sweep │ │ ├── 1800_LR.yaml │ │ ├── 1800_policy_value.yaml │ │ ├── 1800_renorm.yaml │ │ └── 1800_swa.yaml │ └── testing │ │ └── example.yaml ├── replication-extractELOrange.py ├── replication-generate_pgns.sh ├── replication-make_leela_files.sh ├── replication-make_testing_pgns.sh ├── replication-move_training_set.py ├── replication-run_model_on_csv.py └── train_maia.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | training/ 2 | validation/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | reports/overleaf 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | ## Core latex/pdflatex auxiliary files: 136 | *.aux 137 | *.lof 138 | *.log 139 | *.lot 140 | *.fls 141 | *.out 142 | *.toc 143 | *.fmt 144 | *.fot 145 | *.cb 146 | *.cb2 147 | .*.lb 148 | 149 | ## Intermediate documents: 150 | *.dvi 151 | *.xdv 152 | *-converted-to.* 153 | # these rules might exclude image files for figures etc. 154 | # *.ps 155 | # *.eps 156 | # *.pdf 157 | 158 | ## Generated if empty string is given at "Please type another file name for output:" 159 | .pdf 160 | 161 | ## Bibliography auxiliary files (bibtex/biblatex/biber): 162 | *.bbl 163 | *.bcf 164 | *.blg 165 | *-blx.aux 166 | *-blx.bib 167 | *.run.xml 168 | 169 | ## Build tool auxiliary files: 170 | *.fdb_latexmk 171 | *.synctex 172 | *.synctex(busy) 173 | *.synctex.gz 174 | *.synctex.gz(busy) 175 | *.pdfsync 176 | 177 | ## Build tool directories for auxiliary files 178 | # latexrun 179 | latex.out/ 180 | 181 | ## Auxiliary and intermediate files from other packages: 182 | # algorithms 183 | *.alg 184 | *.loa 185 | 186 | # achemso 187 | acs-*.bib 188 | 189 | # amsthm 190 | *.thm 191 | 192 | # beamer 193 | *.nav 194 | *.pre 195 | *.snm 196 | *.vrb 197 | 198 | # changes 199 | *.soc 200 | 201 | # comment 202 | *.cut 203 | 204 | # cprotect 205 | *.cpt 206 | 207 | # elsarticle (documentclass of Elsevier journals) 208 | *.spl 209 | 210 | # endnotes 211 | *.ent 212 | 213 | # fixme 214 | *.lox 215 | 216 | # feynmf/feynmp 217 | *.mf 218 | *.mp 219 | *.t[1-9] 220 | *.t[1-9][0-9] 221 | *.tfm 222 | 223 | #(r)(e)ledmac/(r)(e)ledpar 224 | *.end 225 | *.?end 226 | *.[1-9] 227 | *.[1-9][0-9] 228 | *.[1-9][0-9][0-9] 229 | *.[1-9]R 230 | *.[1-9][0-9]R 231 | *.[1-9][0-9][0-9]R 232 | *.eledsec[1-9] 233 | *.eledsec[1-9]R 234 | *.eledsec[1-9][0-9] 235 | *.eledsec[1-9][0-9]R 236 | *.eledsec[1-9][0-9][0-9] 237 | *.eledsec[1-9][0-9][0-9]R 238 | 239 | # glossaries 240 | *.acn 241 | *.acr 242 | *.glg 243 | *.glo 244 | *.gls 245 | *.glsdefs 246 | *.lzo 247 | *.lzs 248 | 249 | # uncomment this for glossaries-extra (will ignore makeindex's style files!) 250 | # *.ist 251 | 252 | # gnuplottex 253 | *-gnuplottex-* 254 | 255 | # gregoriotex 256 | *.gaux 257 | *.gtex 258 | 259 | # htlatex 260 | *.4ct 261 | *.4tc 262 | *.idv 263 | *.lg 264 | *.trc 265 | *.xref 266 | 267 | # hyperref 268 | *.brf 269 | 270 | # knitr 271 | *-concordance.tex 272 | # TODO Uncomment the next line if you use knitr and want to ignore its generated tikz files 273 | # *.tikz 274 | *-tikzDictionary 275 | 276 | # listings 277 | *.lol 278 | 279 | # luatexja-ruby 280 | *.ltjruby 281 | 282 | # makeidx 283 | *.idx 284 | *.ilg 285 | *.ind 286 | 287 | # minitoc 288 | *.maf 289 | *.mlf 290 | *.mlt 291 | *.mtc[0-9]* 292 | *.slf[0-9]* 293 | *.slt[0-9]* 294 | *.stc[0-9]* 295 | 296 | # minted 297 | _minted* 298 | *.pyg 299 | 300 | # morewrites 301 | *.mw 302 | 303 | # nomencl 304 | *.nlg 305 | *.nlo 306 | *.nls 307 | 308 | # pax 309 | *.pax 310 | 311 | # pdfpcnotes 312 | *.pdfpc 313 | 314 | # sagetex 315 | *.sagetex.sage 316 | *.sagetex.py 317 | *.sagetex.scmd 318 | 319 | # scrwfile 320 | *.wrt 321 | 322 | # sympy 323 | *.sout 324 | *.sympy 325 | sympy-plots-for-*.tex/ 326 | 327 | # pdfcomment 328 | *.upa 329 | *.upb 330 | 331 | # pythontex 332 | *.pytxcode 333 | pythontex-files-*/ 334 | 335 | # tcolorbox 336 | *.listing 337 | 338 | # thmtools 339 | *.loe 340 | 341 | # TikZ & PGF 342 | *.dpth 343 | *.md5 344 | *.auxlock 345 | 346 | # todonotes 347 | *.tdo 348 | 349 | # vhistory 350 | *.hst 351 | *.ver 352 | 353 | # easy-todo 354 | *.lod 355 | 356 | # xcolor 357 | *.xcp 358 | 359 | # xmpincl 360 | *.xmpi 361 | 362 | # xindy 363 | *.xdy 364 | 365 | # xypic precompiled matrices and outlines 366 | *.xyc 367 | *.xyd 368 | 369 | # endfloat 370 | *.ttt 371 | *.fff 372 | 373 | # Latexian 374 | TSWLatexianTemp* 375 | 376 | ## Editors: 377 | # WinEdt 378 | *.bak 379 | *.sav 380 | 381 | # Texpad 382 | .texpadtmp 383 | 384 | # LyX 385 | *.lyx~ 386 | 387 | # Kile 388 | *.backup 389 | 390 | # gummi 391 | .*.swp 392 | 393 | # KBibTeX 394 | *~[0-9]* 395 | 396 | # TeXnicCenter 397 | *.tps 398 | 399 | # auto folder when using emacs and auctex 400 | ./auto/* 401 | *.el 402 | 403 | # expex forward references with \gathertags 404 | *-tags.tex 405 | 406 | # standalone packages 407 | *.sta 408 | 409 | # Makeindex log files 410 | *.lpz 411 | 412 | # REVTeX puts footnotes in the bibliography by default, unless the nofootinbib 413 | # option is specified. Footnotes are the stored in a file with suffix Notes.bib. 414 | # Uncomment the next line to have this generated file ignored. 415 | #*Notes.bib 416 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # YAML 1.2 2 | --- 3 | abstract: | 4 | "As artificial intelligence becomes increasingly intelligent---in some cases, achieving superhuman performance---there is growing potential for humans to learn from and collaborate with algorithms. However, the ways in which AI systems approach problems are often different from the ways people do, and thus may be uninterpretable and hard to learn from. A crucial step in bridging this gap between human and artificial intelligence is modeling the granular actions that constitute human behavior, rather than simply matching aggregate human performance. 5 | We pursue this goal in a model system with a long history in artificial intelligence: chess. The aggregate performance of a chess player unfolds as they make decisions over the course of a game. The hundreds of millions of games played online by players at every skill level form a rich source of data in which these decisions, and their exact context, are recorded in minute detail. Applying existing chess engines to this data, including an open-source implementation of AlphaZero, we find that they do not predict human moves well. 6 | We develop and introduce Maia, a customized version of Alpha-Zero trained on human chess games, that predicts human moves at a much higher accuracy than existing engines, and can achieve maximum accuracy when predicting decisions made by players at a specific skill level in a tuneable way. For a dual task of predicting whether a human will make a large mistake on the next move, we develop a deep neural network that significantly outperforms competitive baselines. Taken together, our results suggest that there is substantial promise in designing artificial intelligence systems with human collaboration in mind by first accurately modeling granular human decision-making. " 7 | authors: 8 | - 9 | affiliation: "Universiy of Toronto" 10 | family-names: "McIlroy-Young" 11 | given-names: Reid 12 | - 13 | affiliation: "Microsoft Research" 14 | family-names: Sen 15 | given-names: Siddhartha 16 | - 17 | affiliation: "Cornell University" 18 | family-names: Kleinberg 19 | given-names: Jon 20 | - 21 | affiliation: "Universiy of Toronto" 22 | family-names: Anderson 23 | given-names: Ashton 24 | cff-version: "1.2.0" 25 | date-released: 2020-08-23 26 | doi: "10.1145/3394486.3403219" 27 | license: "GPL-3.0" 28 | message: "If you use Maia, please cite the original paper" 29 | repository-code: "https://github.com/CSSLab/maia-chess" 30 | url: "https://maiachess.com" 31 | title: "Maia Chess" 32 | version: "1.0.0" 33 | preferred-citation: 34 | type: article 35 | authors: 36 | - 37 | affiliation: "Universiy of Toronto" 38 | family-names: "McIlroy-Young" 39 | given-names: Reid 40 | - 41 | affiliation: "Microsoft Research" 42 | family-names: Sen 43 | given-names: Siddhartha 44 | - 45 | affiliation: "Cornell University" 46 | family-names: Kleinberg 47 | given-names: Jon 48 | - 49 | affiliation: "Universiy of Toronto" 50 | family-names: Anderson 51 | given-names: Ashton 52 | doi: "10.1145/3394486.3403219" 53 | journal: "KDD '20: Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining" 54 | month: 9 55 | start: 1677 56 | end: 1687 57 | title: "Aligning Superhuman AI with Human Behavior: Chess as a Model System" 58 | year: 2020 59 | ... 60 | -------------------------------------------------------------------------------- /blunder_prediction/configs/leela_blunder_00001.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | device: 0 4 | 5 | dataset: 6 | input_train: /datadrive/mmaps/train/ 7 | input_test: /datadrive/mmaps/test/ 8 | 9 | training: 10 | lr_intial: 0.0002 11 | lr_gamma: 0.1 12 | lr_steps: 13 | - 20000 14 | - 1000000 15 | - 1300000 16 | batch_size: 2000 17 | test_steps: 2000 18 | total_steps: 1400000 19 | test_size: 200 20 | 21 | model: 22 | type: leela 23 | outputs: 24 | - is_blunder_wr 25 | channels: 64 26 | blocks: 6 27 | ... 28 | -------------------------------------------------------------------------------- /blunder_prediction/configs/leela_extra_blunder_00001.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | device: 1 4 | 5 | dataset: 6 | input_train: /datadrive/mmaps/train/ 7 | input_test: /datadrive/mmaps/test/ 8 | 9 | training: 10 | lr_intial: 0.0002 11 | lr_gamma: 0.1 12 | lr_steps: 13 | - 20000 14 | - 1000000 15 | - 1300000 16 | batch_size: 2000 17 | test_steps: 2000 18 | total_steps: 1400000 19 | test_size: 200 20 | 21 | model: 22 | type: leela 23 | inputs: 24 | - cp_rel 25 | - clock_percent 26 | - move_ply 27 | - opponent_elo 28 | - active_elo 29 | outputs: 30 | - is_blunder_wr 31 | channels: 64 32 | blocks: 6 33 | ... 34 | -------------------------------------------------------------------------------- /blunder_prediction/configs/leela_grouped_0001.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | device: 1 4 | 5 | dataset: 6 | input_train: /datadrive/mmaps/train_grouped/ 7 | input_test: /datadrive/mmaps/test_grouped/ 8 | 9 | training: 10 | lr_intial: 0.0001 11 | lr_gamma: 0.1 12 | lr_steps: 13 | - 100000 14 | - 1000000 15 | - 1300000 16 | batch_size: 200 17 | test_steps: 2000 18 | total_steps: 1400000 19 | test_size: 200 20 | 21 | model: 22 | type: leela 23 | outputs: 24 | - is_blunder_mean 25 | channels: 64 26 | blocks: 6 27 | ... 28 | -------------------------------------------------------------------------------- /blunder_prediction/configs/leela_winrate_loss_00001.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | device: 0 4 | 5 | dataset: 6 | input_train: /datadrive/mmaps/train/ 7 | input_test: /datadrive/mmaps/test/ 8 | 9 | training: 10 | lr_intial: 0.0002 11 | lr_gamma: 0.1 12 | lr_steps: 13 | - 20000 14 | - 1000000 15 | - 1300000 16 | batch_size: 2000 17 | test_steps: 2000 18 | total_steps: 1400000 19 | test_size: 200 20 | 21 | model: 22 | type: leela 23 | outputs: 24 | - winrate_loss 25 | channels: 64 26 | blocks: 6 27 | ... 28 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/__init__.py: -------------------------------------------------------------------------------- 1 | #from .uci import * 2 | from .games import * 3 | from .utils import * 4 | from .tourney import * 5 | from .loaders import * 6 | from .models_loader import * 7 | from .logging import * 8 | from .fen_to_vec import * 9 | from .bat_files import * 10 | from .plt_utils import * 11 | from .model_loader import load_model_config 12 | #from .pickle4reducer import * 13 | #from .boardTrees import * 14 | #from .stockfishAnalysis import * 15 | 16 | #Tensorflow stuff 17 | try: 18 | from .tf_process import * 19 | from .tf_net import * 20 | from .tf_blocks import * 21 | except ImportError: 22 | pass 23 | 24 | fics_header = [ 25 | 'game_id', 26 | 'rated', 27 | 'name', 28 | 'opp_name', 29 | 'elo', 30 | 'oppelo', 31 | 'num_legal_moves', 32 | 'num_blunders', 33 | 'blunder', 34 | 'eval_before_move', 35 | 'eval_after_move', 36 | 'to_move', 37 | 'is_comp', 38 | 'opp_is_comp', 39 | 'time_control', 40 | 'ECO', 41 | 'result', 42 | 'time_left', 43 | 'opp_time_left', 44 | 'time_used', 45 | 'move_idx', 46 | 'move', 47 | 'material', 48 | 'position', 49 | 'stdpos', 50 | 'unkown' 51 | ] 52 | 53 | __version__ = '1.0.0' 54 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/games.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import collections.abc 3 | import re 4 | 5 | 6 | import chess.pgn 7 | 8 | moveRegex = re.compile(r'\d+[.][ \.](\S+) (?:{[^}]*} )?(\S+)') 9 | 10 | 11 | class GamesFile(collections.abc.Iterable): 12 | def __init__(self, path, cacheGames = False): 13 | self.path = path 14 | self.f = bz2.open(self.path, 'rt') 15 | 16 | self.cache = cacheGames 17 | self.games = [] 18 | self.num_read = 0 19 | 20 | def __iter__(self): 21 | for g in self.games: 22 | yield g 23 | while True: 24 | yield self.loadNextGame() 25 | 26 | def loadNextGame(self): 27 | g = chess.pgn.read_game(self.f) 28 | if g is None: 29 | raise StopIteration 30 | if self.cache: 31 | self.games.append(g) 32 | self.num_read += 1 33 | return g 34 | 35 | def __getitem__(self, val): 36 | if isinstance(val, slice): 37 | return [self[i] for i in range(*val.indices(10**20))] 38 | elif isinstance(val, int): 39 | if len(self.games) < val: 40 | return self.games[val] 41 | elif val < 0: 42 | raise IndexError("negative indexing is not supported") from None 43 | else: 44 | g = self.loadNextGame() 45 | for i in range(val - len(self.games)): 46 | g = self.loadNextGame() 47 | return g 48 | else: 49 | raise IndexError("{} is not a valid input".format(val)) from None 50 | 51 | def __del__(self): 52 | try: 53 | self.f.close() 54 | except AttributeError: 55 | pass 56 | 57 | class LightGamesFile(object): 58 | def __init__(self, path, parseMoves = True, just_games = False): 59 | if path.endswith('bz2'): 60 | self.f = bz2.open(path, 'rt') 61 | else: 62 | self.f = open(path, 'r') 63 | self.parseMoves = parseMoves 64 | self.just_games = just_games 65 | self._peek = None 66 | 67 | def __iter__(self): 68 | try: 69 | while True: 70 | yield self.readNextGame() 71 | except StopIteration: 72 | return 73 | 74 | def peekNextGame(self): 75 | if self._peek is None: 76 | self._peek = self.readNextGame() 77 | return self._peek 78 | 79 | def readNextGame(self): 80 | #self.f.readline() 81 | if self._peek is not None: 82 | g = self._peek 83 | self._peek = None 84 | return g 85 | ret = {} 86 | lines = '' 87 | if self.just_games: 88 | first_hit = False 89 | for l in self.f: 90 | lines += l 91 | if len(l) < 2: 92 | if first_hit: 93 | break 94 | else: 95 | first_hit = True 96 | else: 97 | for l in self.f: 98 | lines += l 99 | if len(l) < 2: 100 | if len(ret) >= 2: 101 | break 102 | else: 103 | raise RuntimeError(l) 104 | else: 105 | k, v, _ = l.split('"') 106 | ret[k[1:-1]] = v 107 | nl = self.f.readline() 108 | lines += nl 109 | if self.parseMoves: 110 | ret['moves'] = re.findall(moveRegex, nl) 111 | lines += self.f.readline() 112 | if len(lines) < 1: 113 | raise StopIteration 114 | return ret, lines 115 | 116 | def readBatch(self, n): 117 | ret = [] 118 | for i in range(n): 119 | try: 120 | ret.append(self.readNextGame()) 121 | except StopIteration: 122 | break 123 | return ret 124 | 125 | def getWinRates(self, extraKey = None): 126 | # Assumes same players in all games 127 | dat, _ = self.peekNextGame() 128 | p1, p2 = sorted((dat['White'], dat['Black'])) 129 | d = { 130 | 'name' : f"{p1} v {p2}", 131 | 'p1' : p1, 132 | 'p2' : p2, 133 | 'games' : 0, 134 | 'wins' : 0, 135 | 'ties' : 0, 136 | 'losses' : 0, 137 | } 138 | if extraKey is not None: 139 | d[extraKey] = {} 140 | for dat, _ in self: 141 | d['games'] += 1 142 | if extraKey is not None and dat[extraKey] not in d[extraKey]: 143 | d[extraKey][dat[extraKey]] = [] 144 | if p1 == dat['White']: 145 | if dat['Result'] == '1-0': 146 | d['wins'] += 1 147 | if extraKey is not None: 148 | d[extraKey][dat[extraKey]].append(1) 149 | elif dat['Result'] == '0-1': 150 | d['losses'] += 1 151 | if extraKey is not None: 152 | d[extraKey][dat[extraKey]].append(0) 153 | else: 154 | d['ties'] += 1 155 | if extraKey is not None: 156 | d[extraKey][dat[extraKey]].append(.5) 157 | else: 158 | if dat['Result'] == '0-1': 159 | d['wins'] += 1 160 | if extraKey is not None: 161 | d[extraKey][dat[extraKey]].append(1) 162 | elif dat['Result'] == '1-0': 163 | d['losses'] += 1 164 | if extraKey is not None: 165 | d[extraKey][dat[extraKey]].append(0) 166 | else: 167 | d['ties'] += 1 168 | if extraKey is not None: 169 | d[extraKey][dat[extraKey]].append(.5) 170 | return d 171 | 172 | def __del__(self): 173 | try: 174 | self.f.close() 175 | except AttributeError: 176 | pass 177 | 178 | def getBoardMoveMap(game, maxMoves = None): 179 | d = {} 180 | board = game.board() 181 | for i, move in enumerate(game.main_line()): 182 | d[board.fen()] = move.uci() 183 | board.push(move) 184 | if maxMoves is not None and i > maxMoves: 185 | break 186 | return d 187 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/loaders.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | 3 | LEELA_WEIGHTS_VERSION = 2 4 | 5 | def read_weights_file(filename): 6 | if '.gz' in filename: 7 | opener = gzip.open 8 | else: 9 | opener = open 10 | with opener(filename, 'rb') as f: 11 | version = f.readline().decode('ascii') 12 | if version != '{}\n'.format(LEELA_WEIGHTS_VERSION): 13 | raise ValueError("Invalid version {}".format(version.strip())) 14 | weights = [] 15 | e = 0 16 | for line in f: 17 | line = line.decode('ascii').strip() 18 | if not line: 19 | continue 20 | e += 1 21 | weight = list(map(float, line.split(' '))) 22 | weights.append(weight) 23 | if e == 2: 24 | filters = len(line.split(' ')) 25 | #print("Channels", filters) 26 | blocks = e - (4 + 14) 27 | if blocks % 8 != 0: 28 | raise ValueError("Inconsistent number of weights in the file - e = {}".format(e)) 29 | blocks //= 8 30 | #print("Blocks", blocks) 31 | return (filters, blocks, weights) 32 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/logging.py: -------------------------------------------------------------------------------- 1 | from .utils import printWithDate 2 | 3 | import functools 4 | import sys 5 | import time 6 | import datetime 7 | import os 8 | import os.path 9 | import traceback 10 | 11 | import pytz 12 | tz = pytz.timezone('Canada/Eastern') 13 | 14 | min_run_time = 60 * 10 # 10 minutes 15 | infos_dir_name = 'runinfos' 16 | 17 | class Tee(object): 18 | #Based on https://stackoverflow.com/a/616686 19 | def __init__(self, fname, is_err = False): 20 | self.file = open(fname, 'a') 21 | self.is_err = is_err 22 | if is_err: 23 | self.stdstream = sys.stderr 24 | sys.stderr = self 25 | else: 26 | self.stdstream = sys.stdout 27 | sys.stdout = self 28 | def __del__(self): 29 | if self.is_err: 30 | sys.stderr = self.stdstream 31 | else: 32 | sys.stdout = self.stdstream 33 | self.file.close() 34 | def write(self, data): 35 | self.file.write(data) 36 | self.stdstream.write(data) 37 | def flush(self): 38 | self.file.flush() 39 | 40 | def makeLog(logs_prefix, start_time, tstart, is_error, *notes): 41 | fname = f'error.log' if is_error else f'run.log' 42 | with open(logs_prefix + fname, 'w') as f: 43 | f.write(f"start: {start_time.strftime('%Y-%m-%d-%H:%M:%S')}\n") 44 | f.write(f"stop: {datetime.datetime.now(tz).strftime('%Y-%m-%d-%H:%M:%S')}\n") 45 | f.write(f"in: {int(tstart > min_run_time)}s\n") 46 | f.write(f"dir: {os.path.abspath(os.getcwd())}\n") 47 | f.write(f"{' '.join(sys.argv)}\n") 48 | f.write('\n'.join([str(n) for n in notes])) 49 | 50 | def makelogNamesPrefix(script_name, start_time): 51 | os.makedirs(infos_dir_name, exist_ok = True) 52 | os.makedirs(os.path.join(infos_dir_name, script_name), exist_ok = True) 53 | return os.path.join(infos_dir_name, script_name, f"{start_time.strftime('%Y-%m-%d-%H%M%S-%f')}_") 54 | 55 | def logged_main(mainFunc): 56 | @functools.wraps(mainFunc) 57 | def wrapped_main(*args, **kwds): 58 | start_time = datetime.datetime.now(tz) 59 | script_name = os.path.basename(sys.argv[0])[:-3] 60 | logs_prefix = makelogNamesPrefix(script_name, start_time) 61 | tee_out = Tee(logs_prefix + 'stdout.log', is_err = False) 62 | tee_err = Tee(logs_prefix + 'stderr.log', is_err = True) 63 | printWithDate(' '.join(sys.argv), colour = 'blue') 64 | printWithDate(f"Starting {script_name}", colour = 'blue') 65 | try: 66 | tstart = time.time() 67 | val = mainFunc(*args, **kwds) 68 | except (Exception, KeyboardInterrupt) as e: 69 | printWithDate(f"Error encountered", colour = 'blue') 70 | if (time.time() - tstart) > min_run_time: 71 | makeLog(logs_prefix, start_time, tstart, True, 'Error', e, traceback.format_exc()) 72 | raise 73 | else: 74 | printWithDate(f"Run completed", colour = 'blue') 75 | if (time.time() - tstart) > min_run_time: 76 | makeLog(logs_prefix, start_time, tstart, False, 'Successful') 77 | tee_out.flush() 78 | tee_err.flush() 79 | return val 80 | return wrapped_main 81 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/maia/__init__.py: -------------------------------------------------------------------------------- 1 | #tensorflow code 2 | 3 | from .tfprocess import TFProcess 4 | from .chunkparser import ChunkParser 5 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/maia/lc0_az_policy_map.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import sys 3 | import numpy as np 4 | from .policy_index import policy_index 5 | 6 | columns = 'abcdefgh' 7 | rows = '12345678' 8 | promotions = 'rbq' # N is encoded as normal move 9 | 10 | col_index = {columns[i] : i for i in range(len(columns))} 11 | row_index = {rows[i] : i for i in range(len(rows))} 12 | 13 | def index_to_position(x): 14 | return columns[x[0]] + rows[x[1]] 15 | 16 | def position_to_index(p): 17 | return col_index[p[0]], row_index[p[1]] 18 | 19 | def valid_index(i): 20 | if i[0] > 7 or i[0] < 0: 21 | return False 22 | if i[1] > 7 or i[1] < 0: 23 | return False 24 | return True 25 | 26 | def queen_move(start, direction, steps): 27 | i = position_to_index(start) 28 | dir_vectors = {'N': (0, 1), 'NE': (1, 1), 'E': (1, 0), 'SE': (1, -1), 29 | 'S':(0, -1), 'SW':(-1, -1), 'W': (-1, 0), 'NW': (-1, 1)} 30 | v = dir_vectors[direction] 31 | i = i[0] + v[0] * steps, i[1] + v[1] * steps 32 | if not valid_index(i): 33 | return None 34 | return index_to_position(i) 35 | 36 | def knight_move(start, direction, steps): 37 | i = position_to_index(start) 38 | dir_vectors = {'N': (1, 2), 'NE': (2, 1), 'E': (2, -1), 'SE': (1, -2), 39 | 'S':(-1, -2), 'SW':(-2, -1), 'W': (-2, 1), 'NW': (-1, 2)} 40 | v = dir_vectors[direction] 41 | i = i[0] + v[0] * steps, i[1] + v[1] * steps 42 | if not valid_index(i): 43 | return None 44 | return index_to_position(i) 45 | 46 | def make_map(kind='matrix'): 47 | # 56 planes of queen moves 48 | moves = [] 49 | for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']: 50 | for steps in range(1, 8): 51 | for r0 in rows: 52 | for c0 in columns: 53 | start = c0 + r0 54 | end = queen_move(start, direction, steps) 55 | if end == None: 56 | moves.append('illegal') 57 | else: 58 | moves.append(start+end) 59 | 60 | # 8 planes of knight moves 61 | for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']: 62 | for r0 in rows: 63 | for c0 in columns: 64 | start = c0 + r0 65 | end = knight_move(start, direction, 1) 66 | if end == None: 67 | moves.append('illegal') 68 | else: 69 | moves.append(start+end) 70 | 71 | # 9 promotions 72 | for direction in ['NW', 'N', 'NE']: 73 | for promotion in promotions: 74 | for r0 in rows: 75 | for c0 in columns: 76 | # Promotion only in the second last rank 77 | if r0 != '7': 78 | moves.append('illegal') 79 | continue 80 | start = c0 + r0 81 | end = queen_move(start, direction, 1) 82 | if end == None: 83 | moves.append('illegal') 84 | else: 85 | moves.append(start+end+promotion) 86 | 87 | for m in policy_index: 88 | if m not in moves: 89 | raise ValueError('Missing move: {}'.format(m)) 90 | 91 | az_to_lc0 = np.zeros((80*8*8, len(policy_index)), dtype=np.float32) 92 | indices = [] 93 | legal_moves = 0 94 | for e, m in enumerate(moves): 95 | if m == 'illegal': 96 | indices.append(-1) 97 | continue 98 | legal_moves += 1 99 | # Check for missing moves 100 | if m not in policy_index: 101 | raise ValueError('Missing move: {}'.format(m)) 102 | i = policy_index.index(m) 103 | indices.append(i) 104 | az_to_lc0[e][i] = 1 105 | 106 | assert legal_moves == len(policy_index) 107 | assert np.sum(az_to_lc0) == legal_moves 108 | for e in range(80*8*8): 109 | for i in range(len(policy_index)): 110 | pass 111 | if kind == 'matrix': 112 | return az_to_lc0 113 | elif kind == 'index': 114 | return indices 115 | 116 | if __name__ == "__main__": 117 | # Generate policy map include file for lc0 118 | if len(sys.argv) != 2: 119 | raise ValueError("Output filename is needed as a command line argument") 120 | 121 | az_to_lc0 = np.ravel(make_map('index')) 122 | header = \ 123 | """/* 124 | This file is part of Leela Chess Zero. 125 | Copyright (C) 2019 The LCZero Authors 126 | 127 | Leela Chess is free software: you can redistribute it and/or modify 128 | it under the terms of the GNU General Public License as published by 129 | the Free Software Foundation, either version 3 of the License, or 130 | (at your option) any later version. 131 | 132 | Leela Chess is distributed in the hope that it will be useful, 133 | but WITHOUT ANY WARRANTY; without even the implied warranty of 134 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 135 | GNU General Public License for more details. 136 | 137 | You should have received a copy of the GNU General Public License 138 | along with Leela Chess. If not, see . 139 | */ 140 | 141 | #pragma once 142 | 143 | namespace lczero { 144 | """ 145 | line_length = 12 146 | with open(sys.argv[1], 'w') as f: 147 | f.write(header+'\n') 148 | f.write('const short kConvPolicyMap[] = {\\\n') 149 | for e, i in enumerate(az_to_lc0): 150 | if e % line_length == 0 and e > 0: 151 | f.write('\n') 152 | f.write(str(i).rjust(5)) 153 | if e != len(az_to_lc0)-1: 154 | f.write(',') 155 | f.write('};\n\n') 156 | f.write('} // namespace lczero') 157 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/maia/net_to_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import tensorflow as tf 4 | import os 5 | import yaml 6 | from .tfprocess import TFProcess 7 | from .net import Net 8 | 9 | argparser = argparse.ArgumentParser(description='Convert net to model.') 10 | argparser.add_argument('net', type=str, 11 | help='Net file to be converted to a model checkpoint.') 12 | argparser.add_argument('--start', type=int, default=0, 13 | help='Offset to set global_step to.') 14 | argparser.add_argument('--cfg', type=argparse.FileType('r'), 15 | help='yaml configuration with training parameters') 16 | args = argparser.parse_args() 17 | cfg = yaml.safe_load(args.cfg.read()) 18 | print(yaml.dump(cfg, default_flow_style=False)) 19 | START_FROM = args.start 20 | net = Net() 21 | net.parse_proto(args.net) 22 | 23 | filters, blocks = net.filters(), net.blocks() 24 | if cfg['model']['filters'] != filters: 25 | raise ValueError("Number of filters in YAML doesn't match the network") 26 | if cfg['model']['residual_blocks'] != blocks: 27 | raise ValueError("Number of blocks in YAML doesn't match the network") 28 | weights = net.get_weights() 29 | 30 | tfp = TFProcess(cfg) 31 | tfp.init_net_v2() 32 | tfp.replace_weights_v2(weights) 33 | tfp.global_step.assign(START_FROM) 34 | 35 | root_dir = os.path.join(cfg['training']['path'], cfg['name']) 36 | if not os.path.exists(root_dir): 37 | os.makedirs(root_dir) 38 | tfp.manager.save() 39 | print("Wrote model to {}".format(tfp.manager.latest_checkpoint)) 40 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/maia/proto/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/blunder_prediction/maia_chess_backend/maia/proto/__init__.py -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/maia/shufflebuffer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # This file is part of Leela Chess. 4 | # Copyright (C) 2018 Michael O 5 | # 6 | # Leela Chess is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Leela Chess is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Leela Chess. If not, see . 18 | 19 | import random 20 | import unittest 21 | 22 | class ShuffleBuffer: 23 | def __init__(self, elem_size, elem_count): 24 | """ 25 | A shuffle buffer for fixed sized elements. 26 | 27 | Manages 'elem_count' items in a fixed buffer, each item being exactly 28 | 'elem_size' bytes. 29 | """ 30 | assert elem_size > 0, elem_size 31 | assert elem_count > 0, elem_count 32 | # Size of each element. 33 | self.elem_size = elem_size 34 | # Number of elements in the buffer. 35 | self.elem_count = elem_count 36 | # Fixed size buffer used to hold all the element. 37 | self.buffer = bytearray(elem_size * elem_count) 38 | # Number of elements actually contained in the buffer. 39 | self.used = 0 40 | 41 | def extract(self): 42 | """ 43 | Return an item from the shuffle buffer. 44 | 45 | If the buffer is empty, returns None 46 | """ 47 | if self.used < 1: 48 | return None 49 | # The items in the shuffle buffer are held in shuffled order 50 | # so returning the last item is sufficient. 51 | self.used -= 1 52 | i = self.used 53 | return self.buffer[i * self.elem_size : (i+1) * self.elem_size] 54 | 55 | def insert_or_replace(self, item): 56 | """ 57 | Inserts 'item' into the shuffle buffer, returning 58 | a random item. 59 | 60 | If the buffer is not yet full, returns None 61 | """ 62 | assert len(item) == self.elem_size, len(item) 63 | # putting the new item in a random location, and appending 64 | # the displaced item to the end of the buffer achieves a full 65 | # random shuffle (Fisher-Yates) 66 | if self.used > 0: 67 | # swap 'item' with random item in buffer. 68 | i = random.randint(0, self.used-1) 69 | old_item = self.buffer[i * self.elem_size : (i+1) * self.elem_size] 70 | self.buffer[i * self.elem_size : (i+1) * self.elem_size] = item 71 | item = old_item 72 | # If the buffer isn't yet full, append 'item' to the end of the buffer. 73 | if self.used < self.elem_count: 74 | # Not yet full, so place the returned item at the end of the buffer. 75 | i = self.used 76 | self.buffer[i * self.elem_size : (i+1) * self.elem_size] = item 77 | self.used += 1 78 | return None 79 | return item 80 | 81 | 82 | class ShuffleBufferTest(unittest.TestCase): 83 | def test_extract(self): 84 | sb = ShuffleBuffer(3, 1) 85 | r = sb.extract() 86 | assert r == None, r # empty buffer => None 87 | r = sb.insert_or_replace(b'111') 88 | assert r == None, r # buffer not yet full => None 89 | r = sb.extract() 90 | assert r == b'111', r # one item in buffer => item 91 | r = sb.extract() 92 | assert r == None, r # buffer empty => None 93 | def test_wrong_size(self): 94 | sb = ShuffleBuffer(3, 1) 95 | try: 96 | sb.insert_or_replace(b'1') # wrong length, so should throw. 97 | assert False # Should not be reached. 98 | except: 99 | pass 100 | def test_insert_or_replace(self): 101 | n=10 # number of test items. 102 | items=[bytes([x,x,x]) for x in range(n)] 103 | sb = ShuffleBuffer(elem_size=3, elem_count=2) 104 | out=[] 105 | for i in items: 106 | r = sb.insert_or_replace(i) 107 | if not r is None: 108 | out.append(r) 109 | # Buffer size is 2, 10 items, should be 8 seen so far. 110 | assert len(out) == n - 2, len(out) 111 | # Get the last two items. 112 | out.append(sb.extract()) 113 | out.append(sb.extract()) 114 | assert sorted(items) == sorted(out), (items, out) 115 | # Check that buffer is empty 116 | r = sb.extract() 117 | assert r is None, r 118 | 119 | 120 | if __name__ == '__main__': 121 | unittest.main() 122 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/maia/update_steps.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os 4 | import yaml 5 | import sys 6 | import tensorflow as tf 7 | from .tfprocess import TFProcess 8 | 9 | START_FROM = 0 10 | 11 | def main(cmd): 12 | cfg = yaml.safe_load(cmd.cfg.read()) 13 | print(yaml.dump(cfg, default_flow_style=False)) 14 | 15 | root_dir = os.path.join(cfg['training']['path'], cfg['name']) 16 | if not os.path.exists(root_dir): 17 | os.makedirs(root_dir) 18 | 19 | tfprocess = TFProcess(cfg) 20 | tfprocess.init_net_v2() 21 | 22 | tfprocess.restore_v2() 23 | 24 | START_FROM = cmd.start 25 | 26 | tfprocess.global_step.assign(START_FROM) 27 | tfprocess.manager.save() 28 | 29 | if __name__ == "__main__": 30 | argparser = argparse.ArgumentParser(description=\ 31 | 'Convert current checkpoint to new step count.') 32 | argparser.add_argument('--cfg', type=argparse.FileType('r'), 33 | help='yaml configuration with training parameters') 34 | argparser.add_argument('--start', type=int, default=0, 35 | help='Offset to set global_step to.') 36 | 37 | main(argparser.parse_args()) 38 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/model_loader.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | import os.path 4 | 5 | from .tourney import RandomEngine, StockfishEngine, LC0Engine 6 | 7 | def load_model_config(config_dir_path, lc0_depth = None, lc0Path = None, noise = False, temperature = 0, temp_decay = 0): 8 | with open(os.path.join(config_dir_path, 'config.yaml')) as f: 9 | config = yaml.safe_load(f.read()) 10 | 11 | if config['engine'] == 'stockfish': 12 | model = StockfishEngine(**config['options']) 13 | elif config['engine'] == 'random': 14 | model = RandomEngine() 15 | elif config['engine'] == 'torch': 16 | raise NotImplementedError("torch engines aren't working yet") 17 | elif config['engine'] in ['lc0', 'lc0_23']: 18 | kwargs = config['options'].copy() 19 | if lc0_depth is not None: 20 | kwargs['nodes'] = lc0_depth 21 | kwargs['movetime'] *= lc0_depth / 10 22 | kwargs['weightsPath'] = os.path.join(config_dir_path, config['options']['weightsPath']) 23 | model = LC0Engine(lc0Path = config['engine'] if lc0Path is None else lc0Path, noise = noise, temperature = temperature, temp_decay = temp_decay, **kwargs) 24 | else: 25 | raise NotImplementedError(f"{config['engine']} is not a known engine type") 26 | 27 | return model, config 28 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/models_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import yaml 4 | 5 | 6 | class Trained_Model(object): 7 | def __init__(self, path): 8 | self.path = path 9 | try: 10 | with open(os.path.join(path, 'config.yaml')) as f: 11 | self.config = yaml.safe_load(f.read()) 12 | except FileNotFoundError: 13 | raise FileNotFoundError(f"No config file found in: {path}") 14 | 15 | self.weights = {int(e.name.split('-')[-1].split('.')[0]) :e.path for e in os.scandir(path) if e.name.endswith('.txt') or e.name.endswith('.pb.gz')} 16 | 17 | def getMostTrained(self): 18 | return self.weights[max(self.weights.keys())] 19 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/new_model.py: -------------------------------------------------------------------------------- 1 | import chess 2 | import chess.engine 3 | 4 | class ChessEngine(object): 5 | def __init__(self, engine, limits): 6 | self.limits = chess.engine.Limit(**limits) 7 | self.engine = engine 8 | 9 | def getMove(self, board): 10 | try: 11 | results = self.engine.play( 12 | board, 13 | limit=self.limits, 14 | info = chess.engine.INFO_ALL 15 | ) 16 | 17 | if isinstance(board, str): 18 | board 19 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/plt_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import os 3 | import os.path 4 | import seaborn 5 | 6 | def multi_savefig(save_name, dir_name = 'images', save_types = ('pdf', 'png', 'svg')): 7 | os.makedirs(dir_name, exist_ok = True) 8 | for sType in save_types: 9 | dName = os.path.join(dir_name, sType) 10 | os.makedirs(dName, exist_ok = True) 11 | 12 | fname = f'{save_name}.{sType}' 13 | 14 | plt.savefig(os.path.join(dName, fname), format = sType, dpi = 300, transparent = True) 15 | 16 | def plot_pieces(board_a): 17 | fig, axes = plt.subplots(nrows=3, ncols=6, figsize = (16, 10)) 18 | axiter = iter(axes.flatten()) 19 | for i in range(17): 20 | seaborn.heatmap(board_a[i], ax = next(axiter), cbar = False, vmin=0, vmax=1, square = True) 21 | 22 | axes[-1,-1].set_axis_off() 23 | for i, n in enumerate(['Knights', 'Bishops', 'Rooks','Queen', 'King']): 24 | axes[0,i + 1].set_title(n) 25 | axes[1,i + 1].set_title(n) 26 | axes[0,0].set_title('Active Player Pieces\nPawns') 27 | axes[1,0].set_title('Opponent Pieces\nPawns') 28 | axes[2,0].set_title('Other Values\n Is White') 29 | for i in range(4): 30 | axes[2,i + 1].set_title('Castling') 31 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .new_model import * 3 | from .new_blocks import * 4 | from .dataset_loader import * 5 | from .dataset_loader_old import MmapIterLoaderMap_old 6 | from .tensorboard_wrapper import TB_wrapper 7 | from .utils import * 8 | from .wrapper import * 9 | from .model_loader import * 10 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/torch/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as F 4 | 5 | class CenteredBatchNorm2d(torch.nn.BatchNorm2d): 6 | """It appears the only way to get a trainable model with beta (bias) but not scale (weight 7 | is by keeping the weight data, even though it's not used""" 8 | 9 | def __init__(self, channels): 10 | super().__init__(channels, affine=True) 11 | self.weight.data.fill_(1) 12 | self.weight.requires_grad = False 13 | 14 | class ConvBlock(torch.nn.Module): 15 | def __init__(self, kernel_size, input_channels, output_channels=None): 16 | super().__init__() 17 | if output_channels is None: 18 | output_channels = input_channels 19 | padding = kernel_size // 2 20 | self.conv1 = torch.nn.Conv2d(input_channels, output_channels, kernel_size, stride=1, padding=padding, bias=False) 21 | self.conv1_bn = CenteredBatchNorm2d(output_channels) 22 | 23 | def forward(self, x): 24 | out = self.conv1_bn(self.conv1(x)) 25 | out = F.relu(out, inplace=True) 26 | return out 27 | 28 | class ResidualBlock(torch.nn.Module): 29 | def __init__(self, channels): 30 | super().__init__() 31 | self.conv1 = torch.nn.Conv2d(channels, channels, 3, stride=1, padding=1, bias=False) 32 | self.conv1_bn = CenteredBatchNorm2d(channels) 33 | self.conv2 = torch.nn.Conv2d(channels, channels, 3, stride=1, padding=1, bias=False) 34 | self.conv2_bn = CenteredBatchNorm2d(channels) 35 | 36 | def forward(self, x): 37 | out = self.conv1_bn(self.conv1(x)) 38 | out = F.relu(out, inplace=True) 39 | out = self.conv2_bn(self.conv2(out)) 40 | out += x 41 | out = F.relu(out, inplace=True) 42 | return out 43 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/torch/data_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import chess 4 | import numpy as np 5 | 6 | # Generate the regexs 7 | boardRE = re.compile(r"(([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)) ((w)|(b)) ((-)|(K)?(Q)?(k)?(q)?) (((-)|(\w+))) (\d+) (\d+)") 8 | 9 | 10 | replaceRE = re.compile(r'[1-8/]') 11 | 12 | all_pieces = 'pPnNbBrRqQkK' 13 | 14 | """ 15 | def replacer(matchObj): 16 | try: 17 | return 'E' * int(matchObj.group(0)) 18 | except ValueError: 19 | return '' 20 | """ 21 | 22 | #Map pieces to lists 23 | 24 | pieceMapWhite = {'E' : [False] * 12} 25 | pieceMapBlack = {'E' : [False] * 12} 26 | 27 | for i, p in enumerate(all_pieces): 28 | mP = [False] * 12 29 | mP[i] = True 30 | pieceMapWhite[p] = mP 31 | mP = [False] * 12 32 | mP[i + -1 if i % 2 else 1] = True 33 | pieceMapBlack[p] = mP 34 | 35 | iSs = [str(i + 1) for i in range(i)] 36 | 37 | #Some previous lines are left in just in case 38 | 39 | def fenToVec(fenstr): 40 | r = boardRE.match(fenstr) 41 | if r.group(11): 42 | is_white = [True] 43 | else: 44 | is_white = [False] 45 | if r.group(14): 46 | castling = [False, False, False, False] 47 | else: 48 | castling = [bool(r.group(15)), bool(r.group(16)), bool(r.group(17)), bool(r.group(18))] 49 | 50 | #En passant and 50 move counter need to be added 51 | #rowsS = replaceRE.sub(replacer, r.group(1)) 52 | rowsS = r.group(1).replace('/', '') 53 | for i, iS in enumerate(iSs): 54 | if iS in rowsS: 55 | rowsS = rowsS.replace(iS, 'E' * (i + 1)) 56 | #rows = [v for ch in rowsS for v in pieceMap[ch]] 57 | rows = [] 58 | for c in rowsS: 59 | rows += pieceMap[c] 60 | return np.array(rows + castling + is_white, dtype='bool') 61 | 62 | def fenToVec(fenstr): 63 | r = boardRE.match(fenstr) 64 | if r.group(11): 65 | pMap = pieceMapBlack 66 | else: 67 | pMap = pieceMapWhite 68 | """ 69 | if r.group(14): 70 | castling = [False, False, False, False] 71 | else: 72 | castling = [bool(r.group(15)), bool(r.group(16)), bool(r.group(17)), bool(r.group(18))] 73 | """ 74 | #rowsS = replaceRE.sub(replacer, r.group(1)) 75 | rowsS = r.group(1).replace('/', '') 76 | for i, iS in enumerate(iSs): 77 | if iS in rowsS: 78 | rowsS = rowsS.replace(iS, 'E' * (i + 1)) 79 | #rows = [v for ch in rowsS for v in pieceMap[ch]] 80 | rows = [] 81 | for c in rowsS: 82 | rows += pMap[c] 83 | #En passant, castling and 50 move counter need to be added 84 | return np.array(rows, dtype='bool').reshape((8,8,12)) 85 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/torch/model_loader.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | import os.path 4 | import pickle 5 | 6 | import torch 7 | 8 | def load_blunder_model_config(config_dir_path): 9 | with open(os.path.join(config_dir_path, 'config.yaml')) as f: 10 | config = yaml.safe_load(f.read()) 11 | if config['engine'] == 'sklearn': 12 | weightsPath = os.path.join(config_dir_path, config['options']['weightsPath']) 13 | with open(weightsPath, 'rb') as f: 14 | model = pickle.load(f) 15 | elif config['engine'] == 'torch': 16 | weightsPath = os.path.join(config_dir_path, config['options']['weightsPath']) 17 | model = torch.load(weightsPath) 18 | else: 19 | raise NotImplementedError(f"{config['engine']} is not a known engine type") 20 | return model, config 21 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/torch/new_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional 4 | 5 | import collections 6 | 7 | class ScalarToChannel(torch.nn.Module): 8 | def __init__(self, multiplier = 1.0): 9 | super().__init__() 10 | self.expander = torch.nn.Linear(1, 8*8, bias = False) 11 | self.expander.weight.requires_grad = False 12 | self.expander.weight.data.fill_(multiplier) 13 | 14 | def forward(self, x): 15 | return self.expander(x.unsqueeze(1)).reshape(-1, 1, 8, 8) 16 | 17 | class Flatten(torch.nn.Module): 18 | #https://stackoverflow.com/a/56771143 19 | def forward(self, input_x): 20 | return input_x.view(input_x.size(0), -1) 21 | 22 | class No_op(torch.nn.Module): 23 | def forward(self, input_x): 24 | return input_x 25 | 26 | 27 | class CenteredBatchNorm2d(torch.nn.BatchNorm2d): 28 | """Only apply bias, no scale like: 29 | tf.layers.batch_normalization( 30 | center=True, scale=False, 31 | ) 32 | """ 33 | 34 | def __init__(self, channels): 35 | super().__init__(channels, affine = True, eps=1e-5) 36 | #self.weight = 1 by default 37 | self.weight.requires_grad = False 38 | 39 | class ConvBlock(torch.nn.Module): 40 | def __init__(self, filter_size, input_channels, output_channels): 41 | super().__init__() 42 | layers = [ 43 | ('conv2d', torch.nn.Conv2d( 44 | input_channels, 45 | output_channels, 46 | filter_size, 47 | stride = 1, 48 | padding = filter_size // 2, 49 | bias = False, 50 | )), 51 | ('norm2d', CenteredBatchNorm2d(output_channels)), 52 | ('ReLU', torch.nn.ReLU()), 53 | ] 54 | self.seq = torch.nn.Sequential(collections.OrderedDict(layers)) 55 | 56 | def forward(self, x): 57 | return self.seq(x) 58 | 59 | class ResidualBlock(torch.nn.Module): 60 | def __init__(self, channels): 61 | super().__init__() 62 | 63 | layers = [ 64 | ('conv2d_1', torch.nn.Conv2d( 65 | channels, 66 | channels, 67 | 3, 68 | stride = 1, 69 | padding = 1, 70 | bias = False, 71 | )), 72 | ('norm2d_1', CenteredBatchNorm2d(channels)), 73 | ('ReLU', torch.nn.ReLU()), 74 | ('conv2d_2', torch.nn.Conv2d( 75 | channels, 76 | channels, 77 | 3, 78 | stride = 1, 79 | padding = 1, 80 | bias = False, 81 | )), 82 | ('norm2d_2', CenteredBatchNorm2d(channels)), 83 | ] 84 | self.seq = torch.nn.Sequential(collections.OrderedDict(layers)) 85 | 86 | def forward(self, x): 87 | y = self.seq(x) 88 | y += x 89 | y = torch.nn.functional.relu(y, inplace = True) 90 | return y 91 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/torch/tensorboard_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | #import torch.utils.tensorboard 3 | import tensorboardX 4 | import os.path 5 | 6 | import pytz 7 | tz = pytz.timezone('Canada/Eastern') 8 | 9 | class TB_wrapper(object): 10 | """Defers creating files until used""" 11 | 12 | def __init__(self, name, log_dir = 'runs'): 13 | self.log_dir = log_dir 14 | self.name = name 15 | self._tb = None 16 | 17 | @property 18 | def tb(self): 19 | if self._tb is None: 20 | tb_path = os.path.join(self.log_dir, f"{self.name}") 21 | if os.path.isdir(tb_path): 22 | i = 2 23 | tb_path = tb_path + f"_{i}" 24 | while os.path.isdir(tb_path): 25 | i += 1 26 | #only works to 10 27 | tb_path = tb_path[:-2] + f"_{i}" 28 | self._tb = tensorboardX.SummaryWriter( 29 | log_dir = tb_path 30 | ) 31 | #_{datetime.datetime.now(tz).strftime('%Y-%m-%d-H-%M')}")) 32 | return self._tb 33 | 34 | def add_scalar(self, *args, **kwargs): 35 | return self.tb.add_scalar(*args, **kwargs) 36 | 37 | def add_graph(self, model, input_to_model): 38 | self.tb.add_graph(model, input_to_model = input_to_model, verbose = False) 39 | 40 | def add_histogram(self, *args, **kwargs): 41 | return self.tb.add_histogram(*args, **kwargs) 42 | 43 | def flush(self): 44 | self.tb.flush() 45 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/torch/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..fen_to_vec import fenToVec 4 | from ..utils import fen_extend 5 | 6 | def fenToTensor(fenstr): 7 | try: 8 | t = torch.from_numpy(fenToVec(fenstr)) 9 | except AttributeError: 10 | t = torch.from_numpy(fenToVec(fen_extend(fenstr))) 11 | if torch.cuda.is_available(): 12 | t = t.cuda() 13 | return t.float() 14 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/torch/wrapper.py: -------------------------------------------------------------------------------- 1 | import pandas 2 | import os 3 | import os.path 4 | import re 5 | import io 6 | import torch 7 | 8 | from .utils import fenToTensor 9 | 10 | line_parse_re = re.compile(r"tensor\(([0-9.]+), .+?\)") 11 | 12 | default_vals_mean = { 13 | 'cp_rel' : 0.7, 14 | 'clock_percent' : 0.35, 15 | 'move_ply' : 30, 16 | 'opponent_elo' : 1500, 17 | 'active_elo' : 1500, 18 | } 19 | 20 | def parse_tensor(t): 21 | return t.group(1) 22 | 23 | line_parse_re = re.compile(r"tensor\(([0-9.]+), .+?\)") 24 | 25 | def parse_tensor(t): 26 | return t.group(1) 27 | 28 | class ModelWrapper(object): 29 | def __init__(self, path, default_vals = None): 30 | self.path = path.rstrip('/') 31 | model_paths = [fname for fname in os.listdir(self.path) if fname.startswith('net')] 32 | self.model_paths = sorted(model_paths, 33 | key = lambda x : int(x.split('-')[-1].split('.')[0]) 34 | if 'final' not in x else float('-inf')) 35 | self.newest_save = self.find_best_save() 36 | self.name = os.path.basename(self.path) 37 | self.net = torch.load(self.newest_save) 38 | if torch.cuda.is_available(): 39 | self.net = self.net.cuda() 40 | else: 41 | self.net = self.net.cpu() 42 | 43 | self.has_extras = False 44 | self.extras = None 45 | self.default_vals = {} 46 | self.defaults_dict = None 47 | 48 | if default_vals is not None: 49 | self.default_vals = default_vals_mean.copy() 50 | try: 51 | if self.net.has_extras: 52 | self.has_extras = True 53 | self.extras = self.net.extra_inputs.copy() 54 | self.defaults_dict = {} 55 | for n in self.extras: 56 | self.defaults_dict[n] = torch.Tensor([self.default_vals.get(n, default_vals_mean[n])]) 57 | except AttributeError: 58 | pass 59 | 60 | def find_best_save(self): 61 | return os.path.join(self.path, self.model_paths[-1]) 62 | 63 | def __repr__(self): 64 | return f"" 65 | 66 | def run_batch(self, input_fens, new_defaults = None): 67 | input_tensors = torch.stack([fenToTensor(fen) for fen in input_fens]) 68 | extra_x = None 69 | if self.has_extras: 70 | extra_x = {} 71 | if new_defaults is None: 72 | for n in self.extras: 73 | extra_x[n] = self.defaults_dict[n] * torch.ones([input_tensors.shape[0]], dtype = torch.float32) 74 | else: 75 | for n in self.extras: 76 | extra_x[n] = new_defaults.get(n, self.defaults_dict[n]) * torch.ones([input_tensors.shape[0]], dtype = torch.float32) 77 | if torch.cuda.is_available(): 78 | input_tensors = input_tensors.cuda() 79 | if extra_x is not None: 80 | for n in self.extras: 81 | extra_x[n] = extra_x[n].cuda() 82 | ret_dat = self.net.dict_forward(input_tensors, extra_x = extra_x) 83 | 84 | for n in ['is_blunder_wr', 'is_blunder_mean']: 85 | try: 86 | return ret_dat[n].detach().cpu().numpy() 87 | except KeyError: 88 | pass 89 | for n in ['winrate_loss', 'is_blunder_wr_mean']: 90 | try: 91 | return ret_dat[n].detach().cpu().numpy() * 5 92 | except KeyError: 93 | pass 94 | raise KeyError(f"No known output types found in: {ret_dat.keys()}") 95 | -------------------------------------------------------------------------------- /blunder_prediction/maia_chess_backend/uci.py: -------------------------------------------------------------------------------- 1 | import chess.uci 2 | 3 | import collections 4 | import concurrent.futures 5 | import threading 6 | 7 | import re 8 | import os.path 9 | 10 | probRe = re.compile(r"\(P: +([^)]+)\)") 11 | 12 | 13 | class ProbInfoHandler(chess.uci.InfoHandler): 14 | def __init__(self): 15 | super().__init__() 16 | self.info["probs"] = [] 17 | 18 | def on_go(self): 19 | """ 20 | Notified when a *go* command is beeing sent. 21 | 22 | Since information about the previous search is invalidated, the 23 | dictionary with the current information will be cleared. 24 | """ 25 | with self.lock: 26 | self.info.clear() 27 | self.info["refutation"] = {} 28 | self.info["currline"] = {} 29 | self.info["pv"] = {} 30 | self.info["score"] = {} 31 | self.info["probs"] = [] 32 | 33 | def string(self, string): 34 | """Receives a string the engine wants to display.""" 35 | prob = re.search(probRe, string).group(1) 36 | self.info["probs"].append(string) 37 | 38 | class EngineHandler(object): 39 | def __init__(self, engine, weights, threads = 2): 40 | self.enginePath = os.path.normpath(engine) 41 | self.weightsPath = os.path.normpath(weights) 42 | 43 | self.engine = chess.uci.popen_engine([self.enginePath, "--verbose-move-stats", f"--threads={threads}", f"--weights={self.weightsPath}"]) 44 | 45 | self.info_handler = ProbInfoHandler() 46 | self.engine.info_handlers.append(self.info_handler) 47 | 48 | self.engine.uci() 49 | self.engine.isready() 50 | 51 | def __repr__(self): 52 | return f"" 53 | 54 | def getBoardProbs(self, board, movetime = 1000, nodes = 1000): 55 | self.engine.ucinewgame() 56 | self.engine.position(board) 57 | moves = self.engine.go(movetime = movetime, nodes = nodes) 58 | probs = self.info_handler.info['probs'] 59 | return moves, probs 60 | -------------------------------------------------------------------------------- /blunder_prediction/make_csvs.sh: -------------------------------------------------------------------------------- 1 | screen -S 2017-04 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2017-04.pgn.bz2 csvs' 2 | screen -S 2017-05 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2017-05.pgn.bz2 csvs' 3 | screen -S 2017-06 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2017-06.pgn.bz2 csvs' 4 | screen -S 2017-07 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2017-07.pgn.bz2 csvs' 5 | screen -S 2017-08 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2017-08.pgn.bz2 csvs' 6 | screen -S 2017-09 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2017-09.pgn.bz2 csvs' 7 | screen -S 2017-10 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2017-10.pgn.bz2 csvs' 8 | screen -S 2017-11 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2017-11.pgn.bz2 csvs' 9 | screen -S 2017-12 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2017-12.pgn.bz2 csvs' 10 | 11 | screen -S 2018-01 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2018-01.pgn.bz2 csvs' 12 | screen -S 2018-02 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2018-02.pgn.bz2 csvs' 13 | screen -S 2018-03 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2018-03.pgn.bz2 csvs' 14 | screen -S 2018-04 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2018-04.pgn.bz2 csvs' 15 | screen -S 2018-05 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2018-05.pgn.bz2 csvs' 16 | screen -S 2018-06 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2018-06.pgn.bz2 csvs' 17 | screen -S 2018-07 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2018-07.pgn.bz2 csvs' 18 | screen -S 2018-08 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2018-08.pgn.bz2 csvs' 19 | screen -S 2018-09 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2018-09.pgn.bz2 csvs' 20 | screen -S 2018-10 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2018-10.pgn.bz2 csvs' 21 | screen -S 2018-11 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2018-11.pgn.bz2 csvs' 22 | screen -S 2018-12 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2018-12.pgn.bz2 csvs' 23 | 24 | screen -S 2019-01 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2019-01.pgn.bz2 csvs' 25 | screen -S 2019-02 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2019-02.pgn.bz2 csvs' 26 | screen -S 2019-03 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2019-03.pgn.bz2 csvs' 27 | screen -S 2019-04 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2019-04.pgn.bz2 csvs' 28 | screen -S 2019-05 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2019-05.pgn.bz2 csvs' 29 | screen -S 2019-06 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2019-06.pgn.bz2 csvs' 30 | 31 | screen -S 2019-07 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2019-07.pgn.bz2 csvs' 32 | screen -S 2019-08 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2019-08.pgn.bz2 csvs' 33 | screen -S 2019-09 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2019-09.pgn.bz2 csvs' 34 | screen -S 2019-10 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2019-10.pgn.bz2 csvs' 35 | screen -S 2019-11 -dm bash -c 'source ~/.bashrc; python3 make_month_csv.py lichess_raw/lichess_db_standard_rated_2019-11.pgn.bz2 csvs' 36 | -------------------------------------------------------------------------------- /blunder_prediction/mmap_grouped_csv.py: -------------------------------------------------------------------------------- 1 | import haibrid_chess_utils 2 | 3 | import bz2 4 | import argparse 5 | import os 6 | import os.path 7 | import multiprocessing 8 | import time 9 | import json 10 | import chess 11 | 12 | import numpy as np 13 | import pandas 14 | 15 | mmap_columns = [ 16 | 'is_blunder_mean', 17 | 'is_blunder_wr_mean', 18 | 'active_elo_mean', 19 | 'opponent_elo_mean', 20 | 'active_won_mean', 21 | 'cp_rel_mean', 22 | 'cp_loss_mean', 23 | 'num_ply_mean', 24 | ] 25 | 26 | 27 | target_columns = mmap_columns + ['board_extended', 'top_nonblunder', 'top_blunder'] 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser(description='Make mmapped version of csv', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 31 | 32 | parser.add_argument('inputs', nargs = '+', help='input csv') 33 | parser.add_argument('outputDir', help='output dir of mmapped files') 34 | parser.add_argument('--nrows', type=int, help='number of rows to read in, FOR TESTING', default = None) 35 | 36 | args = parser.parse_args() 37 | 38 | haibrid_chess_utils.printWithDate(f"Starting mmap of {', '.join(args.inputs)} writing to {args.outputDir} with {', '.join(mmap_columns)}") 39 | 40 | mmaps = {} 41 | for path in args.inputs: 42 | mmaps[path] = mmap_csv( 43 | path, 44 | load_csv(path, args.nrows), 45 | args.outputDir, 46 | args, 47 | ) 48 | haibrid_chess_utils.printWithDate("All mmapped") 49 | try: 50 | while True: 51 | haibrid_chess_utils.printWithDate("Still alive", end = '\r') 52 | time.sleep(10 * 60) 53 | except KeyboardInterrupt: 54 | print() 55 | haibrid_chess_utils.printWithDate("Exiting") 56 | 57 | def load_csv(target_path, nrows): 58 | haibrid_chess_utils.printWithDate(f"Loading: {target_path}", flush = True) 59 | return pandas.read_csv(target_path, usecols=target_columns, nrows = nrows) 60 | 61 | def mmap_csv(target_path, df, outputDir, args): 62 | haibrid_chess_utils.printWithDate(f"Loading: {target_path}") 63 | name = os.path.basename(target_path).split('.')[0] 64 | 65 | df_blunder = df[df['is_blunder_mean']] 66 | haibrid_chess_utils.printWithDate(f"Found {len(df_blunder)} blunders") 67 | 68 | df_blunder = df_blunder.sample(frac=1).reset_index(drop=True) 69 | 70 | df_non_blunder = df[df['is_blunder_mean'].eq(False)] 71 | haibrid_chess_utils.printWithDate(f"Found {len(df_non_blunder)} non blunders") 72 | 73 | df_non_blunder = df_non_blunder.sample(frac=1).reset_index(drop=True) 74 | 75 | del df 76 | 77 | haibrid_chess_utils.printWithDate(f"Reduced to {len(df_non_blunder)} non blunders") 78 | 79 | haibrid_chess_utils.printWithDate(f"Starting mmaping") 80 | 81 | os.makedirs(outputDir, exist_ok = True) 82 | 83 | mmaps = {} 84 | 85 | mmaps['blunder'] = make_df_mmaps(df_blunder, name, os.path.join(outputDir, 'blunder')) 86 | 87 | del df_blunder 88 | mmaps['nonblunder'] = make_df_mmaps(df_non_blunder, name, os.path.join(outputDir, 'nonblunder')) 89 | return mmaps 90 | 91 | def make_var_mmap(y_name, outputName, mmaps, df): 92 | a_c = df[y_name].values 93 | if a_c.dtype == np.bool: 94 | a_c = a_c.astype(np.long) 95 | else: 96 | a_c = a_c.astype(np.float32) 97 | mmaps[y_name] = np.memmap(f"{outputName}+{y_name}+{a_c.dtype}+{a_c.shape[0]}.mm", dtype=a_c.dtype, mode='w+', shape=a_c.shape) 98 | mmaps[y_name][:] = a_c[:] 99 | 100 | def make_board_mmap(outputName, mmaps, df): 101 | 102 | b_sample_shape = haibrid_chess_utils.fenToVec(chess.Board().fen()).shape 103 | 104 | mmap_vec = np.memmap( 105 | f"{outputName}+board+{len(df)}.mm", 106 | dtype=np.bool, 107 | mode='w+', 108 | shape=(len(df), b_sample_shape[0], b_sample_shape[1], b_sample_shape[2]), 109 | ) 110 | for i, (_, row) in enumerate(df.iterrows()): 111 | mmap_vec[i, :] = haibrid_chess_utils.fenToVec(row['board_extended'])[:] 112 | #a_boards = np.stack(pool.map(haibrid_chess_utils.fenToVec, df['board'])) 113 | mmaps['board'] = mmap_vec 114 | 115 | def move_index_nansafe(move): 116 | try: 117 | return haibrid_chess_utils.move_to_index(move) 118 | except TypeError: 119 | return -1 120 | 121 | def make_move_mmap(outputName, mmaps, moves_name, df): 122 | a_moves = np.stack(df[moves_name].apply(move_index_nansafe)) 123 | mmaps[moves_name] = np.memmap(f"{outputName}+{moves_name}+{a_moves.shape[0]}.mm", dtype=a_moves.dtype, mode='w+', shape=a_moves.shape) 124 | mmaps[moves_name][:] = a_moves[:] 125 | 126 | def make_df_mmaps(df, name, output_dir): 127 | os.makedirs(output_dir, exist_ok = True) 128 | outputName = os.path.join(output_dir, name) 129 | 130 | mmaps = {} 131 | haibrid_chess_utils.printWithDate(f"Making y_vals mmaps for: {name} done:", end = ' ') 132 | for y_name in mmap_columns: 133 | make_var_mmap(y_name, outputName, mmaps, df) 134 | print(y_name, end = ' ', flush = True) 135 | 136 | haibrid_chess_utils.printWithDate(f"Making move array mmaps for: {name}") 137 | 138 | make_move_mmap(outputName, mmaps, 'top_blunder', df) 139 | make_move_mmap(outputName, mmaps, 'top_nonblunder', df) 140 | 141 | haibrid_chess_utils.printWithDate(f"Making boards array mmaps for: {name}") 142 | 143 | make_board_mmap(outputName, mmaps, df) 144 | 145 | return mmaps 146 | 147 | if __name__ == '__main__': 148 | main() 149 | -------------------------------------------------------------------------------- /blunder_prediction/model_weights/README.md: -------------------------------------------------------------------------------- 1 | Weight files for all 4 final models, labeld by confog file used and number of steps 2 | -------------------------------------------------------------------------------- /blunder_prediction/model_weights/leela_blunder_00001-266000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/blunder_prediction/model_weights/leela_blunder_00001-266000.pt -------------------------------------------------------------------------------- /blunder_prediction/model_weights/leela_extra_blunder_00001-234000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/blunder_prediction/model_weights/leela_extra_blunder_00001-234000.pt -------------------------------------------------------------------------------- /blunder_prediction/model_weights/leela_grouped_0001-1372000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/blunder_prediction/model_weights/leela_grouped_0001-1372000.pt -------------------------------------------------------------------------------- /blunder_prediction/model_weights/leela_winrate_loss_00001-200000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/blunder_prediction/model_weights/leela_winrate_loss_00001-200000.pt -------------------------------------------------------------------------------- /data_generators/extractELOrange.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../move_prediction") 3 | 4 | import maia_chess_backend 5 | 6 | import argparse 7 | import bz2 8 | 9 | @maia_chess_backend.logged_main 10 | def main(): 11 | parser = argparse.ArgumentParser(description='Process some integers.') 12 | parser.add_argument('eloMin', type=int, help='min ELO') 13 | parser.add_argument('eloMax', type=int, help='max ELO') 14 | parser.add_argument('output', help='output file') 15 | parser.add_argument('targets', nargs='+', help='target files') 16 | parser.add_argument('--remove_bullet', action='store_true', help='Remove bullet and ultrabullet games') 17 | parser.add_argument('--remove_low_time', action='store_true', help='Remove low time moves from games') 18 | 19 | args = parser.parse_args() 20 | gamesWritten = 0 21 | print(f"Starting writing to: {args.output}") 22 | with bz2.open(args.output, 'wt') as f: 23 | for num_files, target in enumerate(sorted(args.targets)): 24 | print(f"{num_files} reading: {target}") 25 | Games = maia_chess_backend.LightGamesFile(target, parseMoves = False) 26 | for i, (dat, lines) in enumerate(Games): 27 | try: 28 | whiteELO = int(dat['WhiteElo']) 29 | BlackELO = int(dat['BlackElo']) 30 | except ValueError: 31 | continue 32 | if whiteELO > args.eloMax or whiteELO <= args.eloMin: 33 | continue 34 | elif BlackELO > args.eloMax or BlackELO <= args.eloMin: 35 | continue 36 | elif dat['Result'] not in ['1-0', '0-1', '1/2-1/2']: 37 | continue 38 | elif args.remove_bullet and 'Bullet' in dat['Event']: 39 | continue 40 | else: 41 | if args.remove_low_time: 42 | f.write(maia_chess_backend.remove_low_time(lines)) 43 | else: 44 | f.write(lines) 45 | gamesWritten += 1 46 | if i % 1000 == 0: 47 | print(f"{i}: written {gamesWritten} files {num_files}: {target}".ljust(79), end = '\r') 48 | print(f"Done: {target} {i}".ljust(79)) 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /data_generators/filter_csv.py: -------------------------------------------------------------------------------- 1 | import pandas 2 | import bz2 3 | import argparse 4 | import os 5 | 6 | import haibrid_chess_utils 7 | 8 | target_columns = ['game_id', 'type', 'time_control', 'num_ply', 'move_ply', 'move', 'cp', 'cp_rel', 'cp_loss', 'is_blunder', 'winrate', 'winrate_loss', 'blunder_wr', 'is_capture', 'opp_winrate', 'white_active', 'active_elo', 'opponent_elo', 'active_won', 'clock', 'opp_clock', 'board'] 9 | 10 | @haibrid_chess_utils.logged_main 11 | def main(): 12 | parser = argparse.ArgumentParser(description='Create new cvs with select columns', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 13 | 14 | parser.add_argument('input', help='input CSV') 15 | parser.add_argument('outputDir', help='output CSV') 16 | 17 | args = parser.parse_args() 18 | 19 | haibrid_chess_utils.printWithDate(f"Starting CSV conversion of {args.input} writing to {args.outputDir}") 20 | haibrid_chess_utils.printWithDate(f"Collecting {', '.join(target_columns)}") 21 | 22 | name = os.path.basename(args.input).split('.')[0] 23 | outputName = os.path.join(args.outputDir, f"{name}_trimmed.csv.bz2") 24 | 25 | haibrid_chess_utils.printWithDate(f"Created output name {outputName}") 26 | 27 | os.makedirs(args.outputDir, exist_ok = True) 28 | 29 | haibrid_chess_utils.printWithDate(f"Starting read") 30 | with bz2.open(args.input, 'rt') as f: 31 | df = pandas.read_csv(f, usecols = target_columns) 32 | 33 | haibrid_chess_utils.printWithDate(f"Starting write") 34 | with bz2.open(outputName, 'wt') as f: 35 | df.to_csv(f, index = False) 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /data_generators/grouped_train_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas 3 | import os 4 | import os.path 5 | import numpy as np 6 | import haibrid_chess_utils 7 | 8 | @haibrid_chess_utils.logged_main 9 | def main(): 10 | parser = argparse.ArgumentParser(description='Make train testr split of grouped boards', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 11 | 12 | parser.add_argument('input', help='input CSV') 13 | parser.add_argument('outputDir', help='output CSVs dir') 14 | parser.add_argument('--nrows', type=int, help='number of rows to read in, FOR TESTING', default = None) 15 | parser.add_argument('--blunder_ratio', type=float, help='ratio to declare a positive class', default = .1) 16 | parser.add_argument('--min_count', type=int, help='ratio to declare a positive class', default = 10) 17 | 18 | args = parser.parse_args() 19 | os.makedirs(os.path.join(args.outputDir, 'test'), exist_ok=True) 20 | os.makedirs(os.path.join(args.outputDir, 'train'), exist_ok=True) 21 | 22 | haibrid_chess_utils.printWithDate(f"Loading: {args.input}") 23 | 24 | df = pandas.read_csv(args.input) 25 | 26 | haibrid_chess_utils.printWithDate(f"Filtering: {args.input}") 27 | df = df[df['count'] >= args.min_count].copy() 28 | 29 | df['is_blunder_mean'] = df['is_blunder_wr_mean'] > .1 30 | df['board_extended'] = df['board'].apply(lambda x : x + ' KQkq - 0 1') 31 | df['white_active'] = df['board'].apply(lambda x : x.endswith('w')) 32 | df['has_nonblunder_move'] = df['top_nonblunder'].isna() == False 33 | df['has_blunder_move'] = df['top_blunder'].isna() == False 34 | df['is_test'] = [np.random.random() < args.blunder_ratio for i in range(len(df))] 35 | 36 | haibrid_chess_utils.printWithDate(f"Wrting to: {args.outputDir}") 37 | df[df['is_test']].to_csv('/datadrive/group_csv/test/grouped_fens.csv.bz2', compression = 'bz2') 38 | 39 | df[df['is_test'] == False].to_csv(os.path.join(args.outputDir, 'train', 'grouped_fens.csv.bz2'), compression = 'bz2') 40 | df[df['is_test'] == True].to_csv(os.path.join(args.outputDir, 'test', 'grouped_fens.csv.bz2'), compression = 'bz2') 41 | 42 | if __name__ == '__main__': 43 | main() 44 | -------------------------------------------------------------------------------- /data_generators/make_batch_csv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import humanize 4 | import multiprocessing 5 | import bz2 6 | import io 7 | import os 8 | import os.path 9 | import re 10 | import queue 11 | import zipfile 12 | import pandas 13 | 14 | import numpy as np 15 | 16 | import chess 17 | import chess.pgn 18 | 19 | import haibrid_chess_utils 20 | 21 | target_columns = [ 22 | 'game_id', 23 | 'move_ply', 24 | 'cp_rel', 25 | 'cp_loss', 26 | 'is_blunder_cp', 27 | 'winrate', 28 | 'winrate_elo', 29 | 'winrate_loss', 30 | 'is_blunder_wr', 31 | 'opp_winrate', 32 | 'white_active', 33 | 'active_elo', 34 | 'opponent_elo', 35 | 'active_won', 36 | 'low_time', 37 | 'board', 38 | ] 39 | 40 | @haibrid_chess_utils.logged_main 41 | def main(): 42 | parser = argparse.ArgumentParser(description='Create two new csvs with select columns split by is_blunder_wr', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 43 | 44 | parser.add_argument('input', help='input CSV') 45 | parser.add_argument('outputDir', help='output CSV') 46 | 47 | parser.add_argument('--min_elo', type=int, help='min active elo', default = 1000) 48 | parser.add_argument('--max_elo', type=int, help='min active elo', default = 9999999999) 49 | parser.add_argument('--allow_negative_loss', type=bool, help='allow winrate losses below 0', default = False) 50 | parser.add_argument('--allow_low_time', type=bool, help='Include low time moves', default = False) 51 | 52 | parser.add_argument('--min_ply', type=int, help='min move ply to consider', default = 6) 53 | 54 | #parser.add_argument('--shuffleSize', type=int, help='Shuffle buffer size', default = 1000) 55 | parser.add_argument('--nrows', type=int, help='number of rows to read in', default = None) 56 | 57 | parser.add_argument('--nb_to_b_ratio', type=float, help='ratio fof blunders to non blunders in dataset', default = 1.5) 58 | 59 | args = parser.parse_args() 60 | 61 | haibrid_chess_utils.printWithDate(f"Starting CSV split of {args.input} writing to {args.outputDir}") 62 | haibrid_chess_utils.printWithDate(f"Collecting {', '.join(target_columns)}") 63 | 64 | name = os.path.basename(args.input).split('.')[0] 65 | outputBlunder = os.path.join(args.outputDir, f"{name}_blunder.csv.bz2") 66 | outputNonBlunder = os.path.join(args.outputDir, f"{name}_nonblunder.csv.bz2") 67 | 68 | haibrid_chess_utils.printWithDate(f"Created outputs named {outputBlunder} and {outputNonBlunder}") 69 | 70 | 71 | 72 | os.makedirs(args.outputDir, exist_ok = True) 73 | 74 | haibrid_chess_utils.printWithDate(f"Starting read") 75 | with bz2.open(args.input, 'rt') as f: 76 | df = pandas.read_csv(f, usecols = target_columns, nrows = args.nrows) 77 | 78 | 79 | haibrid_chess_utils.printWithDate(f"Filtering data starting at {len(df)} rows") 80 | 81 | df = df[df['move_ply'] >= args.min_ply] 82 | 83 | if not args.allow_low_time: 84 | df = df[df['low_time'].eq(False)] 85 | 86 | if not args.allow_negative_loss: 87 | df = df[df['winrate_loss'] > 0] 88 | 89 | df = df[df['active_elo'] > args.min_elo] 90 | df = df[df['active_elo'] < args.max_elo] 91 | 92 | df = df.dropna() 93 | 94 | haibrid_chess_utils.printWithDate(f"Filtering down data to {len(df)} rows") 95 | 96 | df_blunder = df[df['is_blunder_wr']] 97 | haibrid_chess_utils.printWithDate(f"Found {len(df_blunder)} blunders") 98 | 99 | df_blunder = df_blunder.sample(frac=1).reset_index(drop=True) 100 | 101 | df_non_blunder = df[df['is_blunder_wr'].eq(False)] 102 | haibrid_chess_utils.printWithDate(f"Found {len(df_non_blunder)} non blunders") 103 | 104 | df_non_blunder = df_non_blunder.sample(frac=1).reset_index(drop=True).iloc[:int(len(df_blunder) * args.nb_to_b_ratio)] 105 | 106 | haibrid_chess_utils.printWithDate(f"Reduced to {len(df_non_blunder)} non blunders") 107 | 108 | haibrid_chess_utils.printWithDate(f"Starting writing") 109 | 110 | with bz2.open(outputNonBlunder, 'wt') as fnb: 111 | df_non_blunder.to_csv(fnb, index = False) 112 | with bz2.open(outputBlunder, 'wt') as fb: 113 | df_blunder.to_csv(fb, index = False) 114 | 115 | if __name__ == '__main__': 116 | main() 117 | -------------------------------------------------------------------------------- /data_generators/make_combined_csvs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import humanize 4 | import multiprocessing 5 | import bz2 6 | import io 7 | import os 8 | import os.path 9 | import re 10 | import queue 11 | import zipfile 12 | import pandas 13 | 14 | import numpy as np 15 | 16 | import chess 17 | import chess.pgn 18 | 19 | import haibrid_chess_utils 20 | 21 | target_columns = [ 22 | 'game_id', 23 | 'move_ply', 24 | 'cp_rel', 25 | 'cp_loss', 26 | 'winrate', 27 | 'winrate_elo', 28 | 'winrate_loss', 29 | 'is_blunder_wr', 30 | 'opp_winrate', 31 | 'white_active', 32 | 'active_elo', 33 | 'opponent_elo', 34 | 'active_won', 35 | 'low_time', 36 | 'board', 37 | ] 38 | 39 | @haibrid_chess_utils.logged_main 40 | def main(): 41 | parser = argparse.ArgumentParser(description='Create two new csvs with select columns split by is_blunder_wr', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 42 | 43 | parser.add_argument('targets', nargs = '+', help='input CSVs') 44 | parser.add_argument('outputDir', help='output CSV') 45 | 46 | parser.add_argument('--min_elo', type=int, help='min active elo', default = 1000) 47 | parser.add_argument('--max_elo', type=int, help='min active elo', default = 9999999999) 48 | parser.add_argument('--allow_negative_loss', type=bool, help='allow winrate losses below 0', default = False) 49 | parser.add_argument('--allow_low_time', type=bool, help='Include low time moves', default = False) 50 | 51 | parser.add_argument('--min_ply', type=int, help='min move ply to consider', default = 6) 52 | parser.add_argument('--pool', type=64, help='min move ply to consider', default = 6) 53 | 54 | #parser.add_argument('--shuffleSize', type=int, help='Shuffle buffer size', default = 1000) 55 | parser.add_argument('--nrows', type=int, help='number of rows to read in', default = None) 56 | 57 | parser.add_argument('--nb_to_b_ratio', type=float, help='ratio fof blunders to non blunders in dataset', default = 1.5) 58 | 59 | args = parser.parse_args() 60 | 61 | haibrid_chess_utils.printWithDate(f"Starting CSVs split of {len(args.targets)} targets writing to {args.outputBlunder}") 62 | haibrid_chess_utils.printWithDate(f"Collecting {', '.join(target_columns)}") 63 | 64 | name = os.path.basename(args.input).split('.')[0] 65 | outputBlunder = os.path.join(args.outputDir, f"{name}_blunder.csv.bz2") 66 | outputNonBlunder = os.path.join(args.outputDir, f"{name}_nonblunder.csv.bz2") 67 | 68 | haibrid_chess_utils.printWithDate(f"Created outputs named {outputBlunder} and {outputNonBlunder}") 69 | 70 | 71 | 72 | os.makedirs(args.outputDir, exist_ok = True) 73 | 74 | haibrid_chess_utils.printWithDate(f"Starting read") 75 | 76 | with multiprocessing.Pool(args.pool) as pool 77 | dfs = 78 | 79 | 80 | with bz2.open(args.input, 'rt') as f: 81 | df = pandas.read_csv(f, usecols = target_columns, nrows = args.nrows) 82 | 83 | 84 | haibrid_chess_utils.printWithDate(f"Filtering data starting at {len(df)} rows") 85 | 86 | df = df[df['move_ply'] >= args.min_ply] 87 | 88 | if not args.allow_low_time: 89 | df = df[df['low_time'].eq(False)] 90 | 91 | if not args.allow_negative_loss: 92 | df = df[df['winrate_loss'] > 0] 93 | 94 | df = df[df['active_elo'] > args.min_elo] 95 | df = df[df['active_elo'] < args.max_elo] 96 | 97 | df = df.dropna() 98 | 99 | haibrid_chess_utils.printWithDate(f"Filtering down data to {len(df)} rows") 100 | 101 | df_blunder = df[df['is_blunder_wr']] 102 | haibrid_chess_utils.printWithDate(f"Found {len(df_blunder)} blunders") 103 | 104 | df_blunder = df_blunder.sample(frac=1).reset_index(drop=True) 105 | 106 | df_non_blunder = df[df['is_blunder_wr'].eq(False)] 107 | haibrid_chess_utils.printWithDate(f"Found {len(df_non_blunder)} non blunders") 108 | 109 | df_non_blunder = df_non_blunder.sample(frac=1).reset_index(drop=True).iloc[:int(len(df_blunder) * args.nb_to_b_ratio)] 110 | 111 | haibrid_chess_utils.printWithDate(f"Reduced to {len(df_non_blunder)} non blunders") 112 | 113 | haibrid_chess_utils.printWithDate(f"Starting writing") 114 | 115 | with bz2.open(outputNonBlunder, 'wt') as fnb: 116 | df_non_blunder.to_csv(fnb, index = False) 117 | with bz2.open(outputBlunder, 'wt') as fb: 118 | df_blunder.to_csv(fb, index = False) 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /data_generators/make_csvs.sh: -------------------------------------------------------------------------------- 1 | 2 | screen -S 2017-04 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2017-04.pgn.bz2 /datadrive/new_board_csvs' 3 | screen -S 2017-05 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2017-05.pgn.bz2 /datadrive/new_board_csvs' 4 | screen -S 2017-06 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2017-06.pgn.bz2 /datadrive/new_board_csvs' 5 | screen -S 2017-07 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2017-07.pgn.bz2 /datadrive/new_board_csvs' 6 | screen -S 2017-08 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2017-08.pgn.bz2 /datadrive/new_board_csvs' 7 | screen -S 2017-09 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2017-09.pgn.bz2 /datadrive/new_board_csvs' 8 | screen -S 2017-10 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2017-10.pgn.bz2 /datadrive/new_board_csvs' 9 | screen -S 2017-11 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2017-11.pgn.bz2 /datadrive/new_board_csvs' 10 | screen -S 2017-12 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2017-12.pgn.bz2 /datadrive/new_board_csvs' 11 | 12 | screen -S 2018-01 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2018-01.pgn.bz2 /datadrive/new_board_csvs' 13 | screen -S 2018-02 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2018-02.pgn.bz2 /datadrive/new_board_csvs' 14 | screen -S 2018-03 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2018-03.pgn.bz2 /datadrive/new_board_csvs' 15 | screen -S 2018-04 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2018-04.pgn.bz2 /datadrive/new_board_csvs' 16 | screen -S 2018-05 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2018-05.pgn.bz2 /datadrive/new_board_csvs' 17 | screen -S 2018-06 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2018-06.pgn.bz2 /datadrive/new_board_csvs' 18 | screen -S 2018-07 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2018-07.pgn.bz2 /datadrive/new_board_csvs' 19 | screen -S 2018-08 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2018-08.pgn.bz2 /datadrive/new_board_csvs' 20 | screen -S 2018-09 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2018-09.pgn.bz2 /datadrive/new_board_csvs' 21 | screen -S 2018-10 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2018-10.pgn.bz2 /datadrive/new_board_csvs' 22 | screen -S 2018-11 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2018-11.pgn.bz2 /datadrive/new_board_csvs' 23 | screen -S 2018-12 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2018-12.pgn.bz2 /datadrive/new_board_csvs' 24 | 25 | screen -S 2019-01 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2019-01.pgn.bz2 /datadrive/new_board_csvs' 26 | screen -S 2019-02 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2019-02.pgn.bz2 /datadrive/new_board_csvs' 27 | screen -S 2019-03 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2019-03.pgn.bz2 /datadrive/new_board_csvs' 28 | screen -S 2019-04 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2019-04.pgn.bz2 /datadrive/new_board_csvs' 29 | screen -S 2019-05 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2019-05.pgn.bz2 /datadrive/new_board_csvs' 30 | screen -S 2019-06 -dm bash -c '/home/reidmcy/anaconda3/bin/python3 make_month_csv.py /datadrive/raw_pgns/lichess_db_standard_rated_2019-06.pgn.bz2 /datadrive/new_board_csvs' 31 | -------------------------------------------------------------------------------- /data_generators/pgnCPsToCSV_single.py: -------------------------------------------------------------------------------- 1 | #Most of the functions are imported from multi 2 | import argparse 3 | import time 4 | import humanize 5 | import multiprocessing 6 | import bz2 7 | import io 8 | import os 9 | import os.path 10 | import re 11 | import queue 12 | 13 | import chess 14 | import chess.pgn 15 | 16 | import haibrid_chess_utils 17 | 18 | from pgnCPsToCSV_multi import * 19 | 20 | def cleanup(pgnReaders, gameReaders, writers): 21 | pgnReaders.get() 22 | haibrid_chess_utils.printWithDate(f"Done reading") 23 | time.sleep(10) 24 | for r in gameReaders: 25 | r.get() 26 | haibrid_chess_utils.printWithDate(f"Done processing") 27 | 28 | writers.get() 29 | 30 | @haibrid_chess_utils.logged_main 31 | def main(): 32 | parser = argparse.ArgumentParser(description='process PGN file with stockfish annotaions into a csv file', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 33 | 34 | parser.add_argument('input', help='input PGNs') 35 | parser.add_argument('outputDir', help='output CSVs dir') 36 | 37 | parser.add_argument('--pool', type=int, help='number of simultaneous jobs running per fil', default = 30) 38 | #parser.add_argument('--readers', type=int, help='number of simultaneous reader running per inputfile', default = 24) 39 | parser.add_argument('--queueSize', type=int, help='Max number of games to cache', default = 1000) 40 | 41 | args = parser.parse_args() 42 | 43 | haibrid_chess_utils.printWithDate(f"Starting CSV conversion of {args.input} writing to {args.outputDir}") 44 | 45 | os.makedirs(args.outputDir, exist_ok=True) 46 | 47 | name = os.path.basename(args.input).split('.')[0] 48 | outputName = os.path.join(args.outputDir, f"{name}.csv.bz2") 49 | #names[n] = (name, outputName) 50 | 51 | 52 | haibrid_chess_utils.printWithDate(f"Loading file: {name}") 53 | haibrid_chess_utils.printWithDate(f"Starting main loop") 54 | 55 | tstart = time.time() 56 | 57 | print(args) 58 | with multiprocessing.Manager() as manager: 59 | with multiprocessing.Pool(args.pool) as workers_pool, multiprocessing.Pool(3) as io_pool: 60 | pgnReader, gameReader, writer, unproccessedQueue, resultsQueue = processPGN(args.input, name, outputName, args.queueSize, args.pool, manager, workers_pool, io_pool) 61 | 62 | haibrid_chess_utils.printWithDate(f"Done loading Queues in {humanize.naturaldelta(time.time() - tstart)}, waiting for reading to finish") 63 | 64 | cleanup(pgnReader, gameReader, writer) 65 | 66 | haibrid_chess_utils.printWithDate(f"Done everything in {humanize.naturaldelta(time.time() - tstart)}, exiting") 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /data_generators/run_batch_csv_makers.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | 5 | screen -S 2017-04 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2017-04.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 6 | screen -S 2017-05 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2017-05.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 7 | screen -S 2017-06 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2017-06.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 8 | screen -S 2017-07 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2017-07.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 9 | screen -S 2017-08 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2017-08.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 10 | screen -S 2017-09 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2017-09.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 11 | screen -S 2017-10 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2017-10.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 12 | screen -S 2017-11 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2017-11.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 13 | screen -S 2017-12 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2017-12.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 14 | screen -S 2018-01 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2018-01.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 15 | screen -S 2018-02 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2018-02.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 16 | screen -S 2018-03 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2018-03.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 17 | screen -S 2018-04 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2018-04.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 18 | screen -S 2018-05 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2018-05.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 19 | screen -S 2018-06 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2018-06.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 20 | screen -S 2018-07 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2018-07.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 21 | screen -S 2018-08 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2018-08.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 22 | screen -S 2018-09 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2018-09.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 23 | screen -S 2018-10 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2018-10.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 24 | screen -S 2018-11 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2018-11.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 25 | screen -S 2018-12 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2018-12.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 26 | screen -S 2019-01 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2019-01.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 27 | screen -S 2019-02 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2019-02.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 28 | screen -S 2019-03 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2019-03.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 29 | screen -S 2019-04 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2019-04.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 30 | screen -S 2019-05 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2019-05.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 31 | screen -S 2019-06 -dm bash -c 'source ~/.bashrc; python3 make_batch_csv.py /ada/data/haibrid-chess/lichess_board_csvs/lichess_db_standard_rated_2019-06.csv.bz2 ../data/lichess_batch_csv --nrows 10000000000000' 32 | 33 | for i in {01..06} 34 | do 35 | echo "2019-${i}" 36 | screen -S "2019-${i}" -X quit 37 | done 38 | -------------------------------------------------------------------------------- /data_generators/run_singles.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | screen -S 2017-04 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2017-04.pgn.bz2 ../datasets/lichess_board_csvs' 4 | screen -S 2017-05 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2017-05.pgn.bz2 ../datasets/lichess_board_csvs' 5 | screen -S 2017-06 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2017-06.pgn.bz2 ../datasets/lichess_board_csvs' 6 | screen -S 2017-07 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2017-07.pgn.bz2 ../datasets/lichess_board_csvs' 7 | screen -S 2017-08 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2017-08.pgn.bz2 ../datasets/lichess_board_csvs' 8 | screen -S 2017-09 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2017-09.pgn.bz2 ../datasets/lichess_board_csvs' 9 | screen -S 2017-10 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2017-10.pgn.bz2 ../datasets/lichess_board_csvs' 10 | screen -S 2017-11 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2017-11.pgn.bz2 ../datasets/lichess_board_csvs' 11 | screen -S 2017-12 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2017-12.pgn.bz2 ../datasets/lichess_board_csvs' 12 | 13 | screen -S 2018-01 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2018-01.pgn.bz2 ../datasets/lichess_board_csvs' 14 | screen -S 2018-02 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2018-02.pgn.bz2 ../datasets/lichess_board_csvs' 15 | screen -S 2018-03 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2018-03.pgn.bz2 ../datasets/lichess_board_csvs' 16 | screen -S 2018-04 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2018-04.pgn.bz2 ../datasets/lichess_board_csvs' 17 | screen -S 2018-05 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2018-05.pgn.bz2 ../datasets/lichess_board_csvs' 18 | screen -S 2018-06 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2018-06.pgn.bz2 ../datasets/lichess_board_csvs' 19 | screen -S 2018-07 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2018-07.pgn.bz2 ../datasets/lichess_board_csvs' 20 | screen -S 2018-08 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2018-08.pgn.bz2 ../datasets/lichess_board_csvs' 21 | screen -S 2018-09 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2018-09.pgn.bz2 ../datasets/lichess_board_csvs' 22 | screen -S 2018-10 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2018-10.pgn.bz2 ../datasets/lichess_board_csvs' 23 | screen -S 2018-11 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2018-11.pgn.bz2 ../datasets/lichess_board_csvs' 24 | screen -S 2018-12 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2018-12.pgn.bz2 ../datasets/lichess_board_csvs' 25 | 26 | screen -S 2019-01 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2019-01.pgn.bz2 ../datasets/lichess_board_csvs' 27 | screen -S 2019-02 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2019-02.pgn.bz2 ../datasets/lichess_board_csvs' 28 | screen -S 2019-03 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2019-03.pgn.bz2 ../datasets/lichess_board_csvs' 29 | screen -S 2019-04 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2019-04.pgn.bz2 ../datasets/lichess_board_csvs' 30 | screen -S 2019-05 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2019-05.pgn.bz2 ../datasets/lichess_board_csvs' 31 | screen -S 2019-06 -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py /ada/data/chess/bz2/standard/lichess_db_standard_rated_2019-06.pgn.bz2 ../datasets/lichess_board_csvs' 32 | 33 | screen -S 2019-07-val -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py ../data/lichess_db_standard_rated_2019-07.pgn.bz2 ../datasets/lichess_board_csvs_validation' 34 | 35 | screen -S 2019-08-val -dm bash -c 'source ~/.bashrc; python3 pgnCPsToCSV_single.py ../data/lichess_db_standard_rated_2019-08.pgn.bz2 ../datasets/lichess_board_csvs_validation' 36 | 37 | 38 | for i in {01..12} 39 | do 40 | echo "2017-${i}" 41 | screen -S "2017-${i}" -X quit 42 | done 43 | 44 | for i in {01..12} 45 | do 46 | echo "2018-${i}" 47 | screen -S "2018-${i}" -X quit 48 | done 49 | 50 | for i in {01..06} 51 | do 52 | echo "2019-${i}" 53 | screen -S "2019-${i}" -X quit 54 | done 55 | -------------------------------------------------------------------------------- /images/CP_v_winrate_ELO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/images/CP_v_winrate_ELO.png -------------------------------------------------------------------------------- /images/all_lineplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/images/all_lineplot.png -------------------------------------------------------------------------------- /images/delta_human_wr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/images/delta_human_wr.png -------------------------------------------------------------------------------- /images/delta_top2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/images/delta_top2.png -------------------------------------------------------------------------------- /images/delta_wr2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/images/delta_wr2.png -------------------------------------------------------------------------------- /images/leela_lineplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/images/leela_lineplot.png -------------------------------------------------------------------------------- /images/maia_lineplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/images/maia_lineplot.png -------------------------------------------------------------------------------- /images/models_agreement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/images/models_agreement.png -------------------------------------------------------------------------------- /images/other_effects_lineplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/images/other_effects_lineplot.png -------------------------------------------------------------------------------- /images/sf_lineplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/images/sf_lineplot.png -------------------------------------------------------------------------------- /maia_env.yml: -------------------------------------------------------------------------------- 1 | name: maia 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _tflow_select=2.1.0=gpu 8 | - absl-py=0.9.0=py37_0 9 | - asn1crypto=1.3.0=py37_0 10 | - astor=0.8.0=py37_0 11 | - attrs=19.3.0=py_0 12 | - backcall=0.1.0=py37_0 13 | - blas=1.0=mkl 14 | - bleach=3.1.4=py_0 15 | - blinker=1.4=py37_0 16 | - c-ares=1.15.0=h7b6447c_1001 17 | - ca-certificates=2020.1.1=0 18 | - cachetools=3.1.1=py_0 19 | - cairo=1.14.12=h8948797_3 20 | - certifi=2020.4.5.1=py37_0 21 | - cffi=1.13.2=py37h2e261b9_0 22 | - chardet=3.0.4=py37_1003 23 | - click=7.0=py37_0 24 | - conda=4.8.3=py37_0 25 | - conda-package-handling=1.6.0=py37h7b6447c_0 26 | - cryptography=2.8=py37h1ba5d50_0 27 | - cudatoolkit=10.1.243=h6bb024c_0 28 | - cudnn=7.6.5=cuda10.1_0 29 | - cupti=10.1.168=0 30 | - dbus=1.13.12=h746ee38_0 31 | - decorator=4.4.1=py_0 32 | - defusedxml=0.6.0=py_0 33 | - entrypoints=0.3=py37_0 34 | - expat=2.2.6=he6710b0_0 35 | - fontconfig=2.13.0=h9420a91_0 36 | - freetype=2.9.1=h8a8886c_1 37 | - fribidi=1.0.5=h7b6447c_0 38 | - gast=0.2.2=py37_0 39 | - glib=2.63.1=h5a9c865_0 40 | - gmp=6.1.2=h6c8ec71_1 41 | - google-auth=1.11.2=py_0 42 | - google-auth-oauthlib=0.4.1=py_2 43 | - google-pasta=0.1.8=py_0 44 | - graphite2=1.3.13=h23475e2_0 45 | - graphviz=2.40.1=h21bd128_2 46 | - grpcio=1.27.2=py37hf8bcb03_0 47 | - gst-plugins-base=1.14.0=hbbd80ab_1 48 | - gstreamer=1.14.0=hb453b48_1 49 | - h5py=2.10.0=py37h7918eee_0 50 | - harfbuzz=1.8.8=hffaf4a1_0 51 | - hdf5=1.10.4=hb1b8bf9_0 52 | - icu=58.2=h9c2bf20_1 53 | - idna=2.8=py37_0 54 | - importlib_metadata=1.4.0=py37_0 55 | - intel-openmp=2020.0=166 56 | - ipykernel=5.1.4=py37h39e3cac_0 57 | - ipython=7.11.1=py37h39e3cac_0 58 | - ipython_genutils=0.2.0=py37_0 59 | - ipywidgets=7.5.1=py_0 60 | - jedi=0.16.0=py37_0 61 | - jinja2=2.11.1=py_0 62 | - joblib=0.14.1=py_0 63 | - jpeg=9b=h024ee3a_2 64 | - jsonschema=3.2.0=py37_0 65 | - jupyter=1.0.0=py37_7 66 | - jupyter_client=5.3.4=py37_0 67 | - jupyter_console=6.1.0=py_0 68 | - jupyter_core=4.6.1=py37_0 69 | - keras-applications=1.0.8=py_0 70 | - keras-preprocessing=1.1.0=py_1 71 | - libedit=3.1.20181209=hc058e9b_0 72 | - libffi=3.2.1=hd88cf55_4 73 | - libgcc-ng=9.1.0=hdf63c60_0 74 | - libgfortran-ng=7.3.0=hdf63c60_0 75 | - libpng=1.6.37=hbc83047_0 76 | - libprotobuf=3.11.4=hd408876_0 77 | - libsodium=1.0.16=h1bed415_0 78 | - libstdcxx-ng=9.1.0=hdf63c60_0 79 | - libtiff=4.1.0=h2733197_0 80 | - libuuid=1.0.3=h1bed415_2 81 | - libxcb=1.13=h1bed415_1 82 | - libxml2=2.9.9=hea5a465_1 83 | - markdown=3.1.1=py37_0 84 | - markupsafe=1.1.1=py37h7b6447c_0 85 | - meson=0.52.0=py_0 86 | - mistune=0.8.4=py37h7b6447c_0 87 | - mkl=2020.0=166 88 | - mkl-service=2.3.0=py37he904b0f_0 89 | - mkl_fft=1.0.15=py37ha843d7b_0 90 | - mkl_random=1.1.0=py37hd6b4f25_0 91 | - more-itertools=8.0.2=py_0 92 | - nb_conda_kernels=2.2.2=py37_0 93 | - nbconvert=5.6.1=py37_0 94 | - nbformat=5.0.4=py_0 95 | - ncurses=6.1=he6710b0_1 96 | - ninja=1.9.0=py37hfd86e86_0 97 | - notebook=6.0.3=py37_0 98 | - numpy=1.18.1=py37h4f9e942_0 99 | - numpy-base=1.18.1=py37hde5b4d6_1 100 | - oauthlib=3.1.0=py_0 101 | - olefile=0.46=py37_0 102 | - openssl=1.1.1f=h7b6447c_0 103 | - opt_einsum=3.1.0=py_0 104 | - pandoc=2.2.3.2=0 105 | - pandocfilters=1.4.2=py37_1 106 | - pango=1.42.4=h049681c_0 107 | - parso=0.6.0=py_0 108 | - pcre=8.43=he6710b0_0 109 | - pexpect=4.8.0=py37_0 110 | - pickleshare=0.7.5=py37_0 111 | - pillow=7.0.0=py37hb39fc2d_0 112 | - pip=20.0.2=py37_1 113 | - pixman=0.38.0=h7b6447c_0 114 | - prometheus_client=0.7.1=py_0 115 | - prompt_toolkit=3.0.3=py_0 116 | - protobuf=3.11.4=py37he6710b0_0 117 | - ptyprocess=0.6.0=py37_0 118 | - pyasn1=0.4.8=py_0 119 | - pyasn1-modules=0.2.7=py_0 120 | - pycosat=0.6.3=py37h7b6447c_0 121 | - pycparser=2.19=py37_0 122 | - pydot=1.4.1=py37_0 123 | - pygments=2.5.2=py_0 124 | - pyjwt=1.7.1=py37_0 125 | - pyopenssl=19.1.0=py37_0 126 | - pyparsing=2.4.6=py_0 127 | - pyqt=5.9.2=py37h05f1152_2 128 | - pyrsistent=0.15.7=py37h7b6447c_0 129 | - pysocks=1.7.1=py37_0 130 | - python=3.7.4=h265db76_1 131 | - python-dateutil=2.8.1=py_0 132 | - pytorch=1.4.0=py3.7_cuda10.1.243_cudnn7.6.3_0 133 | - pyyaml=5.3=py37h7b6447c_0 134 | - pyzmq=18.1.1=py37he6710b0_0 135 | - qt=5.9.7=h5867ecd_1 136 | - qtconsole=4.6.0=py_1 137 | - readline=7.0=h7b6447c_5 138 | - requests=2.22.0=py37_1 139 | - requests-oauthlib=1.3.0=py_0 140 | - rsa=4.0=py_0 141 | - ruamel_yaml=0.15.87=py37h7b6447c_0 142 | - scikit-learn=0.22.1=py37hd81dba3_0 143 | - scipy=1.4.1=py37h0b6359f_0 144 | - send2trash=1.5.0=py37_0 145 | - setuptools=45.1.0=py37_0 146 | - sip=4.19.8=py37hf484d3e_0 147 | - six=1.14.0=py37_0 148 | - sqlite=3.30.1=h7b6447c_0 149 | - tensorboard=2.1.0=py3_0 150 | - tensorflow=2.1.0=gpu_py37h7a4bb67_0 151 | - tensorflow-base=2.1.0=gpu_py37h6c5654b_0 152 | - tensorflow-estimator=2.1.0=pyhd54b08b_0 153 | - tensorflow-gpu=2.1.0=h0d30ee6_0 154 | - termcolor=1.1.0=py37_1 155 | - terminado=0.8.3=py37_0 156 | - testpath=0.4.4=py_0 157 | - tk=8.6.8=hbc83047_0 158 | - torchvision=0.5.0=py37_cu101 159 | - tornado=6.0.3=py37h7b6447c_0 160 | - tqdm=4.42.0=py_0 161 | - traitlets=4.3.3=py37_0 162 | - urllib3=1.25.8=py37_0 163 | - wcwidth=0.1.9=py_0 164 | - webencodings=0.5.1=py37_1 165 | - werkzeug=1.0.0=py_0 166 | - wheel=0.34.1=py37_0 167 | - widgetsnbextension=3.5.1=py37_0 168 | - wrapt=1.11.2=py37h7b6447c_0 169 | - xz=5.2.4=h14c3975_4 170 | - yaml=0.1.7=had09818_2 171 | - zeromq=4.3.1=he6710b0_3 172 | - zipp=2.2.0=py_0 173 | - zlib=1.2.11=h7b6447c_3 174 | - zstd=1.3.7=h0b5b093_0 175 | - pip: 176 | - cycler==0.10.0 177 | - humanize==2.4.0 178 | - kiwisolver==1.2.0 179 | - matplotlib==3.2.1 180 | - natsort==7.0.1 181 | - pandas==1.0.3 182 | - python-chess==0.30.1 183 | - pytz==2019.3 184 | - seaborn==0.10.0 185 | - tensorboardx==2.0 186 | 187 | -------------------------------------------------------------------------------- /maia_weights/maia-1100.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/maia_weights/maia-1100.pb.gz -------------------------------------------------------------------------------- /maia_weights/maia-1200.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/maia_weights/maia-1200.pb.gz -------------------------------------------------------------------------------- /maia_weights/maia-1300.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/maia_weights/maia-1300.pb.gz -------------------------------------------------------------------------------- /maia_weights/maia-1400.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/maia_weights/maia-1400.pb.gz -------------------------------------------------------------------------------- /maia_weights/maia-1500.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/maia_weights/maia-1500.pb.gz -------------------------------------------------------------------------------- /maia_weights/maia-1600.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/maia_weights/maia-1600.pb.gz -------------------------------------------------------------------------------- /maia_weights/maia-1700.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/maia_weights/maia-1700.pb.gz -------------------------------------------------------------------------------- /maia_weights/maia-1800.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/maia_weights/maia-1800.pb.gz -------------------------------------------------------------------------------- /maia_weights/maia-1900.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/maia_weights/maia-1900.pb.gz -------------------------------------------------------------------------------- /move_prediction/lczero-common/gen_proto_files.sh: -------------------------------------------------------------------------------- 1 | protoc --proto_path=. --python_out=../maia_chess_backend/maia/ proto/net.proto 2 | protoc --proto_path=. --python_out=../maia_chess_backend/maia/ proto/chunk.proto 3 | -------------------------------------------------------------------------------- /move_prediction/lczero-common/proto/chunk.proto: -------------------------------------------------------------------------------- 1 | /* 2 | This file is part of Leela Chess Zero. 3 | Copyright (C) 2018 The LCZero Authors 4 | 5 | Leela Chess is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | Leela Chess is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with Leela Chess. If not, see . 17 | 18 | Additional permission under GNU GPL version 3 section 7 19 | 20 | If you modify this Program, or any covered work, by linking or 21 | combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA 22 | Toolkit and the NVIDIA CUDA Deep Neural Network library (or a 23 | modified version of those libraries), containing parts covered by the 24 | terms of the respective license agreement, the licensors of this 25 | Program grant you additional permission to convey the resulting work. 26 | */ 27 | syntax = "proto2"; 28 | 29 | import "proto/net.proto"; 30 | 31 | package pblczero; 32 | 33 | message State { 34 | repeated fixed64 plane = 1 [packed=true]; 35 | optional uint32 us_ooo = 2; 36 | optional uint32 us_oo = 3; 37 | optional uint32 them_ooo = 4; 38 | optional uint32 them_oo = 5; 39 | optional uint32 side_to_move = 6; 40 | optional uint32 rule_50 = 7; 41 | } 42 | 43 | message Policy { 44 | repeated uint32 index = 1 [packed=true]; 45 | repeated float prior = 2 [packed=true]; 46 | } 47 | 48 | message Game { 49 | enum Result { 50 | WHITE = 0; 51 | BLACK = 1; 52 | DRAW = 2; 53 | } 54 | 55 | repeated State state = 1; 56 | repeated Policy policy = 2; 57 | repeated float value = 3 [packed=true]; 58 | repeated uint32 move = 4 [packed=true]; 59 | optional Result result = 5; 60 | } 61 | 62 | message Chunk { 63 | optional fixed32 magic = 1; 64 | optional string license = 2; 65 | optional EngineVersion version = 3; 66 | repeated Game game = 4; 67 | } 68 | -------------------------------------------------------------------------------- /move_prediction/lczero-common/proto/net.proto: -------------------------------------------------------------------------------- 1 | /* 2 | This file is part of Leela Chess Zero. 3 | Copyright (C) 2018 The LCZero Authors 4 | 5 | Leela Chess is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | Leela Chess is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with Leela Chess. If not, see . 17 | 18 | Additional permission under GNU GPL version 3 section 7 19 | 20 | If you modify this Program, or any covered work, by linking or 21 | combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA 22 | Toolkit and the NVIDIA CUDA Deep Neural Network library (or a 23 | modified version of those libraries), containing parts covered by the 24 | terms of the respective license agreement, the licensors of this 25 | Program grant you additional permission to convey the resulting work. 26 | */ 27 | syntax = "proto2"; 28 | 29 | package pblczero; 30 | 31 | message EngineVersion { 32 | optional uint32 major = 1; 33 | optional uint32 minor = 2; 34 | optional uint32 patch = 3; 35 | } 36 | 37 | message Weights { 38 | message Layer { 39 | optional float min_val = 1; 40 | optional float max_val = 2; 41 | optional bytes params = 3; 42 | } 43 | 44 | message ConvBlock { 45 | optional Layer weights = 1; 46 | optional Layer biases = 2; 47 | optional Layer bn_means = 3; 48 | optional Layer bn_stddivs = 4; 49 | optional Layer bn_gammas = 5; 50 | optional Layer bn_betas = 6; 51 | } 52 | 53 | message SEunit { 54 | // Squeeze-excitation unit (https://arxiv.org/abs/1709.01507) 55 | // weights and biases of the two fully connected layers. 56 | optional Layer w1 = 1; 57 | optional Layer b1 = 2; 58 | optional Layer w2 = 3; 59 | optional Layer b2 = 4; 60 | } 61 | 62 | message Residual { 63 | optional ConvBlock conv1 = 1; 64 | optional ConvBlock conv2 = 2; 65 | optional SEunit se = 3; 66 | } 67 | 68 | // Input convnet. 69 | optional ConvBlock input = 1; 70 | 71 | // Residual tower. 72 | repeated Residual residual = 2; 73 | 74 | // Policy head 75 | // Extra convolution for AZ-style policy head 76 | optional ConvBlock policy1 = 11; 77 | optional ConvBlock policy = 3; 78 | optional Layer ip_pol_w = 4; 79 | optional Layer ip_pol_b = 5; 80 | 81 | // Value head 82 | optional ConvBlock value = 6; 83 | optional Layer ip1_val_w = 7; 84 | optional Layer ip1_val_b = 8; 85 | optional Layer ip2_val_w = 9; 86 | optional Layer ip2_val_b = 10; 87 | 88 | // Moves left head 89 | optional ConvBlock moves_left = 12; 90 | optional Layer ip1_mov_w = 13; 91 | optional Layer ip1_mov_b = 14; 92 | optional Layer ip2_mov_w = 15; 93 | optional Layer ip2_mov_b = 16; 94 | } 95 | 96 | message TrainingParams { 97 | optional uint32 training_steps = 1; 98 | optional float learning_rate = 2; 99 | optional float mse_loss = 3; 100 | optional float policy_loss = 4; 101 | optional float accuracy = 5; 102 | optional string lc0_params = 6; 103 | } 104 | 105 | message NetworkFormat { 106 | // Format to encode the input planes with. Used by position encoder. 107 | enum InputFormat { 108 | INPUT_UNKNOWN = 0; 109 | INPUT_CLASSICAL_112_PLANE = 1; 110 | INPUT_112_WITH_CASTLING_PLANE = 2; 111 | } 112 | optional InputFormat input = 1; 113 | 114 | // Output format of the NN. Used by search code to interpret results. 115 | enum OutputFormat { 116 | OUTPUT_UNKNOWN = 0; 117 | OUTPUT_CLASSICAL = 1; 118 | OUTPUT_WDL = 2; 119 | } 120 | optional OutputFormat output = 2; 121 | 122 | // Network architecture. Used by backends to build the network. 123 | enum NetworkStructure { 124 | // Networks without PolicyFormat or ValueFormat specified 125 | NETWORK_UNKNOWN = 0; 126 | NETWORK_CLASSICAL = 1; 127 | NETWORK_SE = 2; 128 | // Networks with PolicyFormat and ValueFormat specified 129 | NETWORK_CLASSICAL_WITH_HEADFORMAT = 3; 130 | NETWORK_SE_WITH_HEADFORMAT = 4; 131 | } 132 | optional NetworkStructure network = 3; 133 | 134 | // Policy head architecture 135 | enum PolicyFormat { 136 | POLICY_UNKNOWN = 0; 137 | POLICY_CLASSICAL = 1; 138 | POLICY_CONVOLUTION = 2; 139 | } 140 | optional PolicyFormat policy = 4; 141 | 142 | // Value head architecture 143 | enum ValueFormat { 144 | VALUE_UNKNOWN = 0; 145 | VALUE_CLASSICAL = 1; 146 | VALUE_WDL = 2; 147 | } 148 | optional ValueFormat value = 5; 149 | 150 | // Moves left head architecture 151 | enum MovesLeftFormat { 152 | MOVES_LEFT_NONE = 0; 153 | MOVES_LEFT_V1 = 1; 154 | } 155 | optional MovesLeftFormat moves_left = 6; 156 | } 157 | 158 | message Format { 159 | enum Encoding { 160 | UNKNOWN = 0; 161 | LINEAR16 = 1; 162 | } 163 | 164 | optional Encoding weights_encoding = 1; 165 | // If network_format is missing, it's assumed to have 166 | // INPUT_CLASSICAL_112_PLANE / OUTPUT_CLASSICAL / NETWORK_CLASSICAL format. 167 | optional NetworkFormat network_format = 2; 168 | } 169 | 170 | message Net { 171 | optional fixed32 magic = 1; 172 | optional string license = 2; 173 | optional EngineVersion min_version = 3; 174 | optional Format format = 4; 175 | optional TrainingParams training_params = 5; 176 | optional Weights weights = 10; 177 | } 178 | -------------------------------------------------------------------------------- /move_prediction/maia_chess_backend/__init__.py: -------------------------------------------------------------------------------- 1 | #from .uci import * 2 | from .games import * 3 | from .utils import * 4 | from .tourney import * 5 | from .loaders import * 6 | from .models_loader import * 7 | from .logging import * 8 | from .fen_to_vec import * 9 | from .bat_files import * 10 | from .plt_utils import * 11 | from .model_loader import load_model_config 12 | #from .pickle4reducer import * 13 | #from .boardTrees import * 14 | #from .stockfishAnalysis import * 15 | 16 | #Tensorflow stuff 17 | try: 18 | from .tf_process import * 19 | from .tf_net import * 20 | from .tf_blocks import * 21 | except ImportError: 22 | pass 23 | 24 | fics_header = [ 25 | 'game_id', 26 | 'rated', 27 | 'name', 28 | 'opp_name', 29 | 'elo', 30 | 'oppelo', 31 | 'num_legal_moves', 32 | 'num_blunders', 33 | 'blunder', 34 | 'eval_before_move', 35 | 'eval_after_move', 36 | 'to_move', 37 | 'is_comp', 38 | 'opp_is_comp', 39 | 'time_control', 40 | 'ECO', 41 | 'result', 42 | 'time_left', 43 | 'opp_time_left', 44 | 'time_used', 45 | 'move_idx', 46 | 'move', 47 | 'material', 48 | 'position', 49 | 'stdpos', 50 | 'unkown' 51 | ] 52 | 53 | __version__ = '1.0.0' 54 | -------------------------------------------------------------------------------- /move_prediction/maia_chess_backend/games.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import collections.abc 3 | import re 4 | 5 | 6 | import chess.pgn 7 | 8 | moveRegex = re.compile(r'\d+[.][ \.](\S+) (?:{[^}]*} )?(\S+)') 9 | 10 | 11 | class GamesFile(collections.abc.Iterable): 12 | def __init__(self, path, cacheGames = False): 13 | self.path = path 14 | self.f = bz2.open(self.path, 'rt') 15 | 16 | self.cache = cacheGames 17 | self.games = [] 18 | self.num_read = 0 19 | 20 | def __iter__(self): 21 | for g in self.games: 22 | yield g 23 | while True: 24 | yield self.loadNextGame() 25 | 26 | def loadNextGame(self): 27 | g = chess.pgn.read_game(self.f) 28 | if g is None: 29 | raise StopIteration 30 | if self.cache: 31 | self.games.append(g) 32 | self.num_read += 1 33 | return g 34 | 35 | def __getitem__(self, val): 36 | if isinstance(val, slice): 37 | return [self[i] for i in range(*val.indices(10**20))] 38 | elif isinstance(val, int): 39 | if len(self.games) < val: 40 | return self.games[val] 41 | elif val < 0: 42 | raise IndexError("negative indexing is not supported") from None 43 | else: 44 | g = self.loadNextGame() 45 | for i in range(val - len(self.games)): 46 | g = self.loadNextGame() 47 | return g 48 | else: 49 | raise IndexError("{} is not a valid input".format(val)) from None 50 | 51 | def __del__(self): 52 | try: 53 | self.f.close() 54 | except AttributeError: 55 | pass 56 | 57 | class LightGamesFile(object): 58 | def __init__(self, path, parseMoves = True, just_games = False): 59 | if path.endswith('bz2'): 60 | self.f = bz2.open(path, 'rt') 61 | else: 62 | self.f = open(path, 'r') 63 | self.parseMoves = parseMoves 64 | self.just_games = just_games 65 | self._peek = None 66 | 67 | def __iter__(self): 68 | try: 69 | while True: 70 | yield self.readNextGame() 71 | except StopIteration: 72 | return 73 | 74 | def peekNextGame(self): 75 | if self._peek is None: 76 | self._peek = self.readNextGame() 77 | return self._peek 78 | 79 | def readNextGame(self): 80 | #self.f.readline() 81 | if self._peek is not None: 82 | g = self._peek 83 | self._peek = None 84 | return g 85 | ret = {} 86 | lines = '' 87 | if self.just_games: 88 | first_hit = False 89 | for l in self.f: 90 | lines += l 91 | if len(l) < 2: 92 | if first_hit: 93 | break 94 | else: 95 | first_hit = True 96 | else: 97 | for l in self.f: 98 | lines += l 99 | if len(l) < 2: 100 | if len(ret) >= 2: 101 | break 102 | else: 103 | raise RuntimeError(l) 104 | else: 105 | k, v, _ = l.split('"') 106 | ret[k[1:-1]] = v 107 | nl = self.f.readline() 108 | lines += nl 109 | if self.parseMoves: 110 | ret['moves'] = re.findall(moveRegex, nl) 111 | lines += self.f.readline() 112 | if len(lines) < 1: 113 | raise StopIteration 114 | return ret, lines 115 | 116 | def readBatch(self, n): 117 | ret = [] 118 | for i in range(n): 119 | try: 120 | ret.append(self.readNextGame()) 121 | except StopIteration: 122 | break 123 | return ret 124 | 125 | def getWinRates(self, extraKey = None): 126 | # Assumes same players in all games 127 | dat, _ = self.peekNextGame() 128 | p1, p2 = sorted((dat['White'], dat['Black'])) 129 | d = { 130 | 'name' : f"{p1} v {p2}", 131 | 'p1' : p1, 132 | 'p2' : p2, 133 | 'games' : 0, 134 | 'wins' : 0, 135 | 'ties' : 0, 136 | 'losses' : 0, 137 | } 138 | if extraKey is not None: 139 | d[extraKey] = {} 140 | for dat, _ in self: 141 | d['games'] += 1 142 | if extraKey is not None and dat[extraKey] not in d[extraKey]: 143 | d[extraKey][dat[extraKey]] = [] 144 | if p1 == dat['White']: 145 | if dat['Result'] == '1-0': 146 | d['wins'] += 1 147 | if extraKey is not None: 148 | d[extraKey][dat[extraKey]].append(1) 149 | elif dat['Result'] == '0-1': 150 | d['losses'] += 1 151 | if extraKey is not None: 152 | d[extraKey][dat[extraKey]].append(0) 153 | else: 154 | d['ties'] += 1 155 | if extraKey is not None: 156 | d[extraKey][dat[extraKey]].append(.5) 157 | else: 158 | if dat['Result'] == '0-1': 159 | d['wins'] += 1 160 | if extraKey is not None: 161 | d[extraKey][dat[extraKey]].append(1) 162 | elif dat['Result'] == '1-0': 163 | d['losses'] += 1 164 | if extraKey is not None: 165 | d[extraKey][dat[extraKey]].append(0) 166 | else: 167 | d['ties'] += 1 168 | if extraKey is not None: 169 | d[extraKey][dat[extraKey]].append(.5) 170 | return d 171 | 172 | def __del__(self): 173 | try: 174 | self.f.close() 175 | except AttributeError: 176 | pass 177 | 178 | def getBoardMoveMap(game, maxMoves = None): 179 | d = {} 180 | board = game.board() 181 | for i, move in enumerate(game.main_line()): 182 | d[board.fen()] = move.uci() 183 | board.push(move) 184 | if maxMoves is not None and i > maxMoves: 185 | break 186 | return d 187 | -------------------------------------------------------------------------------- /move_prediction/maia_chess_backend/loaders.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | 3 | LEELA_WEIGHTS_VERSION = 2 4 | 5 | def read_weights_file(filename): 6 | if '.gz' in filename: 7 | opener = gzip.open 8 | else: 9 | opener = open 10 | with opener(filename, 'rb') as f: 11 | version = f.readline().decode('ascii') 12 | if version != '{}\n'.format(LEELA_WEIGHTS_VERSION): 13 | raise ValueError("Invalid version {}".format(version.strip())) 14 | weights = [] 15 | e = 0 16 | for line in f: 17 | line = line.decode('ascii').strip() 18 | if not line: 19 | continue 20 | e += 1 21 | weight = list(map(float, line.split(' '))) 22 | weights.append(weight) 23 | if e == 2: 24 | filters = len(line.split(' ')) 25 | #print("Channels", filters) 26 | blocks = e - (4 + 14) 27 | if blocks % 8 != 0: 28 | raise ValueError("Inconsistent number of weights in the file - e = {}".format(e)) 29 | blocks //= 8 30 | #print("Blocks", blocks) 31 | return (filters, blocks, weights) 32 | -------------------------------------------------------------------------------- /move_prediction/maia_chess_backend/logging.py: -------------------------------------------------------------------------------- 1 | from .utils import printWithDate 2 | 3 | import functools 4 | import sys 5 | import time 6 | import datetime 7 | import os 8 | import os.path 9 | import traceback 10 | 11 | import pytz 12 | tz = pytz.timezone('Canada/Eastern') 13 | 14 | min_run_time = 60 * 10 # 10 minutes 15 | infos_dir_name = 'runinfos' 16 | 17 | class Tee(object): 18 | #Based on https://stackoverflow.com/a/616686 19 | def __init__(self, fname, is_err = False): 20 | self.file = open(fname, 'a') 21 | self.is_err = is_err 22 | if is_err: 23 | self.stdstream = sys.stderr 24 | sys.stderr = self 25 | else: 26 | self.stdstream = sys.stdout 27 | sys.stdout = self 28 | def __del__(self): 29 | if self.is_err: 30 | sys.stderr = self.stdstream 31 | else: 32 | sys.stdout = self.stdstream 33 | self.file.close() 34 | def write(self, data): 35 | self.file.write(data) 36 | self.stdstream.write(data) 37 | def flush(self): 38 | self.file.flush() 39 | 40 | def makeLog(logs_prefix, start_time, tstart, is_error, *notes): 41 | fname = f'error.log' if is_error else f'run.log' 42 | with open(logs_prefix + fname, 'w') as f: 43 | f.write(f"start: {start_time.strftime('%Y-%m-%d-%H:%M:%S')}\n") 44 | f.write(f"stop: {datetime.datetime.now(tz).strftime('%Y-%m-%d-%H:%M:%S')}\n") 45 | f.write(f"in: {int(tstart > min_run_time)}s\n") 46 | f.write(f"dir: {os.path.abspath(os.getcwd())}\n") 47 | f.write(f"{' '.join(sys.argv)}\n") 48 | f.write('\n'.join([str(n) for n in notes])) 49 | 50 | def makelogNamesPrefix(script_name, start_time): 51 | os.makedirs(infos_dir_name, exist_ok = True) 52 | os.makedirs(os.path.join(infos_dir_name, script_name), exist_ok = True) 53 | return os.path.join(infos_dir_name, script_name, f"{start_time.strftime('%Y-%m-%d-%H%M%S-%f')}_") 54 | 55 | def logged_main(mainFunc): 56 | @functools.wraps(mainFunc) 57 | def wrapped_main(*args, **kwds): 58 | start_time = datetime.datetime.now(tz) 59 | script_name = os.path.basename(sys.argv[0])[:-3] 60 | logs_prefix = makelogNamesPrefix(script_name, start_time) 61 | tee_out = Tee(logs_prefix + 'stdout.log', is_err = False) 62 | tee_err = Tee(logs_prefix + 'stderr.log', is_err = True) 63 | printWithDate(' '.join(sys.argv), colour = 'blue') 64 | printWithDate(f"Starting {script_name}", colour = 'blue') 65 | try: 66 | tstart = time.time() 67 | val = mainFunc(*args, **kwds) 68 | except (Exception, KeyboardInterrupt) as e: 69 | printWithDate(f"Error encountered", colour = 'blue') 70 | if (time.time() - tstart) > min_run_time: 71 | makeLog(logs_prefix, start_time, tstart, True, 'Error', e, traceback.format_exc()) 72 | raise 73 | else: 74 | printWithDate(f"Run completed", colour = 'blue') 75 | if (time.time() - tstart) > min_run_time: 76 | makeLog(logs_prefix, start_time, tstart, False, 'Successful') 77 | tee_out.flush() 78 | tee_err.flush() 79 | return val 80 | return wrapped_main 81 | -------------------------------------------------------------------------------- /move_prediction/maia_chess_backend/maia/__init__.py: -------------------------------------------------------------------------------- 1 | #tensorflow code 2 | 3 | from .tfprocess import TFProcess 4 | from .chunkparser import ChunkParser 5 | -------------------------------------------------------------------------------- /move_prediction/maia_chess_backend/maia/lc0_az_policy_map.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import sys 3 | import numpy as np 4 | from .policy_index import policy_index 5 | 6 | columns = 'abcdefgh' 7 | rows = '12345678' 8 | promotions = 'rbq' # N is encoded as normal move 9 | 10 | col_index = {columns[i] : i for i in range(len(columns))} 11 | row_index = {rows[i] : i for i in range(len(rows))} 12 | 13 | def index_to_position(x): 14 | return columns[x[0]] + rows[x[1]] 15 | 16 | def position_to_index(p): 17 | return col_index[p[0]], row_index[p[1]] 18 | 19 | def valid_index(i): 20 | if i[0] > 7 or i[0] < 0: 21 | return False 22 | if i[1] > 7 or i[1] < 0: 23 | return False 24 | return True 25 | 26 | def queen_move(start, direction, steps): 27 | i = position_to_index(start) 28 | dir_vectors = {'N': (0, 1), 'NE': (1, 1), 'E': (1, 0), 'SE': (1, -1), 29 | 'S':(0, -1), 'SW':(-1, -1), 'W': (-1, 0), 'NW': (-1, 1)} 30 | v = dir_vectors[direction] 31 | i = i[0] + v[0] * steps, i[1] + v[1] * steps 32 | if not valid_index(i): 33 | return None 34 | return index_to_position(i) 35 | 36 | def knight_move(start, direction, steps): 37 | i = position_to_index(start) 38 | dir_vectors = {'N': (1, 2), 'NE': (2, 1), 'E': (2, -1), 'SE': (1, -2), 39 | 'S':(-1, -2), 'SW':(-2, -1), 'W': (-2, 1), 'NW': (-1, 2)} 40 | v = dir_vectors[direction] 41 | i = i[0] + v[0] * steps, i[1] + v[1] * steps 42 | if not valid_index(i): 43 | return None 44 | return index_to_position(i) 45 | 46 | def make_map(kind='matrix'): 47 | # 56 planes of queen moves 48 | moves = [] 49 | for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']: 50 | for steps in range(1, 8): 51 | for r0 in rows: 52 | for c0 in columns: 53 | start = c0 + r0 54 | end = queen_move(start, direction, steps) 55 | if end == None: 56 | moves.append('illegal') 57 | else: 58 | moves.append(start+end) 59 | 60 | # 8 planes of knight moves 61 | for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']: 62 | for r0 in rows: 63 | for c0 in columns: 64 | start = c0 + r0 65 | end = knight_move(start, direction, 1) 66 | if end == None: 67 | moves.append('illegal') 68 | else: 69 | moves.append(start+end) 70 | 71 | # 9 promotions 72 | for direction in ['NW', 'N', 'NE']: 73 | for promotion in promotions: 74 | for r0 in rows: 75 | for c0 in columns: 76 | # Promotion only in the second last rank 77 | if r0 != '7': 78 | moves.append('illegal') 79 | continue 80 | start = c0 + r0 81 | end = queen_move(start, direction, 1) 82 | if end == None: 83 | moves.append('illegal') 84 | else: 85 | moves.append(start+end+promotion) 86 | 87 | for m in policy_index: 88 | if m not in moves: 89 | raise ValueError('Missing move: {}'.format(m)) 90 | 91 | az_to_lc0 = np.zeros((80*8*8, len(policy_index)), dtype=np.float32) 92 | indices = [] 93 | legal_moves = 0 94 | for e, m in enumerate(moves): 95 | if m == 'illegal': 96 | indices.append(-1) 97 | continue 98 | legal_moves += 1 99 | # Check for missing moves 100 | if m not in policy_index: 101 | raise ValueError('Missing move: {}'.format(m)) 102 | i = policy_index.index(m) 103 | indices.append(i) 104 | az_to_lc0[e][i] = 1 105 | 106 | assert legal_moves == len(policy_index) 107 | assert np.sum(az_to_lc0) == legal_moves 108 | for e in range(80*8*8): 109 | for i in range(len(policy_index)): 110 | pass 111 | if kind == 'matrix': 112 | return az_to_lc0 113 | elif kind == 'index': 114 | return indices 115 | 116 | if __name__ == "__main__": 117 | # Generate policy map include file for lc0 118 | if len(sys.argv) != 2: 119 | raise ValueError("Output filename is needed as a command line argument") 120 | 121 | az_to_lc0 = np.ravel(make_map('index')) 122 | header = \ 123 | """/* 124 | This file is part of Leela Chess Zero. 125 | Copyright (C) 2019 The LCZero Authors 126 | 127 | Leela Chess is free software: you can redistribute it and/or modify 128 | it under the terms of the GNU General Public License as published by 129 | the Free Software Foundation, either version 3 of the License, or 130 | (at your option) any later version. 131 | 132 | Leela Chess is distributed in the hope that it will be useful, 133 | but WITHOUT ANY WARRANTY; without even the implied warranty of 134 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 135 | GNU General Public License for more details. 136 | 137 | You should have received a copy of the GNU General Public License 138 | along with Leela Chess. If not, see . 139 | */ 140 | 141 | #pragma once 142 | 143 | namespace lczero { 144 | """ 145 | line_length = 12 146 | with open(sys.argv[1], 'w') as f: 147 | f.write(header+'\n') 148 | f.write('const short kConvPolicyMap[] = {\\\n') 149 | for e, i in enumerate(az_to_lc0): 150 | if e % line_length == 0 and e > 0: 151 | f.write('\n') 152 | f.write(str(i).rjust(5)) 153 | if e != len(az_to_lc0)-1: 154 | f.write(',') 155 | f.write('};\n\n') 156 | f.write('} // namespace lczero') 157 | -------------------------------------------------------------------------------- /move_prediction/maia_chess_backend/maia/net_to_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import tensorflow as tf 4 | import os 5 | import yaml 6 | from .tfprocess import TFProcess 7 | from .net import Net 8 | 9 | argparser = argparse.ArgumentParser(description='Convert net to model.') 10 | argparser.add_argument('net', type=str, 11 | help='Net file to be converted to a model checkpoint.') 12 | argparser.add_argument('--start', type=int, default=0, 13 | help='Offset to set global_step to.') 14 | argparser.add_argument('--cfg', type=argparse.FileType('r'), 15 | help='yaml configuration with training parameters') 16 | args = argparser.parse_args() 17 | cfg = yaml.safe_load(args.cfg.read()) 18 | print(yaml.dump(cfg, default_flow_style=False)) 19 | START_FROM = args.start 20 | net = Net() 21 | net.parse_proto(args.net) 22 | 23 | filters, blocks = net.filters(), net.blocks() 24 | if cfg['model']['filters'] != filters: 25 | raise ValueError("Number of filters in YAML doesn't match the network") 26 | if cfg['model']['residual_blocks'] != blocks: 27 | raise ValueError("Number of blocks in YAML doesn't match the network") 28 | weights = net.get_weights() 29 | 30 | tfp = TFProcess(cfg) 31 | tfp.init_net_v2() 32 | tfp.replace_weights_v2(weights) 33 | tfp.global_step.assign(START_FROM) 34 | 35 | root_dir = os.path.join(cfg['training']['path'], cfg['name']) 36 | if not os.path.exists(root_dir): 37 | os.makedirs(root_dir) 38 | tfp.manager.save() 39 | print("Wrote model to {}".format(tfp.manager.latest_checkpoint)) 40 | -------------------------------------------------------------------------------- /move_prediction/maia_chess_backend/maia/proto/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/move_prediction/maia_chess_backend/maia/proto/__init__.py -------------------------------------------------------------------------------- /move_prediction/maia_chess_backend/maia/shufflebuffer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # This file is part of Leela Chess. 4 | # Copyright (C) 2018 Michael O 5 | # 6 | # Leela Chess is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Leela Chess is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Leela Chess. If not, see . 18 | 19 | import random 20 | import unittest 21 | 22 | class ShuffleBuffer: 23 | def __init__(self, elem_size, elem_count): 24 | """ 25 | A shuffle buffer for fixed sized elements. 26 | 27 | Manages 'elem_count' items in a fixed buffer, each item being exactly 28 | 'elem_size' bytes. 29 | """ 30 | assert elem_size > 0, elem_size 31 | assert elem_count > 0, elem_count 32 | # Size of each element. 33 | self.elem_size = elem_size 34 | # Number of elements in the buffer. 35 | self.elem_count = elem_count 36 | # Fixed size buffer used to hold all the element. 37 | self.buffer = bytearray(elem_size * elem_count) 38 | # Number of elements actually contained in the buffer. 39 | self.used = 0 40 | 41 | def extract(self): 42 | """ 43 | Return an item from the shuffle buffer. 44 | 45 | If the buffer is empty, returns None 46 | """ 47 | if self.used < 1: 48 | return None 49 | # The items in the shuffle buffer are held in shuffled order 50 | # so returning the last item is sufficient. 51 | self.used -= 1 52 | i = self.used 53 | return self.buffer[i * self.elem_size : (i+1) * self.elem_size] 54 | 55 | def insert_or_replace(self, item): 56 | """ 57 | Inserts 'item' into the shuffle buffer, returning 58 | a random item. 59 | 60 | If the buffer is not yet full, returns None 61 | """ 62 | assert len(item) == self.elem_size, len(item) 63 | # putting the new item in a random location, and appending 64 | # the displaced item to the end of the buffer achieves a full 65 | # random shuffle (Fisher-Yates) 66 | if self.used > 0: 67 | # swap 'item' with random item in buffer. 68 | i = random.randint(0, self.used-1) 69 | old_item = self.buffer[i * self.elem_size : (i+1) * self.elem_size] 70 | self.buffer[i * self.elem_size : (i+1) * self.elem_size] = item 71 | item = old_item 72 | # If the buffer isn't yet full, append 'item' to the end of the buffer. 73 | if self.used < self.elem_count: 74 | # Not yet full, so place the returned item at the end of the buffer. 75 | i = self.used 76 | self.buffer[i * self.elem_size : (i+1) * self.elem_size] = item 77 | self.used += 1 78 | return None 79 | return item 80 | 81 | 82 | class ShuffleBufferTest(unittest.TestCase): 83 | def test_extract(self): 84 | sb = ShuffleBuffer(3, 1) 85 | r = sb.extract() 86 | assert r == None, r # empty buffer => None 87 | r = sb.insert_or_replace(b'111') 88 | assert r == None, r # buffer not yet full => None 89 | r = sb.extract() 90 | assert r == b'111', r # one item in buffer => item 91 | r = sb.extract() 92 | assert r == None, r # buffer empty => None 93 | def test_wrong_size(self): 94 | sb = ShuffleBuffer(3, 1) 95 | try: 96 | sb.insert_or_replace(b'1') # wrong length, so should throw. 97 | assert False # Should not be reached. 98 | except: 99 | pass 100 | def test_insert_or_replace(self): 101 | n=10 # number of test items. 102 | items=[bytes([x,x,x]) for x in range(n)] 103 | sb = ShuffleBuffer(elem_size=3, elem_count=2) 104 | out=[] 105 | for i in items: 106 | r = sb.insert_or_replace(i) 107 | if not r is None: 108 | out.append(r) 109 | # Buffer size is 2, 10 items, should be 8 seen so far. 110 | assert len(out) == n - 2, len(out) 111 | # Get the last two items. 112 | out.append(sb.extract()) 113 | out.append(sb.extract()) 114 | assert sorted(items) == sorted(out), (items, out) 115 | # Check that buffer is empty 116 | r = sb.extract() 117 | assert r is None, r 118 | 119 | 120 | if __name__ == '__main__': 121 | unittest.main() 122 | -------------------------------------------------------------------------------- /move_prediction/maia_chess_backend/maia/update_steps.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os 4 | import yaml 5 | import sys 6 | import tensorflow as tf 7 | from .tfprocess import TFProcess 8 | 9 | START_FROM = 0 10 | 11 | def main(cmd): 12 | cfg = yaml.safe_load(cmd.cfg.read()) 13 | print(yaml.dump(cfg, default_flow_style=False)) 14 | 15 | root_dir = os.path.join(cfg['training']['path'], cfg['name']) 16 | if not os.path.exists(root_dir): 17 | os.makedirs(root_dir) 18 | 19 | tfprocess = TFProcess(cfg) 20 | tfprocess.init_net_v2() 21 | 22 | tfprocess.restore_v2() 23 | 24 | START_FROM = cmd.start 25 | 26 | tfprocess.global_step.assign(START_FROM) 27 | tfprocess.manager.save() 28 | 29 | if __name__ == "__main__": 30 | argparser = argparse.ArgumentParser(description=\ 31 | 'Convert current checkpoint to new step count.') 32 | argparser.add_argument('--cfg', type=argparse.FileType('r'), 33 | help='yaml configuration with training parameters') 34 | argparser.add_argument('--start', type=int, default=0, 35 | help='Offset to set global_step to.') 36 | 37 | main(argparser.parse_args()) 38 | -------------------------------------------------------------------------------- /move_prediction/maia_chess_backend/model_loader.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | import os.path 4 | 5 | from .tourney import RandomEngine, StockfishEngine, LC0Engine 6 | 7 | def load_model_config(config_dir_path, lc0_depth = None, lc0Path = None, noise = False, temperature = 0, temp_decay = 0): 8 | with open(os.path.join(config_dir_path, 'config.yaml')) as f: 9 | config = yaml.safe_load(f.read()) 10 | 11 | if config['engine'] == 'stockfish': 12 | model = StockfishEngine(**config['options']) 13 | elif config['engine'] == 'random': 14 | model = RandomEngine() 15 | elif config['engine'] == 'torch': 16 | raise NotImplementedError("torch engines aren't working yet") 17 | elif config['engine'] in ['lc0', 'lc0_23']: 18 | kwargs = config['options'].copy() 19 | if lc0_depth is not None: 20 | kwargs['nodes'] = lc0_depth 21 | kwargs['movetime'] *= lc0_depth / 10 22 | kwargs['weightsPath'] = os.path.join(config_dir_path, config['options']['weightsPath']) 23 | model = LC0Engine(lc0Path = config['engine'] if lc0Path is None else lc0Path, noise = noise, temperature = temperature, temp_decay = temp_decay, **kwargs) 24 | else: 25 | raise NotImplementedError(f"{config['engine']} is not a known engine type") 26 | 27 | return model, config 28 | -------------------------------------------------------------------------------- /move_prediction/maia_chess_backend/models_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import yaml 4 | 5 | 6 | class Trained_Model(object): 7 | def __init__(self, path): 8 | self.path = path 9 | try: 10 | with open(os.path.join(path, 'config.yaml')) as f: 11 | self.config = yaml.safe_load(f.read()) 12 | except FileNotFoundError: 13 | raise FileNotFoundError(f"No config file found in: {path}") 14 | 15 | self.weights = {int(e.name.split('-')[-1].split('.')[0]) :e.path for e in os.scandir(path) if e.name.endswith('.txt') or e.name.endswith('.pb.gz')} 16 | 17 | def getMostTrained(self): 18 | return self.weights[max(self.weights.keys())] 19 | -------------------------------------------------------------------------------- /move_prediction/maia_chess_backend/new_model.py: -------------------------------------------------------------------------------- 1 | import chess 2 | import chess.engine 3 | 4 | class ChessEngine(object): 5 | def __init__(self, engine, limits): 6 | self.limits = chess.engine.Limit(**limits) 7 | self.engine = engine 8 | 9 | def getMove(self, board): 10 | try: 11 | results = self.engine.play( 12 | board, 13 | limit=self.limits, 14 | info = chess.engine.INFO_ALL 15 | ) 16 | 17 | if isinstance(board, str): 18 | board 19 | -------------------------------------------------------------------------------- /move_prediction/maia_chess_backend/plt_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import os 3 | import os.path 4 | import seaborn 5 | 6 | def multi_savefig(save_name, dir_name = 'images', save_types = ('pdf', 'png', 'svg')): 7 | os.makedirs(dir_name, exist_ok = True) 8 | for sType in save_types: 9 | dName = os.path.join(dir_name, sType) 10 | os.makedirs(dName, exist_ok = True) 11 | 12 | fname = f'{save_name}.{sType}' 13 | 14 | plt.savefig(os.path.join(dName, fname), format = sType, dpi = 300, transparent = True) 15 | 16 | def plot_pieces(board_a): 17 | fig, axes = plt.subplots(nrows=3, ncols=6, figsize = (16, 10)) 18 | axiter = iter(axes.flatten()) 19 | for i in range(17): 20 | seaborn.heatmap(board_a[i], ax = next(axiter), cbar = False, vmin=0, vmax=1, square = True) 21 | 22 | axes[-1,-1].set_axis_off() 23 | for i, n in enumerate(['Knights', 'Bishops', 'Rooks','Queen', 'King']): 24 | axes[0,i + 1].set_title(n) 25 | axes[1,i + 1].set_title(n) 26 | axes[0,0].set_title('Active Player Pieces\nPawns') 27 | axes[1,0].set_title('Opponent Pieces\nPawns') 28 | axes[2,0].set_title('Other Values\n Is White') 29 | for i in range(4): 30 | axes[2,i + 1].set_title('Castling') 31 | -------------------------------------------------------------------------------- /move_prediction/maia_chess_backend/uci.py: -------------------------------------------------------------------------------- 1 | import chess.uci 2 | 3 | import collections 4 | import concurrent.futures 5 | import threading 6 | 7 | import re 8 | import os.path 9 | 10 | probRe = re.compile(r"\(P: +([^)]+)\)") 11 | 12 | 13 | class ProbInfoHandler(chess.uci.InfoHandler): 14 | def __init__(self): 15 | super().__init__() 16 | self.info["probs"] = [] 17 | 18 | def on_go(self): 19 | """ 20 | Notified when a *go* command is beeing sent. 21 | 22 | Since information about the previous search is invalidated, the 23 | dictionary with the current information will be cleared. 24 | """ 25 | with self.lock: 26 | self.info.clear() 27 | self.info["refutation"] = {} 28 | self.info["currline"] = {} 29 | self.info["pv"] = {} 30 | self.info["score"] = {} 31 | self.info["probs"] = [] 32 | 33 | def string(self, string): 34 | """Receives a string the engine wants to display.""" 35 | prob = re.search(probRe, string).group(1) 36 | self.info["probs"].append(string) 37 | 38 | class EngineHandler(object): 39 | def __init__(self, engine, weights, threads = 2): 40 | self.enginePath = os.path.normpath(engine) 41 | self.weightsPath = os.path.normpath(weights) 42 | 43 | self.engine = chess.uci.popen_engine([self.enginePath, "--verbose-move-stats", f"--threads={threads}", f"--weights={self.weightsPath}"]) 44 | 45 | self.info_handler = ProbInfoHandler() 46 | self.engine.info_handlers.append(self.info_handler) 47 | 48 | self.engine.uci() 49 | self.engine.isready() 50 | 51 | def __repr__(self): 52 | return f"" 53 | 54 | def getBoardProbs(self, board, movetime = 1000, nodes = 1000): 55 | self.engine.ucinewgame() 56 | self.engine.position(board) 57 | moves = self.engine.go(movetime = movetime, nodes = nodes) 58 | probs = self.info_handler.info['probs'] 59 | return moves, probs 60 | -------------------------------------------------------------------------------- /move_prediction/maia_config.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | input_train: 'PATH_TO_TRAINING_FILES' # This uses glob so training/*/* is probably the end 7 | input_test: 'PATH_TO_VAL_FILES' # This uses glob so validate/*/* is probably the end 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/model_files/1100/config.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | name: final_maia_1100 4 | display_name: Final Maia 1100 5 | engine: lc0_23 6 | options: 7 | nodes: 1 8 | weightsPath: final_1100-40.pb.gz 9 | movetime: 10 10 | threads: 8 11 | ... 12 | -------------------------------------------------------------------------------- /move_prediction/model_files/1100/final_1100-40.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/move_prediction/model_files/1100/final_1100-40.pb.gz -------------------------------------------------------------------------------- /move_prediction/model_files/1200/config.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | name: final_maia_1200 4 | display_name: Final Maia 1200 5 | engine: lc0_23 6 | options: 7 | nodes: 1 8 | weightsPath: final_1200-40.pb.gz 9 | movetime: 10 10 | threads: 8 11 | ... 12 | -------------------------------------------------------------------------------- /move_prediction/model_files/1200/final_1200-40.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/move_prediction/model_files/1200/final_1200-40.pb.gz -------------------------------------------------------------------------------- /move_prediction/model_files/1300/config.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | name: final_maia_1300 4 | display_name: Final Maia 1300 5 | engine: lc0_23 6 | options: 7 | nodes: 1 8 | weightsPath: final_1300-40.pb.gz 9 | movetime: 10 10 | threads: 8 11 | ... 12 | -------------------------------------------------------------------------------- /move_prediction/model_files/1300/final_1300-40.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/move_prediction/model_files/1300/final_1300-40.pb.gz -------------------------------------------------------------------------------- /move_prediction/model_files/1400/config.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | name: final_maia_1400 4 | display_name: Final Maia 1400 5 | engine: lc0_23 6 | options: 7 | nodes: 1 8 | weightsPath: final_1400-40.pb.gz 9 | movetime: 10 10 | threads: 8 11 | ... 12 | -------------------------------------------------------------------------------- /move_prediction/model_files/1400/final_1400-40.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/move_prediction/model_files/1400/final_1400-40.pb.gz -------------------------------------------------------------------------------- /move_prediction/model_files/1500/config.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | name: final_maia_1500 4 | display_name: Final Maia 1500 5 | engine: lc0_23 6 | options: 7 | nodes: 1 8 | weightsPath: final_1500-40.pb.gz 9 | movetime: 10 10 | threads: 8 11 | ... 12 | -------------------------------------------------------------------------------- /move_prediction/model_files/1500/final_1500-40.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/move_prediction/model_files/1500/final_1500-40.pb.gz -------------------------------------------------------------------------------- /move_prediction/model_files/1600/config.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | name: final_maia_1600 4 | display_name: Final Maia 1600 5 | engine: lc0_23 6 | options: 7 | nodes: 1 8 | weightsPath: final_1600-40.pb.gz 9 | movetime: 10 10 | threads: 8 11 | ... 12 | -------------------------------------------------------------------------------- /move_prediction/model_files/1600/final_1600-40.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/move_prediction/model_files/1600/final_1600-40.pb.gz -------------------------------------------------------------------------------- /move_prediction/model_files/1700/config.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | name: final_maia_1700 4 | display_name: Final Maia 1700 5 | engine: lc0_23 6 | options: 7 | nodes: 1 8 | weightsPath: final_1700-40.pb.gz 9 | movetime: 10 10 | threads: 8 11 | ... 12 | -------------------------------------------------------------------------------- /move_prediction/model_files/1700/final_1700-40.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/move_prediction/model_files/1700/final_1700-40.pb.gz -------------------------------------------------------------------------------- /move_prediction/model_files/1800/config.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | name: final_maia_1800 4 | display_name: Final Maia 1800 5 | engine: lc0_23 6 | options: 7 | nodes: 1 8 | weightsPath: final_1800-40.pb.gz 9 | movetime: 10 10 | threads: 8 11 | ... 12 | -------------------------------------------------------------------------------- /move_prediction/model_files/1800/final_1800-40.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/move_prediction/model_files/1800/final_1800-40.pb.gz -------------------------------------------------------------------------------- /move_prediction/model_files/1900/config.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | name: final_maia_1900 4 | display_name: Final Maia 1900 5 | engine: lc0_23 6 | options: 7 | nodes: 1 8 | weightsPath: final_1900-40.pb.gz 9 | movetime: 10 10 | threads: 8 11 | ... 12 | -------------------------------------------------------------------------------- /move_prediction/model_files/1900/final_1900-40.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-chess/2e307a140de747c795a293d1d533ee315943e6bf/move_prediction/model_files/1900/final_1900-40.pb.gz -------------------------------------------------------------------------------- /move_prediction/pgn_to_trainingdata.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_procs=10 5 | 6 | #sorry these are relative only 7 | #remove the $PWD to make work with absolute paths 8 | input_file=$PWD/$1 9 | output_files=$PWD/$2 10 | mkdir -p $output_files 11 | 12 | mkdir -p $output_files/blocks 13 | mkdir -p $output_files/training 14 | mkdir -p $output_files/validation 15 | 16 | cd $output_files/blocks 17 | 18 | #using tool from: 19 | #https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/ 20 | 21 | pgn-extract -7 -C -N -#1000 $input_file 22 | 23 | #use the first 3000 as validation set 24 | mv {1..3}.pgn $output_files/validation/ 25 | 26 | mv *.pgn $output_files/training/ 27 | 28 | cd .. 29 | rm -rv $output_files/blocks 30 | 31 | for data_type in "training" "validation"; do 32 | cd $output_files/$data_type 33 | for p in *.pgn; do 34 | cd $output_files/$data_type 35 | p_num=${p%".pgn"} 36 | echo "Starting on" $data_type $p_num 37 | mkdir $p_num 38 | cd $p_num 39 | #using tool from: 40 | #https://github.com/DanielUranga/trainingdata-tool 41 | trainingdata-tool ../$p & 42 | while [ `echo $(pgrep -c -P$$)` -gt $max_procs ]; do 43 | printf "waiting\r" 44 | sleep 1 45 | done 46 | done 47 | done 48 | echo "Almost done" 49 | wait 50 | echo "Done" 51 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/defaults/1200.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 # gpu id to process on 4 | 5 | dataset: 6 | input_train: '/datadrive/pgns_ranged/1200/train/supervised-*' 7 | input_test: '/datadrive/pgns_ranged/1200/test/supervised-*' 8 | 9 | training: 10 | batch_size: 2048 # training batch 11 | test_steps: 2000 # eval test set values after this many steps 12 | train_avg_report_steps: 200 # training reports its average values after this many steps. 13 | total_steps: 140000 # terminate after these steps 14 | warmup_steps: 250 # if global step is less than this, scale the current LR by ratio of global step to this value 15 | checkpoint_steps: 10000 # optional frequency for checkpointing before finish 16 | shuffle_size: 524288 # size of the shuffle buffer 17 | lr_values: # list of learning rates 18 | - 0.02 19 | - 0.002 20 | - 0.0005 21 | lr_boundaries: # list of boundaries 22 | - 100000 23 | - 130000 24 | policy_loss_weight: 1.0 # weight of policy loss 25 | value_loss_weight: 1.0 # weight of value loss 26 | 27 | model: 28 | filters: 64 29 | residual_blocks: 6 30 | se_ratio: 8 31 | ... 32 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/defaults/1500.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 1 # gpu id to process on 4 | 5 | dataset: 6 | input_train: '/datadrive/pgns_ranged/1500/train/supervised-*' 7 | input_test: '/datadrive/pgns_ranged/1500/test/supervised-*' 8 | 9 | training: 10 | batch_size: 2048 # training batch 11 | test_steps: 2000 # eval test set values after this many steps 12 | train_avg_report_steps: 200 # training reports its average values after this many steps. 13 | total_steps: 140000 # terminate after these steps 14 | warmup_steps: 250 # if global step is less than this, scale the current LR by ratio of global step to this value 15 | checkpoint_steps: 10000 # optional frequency for checkpointing before finish 16 | shuffle_size: 524288 # size of the shuffle buffer 17 | lr_values: # list of learning rates 18 | - 0.02 19 | - 0.002 20 | - 0.0005 21 | lr_boundaries: # list of boundaries 22 | - 100000 23 | - 130000 24 | policy_loss_weight: 1.0 # weight of policy loss 25 | value_loss_weight: 1.0 # weight of value loss 26 | 27 | model: 28 | filters: 64 29 | residual_blocks: 6 30 | se_ratio: 8 31 | ... 32 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/defaults/1800.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 1 # gpu id to process on 4 | 5 | dataset: 6 | input_train: '/datadrive/pgns_ranged/1800/train/supervised-*' 7 | input_test: '/datadrive/pgns_ranged/1800/test/supervised-*' 8 | 9 | training: 10 | batch_size: 2048 # training batch 11 | test_steps: 2000 # eval test set values after this many steps 12 | train_avg_report_steps: 200 # training reports its average values after this many steps. 13 | total_steps: 140000 # terminate after these steps 14 | warmup_steps: 250 # if global step is less than this, scale the current LR by ratio of global step to this value 15 | checkpoint_steps: 10000 # optional frequency for checkpointing before finish 16 | shuffle_size: 524288 # size of the shuffle buffer 17 | lr_values: # list of learning rates 18 | - 0.02 19 | - 0.002 20 | - 0.0005 21 | lr_boundaries: # list of boundaries 22 | - 100000 23 | - 130000 24 | policy_loss_weight: 1.0 # weight of policy loss 25 | value_loss_weight: 1.0 # weight of value loss 26 | 27 | model: 28 | filters: 64 29 | residual_blocks: 6 30 | se_ratio: 8 31 | ... 32 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/extras/1000.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 1 4 | 5 | dataset: 6 | input_train: '/maiadata/maia_extra_training/500-1100/train/*/*' 7 | input_test: '/maiadata/maia_extra_training/500-1100/val/*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/extras/2000.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 1 4 | 5 | dataset: 6 | input_train: '/maiadata/maia_extra_training/2000-2200/train/*/*' 7 | input_test: '/maiadata/maia_extra_training/2000-2200/val/*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/extras/2300.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 1 4 | 5 | dataset: 6 | input_train: '/maiadata/maia_extra_training/2100-4000/train/*/*' 7 | input_test: '/maiadata/maia_extra_training/2100-4000/val/*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/extras/all.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | input_train: '/finaldata/elo_ranges/*/train/*/*' 7 | input_test: '/finaldata/elo_ranges/*/test/*/*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/extras/double.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | input_train: '/finaldata/elo_ranges/1[19]00/train/*/*' 7 | input_test: '/finaldata/elo_ranges/1[19]00/test/*/*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final/1100_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | input_train: '/finaldata/elo_ranges/1100/train/*/*' 7 | input_test: '/finaldata/elo_ranges/1100/test/*/*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final/1200_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 1 4 | 5 | dataset: 6 | input_train: '/finaldata/elo_ranges/1200/train/*/*' 7 | input_test: '/finaldata/elo_ranges/1200/test/*/*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final/1300_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 1 4 | 5 | dataset: 6 | input_train: '/finaldata/elo_ranges/1300/train/*/*' 7 | input_test: '/finaldata/elo_ranges/1300/test/*/*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final/1400_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 2 4 | 5 | dataset: 6 | input_train: '/finaldata/elo_ranges/1400/train/*/*' 7 | input_test: '/finaldata/elo_ranges/1400/test/*/*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final/1500_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 3 4 | 5 | dataset: 6 | input_train: '/finaldata/elo_ranges/1500/train/*/*' 7 | input_test: '/finaldata/elo_ranges/1500/test/*/*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final/1600_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | input_train: '/finaldata/elo_ranges/1600/train/*/*' 7 | input_test: '/finaldata/elo_ranges/1600/test/*/*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final/1700_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 2 4 | 5 | dataset: 6 | input_train: '/finaldata/elo_ranges/1700/train/*/*' 7 | input_test: '/finaldata/elo_ranges/1700/test/*/*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final/1800_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 3 4 | 5 | dataset: 6 | input_train: '/finaldata/elo_ranges/1800/train/*/*' 7 | input_test: '/finaldata/elo_ranges/1800/test/*/*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final/1900_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | input_train: '/finaldata/elo_ranges/1900/train/*/*' 7 | input_test: '/finaldata/elo_ranges/1900/test/*/*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final_unfiltered/1000_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | input_train: '/maiadata/pgns_ranged/1000/train/20*/*/supervised-*' 7 | input_test: '/maiadata/pgns_ranged/1000/test/20*/1/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final_unfiltered/1100_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | input_train: '/maiadata/pgns_ranged/1100/train/20*/*/supervised-*' 7 | input_test: '/maiadata/pgns_ranged/1100/test/20*/1/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final_unfiltered/1200_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | input_train: '/maiadata/pgns_ranged/1200/train/20*/*/supervised-*' 7 | input_test: '/maiadata/pgns_ranged/1200/test/20*/1/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final_unfiltered/1300_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 1 4 | 5 | dataset: 6 | input_train: '/maiadata/pgns_ranged/1300/train/20*/*/supervised-*' 7 | input_test: '/maiadata/pgns_ranged/1300/test/20*/1/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final_unfiltered/1400_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 1 4 | 5 | dataset: 6 | input_train: '/maiadata/pgns_ranged/1400/train/20*/*/supervised-*' 7 | input_test: '/maiadata/pgns_ranged/1400/test/20*/1/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final_unfiltered/1500_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 2 4 | 5 | dataset: 6 | input_train: '/maiadata/pgns_ranged/1500/train/20*/*/supervised-*' 7 | input_test: '/maiadata/pgns_ranged/1500/test/20*/1/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final_unfiltered/1600_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 2 4 | 5 | dataset: 6 | input_train: '/maiadata/pgns_ranged/1600/train/20*/*/supervised-*' 7 | input_test: '/maiadata/pgns_ranged/1600/test/20*/1/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final_unfiltered/1700_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 2 4 | 5 | dataset: 6 | input_train: '/maiadata/pgns_ranged/1700/train/20*/*/supervised-*' 7 | input_test: '/maiadata/pgns_ranged/1700/test/20*/1/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final_unfiltered/1800_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 3 4 | 5 | dataset: 6 | input_train: '/maiadata/pgns_ranged/1800/train/20*/*/supervised-*' 7 | input_test: '/maiadata/pgns_ranged/1800/test/20*/1/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final_unfiltered/1900_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 3 4 | 5 | dataset: 6 | input_train: '/maiadata/pgns_ranged/1900/train/20*/*/supervised-*' 7 | input_test: '/maiadata/pgns_ranged/1900/test/20*/1/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/final_unfiltered/2000_final.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 3 4 | 5 | dataset: 6 | input_train: '/maiadata/pgns_ranged/2000/train/20*/*/supervised-*' 7 | input_test: '/maiadata/pgns_ranged/2000/test/20*/1/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/leela_best/1200.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 # gpu id to process on 4 | 5 | dataset: 6 | input_train: '/datadrive/pgns_ranged/1200/train/supervised-*' 7 | input_test: '/datadrive/pgns_ranged/1200/test/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 28 | value_loss_weight: 0.5 29 | swa: True 30 | swa_steps: 2000 31 | swa_max_n: 8 32 | warmup_steps: 100 33 | num_test_positions: 40000 34 | mask_legal_moves: true 35 | renorm: true 36 | renorm_max_r: 1.0 37 | renorm_max_d: 0.0 38 | 39 | model: 40 | filters: 320 41 | residual_blocks: 24 42 | se_ratio: 10 43 | ... 44 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/leela_best/1500.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 # gpu id to process on 4 | 5 | dataset: 6 | input_train: '/datadrive/pgns_ranged/1500/train/supervised-*' 7 | input_test: '/datadrive/pgns_ranged/1500/test/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 28 | value_loss_weight: 0.5 29 | swa: True 30 | swa_steps: 2000 31 | swa_max_n: 8 32 | warmup_steps: 100 33 | num_test_positions: 40000 34 | mask_legal_moves: true 35 | renorm: true 36 | renorm_max_r: 1.0 37 | renorm_max_d: 0.0 38 | 39 | model: 40 | filters: 320 41 | residual_blocks: 24 42 | se_ratio: 10 43 | ... 44 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/leela_best/1800.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 1 # gpu id to process on 4 | 5 | dataset: 6 | input_train: '/datadrive/pgns_ranged/1800/train/supervised-*' 7 | input_test: '/datadrive/pgns_ranged/1800/test/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 28 | value_loss_weight: 0.5 29 | swa: True 30 | swa_steps: 2000 31 | swa_max_n: 8 32 | warmup_steps: 100 33 | num_test_positions: 40000 34 | mask_legal_moves: true 35 | renorm: true 36 | renorm_max_r: 1.0 37 | renorm_max_d: 0.0 38 | 39 | model: 40 | filters: 320 41 | residual_blocks: 24 42 | se_ratio: 10 43 | ... 44 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/new_LR/1200_LR.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 1 4 | 5 | dataset: 6 | input_train: '/maiadata/pgn_splits/1200/train/*_files/supervised-*' 7 | input_test: '/maiadata/pgn_splits/1200/test/*_files/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/new_LR/1200_LR_big_batch.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 1 4 | 5 | dataset: 6 | input_train: '/maiadata/pgn_splits/1200/train/*_files/supervised-*' 7 | input_test: '/maiadata/pgn_splits/1200/test/*_files/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 2048 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/new_LR/1500_LR.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | input_train: '/maiadata/pgn_splits/1500/train/*_files/supervised-*' 7 | input_test: '/maiadata/pgn_splits/1500/test/*_files/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/new_LR/1500_LR_big_batch.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 1 4 | 5 | dataset: 6 | input_train: '/maiadata/pgn_splits/1500/train/*_files/supervised-*' 7 | input_test: '/maiadata/pgn_splits/1500/test/*_files/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 2048 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/new_LR/1800_LR.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | input_train: '/maiadata/pgn_splits/1800/train/*_files/supervised-*' 7 | input_test: '/maiadata/pgn_splits/1800/test/*_files/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/new_LR/1800_LR_big_batch.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 1 4 | 5 | dataset: 6 | input_train: '/maiadata/pgn_splits/1800/train/*_files/supervised-*' 7 | input_test: '/maiadata/pgn_splits/1800/test/*_files/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 2048 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/sweep/1800_LR.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | input_train: '/datadrive/pgns_ranged/1800/train/supervised-*' 7 | input_test: '/datadrive/pgns_ranged/1800/test/supervised-*' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 1024 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 400000 16 | checkpoint_steps: 10000 17 | shuffle_size: 250000 18 | lr_values: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | lr_boundaries: 24 | - 80000 25 | - 200000 26 | - 360000 27 | policy_loss_weight: 1.0 # weight of policy loss 28 | value_loss_weight: 1.0 # weight of value loss 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/sweep/1800_policy_value.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | input_train: '/datadrive/pgns_ranged/1800/train/supervised-*' 7 | input_test: '/datadrive/pgns_ranged/1800/test/supervised-*' 8 | 9 | training: 10 | batch_size: 2048 # training batch 11 | test_steps: 2000 # eval test set values after this many steps 12 | train_avg_report_steps: 200 # training reports its average values after this many steps. 13 | total_steps: 140000 # terminate after these steps 14 | warmup_steps: 250 # if global step is less than this, scale the current LR by ratio of global step to this value 15 | checkpoint_steps: 10000 # optional frequency for checkpointing before finish 16 | shuffle_size: 524288 # size of the shuffle buffer 17 | lr_values: # list of learning rates 18 | - 0.02 19 | - 0.002 20 | - 0.0005 21 | lr_boundaries: # list of boundaries 22 | - 100000 23 | - 130000 24 | policy_loss_weight: 1.0 # weight of policy loss 25 | value_loss_weight: 0.5 # weight of value loss 26 | 27 | model: 28 | filters: 64 29 | residual_blocks: 6 30 | se_ratio: 8 31 | ... 32 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/sweep/1800_renorm.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | input_train: '/datadrive/pgns_ranged/1800/train/supervised-*' 7 | input_test: '/datadrive/pgns_ranged/1800/test/supervised-*' 8 | 9 | training: 10 | batch_size: 2048 # training batch 11 | test_steps: 2000 # eval test set values after this many steps 12 | train_avg_report_steps: 200 # training reports its average values after this many steps. 13 | total_steps: 140000 # terminate after these steps 14 | warmup_steps: 250 # if global step is less than this, scale the current LR by ratio of global step to this value 15 | checkpoint_steps: 10000 # optional frequency for checkpointing before finish 16 | shuffle_size: 524288 # size of the shuffle buffer 17 | lr_values: # list of learning rates 18 | - 0.02 19 | - 0.002 20 | - 0.0005 21 | lr_boundaries: # list of boundaries 22 | - 100000 23 | - 130000 24 | policy_loss_weight: 1.0 # weight of policy loss 25 | value_loss_weight: 1.0 # weight of value loss 26 | renorm: true 27 | renorm_max_r: 1.0 28 | renorm_max_d: 0.0 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | ... 35 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/sweep/1800_swa.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | input_train: '/datadrive/pgns_ranged/1800/train/supervised-*' 7 | input_test: '/datadrive/pgns_ranged/1800/test/supervised-*' 8 | 9 | training: 10 | batch_size: 2048 # training batch 11 | test_steps: 2000 # eval test set values after this many steps 12 | train_avg_report_steps: 200 # training reports its average values after this many steps. 13 | total_steps: 140000 # terminate after these steps 14 | warmup_steps: 250 # if global step is less than this, scale the current LR by ratio of global step to this value 15 | checkpoint_steps: 10000 # optional frequency for checkpointing before finish 16 | shuffle_size: 524288 # size of the shuffle buffer 17 | lr_values: # list of learning rates 18 | - 0.02 19 | - 0.002 20 | - 0.0005 21 | lr_boundaries: # list of boundaries 22 | - 100000 23 | - 130000 24 | policy_loss_weight: 1.0 # weight of policy loss 25 | value_loss_weight: 1.0 # weight of value loss 26 | swa: True 27 | swa_steps: 2000 28 | 29 | model: 30 | filters: 64 31 | residual_blocks: 6 32 | se_ratio: 8 33 | ... 34 | -------------------------------------------------------------------------------- /move_prediction/replication-configs/testing/example.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | name: 'kb1-64x6' # ideally no spaces 4 | gpu: 0 # gpu id to process on 5 | 6 | dataset: 7 | num_chunks: 100000 # newest nof chunks to parse 8 | train_ratio: 0.90 # trainingset ratio 9 | # For separated test and train data. 10 | # input_train: '/path/to/chunks/*/draw/' # supports glob 11 | # input_test: '/path/to/chunks/*/draw/' # supports glob 12 | # For a one-shot run with all data in one directory. 13 | input: '/datadrive/pgns_ranged/1500/supervised-*' 14 | 15 | training: 16 | batch_size: 2048 # training batch 17 | test_steps: 2000 # eval test set values after this many steps 18 | train_avg_report_steps: 200 # training reports its average values after this many steps. 19 | total_steps: 140000 # terminate after these steps 20 | warmup_steps: 250 # if global step is less than this, scale the current LR by ratio of global step to this value 21 | # checkpoint_steps: 10000 # optional frequency for checkpointing before finish 22 | shuffle_size: 524288 # size of the shuffle buffer 23 | lr_values: # list of learning rates 24 | - 0.02 25 | - 0.002 26 | - 0.0005 27 | lr_boundaries: # list of boundaries 28 | - 100000 29 | - 130000 30 | policy_loss_weight: 1.0 # weight of policy loss 31 | value_loss_weight: 1.0 # weight of value loss 32 | path: 'models/testing' # network storage dir 33 | 34 | model: 35 | filters: 64 36 | residual_blocks: 6 37 | se_ratio: 8 38 | ... 39 | -------------------------------------------------------------------------------- /move_prediction/replication-extractELOrange.py: -------------------------------------------------------------------------------- 1 | import maia_chess_backend 2 | 3 | import argparse 4 | import bz2 5 | 6 | #@haibrid_chess_utils.logged_main 7 | def main(): 8 | parser = argparse.ArgumentParser(description='Process some integers.') 9 | parser.add_argument('eloMin', type=int, help='min ELO') 10 | parser.add_argument('eloMax', type=int, help='max ELO') 11 | parser.add_argument('output', help='output file') 12 | parser.add_argument('targets', nargs='+', help='target files') 13 | parser.add_argument('--remove_bullet', action='store_true', help='Remove bullet and ultrabullet games') 14 | parser.add_argument('--remove_low_time', action='store_true', help='Remove low time moves from games') 15 | 16 | args = parser.parse_args() 17 | gamesWritten = 0 18 | print(f"Starting writing to: {args.output}") 19 | with bz2.open(args.output, 'wt') as f: 20 | for num_files, target in enumerate(sorted(args.targets)): 21 | print(f"{num_files} reading: {target}") 22 | Games = maia_chess_backend.LightGamesFile(target, parseMoves = False) 23 | for i, (dat, lines) in enumerate(Games): 24 | try: 25 | whiteELO = int(dat['WhiteElo']) 26 | BlackELO = int(dat['BlackElo']) 27 | except ValueError: 28 | continue 29 | if whiteELO > args.eloMax or whiteELO <= args.eloMin: 30 | continue 31 | elif BlackELO > args.eloMax or BlackELO <= args.eloMin: 32 | continue 33 | elif dat['Result'] not in ['1-0', '0-1', '1/2-1/2']: 34 | continue 35 | elif args.remove_bullet and 'Bullet' in dat['Event']: 36 | continue 37 | else: 38 | if args.remove_low_time: 39 | f.write(maia_chess_backend.remove_low_time(lines)) 40 | else: 41 | f.write(lines) 42 | gamesWritten += 1 43 | if i % 1000 == 0: 44 | print(f"{i}: written {gamesWritten} files {num_files}: {target}".ljust(79), end = '\r') 45 | print(f"Done: {target} {i}".ljust(79)) 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /move_prediction/replication-generate_pgns.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #Read the raw pgns from lichess and filter out the elo ranges we care about 4 | 5 | mkdir ../data/pgns_ranged_filtered/ 6 | for i in {1000..2000..100}; do 7 | echo $i 8 | upperval=$(($i + 100)) 9 | outputdir="../data/pgns_ranged_filtered/${i}" 10 | mkdir $outputdir 11 | for f in ../data/lichess_raw/lichess_db_standard_rated_2017* ../data/lichess_raw/lichess_db_standard_rated_2018* ../data/lichess_raw/lichess_db_standard_rated_2019-{01..11}.pgn.bz2; do 12 | fname="$(basename -- $f)" 13 | echo "${i}-${fname}" 14 | screen -S "${i}-${fname}" -dm bash -c "source ~/.bashrc; python3 ../data_generators/extractELOrange.py --remove_bullet --remove_low_time ${i} ${upperval} ${outputdir}/${fname} ${f}" 15 | done 16 | done 17 | 18 | # You have to wait for the screens to finish to do this 19 | # We use pgn-extract to normalize the games and prepare for preprocessing 20 | # This also creates blocks of 200,000 games which are useful for the next step 21 | 22 | mkdir ../data/pgns_ranged_blocks 23 | for i in {1000..2000..100}; do 24 | echo $i 25 | cw=`pwd` 26 | outputdir="../data/pgns_ranged_blocks/${i}" 27 | mkdir $outputdir 28 | cd $outputdir 29 | for y in {2017..2019}; do 30 | echo "${i}-${y}" 31 | mkdir $y 32 | cd $y 33 | screen -S "${i}-${y}" -dm bash -c "source ~/.bashrc; bzcat \"../../../pgns_ranged_filtered/${i}/lichess_db_standard_rated_${y}\"* | pgn-extract -7 -C -N -#200000" 34 | cd .. 35 | done 36 | cd $cw 37 | done 38 | 39 | #Now we have all the pgns in blocks we can randomly sample and creat testing and training sets of 60 and 3 blocks respectively 40 | python3 move_training_set.py 41 | -------------------------------------------------------------------------------- /move_prediction/replication-make_leela_files.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | mkdir ../data/elo_ranges 5 | cw=`pwd` 6 | for elo in {1100..1900..100}; do 7 | echo $i 8 | mkdir "../data/elo_ranges/${elo}" 9 | outputtest="../data/elo_ranges/${elo}/test" 10 | outputtrain="../data/elo_ranges/${elo}/train" 11 | mkdir $outputtest 12 | mkdir $outputtrain 13 | for te in "../data/final_training_data/pgns_ranged_training/${elo}"/*; do 14 | fname="$(basename -- $te)" 15 | echo "${elo}-${fname}" 16 | cd $outputtrain 17 | mkdir $fname 18 | cd $fname 19 | screen -S "${elo}-${fname}-test" -dm bash -c "trainingdata-tool -v -files-per-dir 5000 ${te}" 20 | cd .. 21 | done 22 | for te in "../data/final_training_data/pgns_ranged_testing/${elo}"/{1..2}.pgn; do 23 | fname="$(basename -- $te)" 24 | echo "${elo}-${fname}" 25 | cd $outputtest 26 | mkdir $fname 27 | cd $fname 28 | echo "trainingdata-tool -v -files-per-dir 5000 ${te}" 29 | screen -S "${elo}-${fname}-test" -dm bash -c "trainingdata-tool -v -files-per-dir 5000 ${te}" 30 | cd .. 31 | done 32 | te="../data/final_training_data/pgns_ranged_testing/${elo}/3.pgn" fname="$(basename -- $te)" 33 | echo "${elo}-${fname}" 34 | cd $outputtest 35 | mkdir $fname 36 | cd $fname 37 | trainingdata-tool -v -files-per-dir 5000 ${te} 38 | cd .. 39 | done 40 | cd $cw 41 | 42 | 43 | 44 | 45 | 46 | 47 | #After merging split pgns 48 | pgn-extract -7 -C -N -#400000 /datadrive/pgns_ranged/1200/lichess_1200.pgn 49 | pgn-extract -7 -C -N -#400000 /datadrive/pgns_ranged/1500/lichess_1500.pgn 50 | pgn-extract -7 -C -N -#400000 /datadrive/pgns_ranged/1800/lichess_1800.pgn 51 | 52 | 53 | #Then on all the results 54 | 55 | trainingdata-tool -v -files-per-dir 5000 lichess_1800.pgn 56 | for f in *.pgn; do echo "${f%.*}"; mkdir "${f%.*}_files"; cd "${f%.*}_files"; trainingdata-tool -v -files-per-dir 5000 "../${f}"; cd ..; done 57 | 58 | for f in {1..10}.pgn; do echo "${f%.*}"; mkdir "${f%.*}_files"; cd "${f%.*}_files"; trainingdata-tool -v -files-per-dir 5000 "../${f}"; cd ..; done 59 | 60 | mkdir train 61 | mkdir test 62 | mv 10_files/ test/ 63 | mv 10.pgn test/ 64 | mv *_* train/ 65 | mv *.pgn train 66 | 67 | 68 | download pgns_ranged.zip 69 | unzip pgns_ranged.zip 70 | cd pgns_ranged 71 | for elo in *; do 72 | echo $elo 73 | cd $elo 74 | for year in *.pgn.bz2; do 75 | echo "${year%.*.*}" 76 | mkdir "${year%.*.*}" 77 | cd "${year%.*.*}" 78 | screen -S "${elo}-${year%.*.*}" -dm bash -c "bzcat \"../${year}\" | pgn-extract -7 -C -N -#400000" 79 | cd .. 80 | done 81 | cd .. 82 | done 83 | 84 | for elo in *; do 85 | echo $elo 86 | cd $elo 87 | mkdir -p train 88 | mkdir -p test 89 | for year in lichess-*/; do 90 | yearonly="${year#lichess-}" 91 | yearonly="${yearonly%/}" 92 | echo "${elo}-${yearonly}" 93 | cd test 94 | mkdir -p "${yearonly}" 95 | mkdir -p "${yearonly}/1" 96 | cd "${yearonly}/1" 97 | 98 | screen -S "${elo}-${yearonly}-test" -dm bash -c "trainingdata-tool -v -files-per-dir 5000 \"../../../${year}/1.pgn\"" 99 | 100 | cd ../../.. 101 | cd train 102 | mkdir -p "${yearonly}" 103 | cd "${yearonly}" 104 | for i in {2..10}; do 105 | echo "${i}" 106 | mkdir -p "${i}" 107 | cd "${i}" 108 | screen -S "${elo}-${yearonly}-train-${i}" -dm bash -c "trainingdata-tool -v -files-per-dir 5000 \"../../../${year}/${i}.pgn\"" 109 | cd .. 110 | done 111 | cd ../.. 112 | done 113 | cd .. 114 | done 115 | 116 | for scr in $(screen -ls | awk '{print $1}'); do if [[ $scr == *"test"* ]]; then echo $scr; screen -S $scr -X kill; fi; done 117 | 118 | for scr in $(screen -ls | awk '{print $1}'); do if [[ $scr == *"2200"* ]]; then echo $scr; screen -S $scr -X kill; fi; done 119 | 120 | 121 | for scr in $(screen -ls | awk '{print $1}'); do if [[ $scr == *"final"* ]]; then echo $scr; screen -S $scr -X kill; fi; done 122 | -------------------------------------------------------------------------------- /move_prediction/replication-make_testing_pgns.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #Read the raw pgns from lichess and filter out the elo ranges we care about to make our validation set 4 | 5 | mkdir ../data/pgns_traj_testing/ 6 | for i in {1000..2500..100}; do 7 | echo $i 8 | upperval=$(($i + 100)) 9 | screen -S "${i}-testing" -dm bash -c "source ~/.bashrc; python3 replication-extractELOrange.py --remove_bullet ${i} ${upperval} ../data/pgns_traj_testing/${i}_2019-12.pgn.bz2 ../datasets/lichess_db_standard_rated_2019-12.pgn.bz2;bzcat ../data/pgns_traj_testing/${i}_2019-12.pgn.bz2 | pgn-extract -Wuci | uci-analysis --engine stockfish --searchdepth 15 --bookdepth 0 --annotatePGN | pgn-extract --output ../data/pgns_traj_testing/${i}_2019-12_anotated.pgn" 10 | done 11 | 12 | #Don't really need screen for this 13 | mkdir ../data/pgns_traj_blocks/ 14 | for i in {1000..2500..100}; do 15 | echo $i 16 | screen -S "${i}-testing-split" -dm bash -c "bzcat ../data/pgns_traj_testing/${i}_2019-12.pgn.bz2 | pgn-extract --stopafter 10000 --output ../data/pgns_traj_blocks/${i}_10000_2019-12.pgn" 17 | done 18 | 19 | mkdir ../data/pgns_traj_csvs/ 20 | for i in {1000..2500..100}; do 21 | echo $i 22 | screen -S "${i}-testing-csv" -dm bash -c "source ~/.bashrc; python3 ../data_generators/make_month_csv.py --allow_non_sf ../data/pgns_traj_blocks/${i}_10000_2019-12.pgn ../data/pgns_traj_csvs/" 23 | done 24 | -------------------------------------------------------------------------------- /move_prediction/replication-move_training_set.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import random 4 | import shutil 5 | 6 | random.seed(1) 7 | 8 | train_dir = '../data/pgns_ranged_training/' 9 | test_dir = '../data/pgns_ranged_testing/' 10 | target_per_year = 20 11 | 12 | def main(): 13 | for elo in sorted(os.scandir('../data/pgns_ranged_blocks/'), key = lambda x : x.name): 14 | if elo.name in ['1000','2000']: 15 | continue 16 | print(elo.name) 17 | train_path = os.path.join(train_dir, elo.name) 18 | os.makedirs(train_path, exist_ok = True) 19 | test_path = os.path.join(test_dir, elo.name) 20 | os.makedirs(test_path, exist_ok = True) 21 | count_train = 0 22 | num_missing = 0 23 | count = 0 24 | for year in sorted(os.scandir(elo.path), key = lambda x : x.name): 25 | targets = sorted(os.scandir(year.path), key = lambda x : int(x.name.split('.')[0]))[:-1] 26 | if year.name == '2019': 27 | for i, t in enumerate(targets[-3:]): 28 | shutil.copy(t.path, os.path.join(test_path, f"{i+1}.pgn")) 29 | targets = targets[:-3] 30 | random.shuffle(targets) 31 | for t in targets[:target_per_year + num_missing]: 32 | count_train += 1 33 | shutil.copy(t.path, os.path.join(train_path, f"{count_train}_{year.name}.pgn")) 34 | if len(targets[:target_per_year + num_missing]) < target_per_year + num_missing: 35 | num_missing = target_per_year + num_missing - len(targets[:target_per_year + num_missing]) 36 | else: 37 | num_missing = 0 38 | c = len(os.listdir(year.path)) 39 | print(year.name, c, count_train) 40 | count += c 41 | print(count) 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | python-chess 3 | pytz 4 | humanize 5 | tensorboard 6 | tensorflow 7 | protobuf 8 | tensorboardx 9 | seaborn 10 | matplotlib 11 | --------------------------------------------------------------------------------