├── .gitignore ├── 0-player_counting ├── 0-find_top_players.sh ├── 1-collect_top_players.sh ├── 2-select_extended_set.sh ├── README.md ├── combine_player_counts.py ├── find_top_players.py ├── player_game_counts.py ├── select_binned_players.py ├── select_top_players.py └── split_by_players.py ├── 1-data_generation ├── 0-0-make_training_datasets.sh ├── 0-1-make_training_csvs.sh ├── 0-2-make_reduced_datasets.sh ├── 1-0-make_val_datasets.sh ├── 1-1-make_val_csvs.sh ├── 2-make_testing_datasets.sh ├── 9-pgn_to_training_data.sh ├── pgn_fractional_split.py ├── player_splits.sh └── split_by_player.py ├── 2-training ├── extended_configs │ ├── frozen_copy │ │ └── frozen_copy.yaml │ ├── frozen_random │ │ └── frozen_random.yaml │ ├── unfrozen_copy │ │ └── unfrozen_copy.yaml │ └── unfrozen_random │ │ └── unfrozen_random.yaml ├── final_config.yaml └── train_transfer.py ├── 3-analysis ├── 1-0-baselines_results.sh ├── 1-1-baselines_results_validation.sh ├── 2-0-baseline_results.sh ├── 2-1-model_results.sh ├── 2-2-model_results_val.sh ├── 3-0-model_cross_table.sh ├── 3-1-model_cross_table_val_generation.sh ├── 3-2-model_cross_table_val.sh ├── 4-0-result_summaries.sh ├── 4-1-result_summaries_cross.sh ├── 4-2-result_summaries_val.sh ├── csv_trimmer.py ├── get_accuracy.py ├── get_models_player.py ├── make_summary.py ├── move_predictions.sh ├── prediction_generator.py └── run-kdd-tests.sh ├── 4-cp_loss_stylo_baseline ├── README.md ├── get_cp_loss.py ├── get_cp_loss_per_game.py ├── get_cp_loss_per_move.py ├── get_cp_loss_per_move_per_game.py ├── get_cp_loss_per_move_per_game_count.py ├── results │ ├── games_accuracy.csv │ ├── start_after.csv │ ├── start_after_all_game.csv │ ├── stop_after.csv │ └── stop_after_all_game.csv ├── results_validation │ ├── games_accuracy.csv │ ├── start_after.csv │ ├── start_after_4games.csv │ ├── stop_after.csv │ └── stop_after_4games.csv ├── sweep_moves_all_games.py ├── sweep_moves_num_games.py ├── sweep_moves_per_game.py ├── sweep_num_games.py ├── test_all_games.py └── train_cploss_per_game.py ├── 9-reduced-data └── configs │ ├── Best_frozen.yaml │ ├── NFP.yaml │ └── Tuned.yaml ├── CITATION.cff ├── LICENSE ├── README.md ├── backend ├── __init__.py ├── fen_to_vec.py ├── multiproc.py ├── pgn_parsering.py ├── pgn_to_csv.py ├── proto │ ├── __init__.py │ ├── net.proto │ └── net_pb2.py ├── tf_transfer │ ├── __init__.py │ ├── chunkparser.py │ ├── decode_training.py │ ├── lc0_az_policy_map.py │ ├── net.py │ ├── net_to_model.py │ ├── policy_index.py │ ├── shufflebuffer.py │ ├── tfprocess.py │ ├── tfprocess_reg_lr_noise.py │ ├── training_shared.py │ ├── update_steps.py │ └── utils.py ├── uci_engine.py └── utils.py ├── environment.yml ├── images └── kdd_indiv_final.jpg ├── models └── maia-1900 │ ├── ckpt │ ├── checkpoint │ ├── ckpt-40-400000.pb.gz │ ├── ckpt-40.data-00000-of-00002 │ ├── ckpt-40.data-00001-of-00002 │ └── ckpt-40.index │ ├── config.yaml │ └── final_1900-40.pb.gz └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.zip 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /0-player_counting/0-find_top_players.sh: -------------------------------------------------------------------------------- 1 | ##!/bin/bash 2 | 3 | lichesss_raw_dir='/data/chess/bz2/standard/' 4 | output_dir='../../data/player_counts' 5 | mkdir -p $output_dir 6 | 7 | for t in $lichesss_raw_dir/*-{01..11}.pgn.bz2 $lichesss_raw_dir/*{3..8}-12.pgn.bz2; do 8 | fname="$(basename -- $t)" 9 | echo "${t} ${output_dir}/${fname}.csv.bz2" 10 | screen -S "filter-${fname}" -dm bash -c "source ~/.bashrc; python3 find_top_players.py ${t} ${output_dir}/${fname}.csv.bz2" 11 | done 12 | -------------------------------------------------------------------------------- /0-player_counting/1-collect_top_players.sh: -------------------------------------------------------------------------------- 1 | ##!/bin/bash 2 | 3 | lichesss_raw_dir='/data/chess/bz2/standard/' 4 | counts_dir='../../data/player_counts' 5 | counts_file='../../data/player_counts_combined.csv.bz2' 6 | top_list='../../data/player_counts_combined_top_names.csv.bz2' 7 | 8 | output_2000_dir='../../data/top_2000_player_games' 9 | output_2000_metadata_dir='../../data/top_2000_player_data' 10 | 11 | players_list='../../data/select_transfer_players' 12 | 13 | final_data_dir='../../data/transfer_players_data' 14 | 15 | num_train=10 16 | num_val=900 17 | num_test=100 18 | 19 | python3 combine_player_counts.py $counts_dir/* $counts_file 20 | 21 | bzcat $counts_file | head -n 2000 | bzip2 > $top_list 22 | 23 | mkdir -p $output_2000_dir 24 | 25 | python3 split_by_players.py $top_list $lichesss_raw_dir/*-{01..11}.pgn.bz2 $lichesss_raw_dir/*{3..8}-12.pgn.bz2 $output_2000_dir 26 | 27 | rm -v $top_list 28 | 29 | mkdir -p $output_2000_metadata_dir 30 | 31 | python3 player_game_counts.py $output_2000_dir $output_2000_metadata_dir 32 | 33 | python3 select_top_players.py $output_2000_metadata_dir \ 34 | ${players_list}_train.csv $num_train \ 35 | ${players_list}_validate.csv $num_val \ 36 | ${players_list}_test.csv $num_test \ 37 | 38 | mkdir -p $final_data_dir 39 | mkdir -p $final_data_dir/metadata 40 | cp -v ${players_list}*.csv $final_data_dir/metadata 41 | 42 | for c in "train" "validate" "test"; do 43 | mkdir $final_data_dir/${c} 44 | mkdir $final_data_dir/${c}_metadata 45 | for t in `tail -n +2 ${players_list}_${c}.csv|awk -F ',' '{print $1}'`; do 46 | cp -v ${output_2000_dir}/${t}.pgn.bz2 $final_data_dir/${c} 47 | cp ${output_2000_metadata_dir}/${t}.csv.bz2 $final_data_dir/${c}_metadata 48 | done 49 | done 50 | -------------------------------------------------------------------------------- /0-player_counting/2-select_extended_set.sh: -------------------------------------------------------------------------------- 1 | ##!/bin/bash 2 | set -e 3 | 4 | vals_dat_dir="../../data/transfer_players_data/validate_metadata/" 5 | vals_dir="../../data/transfer_players_validate" 6 | output_dir="../../data/transfer_players_extended" 7 | list_file='../../data/extended_list.csv' 8 | 9 | num_per_bin=5 10 | bins="1100 1300 1500 1700 1900" 11 | 12 | 13 | python3 select_binned_players.py $vals_dat_dir $list_file $num_per_bin $bins 14 | 15 | mkdir -p $output_dir 16 | 17 | while read player; do 18 | echo $player 19 | cp -r ${vals_dir}/${player} ${output_dir} 20 | done < $list_file 21 | -------------------------------------------------------------------------------- /0-player_counting/README.md: -------------------------------------------------------------------------------- 1 | # Player Counting 2 | 3 | This is the code we used to count the number of games each player has. 4 | -------------------------------------------------------------------------------- /0-player_counting/combine_player_counts.py: -------------------------------------------------------------------------------- 1 | import backend 2 | 3 | import argparse 4 | import bz2 5 | 6 | import pandas 7 | 8 | @backend.logged_main 9 | def main(): 10 | parser = argparse.ArgumentParser(description='Collect counts and create list from them', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 11 | parser.add_argument('inputs', nargs = '+', help='input csvs') 12 | parser.add_argument('output', help='output csv') 13 | args = parser.parse_args() 14 | 15 | counts = {} 16 | for p in args.inputs: 17 | backend.printWithDate(f"Processing {p}", end = '\r') 18 | df = pandas.read_csv(p) 19 | for i, row in df.iterrows(): 20 | try: 21 | counts[row['player']] += row['count'] 22 | except KeyError: 23 | counts[row['player']] = row['count'] 24 | backend.printWithDate(f"Writing") 25 | with bz2.open(args.output, 'wt') as f: 26 | f.write('player,count\n') 27 | for p, c in sorted(counts.items(), key = lambda x: x[1], reverse=True): 28 | f.write(f"{p},{c}\n") 29 | 30 | if __name__ == '__main__': 31 | main() 32 | -------------------------------------------------------------------------------- /0-player_counting/find_top_players.py: -------------------------------------------------------------------------------- 1 | import backend 2 | 3 | import argparse 4 | import bz2 5 | 6 | @backend.logged_main 7 | def main(): 8 | parser = argparse.ArgumentParser(description='Count number of times each player occurs in pgn', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 9 | 10 | parser.add_argument('input', help='input pgn') 11 | parser.add_argument('output', help='output csv') 12 | parser.add_argument('--exclude_bullet', action='store_false', help='Remove bullet games from counts') 13 | args = parser.parse_args() 14 | 15 | games = backend.GamesFile(args.input) 16 | 17 | counts = {} 18 | 19 | for i, (d, _) in enumerate(games): 20 | if args.exclude_bullet and 'Bullet' in d['Event']: 21 | continue 22 | else: 23 | add_player(d['White'], counts) 24 | add_player(d['Black'], counts) 25 | if i % 10000 == 0: 26 | backend.printWithDate(f"{i} done with {len(counts)} players from {args.input}", end = '\r') 27 | 28 | backend.printWithDate(f"{i} found total of {len(counts)} players from {args.input}") 29 | with bz2.open(args.output, 'wt') as f: 30 | f.write("player,count\n") 31 | for p, c in sorted(counts.items(), key = lambda x: x[1], reverse=True): 32 | f.write(f"{p},{c}\n") 33 | backend.printWithDate("done") 34 | 35 | def add_player(p, d): 36 | try: 37 | d[p] += 1 38 | except KeyError: 39 | d[p] = 1 40 | 41 | if __name__ == '__main__': 42 | main() 43 | -------------------------------------------------------------------------------- /0-player_counting/player_game_counts.py: -------------------------------------------------------------------------------- 1 | import backend 2 | 3 | import os 4 | import os.path 5 | import csv 6 | import bz2 7 | import argparse 8 | 9 | @backend.logged_main 10 | def main(): 11 | parser = argparse.ArgumentParser(description='Get some stats about each of the games') 12 | parser.add_argument('targets_dir', help='input pgns dir') 13 | parser.add_argument('output_dir', help='output csvs dir') 14 | parser.add_argument('--pool_size', type=int, help='Number of models to run in parallel', default = 64) 15 | args = parser.parse_args() 16 | multiProc = backend.Multiproc(args.pool_size) 17 | multiProc.reader_init(Files_lister, args.targets_dir) 18 | multiProc.processor_init(Games_processor, args.output_dir) 19 | 20 | multiProc.run() 21 | 22 | class Files_lister(backend.MultiprocIterable): 23 | def __init__(self, targets_dir): 24 | self.targets_dir = targets_dir 25 | self.targets = [(p.path, p.name.split('.')[0]) for p in os.scandir(targets_dir) if '.pgn.bz2' in p.name] 26 | backend.printWithDate(f"Found {len(self.targets)} targets in {targets_dir}") 27 | def __next__(self): 28 | try: 29 | backend.printWithDate(f"Pushed target {len(self.targets)} remaining", end = '\r', flush = True) 30 | return self.targets.pop() 31 | except IndexError: 32 | raise StopIteration 33 | 34 | class Games_processor(backend.MultiprocWorker): 35 | def __init__(self, output_dir): 36 | self.output_dir = output_dir 37 | 38 | def __call__(self, path, name): 39 | games = backend.GamesFile(path) 40 | with bz2.open(os.path.join(self.output_dir, f"{name}.csv.bz2"), 'wt') as f: 41 | writer = csv.DictWriter(f, ["player", "opponent","game_id", "ELO", "opp_ELO", "was_white", "result", "won", "UTCDate", "UTCTime", "TimeControl"]) 42 | 43 | writer.writeheader() 44 | for d, _ in games: 45 | game_dat = {} 46 | game_dat['player'] = name 47 | game_dat['game_id'] = d['Site'].split('/')[-1] 48 | game_dat['result'] = d['Result'] 49 | game_dat['UTCDate'] = d['UTCDate'] 50 | game_dat['UTCTime'] = d['UTCTime'] 51 | game_dat['TimeControl'] = d['TimeControl'] 52 | if d['Black'] == name: 53 | game_dat['was_white'] = False 54 | game_dat['opponent'] = d['White'] 55 | game_dat['ELO'] = d['BlackElo'] 56 | game_dat['opp_ELO'] = d['WhiteElo'] 57 | game_dat['won'] = d['Result'] == '0-1' 58 | else: 59 | game_dat['was_white'] = True 60 | game_dat['opponent'] = d['Black'] 61 | game_dat['ELO'] = d['WhiteElo'] 62 | game_dat['opp_ELO'] = d['BlackElo'] 63 | game_dat['won'] = d['Result'] == '1-0' 64 | writer.writerow(game_dat) 65 | 66 | if __name__ == '__main__': 67 | main() 68 | -------------------------------------------------------------------------------- /0-player_counting/select_binned_players.py: -------------------------------------------------------------------------------- 1 | import backend 2 | 3 | import argparse 4 | import bz2 5 | import glob 6 | import random 7 | import os.path 8 | import multiprocessing 9 | 10 | import pandas 11 | 12 | @backend.logged_main 13 | def main(): 14 | parser = argparse.ArgumentParser(description='Read all the metadata and select top n players for training/validation/testing', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 15 | parser.add_argument('csvs_dir', help='dir of csvs') 16 | parser.add_argument('output_list', help='list of targets') 17 | parser.add_argument('bin_size', type=int, help='players per bin') 18 | parser.add_argument('bins', type=int, nargs = '+', help='bins') 19 | parser.add_argument('--pool_size', type=int, help='Number of threads to use for reading', default = 48) 20 | parser.add_argument('--seed', type=int, help='random seed', default = 1) 21 | args = parser.parse_args() 22 | random.seed(args.seed) 23 | 24 | bins = [int(b // 100 * 100) for b in args.bins] 25 | 26 | with multiprocessing.Pool(args.pool_size) as pool: 27 | players = pool.map(load_player, glob.glob(os.path.join(args.csvs_dir, '*.csv.bz2'))) 28 | backend.printWithDate(f"Found {len(players)} players, using {len(bins)} bins") 29 | binned_players = {b : [] for b in bins} 30 | for p in players: 31 | pe_round = int(p['elo'] // 100 * 100) 32 | if pe_round in bins: 33 | binned_players[pe_round].append(p) 34 | backend.printWithDate(f"Found: " + ', '.join([f"{b} : {len(p)}" for b, p in binned_players.items()])) 35 | 36 | with open(args.output_list, 'wt') as f: 37 | for b, p in binned_players.items(): 38 | random.shuffle(p) 39 | print(b, [d['name'] for d in p[:args.bin_size]]) 40 | f.write('\n'.join([d['name'] for d in p[:args.bin_size]]) +'\n') 41 | 42 | def load_player(path): 43 | df = pandas.read_csv(path, low_memory=False) 44 | elo = df['ELO'][-10000:].mean() 45 | count = len(df) 46 | return { 47 | 'name' : df['player'].iloc[0], 48 | 'elo' : elo, 49 | 'count' : count, 50 | } 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /0-player_counting/select_top_players.py: -------------------------------------------------------------------------------- 1 | import backend 2 | 3 | import argparse 4 | import bz2 5 | import glob 6 | import random 7 | import os.path 8 | import multiprocessing 9 | 10 | import pandas 11 | 12 | @backend.logged_main 13 | def main(): 14 | parser = argparse.ArgumentParser(description='Read all the metadata and select top n players for training/validation/testing', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 15 | parser.add_argument('inputs', help='input csvs dir') 16 | parser.add_argument('output_train', help='output csv for training data') 17 | parser.add_argument('num_train', type=int, help='num for main training') 18 | parser.add_argument('output_val', help='output csv for validation data') 19 | parser.add_argument('num_val', type=int, help='num for big validation run') 20 | parser.add_argument('output_test', help='output csv for testing data') 21 | parser.add_argument('num_test', type=int, help='num for holdout set') 22 | parser.add_argument('--pool_size', type=int, help='Number of models to run in parallel', default = 48) 23 | parser.add_argument('--min_elo', type=int, help='min elo to select', default = 1100) 24 | parser.add_argument('--max_elo', type=int, help='max elo to select', default = 2000) 25 | parser.add_argument('--seed', type=int, help='random seed', default = 1) 26 | args = parser.parse_args() 27 | random.seed(args.seed) 28 | 29 | targets = glob.glob(os.path.join(args.inputs, '*csv.bz2')) 30 | 31 | with multiprocessing.Pool(args.pool_size) as pool: 32 | players = pool.starmap(check_player, ((t, args.min_elo, args.max_elo) for t in targets)) 33 | 34 | players_top = sorted( 35 | (p for p in players if p is not None), 36 | key = lambda x : x[1], 37 | reverse=True, 38 | )[:args.num_train + args.num_val + args.num_test] 39 | 40 | random.shuffle(players_top) 41 | 42 | write_output_file(args.output_train, args.num_train, players_top) 43 | write_output_file(args.output_val, args.num_val, players_top) 44 | write_output_file(args.output_test, args.num_test, players_top) 45 | 46 | def write_output_file(path, count, targets): 47 | with open(path, 'wt') as f: 48 | f.write("player,count,ELO\n") 49 | for i in range(count): 50 | t = targets.pop() 51 | f.write(f"{t[0]},{t[1]},{t[2]}\n") 52 | 53 | def check_player(path, min_elo, max_elo): 54 | df = pandas.read_csv(path, low_memory=False) 55 | elo = df['ELO'][-10000:].mean() 56 | count = len(df) 57 | if elo > min_elo and elo < max_elo: 58 | return path.split('/')[-1].split('.')[0], count, elo 59 | else: 60 | return None 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /0-player_counting/split_by_players.py: -------------------------------------------------------------------------------- 1 | import backend 2 | 3 | import pandas 4 | import lockfile 5 | 6 | import argparse 7 | import bz2 8 | import os 9 | import os.path 10 | 11 | @backend.logged_main 12 | def main(): 13 | parser = argparse.ArgumentParser(description='Write pgns of games with slected players in them', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 14 | 15 | parser.add_argument('target', help='target players list as csv') 16 | parser.add_argument('inputs', nargs = '+', help='input pgns') 17 | parser.add_argument('output', help='output dir') 18 | parser.add_argument('--exclude_bullet', action='store_false', help='Remove bullet games from counts') 19 | parser.add_argument('--pool_size', type=int, help='Number of models to run in parallel', default = 48) 20 | args = parser.parse_args() 21 | 22 | df_targets = pandas.read_csv(args.target) 23 | targets = set(df_targets['player']) 24 | 25 | os.makedirs(args.output, exist_ok=True) 26 | 27 | multiProc = backend.Multiproc(args.pool_size) 28 | multiProc.reader_init(Files_lister, args.inputs) 29 | multiProc.processor_init(Games_processor, targets, args.output, args.exclude_bullet) 30 | multiProc.run() 31 | backend.printWithDate("done") 32 | 33 | class Files_lister(backend.MultiprocIterable): 34 | def __init__(self, inputs): 35 | self.inputs = list(inputs) 36 | backend.printWithDate(f"Found {len(self.inputs)}") 37 | def __next__(self): 38 | try: 39 | backend.printWithDate(f"Pushed target {len(self.inputs)} remaining", end = '\r', flush = True) 40 | return self.inputs.pop() 41 | except IndexError: 42 | raise StopIteration 43 | 44 | class Games_processor(backend.MultiprocWorker): 45 | def __init__(self, targets, output_dir, exclude_bullet): 46 | self.output_dir = output_dir 47 | self.targets = targets 48 | self.exclude_bullet = exclude_bullet 49 | 50 | self.c = 0 51 | 52 | def __call__(self, path): 53 | games = backend.GamesFile(path) 54 | self.c = 0 55 | for i, (d, s) in enumerate(games): 56 | if self.exclude_bullet and 'Bullet' in d['Event']: 57 | continue 58 | else: 59 | if d['White'] in self.targets: 60 | self.write_player(d['White'], s) 61 | self.c += 1 62 | if d['Black'] in self.targets: 63 | self.write_player(d['Black'], s) 64 | self.c += 1 65 | if i % 10000 == 0: 66 | backend.printWithDate(f"{path} {i} done with {self.c} writes", end = '\r') 67 | 68 | def write_player(self, p_name, s): 69 | 70 | p_path = os.path.join(self.output_dir, f"{p_name}.pgn.bz2") 71 | lock_path = p_path + '.lock' 72 | lock = lockfile.FileLock(lock_path) 73 | with lock: 74 | with bz2.open(p_path, 'at') as f: 75 | f.write(s) 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /1-data_generation/0-0-make_training_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | train_frac=80 4 | val_frac=10 5 | test_frac=10 6 | 7 | input_files="/maiadata/transfer_players_data/train" 8 | output_files="/maiadata/transfer_players_train" 9 | mkdir -p $output_files 10 | 11 | for player_file in $input_files/*.bz2; do 12 | f=${player_file##*/} 13 | p_name=${f%.pgn.bz2} 14 | p_dir=$output_files/$p_name 15 | split_dir=$output_files/$p_name/split 16 | mkdir -p $p_dir 17 | mkdir -p $split_dir 18 | echo $p_name $p_dir 19 | python split_by_player.py $player_file $p_name $split_dir/games 20 | 21 | 22 | for c in "white" "black"; do 23 | python pgn_fractional_split.py $split_dir/games_$c.pgn.bz2 $split_dir/train_$c.pgn.bz2 $split_dir/validate_$c.pgn.bz2 $split_dir/test_$c.pgn.bz2 --ratios $train_frac $val_frac $test_frac 24 | 25 | cd $p_dir 26 | mkdir -p pgns 27 | for s in "train" "validate" "test"; do 28 | mkdir -p $s 29 | mkdir $s/$c 30 | 31 | #using tool from: 32 | #https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/ 33 | bzcat $split_dir/${s}_${c}.pgn.bz2 | pgn-extract -7 -C -N -#1000 34 | 35 | cat *.pgn > pgns/${s}_${c}.pgn 36 | rm -v *.pgn 37 | 38 | #using tool from: 39 | #https://github.com/DanielUranga/trainingdata-tool 40 | screen -S "${p_name}-${c}-${s}" -dm bash -c "cd ${s}/${c}; trainingdata-tool -v ../../pgns/${s}_${c}.pgn" 41 | done 42 | cd - 43 | done 44 | 45 | done 46 | -------------------------------------------------------------------------------- /1-data_generation/0-1-make_training_csvs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data_dir="/maiadata/transfer_players_train" 4 | 5 | for player_dir in $data_dir/*; do 6 | player_name=`basename ${player_dir}` 7 | mkdir $player_dir/csvs 8 | for c in "white" "black"; do 9 | for s in "train" "validate" "test"; do 10 | target=$player_dir/split/${s}_${c}.pgn.bz2 11 | output=$player_dir/csvs/${s}_${c}.csv.bz2 12 | echo ${player_name} ${s} ${c} 13 | screen -S "csv-${player_name}-${c}-${s}" -dm bash -c "python3 ../../data_generators/pgn_to_csv.py ${target} ${output}" 14 | done 15 | done 16 | done 17 | -------------------------------------------------------------------------------- /1-data_generation/0-2-make_reduced_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | train_frac=80 4 | val_frac=10 5 | test_frac=10 6 | 7 | input_files="/maiadata/transfer_players_data/train" 8 | output_files="/maiadata/transfer_players_train_reduced" 9 | mkdir -p $output_files 10 | 11 | fractions='100 10 1' 12 | 13 | for frac in `echo $fractions`; do 14 | for player_file in $input_files/*.bz2; do 15 | f=${player_file##*/} 16 | p_name=${f%.pgn.bz2} 17 | p_dir=$output_files/$p_name/$frac 18 | split_dir=$output_files/$p_name/$frac/split 19 | mkdir -p $p_dir 20 | mkdir -p $split_dir 21 | 22 | python pgn_fractional_split.py $player_file $p_dir/raw_reduced.pgn.bz2 $p_dir/extra.pgn.bz2 --ratios $frac `echo "1000- $frac " | bc` 23 | 24 | echo $p_name $frac $p_dir 25 | python split_by_player.py $p_dir/raw_reduced.pgn.bz2 $p_name $split_dir/games 26 | 27 | 28 | for c in "white" "black"; do 29 | python pgn_fractional_split.py $split_dir/games_$c.pgn.bz2 $split_dir/train_$c.pgn.bz2 $split_dir/validate_$c.pgn.bz2 $split_dir/test_$c.pgn.bz2 --ratios $train_frac $val_frac $test_frac 30 | 31 | cd $p_dir 32 | mkdir -p pgns 33 | for s in "train" "validate" "test"; do 34 | mkdir -p $s 35 | mkdir $s/$c 36 | 37 | #using tool from: 38 | #https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/ 39 | bzcat $split_dir/${s}_${c}.pgn.bz2 | pgn-extract -7 -C -N -#1000 40 | 41 | cat *.pgn > pgns/${s}_${c}.pgn 42 | rm -v *.pgn 43 | 44 | #using tool from: 45 | #https://github.com/DanielUranga/trainingdata-tool 46 | screen -S "${p_name}-${c}-${s}" -dm bash -c "cd ${s}/${c}; trainingdata-tool -v ../../pgns/${s}_${c}.pgn" 47 | done 48 | cd - 49 | done 50 | done 51 | done 52 | -------------------------------------------------------------------------------- /1-data_generation/1-0-make_val_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | train_frac=80 6 | val_frac=10 7 | test_frac=10 8 | 9 | input_files="/maiadata/transfer_players_data/validate" 10 | output_files="/maiadata/transfer_players_validate" 11 | mkdir -p $output_files 12 | 13 | for player_file in $input_files/*.bz2; do 14 | f=${player_file##*/} 15 | p_name=${f%.pgn.bz2} 16 | p_dir=$output_files/$p_name 17 | 18 | f_size=$(du -sb ${output_files}/${p_name} | cut -f1) 19 | if [ $((f_size)) -lt 50000 ]; then 20 | echo $f_size $p_dir 21 | rm -rv $p_dir/* 22 | else 23 | continue 24 | fi 25 | 26 | split_dir=$output_files/$p_name/split 27 | mkdir -p $p_dir 28 | mkdir -p $split_dir 29 | echo $p_name $p_dir 30 | python split_by_player.py $player_file $p_name $split_dir/games 31 | 32 | for c in "white" "black"; do 33 | python pgn_fractional_split.py $split_dir/games_$c.pgn.bz2 $split_dir/train_$c.pgn.bz2 $split_dir/validate_$c.pgn.bz2 $split_dir/test_$c.pgn.bz2 --ratios $train_frac $val_frac $test_frac 34 | 35 | cd $p_dir 36 | mkdir -p pgns 37 | for s in "train" "validate" "test"; do 38 | mkdir -p $s 39 | mkdir $s/$c 40 | 41 | #using tool from: 42 | #https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/ 43 | bzcat $split_dir/${s}_${c}.pgn.bz2 | pgn-extract -7 -C -N -#1000 44 | 45 | cat *.pgn > pgns/${s}_${c}.pgn 46 | rm -v *.pgn 47 | 48 | #using tool from: 49 | #https://github.com/DanielUranga/trainingdata-tool 50 | screen -S "${p_name}-${c}-${s}" -dm bash -c "cd ${s}/${c}; trainingdata-tool -v ../../pgns/${s}_${c}.pgn" 51 | done 52 | cd - 53 | done 54 | while [ `screen -ls | wc -l` -gt 20 ]; do 55 | printf "waiting\r" 56 | sleep 10 57 | done 58 | done 59 | -------------------------------------------------------------------------------- /1-data_generation/1-1-make_val_csvs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data_dir="../../data/transfer_players_validate" 4 | 5 | for player_dir in $data_dir/*; do 6 | player_name=`basename ${player_dir}` 7 | mkdir -p $player_dir/csvs 8 | for c in "white" "black"; do 9 | for s in "train" "validate" "test"; do 10 | target=$player_dir/split/${s}_${c}.pgn.bz2 11 | output=$player_dir/csvs/${s}_${c}.csv.bz2 12 | echo ${player_name} ${s} ${c} 13 | screen -S "csv-${player_name}-${c}-${s}" -dm bash -c "python3 ../../data_generators/pgn_to_csv.py ${target} ${output}" 14 | done 15 | done 16 | while [ `screen -ls | wc -l` -gt 50 ]; do 17 | printf "waiting\r" 18 | sleep 10 19 | done 20 | done 21 | -------------------------------------------------------------------------------- /1-data_generation/2-make_testing_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | train_frac=80 4 | val_frac=10 5 | test_frac=10 6 | 7 | input_files="/maiadata/transfer_players_data/test" 8 | output_files="/maiadata/transfer_players_test" 9 | mkdir -p $output_files 10 | 11 | for player_file in $input_files/*.bz2; do 12 | f=${player_file##*/} 13 | p_name=${f%.pgn.bz2} 14 | p_dir=$output_files/$p_name 15 | split_dir=$output_files/$p_name/split 16 | mkdir -p $p_dir 17 | mkdir -p $split_dir 18 | echo $p_name $p_dir 19 | python split_by_player.py $player_file $p_name $split_dir/games 20 | 21 | 22 | for c in "white" "black"; do 23 | python pgn_fractional_split.py $split_dir/games_$c.pgn.bz2 $split_dir/train_$c.pgn.bz2 $split_dir/validate_$c.pgn.bz2 $split_dir/test_$c.pgn.bz2 --ratios $train_frac $val_frac $test_frac 24 | 25 | cd $p_dir 26 | mkdir -p pgns 27 | for s in "train" "validate" "test"; do 28 | mkdir -p $s 29 | mkdir $s/$c 30 | 31 | #using tool from: 32 | #https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/ 33 | bzcat $split_dir/${s}_${c}.pgn.bz2 | pgn-extract -7 -C -N -#1000 34 | 35 | cat *.pgn > pgns/${s}_${c}.pgn 36 | rm -v *.pgn 37 | 38 | #using tool from: 39 | #https://github.com/DanielUranga/trainingdata-tool 40 | screen -S "${p_name}-${c}-${s}" -dm bash -c "cd ${s}/${c}; trainingdata-tool -v ../../pgns/${s}_${c}.pgn" 41 | done 42 | cd - 43 | done 44 | while [ `screen -ls | wc -l` -gt 20 ]; do 45 | printf "waiting\r" 46 | sleep 10 47 | done 48 | 49 | done 50 | -------------------------------------------------------------------------------- /1-data_generation/9-pgn_to_training_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | #args input_path output_dir player 5 | 6 | player_file=${1} 7 | p_dir=${2} 8 | p_name=${3} 9 | 10 | train_frac=90 11 | val_frac=10 12 | 13 | split_dir=$p_dir/split 14 | 15 | mkdir -p ${p_dir} 16 | mkdir -p ${split_dir} 17 | 18 | echo "${p_name} to ${p_dir}" 19 | 20 | python split_by_player.py $player_file $p_name $split_dir/games 21 | 22 | for c in "white" "black"; do 23 | python pgn_fractional_split.py $split_dir/games_$c.pgn.bz2 $split_dir/train_$c.pgn.bz2 $split_dir/validate_$c.pgn.bz2 --ratios $train_frac $val_frac 24 | 25 | cd $p_dir 26 | mkdir -p pgns 27 | for s in "train" "validate"; do 28 | mkdir -p $s 29 | mkdir -p $s/$c 30 | 31 | #using tool from: 32 | #https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/ 33 | bzcat $split_dir/${s}_${c}.pgn.bz2 | pgn-extract -7 -C -N -#1000 34 | 35 | cat *.pgn > pgns/${s}_${c}.pgn 36 | rm -v *.pgn 37 | 38 | #using tool from: 39 | #https://github.com/DanielUranga/trainingdata-tool 40 | screen -S "${p_name}-${c}-${s}" -dm bash -c "cd ${s}/${c}; trainingdata-tool -v ../../pgns/${s}_${c}.pgn" 41 | done 42 | cd - 43 | done 44 | -------------------------------------------------------------------------------- /1-data_generation/pgn_fractional_split.py: -------------------------------------------------------------------------------- 1 | import backend 2 | 3 | import argparse 4 | import bz2 5 | import random 6 | 7 | @backend.logged_main 8 | def main(): 9 | parser = argparse.ArgumentParser(description='Split games into some numbe of subsets, by percentage', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | 11 | parser.add_argument('input', help='input pgn') 12 | 13 | parser.add_argument('outputs', nargs='+', help='output pgn files ', type = str) 14 | 15 | parser.add_argument('--ratios', nargs='+', help='ratios of games for the outputs', required = True, type = float) 16 | 17 | parser.add_argument('--no_shuffle', action='store_false', help='Stop output shuffling') 18 | parser.add_argument('--seed', type=int, help='random seed', default = 1) 19 | args = parser.parse_args() 20 | 21 | if len(args.ratios) != len(args.outputs): 22 | raise RuntimeError(f"Invalid outputs specified: {args.outputs} and {args.ratios}") 23 | 24 | random.seed(args.seed) 25 | games = backend.GamesFile(args.input) 26 | 27 | game_strs = [] 28 | 29 | for i, (d, l) in enumerate(games): 30 | game_strs.append(l) 31 | if i % 10000 == 0: 32 | backend.printWithDate(f"{i} done from {args.input}", end = '\r') 33 | backend.printWithDate(f"{i} done total from {args.input}") 34 | if not args.no_shuffle: 35 | random.shuffle(game_strs) 36 | 37 | split_indices = [int(r * len(game_strs) / sum(args.ratios)) for r in args.ratios] 38 | 39 | #Correction for rounding, not very precise 40 | split_indices[0] += len(game_strs) - sum(split_indices) 41 | 42 | for p, c in zip(args.outputs, split_indices): 43 | backend.printWithDate(f"Writing {c} games to: {p}") 44 | with bz2.open(p, 'wt') as f: 45 | f.write(''.join( 46 | [game_strs.pop() for i in range(c)] 47 | )) 48 | 49 | backend.printWithDate("done") 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /1-data_generation/player_splits.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | 3 | input_files="../../data/top_player_games" 4 | output_files="../../data/transfer_players_pgns_split" 5 | mkdir -p $output_files 6 | 7 | train_frac=80 8 | val_frac=10 9 | test_frac=10 10 | 11 | for p in $input_files/*; do 12 | name=`basename $p` 13 | p_name=${name%.pgn.bz2} 14 | split_dir=$output_files/$name 15 | mkdir $split_dir 16 | 17 | screen -S "${p_name}" -dm bash -c "python3 pgn_fractional_split.py $p $split_dir/train.pgn.bz2 $split_dir/validate.pgn.bz2 $split_dir/test.pgn.bz2 --ratios $train_frac $val_frac $test_frac" 18 | done 19 | -------------------------------------------------------------------------------- /1-data_generation/split_by_player.py: -------------------------------------------------------------------------------- 1 | import backend 2 | 3 | import argparse 4 | import bz2 5 | import random 6 | 7 | @backend.logged_main 8 | def main(): 9 | parser = argparse.ArgumentParser(description='Split games into games were the target was White or Black', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | 11 | parser.add_argument('input', help='input pgn') 12 | parser.add_argument('player', help='target player name') 13 | parser.add_argument('output', help='output pgn prefix') 14 | parser.add_argument('--no_shuffle', action='store_false', help='Stop output shuffling') 15 | parser.add_argument('--seed', type=int, help='random seed', default = 1) 16 | args = parser.parse_args() 17 | 18 | random.seed(args.seed) 19 | 20 | games = backend.GamesFile(args.input) 21 | 22 | outputs_white = [] 23 | outputs_black = [] 24 | 25 | for i, (d, l) in enumerate(games): 26 | if d['White'] == args.player: 27 | outputs_white.append(l) 28 | elif d['Black'] == args.player: 29 | outputs_black.append(l) 30 | else: 31 | raise ValueError(f"{args.player} not found in game {i}:\n{l}") 32 | if i % 10000 == 0: 33 | backend.printWithDate(f"{i} done with {len(outputs_white)}:{len(outputs_black)} players from {args.input}", end = '\r') 34 | backend.printWithDate(f"{i} found totals of {len(outputs_white)}:{len(outputs_black)} players from {args.input}") 35 | backend.printWithDate("Writing white") 36 | with bz2.open(f"{args.output}_white.pgn.bz2", 'wt') as f: 37 | if not args.no_shuffle: 38 | random.shuffle(outputs_white) 39 | f.write(''.join(outputs_white)) 40 | backend.printWithDate("Writing black") 41 | with bz2.open(f"{args.output}_black.pgn.bz2", 'wt') as f: 42 | if not args.no_shuffle: 43 | random.shuffle(outputs_black) 44 | f.write(''.join(outputs_black)) 45 | backend.printWithDate("done") 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /2-training/extended_configs/frozen_copy/frozen_copy.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | #gpu: 1 4 | 5 | dataset: 6 | path: '/data/transfer_players_extended/' 7 | #name: '' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 256 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 150000 16 | checkpoint_steps: 500 17 | shuffle_size: 256 18 | lr_values: 19 | - 0.01 20 | - 0.001 21 | - 0.0001 22 | - 0.00001 23 | lr_boundaries: 24 | - 35000 25 | - 80000 26 | - 110000 27 | policy_loss_weight: 1.0 28 | value_loss_weight: 1.0 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | path: "maia/1900" 35 | keep_weights: true 36 | back_prop_blocks: 3 37 | ... 38 | -------------------------------------------------------------------------------- /2-training/extended_configs/frozen_random/frozen_random.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | #gpu: 1 4 | 5 | dataset: 6 | path: '/data/transfer_players_extended/' 7 | #name: '' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 256 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 150000 16 | checkpoint_steps: 500 17 | shuffle_size: 256 18 | lr_values: 19 | - 0.01 20 | - 0.001 21 | - 0.0001 22 | - 0.00001 23 | lr_boundaries: 24 | - 35000 25 | - 80000 26 | - 110000 27 | policy_loss_weight: 1.0 28 | value_loss_weight: 1.0 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | path: "maia/1900" 35 | keep_weights: false 36 | back_prop_blocks: 3 37 | ... 38 | -------------------------------------------------------------------------------- /2-training/extended_configs/unfrozen_copy/unfrozen_copy.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | #gpu: 1 4 | 5 | dataset: 6 | path: '/data/transfer_players_extended/' 7 | #name: '' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 256 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 150000 16 | checkpoint_steps: 500 17 | shuffle_size: 256 18 | lr_values: 19 | - 0.01 20 | - 0.001 21 | - 0.0001 22 | - 0.00001 23 | lr_boundaries: 24 | - 35000 25 | - 80000 26 | - 110000 27 | policy_loss_weight: 1.0 28 | value_loss_weight: 1.0 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | path: "maia/1900" 35 | keep_weights: true 36 | back_prop_blocks: 99 37 | ... 38 | -------------------------------------------------------------------------------- /2-training/extended_configs/unfrozen_random/unfrozen_random.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | #gpu: 1 4 | 5 | dataset: 6 | path: '/data/transfer_players_extended/' 7 | #name: '' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 256 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 150000 16 | checkpoint_steps: 500 17 | shuffle_size: 256 18 | lr_values: 19 | - 0.01 20 | - 0.001 21 | - 0.0001 22 | - 0.00001 23 | lr_boundaries: 24 | - 35000 25 | - 80000 26 | - 110000 27 | policy_loss_weight: 1.0 28 | value_loss_weight: 1.0 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | path: "maia/1900" 35 | keep_weights: false 36 | back_prop_blocks: 99 37 | ... 38 | -------------------------------------------------------------------------------- /2-training/final_config.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | path: 'path to player data' 7 | #name: '' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 256 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 150000 16 | checkpoint_steps: 500 17 | shuffle_size: 256 18 | lr_values: 19 | - 0.01 20 | - 0.001 21 | - 0.0001 22 | - 0.00001 23 | lr_boundaries: 24 | - 35000 25 | - 80000 26 | - 110000 27 | policy_loss_weight: 1.0 28 | value_loss_weight: 1.0 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | path: "maia-1900" 35 | keep_weights: true 36 | back_prop_blocks: 99 37 | ... 38 | -------------------------------------------------------------------------------- /2-training/train_transfer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path 4 | import yaml 5 | import sys 6 | import glob 7 | import gzip 8 | import random 9 | import multiprocessing 10 | import shutil 11 | 12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 13 | import tensorflow as tf 14 | 15 | 16 | import backend 17 | import backend.tf_transfer 18 | 19 | SKIP = 32 20 | 21 | @backend.logged_main 22 | def main(config_path, name, collection_name, player_name, gpu, num_workers): 23 | output_name = os.path.join('models', collection_name, name + '.txt') 24 | 25 | with open(config_path) as f: 26 | cfg = yaml.safe_load(f.read()) 27 | 28 | if player_name is not None: 29 | cfg['dataset']['name'] = player_name 30 | if gpu is not None: 31 | cfg['gpu'] = gpu 32 | 33 | backend.printWithDate(yaml.dump(cfg, default_flow_style=False)) 34 | 35 | train_chunks_white, train_chunks_black = backend.tf_transfer.get_latest_chunks(os.path.join( 36 | cfg['dataset']['path'], 37 | cfg['dataset']['name'], 38 | 'train', 39 | )) 40 | val_chunks_white, val_chunks_black = backend.tf_transfer.get_latest_chunks(os.path.join( 41 | cfg['dataset']['path'], 42 | cfg['dataset']['name'], 43 | 'validate', 44 | )) 45 | 46 | shuffle_size = cfg['training']['shuffle_size'] 47 | total_batch_size = cfg['training']['batch_size'] 48 | backend.tf_transfer.ChunkParser.BATCH_SIZE = total_batch_size 49 | tfprocess = backend.tf_transfer.TFProcess(cfg, name, collection_name) 50 | 51 | train_parser = backend.tf_transfer.ChunkParser( 52 | backend.tf_transfer.FileDataSrc(train_chunks_white.copy(), train_chunks_black.copy()), 53 | shuffle_size=shuffle_size, 54 | sample=SKIP, 55 | batch_size=backend.tf_transfer.ChunkParser.BATCH_SIZE, 56 | workers=num_workers, 57 | ) 58 | train_dataset = tf.data.Dataset.from_generator( 59 | train_parser.parse, 60 | output_types=( 61 | tf.string, tf.string, tf.string, tf.string 62 | ), 63 | ) 64 | train_dataset = train_dataset.map( 65 | backend.tf_transfer.ChunkParser.parse_function) 66 | train_dataset = train_dataset.prefetch(4) 67 | 68 | test_parser = backend.tf_transfer.ChunkParser( 69 | backend.tf_transfer.FileDataSrc(val_chunks_white.copy(), val_chunks_black.copy()), 70 | shuffle_size=shuffle_size, 71 | sample=SKIP, 72 | batch_size=backend.tf_transfer.ChunkParser.BATCH_SIZE, 73 | workers=num_workers, 74 | ) 75 | test_dataset = tf.data.Dataset.from_generator( 76 | test_parser.parse, 77 | output_types=(tf.string, tf.string, tf.string, tf.string), 78 | ) 79 | test_dataset = test_dataset.map( 80 | backend.tf_transfer.ChunkParser.parse_function) 81 | test_dataset = test_dataset.prefetch(4) 82 | 83 | tfprocess.init_v2(train_dataset, test_dataset) 84 | 85 | tfprocess.restore_v2() 86 | 87 | num_evals = cfg['training'].get('num_test_positions', (len(val_chunks_white) + len(val_chunks_black)) * 10) 88 | num_evals = max(1, num_evals // backend.tf_transfer.ChunkParser.BATCH_SIZE) 89 | print("Using {} evaluation batches".format(num_evals)) 90 | try: 91 | tfprocess.process_loop_v2(total_batch_size, num_evals, batch_splits=1) 92 | except KeyboardInterrupt: 93 | backend.printWithDate("KeyboardInterrupt: Stopping") 94 | train_parser.shutdown() 95 | test_parser.shutdown() 96 | raise 97 | tfprocess.save_leelaz_weights_v2(output_name) 98 | 99 | train_parser.shutdown() 100 | test_parser.shutdown() 101 | return cfg 102 | 103 | def make_model_files(cfg, name, collection_name, save_dir): 104 | output_name = os.path.join(save_dir, collection_name, name) 105 | models_dir = os.path.join('models', collection_name, name) 106 | models = [(int(p.name.split('-')[1]), p.name, p.path) for p in os.scandir(models_dir) if p.name.endswith('.pb.gz')] 107 | top_model = max(models, key = lambda x : x[0]) 108 | 109 | os.makedirs(output_name, exist_ok=True) 110 | model_file_name = top_model[1].replace('ckpt', name) 111 | shutil.copy(top_model[2], os.path.join(output_name, model_file_name)) 112 | with open(os.path.join(output_name, "config.yaml"), 'w') as f: 113 | cfg_yaml = yaml.dump(cfg).replace('\n', '\n ').strip() 114 | f.write(f""" 115 | %YAML 1.2 116 | --- 117 | name: {name} 118 | display_name: {name.replace('_', ' ')} 119 | engine: lc0_23 120 | options: 121 | weightsPath: {model_file_name} 122 | full_config: 123 | {cfg_yaml} 124 | ...""") 125 | 126 | if __name__ == "__main__": 127 | parser = argparse.ArgumentParser(description='Tensorflow pipeline for training Leela Chess.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 128 | 129 | parser.add_argument('config', help='config file for model / training') 130 | parser.add_argument('player_name', nargs='?', help='player name to train on', default=None) 131 | parser.add_argument('--gpu', help='gpu to use', default = 0, type = int) 132 | parser.add_argument('--num_workers', help='number of worker threads to use', default = max(1, multiprocessing.cpu_count() - 2), type = int) 133 | parser.add_argument('--copy_dir', help='dir to save final models in', default = 'final_models') 134 | args = parser.parse_args() 135 | 136 | collection_name = os.path.basename(os.path.dirname(args.config)).replace('configs_', '') 137 | name = os.path.basename(args.config).split('.')[0] 138 | 139 | if args.player_name is not None: 140 | name = f"{args.player_name}_{name}" 141 | 142 | multiprocessing.set_start_method('spawn') 143 | cfg = main(args.config, name, collection_name, args.player_name, args.gpu, args.num_workers) 144 | make_model_files(cfg, name, collection_name, args.copy_dir) 145 | -------------------------------------------------------------------------------- /3-analysis/1-0-baselines_results.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | targets_dir="../../data/transfer_players_train" 4 | outputs_dir="../../data/transfer_results/train" 5 | 6 | maias="../../models/maia" 7 | stockfish="../../models/stockfish/stockfish_d15" 8 | leela="../../models/leela/sergio" 9 | 10 | 11 | mkdir -p outputs_dir 12 | 13 | for player_dir in $targets_dir/*; do 14 | player=`basename ${player_dir}` 15 | player_ret_dir=$outputs_dir/$player 16 | 17 | echo $player_dir 18 | 19 | mkdir -p $player_ret_dir 20 | mkdir -p $player_ret_dir/maia 21 | mkdir -p $player_ret_dir/leela 22 | #mkdir -p $player_ret_dir/stockfish 23 | for c in "white" "black"; do 24 | player_files=$player_dir/csvs/test_${c}.csv.bz2 25 | #screen -S "baseline-tests-${player}-leela-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $leela ${player_files} ${player_ret_dir}/leela/segio_${c}.csv.bz2" 26 | for maia_path in $maias/*; do 27 | maia_name=`basename ${maia_path}` 28 | printf "$maia_name\r" 29 | screen -S "baseline-tests-${player}-${maia_name}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $maia_path ${player_files} ${player_ret_dir}/maia/${maia_name}_${c}.csv.bz2" 30 | done 31 | done 32 | done 33 | -------------------------------------------------------------------------------- /3-analysis/1-1-baselines_results_validation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | targets_dir="../../data/transfer_players_validate" 4 | outputs_dir="../../data/transfer_results/validate" 5 | 6 | maias="../../models/maia" 7 | stockfish="../../models/stockfish/stockfish_d15" 8 | leela="../../models/leela/sergio" 9 | 10 | 11 | mkdir -p outputs_dir 12 | 13 | for player_dir in $targets_dir/*; do 14 | player=`basename ${player_dir}` 15 | player_ret_dir=$outputs_dir/$player 16 | 17 | echo $player_dir 18 | 19 | mkdir -p $player_ret_dir 20 | mkdir -p $player_ret_dir/maia 21 | mkdir -p $player_ret_dir/leela 22 | #mkdir -p $player_ret_dir/stockfish 23 | for c in "white" "black"; do 24 | player_files=$player_dir/csvs/test_${c}.csv.bz2 25 | #screen -S "baseline-tests-${player}-leela-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $leela ${player_files} ${player_ret_dir}/leela/segio_${c}.csv.bz2" 26 | for maia_path in $maias/1{1..9..2}00; do 27 | maia_name=`basename ${maia_path}` 28 | printf "$maia_name\r" 29 | screen -S "baseline-tests-${player}-${maia_name}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $maia_path ${player_files} ${player_ret_dir}/maia/${maia_name}_${c}.csv.bz2" 30 | done 31 | done 32 | while [ `screen -ls | wc -l` -gt 70 ]; do 33 | printf "waiting\r" 34 | sleep 10 35 | done 36 | done 37 | -------------------------------------------------------------------------------- /3-analysis/2-0-baseline_results.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=40 5 | 6 | targets_dir="../../data/transfer_players" 7 | outputs_dir="../../data/transfer_results" 8 | kdd_path="../../datasets/10000_full_2019-12.csv.bz2" 9 | 10 | mkdir -p outputs_dir 11 | 12 | maias_dir=../../models/maia 13 | 14 | for t in "train" "extended" "validate"; do 15 | for player_dir in ${targets_dir}_${t}/*; do 16 | for model in $maias_dir/1{1..9..2}00; do 17 | maia_type=`basename ${model}` 18 | player_ret_dir=$outputs_dir/$t/$player/maia 19 | mkdir -p $player_ret_dir 20 | player=`basename ${player_dir}` 21 | echo $t $maia_type $player 22 | for c in "white" "black"; do 23 | player_files=${player_dir}/csvs/test_${c}.csv.bz2 24 | screen -S "baselines-${player}-${maia_type}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $model ${player_files} ${player_ret_dir}/${maia_type}_${c}.csv.bz2" 25 | done 26 | while [ `screen -ls | wc -l` -gt $max_screens ]; do 27 | printf "waiting\r" 28 | sleep 10 29 | done 30 | done 31 | done 32 | done 33 | 34 | -------------------------------------------------------------------------------- /3-analysis/2-1-model_results.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=50 5 | 6 | targets_dir="../../data/transfer_players" 7 | outputs_dir="../../data/transfer_results" 8 | kdd_path="../../datasets/10000_full_2019-12.csv.bz2" 9 | 10 | 11 | models_dir="../../transfer_training/final_models" 12 | mkdir -p outputs_dir 13 | 14 | for model in $models_dir/*/*; do 15 | player=`python3 get_models_player.py ${model}` 16 | model_type=`dirname ${model}` 17 | model_type=`basename ${model_type}` 18 | model_name=`basename ${model}` 19 | #echo $model $model_type $model_name $player 20 | 21 | for c in "white" "black"; do 22 | for t in "train" "extended"; do 23 | player_files=${targets_dir}_${t}/$player/csvs/test_${c}.csv.bz2 24 | if [ -f "$player_files" ]; then 25 | echo $player_files 26 | player_ret_dir=$outputs_dir/$t/$player/transfer/$model_type 27 | mkdir -p $player_ret_dir 28 | screen -S "transfer-tests-${player}-${model_type}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $model ${player_files} ${player_ret_dir}/${model_name}_${c}.csv.bz2" 29 | fi 30 | done 31 | done 32 | while [ `screen -ls | wc -l` -gt $max_screens ]; do 33 | printf "waiting\r" 34 | sleep 10 35 | done 36 | #screen -S "transfer-tests-${player}-${model_type}-kdd" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py $model ${kdd_path} ${player_ret_dir}/${model_name}_kdd.csv.bz2" 37 | done 38 | 39 | 40 | -------------------------------------------------------------------------------- /3-analysis/2-2-model_results_val.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=50 5 | 6 | targets_dir="../../data/transfer_players" 7 | outputs_dir="../../data/transfer_results_val" 8 | summaries_dir="../../data/transfer_summaries" 9 | kdd_path="../../data/reduced_kdd_test_set.csv.bz2" 10 | 11 | 12 | models_dir="../../transfer_training/final_models_val" 13 | mkdir -p $outputs_dir 14 | mkdir -p $summaries_dir 15 | 16 | for model in $models_dir/*/*; do 17 | player=`python3 get_models_player.py ${model}` 18 | model_type=`dirname ${model}` 19 | model_type=`basename ${model_type}` 20 | model_name=`basename ${model}` 21 | #echo $model $model_type $model_name $player 22 | for t in "train" "validate"; do 23 | for c in "white" "black"; do 24 | player_files=${targets_dir}_${t}/$player/csvs/test_${c}.csv.bz2 25 | if [ -f "$player_files" ]; then 26 | echo $player_files 27 | player_ret_dir=$outputs_dir/$t/$player/${t}/$model_type 28 | player_sum_dir=$summaries_dir/$t/$player/${t}/$model_type 29 | mkdir -p $player_ret_dir 30 | screen -S "val-tests-${player}-${model_type}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $model ${player_files} ${player_ret_dir}/${model_name}_${c}.csv.bz2;python3 make_summary.py ${player_ret_dir}/${model_name}_${c}.csv.bz2 ${player_sum_dir}/${model_name}_${c}.json" 31 | fi 32 | done 33 | done 34 | screen -S "val-tests-${player}-${model_type}-kdd" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py $model ${kdd_path} ${player_ret_dir}/${model_name}_kdd_reduced.csv.bz2;python3 make_summary.py ${player_ret_dir}/${model_name}_kdd_reduced.csv.bz2 35 | ${player_sum_dir}/${model_name}_kdd_reduced.csv.bz2" 36 | while [ `screen -ls | wc -l` -gt $max_screens ]; do 37 | printf "waiting\r" 38 | sleep 10 39 | done 40 | #screen -S "transfer-tests-${player}-${model_type}-kdd" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py $model ${kdd_path} ${player_ret_dir}/${model_name}_kdd.csv.bz2" 41 | done 42 | 43 | 44 | -------------------------------------------------------------------------------- /3-analysis/3-0-model_cross_table.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=40 5 | 6 | targets_dir="../../data/transfer_players" 7 | outputs_dir="../../data/transfer_results_cross" 8 | 9 | models_dir="../../transfer_training/final_models" 10 | 11 | target_models=`echo ../../transfer_training/final_models/{no_stop,unfrozen_copy}/*` 12 | 13 | mkdir -p $outputs_dir 14 | 15 | for model in $target_models; do 16 | player=`python3 get_models_player.py ${model}` 17 | model_type=`dirname ${model}` 18 | model_type=`basename ${model_type}` 19 | model_name=`basename ${model}` 20 | echo $player $model_type $model 21 | for c in "white" "black"; do 22 | for t in "train" "extended"; do 23 | player_files=${targets_dir}_${t}/$player/csvs/test_${c}.csv.bz2 24 | if [ -f "$player_files" ]; then 25 | player_ret_dir=$outputs_dir/$player 26 | mkdir -p $player_ret_dir 27 | echo $player_files 28 | for model2 in $target_models; do 29 | model2_name=`basename ${model2}` 30 | model2_player=`python3 get_models_player.py ${model2}` 31 | screen -S "cross-${player}-${model2_player}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $model2 ${player_files} ${player_ret_dir}/${model2_player}_${c}.csv.bz2" 32 | done 33 | while [ `screen -ls | wc -l` -gt $max_screens ]; do 34 | printf "waiting\r" 35 | sleep 10 36 | done 37 | fi 38 | done 39 | done 40 | done 41 | -------------------------------------------------------------------------------- /3-analysis/3-1-model_cross_table_val_generation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=40 5 | 6 | targets_dir="../../data/transfer_players_validate" 7 | outputs_dir="../../data/transfer_players_validate_cross_csvs" 8 | 9 | mkdir -p $outputs_dir 10 | 11 | for player_dir in $targets_dir/*; do 12 | player=`basename ${player_dir}` 13 | echo $player 14 | mkdir -p ${outputs_dir}/${player} 15 | for c in "white" "black"; do 16 | screen -S "cross-${player}-${c}" -dm bash -c "sourcer ~/.basrc; python3 csv_trimmer.py ${player_dir}/csvs/test_${c}.csv.bz2 ${outputs_dir}/${player}/test_${c}_reduced.csv.bz2" 17 | done 18 | done 19 | 20 | -------------------------------------------------------------------------------- /3-analysis/3-2-model_cross_table_val.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=40 5 | 6 | targets_dir="../../data/transfer_players" 7 | outputs_dir="../../data/transfer_results_cross_val" 8 | 9 | models_dir="../../transfer_training/final_models" 10 | 11 | target_models=`echo ../../transfer_training/final_models_val/unfrozen_copy/* ../../transfer_training/final_models/{no_stop,unfrozen_copy}/* ` 12 | 13 | mkdir -p $outputs_dir 14 | 15 | for model in $target_models; do 16 | player=`python3 get_models_player.py ${model}` 17 | model_type=`dirname ${model}` 18 | model_type=`basename ${model_type}` 19 | model_name=`basename ${model}` 20 | echo $player $model_type $model 21 | for c in "white" "black"; do 22 | for t in "train" "extended" "validate"; do 23 | player_files=${targets_dir}_${t}/$player/csvs/test_${c}.csv.bz2 24 | if [ -f "$player_files" ]; then 25 | player_ret_dir=$outputs_dir/$player 26 | mkdir -p $player_ret_dir 27 | echo $player_files 28 | for model2 in $target_models; do 29 | model2_name=`basename ${model2}` 30 | model2_player=`python3 get_models_player.py ${model2}` 31 | screen -S "cross-${player}-${model2_player}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $model2 ${player_files} ${player_ret_dir}/${model2_player}_${c}.csv.bz2" 32 | done 33 | while [ `screen -ls | wc -l` -gt $max_screens ]; do 34 | printf "waiting\r" 35 | sleep 10 36 | done 37 | fi 38 | done 39 | done 40 | done 41 | -------------------------------------------------------------------------------- /3-analysis/4-0-result_summaries.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=80 5 | 6 | targets_dir="../../data/transfer_results" 7 | outputs_dir="../../data/transfer_summaries" 8 | mkdir -p outputs_dir 9 | 10 | for p in $targets_dir/*/*/*/*.bz2 $targets_dir/*/*/*/*/*.bz2; do 11 | #result=$(echo "$p" | sed "s/$targets_dir/$outputs_dir/g") 12 | out_path=${p/$targets_dir/$outputs_dir} 13 | out_path=${out_path/.csv.bz2/.json} 14 | base=`dirname ${p/$targets_dir/}` 15 | base=${base//\//-} 16 | mkdir -p `dirname $out_path` 17 | echo $base 18 | #"${${}/${outputs_dir}/${targets_dir}}" 19 | screen -S "summary${base}" -dm bash -c "source ~/.bashrc; python3 make_summary.py ${p} ${out_path}" 20 | while [ `screen -ls | wc -l` -gt $max_screens ]; do 21 | printf "waiting\r" 22 | sleep 10 23 | done 24 | done 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /3-analysis/4-1-result_summaries_cross.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=80 5 | 6 | targets_dir="../../data/transfer_results_cross" 7 | outputs_dir="../../data/transfer_results_cross_summaries" 8 | mkdir -p outputs_dir 9 | 10 | for p in $targets_dir/*/*.bz2; do 11 | #result=$(echo "$p" | sed "s/$targets_dir/$outputs_dir/g") 12 | out_path=${p/$targets_dir/$outputs_dir} 13 | out_path=${out_path/.csv.bz2/.json} 14 | base=`dirname ${p/$targets_dir/}` 15 | base=${base//\//-} 16 | mkdir -p `dirname $out_path` 17 | echo $base 18 | #"${${}/${outputs_dir}/${targets_dir}}" 19 | screen -S "summary${base}" -dm bash -c "source ~/.bashrc; python3 make_summary.py ${p} ${out_path}" 20 | while [ `screen -ls | wc -l` -gt $max_screens ]; do 21 | printf "waiting\r" 22 | sleep 10 23 | done 24 | done 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /3-analysis/4-2-result_summaries_val.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=80 5 | 6 | targets_dir="../../data/transfer_results_val/validate" 7 | outputs_dir="../../data/transfer_summaries_val" 8 | mkdir -p outputs_dir 9 | 10 | for p in $targets_dir/*/*/*/*.bz2; do 11 | #result=$(echo "$p" | sed "s/$targets_dir/$outputs_dir/g") 12 | out_path=${p/$targets_dir/$outputs_dir} 13 | out_path=${out_path/.csv.bz2/.json} 14 | base=`dirname ${p/$targets_dir/}` 15 | base=${base//\//-} 16 | mkdir -p `dirname $out_path` 17 | echo $base 18 | #"${${}/${outputs_dir}/${targets_dir}}" 19 | screen -S "summary${base}" -dm bash -c "source ~/.bashrc; python3 make_summary.py ${p} ${out_path}" 20 | while [ `screen -ls | wc -l` -gt $max_screens ]; do 21 | printf "waiting\r" 22 | sleep 10 23 | done 24 | done 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /3-analysis/csv_trimmer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path 4 | import bz2 5 | import csv 6 | import multiprocessing 7 | import humanize 8 | import time 9 | import queue 10 | import json 11 | import pandas 12 | 13 | import chess 14 | 15 | import backend 16 | 17 | #@backend.logged_main 18 | def main(): 19 | parser = argparse.ArgumentParser(description='Run model on all the lines of the csv', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | 21 | parser.add_argument('input', help='input CSV') 22 | parser.add_argument('output', help='output CSV') 23 | parser.add_argument('--ngames', type=int, help='number of games to read in', default = 10) 24 | parser.add_argument('--min_ply', type=int, help='look at games with ply above this', default = 50) 25 | parser.add_argument('--max_ply', type=int, help='look at games with ply below this', default = 100) 26 | args = parser.parse_args() 27 | backend.printWithDate(f"Starting {args.input} to {args.output}") 28 | 29 | with bz2.open(args.input, 'rt') as fin, bz2.open(args.output, 'wt') as fout: 30 | reader = csv.DictReader(fin) 31 | writer = csv.DictWriter(fout, reader.fieldnames) 32 | writer.writeheader() 33 | games_count = 0 34 | current_game = None 35 | for row in reader: 36 | if args.min_ply is not None and int(row['num_ply']) <= args.min_ply: 37 | continue 38 | elif args.max_ply is not None and int(row['num_ply']) >= args.max_ply: 39 | continue 40 | elif row['game_id'] != current_game: 41 | current_game = row['game_id'] 42 | games_count += 1 43 | if args.ngames is not None and games_count >args.ngames: 44 | break 45 | writer.writerow(row) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /3-analysis/get_accuracy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path 4 | 5 | import pandas 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser(description='Quick helper for getting model accuracies', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 9 | parser.add_argument('inputs', nargs = '+', help='input CSVs') 10 | parser.add_argument('--nrows', help='num lines', type = int, default=None) 11 | args = parser.parse_args() 12 | 13 | for p in args.inputs: 14 | try: 15 | df = pandas.read_csv(p, nrows = args.nrows) 16 | except EOFError: 17 | print(f"{os.path.abspath(p).split('.')[0]} EOF") 18 | else: 19 | print(f"{os.path.abspath(p).split('.')[0]} {df['model_correct'].mean() * 100:.2f}%") 20 | 21 | if __name__ == "__main__": 22 | main() 23 | -------------------------------------------------------------------------------- /3-analysis/get_models_player.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path 4 | import yaml 5 | 6 | import pandas 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser(description='Quick helper for getting model players', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | parser.add_argument('input', help='input model dir') 11 | args = parser.parse_args() 12 | 13 | conf_path = os.path.abspath(os.path.join(args.input, "config.yaml")) 14 | if os.path.isfile(conf_path): 15 | with open(conf_path) as f: 16 | cfg = yaml.safe_load(f) 17 | try: 18 | print(cfg['full_config']['name']) 19 | except (KeyError, TypeError): 20 | #some have corrupted configs 21 | if 'Eron_Capivara' in args.input: 22 | print('Eron_Capivara') #hack 23 | else: 24 | 25 | print(os.path.basename(os.path.dirname(conf_path)).split('_')[0]) 26 | else: 27 | raise FileNotFoundError(f"Not a config path: {conf_path}") 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /3-analysis/make_summary.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path 4 | import json 5 | import re 6 | import glob 7 | 8 | import pandas 9 | import numpy as np 10 | 11 | root_dir = os.path.relpath("../..", start=os.path.dirname(os.path.abspath(__file__))) 12 | root_dir = os.path.abspath(root_dir) 13 | cats_ply = { 14 | 'early' : (0, 10), 15 | 'mid' : (11, 50), 16 | 'late' : (51, 999), 17 | 'kdd' : (11, 999), 18 | } 19 | 20 | last_n = [2**n for n in range(12)] 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser(description='Create summary json from results csv', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 24 | parser.add_argument('input', help='input CSV') 25 | parser.add_argument('output', help='output JSON') 26 | parser.add_argument('--players_infos', default="/ada/projects/chess/backend-backend/data/players_infos.json")#os.path.join(root_dir, 'data/players_infos.json')) 27 | args = parser.parse_args() 28 | 29 | with open(args.players_infos) as f: 30 | player_to_dat = json.load(f) 31 | 32 | fixed_data_paths = glob.glob("/ada/projects/chess/backend-backend/data/top_2000_player_data/*.csv.bz2") 33 | fixed_data_lookup = {p.split('/')[-1].replace('.csv.bz2','') :p for p in fixed_data_paths} 34 | 35 | 36 | df = collect_results_csv(args.input ,player_to_dat, fixed_data_lookup) 37 | 38 | r_d = dict(df.iloc[0]) 39 | sum_dict = { 40 | 'count' : len(df), 41 | 'player' : r_d['player_name'], 42 | 'model' : r_d['model_name'], 43 | 'backend' : bool(r_d['backend']), 44 | #'model_correct' : float(df['model_correct'].mean()), 45 | 'elo' : r_d['elo'], 46 | } 47 | c = 'white' if 'white' in args.input.split('/')[-1].split('_')[-1] else 'black' 48 | 49 | add_infos(sum_dict, "full", df) 50 | 51 | try: 52 | csv_raw_path = glob.glob(f"/ada/projects/chess/backend-backend/data/transfer_players_*/{r_d['player_name']}/csvs/test_{c}.csv.bz2")[0] 53 | except IndexError: 54 | if args.input.endswith('kdd.csv.bz2'): 55 | csv_raw_path = "/ada/projects/chess/backend-backend/data/reduced_kdd_test_set.csv.bz2" 56 | else: 57 | csv_raw_path = None 58 | if csv_raw_path is not None: 59 | csv_base = pandas.read_csv(csv_raw_path, index_col=['game_id', 'move_ply'], low_memory=False) 60 | csv_base['winrate_no_0'] = np.where(csv_base.reset_index()['move_ply'] < 2,np.nan, csv_base['winrate']) 61 | csv_base['next_wr'] = 1 - csv_base['winrate_no_0'].shift(-1) 62 | csv_base['move_delta_wr'] = csv_base['next_wr'] - csv_base['winrate'] 63 | csv_base_dropped = csv_base[~csv_base['winrate_loss'].isna()] 64 | csv_base_dropped = csv_base_dropped.join(df.set_index(['game_id', 'move_ply']), how = 'inner', lsuffix = 'r_') 65 | csv_base_dropped['move_delta_wr_rounded'] = (csv_base_dropped['move_delta_wr'] * 1).round(2) / 1 66 | 67 | for dr in csv_base_dropped['move_delta_wr_rounded'].unique(): 68 | if dr < 0 and dr >-.32: 69 | df_dr = csv_base_dropped[csv_base_dropped['move_delta_wr_rounded'] == dr] 70 | add_infos(sum_dict, f"delta_wr_{dr}", df_dr) 71 | 72 | for k, v in player_to_dat[r_d['player_name']].items(): 73 | if k != 'name': 74 | sum_dict[k] = v 75 | if r_d['backend']: 76 | sum_dict['backend_elo'] = int(r_d['model_name'].split('_')[-1]) 77 | 78 | 79 | for c, (p_min, p_max) in cats_ply.items(): 80 | df_c = df[(df['move_ply'] >= p_min) & (df['move_ply'] <= p_max)] 81 | add_infos(sum_dict, c, df_c) 82 | 83 | for year in df['UTCDate'].dt.year.unique(): 84 | df_y = df[df['UTCDate'].dt.year == year] 85 | add_infos(sum_dict, int(year), df_y) 86 | 87 | for ply in range(50): 88 | df_p = df[df['move_ply'] == ply] 89 | if len(df_p) > 0: 90 | # ignore the 50% missing ones 91 | add_infos(sum_dict, f"ply_{ply}", df_p) 92 | 93 | for won in [True, False]: 94 | df_w = df[df['won'] == won] 95 | add_infos(sum_dict, "won" if won else "lost", df_w) 96 | 97 | games = list(df.groupby('game_id').first().sort_values('UTCDate').index) 98 | 99 | for n in last_n: 100 | df_n = df[df['game_id'].isin(games[-n:])] 101 | add_infos(sum_dict, f"last_{n}", df_n) 102 | 103 | p_min, p_max = cats_ply['kdd'] 104 | df_n_kdd = df_n[(df_n['move_ply'] >= p_min) & (df_n['move_ply'] <= p_max)] 105 | add_infos(sum_dict, f"last_{n}_kdd", df_n_kdd) 106 | 107 | with open(args.output, 'wt') as f: 108 | json.dump(sum_dict, f) 109 | 110 | def collect_results_csv(path, player_to_dat, fixed_data_lookup): 111 | try: 112 | df = pandas.read_csv(path, low_memory=False) 113 | except EOFError: 114 | print(f"Error on: {path}") 115 | return None 116 | if len(df) < 1: 117 | return None 118 | df['colour'] = re.search("(black|white)\.csv\.bz2", path).group(1) 119 | #df['class'] = re.search(f"{base_dir}/([a-z]*)/", path).group(1) 120 | backend = 'final_backend_' in df['model_name'].iloc[0] 121 | df['backend'] = backend 122 | try: 123 | df['player'] = df['player_name'] 124 | except KeyError: 125 | pass 126 | try: 127 | if backend: 128 | df['model_type'] = df['model_name'].iloc[0].replace('final_', '') 129 | else: 130 | df['model_type'] = df['model_name'].iloc[0].replace(f'{df.iloc[0]["player"]}_', '') 131 | for k, v in player_to_dat[df.iloc[0]["player"]].items(): 132 | if k != 'name': 133 | df[k] = v 134 | games_df = pandas.read_csv(fixed_data_lookup[df['player'].iloc[0]], 135 | low_memory=False, parse_dates = ['UTCDate'], index_col = 'game_id') 136 | df = df.join(games_df, how= 'left', on = 'game_id', rsuffix='_per_game') 137 | except Exception as e: 138 | print(f"{e} : {path}") 139 | raise 140 | return df 141 | 142 | def add_infos(target_dict, name, df_sub): 143 | target_dict[f'model_correct_{name}'] = float(df_sub['model_correct'].dropna().mean()) 144 | target_dict[f'model_correct_per_game_{name}'] = float(df_sub.groupby(['game_id']).mean()['model_correct'].dropna().mean()) 145 | target_dict[f'count_{name}'] = len(df_sub) 146 | target_dict[f'std_{name}'] = float(df_sub['model_correct'].dropna().std()) 147 | target_dict[f'num_games_{name}'] = len(df_sub.groupby('game_id').count()) 148 | 149 | if __name__ == "__main__": 150 | main() 151 | -------------------------------------------------------------------------------- /3-analysis/move_predictions.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | data_dir="../../data/transfer_players_train" 5 | outputs_dir="../../data/transfer_players_train_results/weights_testing" 6 | maia_path="../../models/maia/1900" 7 | models_path="../models/weights_testing" 8 | 9 | kdd_path="../../datasets/10000_full_2019-12.csv.bz2" 10 | 11 | mkdir -p $outputs_dir 12 | 13 | for player_dir in $data_dir/*; do 14 | player_name=`basename ${player_dir}` 15 | echo $player_name 16 | mkdir -p $outputs_dir/$player_name 17 | 18 | for c in "white" "black"; do 19 | #echo "source ~/.bashrc; python3 prediction_generator.py --target_player ${player_name} ${models_path}/${player_name}*/ ${data_dir}/${player_name}/csvs/test_${c}.csv.bz2 $outputs_dir/${player_name}/transfer_test_${c}.csv.bz2" 20 | screen -S "test-transfer-${c}-${player_name}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player_name} ${models_path}/${player_name}*/ ${data_dir}/${player_name}/csvs/test_${c}.csv.bz2 $outputs_dir/${player_name}/transfer_test_${c}.csv.bz2" 21 | screen -S "test-maia-${c}-${player_name}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player_name} ${maia_path} ${data_dir}/${player_name}/csvs/test_${c}.csv.bz2 $outputs_dir/$player_name/maia_test_${c}.csv.bz2" 22 | done 23 | screen -S "kdd-transfer-${player_name}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py ${models_path}/${player_name}*/ ${kdd_path} $outputs_dir/$player_name/transfer_kdd.csv.bz2" 24 | 25 | while [ `screen -ls | wc -l` -gt 250 ]; do 26 | printf "waiting\r" 27 | sleep 10 28 | done 29 | done 30 | -------------------------------------------------------------------------------- /3-analysis/run-kdd-tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p ../data 4 | mkdir -p ../data/kdd_sweeps 5 | 6 | screen -S "kdd-sweep" -dm bash -c "source ~/.bashrc; python3 ../../analysis/move_prediction_csv.py ../../transfer_models ../../datasets/10000_full_2019-12.csv.bz2 ../data/kdd_sweeps" 7 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/README.md: -------------------------------------------------------------------------------- 1 | `cp_loss_hist`: This is where i save all the train/validation/test data for adding up all games 2 | 3 | `cp_loss_count_per_game`: This is where i save all the train/validation/test data per game. Note the counts haven't been normalized. 4 | 5 | `cp_loss_hist_per_move`: This is where i save all the train/validation/test data per move adding up all games. 6 | 7 | `cp_loss_hist_per_move_per_game`: This is where i save all the train/validation/test data per move per game. 8 | 9 | `cp_loss_hist_per_move_per_game_count`: This is where i save all the train/validation/test data per move per game in counts so they can be added later. 10 | 11 | `get_cp_loss.py`: Parsing code to get cp_loss and its histograms for both train and extended players, and save them in format of **.npy** 12 | 13 | `get_cp_loss_per_game.py`: Parsing code to get cp_loss and its histograms (counts) for extended players for each game, and save them in format of **.npy**. Note I don't normalize when saving, so I can sweep across it to get parametrization of num_games. 14 | 15 | `get_cp_loss_per_move.py`: Parsing code to get cp_loss and its histograms for both train and extended players for all games by moves, and save them in format of **.npy**. 16 | 17 | `get_cp_loss_per_move_per_game.py`: Parsing code to get cp_loss and its histograms for both train and extended players for each game by moves, and save them in format of **.npy**. 18 | 19 | `get_cp_loss_per_move_per_game_count`: Parsing code to get cp_loss and its histograms (counts) for both train and extended players for each game by moves, and save them in format of **.npy**. 20 | 21 | `test_all_games.py`: Baseline to test accuracy using all games instead of individual games, with Euclidean Distance or Naive Bayes. Data is from `cp_loss_hist`. 22 | 23 | `sweep_num_games.py`: Baseline using Euclidean Distance or Naive Bayes. Training Data is from `cp_loss_hist` and Test Data is from `cp_loss_count_per_game`. Will sweep across [1, 2, 4, 8, 16] number of games. 24 | 25 | `sweep_moves_per_game.py`: Naive Bayes on per move evaluation. This is done on average accuracy for each game. Training data is from `cp_loss_hist_per_move`, Test data is from `cp_loss_hist_per_move_per_game`. 26 | 27 | `sweep_moves_all_games.py`: Naive Bayes on per move evaluation. This is done on average accuracy for each game. Data is from `cp_loss_hist_per_move` 28 | 29 | `sweep_moves_num_games.py`: Naive Bayes on per move evaluation given number of games. Training data is from `cp_loss_hist_per_move`, Test data is from `cp_loss_hist_per_move_per_game_count`. Set it to 1 will be same as `sweep_moves_per_game.py` 30 | 31 | `train_cploss_per_game.py`: Baseline using simple neural network with 2 fully-connected layer. Training on each game, and also evaluate per game accuracy. **This now gives nan value when training on 30 players.** 32 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/get_cp_loss.py: -------------------------------------------------------------------------------- 1 | 2 | import bz2 3 | import csv 4 | import argparse 5 | import os 6 | import numpy as np 7 | # import matplotlib 8 | # matplotlib.use('TkAgg') 9 | import matplotlib.pyplot as plt 10 | import multiprocessing 11 | from functools import partial 12 | 13 | def parse_argument(): 14 | parser = argparse.ArgumentParser(description='arg parser') 15 | 16 | parser.add_argument('--input_dir', default='/data/transfer_players_validate') 17 | parser.add_argument('--player_name_dir', default='../transfer_training/final_models_val/unfrozen_copy') 18 | parser.add_argument('--saved_dir', default='cp_loss_hist') 19 | parser.add_argument('--will_save', default=True) 20 | 21 | return parser.parse_args() 22 | 23 | def normalize(data): 24 | norm = np.linalg.norm(data) 25 | data_norm = data/norm 26 | return data_norm 27 | 28 | def prepare_dataset(players, player_name, cp_hist, dataset): 29 | # add up black and white games (counts can be directly added) 30 | if players[player_name][dataset] is None: 31 | players[player_name][dataset] = cp_hist 32 | else: 33 | players[player_name][dataset] = players[player_name][dataset] + cp_hist 34 | 35 | def save_npy(saved_dir, players, player_name, dataset): 36 | if not os.path.exists(saved_dir): 37 | os.mkdir(saved_dir) 38 | 39 | saved = os.path.join(saved_dir, player_name + '_{}.npy'.format(dataset)) 40 | print('saving data to {}'.format(saved)) 41 | np.save(saved, players[player_name][dataset]) 42 | 43 | def multi_parse(input_dir, saved_dir, players, save, player_name): 44 | 45 | print("=============================================") 46 | print("parsing data for {}".format(player_name)) 47 | players[player_name] = {'train': None, 'validation': None, 'test': None} 48 | 49 | csv_dir = os.path.join(input_dir, player_name, 'csvs') 50 | # for each csv, add up black and white games (counts can be directly added) 51 | for csv_fname in os.listdir(csv_dir): 52 | path = os.path.join(csv_dir, csv_fname) 53 | # parse bz2 file 54 | source_file = bz2.BZ2File(path, "r") 55 | cp_hist, num_games = get_cp_loss_from_csv(player_name, source_file) 56 | print(path) 57 | 58 | if csv_fname.startswith('train'): 59 | prepare_dataset(players, player_name, cp_hist, 'train') 60 | 61 | elif csv_fname.startswith('validate'): 62 | prepare_dataset(players, player_name, cp_hist, 'validation') 63 | 64 | elif csv_fname.startswith('test'): 65 | prepare_dataset(players, player_name, cp_hist, 'test') 66 | 67 | # normalize the histogram to range [0, 1] 68 | players[player_name]['train'] = normalize(players[player_name]['train']) 69 | players[player_name]['validation'] = normalize(players[player_name]['validation']) 70 | players[player_name]['test'] = normalize(players[player_name]['test']) 71 | 72 | # save for future use, parsing takes too long... 73 | if save: 74 | save_npy(saved_dir, players, player_name, 'train') 75 | save_npy(saved_dir, players, player_name, 'validation') 76 | save_npy(saved_dir, players, player_name, 'test') 77 | 78 | def construct_datasets(player_names, input_dir, saved_dir, will_save): 79 | players = {} 80 | 81 | pool = multiprocessing.Pool(25) 82 | func = partial(multi_parse, input_dir, saved_dir, players, will_save) 83 | pool.map(func, player_names) 84 | pool.close() 85 | pool.join() 86 | 87 | def get_cp_loss_from_csv(player_name, path): 88 | cp_losses = [] 89 | games = {} 90 | with bz2.open(path, 'rt') as f: 91 | for i, line in enumerate(path): 92 | if i > 0: 93 | line = line.decode("utf-8") 94 | row = line.rstrip().split(',') 95 | # avoid empty line 96 | if row[0] == '': 97 | continue 98 | 99 | game_id = row[0] 100 | cp_loss = row[17] 101 | active_player = row[25] 102 | if player_name != active_player: 103 | continue 104 | 105 | # ignore cases like -inf, inf, nan 106 | if cp_loss != str(-1 * np.inf) and cp_loss != str(np.inf) and cp_loss != 'nan': 107 | # append cp loss per move 108 | cp_losses.append(float(cp_loss)) 109 | 110 | # for purpose of counting how many games 111 | if game_id not in games: 112 | games[game_id] = 1 113 | 114 | 115 | ######################## plot for viewing ######################## 116 | # plt.hist(cp_losses, density=False, bins=50) 117 | # plt.ylabel('Count') 118 | # plt.xlabel('Cp Loss') 119 | # plt.show() 120 | 121 | cp_hist = np.histogram(cp_losses, density=False, bins=50, range=(0, 5)) # density=False for counts 122 | 123 | cp_hist = cp_hist[0] # cp_hist in format of (hist count, range) 124 | 125 | print("number of games: {}".format(len(games))) 126 | 127 | return cp_hist, len(games) 128 | 129 | 130 | def get_player_names(player_name_dir): 131 | 132 | player_names = [] 133 | for player_name in os.listdir(player_name_dir): 134 | player = player_name.replace("_unfrozen_copy", "") 135 | player_names.append(player) 136 | 137 | # print(player_names) 138 | return player_names 139 | 140 | 141 | if __name__ == '__main__': 142 | args = parse_argument() 143 | 144 | player_names = get_player_names(args.player_name_dir) 145 | 146 | construct_datasets(player_names, args.input_dir, args.saved_dir, args.will_save) 147 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/get_cp_loss_per_game.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import csv 3 | import argparse 4 | import os 5 | import numpy as np 6 | # import matplotlib 7 | # matplotlib.use('TkAgg') 8 | import matplotlib.pyplot as plt 9 | import multiprocessing 10 | from functools import partial 11 | 12 | def parse_argument(): 13 | parser = argparse.ArgumentParser(description='arg parser') 14 | 15 | parser.add_argument('--input_dir', default='/data/transfer_players_validate') 16 | parser.add_argument('--player_name_dir', default='../transfer_training/final_models_val/unfrozen_copy') 17 | parser.add_argument('--saved_dir', default='cp_loss_count_per_game') 18 | parser.add_argument('--will_save', default=True) 19 | 20 | return parser.parse_args() 21 | 22 | def normalize(data): 23 | norm = np.linalg.norm(data) 24 | data_norm = data/norm 25 | return data_norm 26 | 27 | def prepare_dataset(players, player_name, games, dataset): 28 | # add up black and white games 29 | if players[player_name][dataset] is None: 30 | players[player_name][dataset] = games 31 | else: 32 | players[player_name][dataset].update(games) 33 | 34 | def save_npy(saved_dir, players, player_name, dataset): 35 | if not os.path.exists(saved_dir): 36 | os.mkdir(saved_dir) 37 | 38 | saved = os.path.join(saved_dir, player_name + '_{}.npy'.format(dataset)) 39 | print('saving data to {}'.format(saved)) 40 | print('total number of games: {}'.format(len(players[player_name][dataset]))) 41 | np.save(saved, players[player_name][dataset]) 42 | 43 | def multi_parse(input_dir, saved_dir, players, save, player_name): 44 | 45 | print("=============================================") 46 | print("parsing data for {}".format(player_name)) 47 | players[player_name] = {'train': None, 'validation': None, 'test': None} 48 | 49 | csv_dir = os.path.join(input_dir, player_name, 'csvs') 50 | for csv_fname in os.listdir(csv_dir): 51 | path = os.path.join(csv_dir, csv_fname) 52 | print(path) 53 | source_file = bz2.BZ2File(path, "r") 54 | games, num_games = get_cp_loss_from_csv(player_name, source_file) 55 | 56 | if csv_fname.startswith('train'): 57 | prepare_dataset(players, player_name, games, 'train') 58 | 59 | elif csv_fname.startswith('validate'): 60 | prepare_dataset(players, player_name, games, 'validation') 61 | 62 | elif csv_fname.startswith('test'): 63 | prepare_dataset(players, player_name, games, 'test') 64 | 65 | if save: 66 | save_npy(saved_dir, players, player_name, 'train') 67 | save_npy(saved_dir, players, player_name, 'validation') 68 | save_npy(saved_dir, players, player_name, 'test') 69 | 70 | def construct_datasets(player_names, input_dir, saved_dir, will_save): 71 | players = {} 72 | 73 | pool = multiprocessing.Pool(25) 74 | func = partial(multi_parse, input_dir, saved_dir, players, will_save) 75 | pool.map(func, player_names) 76 | pool.close() 77 | pool.join() 78 | 79 | def get_cp_loss_from_csv(player_name, path): 80 | cp_losses = [] 81 | games = {} 82 | with bz2.open(path, 'rt') as f: 83 | for i, line in enumerate(path): 84 | if i > 0: 85 | line = line.decode("utf-8") 86 | row = line.rstrip().split(',') 87 | # avoid empty line 88 | if row[0] == '': 89 | continue 90 | 91 | game_id = row[0] 92 | cp_loss = row[17] 93 | active_player = row[25] 94 | if player_name != active_player: 95 | continue 96 | 97 | if cp_loss != str(-1 * np.inf) and cp_loss != str(np.inf) and cp_loss != 'nan': 98 | cp_losses.append(float(cp_loss)) 99 | 100 | # for purpose of counting how many games 101 | if game_id not in games: 102 | games[game_id] = [float(cp_loss)] 103 | else: 104 | games[game_id].append(float(cp_loss)) 105 | 106 | # get per game histogram 107 | for key, value in games.items(): 108 | games[key] = np.histogram(games[key], density=False, bins=50, range=(0, 5)) 109 | # cp_hist in format (hist, range) 110 | games[key] = games[key][0] 111 | # games[key] = normalize(games[key]) 112 | 113 | print("number of games: {}".format(len(games))) 114 | 115 | return games, len(games) 116 | 117 | def get_player_names(player_name_dir): 118 | 119 | player_names = [] 120 | for player_name in os.listdir(player_name_dir): 121 | player = player_name.replace("_unfrozen_copy", "") 122 | player_names.append(player) 123 | 124 | # print(player_names) 125 | return player_names 126 | 127 | if __name__ == '__main__': 128 | args = parse_argument() 129 | 130 | player_names = get_player_names(args.player_name_dir) 131 | 132 | construct_datasets(player_names, args.input_dir, args.saved_dir, args.will_save) 133 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/get_cp_loss_per_move.py: -------------------------------------------------------------------------------- 1 | 2 | import bz2 3 | import csv 4 | import argparse 5 | import os 6 | import numpy as np 7 | # import matplotlib 8 | # matplotlib.use('TkAgg') 9 | import matplotlib.pyplot as plt 10 | import multiprocessing 11 | from functools import partial 12 | 13 | def parse_argument(): 14 | parser = argparse.ArgumentParser(description='arg parser') 15 | 16 | parser.add_argument('--input_dir', default='/data/transfer_players_validate') 17 | parser.add_argument('--player_name_dir', default='../transfer_training/final_models_val/unfrozen_copy') 18 | parser.add_argument('--saved_dir', default='cp_loss_hist_per_move') 19 | parser.add_argument('--will_save', default=True) 20 | 21 | return parser.parse_args() 22 | 23 | def normalize(data): 24 | for i in range(101): 25 | # update in-place 26 | if any(v != 0 for v in data['start_after'][i]): 27 | start_after_norm = np.linalg.norm(data['start_after'][i]) 28 | data['start_after'][i] = data['start_after'][i] / start_after_norm 29 | if any(v != 0 for v in data['stop_after'][i]): 30 | stop_after_norm = np.linalg.norm(data['stop_after'][i]) 31 | data['stop_after'][i] = data['stop_after'][i] / stop_after_norm 32 | 33 | 34 | def prepare_dataset(players, player_name, cp_loss_hist_dict, dataset): 35 | # add up black and white games (counts can be directly added) 36 | if players[player_name][dataset] is None: 37 | players[player_name][dataset] = cp_loss_hist_dict 38 | else: 39 | # add up each move 40 | for i in range(101): 41 | players[player_name][dataset]['start_after'][i] = players[player_name][dataset]['start_after'][i] + cp_loss_hist_dict['start_after'][i] 42 | players[player_name][dataset]['stop_after'][i] = players[player_name][dataset]['stop_after'][i] + cp_loss_hist_dict['stop_after'][i] 43 | 44 | def save_npy(saved_dir, players, player_name, dataset): 45 | if not os.path.exists(saved_dir): 46 | os.mkdir(saved_dir) 47 | 48 | saved = os.path.join(saved_dir, player_name + '_{}.npy'.format(dataset)) 49 | print('saving data to {}'.format(saved)) 50 | np.save(saved, players[player_name][dataset]) 51 | 52 | def multi_parse(input_dir, saved_dir, players, save, player_name): 53 | 54 | print("=============================================") 55 | print("parsing data for {}".format(player_name)) 56 | players[player_name] = {'train': None, 'validation': None, 'test': None} 57 | 58 | csv_dir = os.path.join(input_dir, player_name, 'csvs') 59 | # for each csv, add up black and white games (counts can be directly added) 60 | for csv_fname in os.listdir(csv_dir): 61 | path = os.path.join(csv_dir, csv_fname) 62 | # parse bz2 file 63 | source_file = bz2.BZ2File(path, "r") 64 | cp_loss_hist_dict, num_games = get_cp_loss_from_csv(player_name, source_file) 65 | print(path) 66 | 67 | if csv_fname.startswith('train'): 68 | prepare_dataset(players, player_name, cp_loss_hist_dict, 'train') 69 | 70 | elif csv_fname.startswith('validate'): 71 | prepare_dataset(players, player_name, cp_loss_hist_dict, 'validation') 72 | 73 | elif csv_fname.startswith('test'): 74 | prepare_dataset(players, player_name, cp_loss_hist_dict, 'test') 75 | 76 | # normalize the histogram to range [0, 1] 77 | normalize(players[player_name]['train']) 78 | normalize(players[player_name]['validation']) 79 | normalize(players[player_name]['test']) 80 | 81 | # save for future use, parsing takes too long... 82 | if save: 83 | save_npy(saved_dir, players, player_name, 'train') 84 | save_npy(saved_dir, players, player_name, 'validation') 85 | save_npy(saved_dir, players, player_name, 'test') 86 | 87 | def construct_datasets(player_names, input_dir, saved_dir, will_save): 88 | players = {} 89 | 90 | pool = multiprocessing.Pool(25) 91 | func = partial(multi_parse, input_dir, saved_dir, players, will_save) 92 | pool.map(func, player_names) 93 | pool.close() 94 | pool.join() 95 | 96 | def get_cp_loss_start_after(cp_losses, move_start=0): 97 | return cp_losses[move_start:] 98 | 99 | # 0 will be empty 100 | def get_cp_loss_stop_after(cp_losses, move_stop=100): 101 | return cp_losses[:move_stop] 102 | 103 | # move_stop is in range [0, 100] 104 | def get_cp_loss_from_csv(player_name, path): 105 | cp_loss_hist_dict = {'start_after': {}, 'stop_after': {}} 106 | # cp_losses = [] 107 | games = {} 108 | with bz2.open(path, 'rt') as f: 109 | for i, line in enumerate(path): 110 | if i > 0: 111 | line = line.decode("utf-8") 112 | row = line.rstrip().split(',') 113 | # avoid empty line 114 | if row[0] == '': 115 | continue 116 | 117 | active_player = row[25] 118 | if player_name != active_player: 119 | continue 120 | 121 | # move_ply starts from 0, need to add 1, move will be parsed in order 122 | move_ply = int(row[13]) 123 | move = move_ply // 2 + 1 124 | 125 | game_id = row[0] 126 | cp_loss = row[17] 127 | 128 | # ignore cases like -inf, inf, nan 129 | if cp_loss != str(-1 * np.inf) and cp_loss != str(np.inf) and cp_loss != 'nan': 130 | if game_id not in games: 131 | games[game_id] = [float(cp_loss)] 132 | else: 133 | games[game_id].append(float(cp_loss)) 134 | 135 | # get per game histogram 136 | for i in range(101): 137 | cp_loss_hist_dict['start_after'][i] = [] 138 | cp_loss_hist_dict['stop_after'][i] = [] 139 | for key, value in games.items(): 140 | cp_loss_hist_dict['start_after'][i].extend(get_cp_loss_start_after(value, i)) 141 | cp_loss_hist_dict['stop_after'][i].extend(get_cp_loss_stop_after(value, i)) 142 | 143 | # transform into counts 144 | cp_loss_hist_dict['start_after'][i] = np.histogram(cp_loss_hist_dict['start_after'][i], density=False, bins=50, range=(0, 5))[0] 145 | cp_loss_hist_dict['stop_after'][i] = np.histogram(cp_loss_hist_dict['stop_after'][i], density=False, bins=50, range=(0, 5))[0] 146 | 147 | 148 | print("number of games: {}".format(len(games))) 149 | 150 | return cp_loss_hist_dict, len(games) 151 | 152 | def get_player_names(player_name_dir): 153 | 154 | player_names = [] 155 | for player_name in os.listdir(player_name_dir): 156 | player = player_name.replace("_unfrozen_copy", "") 157 | player_names.append(player) 158 | 159 | # print(player_names) 160 | return player_names 161 | 162 | if __name__ == '__main__': 163 | args = parse_argument() 164 | 165 | player_names = get_player_names(args.player_name_dir) 166 | 167 | construct_datasets(player_names, args.input_dir, args.saved_dir, args.will_save) 168 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/get_cp_loss_per_move_per_game.py: -------------------------------------------------------------------------------- 1 | 2 | import bz2 3 | import csv 4 | import argparse 5 | import os 6 | import numpy as np 7 | # import matplotlib 8 | # matplotlib.use('TkAgg') 9 | import matplotlib.pyplot as plt 10 | import multiprocessing 11 | from functools import partial 12 | 13 | def parse_argument(): 14 | parser = argparse.ArgumentParser(description='arg parser') 15 | 16 | parser.add_argument('--input_dir', default='/data/transfer_players_validate') 17 | parser.add_argument('--player_name_dir', default='../transfer_training/final_models_val/unfrozen_copy') 18 | parser.add_argument('--saved_dir', default='cp_loss_hist_per_move_per_game') 19 | parser.add_argument('--will_save', default=True) 20 | 21 | return parser.parse_args() 22 | 23 | def normalize(data): 24 | data_norm = data 25 | 26 | if any(v != 0 for v in data): 27 | norm = np.linalg.norm(data) 28 | data_norm = data/norm 29 | 30 | return data_norm 31 | 32 | def prepare_dataset(players, player_name, games, dataset): 33 | # add up black and white games 34 | if players[player_name][dataset] is None: 35 | players[player_name][dataset] = games 36 | else: 37 | players[player_name][dataset].update(games) 38 | 39 | 40 | def save_npy(saved_dir, players, player_name, dataset): 41 | if not os.path.exists(saved_dir): 42 | os.mkdir(saved_dir) 43 | 44 | saved = os.path.join(saved_dir, player_name + '_{}.npy'.format(dataset)) 45 | print('saving data to {}'.format(saved)) 46 | np.save(saved, players[player_name][dataset]) 47 | 48 | def multi_parse(input_dir, saved_dir, players, save, player_name): 49 | 50 | print("=============================================") 51 | print("parsing data for {}".format(player_name)) 52 | players[player_name] = {'train': None, 'validation': None, 'test': None} 53 | 54 | csv_dir = os.path.join(input_dir, player_name, 'csvs') 55 | # for each csv, add up black and white games (counts can be directly added) 56 | for csv_fname in os.listdir(csv_dir): 57 | path = os.path.join(csv_dir, csv_fname) 58 | print(path) 59 | # parse bz2 file 60 | source_file = bz2.BZ2File(path, "r") 61 | games, num_games = get_cp_loss_from_csv(player_name, source_file) 62 | 63 | if csv_fname.startswith('train'): 64 | prepare_dataset(players, player_name, games, 'train') 65 | 66 | elif csv_fname.startswith('validate'): 67 | prepare_dataset(players, player_name, games, 'validation') 68 | 69 | elif csv_fname.startswith('test'): 70 | prepare_dataset(players, player_name, games, 'test') 71 | 72 | # save for future use, parsing takes too long... 73 | if save: 74 | save_npy(saved_dir, players, player_name, 'train') 75 | save_npy(saved_dir, players, player_name, 'validation') 76 | save_npy(saved_dir, players, player_name, 'test') 77 | 78 | def construct_datasets(player_names, input_dir, saved_dir, will_save): 79 | players = {} 80 | 81 | pool = multiprocessing.Pool(25) 82 | func = partial(multi_parse, input_dir, saved_dir, players, will_save) 83 | pool.map(func, player_names) 84 | pool.close() 85 | pool.join() 86 | 87 | 88 | def get_cp_loss_start_after(cp_losses, move_start=0): 89 | return cp_losses[move_start:] 90 | 91 | # 0 will be empty 92 | def get_cp_loss_stop_after(cp_losses, move_stop=100): 93 | return cp_losses[:move_stop] 94 | 95 | # move_stop is in range [0, 100] 96 | def get_cp_loss_from_csv(player_name, path): 97 | # cp_losses = [] 98 | games = {} 99 | with bz2.open(path, 'rt') as f: 100 | for i, line in enumerate(path): 101 | if i > 0: 102 | line = line.decode("utf-8") 103 | row = line.rstrip().split(',') 104 | # avoid empty line 105 | if row[0] == '': 106 | continue 107 | 108 | active_player = row[25] 109 | if player_name != active_player: 110 | continue 111 | 112 | # move_ply starts from 0, need to add 1, move will be parsed in order 113 | move_ply = int(row[13]) 114 | move = move_ply // 2 + 1 115 | 116 | game_id = row[0] 117 | cp_loss = row[17] 118 | 119 | # ignore cases like -inf, inf, nan 120 | if cp_loss != str(-1 * np.inf) and cp_loss != str(np.inf) and cp_loss != 'nan': 121 | if game_id not in games: 122 | games[game_id] = [float(cp_loss)] 123 | else: 124 | games[game_id].append(float(cp_loss)) 125 | 126 | final_games = {key: {'start_after': {}, 'stop_after': {}} for key in games.keys()} 127 | 128 | # get per game histogram 129 | for i in range(101): 130 | for key, value in games.items(): 131 | final_games[key]['start_after'][i] = get_cp_loss_start_after(value, i) 132 | final_games[key]['stop_after'][i] = get_cp_loss_stop_after(value, i) 133 | 134 | final_games[key]['start_after'][i] = np.histogram(final_games[key]['start_after'][i], density=False, bins=50, range=(0, 5))[0] 135 | final_games[key]['stop_after'][i] = np.histogram(final_games[key]['stop_after'][i], density=False, bins=50, range=(0, 5))[0] 136 | 137 | final_games[key]['start_after'][i] = normalize(final_games[key]['start_after'][i]) 138 | final_games[key]['stop_after'][i] = normalize(final_games[key]['stop_after'][i]) 139 | 140 | print("number of games: {}".format(len(games))) 141 | 142 | return final_games, len(games) 143 | 144 | 145 | def get_player_names(player_name_dir): 146 | 147 | player_names = [] 148 | for player_name in os.listdir(player_name_dir): 149 | player = player_name.replace("_unfrozen_copy", "") 150 | player_names.append(player) 151 | 152 | # print(player_names) 153 | return player_names 154 | 155 | if __name__ == '__main__': 156 | args = parse_argument() 157 | 158 | player_names = get_player_names(args.player_name_dir) 159 | 160 | construct_datasets(player_names, args.input_dir, args.saved_dir, args.will_save) 161 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/get_cp_loss_per_move_per_game_count.py: -------------------------------------------------------------------------------- 1 | 2 | import bz2 3 | import csv 4 | import argparse 5 | import os 6 | import numpy as np 7 | # import matplotlib 8 | # matplotlib.use('TkAgg') 9 | import matplotlib.pyplot as plt 10 | import multiprocessing 11 | from functools import partial 12 | 13 | def parse_argument(): 14 | parser = argparse.ArgumentParser(description='arg parser') 15 | 16 | parser.add_argument('--input_dir', default='/data/transfer_players_validate') 17 | parser.add_argument('--player_name_dir', default='../transfer_training/final_models_val/unfrozen_copy') 18 | parser.add_argument('--saved_dir', default='cp_loss_hist_per_move_per_game_count') 19 | parser.add_argument('--will_save', default=True) 20 | 21 | return parser.parse_args() 22 | 23 | def normalize(data): 24 | data_norm = data 25 | 26 | if any(v != 0 for v in data): 27 | norm = np.linalg.norm(data) 28 | data_norm = data/norm 29 | 30 | return data_norm 31 | 32 | def prepare_dataset(players, player_name, games, dataset): 33 | # add up black and white games 34 | if players[player_name][dataset] is None: 35 | players[player_name][dataset] = games 36 | else: 37 | players[player_name][dataset].update(games) 38 | 39 | 40 | def save_npy(saved_dir, players, player_name, dataset): 41 | if not os.path.exists(saved_dir): 42 | os.mkdir(saved_dir) 43 | 44 | saved = os.path.join(saved_dir, player_name + '_{}.npy'.format(dataset)) 45 | print('saving data to {}'.format(saved)) 46 | np.save(saved, players[player_name][dataset]) 47 | 48 | def multi_parse(input_dir, saved_dir, players, save, player_name): 49 | 50 | print("=============================================") 51 | print("parsing data for {}".format(player_name)) 52 | players[player_name] = {'train': None, 'validation': None, 'test': None} 53 | 54 | csv_dir = os.path.join(input_dir, player_name, 'csvs') 55 | # for each csv, add up black and white games (counts can be directly added) 56 | for csv_fname in os.listdir(csv_dir): 57 | path = os.path.join(csv_dir, csv_fname) 58 | print(path) 59 | # parse bz2 file 60 | source_file = bz2.BZ2File(path, "r") 61 | games, num_games = get_cp_loss_from_csv(player_name, source_file) 62 | 63 | if csv_fname.startswith('train'): 64 | prepare_dataset(players, player_name, games, 'train') 65 | 66 | elif csv_fname.startswith('validate'): 67 | prepare_dataset(players, player_name, games, 'validation') 68 | 69 | elif csv_fname.startswith('test'): 70 | prepare_dataset(players, player_name, games, 'test') 71 | 72 | # save for future use, parsing takes too long... 73 | if save: 74 | save_npy(saved_dir, players, player_name, 'train') 75 | save_npy(saved_dir, players, player_name, 'validation') 76 | save_npy(saved_dir, players, player_name, 'test') 77 | 78 | def construct_datasets(player_names, input_dir, saved_dir, will_save): 79 | players = {} 80 | 81 | pool = multiprocessing.Pool(40) 82 | func = partial(multi_parse, input_dir, saved_dir, players, will_save) 83 | pool.map(func, player_names) 84 | pool.close() 85 | pool.join() 86 | 87 | 88 | def get_cp_loss_start_after(cp_losses, move_start=0): 89 | return cp_losses[move_start:] 90 | 91 | # 0 will be empty 92 | def get_cp_loss_stop_after(cp_losses, move_stop=100): 93 | return cp_losses[:move_stop] 94 | 95 | # move_stop is in range [0, 100] 96 | def get_cp_loss_from_csv(player_name, path): 97 | # cp_losses = [] 98 | games = {} 99 | with bz2.open(path, 'rt') as f: 100 | for i, line in enumerate(path): 101 | if i > 0: 102 | line = line.decode("utf-8") 103 | row = line.rstrip().split(',') 104 | # avoid empty line 105 | if row[0] == '': 106 | continue 107 | 108 | active_player = row[25] 109 | if player_name != active_player: 110 | continue 111 | 112 | # move_ply starts from 0, need to add 1, move will be parsed in order 113 | move_ply = int(row[13]) 114 | move = move_ply // 2 + 1 115 | 116 | game_id = row[0] 117 | cp_loss = row[17] 118 | 119 | if game_id in games: 120 | if cp_loss == str(-1 * np.inf) or cp_loss == str(np.inf) or cp_loss == 'nan': 121 | cp_loss = float(-100) 122 | 123 | # ignore cases like -inf, inf, nan 124 | if cp_loss != str(-1 * np.inf) and cp_loss != str(np.inf) and cp_loss != 'nan': 125 | if game_id not in games: 126 | games[game_id] = [float(cp_loss)] 127 | else: 128 | games[game_id].append(float(cp_loss)) 129 | 130 | final_games = {key: {'start_after': {}, 'stop_after': {}} for key, value in games.items() if len(value) > 25 and len(value) < 50} 131 | # final_games = {key: {'start_after': {}, 'stop_after': {}} for key in games.keys()} 132 | 133 | # get per game histogram 134 | for i in range(101): 135 | for key, value in games.items(): 136 | if len(value) > 25 and len(value) < 50: 137 | if key not in final_games: 138 | print(key) 139 | 140 | final_games[key]['start_after'][i] = get_cp_loss_start_after(value, i) 141 | final_games[key]['stop_after'][i] = get_cp_loss_stop_after(value, i) 142 | 143 | final_games[key]['start_after'][i] = np.histogram(final_games[key]['start_after'][i], density=False, bins=50, range=(0, 5))[0] 144 | final_games[key]['stop_after'][i] = np.histogram(final_games[key]['stop_after'][i], density=False, bins=50, range=(0, 5))[0] 145 | 146 | # final_games[key]['start_after'][i] = normalize(final_games[key]['start_after'][i]) 147 | # final_games[key]['stop_after'][i] = normalize(final_games[key]['stop_after'][i]) 148 | 149 | print("number of games: {}".format(len(games))) 150 | return final_games, len(games) 151 | 152 | 153 | def get_player_names(player_name_dir): 154 | 155 | player_names = [] 156 | for player_name in os.listdir(player_name_dir): 157 | player = player_name.replace("_unfrozen_copy", "") 158 | player_names.append(player) 159 | 160 | # print(player_names) 161 | return player_names 162 | 163 | if __name__ == '__main__': 164 | args = parse_argument() 165 | 166 | player_names = get_player_names(args.player_name_dir) 167 | 168 | construct_datasets(player_names, args.input_dir, args.saved_dir, args.will_save) 169 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results/games_accuracy.csv: -------------------------------------------------------------------------------- 1 | num_games,accuracy 2 | 1,0.05579916684937514 3 | 2,0.06811689738519007 4 | 4,0.08245149911816578 5 | 8,0.11002661934338953 6 | 16,0.1223021582733813 7 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results/start_after.csv: -------------------------------------------------------------------------------- 1 | move,accuracy 2 | 0,0.059309021113243765 3 | 1,0.05508637236084453 4 | 2,0.053358925143953934 5 | 3,0.052207293666026874 6 | 4,0.050287907869481764 7 | 5,0.054318618042226485 8 | 6,0.05182341650671785 9 | 7,0.05105566218809981 10 | 8,0.05259117082533589 11 | 9,0.05356114417354579 12 | 10,0.05586484929928969 13 | 11,0.05682472643501632 14 | 12,0.05779569892473118 15 | 13,0.05915114269252929 16 | 14,0.05859750240153699 17 | 15,0.05573707476455891 18 | 16,0.05714835482008851 19 | 17,0.05772200772200772 20 | 18,0.05515773175924134 21 | 19,0.05539358600583091 22 | 20,0.05190989226248776 23 | 21,0.05554457402648745 24 | 22,0.05923836389280677 25 | 23,0.06203627370156636 26 | 24,0.0579647917561185 27 | 25,0.053575482406356414 28 | 26,0.053108174253548704 29 | 27,0.05705944798301486 30 | 28,0.061177152797912436 31 | 29,0.048 32 | 30,0.05115452930728242 33 | 31,0.045636509207365894 34 | 32,0.040467625899280574 35 | 33,0.04314329738058552 36 | 34,0.03882915173237754 37 | 35,0.03529411764705882 38 | 36,0.0575831305758313 39 | 37,0.04664723032069971 40 | 38,0.05052878965922444 41 | 39,0.04850213980028531 42 | 40,0.06204379562043796 43 | 41,0.06697459584295612 44 | 42,0.06382978723404255 45 | 43,0.05761316872427984 46 | 44,0.08928571428571429 47 | 45,0.13274336283185842 48 | 46,0.1506849315068493 49 | 47,0.2571428571428571 50 | 48,0.3076923076923077 51 | 49,0 52 | 50,0 53 | 51,0 54 | 52,0 55 | 53,0 56 | 54,0 57 | 55,0 58 | 56,0 59 | 57,0 60 | 58,0 61 | 59,0 62 | 60,0 63 | 61,0 64 | 62,0 65 | 63,0 66 | 64,0 67 | 65,0 68 | 66,0 69 | 67,0 70 | 68,0 71 | 69,0 72 | 70,0 73 | 71,0 74 | 72,0 75 | 73,0 76 | 74,0 77 | 75,0 78 | 76,0 79 | 77,0 80 | 78,0 81 | 79,0 82 | 80,0 83 | 81,0 84 | 82,0 85 | 83,0 86 | 84,0 87 | 85,0 88 | 86,0 89 | 87,0 90 | 88,0 91 | 89,0 92 | 90,0 93 | 91,0 94 | 92,0 95 | 93,0 96 | 94,0 97 | 95,0 98 | 96,0 99 | 97,0 100 | 98,0 101 | 99,0 102 | 100,0 103 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results/start_after_all_game.csv: -------------------------------------------------------------------------------- 1 | move,accuracy 2 | 0,0.6333333333333333 3 | 1,0.6666666666666666 4 | 2,0.5666666666666667 5 | 3,0.5666666666666667 6 | 4,0.5333333333333333 7 | 5,0.5333333333333333 8 | 6,0.5666666666666667 9 | 7,0.5333333333333333 10 | 8,0.5 11 | 9,0.5666666666666667 12 | 10,0.5 13 | 11,0.4666666666666667 14 | 12,0.4666666666666667 15 | 13,0.5 16 | 14,0.5 17 | 15,0.43333333333333335 18 | 16,0.3 19 | 17,0.23333333333333334 20 | 18,0.3 21 | 19,0.3333333333333333 22 | 20,0.3 23 | 21,0.36666666666666664 24 | 22,0.3333333333333333 25 | 23,0.26666666666666666 26 | 24,0.3 27 | 25,0.23333333333333334 28 | 26,0.26666666666666666 29 | 27,0.36666666666666664 30 | 28,0.26666666666666666 31 | 29,0.3 32 | 30,0.4666666666666667 33 | 31,0.43333333333333335 34 | 32,0.36666666666666664 35 | 33,0.36666666666666664 36 | 34,0.36666666666666664 37 | 35,0.3333333333333333 38 | 36,0.4 39 | 37,0.36666666666666664 40 | 38,0.4 41 | 39,0.4 42 | 40,0.4 43 | 41,0.23333333333333334 44 | 42,0.3333333333333333 45 | 43,0.2857142857142857 46 | 44,0.3333333333333333 47 | 45,0.2222222222222222 48 | 46,0.28 49 | 47,0.5 50 | 48,0.3333333333333333 51 | 49,0 52 | 50,0 53 | 51,0 54 | 52,0 55 | 53,0 56 | 54,0 57 | 55,0 58 | 56,0 59 | 57,0 60 | 58,0 61 | 59,0 62 | 60,0 63 | 61,0 64 | 62,0 65 | 63,0 66 | 64,0 67 | 65,0 68 | 66,0 69 | 67,0 70 | 68,0 71 | 69,0 72 | 70,0 73 | 71,0 74 | 72,0 75 | 73,0 76 | 74,0 77 | 75,0 78 | 76,0 79 | 77,0 80 | 78,0 81 | 79,0 82 | 80,0 83 | 81,0 84 | 82,0 85 | 83,0 86 | 84,0 87 | 85,0 88 | 86,0 89 | 87,0 90 | 88,0 91 | 89,0 92 | 90,0 93 | 91,0 94 | 92,0 95 | 93,0 96 | 94,0 97 | 95,0 98 | 96,0 99 | 97,0 100 | 98,0 101 | 99,0 102 | 100,0 103 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results/stop_after.csv: -------------------------------------------------------------------------------- 1 | move,accuracy 2 | 0,0.01689059500959693 3 | 1,0.10134357005758157 4 | 2,0.07044145873320537 5 | 3,0.06967370441458733 6 | 4,0.0783109404990403 7 | 5,0.06756238003838771 8 | 6,0.06641074856046066 9 | 7,0.06602687140115163 10 | 8,0.06506717850287908 11 | 9,0.06641074856046066 12 | 10,0.06238003838771593 13 | 11,0.05911708253358925 14 | 12,0.0581573896353167 15 | 13,0.059309021113243765 16 | 14,0.0600767754318618 17 | 15,0.06065259117082534 18 | 16,0.06218809980806142 19 | 17,0.06065259117082534 20 | 18,0.061420345489443376 21 | 19,0.061036468330134354 22 | 20,0.060460652591170824 23 | 21,0.061420345489443376 24 | 22,0.060844529750479846 25 | 23,0.06161228406909789 26 | 24,0.0600767754318618 27 | 25,0.05969289827255278 28 | 26,0.060844529750479846 29 | 27,0.06238003838771593 30 | 28,0.05969289827255278 31 | 29,0.05969289827255278 32 | 30,0.06333973128598848 33 | 31,0.061036468330134354 34 | 32,0.061036468330134354 35 | 33,0.06161228406909789 36 | 34,0.060844529750479846 37 | 35,0.05854126679462572 38 | 36,0.05969289827255278 39 | 37,0.0600767754318618 40 | 38,0.059309021113243765 41 | 39,0.05892514395393474 42 | 40,0.059884836852207295 43 | 41,0.05892514395393474 44 | 42,0.05911708253358925 45 | 43,0.05873320537428023 46 | 44,0.05873320537428023 47 | 45,0.05873320537428023 48 | 46,0.05854126679462572 49 | 47,0.0581573896353167 50 | 48,0.05873320537428023 51 | 49,0.059309021113243765 52 | 50,0.059309021113243765 53 | 51,0.059309021113243765 54 | 52,0.059309021113243765 55 | 53,0.059309021113243765 56 | 54,0.059309021113243765 57 | 55,0.059309021113243765 58 | 56,0.059309021113243765 59 | 57,0.059309021113243765 60 | 58,0.059309021113243765 61 | 59,0.059309021113243765 62 | 60,0.059309021113243765 63 | 61,0.059309021113243765 64 | 62,0.059309021113243765 65 | 63,0.059309021113243765 66 | 64,0.059309021113243765 67 | 65,0.059309021113243765 68 | 66,0.059309021113243765 69 | 67,0.059309021113243765 70 | 68,0.059309021113243765 71 | 69,0.059309021113243765 72 | 70,0.059309021113243765 73 | 71,0.059309021113243765 74 | 72,0.059309021113243765 75 | 73,0.059309021113243765 76 | 74,0.059309021113243765 77 | 75,0.059309021113243765 78 | 76,0.059309021113243765 79 | 77,0.059309021113243765 80 | 78,0.059309021113243765 81 | 79,0.059309021113243765 82 | 80,0.059309021113243765 83 | 81,0.059309021113243765 84 | 82,0.059309021113243765 85 | 83,0.059309021113243765 86 | 84,0.059309021113243765 87 | 85,0.059309021113243765 88 | 86,0.059309021113243765 89 | 87,0.059309021113243765 90 | 88,0.059309021113243765 91 | 89,0.059309021113243765 92 | 90,0.059309021113243765 93 | 91,0.059309021113243765 94 | 92,0.059309021113243765 95 | 93,0.059309021113243765 96 | 94,0.059309021113243765 97 | 95,0.059309021113243765 98 | 96,0.059309021113243765 99 | 97,0.059309021113243765 100 | 98,0.059309021113243765 101 | 99,0.059309021113243765 102 | 100,0.059309021113243765 103 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results/stop_after_all_game.csv: -------------------------------------------------------------------------------- 1 | move,accuracy 2 | 0,0.03333333333333333 3 | 1,0.16666666666666666 4 | 2,0.5333333333333333 5 | 3,0.6333333333333333 6 | 4,0.5 7 | 5,0.5 8 | 6,0.5 9 | 7,0.43333333333333335 10 | 8,0.43333333333333335 11 | 9,0.43333333333333335 12 | 10,0.5 13 | 11,0.43333333333333335 14 | 12,0.4666666666666667 15 | 13,0.5 16 | 14,0.5333333333333333 17 | 15,0.5 18 | 16,0.5333333333333333 19 | 17,0.5666666666666667 20 | 18,0.5333333333333333 21 | 19,0.5666666666666667 22 | 20,0.6 23 | 21,0.5333333333333333 24 | 22,0.5666666666666667 25 | 23,0.5333333333333333 26 | 24,0.6 27 | 25,0.6 28 | 26,0.6 29 | 27,0.6666666666666666 30 | 28,0.6333333333333333 31 | 29,0.6333333333333333 32 | 30,0.6333333333333333 33 | 31,0.6666666666666666 34 | 32,0.6666666666666666 35 | 33,0.6666666666666666 36 | 34,0.6 37 | 35,0.6 38 | 36,0.6 39 | 37,0.6 40 | 38,0.5666666666666667 41 | 39,0.6 42 | 40,0.6 43 | 41,0.6 44 | 42,0.6333333333333333 45 | 43,0.6 46 | 44,0.6333333333333333 47 | 45,0.6333333333333333 48 | 46,0.6333333333333333 49 | 47,0.6333333333333333 50 | 48,0.6333333333333333 51 | 49,0.6333333333333333 52 | 50,0.6333333333333333 53 | 51,0.6333333333333333 54 | 52,0.6333333333333333 55 | 53,0.6333333333333333 56 | 54,0.6333333333333333 57 | 55,0.6333333333333333 58 | 56,0.6333333333333333 59 | 57,0.6333333333333333 60 | 58,0.6333333333333333 61 | 59,0.6333333333333333 62 | 60,0.6333333333333333 63 | 61,0.6333333333333333 64 | 62,0.6333333333333333 65 | 63,0.6333333333333333 66 | 64,0.6333333333333333 67 | 65,0.6333333333333333 68 | 66,0.6333333333333333 69 | 67,0.6333333333333333 70 | 68,0.6333333333333333 71 | 69,0.6333333333333333 72 | 70,0.6333333333333333 73 | 71,0.6333333333333333 74 | 72,0.6333333333333333 75 | 73,0.6333333333333333 76 | 74,0.6333333333333333 77 | 75,0.6333333333333333 78 | 76,0.6333333333333333 79 | 77,0.6333333333333333 80 | 78,0.6333333333333333 81 | 79,0.6333333333333333 82 | 80,0.6333333333333333 83 | 81,0.6333333333333333 84 | 82,0.6333333333333333 85 | 83,0.6333333333333333 86 | 84,0.6333333333333333 87 | 85,0.6333333333333333 88 | 86,0.6333333333333333 89 | 87,0.6333333333333333 90 | 88,0.6333333333333333 91 | 89,0.6333333333333333 92 | 90,0.6333333333333333 93 | 91,0.6333333333333333 94 | 92,0.6333333333333333 95 | 93,0.6333333333333333 96 | 94,0.6333333333333333 97 | 95,0.6333333333333333 98 | 96,0.6333333333333333 99 | 97,0.6333333333333333 100 | 98,0.6333333333333333 101 | 99,0.6333333333333333 102 | 100,0.6333333333333333 103 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results_validation/games_accuracy.csv: -------------------------------------------------------------------------------- 1 | num_games,accuracy 2 | 1,0.005322294500295683 3 | 2,0.007152042305413879 4 | 4,0.00927616894222686 5 | 8,0.013657853265627736 6 | 16,0.020702709097361882 7 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results_validation/start_after.csv: -------------------------------------------------------------------------------- 1 | move,accuracy 2 | 0,0.008001119393492254 3 | 1,0.007988399012898465 4 | 2,0.006487394102831557 5 | 3,0.00449029434960694 6 | 4,0.0042867682601063425 7 | 5,0.004032360648230595 8 | 6,0.004032360648230595 9 | 7,0.004235940620508059 10 | 8,0.004312373586393762 11 | 9,0.004134439242825158 12 | 10,0.004134649636150832 13 | 11,0.0044149829507863 14 | 12,0.004517631488527761 15 | 13,0.004315392840776007 16 | 14,0.004419762835780974 17 | 15,0.004282436910527657 18 | 16,0.004453405132262304 19 | 17,0.0038727983844167794 20 | 18,0.003896336930609315 21 | 19,0.0038616499543038087 22 | 20,0.0037848837962902956 23 | 21,0.0037588076590617386 24 | 22,0.004161569962240068 25 | 23,0.0038119613902767757 26 | 24,0.0037130110684436414 27 | 25,0.003143389199255121 28 | 26,0.002826501275963433 29 | 27,0.002573004599686305 30 | 28,0.0027073019801980196 31 | 29,0.002663825253063399 32 | 30,0.0027465372321534274 33 | 31,0.0026176626123744053 34 | 32,0.0026072529035316427 35 | 33,0.0031696249833177634 36 | 34,0.0031620252200083815 37 | 35,0.0035925520262869663 38 | 36,0.003196184871391609 39 | 37,0.0030963439323567943 40 | 38,0.0032433194669674965 41 | 39,0.004067796610169492 42 | 40,0.003640902943930095 43 | 41,0.003702234563004099 44 | 42,0.004691572545612511 45 | 43,0.004493850520340586 46 | 44,0.0050150451354062184 47 | 45,0.010349926071956629 48 | 46,0.008507347254447023 49 | 47,0.008253094910591471 50 | 48,0.038338658146964855 51 | 49,0 52 | 50,0 53 | 51,0 54 | 52,0 55 | 53,0 56 | 54,0 57 | 55,0 58 | 56,0 59 | 57,0 60 | 58,0 61 | 59,0 62 | 60,0 63 | 61,0 64 | 62,0 65 | 63,0 66 | 64,0 67 | 65,0 68 | 66,0 69 | 67,0 70 | 68,0 71 | 69,0 72 | 70,0 73 | 71,0 74 | 72,0 75 | 73,0 76 | 74,0 77 | 75,0 78 | 76,0 79 | 77,0 80 | 78,0 81 | 79,0 82 | 80,0 83 | 81,0 84 | 82,0 85 | 83,0 86 | 84,0 87 | 85,0 88 | 86,0 89 | 87,0 90 | 88,0 91 | 89,0 92 | 90,0 93 | 91,0 94 | 92,0 95 | 93,0 96 | 94,0 97 | 95,0 98 | 96,0 99 | 97,0 100 | 98,0 101 | 99,0 102 | 100,0 103 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results_validation/start_after_4games.csv: -------------------------------------------------------------------------------- 1 | move,accuracy 2 | 0,0.013087661671114761 3 | 1,0.01365222746869226 4 | 2,0.011599260932046808 5 | 3,0.00949497023198522 6 | 4,0.008160541983165674 7 | 5,0.007647300349004311 8 | 6,0.008160541983165674 9 | 7,0.0083662680285377 10 | 8,0.00810963403993225 11 | 9,0.00810963403993225 12 | 10,0.00805913454134798 13 | 11,0.00765004877547877 14 | 12,0.008062445437272121 15 | 13,0.008527689304428234 16 | 14,0.008069490131578948 17 | 15,0.007920588386565858 18 | 16,0.007926295743476247 19 | 17,0.006804825239715435 20 | 18,0.006308822008480711 21 | 19,0.005659691572771172 22 | 20,0.0067447453727909655 23 | 21,0.0066029264169880095 24 | 22,0.0075660012878300065 25 | 23,0.006154184295840431 26 | 24,0.006217557469625236 27 | 25,0.005656517029726802 28 | 26,0.004635977799542932 29 | 27,0.0037816625044595075 30 | 28,0.00454331818893937 31 | 29,0.0036316472114137485 32 | 30,0.0045231450293523245 33 | 31,0.004412397761515282 34 | 32,0.0063978754225012075 35 | 33,0.006824075337791729 36 | 34,0.006715602061533656 37 | 35,0.007404731804226115 38 | 36,0.009222385244183609 39 | 37,0.0061942517343904855 40 | 38,0.006505026611472502 41 | 39,0.00972972972972973 42 | 40,0.009379187137114784 43 | 41,0.01678240740740741 44 | 42,0.01564945226917058 45 | 43,0.022598870056497175 46 | 44,0.029209621993127148 47 | 45,0.03932584269662921 48 | 46,0.06349206349206349 49 | 47,0.0641025641025641 50 | 48,0.375 51 | 49,0 52 | 50,0 53 | 51,0 54 | 52,0 55 | 53,0 56 | 54,0 57 | 55,0 58 | 56,0 59 | 57,0 60 | 58,0 61 | 59,0 62 | 60,0 63 | 61,0 64 | 62,0 65 | 63,0 66 | 64,0 67 | 65,0 68 | 66,0 69 | 67,0 70 | 68,0 71 | 69,0 72 | 70,0 73 | 71,0 74 | 72,0 75 | 73,0 76 | 74,0 77 | 75,0 78 | 76,0 79 | 77,0 80 | 78,0 81 | 79,0 82 | 80,0 83 | 81,0 84 | 82,0 85 | 83,0 86 | 84,0 87 | 85,0 88 | 86,0 89 | 87,0 90 | 88,0 91 | 89,0 92 | 90,0 93 | 91,0 94 | 92,0 95 | 93,0 96 | 94,0 97 | 95,0 98 | 96,0 99 | 97,0 100 | 98,0 101 | 99,0 102 | 100,0 103 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results_validation/stop_after.csv: -------------------------------------------------------------------------------- 1 | move,accuracy 2 | 0,0.0015773271936296335 3 | 1,0.003510825043885313 4 | 2,0.01217340422825451 5 | 3,0.011651868623909227 6 | 4,0.010201745236217467 7 | 5,0.010367110183936703 8 | 6,0.010061821049685806 9 | 7,0.009985498766123082 10 | 8,0.009743811534841123 11 | 9,0.00961660772890325 12 | 10,0.00927315745287099 13 | 11,0.009285877833464778 14 | 12,0.009196835169308266 15 | 13,0.009311318594652352 16 | 14,0.009362200117027502 17 | 15,0.009234996311089629 18 | 16,0.008878825654463582 19 | 17,0.008853384893276008 20 | 18,0.008993309079807667 21 | 19,0.00884066451268222 22 | 20,0.009018749840995242 23 | 21,0.00927315745287099 24 | 22,0.009069631363370393 25 | 23,0.008967868318620092 26 | 24,0.008777062609713282 27 | 25,0.008916986796244943 28 | 26,0.008866105273869794 29 | 27,0.008878825654463582 30 | 28,0.008853384893276008 31 | 29,0.00884066451268222 32 | 30,0.008929707176838731 33 | 31,0.008904266415651157 34 | 32,0.008891546035057369 35 | 33,0.008929707176838731 36 | 34,0.008815223751494645 37 | 35,0.008649858803775409 38 | 36,0.008700740326150558 39 | 37,0.008713460706744346 40 | 38,0.00854809575902511 41 | 39,0.00830640852774315 42 | 40,0.008433612333681024 43 | 41,0.008420891953087236 44 | 42,0.008319128908336937 45 | 43,0.008344569669524512 46 | 44,0.00820464548299285 47 | 45,0.007975678632304679 48 | 46,0.008090162057648766 49 | 47,0.008039280535273615 50 | 48,0.008001119393492254 51 | 49,0.008001119393492254 52 | 50,0.008001119393492254 53 | 51,0.008001119393492254 54 | 52,0.008001119393492254 55 | 53,0.008001119393492254 56 | 54,0.008001119393492254 57 | 55,0.008001119393492254 58 | 56,0.008001119393492254 59 | 57,0.008001119393492254 60 | 58,0.008001119393492254 61 | 59,0.008001119393492254 62 | 60,0.008001119393492254 63 | 61,0.008001119393492254 64 | 62,0.008001119393492254 65 | 63,0.008001119393492254 66 | 64,0.008001119393492254 67 | 65,0.008001119393492254 68 | 66,0.008001119393492254 69 | 67,0.008001119393492254 70 | 68,0.008001119393492254 71 | 69,0.008001119393492254 72 | 70,0.008001119393492254 73 | 71,0.008001119393492254 74 | 72,0.008001119393492254 75 | 73,0.008001119393492254 76 | 74,0.008001119393492254 77 | 75,0.008001119393492254 78 | 76,0.008001119393492254 79 | 77,0.008001119393492254 80 | 78,0.008001119393492254 81 | 79,0.008001119393492254 82 | 80,0.008001119393492254 83 | 81,0.008001119393492254 84 | 82,0.008001119393492254 85 | 83,0.008001119393492254 86 | 84,0.008001119393492254 87 | 85,0.008001119393492254 88 | 86,0.008001119393492254 89 | 87,0.008001119393492254 90 | 88,0.008001119393492254 91 | 89,0.008001119393492254 92 | 90,0.008001119393492254 93 | 91,0.008001119393492254 94 | 92,0.008001119393492254 95 | 93,0.008001119393492254 96 | 94,0.008001119393492254 97 | 95,0.008001119393492254 98 | 96,0.008001119393492254 99 | 97,0.008001119393492254 100 | 98,0.008001119393492254 101 | 99,0.008001119393492254 102 | 100,0.008001119393492254 103 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results_validation/stop_after_4games.csv: -------------------------------------------------------------------------------- 1 | move,accuracy 2 | 0,0.0015910490659002258 3 | 1,0.006056251283104086 4 | 2,0.016115787312666805 5 | 3,0.017193594744405665 6 | 4,0.014884007390679532 7 | 5,0.016013138985834532 8 | 6,0.015499897351673168 9 | 7,0.015448573188257032 10 | 8,0.014986655717511805 11 | 9,0.014678710737014987 12 | 10,0.01591049065900226 13 | 11,0.015294600698008622 14 | 12,0.014524738246766578 15 | 13,0.014730034900431123 16 | 14,0.015037979880927942 17 | 15,0.014730034900431123 18 | 16,0.014832683227263395 19 | 17,0.015345924861424758 20 | 18,0.015089304044344077 21 | 19,0.014114144939437486 22 | 20,0.014422089919934305 23 | 21,0.015037979880927942 24 | 22,0.014678710737014987 25 | 23,0.013960172449189078 26 | 24,0.013857524122356805 27 | 25,0.013395606651611578 28 | 26,0.013087661671114761 29 | 27,0.014165469102853623 30 | 28,0.014576062410182713 31 | 29,0.01365222746869226 32 | 30,0.014319441593102033 33 | 31,0.015037979880927942 34 | 32,0.014576062410182713 35 | 33,0.01478135906384726 36 | 34,0.014422089919934305 37 | 35,0.014576062410182713 38 | 36,0.013960172449189078 39 | 37,0.013754875795524533 40 | 38,0.013600903305276125 41 | 39,0.013908848285772941 42 | 40,0.014319441593102033 43 | 41,0.014114144939437486 44 | 42,0.014268117429685897 45 | 43,0.01365222746869226 46 | 44,0.013395606651611578 47 | 45,0.013292958324779306 48 | 46,0.013190309997947033 49 | 47,0.013292958324779306 50 | 48,0.013190309997947033 51 | 49,0.013087661671114761 52 | 50,0.013087661671114761 53 | 51,0.013087661671114761 54 | 52,0.013087661671114761 55 | 53,0.013087661671114761 56 | 54,0.013087661671114761 57 | 55,0.013087661671114761 58 | 56,0.013087661671114761 59 | 57,0.013087661671114761 60 | 58,0.013087661671114761 61 | 59,0.013087661671114761 62 | 60,0.013087661671114761 63 | 61,0.013087661671114761 64 | 62,0.013087661671114761 65 | 63,0.013087661671114761 66 | 64,0.013087661671114761 67 | 65,0.013087661671114761 68 | 66,0.013087661671114761 69 | 67,0.013087661671114761 70 | 68,0.013087661671114761 71 | 69,0.013087661671114761 72 | 70,0.013087661671114761 73 | 71,0.013087661671114761 74 | 72,0.013087661671114761 75 | 73,0.013087661671114761 76 | 74,0.013087661671114761 77 | 75,0.013087661671114761 78 | 76,0.013087661671114761 79 | 77,0.013087661671114761 80 | 78,0.013087661671114761 81 | 79,0.013087661671114761 82 | 80,0.013087661671114761 83 | 81,0.013087661671114761 84 | 82,0.013087661671114761 85 | 83,0.013087661671114761 86 | 84,0.013087661671114761 87 | 85,0.013087661671114761 88 | 86,0.013087661671114761 89 | 87,0.013087661671114761 90 | 88,0.013087661671114761 91 | 89,0.013087661671114761 92 | 90,0.013087661671114761 93 | 91,0.013087661671114761 94 | 92,0.013087661671114761 95 | 93,0.013087661671114761 96 | 94,0.013087661671114761 97 | 95,0.013087661671114761 98 | 96,0.013087661671114761 99 | 97,0.013087661671114761 100 | 98,0.013087661671114761 101 | 99,0.013087661671114761 102 | 100,0.013087661671114761 103 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/sweep_moves_all_games.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import csv 3 | import argparse 4 | import os 5 | import numpy as np 6 | from sklearn.naive_bayes import GaussianNB 7 | from matplotlib import pyplot as plt 8 | 9 | def parse_argument(): 10 | parser = argparse.ArgumentParser(description='arg parser') 11 | 12 | parser.add_argument('--input_dir', default='cp_loss_hist_per_move') 13 | parser.add_argument('--output_start_after_csv', default='start_after_all_game.csv') 14 | parser.add_argument('--output_stop_after_csv', default='stop_after_all_game.csv') 15 | parser.add_argument('--saved_plot', default='plot_all_game.png') 16 | 17 | return parser.parse_args() 18 | 19 | def read_npy(input_dir): 20 | 21 | player_list = {} 22 | for input_data in os.listdir(input_dir): 23 | # will split into [player_name, 'train/test/val'] 24 | input_name = input_data.split('_') 25 | if len(input_name) > 2: 26 | player_name = input_name[:-1] 27 | player_name = '_'.join(player_name) 28 | else: 29 | player_name = input_name[0] 30 | # add into player list 31 | if player_name not in player_list: 32 | player_list[player_name] = 1 33 | 34 | player_list = list(player_list.keys()) 35 | 36 | player_data = {} 37 | for player_name in player_list: 38 | player_data[player_name] = {'train': None, 'validation': None, 'test': None} 39 | train_path = os.path.join(input_dir, player_name + '_{}.npy'.format('train')) 40 | val_path = os.path.join(input_dir, player_name + '_{}.npy'.format('validation')) 41 | test_path = os.path.join(input_dir, player_name + '_{}.npy'.format('test')) 42 | 43 | player_data[player_name]['train'] = np.load(train_path, allow_pickle=True) 44 | player_data[player_name]['train'] = player_data[player_name]['train'].item() 45 | player_data[player_name]['validation'] = np.load(val_path, allow_pickle=True) 46 | player_data[player_name]['validation'] = player_data[player_name]['validation'].item() 47 | player_data[player_name]['test'] = np.load(test_path, allow_pickle=True) 48 | player_data[player_name]['test'] = player_data[player_name]['test'].item() 49 | 50 | return player_data 51 | 52 | 53 | def construct_train_set(player_data, is_start_after, move_stop): 54 | player_index = {} 55 | train_list = [] 56 | train_label = [] 57 | 58 | i = 0 59 | for player in player_data.keys(): 60 | # if player in os.listdir('/data/csvs'): 61 | player_index[player] = i 62 | train_label.append(i) 63 | if is_start_after: 64 | train_list.append(player_data[player]['train']['start_after'][move_stop]) 65 | else: 66 | train_list.append(player_data[player]['train']['stop_after'][move_stop]) 67 | 68 | i += 1 69 | 70 | train_label = np.asarray(train_label) 71 | # one_hot = np.zeros((train_label.size, train_label.max()+1)) 72 | # one_hot[np.arange(train_label.size),train_label] = 1 73 | # print(one_hot.shape) 74 | 75 | train_data = np.stack(train_list, 0) 76 | return train_data, train_label, player_index 77 | 78 | 79 | def predict(train_data, train_label, player_data, player_index, is_start_after, move_stop): 80 | correct = 0 81 | total = 0 82 | model = GaussianNB() 83 | model.fit(train_data, train_label) 84 | for player in player_data.keys(): 85 | test = player_data[player]['test'] 86 | if is_start_after: 87 | test = test['start_after'][move_stop] 88 | if all(v == 0 for v in test): 89 | continue 90 | else: 91 | test = test['stop_after'][move_stop] 92 | 93 | predicted = model.predict(np.expand_dims(test, axis=0)) 94 | index = predicted[0] 95 | if index == player_index[player]: 96 | correct += 1 97 | total += 1 98 | 99 | if total == 0: 100 | accuracy = 0 101 | else: 102 | accuracy = correct / total 103 | 104 | print(accuracy) 105 | return accuracy 106 | 107 | 108 | def make_plots(moves, start_after_accuracies, stop_after_accuracies, plot_name): 109 | plt.plot(moves, start_after_accuracies, label="Start after x moves") 110 | plt.plot(moves, stop_after_accuracies, label="Stop after x moves") 111 | plt.legend() 112 | plt.xlabel("Moves") 113 | plt.savefig(plot_name) 114 | 115 | 116 | if __name__ == '__main__': 117 | args = parse_argument() 118 | 119 | player_data = read_npy(args.input_dir) 120 | moves = [i for i in range(101)] 121 | start_after_accuracies = [] 122 | stop_after_accuracies = [] 123 | output_start_csv = open(args.output_start_after_csv, 'w', newline='') 124 | writer_start = csv.writer(output_start_csv) 125 | writer_start.writerow(['move', 'accuracy']) 126 | 127 | output_stop_csv = open(args.output_stop_after_csv, 'w', newline='') 128 | writer_stop = csv.writer(output_stop_csv) 129 | writer_stop.writerow(['move', 'accuracy']) 130 | 131 | for is_start_after in (True, False): 132 | for i in range(101): 133 | print('testing {} move {}'.format('start_after' if is_start_after else 'stop_after', i)) 134 | train_data, train_label, player_index = construct_train_set(player_data, is_start_after, i) 135 | 136 | accuracy = predict(train_data, train_label, player_data, player_index, is_start_after, i) 137 | 138 | if is_start_after: 139 | start_after_accuracies.append(accuracy) 140 | writer_start.writerow([i, accuracy]) 141 | else: 142 | stop_after_accuracies.append(accuracy) 143 | writer_stop.writerow([i, accuracy]) 144 | 145 | make_plots(moves, start_after_accuracies, stop_after_accuracies, args.saved_plot) 146 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/sweep_moves_num_games.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import csv 3 | import argparse 4 | import os 5 | import numpy as np 6 | from sklearn.naive_bayes import GaussianNB 7 | from matplotlib import pyplot as plt 8 | 9 | def parse_argument(): 10 | parser = argparse.ArgumentParser(description='arg parser') 11 | 12 | parser.add_argument('--train_dir', default='cp_loss_hist_per_move') 13 | parser.add_argument('--input_dir', default='cp_loss_hist_per_move_per_game_count') 14 | parser.add_argument('--output_start_after_csv', default='start_after_4games.csv') 15 | parser.add_argument('--output_stop_after_csv', default='stop_after_4games.csv') 16 | parser.add_argument('--num_games', default=4) 17 | parser.add_argument('--saved_plot', default='plot_4games.png') 18 | 19 | return parser.parse_args() 20 | 21 | def read_npy(train_dir, input_dir): 22 | 23 | player_list = {} 24 | for input_data in os.listdir(input_dir): 25 | # will split into [player_name, 'train/test/val'] 26 | input_name = input_data.split('_') 27 | if len(input_name) > 2: 28 | player_name = input_name[:-1] 29 | player_name = '_'.join(player_name) 30 | else: 31 | player_name = input_name[0] 32 | # add into player list 33 | if player_name not in player_list: 34 | player_list[player_name] = 1 35 | 36 | player_list = list(player_list.keys()) 37 | 38 | player_data = {} 39 | for player_name in player_list: 40 | player_data[player_name] = {'train': None, 'validation': None, 'test': None} 41 | train_path = os.path.join(train_dir, player_name + '_{}.npy'.format('train')) 42 | val_path = os.path.join(input_dir, player_name + '_{}.npy'.format('validation')) 43 | test_path = os.path.join(input_dir, player_name + '_{}.npy'.format('test')) 44 | 45 | player_data[player_name]['train'] = np.load(train_path, allow_pickle=True) 46 | player_data[player_name]['train'] = player_data[player_name]['train'].item() 47 | player_data[player_name]['validation'] = np.load(val_path, allow_pickle=True) 48 | player_data[player_name]['validation'] = player_data[player_name]['validation'].item() 49 | player_data[player_name]['test'] = np.load(test_path, allow_pickle=True) 50 | player_data[player_name]['test'] = player_data[player_name]['test'].item() 51 | 52 | return player_data 53 | 54 | def normalize(data): 55 | data_norm = data 56 | 57 | if any(v != 0 for v in data): 58 | norm = np.linalg.norm(data) 59 | data_norm = data/norm 60 | 61 | return data_norm 62 | 63 | def construct_train_set(player_data, is_start_after, move_stop): 64 | player_index = {} 65 | train_list = [] 66 | train_label = [] 67 | 68 | i = 0 69 | for player in player_data.keys(): 70 | # if player in os.listdir('/data/csvs'): 71 | player_index[player] = i 72 | train_label.append(i) 73 | if is_start_after: 74 | train_list.append(player_data[player]['train']['start_after'][move_stop]) 75 | else: 76 | train_list.append(player_data[player]['train']['stop_after'][move_stop]) 77 | 78 | i += 1 79 | 80 | train_label = np.asarray(train_label) 81 | # one_hot = np.zeros((train_label.size, train_label.max()+1)) 82 | # one_hot[np.arange(train_label.size),train_label] = 1 83 | # print(one_hot.shape) 84 | 85 | train_data = np.stack(train_list, 0) 86 | return train_data, train_label, player_index 87 | 88 | 89 | def predict(train_data, train_label, player_data, player_index, is_start_after, move_stop, num_games): 90 | accurcies = [] 91 | correct = 0 92 | total = 0 93 | model = GaussianNB() 94 | model.fit(train_data, train_label) 95 | results = None 96 | for player in player_data.keys(): 97 | test_game = None 98 | tmp_game = None 99 | test_games = [] 100 | test = player_data[player]['test'] 101 | count = 1 102 | 103 | # key is game id 104 | for key, value in test.items(): 105 | # get which game to use 106 | if is_start_after: 107 | tmp_game = test[key]['start_after'][move_stop] 108 | # ignore all 0 cases, essentially there's no more move in this game 109 | if all(v == 0 for v in tmp_game): 110 | continue 111 | else: 112 | tmp_game = test[key]['stop_after'][move_stop] 113 | 114 | # add up counts in each game 115 | if test_game is None: 116 | test_game = tmp_game 117 | else: 118 | test_game = test_game + tmp_game 119 | 120 | if count == num_games: 121 | # test_game is addition of counts, need to normalize before testing 122 | test_game = normalize(test_game) 123 | test_games.append(test_game) 124 | 125 | # reset 126 | test_game = None 127 | tmp_game = None 128 | count = 1 129 | 130 | else: 131 | count += 1 132 | 133 | # skip player if all games are beyond move_stop 134 | if not test_games: 135 | continue 136 | 137 | test_games = np.stack(test_games, axis=0) 138 | predicted = model.predict(test_games) 139 | result = (predicted == player_index[player]).astype(float) 140 | 141 | # append to the overall result 142 | if results is None: 143 | results = result 144 | else: 145 | results = np.append(results, result, 0) 146 | 147 | if results is None: 148 | accuracy = 0 149 | 150 | else: 151 | accuracy = np.mean(results) 152 | 153 | print(accuracy) 154 | 155 | return accuracy 156 | 157 | 158 | def make_plots(moves, start_after_accuracies, stop_after_accuracies, plot_name): 159 | plt.plot(moves, start_after_accuracies, label="Start after x moves") 160 | plt.plot(moves, stop_after_accuracies, label="Stop after x moves") 161 | plt.legend() 162 | plt.xlabel("Moves") 163 | plt.savefig(plot_name) 164 | 165 | 166 | if __name__ == '__main__': 167 | args = parse_argument() 168 | 169 | player_data = read_npy(args.train_dir, args.input_dir) 170 | moves = [i for i in range(101)] 171 | start_after_accuracies = [] 172 | stop_after_accuracies = [] 173 | output_start_csv = open(args.output_start_after_csv, 'w', newline='') 174 | writer_start = csv.writer(output_start_csv) 175 | writer_start.writerow(['move', 'accuracy']) 176 | 177 | output_stop_csv = open(args.output_stop_after_csv, 'w', newline='') 178 | writer_stop = csv.writer(output_stop_csv) 179 | writer_stop.writerow(['move', 'accuracy']) 180 | 181 | for is_start_after in (True, False): 182 | for i in range(101): 183 | print('testing {} move {}'.format('start_after' if is_start_after else 'stop_after', i)) 184 | train_data, train_label, player_index = construct_train_set(player_data, is_start_after, i) 185 | 186 | accuracy = predict(train_data, train_label, player_data, player_index, is_start_after, i, args.num_games) 187 | 188 | if is_start_after: 189 | start_after_accuracies.append(accuracy) 190 | writer_start.writerow([i, accuracy]) 191 | else: 192 | stop_after_accuracies.append(accuracy) 193 | writer_stop.writerow([i, accuracy]) 194 | 195 | make_plots(moves, start_after_accuracies, stop_after_accuracies, args.saved_plot) 196 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/sweep_moves_per_game.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import csv 3 | import argparse 4 | import os 5 | import numpy as np 6 | from sklearn.naive_bayes import GaussianNB 7 | from matplotlib import pyplot as plt 8 | 9 | def parse_argument(): 10 | parser = argparse.ArgumentParser(description='arg parser') 11 | 12 | parser.add_argument('--train_dir', default='cp_loss_hist_per_move') 13 | parser.add_argument('--input_dir', default='cp_loss_hist_per_move_per_game') 14 | parser.add_argument('--output_start_after_csv', default='start_after.csv') 15 | parser.add_argument('--output_stop_after_csv', default='stop_after.csv') 16 | parser.add_argument('--saved_plot', default='plot.png') 17 | 18 | return parser.parse_args() 19 | 20 | def read_npy(train_dir, input_dir): 21 | 22 | player_list = {} 23 | for input_data in os.listdir(input_dir): 24 | # will split into [player_name, 'train/test/val'] 25 | input_name = input_data.split('_') 26 | if len(input_name) > 2: 27 | player_name = input_name[:-1] 28 | player_name = '_'.join(player_name) 29 | else: 30 | player_name = input_name[0] 31 | # add into player list 32 | if player_name not in player_list: 33 | player_list[player_name] = 1 34 | 35 | player_list = list(player_list.keys()) 36 | 37 | player_data = {} 38 | for player_name in player_list: 39 | player_data[player_name] = {'train': None, 'validation': None, 'test': None} 40 | train_path = os.path.join(train_dir, player_name + '_{}.npy'.format('train')) 41 | val_path = os.path.join(input_dir, player_name + '_{}.npy'.format('validation')) 42 | test_path = os.path.join(input_dir, player_name + '_{}.npy'.format('test')) 43 | 44 | player_data[player_name]['train'] = np.load(train_path, allow_pickle=True) 45 | player_data[player_name]['train'] = player_data[player_name]['train'].item() 46 | player_data[player_name]['validation'] = np.load(val_path, allow_pickle=True) 47 | player_data[player_name]['validation'] = player_data[player_name]['validation'].item() 48 | player_data[player_name]['test'] = np.load(test_path, allow_pickle=True) 49 | player_data[player_name]['test'] = player_data[player_name]['test'].item() 50 | 51 | return player_data 52 | 53 | 54 | def construct_train_set(player_data, is_start_after, move_stop): 55 | player_index = {} 56 | train_list = [] 57 | train_label = [] 58 | 59 | i = 0 60 | for player in player_data.keys(): 61 | # if player in os.listdir('/data/csvs'): 62 | player_index[player] = i 63 | train_label.append(i) 64 | if is_start_after: 65 | train_list.append(player_data[player]['train']['start_after'][move_stop]) 66 | else: 67 | train_list.append(player_data[player]['train']['stop_after'][move_stop]) 68 | 69 | i += 1 70 | 71 | train_label = np.asarray(train_label) 72 | # one_hot = np.zeros((train_label.size, train_label.max()+1)) 73 | # one_hot[np.arange(train_label.size),train_label] = 1 74 | # print(one_hot.shape) 75 | 76 | train_data = np.stack(train_list, 0) 77 | return train_data, train_label, player_index 78 | 79 | 80 | def predict(train_data, train_label, player_data, player_index, is_start_after, move_stop): 81 | accurcies = [] 82 | correct = 0 83 | total = 0 84 | model = GaussianNB() 85 | model.fit(train_data, train_label) 86 | results = None 87 | for player in player_data.keys(): 88 | test_game = None 89 | test_games = [] 90 | test = player_data[player]['test'] 91 | 92 | # key is game id 93 | for key, value in test.items(): 94 | if is_start_after: 95 | test_game = test[key]['start_after'][move_stop] 96 | # ignore all 0 cases, essentially there's no more move in this game 97 | if all(v == 0 for v in test_game): 98 | continue 99 | else: 100 | test_game = test[key]['stop_after'][move_stop] 101 | 102 | test_games.append(test_game) 103 | 104 | # skip player if all games are beyond move_stop 105 | if not test_games: 106 | continue 107 | 108 | test_games = np.stack(test_games, axis=0) 109 | predicted = model.predict(test_games) 110 | result = (predicted == player_index[player]).astype(float) 111 | 112 | # append to the overall result 113 | if results is None: 114 | results = result 115 | else: 116 | results = np.append(results, result, 0) 117 | 118 | if results is None: 119 | accuracy = 0 120 | 121 | else: 122 | accuracy = np.mean(results) 123 | 124 | print(accuracy) 125 | 126 | return accuracy 127 | 128 | 129 | def make_plots(moves, start_after_accuracies, stop_after_accuracies, plot_name): 130 | plt.plot(moves, start_after_accuracies, label="Start after x moves") 131 | plt.plot(moves, stop_after_accuracies, label="Stop after x moves") 132 | plt.legend() 133 | plt.xlabel("Moves") 134 | plt.savefig(plot_name) 135 | 136 | 137 | if __name__ == '__main__': 138 | args = parse_argument() 139 | 140 | player_data = read_npy(args.train_dir, args.input_dir) 141 | moves = [i for i in range(101)] 142 | start_after_accuracies = [] 143 | stop_after_accuracies = [] 144 | output_start_csv = open(args.output_start_after_csv, 'w', newline='') 145 | writer_start = csv.writer(output_start_csv) 146 | writer_start.writerow(['move', 'accuracy']) 147 | 148 | output_stop_csv = open(args.output_stop_after_csv, 'w', newline='') 149 | writer_stop = csv.writer(output_stop_csv) 150 | writer_stop.writerow(['move', 'accuracy']) 151 | 152 | for is_start_after in (True, False): 153 | for i in range(101): 154 | print('testing {} move {}'.format('start_after' if is_start_after else 'stop_after', i)) 155 | train_data, train_label, player_index = construct_train_set(player_data, is_start_after, i) 156 | 157 | accuracy = predict(train_data, train_label, player_data, player_index, is_start_after, i) 158 | 159 | if is_start_after: 160 | start_after_accuracies.append(accuracy) 161 | writer_start.writerow([i, accuracy]) 162 | else: 163 | stop_after_accuracies.append(accuracy) 164 | writer_stop.writerow([i, accuracy]) 165 | 166 | make_plots(moves, start_after_accuracies, stop_after_accuracies, args.saved_plot) 167 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/test_all_games.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import csv 3 | import argparse 4 | import os 5 | import numpy as np 6 | from sklearn.naive_bayes import GaussianNB 7 | 8 | def parse_argument(): 9 | parser = argparse.ArgumentParser(description='arg parser') 10 | 11 | parser.add_argument('--input_dir', default='cp_loss_hist') 12 | parser.add_argument('--use_bayes', default=True) 13 | 14 | return parser.parse_args() 15 | 16 | def read_npy(input_dir): 17 | 18 | player_list = {} 19 | for input_data in os.listdir(input_dir): 20 | # will split into [player_name, 'train/test/val'] 21 | input_name = input_data.split('_') 22 | if len(input_name) > 2: 23 | player_name = input_name[:-1] 24 | player_name = '_'.join(player_name) 25 | else: 26 | player_name = input_name[0] 27 | # add into player list 28 | if player_name not in player_list: 29 | player_list[player_name] = 1 30 | 31 | player_list = list(player_list.keys()) 32 | 33 | player_data = {} 34 | for player_name in player_list: 35 | player_data[player_name] = {'train': None, 'validation': None, 'test': None} 36 | train_path = os.path.join(input_dir, player_name + '_{}.npy'.format('train')) 37 | val_path = os.path.join(input_dir, player_name + '_{}.npy'.format('validation')) 38 | test_path = os.path.join(input_dir, player_name + '_{}.npy'.format('test')) 39 | 40 | player_data[player_name]['train'] = np.load(train_path) 41 | player_data[player_name]['validation'] = np.load(val_path) 42 | player_data[player_name]['test'] = np.load(test_path) 43 | 44 | return player_data 45 | 46 | # =============================== Naive Bayes =============================== 47 | def construct_train_set(player_data): 48 | player_index = {} 49 | train_list = [] 50 | train_label = [] 51 | 52 | i = 0 53 | for player in player_data.keys(): 54 | player_index[player] = i 55 | train_label.append(i) 56 | train_list.append(player_data[player]['train']) 57 | i += 1 58 | 59 | train_label = np.asarray(train_label) 60 | 61 | train_data = np.stack(train_list, 0) 62 | return train_data, train_label, player_index 63 | 64 | def predict(train_data, train_label, player_data, player_index): 65 | print(player_index) 66 | correct = 0 67 | total = 0 68 | model = GaussianNB() 69 | model.fit(train_data, train_label) 70 | 71 | for player in player_data.keys(): 72 | test = player_data[player]['test'] 73 | predicted = model.predict(np.expand_dims(test, axis=0)) 74 | index = predicted[0] 75 | if index == player_index[player]: 76 | correct += 1 77 | total += 1 78 | 79 | print('accuracy is {}'.format(correct / total)) 80 | 81 | # =============================== Euclidean Distance =============================== 82 | def construct_train_list(player_data): 83 | # player_index is {player_name: id} mapping 84 | player_index = {} 85 | train_list = [] 86 | i = 0 87 | for player in player_data.keys(): 88 | player_index[player] = i 89 | train_list.append(player_data[player]['train']) 90 | i += 1 91 | 92 | return train_list, player_index 93 | 94 | def test_euclidean_dist(train_list, player_data, player_index): 95 | print(player_index) 96 | correct = 0 97 | total = 0 98 | # loop through each player and test their 'test set' 99 | for player in player_data.keys(): 100 | dist_list = [] 101 | test = player_data[player]['test'] 102 | 103 | # save distance for each (test, train) 104 | for train_data in train_list: 105 | dist = np.linalg.norm(train_data - test) 106 | dist_list.append(dist) 107 | 108 | # find minimum distance and its index 109 | min_index = dist_list.index(min(dist_list)) 110 | if min_index == player_index[player]: 111 | correct += 1 112 | total += 1 113 | 114 | print('accuracy is {}'.format(correct / total)) 115 | 116 | # =============================== run bayes or euclidean =============================== 117 | def run_bayes(player_data): 118 | print("Using Naive Bayes") 119 | train_data, train_label, player_index = construct_train_set(player_data) 120 | predict(train_data, train_label, player_data, player_index) 121 | 122 | def run_euclidean_dist(player_data): 123 | print("Using Euclidean Distance") 124 | train_list, player_index = construct_train_list(player_data) 125 | test_euclidean_dist(train_list, player_data, player_index) 126 | 127 | 128 | if __name__ == '__main__': 129 | args = parse_argument() 130 | 131 | player_data = read_npy(args.input_dir) 132 | 133 | if args.use_bayes: 134 | run_bayes(player_data) 135 | else: 136 | run_euclidean_dist(player_data) -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/train_cploss_per_game.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import csv 3 | import argparse 4 | import os 5 | import numpy as np 6 | import tensorflow as tf 7 | from sklearn.naive_bayes import GaussianNB 8 | 9 | def parse_argument(): 10 | parser = argparse.ArgumentParser(description='arg parser') 11 | 12 | parser.add_argument('--input_dir', default='cp_loss_count_per_game') 13 | parser.add_argument('--gpu', default=0, type=int) 14 | 15 | return parser.parse_args() 16 | 17 | def normalize(data): 18 | norm = np.linalg.norm(data) 19 | data_norm = data/norm 20 | return data_norm 21 | 22 | def read_npy(input_dir): 23 | 24 | player_list = {} 25 | for input_data in os.listdir(input_dir): 26 | # will split into [player_name, 'train/test/val'] 27 | input_name = input_data.split('_') 28 | if len(input_name) > 2: 29 | player_name = input_name[:-1] 30 | player_name = '_'.join(player_name) 31 | else: 32 | player_name = input_name[0] 33 | # add into player list 34 | if player_name not in player_list: 35 | player_list[player_name] = 1 36 | 37 | player_list = list(player_list.keys()) 38 | 39 | player_data = {} 40 | for player_name in player_list: 41 | player_data[player_name] = {'train': None, 'validation': None, 'test': None} 42 | train_path = os.path.join(input_dir, player_name + '_{}.npy'.format('train')) 43 | val_path = os.path.join(input_dir, player_name + '_{}.npy'.format('validation')) 44 | test_path = os.path.join(input_dir, player_name + '_{}.npy'.format('test')) 45 | 46 | player_data[player_name]['train'] = np.load(train_path, allow_pickle=True) 47 | player_data[player_name]['train'] = player_data[player_name]['train'].item() 48 | player_data[player_name]['validation'] = np.load(val_path, allow_pickle=True) 49 | player_data[player_name]['validation'] = player_data[player_name]['validation'].item() 50 | player_data[player_name]['test'] = np.load(test_path, allow_pickle=True) 51 | player_data[player_name]['test'] = player_data[player_name]['test'].item() 52 | 53 | return player_data 54 | 55 | def construct_datasets(player_data): 56 | player_index = {} 57 | train_list = [] 58 | train_labels = [] 59 | validation_list = [] 60 | validation_labels = [] 61 | test_list = [] 62 | test_labels = [] 63 | i = 0 64 | for player in player_data.keys(): 65 | label = i 66 | player_index[player] = i 67 | for key, value in player_data[player]['train'].items(): 68 | train_list.append(normalize(value)) 69 | train_labels.append(label) 70 | 71 | for key, value in player_data[player]['validation'].items(): 72 | validation_list.append(normalize(value)) 73 | validation_labels.append(label) 74 | 75 | for key, value in player_data[player]['test'].items(): 76 | test_list.append(normalize(value)) 77 | test_labels.append(label) 78 | 79 | i += 1 80 | # convert lists into numpy arrays 81 | train_list_np = np.stack(train_list, axis=0) 82 | validation_list_np = np.stack(validation_list, axis=0) 83 | test_list_np = np.stack(test_list, axis=0) 84 | 85 | train_labels_np = np.stack(train_labels, axis=0) 86 | validation_labels_np = np.stack(validation_labels, axis=0) 87 | test_labels_np = np.stack(test_labels, axis=0) 88 | 89 | return train_list_np, train_labels_np, validation_list_np, validation_labels_np, test_list_np, test_labels_np, player_index 90 | 91 | 92 | def init_net(output_size): 93 | l2reg = tf.keras.regularizers.l2(l=0.5 * (0.0001)) 94 | input_var = tf.keras.Input(shape=(50, )) 95 | dense_1 = tf.keras.layers.Dense(40, kernel_initializer='glorot_normal', kernel_regularizer=l2reg, bias_regularizer=l2reg, activation='relu')(input_var) 96 | dense_2 = tf.keras.layers.Dense(30, kernel_initializer='glorot_normal', kernel_regularizer=l2reg, bias_regularizer=l2reg)(dense_1) 97 | 98 | model= tf.keras.Model(inputs=input_var, outputs=dense_2) 99 | return model 100 | 101 | def train(train_dataset, train_labels, val_dataset, val_labels, test_dataset, test_labels, player_index): 102 | net = init_net(max(test_labels) + 1) 103 | net.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001, clipnorm=1), 104 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 105 | metrics=['accuracy']) 106 | 107 | net.fit(train_dataset, train_labels, batch_size=32, epochs=10, validation_data=(val_dataset, val_labels)) 108 | 109 | test_loss, test_acc = net.evaluate(test_dataset, test_labels, verbose=2) 110 | 111 | print('\nTest accuracy:', test_acc) 112 | 113 | return net 114 | 115 | # predict is to verify if keras test is correct 116 | def predict(net, test, test_labels): 117 | probability_model = tf.keras.Sequential([net, 118 | tf.keras.layers.Softmax()]) 119 | predictions = probability_model.predict(test) 120 | 121 | correct = 0 122 | total = 0 123 | for i, prediction in enumerate(predictions): 124 | if test_labels[i] == np.argmax(prediction): 125 | correct += 1 126 | total += 1 127 | 128 | print('test accuracy is: {}'.format(correct / total)) 129 | 130 | if __name__ == '__main__': 131 | args = parse_argument() 132 | 133 | gpus = tf.config.experimental.list_physical_devices('GPU') 134 | tf.config.experimental.set_visible_devices(gpus[args.gpu], 'GPU') 135 | tf.config.experimental.set_memory_growth(gpus[args.gpu], True) 136 | 137 | player_data = read_npy(args.input_dir) 138 | 139 | train_dataset, train_labels, val_dataset, val_labels, test_dataset, test_labels, player_index = construct_datasets(player_data) 140 | 141 | net = train(train_dataset, train_labels, val_dataset, val_labels, test_dataset, test_labels, player_index) 142 | 143 | # predict is to verify if test is correct 144 | # predict(net, test_dataset, test_labels) 145 | -------------------------------------------------------------------------------- /9-reduced-data/configs/Best_frozen.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | dataset: 4 | name: '' 5 | train_path: '' 6 | validate_path: '' 7 | gpu: 3 8 | model: 9 | back_prop_blocks: 3 10 | filters: 64 11 | keep_weights: true 12 | path: maia/1700 13 | residual_blocks: 6 14 | se_ratio: 8 15 | training: 16 | batch_size: 16 17 | checkpoint_small_steps: 18 | - 100 19 | - 200 20 | - 400 21 | - 800 22 | - 1600 23 | - 2500 24 | checkpoint_steps: 5000 25 | lr_boundaries: 26 | - 50000 27 | - 110000 28 | - 160000 29 | lr_values: 30 | - 1.0e-05 31 | - 1.0e-06 32 | - 1.0e-07 33 | - 1.0e-08 34 | num_batch_splits: 1 35 | policy_loss_weight: 1.0 36 | precision: half 37 | shuffle_size: 256 38 | small_mode: true 39 | test_small_boundaries: 40 | - 20000 41 | - 40000 42 | - 60000 43 | - 80000 44 | - 100000 45 | test_small_steps: 46 | - 100 47 | - 200 48 | - 400 49 | - 800 50 | - 1600 51 | - 2500 52 | test_steps: 2000 53 | total_steps: 200000 54 | train_avg_report_steps: 50 55 | value_loss_weight: 1.0 56 | ... 57 | -------------------------------------------------------------------------------- /9-reduced-data/configs/NFP.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | train_path: '' 7 | validate_path: '' 8 | name: '' 9 | 10 | training: 11 | precision: 'half' 12 | batch_size: 256 13 | num_batch_splits: 1 14 | test_steps: 1000 15 | train_avg_report_steps: 50 16 | total_steps: 150000 17 | checkpoint_steps: 5000 18 | shuffle_size: 256 19 | lr_values: 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | - 0.00001 24 | lr_boundaries: 25 | - 35000 26 | - 80000 27 | - 110000 28 | policy_loss_weight: 1.0 29 | value_loss_weight: 1.0 30 | 31 | model: 32 | filters: 64 33 | residual_blocks: 6 34 | se_ratio: 8 35 | path: "maia/1900" 36 | keep_weights: true 37 | back_prop_blocks: 99 38 | ... 39 | -------------------------------------------------------------------------------- /9-reduced-data/configs/Tuned.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | model: 5 | back_prop_blocks: 99 6 | filters: 64 7 | keep_weights: true 8 | path: maia/1900 9 | residual_blocks: 6 10 | se_ratio: 8 11 | training: 12 | batch_size: 16 13 | checkpoint_small_steps: 14 | - 50 15 | - 200 16 | - 400 17 | - 800 18 | - 1600 19 | - 2500 20 | checkpoint_steps: 5000 21 | early_stopping_steps: 10000 22 | lr_boundaries: 23 | - 50000 24 | - 110000 25 | - 160000 26 | lr_values: 27 | - 1.0e-05 28 | - 1.0e-06 29 | - 1.0e-07 30 | - 1.0e-08 31 | num_batch_splits: 1 32 | policy_loss_weight: 1.0 33 | precision: half 34 | shuffle_size: 256 35 | small_mode: true 36 | test_small_boundaries: 37 | - 20000 38 | - 40000 39 | - 60000 40 | - 80000 41 | - 100000 42 | test_small_steps: 43 | - 50 44 | - 200 45 | - 400 46 | - 800 47 | - 1600 48 | - 2500 49 | test_steps: 2000 50 | total_steps: 200000 51 | train_avg_report_steps: 50 52 | value_loss_weight: 1.0 53 | ... 54 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # YAML 1.2 2 | --- 3 | abstract: | 4 | "AI systems that can capture human-like behavior are becoming increasingly useful in situations where humans may want to learn from these systems, collaborate with them, or engage with them as partners for an extended duration. In order to develop human-oriented AI systems, the problem of predicting human actions---as opposed to predicting optimal actions---has received considerable attention. Existing work has focused on capturing human behavior in an aggregate sense, which potentially limits the benefit any particular individual could gain from interaction with these systems. We extend this line of work by developing highly accurate predictive models of individual human behavior in chess. Chess is a rich domain for exploring human-AI interaction because it combines a unique set of properties: AI systems achieved superhuman performance many years ago, and yet humans still interact with them closely, both as opponents and as preparation tools, and there is an enormous corpus of recorded data on individual player games. Starting with Maia, an open-source version of AlphaZero trained on a population of human players, we demonstrate that we can significantly improve prediction accuracy of a particular player's moves by applying a series of fine-tuning methods. Furthermore, our personalized models can be used to perform stylometry---predicting who made a given set of moves---indicating that they capture human decision-making at an individual level. Our work demonstrates a way to bring AI systems into better alignment with the behavior of individual people, which could lead to large improvements in human-AI interaction." 5 | authors: 6 | - 7 | affiliation: "Universiy of Toronto" 8 | family-names: "McIlroy-Young" 9 | given-names: Reid 10 | - 11 | affiliation: "Universiy of Toronto" 12 | family-names: Wang 13 | given-names: Russell 14 | - 15 | affiliation: "Microsoft Research" 16 | family-names: Sen 17 | given-names: Siddhartha 18 | - 19 | affiliation: "Cornell University" 20 | family-names: Kleinberg 21 | given-names: Jon 22 | - 23 | affiliation: "Universiy of Toronto" 24 | family-names: Anderson 25 | given-names: Ashton 26 | cff-version: "1.2.0" 27 | date-released: 2022-09-16 28 | doi: "10.1145/3534678.3539367" 29 | license: "GPL-3.0" 30 | message: "If you use Maia Individual, please cite the original paper" 31 | repository-code: "https://github.com/CSSLab/maia-individual" 32 | url: "https://maiachess.com" 33 | title: "Maia Individual" 34 | version: "1.0.0" 35 | preferred-citation: 36 | type: article 37 | authors: 38 | - 39 | affiliation: "Universiy of Toronto" 40 | family-names: "McIlroy-Young" 41 | given-names: Reid 42 | - 43 | affiliation: "Microsoft Research" 44 | family-names: Sen 45 | given-names: Siddhartha 46 | - 47 | affiliation: "Cornell University" 48 | family-names: Kleinberg 49 | given-names: Jon 50 | - 51 | affiliation: "Universiy of Toronto" 52 | family-names: Anderson 53 | given-names: Ashton 54 | doi: "10.1145/3534678.3539367" 55 | journal: "KDD '22: Proceedings of the 28th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining" 56 | month: 8 57 | title: "Learning Models of Individual Behavior in Chess" 58 | year: 2022 59 | ... 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Models of Individual Behavior in Chess 2 | 3 | ## [website](https://maiachess.com)/[paper](https://arxiv.org/abs/2008.10086)/[code](https://github.com/CSSLab/maia-individual) 4 | 5 |

6 | Uplift of individualized models vs the baseline 7 |

8 | 9 | ## Overview 10 | 11 | The main code used in this project is stored in `backend` which is setup as a Python package, running `python setup.py install` will install it. Then the various scripts can be used. We also recommend using the virtual env config we include in `environment.yml` as some packages are required to be up to date. In addition for generating training data two more tools are need [`pgn-extract`](https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/) to clean the PGNs and [`trainingdata-tool`](https://github.com/DanielUranga/trainingdata-tool) to convert them into training data. 12 | 13 | To run a model as a chess engine [`lco`](https://github.com/LeelaChessZero/lc0}{github.com/LeelaChessZero/lc0) version 23 has been tested and should work with all models (you can specify a path to a model with the `-w` argument). 14 | 15 | All testing was done with Ubuntu 18.04 with CUDA Version 10. 16 | 17 | ## Models 18 | 19 | We do not have any public models of players publicly available at this time. This is because of the stylometry results shown in section 5.2 of the paper we cannot release anonymous models. 20 | 21 | We have included the maia model, from [https://github.com/CSSLab/maia-chess](https://github.com/CSSLab/maia-chess) that was used as the base. 22 | 23 | ## Running the code 24 | 25 | The code for this project is divided into different sections, each has a series of shell scripts that are numbered. If ran in order the training data, then final models can be generated. For the full release we plan to have the process to generate a model more streamlined. 26 | 27 | ### Quick Run 28 | 29 | To get the model for a single player from a single PGN a simpler system can be used first 30 | 31 | 1. Run `1-data_generation/9-pgn_to_training_data.sh input_PGN_file output_directory player_name` 32 | 2. Create a config file by copying `2-training/final_config.yaml` and adding `output_directory` and `player_name` 33 | 3. Run `python 2-training/train_transfer.py path_to_config` 34 | 4. The final model will be written to `final_models`, read the `--help` for more information 35 | 36 | ### Full Run 37 | 38 | For all scripts if applicable they start with a list of variables, these will need to be edited to match the paths on your system. 39 | 40 | The list of players we used was selected using the code in `0-player_counting`. The standard games from lichess [database.lichess.org](database.lichess.org) up to April are required to get our exact results but it should work with other sets, even non-Lichess ones with a bit of work. 41 | 42 | Then the players games are extracted and the various sets are constructed from them in `1-data_generation`. 43 | 44 | Finally `2-training` has the main training script along with a configuration file that specifies the hyper parameters. All four discussed in the main text are included. 45 | 46 | ### Extras 47 | 48 | The analysis code (`3-analysis`) is included for completeness, but as it is for generating the data used in the plots and relies on various hard coded paths we have not tested it. That said `3-analysis/prediction_generator.py` is the main workhorse and has a `--help`, note it is designed for on files output by `backend.gameToCSVlines`, but less complicated csvs could be used. 49 | 50 | The baseline models code and results are in `4-cp_loss_stylo_baseline` this is a simple baseline model to compare our results to, and is included for completeness. 51 | 52 | ## Reduced Data 53 | 54 | The model configurations for the reduced data training are included in `9-reduced-data/configs` To train them yourself simply use the configs in the quick run training. 55 | 56 | ## Citation 57 | 58 | ``` 59 | @article{McIlroyYoung_Learning_Models_Chess_2022, 60 | author = {McIlroy-Young, Reid and Sen, Siddhartha and Kleinberg, Jon and Anderson, Ashton}, 61 | doi = {10.1145/3534678.3539367}, 62 | journal = {KDD '22: Proceedings of the 28th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining}, 63 | month = {8}, 64 | title = {{Learning Models of Individual Behavior in Chess}}, 65 | year = {2022} 66 | } 67 | ``` 68 | 69 | ## License 70 | 71 | The software is available under the GPL License and includes code from the [Leela Chess Zero](https://github.com/LeelaChessZero/lczero-training) project. 72 | 73 | ## Contact 74 | 75 | Please [open an issue](https://github.com/CSSLab/maia-individual/issues/new) or email [Reid McIlroy-Young](https://reidmcy.com/) to get in touch 76 | -------------------------------------------------------------------------------- /backend/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .uci_engine import * 3 | from .pgn_parsering import * 4 | from .multiproc import * 5 | from .fen_to_vec import fen_to_vec, array_to_fen, array_to_board, game_to_vecs 6 | from .pgn_to_csv import * 7 | 8 | __version__ = '1.0.0' 9 | -------------------------------------------------------------------------------- /backend/fen_to_vec.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import chess 4 | import numpy as np 5 | 6 | # Generate the regexs 7 | boardRE = re.compile(r"(([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)) ((w)|(b)) ((-)|(K)?(Q)?(k)?(q)?)( ((-)|(\w+)))?( \d+)?( \d+)?") 8 | 9 | replaceRE = re.compile(r'[1-8/]') 10 | 11 | pieceMapWhite = {'E' : [False] * 12} 12 | pieceMapBlack = {'E' : [False] * 12} 13 | 14 | piece_reverse_lookup = {} 15 | 16 | all_pieces = 'PNBRQK' 17 | 18 | for i, p in enumerate(all_pieces): 19 | #White pieces first 20 | mP = [False] * 12 21 | mP[i] = True 22 | pieceMapBlack[p] = mP 23 | piece_reverse_lookup[i] = p 24 | 25 | #then black 26 | mP = [False] * 12 27 | mP[i + len(all_pieces)] = True 28 | pieceMapBlack[p.lower()] = mP 29 | piece_reverse_lookup[i + len(all_pieces)] = p.lower() 30 | 31 | 32 | #Black pieces first 33 | mP = [False] * 12 34 | mP[i] = True 35 | pieceMapWhite[p.lower()] = mP 36 | 37 | #then white 38 | mP = [False] * 12 39 | mP[i + len(all_pieces)] = True 40 | pieceMapWhite[p] = mP 41 | 42 | iSs = [str(i + 1) for i in range(8)] 43 | eSubss = [('E' * i, str(i)) for i in range(8,0, -1)] 44 | castling_vals = 'KQkq' 45 | 46 | def toByteBuff(l): 47 | return b''.join([b'\1' if e else b'\0' for e in l]) 48 | 49 | pieceMapBin = {k : toByteBuff(v) for k,v in pieceMapBlack.items()} 50 | 51 | def toBin(c): 52 | return pieceMapBin[c] 53 | 54 | castlesMap = {True : b'\1'*64, False : b'\0'*64} 55 | 56 | #Some previous lines are left in just in case 57 | 58 | # using N,C,H,W format 59 | 60 | move_letters = list('abcdefgh') 61 | 62 | moves_lookup = {} 63 | move_ind = 0 64 | for r_1 in range(8): 65 | for c_1 in range(8): 66 | for r_2 in range(8): 67 | for c_2 in range(8): 68 | moves_lookup[f"{move_letters[r_1]}{c_1+1}{move_letters[r_2]}{c_2+1}"] = move_ind 69 | move_ind += 1 70 | 71 | def move_to_index(move_str): 72 | return moves_lookup[move_str[:4]] 73 | 74 | def array_to_preproc(a_target): 75 | if not isinstance(a_target, np.ndarray): 76 | #check if toch Tensor without importing torch 77 | a_target = a_target.cpu().numpy() 78 | if a_target.dtype != np.bool_: 79 | a_target = a_target.astype(np.bool_) 80 | piece_layers = a_target[:12] 81 | board_a = np.moveaxis(piece_layers, 2, 0).reshape(64, 12) 82 | board_str = '' 83 | is_white = bool(a_target[12, 0, 0]) 84 | castling = [bool(l[0,0]) for l in a_target[13:]] 85 | board = [['E'] * 8 for i in range(8)] 86 | for i in range(12): 87 | for x in range(8): 88 | for y in range(8): 89 | if piece_layers[i,x,y]: 90 | board[x][y] = piece_reverse_lookup[i] 91 | board = [''.join(r) for r in board] 92 | return ''.join(board), is_white, tuple(castling) 93 | 94 | def preproc_to_fen(boardStr, is_white, castling): 95 | rows = [boardStr[(i*8):(i*8)+8] for i in range(8)] 96 | 97 | if not is_white: 98 | castling = castling[2:] + castling[:2] 99 | new_rows = [] 100 | for b in rows: 101 | new_rows.append(b.swapcase()[::-1].replace('e', 'E')) 102 | 103 | rows = reversed(new_rows) 104 | row_strs = [] 105 | for r in rows: 106 | for es, i in eSubss: 107 | if es in r: 108 | r = r.replace(es, i) 109 | row_strs.append(r) 110 | castle_str = '' 111 | for i, v in enumerate(castling): 112 | if v: 113 | castle_str += castling_vals[i] 114 | if len(castle_str) < 1: 115 | castle_str = '-' 116 | 117 | is_white_str = 'w' if is_white else 'b' 118 | board_str = '/'.join(row_strs) 119 | return f"{board_str} {is_white_str} {castle_str} - 0 1" 120 | 121 | def array_to_fen(a_target): 122 | return preproc_to_fen(*array_to_preproc(a_target)) 123 | 124 | def array_to_board(a_target): 125 | return chess.Board(fen = array_to_fen(a_target)) 126 | 127 | def simple_fen_vec(boardStr, is_white, castling): 128 | castles = [np.frombuffer(castlesMap[c], dtype='bool').reshape(1, 8, 8) for c in castling] 129 | board_buff_map = map(toBin, boardStr) 130 | board_buff = b''.join(board_buff_map) 131 | a = np.frombuffer(board_buff, dtype='bool') 132 | a = a.reshape(8, 8, -1) 133 | a = np.moveaxis(a, 2, 0) 134 | if is_white: 135 | colour_plane = np.ones((1, 8, 8), dtype='bool') 136 | else: 137 | colour_plane = np.zeros((1, 8, 8), dtype='bool') 138 | 139 | return np.concatenate([a, colour_plane, *castles], axis = 0) 140 | 141 | def preproc_fen(fenstr): 142 | r = boardRE.match(fenstr) 143 | if r.group(14): 144 | castling = (False, False, False, False) 145 | else: 146 | castling = (bool(r.group(15)), bool(r.group(16)), bool(r.group(17)), bool(r.group(18))) 147 | if r.group(11): 148 | is_white = True 149 | rows_lst = r.group(1).split('/') 150 | else: 151 | is_white = False 152 | castling = castling[2:] + castling[:2] 153 | rows_lst = r.group(1).swapcase().split('/') 154 | rows_lst = reversed([s[::-1] for s in rows_lst]) 155 | 156 | rowsS = ''.join(rows_lst) 157 | for i, iS in enumerate(iSs): 158 | if iS in rowsS: 159 | rowsS = rowsS.replace(iS, 'E' * (i + 1)) 160 | return rowsS, is_white, castling 161 | 162 | def fen_to_vec(fenstr): 163 | return simple_fen_vec(*preproc_fen(fenstr)) 164 | 165 | def game_to_vecs(game): 166 | boards = [] 167 | board = game.board() 168 | for i, node in enumerate(game.mainline()): 169 | fen = str(board.fen()) 170 | board.push(node.move) 171 | boards.append(fenToVec(fen)) 172 | return np.stack(boards, axis = 0) 173 | -------------------------------------------------------------------------------- /backend/multiproc.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import collections.abc 3 | import time 4 | import sys 5 | import traceback 6 | import functools 7 | import pickle 8 | 9 | class Multiproc(object): 10 | def __init__(self, num_procs, max_queue_size = 1000, proc_check_interval = .1): 11 | self.num_procs = num_procs 12 | self.max_queue_size = max_queue_size 13 | self.proc_check_interval = proc_check_interval 14 | 15 | self.reader = MultiprocIterable 16 | self.reader_args = [] 17 | self.reader_kwargs = {} 18 | 19 | self.processor = MultiprocWorker 20 | self.processor_args = [] 21 | self.processor_kwargs = {} 22 | 23 | self.writer = MultiprocWorker 24 | self.writer_args = [] 25 | self.writer_kwargs = {} 26 | 27 | def reader_init(self, reader_cls, *reader_args, **reader_kwargs): 28 | self.reader = reader_cls 29 | self.reader_args = reader_args 30 | self.reader_kwargs = reader_kwargs 31 | 32 | def processor_init(self, processor_cls, *processor_args, **processor_kwargs): 33 | self.processor = processor_cls 34 | self.processor_args = processor_args 35 | self.processor_kwargs = processor_kwargs 36 | 37 | def writer_init(self, writer_cls, *writer_args, **writer_kwargs): 38 | self.writer = writer_cls 39 | self.writer_args = writer_args 40 | self.writer_kwargs = writer_kwargs 41 | 42 | 43 | def run(self): 44 | with multiprocessing.Pool(self.num_procs + 2) as pool, multiprocessing.Manager() as manager: 45 | inputQueue = manager.Queue(self.max_queue_size) 46 | resultsQueue = manager.Queue(self.max_queue_size) 47 | reader_proc = pool.apply_async(reader_loop, (inputQueue, self.num_procs, self.reader, self.reader_args, self.reader_kwargs)) 48 | 49 | worker_procs = [] 50 | for _ in range(self.num_procs): 51 | wp = pool.apply_async(processor_loop, (inputQueue, resultsQueue, self.processor, self.processor_args, self.processor_kwargs)) 52 | worker_procs.append(wp) 53 | 54 | writer_proc = pool.apply_async(writer_loop, (resultsQueue, self.num_procs, self.writer, self.writer_args, self.writer_kwargs)) 55 | 56 | self.cleanup(reader_proc, worker_procs, writer_proc) 57 | 58 | def cleanup(self, reader_proc, worker_procs, writer_proc): 59 | reader_working = True 60 | processor_working = True 61 | writer_working = True 62 | while reader_working or processor_working or writer_working: 63 | if reader_working and reader_proc.ready(): 64 | reader_proc.get() 65 | reader_working = False 66 | 67 | if processor_working: 68 | new_procs = [] 69 | for p in worker_procs: 70 | if p.ready(): 71 | p.get() 72 | else: 73 | new_procs.append(p) 74 | if len(new_procs) < 1: 75 | processor_working = False 76 | else: 77 | worker_procs = new_procs 78 | 79 | if writer_working and writer_proc.ready(): 80 | writer_proc.get() 81 | writer_working = False 82 | time.sleep(self.proc_check_interval) 83 | 84 | def catch_remote_exceptions(wrapped_function): 85 | """ https://stackoverflow.com/questions/6126007/python-getting-a-traceback """ 86 | 87 | @functools.wraps(wrapped_function) 88 | def new_function(*args, **kwargs): 89 | try: 90 | return wrapped_function(*args, **kwargs) 91 | 92 | except: 93 | raise Exception( "".join(traceback.format_exception(*sys.exc_info())) ) 94 | 95 | return new_function 96 | 97 | @catch_remote_exceptions 98 | def reader_loop(inputQueue, num_workers, reader_cls, reader_args, reader_kwargs): 99 | with reader_cls(*reader_args, **reader_kwargs) as R: 100 | for dat in R: 101 | inputQueue.put(dat) 102 | for i in range(num_workers): 103 | inputQueue.put(_QueueDone(count = i)) 104 | 105 | @catch_remote_exceptions 106 | def processor_loop(inputQueue, resultsQueue, processor_cls, processor_args, processor_kwargs): 107 | with processor_cls(*processor_args, **processor_kwargs) as Proc: 108 | while True: 109 | dat = inputQueue.get() 110 | if isinstance(dat, _QueueDone): 111 | resultsQueue.put(dat) 112 | break 113 | try: 114 | if isinstance(dat, tuple): 115 | procced_dat = Proc(*dat) 116 | else: 117 | procced_dat = Proc(dat) 118 | except SkipCallMultiProc: 119 | pass 120 | except: 121 | raise 122 | resultsQueue.put(procced_dat) 123 | 124 | @catch_remote_exceptions 125 | def writer_loop(resultsQueue, num_workers, writer_cls, writer_args, writer_kwargs): 126 | complete_workers = 0 127 | with writer_cls(*writer_args, **writer_kwargs) as W: 128 | if W is None: 129 | raise AttributeError(f"Worker was created, but closure failed to form") 130 | while complete_workers < num_workers: 131 | dat = resultsQueue.get() 132 | if isinstance(dat, _QueueDone): 133 | complete_workers += 1 134 | else: 135 | if isinstance(dat, tuple): 136 | W(*dat) 137 | else: 138 | W(dat) 139 | 140 | class SkipCallMultiProc(Exception): 141 | pass 142 | 143 | class _QueueDone(object): 144 | def __init__(self, count = 0): 145 | self.count = count 146 | 147 | class MultiprocWorker(collections.abc.Callable): 148 | 149 | def __call__(self, *args): 150 | return None 151 | 152 | def __enter__(self): 153 | return self 154 | 155 | def __exit__(self, exc_type, exc_value, traceback): 156 | pass 157 | 158 | class MultiprocIterable(MultiprocWorker, collections.abc.Iterator): 159 | def __next__(self): 160 | raise StopIteration 161 | 162 | def __call__(self, *args): 163 | return next(self) 164 | -------------------------------------------------------------------------------- /backend/pgn_parsering.py: -------------------------------------------------------------------------------- 1 | import re 2 | import bz2 3 | 4 | import chess.pgn 5 | 6 | moveRegex = re.compile(r'\d+[.][ \.](\S+) (?:{[^}]*} )?(\S+)') 7 | 8 | class GamesFile(object): 9 | def __init__(self, path): 10 | if path.endswith('bz2'): 11 | self.f = bz2.open(path, 'rt') 12 | else: 13 | self.f = open(path, 'r') 14 | self.path = path 15 | self.i = 0 16 | 17 | def __iter__(self): 18 | try: 19 | while True: 20 | yield next(self) 21 | except StopIteration: 22 | return 23 | 24 | def __del__(self): 25 | try: 26 | self.f.close() 27 | except AttributeError: 28 | pass 29 | 30 | def __next__(self): 31 | 32 | ret = {} 33 | lines = '' 34 | for l in self.f: 35 | self.i += 1 36 | lines += l 37 | if len(l) < 2: 38 | if len(ret) >= 2: 39 | break 40 | else: 41 | raise RuntimeError(l) 42 | else: 43 | try: 44 | k, v, _ = l.split('"') 45 | except ValueError: 46 | #bad line 47 | if l == 'null\n': 48 | pass 49 | else: 50 | raise 51 | else: 52 | ret[k[1:-1]] = v 53 | nl = self.f.readline() 54 | lines += nl 55 | lines += self.f.readline() 56 | if len(lines) < 1: 57 | raise StopIteration 58 | return ret, lines 59 | -------------------------------------------------------------------------------- /backend/proto/__init__.py: -------------------------------------------------------------------------------- 1 | from .net_pb2 import Net, NetworkFormat 2 | -------------------------------------------------------------------------------- /backend/proto/net.proto: -------------------------------------------------------------------------------- 1 | /* 2 | This file is part of Leela Chess Zero. 3 | Copyright (C) 2018 The LCZero Authors 4 | 5 | Leela Chess is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | Leela Chess is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with Leela Chess. If not, see . 17 | 18 | Additional permission under GNU GPL version 3 section 7 19 | 20 | If you modify this Program, or any covered work, by linking or 21 | combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA 22 | Toolkit and the NVIDIA CUDA Deep Neural Network library (or a 23 | modified version of those libraries), containing parts covered by the 24 | terms of the respective license agreement, the licensors of this 25 | Program grant you additional permission to convey the resulting work. 26 | */ 27 | syntax = "proto2"; 28 | 29 | package pblczero; 30 | 31 | message EngineVersion { 32 | optional uint32 major = 1; 33 | optional uint32 minor = 2; 34 | optional uint32 patch = 3; 35 | } 36 | 37 | message Weights { 38 | message Layer { 39 | optional float min_val = 1; 40 | optional float max_val = 2; 41 | optional bytes params = 3; 42 | } 43 | 44 | message ConvBlock { 45 | optional Layer weights = 1; 46 | optional Layer biases = 2; 47 | optional Layer bn_means = 3; 48 | optional Layer bn_stddivs = 4; 49 | optional Layer bn_gammas = 5; 50 | optional Layer bn_betas = 6; 51 | } 52 | 53 | message SEunit { 54 | // Squeeze-excitation unit (https://arxiv.org/abs/1709.01507) 55 | // weights and biases of the two fully connected layers. 56 | optional Layer w1 = 1; 57 | optional Layer b1 = 2; 58 | optional Layer w2 = 3; 59 | optional Layer b2 = 4; 60 | } 61 | 62 | message Residual { 63 | optional ConvBlock conv1 = 1; 64 | optional ConvBlock conv2 = 2; 65 | optional SEunit se = 3; 66 | } 67 | 68 | // Input convnet. 69 | optional ConvBlock input = 1; 70 | 71 | // Residual tower. 72 | repeated Residual residual = 2; 73 | 74 | // Policy head 75 | // Extra convolution for AZ-style policy head 76 | optional ConvBlock policy1 = 11; 77 | optional ConvBlock policy = 3; 78 | optional Layer ip_pol_w = 4; 79 | optional Layer ip_pol_b = 5; 80 | 81 | // Value head 82 | optional ConvBlock value = 6; 83 | optional Layer ip1_val_w = 7; 84 | optional Layer ip1_val_b = 8; 85 | optional Layer ip2_val_w = 9; 86 | optional Layer ip2_val_b = 10; 87 | } 88 | 89 | message TrainingParams { 90 | optional uint32 training_steps = 1; 91 | optional float learning_rate = 2; 92 | optional float mse_loss = 3; 93 | optional float policy_loss = 4; 94 | optional float accuracy = 5; 95 | optional string lc0_params = 6; 96 | } 97 | 98 | message NetworkFormat { 99 | // Format to encode the input planes with. Used by position encoder. 100 | enum InputFormat { 101 | INPUT_UNKNOWN = 0; 102 | INPUT_CLASSICAL_112_PLANE = 1; 103 | // INPUT_WITH_COORDINATE_PLANES = 2; // Example. Uncomment/rename. 104 | } 105 | optional InputFormat input = 1; 106 | 107 | // Output format of the NN. Used by search code to interpret results. 108 | enum OutputFormat { 109 | OUTPUT_UNKNOWN = 0; 110 | OUTPUT_CLASSICAL = 1; 111 | OUTPUT_WDL = 2; 112 | } 113 | optional OutputFormat output = 2; 114 | 115 | // Network architecture. Used by backends to build the network. 116 | enum NetworkStructure { 117 | // Networks without PolicyFormat or ValueFormat specified 118 | NETWORK_UNKNOWN = 0; 119 | NETWORK_CLASSICAL = 1; 120 | NETWORK_SE = 2; 121 | // Networks with PolicyFormat and ValueFormat specified 122 | NETWORK_CLASSICAL_WITH_HEADFORMAT = 3; 123 | NETWORK_SE_WITH_HEADFORMAT = 4; 124 | } 125 | optional NetworkStructure network = 3; 126 | 127 | // Policy head architecture 128 | enum PolicyFormat { 129 | POLICY_UNKNOWN = 0; 130 | POLICY_CLASSICAL = 1; 131 | POLICY_CONVOLUTION = 2; 132 | } 133 | optional PolicyFormat policy = 4; 134 | 135 | // Value head architecture 136 | enum ValueFormat { 137 | VALUE_UNKNOWN = 0; 138 | VALUE_CLASSICAL = 1; 139 | VALUE_WDL = 2; 140 | } 141 | optional ValueFormat value = 5; 142 | } 143 | 144 | message Format { 145 | enum Encoding { 146 | UNKNOWN = 0; 147 | LINEAR16 = 1; 148 | } 149 | 150 | optional Encoding weights_encoding = 1; 151 | // If network_format is missing, it's assumed to have 152 | // INPUT_CLASSICAL_112_PLANE / OUTPUT_CLASSICAL / NETWORK_CLASSICAL format. 153 | optional NetworkFormat network_format = 2; 154 | } 155 | 156 | message Net { 157 | optional fixed32 magic = 1; 158 | optional string license = 2; 159 | optional EngineVersion min_version = 3; 160 | optional Format format = 4; 161 | optional TrainingParams training_params = 5; 162 | optional Weights weights = 10; 163 | } 164 | -------------------------------------------------------------------------------- /backend/tf_transfer/__init__.py: -------------------------------------------------------------------------------- 1 | from .tfprocess import TFProcess 2 | from .chunkparser import ChunkParser 3 | from .net import * 4 | from .training_shared import * 5 | from .utils import * 6 | -------------------------------------------------------------------------------- /backend/tf_transfer/lc0_az_policy_map.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import sys 3 | import numpy as np 4 | from .policy_index import policy_index 5 | 6 | columns = 'abcdefgh' 7 | rows = '12345678' 8 | promotions = 'rbq' # N is encoded as normal move 9 | 10 | col_index = {columns[i] : i for i in range(len(columns))} 11 | row_index = {rows[i] : i for i in range(len(rows))} 12 | 13 | def index_to_position(x): 14 | return columns[x[0]] + rows[x[1]] 15 | 16 | def position_to_index(p): 17 | return col_index[p[0]], row_index[p[1]] 18 | 19 | def valid_index(i): 20 | if i[0] > 7 or i[0] < 0: 21 | return False 22 | if i[1] > 7 or i[1] < 0: 23 | return False 24 | return True 25 | 26 | def queen_move(start, direction, steps): 27 | i = position_to_index(start) 28 | dir_vectors = {'N': (0, 1), 'NE': (1, 1), 'E': (1, 0), 'SE': (1, -1), 29 | 'S':(0, -1), 'SW':(-1, -1), 'W': (-1, 0), 'NW': (-1, 1)} 30 | v = dir_vectors[direction] 31 | i = i[0] + v[0] * steps, i[1] + v[1] * steps 32 | if not valid_index(i): 33 | return None 34 | return index_to_position(i) 35 | 36 | def knight_move(start, direction, steps): 37 | i = position_to_index(start) 38 | dir_vectors = {'N': (1, 2), 'NE': (2, 1), 'E': (2, -1), 'SE': (1, -2), 39 | 'S':(-1, -2), 'SW':(-2, -1), 'W': (-2, 1), 'NW': (-1, 2)} 40 | v = dir_vectors[direction] 41 | i = i[0] + v[0] * steps, i[1] + v[1] * steps 42 | if not valid_index(i): 43 | return None 44 | return index_to_position(i) 45 | 46 | def make_map(kind='matrix'): 47 | # 56 planes of queen moves 48 | moves = [] 49 | for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']: 50 | for steps in range(1, 8): 51 | for r0 in rows: 52 | for c0 in columns: 53 | start = c0 + r0 54 | end = queen_move(start, direction, steps) 55 | if end == None: 56 | moves.append('illegal') 57 | else: 58 | moves.append(start+end) 59 | 60 | # 8 planes of knight moves 61 | for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']: 62 | for r0 in rows: 63 | for c0 in columns: 64 | start = c0 + r0 65 | end = knight_move(start, direction, 1) 66 | if end == None: 67 | moves.append('illegal') 68 | else: 69 | moves.append(start+end) 70 | 71 | # 9 promotions 72 | for direction in ['NW', 'N', 'NE']: 73 | for promotion in promotions: 74 | for r0 in rows: 75 | for c0 in columns: 76 | # Promotion only in the second last rank 77 | if r0 != '7': 78 | moves.append('illegal') 79 | continue 80 | start = c0 + r0 81 | end = queen_move(start, direction, 1) 82 | if end == None: 83 | moves.append('illegal') 84 | else: 85 | moves.append(start+end+promotion) 86 | 87 | for m in policy_index: 88 | if m not in moves: 89 | raise ValueError('Missing move: {}'.format(m)) 90 | 91 | az_to_lc0 = np.zeros((80*8*8, len(policy_index)), dtype=np.float32) 92 | indices = [] 93 | legal_moves = 0 94 | for e, m in enumerate(moves): 95 | if m == 'illegal': 96 | indices.append(-1) 97 | continue 98 | legal_moves += 1 99 | # Check for missing moves 100 | if m not in policy_index: 101 | raise ValueError('Missing move: {}'.format(m)) 102 | i = policy_index.index(m) 103 | indices.append(i) 104 | az_to_lc0[e][i] = 1 105 | 106 | assert legal_moves == len(policy_index) 107 | assert np.sum(az_to_lc0) == legal_moves 108 | for e in range(80*8*8): 109 | for i in range(len(policy_index)): 110 | pass 111 | if kind == 'matrix': 112 | return az_to_lc0 113 | elif kind == 'index': 114 | return indices 115 | 116 | if __name__ == "__main__": 117 | # Generate policy map include file for lc0 118 | if len(sys.argv) != 2: 119 | raise ValueError("Output filename is needed as a command line argument") 120 | 121 | az_to_lc0 = np.ravel(make_map('index')) 122 | header = \ 123 | """/* 124 | This file is part of Leela Chess Zero. 125 | Copyright (C) 2019 The LCZero Authors 126 | 127 | Leela Chess is free software: you can redistribute it and/or modify 128 | it under the terms of the GNU General Public License as published by 129 | the Free Software Foundation, either version 3 of the License, or 130 | (at your option) any later version. 131 | 132 | Leela Chess is distributed in the hope that it will be useful, 133 | but WITHOUT ANY WARRANTY; without even the implied warranty of 134 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 135 | GNU General Public License for more details. 136 | 137 | You should have received a copy of the GNU General Public License 138 | along with Leela Chess. If not, see . 139 | */ 140 | 141 | #pragma once 142 | 143 | namespace lczero { 144 | """ 145 | line_length = 12 146 | with open(sys.argv[1], 'w') as f: 147 | f.write(header+'\n') 148 | f.write('const short kConvPolicyMap[] = {\\\n') 149 | for e, i in enumerate(az_to_lc0): 150 | if e % line_length == 0 and e > 0: 151 | f.write('\n') 152 | f.write(str(i).rjust(5)) 153 | if e != len(az_to_lc0)-1: 154 | f.write(',') 155 | f.write('};\n\n') 156 | f.write('} // namespace lczero') 157 | -------------------------------------------------------------------------------- /backend/tf_transfer/net_to_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import tensorflow as tf 4 | import os 5 | import yaml 6 | from .tfprocess import TFProcess 7 | from .net import Net 8 | 9 | argparser = argparse.ArgumentParser(description='Convert net to model.') 10 | argparser.add_argument('net', type=str, 11 | help='Net file to be converted to a model checkpoint.') 12 | argparser.add_argument('--start', type=int, default=0, 13 | help='Offset to set global_step to.') 14 | argparser.add_argument('--cfg', type=argparse.FileType('r'), 15 | help='yaml configuration with training parameters') 16 | args = argparser.parse_args() 17 | cfg = yaml.safe_load(args.cfg.read()) 18 | print(yaml.dump(cfg, default_flow_style=False)) 19 | START_FROM = args.start 20 | net = Net() 21 | net.parse_proto(args.net) 22 | 23 | filters, blocks = net.filters(), net.blocks() 24 | if cfg['model']['filters'] != filters: 25 | raise ValueError("Number of filters in YAML doesn't match the network") 26 | if cfg['model']['residual_blocks'] != blocks: 27 | raise ValueError("Number of blocks in YAML doesn't match the network") 28 | weights = net.get_weights() 29 | 30 | tfp = TFProcess(cfg) 31 | tfp.init_net_v2() 32 | tfp.replace_weights_v2(weights) 33 | tfp.global_step.assign(START_FROM) 34 | 35 | root_dir = os.path.join(cfg['training']['path'], cfg['name']) 36 | if not os.path.exists(root_dir): 37 | os.makedirs(root_dir) 38 | tfp.manager.save() 39 | print("Wrote model to {}".format(tfp.manager.latest_checkpoint)) 40 | -------------------------------------------------------------------------------- /backend/tf_transfer/shufflebuffer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # This file is part of Leela Chess. 4 | # Copyright (C) 2018 Michael O 5 | # 6 | # Leela Chess is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Leela Chess is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Leela Chess. If not, see . 18 | 19 | import random 20 | import unittest 21 | 22 | class ShuffleBuffer: 23 | def __init__(self, elem_size, elem_count): 24 | """ 25 | A shuffle buffer for fixed sized elements. 26 | 27 | Manages 'elem_count' items in a fixed buffer, each item being exactly 28 | 'elem_size' bytes. 29 | """ 30 | assert elem_size > 0, elem_size 31 | assert elem_count > 0, elem_count 32 | # Size of each element. 33 | self.elem_size = elem_size 34 | # Number of elements in the buffer. 35 | self.elem_count = elem_count 36 | # Fixed size buffer used to hold all the element. 37 | self.buffer = bytearray(elem_size * elem_count) 38 | # Number of elements actually contained in the buffer. 39 | self.used = 0 40 | 41 | def extract(self): 42 | """ 43 | Return an item from the shuffle buffer. 44 | 45 | If the buffer is empty, returns None 46 | """ 47 | if self.used < 1: 48 | return None 49 | # The items in the shuffle buffer are held in shuffled order 50 | # so returning the last item is sufficient. 51 | self.used -= 1 52 | i = self.used 53 | return self.buffer[i * self.elem_size : (i+1) * self.elem_size] 54 | 55 | def insert_or_replace(self, item): 56 | """ 57 | Inserts 'item' into the shuffle buffer, returning 58 | a random item. 59 | 60 | If the buffer is not yet full, returns None 61 | """ 62 | assert len(item) == self.elem_size, len(item) 63 | # putting the new item in a random location, and appending 64 | # the displaced item to the end of the buffer achieves a full 65 | # random shuffle (Fisher-Yates) 66 | if self.used > 0: 67 | # swap 'item' with random item in buffer. 68 | i = random.randint(0, self.used-1) 69 | old_item = self.buffer[i * self.elem_size : (i+1) * self.elem_size] 70 | self.buffer[i * self.elem_size : (i+1) * self.elem_size] = item 71 | item = old_item 72 | # If the buffer isn't yet full, append 'item' to the end of the buffer. 73 | if self.used < self.elem_count: 74 | # Not yet full, so place the returned item at the end of the buffer. 75 | i = self.used 76 | self.buffer[i * self.elem_size : (i+1) * self.elem_size] = item 77 | self.used += 1 78 | return None 79 | return item 80 | 81 | 82 | class ShuffleBufferTest(unittest.TestCase): 83 | def test_extract(self): 84 | sb = ShuffleBuffer(3, 1) 85 | r = sb.extract() 86 | assert r == None, r # empty buffer => None 87 | r = sb.insert_or_replace(b'111') 88 | assert r == None, r # buffer not yet full => None 89 | r = sb.extract() 90 | assert r == b'111', r # one item in buffer => item 91 | r = sb.extract() 92 | assert r == None, r # buffer empty => None 93 | def test_wrong_size(self): 94 | sb = ShuffleBuffer(3, 1) 95 | try: 96 | sb.insert_or_replace(b'1') # wrong length, so should throw. 97 | assert False # Should not be reached. 98 | except: 99 | pass 100 | def test_insert_or_replace(self): 101 | n=10 # number of test items. 102 | items=[bytes([x,x,x]) for x in range(n)] 103 | sb = ShuffleBuffer(elem_size=3, elem_count=2) 104 | out=[] 105 | for i in items: 106 | r = sb.insert_or_replace(i) 107 | if not r is None: 108 | out.append(r) 109 | # Buffer size is 2, 10 items, should be 8 seen so far. 110 | assert len(out) == n - 2, len(out) 111 | # Get the last two items. 112 | out.append(sb.extract()) 113 | out.append(sb.extract()) 114 | assert sorted(items) == sorted(out), (items, out) 115 | # Check that buffer is empty 116 | r = sb.extract() 117 | assert r is None, r 118 | 119 | 120 | if __name__ == '__main__': 121 | unittest.main() 122 | -------------------------------------------------------------------------------- /backend/tf_transfer/training_shared.py: -------------------------------------------------------------------------------- 1 | from ..utils import printWithDate 2 | 3 | import tensorflow as tf 4 | 5 | import glob 6 | import os 7 | import os.path 8 | import random 9 | import gzip 10 | import sys 11 | 12 | def get_latest_chunks(path): 13 | chunks = [] 14 | printWithDate(f"found {glob.glob(path)} chunk dirs") 15 | whites = [] 16 | blacks = [] 17 | 18 | for d in glob.glob(path): 19 | for root, dirs, files in os.walk(d): 20 | for fpath in files: 21 | if fpath.endswith('.gz'): 22 | #TODO: Make less sketchy 23 | if 'black' in root: 24 | blacks.append(os.path.join(root, fpath)) 25 | elif 'white' in root: 26 | whites.append(os.path.join(root, fpath)) 27 | else: 28 | raise RuntimeError( 29 | f"invalid chunk path found:{os.path.join(root, fpath)}") 30 | 31 | printWithDate( 32 | f"found {len(whites)} white {len(blacks)} black chunks", end='\r') 33 | printWithDate(f"found {len(whites) + len(blacks)} chunks total") 34 | if len(whites) < 1 or len(blacks) < 1: 35 | print("Not enough chunks {}".format(len(blacks))) 36 | sys.exit(1) 37 | 38 | print("sorting {} B chunks...".format(len(blacks)), end='') 39 | blacks.sort(key=os.path.getmtime, reverse=True) 40 | print("sorting {} W chunks...".format(len(whites)), end='') 41 | whites.sort(key=os.path.getmtime, reverse=True) 42 | print("[done]") 43 | print("{} - {}".format(os.path.basename(whites[-1]), os.path.basename(whites[0]))) 44 | print("{} - {}".format(os.path.basename(blacks[-1]), os.path.basename(blacks[0]))) 45 | random.shuffle(blacks) 46 | random.shuffle(whites) 47 | return whites, blacks 48 | 49 | 50 | class FileDataSrc: 51 | """ 52 | data source yielding chunkdata from chunk files. 53 | """ 54 | def __init__(self, white_chunks, black_chunks): 55 | self.white_chunks = [] 56 | self.white_done = white_chunks 57 | 58 | self.black_chunks = [] 59 | self.black_done = black_chunks 60 | 61 | self.next_is_white = True 62 | 63 | def next(self): 64 | self.next_is_white = not self.next_is_white 65 | return self.next_by_colour(not self.next_is_white) 66 | 67 | def next_by_colour(self, is_white): 68 | if is_white: 69 | if not self.white_chunks: 70 | self.white_chunks, self.white_done = self.white_done, self.white_chunks 71 | random.shuffle(self.white_chunks) 72 | if not self.white_chunks: 73 | return None 74 | while len(self.white_chunks): 75 | filename = self.white_chunks.pop() 76 | try: 77 | with gzip.open(filename, 'rb') as chunk_file: 78 | self.white_done.append(filename) 79 | return chunk_file.read(), True 80 | except: 81 | print("failed to parse {}".format(filename)) 82 | else: 83 | if not self.black_chunks: 84 | self.black_chunks, self.black_done = self.black_done, self.black_chunks 85 | random.shuffle(self.black_chunks) 86 | if not self.black_chunks: 87 | return None, False 88 | while len(self.black_chunks): 89 | filename = self.black_chunks.pop() 90 | try: 91 | with gzip.open(filename, 'rb') as chunk_file: 92 | self.black_done.append(filename) 93 | return chunk_file.read(), False 94 | except: 95 | print("failed to parse {}".format(filename)) 96 | 97 | -------------------------------------------------------------------------------- /backend/tf_transfer/update_steps.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os 4 | import yaml 5 | import sys 6 | import tensorflow as tf 7 | from .tfprocess import TFProcess 8 | 9 | START_FROM = 0 10 | 11 | def main(cmd): 12 | cfg = yaml.safe_load(cmd.cfg.read()) 13 | print(yaml.dump(cfg, default_flow_style=False)) 14 | 15 | root_dir = os.path.join(cfg['training']['path'], cfg['name']) 16 | if not os.path.exists(root_dir): 17 | os.makedirs(root_dir) 18 | 19 | tfprocess = TFProcess(cfg) 20 | tfprocess.init_net_v2() 21 | 22 | tfprocess.restore_v2() 23 | 24 | START_FROM = cmd.start 25 | 26 | tfprocess.global_step.assign(START_FROM) 27 | tfprocess.manager.save() 28 | 29 | if __name__ == "__main__": 30 | argparser = argparse.ArgumentParser(description=\ 31 | 'Convert current checkpoint to new step count.') 32 | argparser.add_argument('--cfg', type=argparse.FileType('r'), 33 | help='yaml configuration with training parameters') 34 | argparser.add_argument('--start', type=int, default=0, 35 | help='Offset to set global_step to.') 36 | 37 | main(argparser.parse_args()) 38 | -------------------------------------------------------------------------------- /backend/tf_transfer/utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import tempfile 3 | 4 | import tensorflow as tf 5 | 6 | def show_model(model, filename = None, detailed = False, show_shapes = False): 7 | if filename is None: 8 | tempf = tempfile.NamedTemporaryFile(suffix='.png') 9 | filename = tempf.name 10 | return tf.keras.utils.plot_model( 11 | model, 12 | to_file=filename, 13 | show_shapes=show_shapes, 14 | show_layer_names=True, 15 | rankdir='TB', 16 | expand_nested=detailed, 17 | dpi=96, 18 | ) 19 | -------------------------------------------------------------------------------- /backend/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import sys 3 | import time 4 | import datetime 5 | import os 6 | import os.path 7 | import traceback 8 | 9 | import pytz 10 | 11 | 12 | min_run_time = 60 * 10 # 10 minutes 13 | infos_dir_name = 'runinfos' 14 | tz = pytz.timezone('Canada/Eastern') 15 | 16 | colours = { 17 | 'blue' : '\033[94m', 18 | 'green' : '\033[92m', 19 | 'yellow' : '\033[93m', 20 | 'red' : '\033[91m', 21 | 'pink' : '\033[95m', 22 | } 23 | endColour = '\033[0m' 24 | 25 | def printWithDate(s, colour = None, **kwargs): 26 | if colour is None: 27 | print(f"{datetime.datetime.now(tz).strftime('%Y-%m-%d %H:%M:%S')} {s}", **kwargs) 28 | else: 29 | print(f"{datetime.datetime.now(tz).strftime('%Y-%m-%d %H:%M:%S')}{colours[colour]} {s}{endColour}", **kwargs) 30 | 31 | class Tee(object): 32 | #Based on https://stackoverflow.com/a/616686 33 | def __init__(self, fname, is_err = False): 34 | self.file = open(fname, 'a') 35 | self.is_err = is_err 36 | if is_err: 37 | self.stdstream = sys.stderr 38 | sys.stderr = self 39 | else: 40 | self.stdstream = sys.stdout 41 | sys.stdout = self 42 | def __del__(self): 43 | if self.is_err: 44 | sys.stderr = self.stdstream 45 | else: 46 | sys.stdout = self.stdstream 47 | self.file.close() 48 | def write(self, data): 49 | self.file.write(data) 50 | self.stdstream.write(data) 51 | def flush(self): 52 | self.file.flush() 53 | 54 | class LockedName(object): 55 | def __init__(self, script_name, start_time): 56 | self.script_name = script_name 57 | self.start_time = start_time 58 | os.makedirs(infos_dir_name, exist_ok = True) 59 | os.makedirs(os.path.join(infos_dir_name, self.script_name), exist_ok = True) 60 | 61 | self.file_prefix = self.get_name_prefix() 62 | self.full_prefix = self.file_prefix + f"-{start_time.strftime('%Y-%m-%d-%H%M')}_" 63 | self.lock = None 64 | self.lock_name = None 65 | 66 | def __enter__(self): 67 | try: 68 | self.lock_name = self.file_prefix + '.lock' 69 | self.lock = open(self.lock_name, 'x') 70 | except FileExistsError: 71 | self.file_prefix = self.get_name_prefix() 72 | self.full_prefix = self.file_prefix + f"-{start_time.strftime('%Y-%m-%d-%H%M')}_" 73 | return self.__enter__() 74 | return self 75 | 76 | def __exit__(self, exc_type, exc_value, tb): 77 | try: 78 | self.lock.close() 79 | os.remove(self.lock_name) 80 | except: 81 | pass 82 | 83 | def get_name_prefix(self): 84 | fdir = os.path.join(infos_dir_name, self.script_name) 85 | prefixes = [n.name.split('-')[0] for n in os.scandir(fdir) if n.is_file()] 86 | file_num = 1 87 | nums = [] 88 | for p in set(prefixes): 89 | try: 90 | nums.append(int(p)) 91 | except ValueError: 92 | pass 93 | if len(nums) > 0: 94 | file_num = max(nums) + 1 95 | 96 | return os.path.join(fdir, f"{file_num:04.0f}") 97 | 98 | def logged_main(mainFunc): 99 | @functools.wraps(mainFunc) 100 | def wrapped_main(*args, **kwds): 101 | start_time = datetime.datetime.now(tz) 102 | script_name = os.path.basename(sys.argv[0])[:-3] 103 | 104 | with LockedName(script_name, start_time) as name_lock: 105 | tee_out = Tee(name_lock.full_prefix + 'stdout.log', is_err = False) 106 | tee_err = Tee(name_lock.full_prefix + 'stderr.log', is_err = True) 107 | logs_prefix = name_lock.full_prefix 108 | printWithDate(' '.join(sys.argv), colour = 'blue') 109 | printWithDate(f"Starting {script_name}", colour = 'blue') 110 | try: 111 | tstart = time.time() 112 | val = mainFunc(*args, **kwds) 113 | except (Exception, KeyboardInterrupt) as e: 114 | printWithDate(f"Error encountered", colour = 'blue') 115 | if (time.time() - tstart) > min_run_time: 116 | makeLog(logs_prefix, start_time, tstart, True, 'Error', e, traceback.format_exc()) 117 | raise 118 | else: 119 | printWithDate(f"Run completed", colour = 'blue') 120 | if (time.time() - tstart) > min_run_time: 121 | makeLog(logs_prefix, start_time, tstart, False, 'Successful') 122 | tee_out.flush() 123 | tee_err.flush() 124 | return val 125 | return wrapped_main 126 | 127 | def makeLog(logs_prefix, start_time, tstart, is_error, *notes): 128 | fname = f'error.log' if is_error else f'run.log' 129 | with open(logs_prefix + fname, 'w') as f: 130 | f.write(f"start: {start_time.strftime('%Y-%m-%d-%H:%M:%S')}\n") 131 | f.write(f"stop: {datetime.datetime.now(tz).strftime('%Y-%m-%d-%H:%M:%S')}\n") 132 | f.write(f"duration: {int(tstart > min_run_time)}s\n") 133 | f.write(f"dir: {os.path.abspath(os.getcwd())}\n") 134 | f.write(f"{' '.join(sys.argv)}\n") 135 | f.write('\n'.join([str(n) for n in notes])) 136 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: transfer_chess 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _tflow_select=2.1.0=gpu 8 | - absl-py=0.9.0=py37_0 9 | - asn1crypto=1.3.0=py37_0 10 | - astor=0.8.0=py37_0 11 | - attrs=19.3.0=py_0 12 | - backcall=0.1.0=py37_0 13 | - blas=1.0=mkl 14 | - bleach=3.1.4=py_0 15 | - blinker=1.4=py37_0 16 | - c-ares=1.15.0=h7b6447c_1001 17 | - ca-certificates=2020.1.1=0 18 | - cachetools=3.1.1=py_0 19 | - cairo=1.14.12=h8948797_3 20 | - certifi=2020.4.5.1=py37_0 21 | - cffi=1.13.2=py37h2e261b9_0 22 | - chardet=3.0.4=py37_1003 23 | - click=7.0=py37_0 24 | - conda=4.8.3=py37_0 25 | - conda-package-handling=1.6.0=py37h7b6447c_0 26 | - cryptography=2.8=py37h1ba5d50_0 27 | - cudatoolkit=10.1.243=h6bb024c_0 28 | - cudnn=7.6.5=cuda10.1_0 29 | - cupti=10.1.168=0 30 | - dbus=1.13.12=h746ee38_0 31 | - decorator=4.4.1=py_0 32 | - defusedxml=0.6.0=py_0 33 | - entrypoints=0.3=py37_0 34 | - expat=2.2.6=he6710b0_0 35 | - fontconfig=2.13.0=h9420a91_0 36 | - freetype=2.9.1=h8a8886c_1 37 | - fribidi=1.0.5=h7b6447c_0 38 | - gast=0.2.2=py37_0 39 | - glib=2.63.1=h5a9c865_0 40 | - gmp=6.1.2=h6c8ec71_1 41 | - google-auth=1.11.2=py_0 42 | - google-auth-oauthlib=0.4.1=py_2 43 | - google-pasta=0.1.8=py_0 44 | - graphite2=1.3.13=h23475e2_0 45 | - graphviz=2.40.1=h21bd128_2 46 | - grpcio=1.27.2=py37hf8bcb03_0 47 | - gst-plugins-base=1.14.0=hbbd80ab_1 48 | - gstreamer=1.14.0=hb453b48_1 49 | - h5py=2.10.0=py37h7918eee_0 50 | - harfbuzz=1.8.8=hffaf4a1_0 51 | - hdf5=1.10.4=hb1b8bf9_0 52 | - icu=58.2=h9c2bf20_1 53 | - idna=2.8=py37_0 54 | - importlib_metadata=1.4.0=py37_0 55 | - intel-openmp=2020.0=166 56 | - ipykernel=5.1.4=py37h39e3cac_0 57 | - ipython=7.11.1=py37h39e3cac_0 58 | - ipython_genutils=0.2.0=py37_0 59 | - ipywidgets=7.5.1=py_0 60 | - jedi=0.16.0=py37_0 61 | - jinja2=2.11.1=py_0 62 | - joblib=0.14.1=py_0 63 | - jpeg=9b=h024ee3a_2 64 | - jsonschema=3.2.0=py37_0 65 | - jupyter=1.0.0=py37_7 66 | - jupyter_client=5.3.4=py37_0 67 | - jupyter_console=6.1.0=py_0 68 | - jupyter_core=4.6.1=py37_0 69 | - keras-applications=1.0.8=py_0 70 | - keras-preprocessing=1.1.0=py_1 71 | - libedit=3.1.20181209=hc058e9b_0 72 | - libffi=3.2.1=hd88cf55_4 73 | - libgcc-ng=9.1.0=hdf63c60_0 74 | - libgfortran-ng=7.3.0=hdf63c60_0 75 | - libpng=1.6.37=hbc83047_0 76 | - libprotobuf=3.11.4=hd408876_0 77 | - libsodium=1.0.16=h1bed415_0 78 | - libstdcxx-ng=9.1.0=hdf63c60_0 79 | - libtiff=4.1.0=h2733197_0 80 | - libuuid=1.0.3=h1bed415_2 81 | - libxcb=1.13=h1bed415_1 82 | - libxml2=2.9.9=hea5a465_1 83 | - markdown=3.1.1=py37_0 84 | - markupsafe=1.1.1=py37h7b6447c_0 85 | - meson=0.52.0=py_0 86 | - mistune=0.8.4=py37h7b6447c_0 87 | - mkl=2020.0=166 88 | - mkl-service=2.3.0=py37he904b0f_0 89 | - mkl_fft=1.0.15=py37ha843d7b_0 90 | - mkl_random=1.1.0=py37hd6b4f25_0 91 | - more-itertools=8.0.2=py_0 92 | - nb_conda_kernels=2.2.2=py37_0 93 | - nbconvert=5.6.1=py37_0 94 | - nbformat=5.0.4=py_0 95 | - ncurses=6.1=he6710b0_1 96 | - ninja=1.9.0=py37hfd86e86_0 97 | - notebook=6.0.3=py37_0 98 | - numpy=1.18.1=py37h4f9e942_0 99 | - numpy-base=1.18.1=py37hde5b4d6_1 100 | - oauthlib=3.1.0=py_0 101 | - olefile=0.46=py37_0 102 | - openssl=1.1.1f=h7b6447c_0 103 | - opt_einsum=3.1.0=py_0 104 | - pandoc=2.2.3.2=0 105 | - pandocfilters=1.4.2=py37_1 106 | - pango=1.42.4=h049681c_0 107 | - parso=0.6.0=py_0 108 | - pcre=8.43=he6710b0_0 109 | - pexpect=4.8.0=py37_0 110 | - pickleshare=0.7.5=py37_0 111 | - pillow=7.0.0=py37hb39fc2d_0 112 | - pip=20.0.2=py37_1 113 | - pixman=0.38.0=h7b6447c_0 114 | - prometheus_client=0.7.1=py_0 115 | - prompt_toolkit=3.0.3=py_0 116 | - protobuf=3.11.4=py37he6710b0_0 117 | - ptyprocess=0.6.0=py37_0 118 | - pyasn1=0.4.8=py_0 119 | - pyasn1-modules=0.2.7=py_0 120 | - pycosat=0.6.3=py37h7b6447c_0 121 | - pycparser=2.19=py37_0 122 | - pydot=1.4.1=py37_0 123 | - pygments=2.5.2=py_0 124 | - pyjwt=1.7.1=py37_0 125 | - pyopenssl=19.1.0=py37_0 126 | - pyparsing=2.4.6=py_0 127 | - pyqt=5.9.2=py37h05f1152_2 128 | - pyrsistent=0.15.7=py37h7b6447c_0 129 | - pysocks=1.7.1=py37_0 130 | - python=3.7.4=h265db76_1 131 | - python-dateutil=2.8.1=py_0 132 | - pytorch=1.4.0=py3.7_cuda10.1.243_cudnn7.6.3_0 133 | - pyyaml=5.3=py37h7b6447c_0 134 | - pyzmq=18.1.1=py37he6710b0_0 135 | - qt=5.9.7=h5867ecd_1 136 | - qtconsole=4.6.0=py_1 137 | - readline=7.0=h7b6447c_5 138 | - requests=2.22.0=py37_1 139 | - requests-oauthlib=1.3.0=py_0 140 | - rsa=4.0=py_0 141 | - ruamel_yaml=0.15.87=py37h7b6447c_0 142 | - scikit-learn=0.22.1=py37hd81dba3_0 143 | - scipy=1.4.1=py37h0b6359f_0 144 | - send2trash=1.5.0=py37_0 145 | - setuptools=45.1.0=py37_0 146 | - sip=4.19.8=py37hf484d3e_0 147 | - six=1.14.0=py37_0 148 | - sqlite=3.30.1=h7b6447c_0 149 | - tensorboard=2.1.0=py3_0 150 | - tensorflow=2.1.0=gpu_py37h7a4bb67_0 151 | - tensorflow-base=2.1.0=gpu_py37h6c5654b_0 152 | - tensorflow-estimator=2.1.0=pyhd54b08b_0 153 | - tensorflow-gpu=2.1.0=h0d30ee6_0 154 | - termcolor=1.1.0=py37_1 155 | - terminado=0.8.3=py37_0 156 | - testpath=0.4.4=py_0 157 | - tk=8.6.8=hbc83047_0 158 | - torchvision=0.5.0=py37_cu101 159 | - tornado=6.0.3=py37h7b6447c_0 160 | - tqdm=4.42.0=py_0 161 | - traitlets=4.3.3=py37_0 162 | - urllib3=1.25.8=py37_0 163 | - wcwidth=0.1.9=py_0 164 | - webencodings=0.5.1=py37_1 165 | - werkzeug=1.0.0=py_0 166 | - wheel=0.34.1=py37_0 167 | - widgetsnbextension=3.5.1=py37_0 168 | - wrapt=1.11.2=py37h7b6447c_0 169 | - xz=5.2.4=h14c3975_4 170 | - yaml=0.1.7=had09818_2 171 | - zeromq=4.3.1=he6710b0_3 172 | - zipp=2.2.0=py_0 173 | - zlib=1.2.11=h7b6447c_3 174 | - zstd=1.3.7=h0b5b093_0 175 | - pip: 176 | - cycler==0.10.0 177 | - humanize==2.4.0 178 | - kiwisolver==1.2.0 179 | - matplotlib==3.2.1 180 | - natsort==7.0.1 181 | - pandas==1.0.3 182 | - python-chess==0.30.1 183 | - pytz==2019.3 184 | - seaborn==0.10.0 185 | - tensorboardx==2.0 186 | 187 | -------------------------------------------------------------------------------- /images/kdd_indiv_final.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-individual/012c2179788ed6db8c44a0f9a94bbe84c1024294/images/kdd_indiv_final.jpg -------------------------------------------------------------------------------- /models/maia-1900/ckpt/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "ckpt-40" 2 | all_model_checkpoint_paths: "ckpt-1" 3 | all_model_checkpoint_paths: "ckpt-2" 4 | all_model_checkpoint_paths: "ckpt-3" 5 | all_model_checkpoint_paths: "ckpt-4" 6 | all_model_checkpoint_paths: "ckpt-5" 7 | all_model_checkpoint_paths: "ckpt-6" 8 | all_model_checkpoint_paths: "ckpt-7" 9 | all_model_checkpoint_paths: "ckpt-8" 10 | all_model_checkpoint_paths: "ckpt-9" 11 | all_model_checkpoint_paths: "ckpt-10" 12 | all_model_checkpoint_paths: "ckpt-11" 13 | all_model_checkpoint_paths: "ckpt-12" 14 | all_model_checkpoint_paths: "ckpt-13" 15 | all_model_checkpoint_paths: "ckpt-14" 16 | all_model_checkpoint_paths: "ckpt-15" 17 | all_model_checkpoint_paths: "ckpt-16" 18 | all_model_checkpoint_paths: "ckpt-17" 19 | all_model_checkpoint_paths: "ckpt-18" 20 | all_model_checkpoint_paths: "ckpt-19" 21 | all_model_checkpoint_paths: "ckpt-20" 22 | all_model_checkpoint_paths: "ckpt-21" 23 | all_model_checkpoint_paths: "ckpt-22" 24 | all_model_checkpoint_paths: "ckpt-23" 25 | all_model_checkpoint_paths: "ckpt-24" 26 | all_model_checkpoint_paths: "ckpt-25" 27 | all_model_checkpoint_paths: "ckpt-26" 28 | all_model_checkpoint_paths: "ckpt-27" 29 | all_model_checkpoint_paths: "ckpt-28" 30 | all_model_checkpoint_paths: "ckpt-29" 31 | all_model_checkpoint_paths: "ckpt-30" 32 | all_model_checkpoint_paths: "ckpt-31" 33 | all_model_checkpoint_paths: "ckpt-32" 34 | all_model_checkpoint_paths: "ckpt-33" 35 | all_model_checkpoint_paths: "ckpt-34" 36 | all_model_checkpoint_paths: "ckpt-35" 37 | all_model_checkpoint_paths: "ckpt-36" 38 | all_model_checkpoint_paths: "ckpt-37" 39 | all_model_checkpoint_paths: "ckpt-38" 40 | all_model_checkpoint_paths: "ckpt-39" 41 | all_model_checkpoint_paths: "ckpt-40" 42 | all_model_checkpoint_timestamps: 1580106783.8790061 43 | all_model_checkpoint_timestamps: 1580113034.5215666 44 | all_model_checkpoint_timestamps: 1580119167.9981554 45 | all_model_checkpoint_timestamps: 1580125270.5550704 46 | all_model_checkpoint_timestamps: 1580131382.6197543 47 | all_model_checkpoint_timestamps: 1580138060.0350215 48 | all_model_checkpoint_timestamps: 1580144931.4751053 49 | all_model_checkpoint_timestamps: 1580151357.3907902 50 | all_model_checkpoint_timestamps: 1580157406.0482683 51 | all_model_checkpoint_timestamps: 1580163445.5980349 52 | all_model_checkpoint_timestamps: 1580169474.1105049 53 | all_model_checkpoint_timestamps: 1580175510.0387604 54 | all_model_checkpoint_timestamps: 1580181567.815861 55 | all_model_checkpoint_timestamps: 1580187622.8185244 56 | all_model_checkpoint_timestamps: 1580193674.1944962 57 | all_model_checkpoint_timestamps: 1580199721.2665217 58 | all_model_checkpoint_timestamps: 1580205792.755944 59 | all_model_checkpoint_timestamps: 1580211859.5465987 60 | all_model_checkpoint_timestamps: 1580217928.1305025 61 | all_model_checkpoint_timestamps: 1580223989.668282 62 | all_model_checkpoint_timestamps: 1580231494.4801118 63 | all_model_checkpoint_timestamps: 1580240895.8979034 64 | all_model_checkpoint_timestamps: 1580250465.895426 65 | all_model_checkpoint_timestamps: 1580259628.7052832 66 | all_model_checkpoint_timestamps: 1580268883.0895178 67 | all_model_checkpoint_timestamps: 1580278314.7480402 68 | all_model_checkpoint_timestamps: 1580288003.8131309 69 | all_model_checkpoint_timestamps: 1580297809.2752874 70 | all_model_checkpoint_timestamps: 1580307735.15046 71 | all_model_checkpoint_timestamps: 1580318164.597156 72 | all_model_checkpoint_timestamps: 1580328825.4124599 73 | all_model_checkpoint_timestamps: 1580339783.5046844 74 | all_model_checkpoint_timestamps: 1580347138.0900939 75 | all_model_checkpoint_timestamps: 1580354427.078483 76 | all_model_checkpoint_timestamps: 1580360702.8677912 77 | all_model_checkpoint_timestamps: 1580366508.5701687 78 | all_model_checkpoint_timestamps: 1580372158.3093505 79 | all_model_checkpoint_timestamps: 1580377816.579277 80 | all_model_checkpoint_timestamps: 1580383466.9756734 81 | all_model_checkpoint_timestamps: 1580389118.3248632 82 | last_preserved_timestamp: 1580099931.4647074 83 | -------------------------------------------------------------------------------- /models/maia-1900/ckpt/ckpt-40-400000.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-individual/012c2179788ed6db8c44a0f9a94bbe84c1024294/models/maia-1900/ckpt/ckpt-40-400000.pb.gz -------------------------------------------------------------------------------- /models/maia-1900/ckpt/ckpt-40.data-00000-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-individual/012c2179788ed6db8c44a0f9a94bbe84c1024294/models/maia-1900/ckpt/ckpt-40.data-00000-of-00002 -------------------------------------------------------------------------------- /models/maia-1900/ckpt/ckpt-40.data-00001-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-individual/012c2179788ed6db8c44a0f9a94bbe84c1024294/models/maia-1900/ckpt/ckpt-40.data-00001-of-00002 -------------------------------------------------------------------------------- /models/maia-1900/ckpt/ckpt-40.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-individual/012c2179788ed6db8c44a0f9a94bbe84c1024294/models/maia-1900/ckpt/ckpt-40.index -------------------------------------------------------------------------------- /models/maia-1900/config.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | name: final_maia_1900 4 | display_name: Final Maia 1900 5 | engine: lc0_23 6 | options: 7 | nodes: 1 8 | weightsPath: final_1900-40.pb.gz 9 | movetime: 10 10 | threads: 8 11 | ... 12 | -------------------------------------------------------------------------------- /models/maia-1900/final_1900-40.pb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSSLab/maia-individual/012c2179788ed6db8c44a0f9a94bbe84c1024294/models/maia-1900/final_1900-40.pb.gz -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import re 3 | 4 | with open('backend/__init__.py') as f: 5 | versionString = re.search(r"__version__ = '(.+)'", f.read()).group(1) 6 | 7 | if __name__ == '__main__': 8 | setup(name='backend', 9 | version = versionString, 10 | author="Anon", 11 | author_email="anon@anon", 12 | packages = find_packages(), 13 | install_requires = [ 14 | 'numpy', 15 | 'matplotlib', 16 | 'pandas', 17 | 'seaborn', 18 | 'python-chess>=0.30.0', 19 | 'pytz', 20 | 'natsort', 21 | 'humanize', 22 | 'pyyaml', 23 | 'tensorboardX', 24 | ], 25 | ) 26 | --------------------------------------------------------------------------------