├── .gitignore ├── 0-player_counting ├── 0-find_top_players.sh ├── 1-collect_top_players.sh ├── 2-select_extended_set.sh ├── README.md ├── combine_player_counts.py ├── find_top_players.py ├── player_game_counts.py ├── select_binned_players.py ├── select_top_players.py └── split_by_players.py ├── 1-data_generation ├── 0-0-make_training_datasets.sh ├── 0-1-make_training_csvs.sh ├── 0-2-make_reduced_datasets.sh ├── 1-0-make_val_datasets.sh ├── 1-1-make_val_csvs.sh ├── 2-make_testing_datasets.sh ├── 9-pgn_to_training_data.sh ├── pgn_fractional_split.py ├── player_splits.sh └── split_by_player.py ├── 2-training ├── extended_configs │ ├── frozen_copy │ │ └── frozen_copy.yaml │ ├── frozen_random │ │ └── frozen_random.yaml │ ├── unfrozen_copy │ │ └── unfrozen_copy.yaml │ └── unfrozen_random │ │ └── unfrozen_random.yaml ├── final_config.yaml └── train_transfer.py ├── 3-analysis ├── 1-0-baselines_results.sh ├── 1-1-baselines_results_validation.sh ├── 2-0-baseline_results.sh ├── 2-1-model_results.sh ├── 2-2-model_results_val.sh ├── 3-0-model_cross_table.sh ├── 3-1-model_cross_table_val_generation.sh ├── 3-2-model_cross_table_val.sh ├── 4-0-result_summaries.sh ├── 4-1-result_summaries_cross.sh ├── 4-2-result_summaries_val.sh ├── csv_trimmer.py ├── get_accuracy.py ├── get_models_player.py ├── make_summary.py ├── move_predictions.sh ├── prediction_generator.py └── run-kdd-tests.sh ├── 4-cp_loss_stylo_baseline ├── README.md ├── get_cp_loss.py ├── get_cp_loss_per_game.py ├── get_cp_loss_per_move.py ├── get_cp_loss_per_move_per_game.py ├── get_cp_loss_per_move_per_game_count.py ├── results │ ├── games_accuracy.csv │ ├── start_after.csv │ ├── start_after_all_game.csv │ ├── stop_after.csv │ └── stop_after_all_game.csv ├── results_validation │ ├── games_accuracy.csv │ ├── start_after.csv │ ├── start_after_4games.csv │ ├── stop_after.csv │ └── stop_after_4games.csv ├── sweep_moves_all_games.py ├── sweep_moves_num_games.py ├── sweep_moves_per_game.py ├── sweep_num_games.py ├── test_all_games.py └── train_cploss_per_game.py ├── 9-reduced-data └── configs │ ├── Best_frozen.yaml │ ├── NFP.yaml │ └── Tuned.yaml ├── CITATION.cff ├── LICENSE ├── README.md ├── backend ├── __init__.py ├── fen_to_vec.py ├── multiproc.py ├── pgn_parsering.py ├── pgn_to_csv.py ├── proto │ ├── __init__.py │ ├── net.proto │ └── net_pb2.py ├── tf_transfer │ ├── __init__.py │ ├── chunkparser.py │ ├── decode_training.py │ ├── lc0_az_policy_map.py │ ├── net.py │ ├── net_to_model.py │ ├── policy_index.py │ ├── shufflebuffer.py │ ├── tfprocess.py │ ├── tfprocess_reg_lr_noise.py │ ├── training_shared.py │ ├── update_steps.py │ └── utils.py ├── uci_engine.py └── utils.py ├── environment.yml ├── images └── kdd_indiv_final.jpg ├── models └── maia-1900 │ ├── ckpt │ ├── checkpoint │ ├── ckpt-40-400000.pb.gz │ ├── ckpt-40.data-00000-of-00002 │ ├── ckpt-40.data-00001-of-00002 │ └── ckpt-40.index │ ├── config.yaml │ └── final_1900-40.pb.gz └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.zip 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /0-player_counting/0-find_top_players.sh: -------------------------------------------------------------------------------- 1 | ##!/bin/bash 2 | 3 | lichesss_raw_dir='/data/chess/bz2/standard/' 4 | output_dir='../../data/player_counts' 5 | mkdir -p $output_dir 6 | 7 | for t in $lichesss_raw_dir/*-{01..11}.pgn.bz2 $lichesss_raw_dir/*{3..8}-12.pgn.bz2; do 8 | fname="$(basename -- $t)" 9 | echo "${t} ${output_dir}/${fname}.csv.bz2" 10 | screen -S "filter-${fname}" -dm bash -c "source ~/.bashrc; python3 find_top_players.py ${t} ${output_dir}/${fname}.csv.bz2" 11 | done 12 | -------------------------------------------------------------------------------- /0-player_counting/1-collect_top_players.sh: -------------------------------------------------------------------------------- 1 | ##!/bin/bash 2 | 3 | lichesss_raw_dir='/data/chess/bz2/standard/' 4 | counts_dir='../../data/player_counts' 5 | counts_file='../../data/player_counts_combined.csv.bz2' 6 | top_list='../../data/player_counts_combined_top_names.csv.bz2' 7 | 8 | output_2000_dir='../../data/top_2000_player_games' 9 | output_2000_metadata_dir='../../data/top_2000_player_data' 10 | 11 | players_list='../../data/select_transfer_players' 12 | 13 | final_data_dir='../../data/transfer_players_data' 14 | 15 | num_train=10 16 | num_val=900 17 | num_test=100 18 | 19 | python3 combine_player_counts.py $counts_dir/* $counts_file 20 | 21 | bzcat $counts_file | head -n 2000 | bzip2 > $top_list 22 | 23 | mkdir -p $output_2000_dir 24 | 25 | python3 split_by_players.py $top_list $lichesss_raw_dir/*-{01..11}.pgn.bz2 $lichesss_raw_dir/*{3..8}-12.pgn.bz2 $output_2000_dir 26 | 27 | rm -v $top_list 28 | 29 | mkdir -p $output_2000_metadata_dir 30 | 31 | python3 player_game_counts.py $output_2000_dir $output_2000_metadata_dir 32 | 33 | python3 select_top_players.py $output_2000_metadata_dir \ 34 | ${players_list}_train.csv $num_train \ 35 | ${players_list}_validate.csv $num_val \ 36 | ${players_list}_test.csv $num_test \ 37 | 38 | mkdir -p $final_data_dir 39 | mkdir -p $final_data_dir/metadata 40 | cp -v ${players_list}*.csv $final_data_dir/metadata 41 | 42 | for c in "train" "validate" "test"; do 43 | mkdir $final_data_dir/${c} 44 | mkdir $final_data_dir/${c}_metadata 45 | for t in `tail -n +2 ${players_list}_${c}.csv|awk -F ',' '{print $1}'`; do 46 | cp -v ${output_2000_dir}/${t}.pgn.bz2 $final_data_dir/${c} 47 | cp ${output_2000_metadata_dir}/${t}.csv.bz2 $final_data_dir/${c}_metadata 48 | done 49 | done 50 | -------------------------------------------------------------------------------- /0-player_counting/2-select_extended_set.sh: -------------------------------------------------------------------------------- 1 | ##!/bin/bash 2 | set -e 3 | 4 | vals_dat_dir="../../data/transfer_players_data/validate_metadata/" 5 | vals_dir="../../data/transfer_players_validate" 6 | output_dir="../../data/transfer_players_extended" 7 | list_file='../../data/extended_list.csv' 8 | 9 | num_per_bin=5 10 | bins="1100 1300 1500 1700 1900" 11 | 12 | 13 | python3 select_binned_players.py $vals_dat_dir $list_file $num_per_bin $bins 14 | 15 | mkdir -p $output_dir 16 | 17 | while read player; do 18 | echo $player 19 | cp -r ${vals_dir}/${player} ${output_dir} 20 | done < $list_file 21 | -------------------------------------------------------------------------------- /0-player_counting/README.md: -------------------------------------------------------------------------------- 1 | # Player Counting 2 | 3 | This is the code we used to count the number of games each player has. 4 | -------------------------------------------------------------------------------- /0-player_counting/combine_player_counts.py: -------------------------------------------------------------------------------- 1 | import backend 2 | 3 | import argparse 4 | import bz2 5 | 6 | import pandas 7 | 8 | @backend.logged_main 9 | def main(): 10 | parser = argparse.ArgumentParser(description='Collect counts and create list from them', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 11 | parser.add_argument('inputs', nargs = '+', help='input csvs') 12 | parser.add_argument('output', help='output csv') 13 | args = parser.parse_args() 14 | 15 | counts = {} 16 | for p in args.inputs: 17 | backend.printWithDate(f"Processing {p}", end = '\r') 18 | df = pandas.read_csv(p) 19 | for i, row in df.iterrows(): 20 | try: 21 | counts[row['player']] += row['count'] 22 | except KeyError: 23 | counts[row['player']] = row['count'] 24 | backend.printWithDate(f"Writing") 25 | with bz2.open(args.output, 'wt') as f: 26 | f.write('player,count\n') 27 | for p, c in sorted(counts.items(), key = lambda x: x[1], reverse=True): 28 | f.write(f"{p},{c}\n") 29 | 30 | if __name__ == '__main__': 31 | main() 32 | -------------------------------------------------------------------------------- /0-player_counting/find_top_players.py: -------------------------------------------------------------------------------- 1 | import backend 2 | 3 | import argparse 4 | import bz2 5 | 6 | @backend.logged_main 7 | def main(): 8 | parser = argparse.ArgumentParser(description='Count number of times each player occurs in pgn', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 9 | 10 | parser.add_argument('input', help='input pgn') 11 | parser.add_argument('output', help='output csv') 12 | parser.add_argument('--exclude_bullet', action='store_false', help='Remove bullet games from counts') 13 | args = parser.parse_args() 14 | 15 | games = backend.GamesFile(args.input) 16 | 17 | counts = {} 18 | 19 | for i, (d, _) in enumerate(games): 20 | if args.exclude_bullet and 'Bullet' in d['Event']: 21 | continue 22 | else: 23 | add_player(d['White'], counts) 24 | add_player(d['Black'], counts) 25 | if i % 10000 == 0: 26 | backend.printWithDate(f"{i} done with {len(counts)} players from {args.input}", end = '\r') 27 | 28 | backend.printWithDate(f"{i} found total of {len(counts)} players from {args.input}") 29 | with bz2.open(args.output, 'wt') as f: 30 | f.write("player,count\n") 31 | for p, c in sorted(counts.items(), key = lambda x: x[1], reverse=True): 32 | f.write(f"{p},{c}\n") 33 | backend.printWithDate("done") 34 | 35 | def add_player(p, d): 36 | try: 37 | d[p] += 1 38 | except KeyError: 39 | d[p] = 1 40 | 41 | if __name__ == '__main__': 42 | main() 43 | -------------------------------------------------------------------------------- /0-player_counting/player_game_counts.py: -------------------------------------------------------------------------------- 1 | import backend 2 | 3 | import os 4 | import os.path 5 | import csv 6 | import bz2 7 | import argparse 8 | 9 | @backend.logged_main 10 | def main(): 11 | parser = argparse.ArgumentParser(description='Get some stats about each of the games') 12 | parser.add_argument('targets_dir', help='input pgns dir') 13 | parser.add_argument('output_dir', help='output csvs dir') 14 | parser.add_argument('--pool_size', type=int, help='Number of models to run in parallel', default = 64) 15 | args = parser.parse_args() 16 | multiProc = backend.Multiproc(args.pool_size) 17 | multiProc.reader_init(Files_lister, args.targets_dir) 18 | multiProc.processor_init(Games_processor, args.output_dir) 19 | 20 | multiProc.run() 21 | 22 | class Files_lister(backend.MultiprocIterable): 23 | def __init__(self, targets_dir): 24 | self.targets_dir = targets_dir 25 | self.targets = [(p.path, p.name.split('.')[0]) for p in os.scandir(targets_dir) if '.pgn.bz2' in p.name] 26 | backend.printWithDate(f"Found {len(self.targets)} targets in {targets_dir}") 27 | def __next__(self): 28 | try: 29 | backend.printWithDate(f"Pushed target {len(self.targets)} remaining", end = '\r', flush = True) 30 | return self.targets.pop() 31 | except IndexError: 32 | raise StopIteration 33 | 34 | class Games_processor(backend.MultiprocWorker): 35 | def __init__(self, output_dir): 36 | self.output_dir = output_dir 37 | 38 | def __call__(self, path, name): 39 | games = backend.GamesFile(path) 40 | with bz2.open(os.path.join(self.output_dir, f"{name}.csv.bz2"), 'wt') as f: 41 | writer = csv.DictWriter(f, ["player", "opponent","game_id", "ELO", "opp_ELO", "was_white", "result", "won", "UTCDate", "UTCTime", "TimeControl"]) 42 | 43 | writer.writeheader() 44 | for d, _ in games: 45 | game_dat = {} 46 | game_dat['player'] = name 47 | game_dat['game_id'] = d['Site'].split('/')[-1] 48 | game_dat['result'] = d['Result'] 49 | game_dat['UTCDate'] = d['UTCDate'] 50 | game_dat['UTCTime'] = d['UTCTime'] 51 | game_dat['TimeControl'] = d['TimeControl'] 52 | if d['Black'] == name: 53 | game_dat['was_white'] = False 54 | game_dat['opponent'] = d['White'] 55 | game_dat['ELO'] = d['BlackElo'] 56 | game_dat['opp_ELO'] = d['WhiteElo'] 57 | game_dat['won'] = d['Result'] == '0-1' 58 | else: 59 | game_dat['was_white'] = True 60 | game_dat['opponent'] = d['Black'] 61 | game_dat['ELO'] = d['WhiteElo'] 62 | game_dat['opp_ELO'] = d['BlackElo'] 63 | game_dat['won'] = d['Result'] == '1-0' 64 | writer.writerow(game_dat) 65 | 66 | if __name__ == '__main__': 67 | main() 68 | -------------------------------------------------------------------------------- /0-player_counting/select_binned_players.py: -------------------------------------------------------------------------------- 1 | import backend 2 | 3 | import argparse 4 | import bz2 5 | import glob 6 | import random 7 | import os.path 8 | import multiprocessing 9 | 10 | import pandas 11 | 12 | @backend.logged_main 13 | def main(): 14 | parser = argparse.ArgumentParser(description='Read all the metadata and select top n players for training/validation/testing', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 15 | parser.add_argument('csvs_dir', help='dir of csvs') 16 | parser.add_argument('output_list', help='list of targets') 17 | parser.add_argument('bin_size', type=int, help='players per bin') 18 | parser.add_argument('bins', type=int, nargs = '+', help='bins') 19 | parser.add_argument('--pool_size', type=int, help='Number of threads to use for reading', default = 48) 20 | parser.add_argument('--seed', type=int, help='random seed', default = 1) 21 | args = parser.parse_args() 22 | random.seed(args.seed) 23 | 24 | bins = [int(b // 100 * 100) for b in args.bins] 25 | 26 | with multiprocessing.Pool(args.pool_size) as pool: 27 | players = pool.map(load_player, glob.glob(os.path.join(args.csvs_dir, '*.csv.bz2'))) 28 | backend.printWithDate(f"Found {len(players)} players, using {len(bins)} bins") 29 | binned_players = {b : [] for b in bins} 30 | for p in players: 31 | pe_round = int(p['elo'] // 100 * 100) 32 | if pe_round in bins: 33 | binned_players[pe_round].append(p) 34 | backend.printWithDate(f"Found: " + ', '.join([f"{b} : {len(p)}" for b, p in binned_players.items()])) 35 | 36 | with open(args.output_list, 'wt') as f: 37 | for b, p in binned_players.items(): 38 | random.shuffle(p) 39 | print(b, [d['name'] for d in p[:args.bin_size]]) 40 | f.write('\n'.join([d['name'] for d in p[:args.bin_size]]) +'\n') 41 | 42 | def load_player(path): 43 | df = pandas.read_csv(path, low_memory=False) 44 | elo = df['ELO'][-10000:].mean() 45 | count = len(df) 46 | return { 47 | 'name' : df['player'].iloc[0], 48 | 'elo' : elo, 49 | 'count' : count, 50 | } 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /0-player_counting/select_top_players.py: -------------------------------------------------------------------------------- 1 | import backend 2 | 3 | import argparse 4 | import bz2 5 | import glob 6 | import random 7 | import os.path 8 | import multiprocessing 9 | 10 | import pandas 11 | 12 | @backend.logged_main 13 | def main(): 14 | parser = argparse.ArgumentParser(description='Read all the metadata and select top n players for training/validation/testing', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 15 | parser.add_argument('inputs', help='input csvs dir') 16 | parser.add_argument('output_train', help='output csv for training data') 17 | parser.add_argument('num_train', type=int, help='num for main training') 18 | parser.add_argument('output_val', help='output csv for validation data') 19 | parser.add_argument('num_val', type=int, help='num for big validation run') 20 | parser.add_argument('output_test', help='output csv for testing data') 21 | parser.add_argument('num_test', type=int, help='num for holdout set') 22 | parser.add_argument('--pool_size', type=int, help='Number of models to run in parallel', default = 48) 23 | parser.add_argument('--min_elo', type=int, help='min elo to select', default = 1100) 24 | parser.add_argument('--max_elo', type=int, help='max elo to select', default = 2000) 25 | parser.add_argument('--seed', type=int, help='random seed', default = 1) 26 | args = parser.parse_args() 27 | random.seed(args.seed) 28 | 29 | targets = glob.glob(os.path.join(args.inputs, '*csv.bz2')) 30 | 31 | with multiprocessing.Pool(args.pool_size) as pool: 32 | players = pool.starmap(check_player, ((t, args.min_elo, args.max_elo) for t in targets)) 33 | 34 | players_top = sorted( 35 | (p for p in players if p is not None), 36 | key = lambda x : x[1], 37 | reverse=True, 38 | )[:args.num_train + args.num_val + args.num_test] 39 | 40 | random.shuffle(players_top) 41 | 42 | write_output_file(args.output_train, args.num_train, players_top) 43 | write_output_file(args.output_val, args.num_val, players_top) 44 | write_output_file(args.output_test, args.num_test, players_top) 45 | 46 | def write_output_file(path, count, targets): 47 | with open(path, 'wt') as f: 48 | f.write("player,count,ELO\n") 49 | for i in range(count): 50 | t = targets.pop() 51 | f.write(f"{t[0]},{t[1]},{t[2]}\n") 52 | 53 | def check_player(path, min_elo, max_elo): 54 | df = pandas.read_csv(path, low_memory=False) 55 | elo = df['ELO'][-10000:].mean() 56 | count = len(df) 57 | if elo > min_elo and elo < max_elo: 58 | return path.split('/')[-1].split('.')[0], count, elo 59 | else: 60 | return None 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /0-player_counting/split_by_players.py: -------------------------------------------------------------------------------- 1 | import backend 2 | 3 | import pandas 4 | import lockfile 5 | 6 | import argparse 7 | import bz2 8 | import os 9 | import os.path 10 | 11 | @backend.logged_main 12 | def main(): 13 | parser = argparse.ArgumentParser(description='Write pgns of games with slected players in them', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 14 | 15 | parser.add_argument('target', help='target players list as csv') 16 | parser.add_argument('inputs', nargs = '+', help='input pgns') 17 | parser.add_argument('output', help='output dir') 18 | parser.add_argument('--exclude_bullet', action='store_false', help='Remove bullet games from counts') 19 | parser.add_argument('--pool_size', type=int, help='Number of models to run in parallel', default = 48) 20 | args = parser.parse_args() 21 | 22 | df_targets = pandas.read_csv(args.target) 23 | targets = set(df_targets['player']) 24 | 25 | os.makedirs(args.output, exist_ok=True) 26 | 27 | multiProc = backend.Multiproc(args.pool_size) 28 | multiProc.reader_init(Files_lister, args.inputs) 29 | multiProc.processor_init(Games_processor, targets, args.output, args.exclude_bullet) 30 | multiProc.run() 31 | backend.printWithDate("done") 32 | 33 | class Files_lister(backend.MultiprocIterable): 34 | def __init__(self, inputs): 35 | self.inputs = list(inputs) 36 | backend.printWithDate(f"Found {len(self.inputs)}") 37 | def __next__(self): 38 | try: 39 | backend.printWithDate(f"Pushed target {len(self.inputs)} remaining", end = '\r', flush = True) 40 | return self.inputs.pop() 41 | except IndexError: 42 | raise StopIteration 43 | 44 | class Games_processor(backend.MultiprocWorker): 45 | def __init__(self, targets, output_dir, exclude_bullet): 46 | self.output_dir = output_dir 47 | self.targets = targets 48 | self.exclude_bullet = exclude_bullet 49 | 50 | self.c = 0 51 | 52 | def __call__(self, path): 53 | games = backend.GamesFile(path) 54 | self.c = 0 55 | for i, (d, s) in enumerate(games): 56 | if self.exclude_bullet and 'Bullet' in d['Event']: 57 | continue 58 | else: 59 | if d['White'] in self.targets: 60 | self.write_player(d['White'], s) 61 | self.c += 1 62 | if d['Black'] in self.targets: 63 | self.write_player(d['Black'], s) 64 | self.c += 1 65 | if i % 10000 == 0: 66 | backend.printWithDate(f"{path} {i} done with {self.c} writes", end = '\r') 67 | 68 | def write_player(self, p_name, s): 69 | 70 | p_path = os.path.join(self.output_dir, f"{p_name}.pgn.bz2") 71 | lock_path = p_path + '.lock' 72 | lock = lockfile.FileLock(lock_path) 73 | with lock: 74 | with bz2.open(p_path, 'at') as f: 75 | f.write(s) 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /1-data_generation/0-0-make_training_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | train_frac=80 4 | val_frac=10 5 | test_frac=10 6 | 7 | input_files="/maiadata/transfer_players_data/train" 8 | output_files="/maiadata/transfer_players_train" 9 | mkdir -p $output_files 10 | 11 | for player_file in $input_files/*.bz2; do 12 | f=${player_file##*/} 13 | p_name=${f%.pgn.bz2} 14 | p_dir=$output_files/$p_name 15 | split_dir=$output_files/$p_name/split 16 | mkdir -p $p_dir 17 | mkdir -p $split_dir 18 | echo $p_name $p_dir 19 | python split_by_player.py $player_file $p_name $split_dir/games 20 | 21 | 22 | for c in "white" "black"; do 23 | python pgn_fractional_split.py $split_dir/games_$c.pgn.bz2 $split_dir/train_$c.pgn.bz2 $split_dir/validate_$c.pgn.bz2 $split_dir/test_$c.pgn.bz2 --ratios $train_frac $val_frac $test_frac 24 | 25 | cd $p_dir 26 | mkdir -p pgns 27 | for s in "train" "validate" "test"; do 28 | mkdir -p $s 29 | mkdir $s/$c 30 | 31 | #using tool from: 32 | #https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/ 33 | bzcat $split_dir/${s}_${c}.pgn.bz2 | pgn-extract -7 -C -N -#1000 34 | 35 | cat *.pgn > pgns/${s}_${c}.pgn 36 | rm -v *.pgn 37 | 38 | #using tool from: 39 | #https://github.com/DanielUranga/trainingdata-tool 40 | screen -S "${p_name}-${c}-${s}" -dm bash -c "cd ${s}/${c}; trainingdata-tool -v ../../pgns/${s}_${c}.pgn" 41 | done 42 | cd - 43 | done 44 | 45 | done 46 | -------------------------------------------------------------------------------- /1-data_generation/0-1-make_training_csvs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data_dir="/maiadata/transfer_players_train" 4 | 5 | for player_dir in $data_dir/*; do 6 | player_name=`basename ${player_dir}` 7 | mkdir $player_dir/csvs 8 | for c in "white" "black"; do 9 | for s in "train" "validate" "test"; do 10 | target=$player_dir/split/${s}_${c}.pgn.bz2 11 | output=$player_dir/csvs/${s}_${c}.csv.bz2 12 | echo ${player_name} ${s} ${c} 13 | screen -S "csv-${player_name}-${c}-${s}" -dm bash -c "python3 ../../data_generators/pgn_to_csv.py ${target} ${output}" 14 | done 15 | done 16 | done 17 | -------------------------------------------------------------------------------- /1-data_generation/0-2-make_reduced_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | train_frac=80 4 | val_frac=10 5 | test_frac=10 6 | 7 | input_files="/maiadata/transfer_players_data/train" 8 | output_files="/maiadata/transfer_players_train_reduced" 9 | mkdir -p $output_files 10 | 11 | fractions='100 10 1' 12 | 13 | for frac in `echo $fractions`; do 14 | for player_file in $input_files/*.bz2; do 15 | f=${player_file##*/} 16 | p_name=${f%.pgn.bz2} 17 | p_dir=$output_files/$p_name/$frac 18 | split_dir=$output_files/$p_name/$frac/split 19 | mkdir -p $p_dir 20 | mkdir -p $split_dir 21 | 22 | python pgn_fractional_split.py $player_file $p_dir/raw_reduced.pgn.bz2 $p_dir/extra.pgn.bz2 --ratios $frac `echo "1000- $frac " | bc` 23 | 24 | echo $p_name $frac $p_dir 25 | python split_by_player.py $p_dir/raw_reduced.pgn.bz2 $p_name $split_dir/games 26 | 27 | 28 | for c in "white" "black"; do 29 | python pgn_fractional_split.py $split_dir/games_$c.pgn.bz2 $split_dir/train_$c.pgn.bz2 $split_dir/validate_$c.pgn.bz2 $split_dir/test_$c.pgn.bz2 --ratios $train_frac $val_frac $test_frac 30 | 31 | cd $p_dir 32 | mkdir -p pgns 33 | for s in "train" "validate" "test"; do 34 | mkdir -p $s 35 | mkdir $s/$c 36 | 37 | #using tool from: 38 | #https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/ 39 | bzcat $split_dir/${s}_${c}.pgn.bz2 | pgn-extract -7 -C -N -#1000 40 | 41 | cat *.pgn > pgns/${s}_${c}.pgn 42 | rm -v *.pgn 43 | 44 | #using tool from: 45 | #https://github.com/DanielUranga/trainingdata-tool 46 | screen -S "${p_name}-${c}-${s}" -dm bash -c "cd ${s}/${c}; trainingdata-tool -v ../../pgns/${s}_${c}.pgn" 47 | done 48 | cd - 49 | done 50 | done 51 | done 52 | -------------------------------------------------------------------------------- /1-data_generation/1-0-make_val_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | train_frac=80 6 | val_frac=10 7 | test_frac=10 8 | 9 | input_files="/maiadata/transfer_players_data/validate" 10 | output_files="/maiadata/transfer_players_validate" 11 | mkdir -p $output_files 12 | 13 | for player_file in $input_files/*.bz2; do 14 | f=${player_file##*/} 15 | p_name=${f%.pgn.bz2} 16 | p_dir=$output_files/$p_name 17 | 18 | f_size=$(du -sb ${output_files}/${p_name} | cut -f1) 19 | if [ $((f_size)) -lt 50000 ]; then 20 | echo $f_size $p_dir 21 | rm -rv $p_dir/* 22 | else 23 | continue 24 | fi 25 | 26 | split_dir=$output_files/$p_name/split 27 | mkdir -p $p_dir 28 | mkdir -p $split_dir 29 | echo $p_name $p_dir 30 | python split_by_player.py $player_file $p_name $split_dir/games 31 | 32 | for c in "white" "black"; do 33 | python pgn_fractional_split.py $split_dir/games_$c.pgn.bz2 $split_dir/train_$c.pgn.bz2 $split_dir/validate_$c.pgn.bz2 $split_dir/test_$c.pgn.bz2 --ratios $train_frac $val_frac $test_frac 34 | 35 | cd $p_dir 36 | mkdir -p pgns 37 | for s in "train" "validate" "test"; do 38 | mkdir -p $s 39 | mkdir $s/$c 40 | 41 | #using tool from: 42 | #https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/ 43 | bzcat $split_dir/${s}_${c}.pgn.bz2 | pgn-extract -7 -C -N -#1000 44 | 45 | cat *.pgn > pgns/${s}_${c}.pgn 46 | rm -v *.pgn 47 | 48 | #using tool from: 49 | #https://github.com/DanielUranga/trainingdata-tool 50 | screen -S "${p_name}-${c}-${s}" -dm bash -c "cd ${s}/${c}; trainingdata-tool -v ../../pgns/${s}_${c}.pgn" 51 | done 52 | cd - 53 | done 54 | while [ `screen -ls | wc -l` -gt 20 ]; do 55 | printf "waiting\r" 56 | sleep 10 57 | done 58 | done 59 | -------------------------------------------------------------------------------- /1-data_generation/1-1-make_val_csvs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data_dir="../../data/transfer_players_validate" 4 | 5 | for player_dir in $data_dir/*; do 6 | player_name=`basename ${player_dir}` 7 | mkdir -p $player_dir/csvs 8 | for c in "white" "black"; do 9 | for s in "train" "validate" "test"; do 10 | target=$player_dir/split/${s}_${c}.pgn.bz2 11 | output=$player_dir/csvs/${s}_${c}.csv.bz2 12 | echo ${player_name} ${s} ${c} 13 | screen -S "csv-${player_name}-${c}-${s}" -dm bash -c "python3 ../../data_generators/pgn_to_csv.py ${target} ${output}" 14 | done 15 | done 16 | while [ `screen -ls | wc -l` -gt 50 ]; do 17 | printf "waiting\r" 18 | sleep 10 19 | done 20 | done 21 | -------------------------------------------------------------------------------- /1-data_generation/2-make_testing_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | train_frac=80 4 | val_frac=10 5 | test_frac=10 6 | 7 | input_files="/maiadata/transfer_players_data/test" 8 | output_files="/maiadata/transfer_players_test" 9 | mkdir -p $output_files 10 | 11 | for player_file in $input_files/*.bz2; do 12 | f=${player_file##*/} 13 | p_name=${f%.pgn.bz2} 14 | p_dir=$output_files/$p_name 15 | split_dir=$output_files/$p_name/split 16 | mkdir -p $p_dir 17 | mkdir -p $split_dir 18 | echo $p_name $p_dir 19 | python split_by_player.py $player_file $p_name $split_dir/games 20 | 21 | 22 | for c in "white" "black"; do 23 | python pgn_fractional_split.py $split_dir/games_$c.pgn.bz2 $split_dir/train_$c.pgn.bz2 $split_dir/validate_$c.pgn.bz2 $split_dir/test_$c.pgn.bz2 --ratios $train_frac $val_frac $test_frac 24 | 25 | cd $p_dir 26 | mkdir -p pgns 27 | for s in "train" "validate" "test"; do 28 | mkdir -p $s 29 | mkdir $s/$c 30 | 31 | #using tool from: 32 | #https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/ 33 | bzcat $split_dir/${s}_${c}.pgn.bz2 | pgn-extract -7 -C -N -#1000 34 | 35 | cat *.pgn > pgns/${s}_${c}.pgn 36 | rm -v *.pgn 37 | 38 | #using tool from: 39 | #https://github.com/DanielUranga/trainingdata-tool 40 | screen -S "${p_name}-${c}-${s}" -dm bash -c "cd ${s}/${c}; trainingdata-tool -v ../../pgns/${s}_${c}.pgn" 41 | done 42 | cd - 43 | done 44 | while [ `screen -ls | wc -l` -gt 20 ]; do 45 | printf "waiting\r" 46 | sleep 10 47 | done 48 | 49 | done 50 | -------------------------------------------------------------------------------- /1-data_generation/9-pgn_to_training_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | #args input_path output_dir player 5 | 6 | player_file=${1} 7 | p_dir=${2} 8 | p_name=${3} 9 | 10 | train_frac=90 11 | val_frac=10 12 | 13 | split_dir=$p_dir/split 14 | 15 | mkdir -p ${p_dir} 16 | mkdir -p ${split_dir} 17 | 18 | echo "${p_name} to ${p_dir}" 19 | 20 | python split_by_player.py $player_file $p_name $split_dir/games 21 | 22 | for c in "white" "black"; do 23 | python pgn_fractional_split.py $split_dir/games_$c.pgn.bz2 $split_dir/train_$c.pgn.bz2 $split_dir/validate_$c.pgn.bz2 --ratios $train_frac $val_frac 24 | 25 | cd $p_dir 26 | mkdir -p pgns 27 | for s in "train" "validate"; do 28 | mkdir -p $s 29 | mkdir -p $s/$c 30 | 31 | #using tool from: 32 | #https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/ 33 | bzcat $split_dir/${s}_${c}.pgn.bz2 | pgn-extract -7 -C -N -#1000 34 | 35 | cat *.pgn > pgns/${s}_${c}.pgn 36 | rm -v *.pgn 37 | 38 | #using tool from: 39 | #https://github.com/DanielUranga/trainingdata-tool 40 | screen -S "${p_name}-${c}-${s}" -dm bash -c "cd ${s}/${c}; trainingdata-tool -v ../../pgns/${s}_${c}.pgn" 41 | done 42 | cd - 43 | done 44 | -------------------------------------------------------------------------------- /1-data_generation/pgn_fractional_split.py: -------------------------------------------------------------------------------- 1 | import backend 2 | 3 | import argparse 4 | import bz2 5 | import random 6 | 7 | @backend.logged_main 8 | def main(): 9 | parser = argparse.ArgumentParser(description='Split games into some numbe of subsets, by percentage', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | 11 | parser.add_argument('input', help='input pgn') 12 | 13 | parser.add_argument('outputs', nargs='+', help='output pgn files ', type = str) 14 | 15 | parser.add_argument('--ratios', nargs='+', help='ratios of games for the outputs', required = True, type = float) 16 | 17 | parser.add_argument('--no_shuffle', action='store_false', help='Stop output shuffling') 18 | parser.add_argument('--seed', type=int, help='random seed', default = 1) 19 | args = parser.parse_args() 20 | 21 | if len(args.ratios) != len(args.outputs): 22 | raise RuntimeError(f"Invalid outputs specified: {args.outputs} and {args.ratios}") 23 | 24 | random.seed(args.seed) 25 | games = backend.GamesFile(args.input) 26 | 27 | game_strs = [] 28 | 29 | for i, (d, l) in enumerate(games): 30 | game_strs.append(l) 31 | if i % 10000 == 0: 32 | backend.printWithDate(f"{i} done from {args.input}", end = '\r') 33 | backend.printWithDate(f"{i} done total from {args.input}") 34 | if not args.no_shuffle: 35 | random.shuffle(game_strs) 36 | 37 | split_indices = [int(r * len(game_strs) / sum(args.ratios)) for r in args.ratios] 38 | 39 | #Correction for rounding, not very precise 40 | split_indices[0] += len(game_strs) - sum(split_indices) 41 | 42 | for p, c in zip(args.outputs, split_indices): 43 | backend.printWithDate(f"Writing {c} games to: {p}") 44 | with bz2.open(p, 'wt') as f: 45 | f.write(''.join( 46 | [game_strs.pop() for i in range(c)] 47 | )) 48 | 49 | backend.printWithDate("done") 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /1-data_generation/player_splits.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | 3 | input_files="../../data/top_player_games" 4 | output_files="../../data/transfer_players_pgns_split" 5 | mkdir -p $output_files 6 | 7 | train_frac=80 8 | val_frac=10 9 | test_frac=10 10 | 11 | for p in $input_files/*; do 12 | name=`basename $p` 13 | p_name=${name%.pgn.bz2} 14 | split_dir=$output_files/$name 15 | mkdir $split_dir 16 | 17 | screen -S "${p_name}" -dm bash -c "python3 pgn_fractional_split.py $p $split_dir/train.pgn.bz2 $split_dir/validate.pgn.bz2 $split_dir/test.pgn.bz2 --ratios $train_frac $val_frac $test_frac" 18 | done 19 | -------------------------------------------------------------------------------- /1-data_generation/split_by_player.py: -------------------------------------------------------------------------------- 1 | import backend 2 | 3 | import argparse 4 | import bz2 5 | import random 6 | 7 | @backend.logged_main 8 | def main(): 9 | parser = argparse.ArgumentParser(description='Split games into games were the target was White or Black', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | 11 | parser.add_argument('input', help='input pgn') 12 | parser.add_argument('player', help='target player name') 13 | parser.add_argument('output', help='output pgn prefix') 14 | parser.add_argument('--no_shuffle', action='store_false', help='Stop output shuffling') 15 | parser.add_argument('--seed', type=int, help='random seed', default = 1) 16 | args = parser.parse_args() 17 | 18 | random.seed(args.seed) 19 | 20 | games = backend.GamesFile(args.input) 21 | 22 | outputs_white = [] 23 | outputs_black = [] 24 | 25 | for i, (d, l) in enumerate(games): 26 | if d['White'] == args.player: 27 | outputs_white.append(l) 28 | elif d['Black'] == args.player: 29 | outputs_black.append(l) 30 | else: 31 | raise ValueError(f"{args.player} not found in game {i}:\n{l}") 32 | if i % 10000 == 0: 33 | backend.printWithDate(f"{i} done with {len(outputs_white)}:{len(outputs_black)} players from {args.input}", end = '\r') 34 | backend.printWithDate(f"{i} found totals of {len(outputs_white)}:{len(outputs_black)} players from {args.input}") 35 | backend.printWithDate("Writing white") 36 | with bz2.open(f"{args.output}_white.pgn.bz2", 'wt') as f: 37 | if not args.no_shuffle: 38 | random.shuffle(outputs_white) 39 | f.write(''.join(outputs_white)) 40 | backend.printWithDate("Writing black") 41 | with bz2.open(f"{args.output}_black.pgn.bz2", 'wt') as f: 42 | if not args.no_shuffle: 43 | random.shuffle(outputs_black) 44 | f.write(''.join(outputs_black)) 45 | backend.printWithDate("done") 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /2-training/extended_configs/frozen_copy/frozen_copy.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | #gpu: 1 4 | 5 | dataset: 6 | path: '/data/transfer_players_extended/' 7 | #name: '' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 256 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 150000 16 | checkpoint_steps: 500 17 | shuffle_size: 256 18 | lr_values: 19 | - 0.01 20 | - 0.001 21 | - 0.0001 22 | - 0.00001 23 | lr_boundaries: 24 | - 35000 25 | - 80000 26 | - 110000 27 | policy_loss_weight: 1.0 28 | value_loss_weight: 1.0 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | path: "maia/1900" 35 | keep_weights: true 36 | back_prop_blocks: 3 37 | ... 38 | -------------------------------------------------------------------------------- /2-training/extended_configs/frozen_random/frozen_random.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | #gpu: 1 4 | 5 | dataset: 6 | path: '/data/transfer_players_extended/' 7 | #name: '' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 256 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 150000 16 | checkpoint_steps: 500 17 | shuffle_size: 256 18 | lr_values: 19 | - 0.01 20 | - 0.001 21 | - 0.0001 22 | - 0.00001 23 | lr_boundaries: 24 | - 35000 25 | - 80000 26 | - 110000 27 | policy_loss_weight: 1.0 28 | value_loss_weight: 1.0 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | path: "maia/1900" 35 | keep_weights: false 36 | back_prop_blocks: 3 37 | ... 38 | -------------------------------------------------------------------------------- /2-training/extended_configs/unfrozen_copy/unfrozen_copy.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | #gpu: 1 4 | 5 | dataset: 6 | path: '/data/transfer_players_extended/' 7 | #name: '' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 256 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 150000 16 | checkpoint_steps: 500 17 | shuffle_size: 256 18 | lr_values: 19 | - 0.01 20 | - 0.001 21 | - 0.0001 22 | - 0.00001 23 | lr_boundaries: 24 | - 35000 25 | - 80000 26 | - 110000 27 | policy_loss_weight: 1.0 28 | value_loss_weight: 1.0 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | path: "maia/1900" 35 | keep_weights: true 36 | back_prop_blocks: 99 37 | ... 38 | -------------------------------------------------------------------------------- /2-training/extended_configs/unfrozen_random/unfrozen_random.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | #gpu: 1 4 | 5 | dataset: 6 | path: '/data/transfer_players_extended/' 7 | #name: '' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 256 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 150000 16 | checkpoint_steps: 500 17 | shuffle_size: 256 18 | lr_values: 19 | - 0.01 20 | - 0.001 21 | - 0.0001 22 | - 0.00001 23 | lr_boundaries: 24 | - 35000 25 | - 80000 26 | - 110000 27 | policy_loss_weight: 1.0 28 | value_loss_weight: 1.0 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | path: "maia/1900" 35 | keep_weights: false 36 | back_prop_blocks: 99 37 | ... 38 | -------------------------------------------------------------------------------- /2-training/final_config.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | path: 'path to player data' 7 | #name: '' 8 | 9 | training: 10 | precision: 'half' 11 | batch_size: 256 12 | num_batch_splits: 1 13 | test_steps: 2000 14 | train_avg_report_steps: 50 15 | total_steps: 150000 16 | checkpoint_steps: 500 17 | shuffle_size: 256 18 | lr_values: 19 | - 0.01 20 | - 0.001 21 | - 0.0001 22 | - 0.00001 23 | lr_boundaries: 24 | - 35000 25 | - 80000 26 | - 110000 27 | policy_loss_weight: 1.0 28 | value_loss_weight: 1.0 29 | 30 | model: 31 | filters: 64 32 | residual_blocks: 6 33 | se_ratio: 8 34 | path: "maia-1900" 35 | keep_weights: true 36 | back_prop_blocks: 99 37 | ... 38 | -------------------------------------------------------------------------------- /2-training/train_transfer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path 4 | import yaml 5 | import sys 6 | import glob 7 | import gzip 8 | import random 9 | import multiprocessing 10 | import shutil 11 | 12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 13 | import tensorflow as tf 14 | 15 | 16 | import backend 17 | import backend.tf_transfer 18 | 19 | SKIP = 32 20 | 21 | @backend.logged_main 22 | def main(config_path, name, collection_name, player_name, gpu, num_workers): 23 | output_name = os.path.join('models', collection_name, name + '.txt') 24 | 25 | with open(config_path) as f: 26 | cfg = yaml.safe_load(f.read()) 27 | 28 | if player_name is not None: 29 | cfg['dataset']['name'] = player_name 30 | if gpu is not None: 31 | cfg['gpu'] = gpu 32 | 33 | backend.printWithDate(yaml.dump(cfg, default_flow_style=False)) 34 | 35 | train_chunks_white, train_chunks_black = backend.tf_transfer.get_latest_chunks(os.path.join( 36 | cfg['dataset']['path'], 37 | cfg['dataset']['name'], 38 | 'train', 39 | )) 40 | val_chunks_white, val_chunks_black = backend.tf_transfer.get_latest_chunks(os.path.join( 41 | cfg['dataset']['path'], 42 | cfg['dataset']['name'], 43 | 'validate', 44 | )) 45 | 46 | shuffle_size = cfg['training']['shuffle_size'] 47 | total_batch_size = cfg['training']['batch_size'] 48 | backend.tf_transfer.ChunkParser.BATCH_SIZE = total_batch_size 49 | tfprocess = backend.tf_transfer.TFProcess(cfg, name, collection_name) 50 | 51 | train_parser = backend.tf_transfer.ChunkParser( 52 | backend.tf_transfer.FileDataSrc(train_chunks_white.copy(), train_chunks_black.copy()), 53 | shuffle_size=shuffle_size, 54 | sample=SKIP, 55 | batch_size=backend.tf_transfer.ChunkParser.BATCH_SIZE, 56 | workers=num_workers, 57 | ) 58 | train_dataset = tf.data.Dataset.from_generator( 59 | train_parser.parse, 60 | output_types=( 61 | tf.string, tf.string, tf.string, tf.string 62 | ), 63 | ) 64 | train_dataset = train_dataset.map( 65 | backend.tf_transfer.ChunkParser.parse_function) 66 | train_dataset = train_dataset.prefetch(4) 67 | 68 | test_parser = backend.tf_transfer.ChunkParser( 69 | backend.tf_transfer.FileDataSrc(val_chunks_white.copy(), val_chunks_black.copy()), 70 | shuffle_size=shuffle_size, 71 | sample=SKIP, 72 | batch_size=backend.tf_transfer.ChunkParser.BATCH_SIZE, 73 | workers=num_workers, 74 | ) 75 | test_dataset = tf.data.Dataset.from_generator( 76 | test_parser.parse, 77 | output_types=(tf.string, tf.string, tf.string, tf.string), 78 | ) 79 | test_dataset = test_dataset.map( 80 | backend.tf_transfer.ChunkParser.parse_function) 81 | test_dataset = test_dataset.prefetch(4) 82 | 83 | tfprocess.init_v2(train_dataset, test_dataset) 84 | 85 | tfprocess.restore_v2() 86 | 87 | num_evals = cfg['training'].get('num_test_positions', (len(val_chunks_white) + len(val_chunks_black)) * 10) 88 | num_evals = max(1, num_evals // backend.tf_transfer.ChunkParser.BATCH_SIZE) 89 | print("Using {} evaluation batches".format(num_evals)) 90 | try: 91 | tfprocess.process_loop_v2(total_batch_size, num_evals, batch_splits=1) 92 | except KeyboardInterrupt: 93 | backend.printWithDate("KeyboardInterrupt: Stopping") 94 | train_parser.shutdown() 95 | test_parser.shutdown() 96 | raise 97 | tfprocess.save_leelaz_weights_v2(output_name) 98 | 99 | train_parser.shutdown() 100 | test_parser.shutdown() 101 | return cfg 102 | 103 | def make_model_files(cfg, name, collection_name, save_dir): 104 | output_name = os.path.join(save_dir, collection_name, name) 105 | models_dir = os.path.join('models', collection_name, name) 106 | models = [(int(p.name.split('-')[1]), p.name, p.path) for p in os.scandir(models_dir) if p.name.endswith('.pb.gz')] 107 | top_model = max(models, key = lambda x : x[0]) 108 | 109 | os.makedirs(output_name, exist_ok=True) 110 | model_file_name = top_model[1].replace('ckpt', name) 111 | shutil.copy(top_model[2], os.path.join(output_name, model_file_name)) 112 | with open(os.path.join(output_name, "config.yaml"), 'w') as f: 113 | cfg_yaml = yaml.dump(cfg).replace('\n', '\n ').strip() 114 | f.write(f""" 115 | %YAML 1.2 116 | --- 117 | name: {name} 118 | display_name: {name.replace('_', ' ')} 119 | engine: lc0_23 120 | options: 121 | weightsPath: {model_file_name} 122 | full_config: 123 | {cfg_yaml} 124 | ...""") 125 | 126 | if __name__ == "__main__": 127 | parser = argparse.ArgumentParser(description='Tensorflow pipeline for training Leela Chess.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 128 | 129 | parser.add_argument('config', help='config file for model / training') 130 | parser.add_argument('player_name', nargs='?', help='player name to train on', default=None) 131 | parser.add_argument('--gpu', help='gpu to use', default = 0, type = int) 132 | parser.add_argument('--num_workers', help='number of worker threads to use', default = max(1, multiprocessing.cpu_count() - 2), type = int) 133 | parser.add_argument('--copy_dir', help='dir to save final models in', default = 'final_models') 134 | args = parser.parse_args() 135 | 136 | collection_name = os.path.basename(os.path.dirname(args.config)).replace('configs_', '') 137 | name = os.path.basename(args.config).split('.')[0] 138 | 139 | if args.player_name is not None: 140 | name = f"{args.player_name}_{name}" 141 | 142 | multiprocessing.set_start_method('spawn') 143 | cfg = main(args.config, name, collection_name, args.player_name, args.gpu, args.num_workers) 144 | make_model_files(cfg, name, collection_name, args.copy_dir) 145 | -------------------------------------------------------------------------------- /3-analysis/1-0-baselines_results.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | targets_dir="../../data/transfer_players_train" 4 | outputs_dir="../../data/transfer_results/train" 5 | 6 | maias="../../models/maia" 7 | stockfish="../../models/stockfish/stockfish_d15" 8 | leela="../../models/leela/sergio" 9 | 10 | 11 | mkdir -p outputs_dir 12 | 13 | for player_dir in $targets_dir/*; do 14 | player=`basename ${player_dir}` 15 | player_ret_dir=$outputs_dir/$player 16 | 17 | echo $player_dir 18 | 19 | mkdir -p $player_ret_dir 20 | mkdir -p $player_ret_dir/maia 21 | mkdir -p $player_ret_dir/leela 22 | #mkdir -p $player_ret_dir/stockfish 23 | for c in "white" "black"; do 24 | player_files=$player_dir/csvs/test_${c}.csv.bz2 25 | #screen -S "baseline-tests-${player}-leela-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $leela ${player_files} ${player_ret_dir}/leela/segio_${c}.csv.bz2" 26 | for maia_path in $maias/*; do 27 | maia_name=`basename ${maia_path}` 28 | printf "$maia_name\r" 29 | screen -S "baseline-tests-${player}-${maia_name}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $maia_path ${player_files} ${player_ret_dir}/maia/${maia_name}_${c}.csv.bz2" 30 | done 31 | done 32 | done 33 | -------------------------------------------------------------------------------- /3-analysis/1-1-baselines_results_validation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | targets_dir="../../data/transfer_players_validate" 4 | outputs_dir="../../data/transfer_results/validate" 5 | 6 | maias="../../models/maia" 7 | stockfish="../../models/stockfish/stockfish_d15" 8 | leela="../../models/leela/sergio" 9 | 10 | 11 | mkdir -p outputs_dir 12 | 13 | for player_dir in $targets_dir/*; do 14 | player=`basename ${player_dir}` 15 | player_ret_dir=$outputs_dir/$player 16 | 17 | echo $player_dir 18 | 19 | mkdir -p $player_ret_dir 20 | mkdir -p $player_ret_dir/maia 21 | mkdir -p $player_ret_dir/leela 22 | #mkdir -p $player_ret_dir/stockfish 23 | for c in "white" "black"; do 24 | player_files=$player_dir/csvs/test_${c}.csv.bz2 25 | #screen -S "baseline-tests-${player}-leela-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $leela ${player_files} ${player_ret_dir}/leela/segio_${c}.csv.bz2" 26 | for maia_path in $maias/1{1..9..2}00; do 27 | maia_name=`basename ${maia_path}` 28 | printf "$maia_name\r" 29 | screen -S "baseline-tests-${player}-${maia_name}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $maia_path ${player_files} ${player_ret_dir}/maia/${maia_name}_${c}.csv.bz2" 30 | done 31 | done 32 | while [ `screen -ls | wc -l` -gt 70 ]; do 33 | printf "waiting\r" 34 | sleep 10 35 | done 36 | done 37 | -------------------------------------------------------------------------------- /3-analysis/2-0-baseline_results.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=40 5 | 6 | targets_dir="../../data/transfer_players" 7 | outputs_dir="../../data/transfer_results" 8 | kdd_path="../../datasets/10000_full_2019-12.csv.bz2" 9 | 10 | mkdir -p outputs_dir 11 | 12 | maias_dir=../../models/maia 13 | 14 | for t in "train" "extended" "validate"; do 15 | for player_dir in ${targets_dir}_${t}/*; do 16 | for model in $maias_dir/1{1..9..2}00; do 17 | maia_type=`basename ${model}` 18 | player_ret_dir=$outputs_dir/$t/$player/maia 19 | mkdir -p $player_ret_dir 20 | player=`basename ${player_dir}` 21 | echo $t $maia_type $player 22 | for c in "white" "black"; do 23 | player_files=${player_dir}/csvs/test_${c}.csv.bz2 24 | screen -S "baselines-${player}-${maia_type}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $model ${player_files} ${player_ret_dir}/${maia_type}_${c}.csv.bz2" 25 | done 26 | while [ `screen -ls | wc -l` -gt $max_screens ]; do 27 | printf "waiting\r" 28 | sleep 10 29 | done 30 | done 31 | done 32 | done 33 | 34 | -------------------------------------------------------------------------------- /3-analysis/2-1-model_results.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=50 5 | 6 | targets_dir="../../data/transfer_players" 7 | outputs_dir="../../data/transfer_results" 8 | kdd_path="../../datasets/10000_full_2019-12.csv.bz2" 9 | 10 | 11 | models_dir="../../transfer_training/final_models" 12 | mkdir -p outputs_dir 13 | 14 | for model in $models_dir/*/*; do 15 | player=`python3 get_models_player.py ${model}` 16 | model_type=`dirname ${model}` 17 | model_type=`basename ${model_type}` 18 | model_name=`basename ${model}` 19 | #echo $model $model_type $model_name $player 20 | 21 | for c in "white" "black"; do 22 | for t in "train" "extended"; do 23 | player_files=${targets_dir}_${t}/$player/csvs/test_${c}.csv.bz2 24 | if [ -f "$player_files" ]; then 25 | echo $player_files 26 | player_ret_dir=$outputs_dir/$t/$player/transfer/$model_type 27 | mkdir -p $player_ret_dir 28 | screen -S "transfer-tests-${player}-${model_type}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $model ${player_files} ${player_ret_dir}/${model_name}_${c}.csv.bz2" 29 | fi 30 | done 31 | done 32 | while [ `screen -ls | wc -l` -gt $max_screens ]; do 33 | printf "waiting\r" 34 | sleep 10 35 | done 36 | #screen -S "transfer-tests-${player}-${model_type}-kdd" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py $model ${kdd_path} ${player_ret_dir}/${model_name}_kdd.csv.bz2" 37 | done 38 | 39 | 40 | -------------------------------------------------------------------------------- /3-analysis/2-2-model_results_val.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=50 5 | 6 | targets_dir="../../data/transfer_players" 7 | outputs_dir="../../data/transfer_results_val" 8 | summaries_dir="../../data/transfer_summaries" 9 | kdd_path="../../data/reduced_kdd_test_set.csv.bz2" 10 | 11 | 12 | models_dir="../../transfer_training/final_models_val" 13 | mkdir -p $outputs_dir 14 | mkdir -p $summaries_dir 15 | 16 | for model in $models_dir/*/*; do 17 | player=`python3 get_models_player.py ${model}` 18 | model_type=`dirname ${model}` 19 | model_type=`basename ${model_type}` 20 | model_name=`basename ${model}` 21 | #echo $model $model_type $model_name $player 22 | for t in "train" "validate"; do 23 | for c in "white" "black"; do 24 | player_files=${targets_dir}_${t}/$player/csvs/test_${c}.csv.bz2 25 | if [ -f "$player_files" ]; then 26 | echo $player_files 27 | player_ret_dir=$outputs_dir/$t/$player/${t}/$model_type 28 | player_sum_dir=$summaries_dir/$t/$player/${t}/$model_type 29 | mkdir -p $player_ret_dir 30 | screen -S "val-tests-${player}-${model_type}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $model ${player_files} ${player_ret_dir}/${model_name}_${c}.csv.bz2;python3 make_summary.py ${player_ret_dir}/${model_name}_${c}.csv.bz2 ${player_sum_dir}/${model_name}_${c}.json" 31 | fi 32 | done 33 | done 34 | screen -S "val-tests-${player}-${model_type}-kdd" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py $model ${kdd_path} ${player_ret_dir}/${model_name}_kdd_reduced.csv.bz2;python3 make_summary.py ${player_ret_dir}/${model_name}_kdd_reduced.csv.bz2 35 | ${player_sum_dir}/${model_name}_kdd_reduced.csv.bz2" 36 | while [ `screen -ls | wc -l` -gt $max_screens ]; do 37 | printf "waiting\r" 38 | sleep 10 39 | done 40 | #screen -S "transfer-tests-${player}-${model_type}-kdd" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py $model ${kdd_path} ${player_ret_dir}/${model_name}_kdd.csv.bz2" 41 | done 42 | 43 | 44 | -------------------------------------------------------------------------------- /3-analysis/3-0-model_cross_table.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=40 5 | 6 | targets_dir="../../data/transfer_players" 7 | outputs_dir="../../data/transfer_results_cross" 8 | 9 | models_dir="../../transfer_training/final_models" 10 | 11 | target_models=`echo ../../transfer_training/final_models/{no_stop,unfrozen_copy}/*` 12 | 13 | mkdir -p $outputs_dir 14 | 15 | for model in $target_models; do 16 | player=`python3 get_models_player.py ${model}` 17 | model_type=`dirname ${model}` 18 | model_type=`basename ${model_type}` 19 | model_name=`basename ${model}` 20 | echo $player $model_type $model 21 | for c in "white" "black"; do 22 | for t in "train" "extended"; do 23 | player_files=${targets_dir}_${t}/$player/csvs/test_${c}.csv.bz2 24 | if [ -f "$player_files" ]; then 25 | player_ret_dir=$outputs_dir/$player 26 | mkdir -p $player_ret_dir 27 | echo $player_files 28 | for model2 in $target_models; do 29 | model2_name=`basename ${model2}` 30 | model2_player=`python3 get_models_player.py ${model2}` 31 | screen -S "cross-${player}-${model2_player}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $model2 ${player_files} ${player_ret_dir}/${model2_player}_${c}.csv.bz2" 32 | done 33 | while [ `screen -ls | wc -l` -gt $max_screens ]; do 34 | printf "waiting\r" 35 | sleep 10 36 | done 37 | fi 38 | done 39 | done 40 | done 41 | -------------------------------------------------------------------------------- /3-analysis/3-1-model_cross_table_val_generation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=40 5 | 6 | targets_dir="../../data/transfer_players_validate" 7 | outputs_dir="../../data/transfer_players_validate_cross_csvs" 8 | 9 | mkdir -p $outputs_dir 10 | 11 | for player_dir in $targets_dir/*; do 12 | player=`basename ${player_dir}` 13 | echo $player 14 | mkdir -p ${outputs_dir}/${player} 15 | for c in "white" "black"; do 16 | screen -S "cross-${player}-${c}" -dm bash -c "sourcer ~/.basrc; python3 csv_trimmer.py ${player_dir}/csvs/test_${c}.csv.bz2 ${outputs_dir}/${player}/test_${c}_reduced.csv.bz2" 17 | done 18 | done 19 | 20 | -------------------------------------------------------------------------------- /3-analysis/3-2-model_cross_table_val.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=40 5 | 6 | targets_dir="../../data/transfer_players" 7 | outputs_dir="../../data/transfer_results_cross_val" 8 | 9 | models_dir="../../transfer_training/final_models" 10 | 11 | target_models=`echo ../../transfer_training/final_models_val/unfrozen_copy/* ../../transfer_training/final_models/{no_stop,unfrozen_copy}/* ` 12 | 13 | mkdir -p $outputs_dir 14 | 15 | for model in $target_models; do 16 | player=`python3 get_models_player.py ${model}` 17 | model_type=`dirname ${model}` 18 | model_type=`basename ${model_type}` 19 | model_name=`basename ${model}` 20 | echo $player $model_type $model 21 | for c in "white" "black"; do 22 | for t in "train" "extended" "validate"; do 23 | player_files=${targets_dir}_${t}/$player/csvs/test_${c}.csv.bz2 24 | if [ -f "$player_files" ]; then 25 | player_ret_dir=$outputs_dir/$player 26 | mkdir -p $player_ret_dir 27 | echo $player_files 28 | for model2 in $target_models; do 29 | model2_name=`basename ${model2}` 30 | model2_player=`python3 get_models_player.py ${model2}` 31 | screen -S "cross-${player}-${model2_player}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $model2 ${player_files} ${player_ret_dir}/${model2_player}_${c}.csv.bz2" 32 | done 33 | while [ `screen -ls | wc -l` -gt $max_screens ]; do 34 | printf "waiting\r" 35 | sleep 10 36 | done 37 | fi 38 | done 39 | done 40 | done 41 | -------------------------------------------------------------------------------- /3-analysis/4-0-result_summaries.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=80 5 | 6 | targets_dir="../../data/transfer_results" 7 | outputs_dir="../../data/transfer_summaries" 8 | mkdir -p outputs_dir 9 | 10 | for p in $targets_dir/*/*/*/*.bz2 $targets_dir/*/*/*/*/*.bz2; do 11 | #result=$(echo "$p" | sed "s/$targets_dir/$outputs_dir/g") 12 | out_path=${p/$targets_dir/$outputs_dir} 13 | out_path=${out_path/.csv.bz2/.json} 14 | base=`dirname ${p/$targets_dir/}` 15 | base=${base//\//-} 16 | mkdir -p `dirname $out_path` 17 | echo $base 18 | #"${${}/${outputs_dir}/${targets_dir}}" 19 | screen -S "summary${base}" -dm bash -c "source ~/.bashrc; python3 make_summary.py ${p} ${out_path}" 20 | while [ `screen -ls | wc -l` -gt $max_screens ]; do 21 | printf "waiting\r" 22 | sleep 10 23 | done 24 | done 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /3-analysis/4-1-result_summaries_cross.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=80 5 | 6 | targets_dir="../../data/transfer_results_cross" 7 | outputs_dir="../../data/transfer_results_cross_summaries" 8 | mkdir -p outputs_dir 9 | 10 | for p in $targets_dir/*/*.bz2; do 11 | #result=$(echo "$p" | sed "s/$targets_dir/$outputs_dir/g") 12 | out_path=${p/$targets_dir/$outputs_dir} 13 | out_path=${out_path/.csv.bz2/.json} 14 | base=`dirname ${p/$targets_dir/}` 15 | base=${base//\//-} 16 | mkdir -p `dirname $out_path` 17 | echo $base 18 | #"${${}/${outputs_dir}/${targets_dir}}" 19 | screen -S "summary${base}" -dm bash -c "source ~/.bashrc; python3 make_summary.py ${p} ${out_path}" 20 | while [ `screen -ls | wc -l` -gt $max_screens ]; do 21 | printf "waiting\r" 22 | sleep 10 23 | done 24 | done 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /3-analysis/4-2-result_summaries_val.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | max_screens=80 5 | 6 | targets_dir="../../data/transfer_results_val/validate" 7 | outputs_dir="../../data/transfer_summaries_val" 8 | mkdir -p outputs_dir 9 | 10 | for p in $targets_dir/*/*/*/*.bz2; do 11 | #result=$(echo "$p" | sed "s/$targets_dir/$outputs_dir/g") 12 | out_path=${p/$targets_dir/$outputs_dir} 13 | out_path=${out_path/.csv.bz2/.json} 14 | base=`dirname ${p/$targets_dir/}` 15 | base=${base//\//-} 16 | mkdir -p `dirname $out_path` 17 | echo $base 18 | #"${${}/${outputs_dir}/${targets_dir}}" 19 | screen -S "summary${base}" -dm bash -c "source ~/.bashrc; python3 make_summary.py ${p} ${out_path}" 20 | while [ `screen -ls | wc -l` -gt $max_screens ]; do 21 | printf "waiting\r" 22 | sleep 10 23 | done 24 | done 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /3-analysis/csv_trimmer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path 4 | import bz2 5 | import csv 6 | import multiprocessing 7 | import humanize 8 | import time 9 | import queue 10 | import json 11 | import pandas 12 | 13 | import chess 14 | 15 | import backend 16 | 17 | #@backend.logged_main 18 | def main(): 19 | parser = argparse.ArgumentParser(description='Run model on all the lines of the csv', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | 21 | parser.add_argument('input', help='input CSV') 22 | parser.add_argument('output', help='output CSV') 23 | parser.add_argument('--ngames', type=int, help='number of games to read in', default = 10) 24 | parser.add_argument('--min_ply', type=int, help='look at games with ply above this', default = 50) 25 | parser.add_argument('--max_ply', type=int, help='look at games with ply below this', default = 100) 26 | args = parser.parse_args() 27 | backend.printWithDate(f"Starting {args.input} to {args.output}") 28 | 29 | with bz2.open(args.input, 'rt') as fin, bz2.open(args.output, 'wt') as fout: 30 | reader = csv.DictReader(fin) 31 | writer = csv.DictWriter(fout, reader.fieldnames) 32 | writer.writeheader() 33 | games_count = 0 34 | current_game = None 35 | for row in reader: 36 | if args.min_ply is not None and int(row['num_ply']) <= args.min_ply: 37 | continue 38 | elif args.max_ply is not None and int(row['num_ply']) >= args.max_ply: 39 | continue 40 | elif row['game_id'] != current_game: 41 | current_game = row['game_id'] 42 | games_count += 1 43 | if args.ngames is not None and games_count >args.ngames: 44 | break 45 | writer.writerow(row) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /3-analysis/get_accuracy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path 4 | 5 | import pandas 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser(description='Quick helper for getting model accuracies', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 9 | parser.add_argument('inputs', nargs = '+', help='input CSVs') 10 | parser.add_argument('--nrows', help='num lines', type = int, default=None) 11 | args = parser.parse_args() 12 | 13 | for p in args.inputs: 14 | try: 15 | df = pandas.read_csv(p, nrows = args.nrows) 16 | except EOFError: 17 | print(f"{os.path.abspath(p).split('.')[0]} EOF") 18 | else: 19 | print(f"{os.path.abspath(p).split('.')[0]} {df['model_correct'].mean() * 100:.2f}%") 20 | 21 | if __name__ == "__main__": 22 | main() 23 | -------------------------------------------------------------------------------- /3-analysis/get_models_player.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path 4 | import yaml 5 | 6 | import pandas 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser(description='Quick helper for getting model players', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | parser.add_argument('input', help='input model dir') 11 | args = parser.parse_args() 12 | 13 | conf_path = os.path.abspath(os.path.join(args.input, "config.yaml")) 14 | if os.path.isfile(conf_path): 15 | with open(conf_path) as f: 16 | cfg = yaml.safe_load(f) 17 | try: 18 | print(cfg['full_config']['name']) 19 | except (KeyError, TypeError): 20 | #some have corrupted configs 21 | if 'Eron_Capivara' in args.input: 22 | print('Eron_Capivara') #hack 23 | else: 24 | 25 | print(os.path.basename(os.path.dirname(conf_path)).split('_')[0]) 26 | else: 27 | raise FileNotFoundError(f"Not a config path: {conf_path}") 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /3-analysis/make_summary.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path 4 | import json 5 | import re 6 | import glob 7 | 8 | import pandas 9 | import numpy as np 10 | 11 | root_dir = os.path.relpath("../..", start=os.path.dirname(os.path.abspath(__file__))) 12 | root_dir = os.path.abspath(root_dir) 13 | cats_ply = { 14 | 'early' : (0, 10), 15 | 'mid' : (11, 50), 16 | 'late' : (51, 999), 17 | 'kdd' : (11, 999), 18 | } 19 | 20 | last_n = [2**n for n in range(12)] 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser(description='Create summary json from results csv', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 24 | parser.add_argument('input', help='input CSV') 25 | parser.add_argument('output', help='output JSON') 26 | parser.add_argument('--players_infos', default="/ada/projects/chess/backend-backend/data/players_infos.json")#os.path.join(root_dir, 'data/players_infos.json')) 27 | args = parser.parse_args() 28 | 29 | with open(args.players_infos) as f: 30 | player_to_dat = json.load(f) 31 | 32 | fixed_data_paths = glob.glob("/ada/projects/chess/backend-backend/data/top_2000_player_data/*.csv.bz2") 33 | fixed_data_lookup = {p.split('/')[-1].replace('.csv.bz2','') :p for p in fixed_data_paths} 34 | 35 | 36 | df = collect_results_csv(args.input ,player_to_dat, fixed_data_lookup) 37 | 38 | r_d = dict(df.iloc[0]) 39 | sum_dict = { 40 | 'count' : len(df), 41 | 'player' : r_d['player_name'], 42 | 'model' : r_d['model_name'], 43 | 'backend' : bool(r_d['backend']), 44 | #'model_correct' : float(df['model_correct'].mean()), 45 | 'elo' : r_d['elo'], 46 | } 47 | c = 'white' if 'white' in args.input.split('/')[-1].split('_')[-1] else 'black' 48 | 49 | add_infos(sum_dict, "full", df) 50 | 51 | try: 52 | csv_raw_path = glob.glob(f"/ada/projects/chess/backend-backend/data/transfer_players_*/{r_d['player_name']}/csvs/test_{c}.csv.bz2")[0] 53 | except IndexError: 54 | if args.input.endswith('kdd.csv.bz2'): 55 | csv_raw_path = "/ada/projects/chess/backend-backend/data/reduced_kdd_test_set.csv.bz2" 56 | else: 57 | csv_raw_path = None 58 | if csv_raw_path is not None: 59 | csv_base = pandas.read_csv(csv_raw_path, index_col=['game_id', 'move_ply'], low_memory=False) 60 | csv_base['winrate_no_0'] = np.where(csv_base.reset_index()['move_ply'] < 2,np.nan, csv_base['winrate']) 61 | csv_base['next_wr'] = 1 - csv_base['winrate_no_0'].shift(-1) 62 | csv_base['move_delta_wr'] = csv_base['next_wr'] - csv_base['winrate'] 63 | csv_base_dropped = csv_base[~csv_base['winrate_loss'].isna()] 64 | csv_base_dropped = csv_base_dropped.join(df.set_index(['game_id', 'move_ply']), how = 'inner', lsuffix = 'r_') 65 | csv_base_dropped['move_delta_wr_rounded'] = (csv_base_dropped['move_delta_wr'] * 1).round(2) / 1 66 | 67 | for dr in csv_base_dropped['move_delta_wr_rounded'].unique(): 68 | if dr < 0 and dr >-.32: 69 | df_dr = csv_base_dropped[csv_base_dropped['move_delta_wr_rounded'] == dr] 70 | add_infos(sum_dict, f"delta_wr_{dr}", df_dr) 71 | 72 | for k, v in player_to_dat[r_d['player_name']].items(): 73 | if k != 'name': 74 | sum_dict[k] = v 75 | if r_d['backend']: 76 | sum_dict['backend_elo'] = int(r_d['model_name'].split('_')[-1]) 77 | 78 | 79 | for c, (p_min, p_max) in cats_ply.items(): 80 | df_c = df[(df['move_ply'] >= p_min) & (df['move_ply'] <= p_max)] 81 | add_infos(sum_dict, c, df_c) 82 | 83 | for year in df['UTCDate'].dt.year.unique(): 84 | df_y = df[df['UTCDate'].dt.year == year] 85 | add_infos(sum_dict, int(year), df_y) 86 | 87 | for ply in range(50): 88 | df_p = df[df['move_ply'] == ply] 89 | if len(df_p) > 0: 90 | # ignore the 50% missing ones 91 | add_infos(sum_dict, f"ply_{ply}", df_p) 92 | 93 | for won in [True, False]: 94 | df_w = df[df['won'] == won] 95 | add_infos(sum_dict, "won" if won else "lost", df_w) 96 | 97 | games = list(df.groupby('game_id').first().sort_values('UTCDate').index) 98 | 99 | for n in last_n: 100 | df_n = df[df['game_id'].isin(games[-n:])] 101 | add_infos(sum_dict, f"last_{n}", df_n) 102 | 103 | p_min, p_max = cats_ply['kdd'] 104 | df_n_kdd = df_n[(df_n['move_ply'] >= p_min) & (df_n['move_ply'] <= p_max)] 105 | add_infos(sum_dict, f"last_{n}_kdd", df_n_kdd) 106 | 107 | with open(args.output, 'wt') as f: 108 | json.dump(sum_dict, f) 109 | 110 | def collect_results_csv(path, player_to_dat, fixed_data_lookup): 111 | try: 112 | df = pandas.read_csv(path, low_memory=False) 113 | except EOFError: 114 | print(f"Error on: {path}") 115 | return None 116 | if len(df) < 1: 117 | return None 118 | df['colour'] = re.search("(black|white)\.csv\.bz2", path).group(1) 119 | #df['class'] = re.search(f"{base_dir}/([a-z]*)/", path).group(1) 120 | backend = 'final_backend_' in df['model_name'].iloc[0] 121 | df['backend'] = backend 122 | try: 123 | df['player'] = df['player_name'] 124 | except KeyError: 125 | pass 126 | try: 127 | if backend: 128 | df['model_type'] = df['model_name'].iloc[0].replace('final_', '') 129 | else: 130 | df['model_type'] = df['model_name'].iloc[0].replace(f'{df.iloc[0]["player"]}_', '') 131 | for k, v in player_to_dat[df.iloc[0]["player"]].items(): 132 | if k != 'name': 133 | df[k] = v 134 | games_df = pandas.read_csv(fixed_data_lookup[df['player'].iloc[0]], 135 | low_memory=False, parse_dates = ['UTCDate'], index_col = 'game_id') 136 | df = df.join(games_df, how= 'left', on = 'game_id', rsuffix='_per_game') 137 | except Exception as e: 138 | print(f"{e} : {path}") 139 | raise 140 | return df 141 | 142 | def add_infos(target_dict, name, df_sub): 143 | target_dict[f'model_correct_{name}'] = float(df_sub['model_correct'].dropna().mean()) 144 | target_dict[f'model_correct_per_game_{name}'] = float(df_sub.groupby(['game_id']).mean()['model_correct'].dropna().mean()) 145 | target_dict[f'count_{name}'] = len(df_sub) 146 | target_dict[f'std_{name}'] = float(df_sub['model_correct'].dropna().std()) 147 | target_dict[f'num_games_{name}'] = len(df_sub.groupby('game_id').count()) 148 | 149 | if __name__ == "__main__": 150 | main() 151 | -------------------------------------------------------------------------------- /3-analysis/move_predictions.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | data_dir="../../data/transfer_players_train" 5 | outputs_dir="../../data/transfer_players_train_results/weights_testing" 6 | maia_path="../../models/maia/1900" 7 | models_path="../models/weights_testing" 8 | 9 | kdd_path="../../datasets/10000_full_2019-12.csv.bz2" 10 | 11 | mkdir -p $outputs_dir 12 | 13 | for player_dir in $data_dir/*; do 14 | player_name=`basename ${player_dir}` 15 | echo $player_name 16 | mkdir -p $outputs_dir/$player_name 17 | 18 | for c in "white" "black"; do 19 | #echo "source ~/.bashrc; python3 prediction_generator.py --target_player ${player_name} ${models_path}/${player_name}*/ ${data_dir}/${player_name}/csvs/test_${c}.csv.bz2 $outputs_dir/${player_name}/transfer_test_${c}.csv.bz2" 20 | screen -S "test-transfer-${c}-${player_name}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player_name} ${models_path}/${player_name}*/ ${data_dir}/${player_name}/csvs/test_${c}.csv.bz2 $outputs_dir/${player_name}/transfer_test_${c}.csv.bz2" 21 | screen -S "test-maia-${c}-${player_name}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player_name} ${maia_path} ${data_dir}/${player_name}/csvs/test_${c}.csv.bz2 $outputs_dir/$player_name/maia_test_${c}.csv.bz2" 22 | done 23 | screen -S "kdd-transfer-${player_name}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py ${models_path}/${player_name}*/ ${kdd_path} $outputs_dir/$player_name/transfer_kdd.csv.bz2" 24 | 25 | while [ `screen -ls | wc -l` -gt 250 ]; do 26 | printf "waiting\r" 27 | sleep 10 28 | done 29 | done 30 | -------------------------------------------------------------------------------- /3-analysis/run-kdd-tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p ../data 4 | mkdir -p ../data/kdd_sweeps 5 | 6 | screen -S "kdd-sweep" -dm bash -c "source ~/.bashrc; python3 ../../analysis/move_prediction_csv.py ../../transfer_models ../../datasets/10000_full_2019-12.csv.bz2 ../data/kdd_sweeps" 7 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/README.md: -------------------------------------------------------------------------------- 1 | `cp_loss_hist`: This is where i save all the train/validation/test data for adding up all games 2 | 3 | `cp_loss_count_per_game`: This is where i save all the train/validation/test data per game. Note the counts haven't been normalized. 4 | 5 | `cp_loss_hist_per_move`: This is where i save all the train/validation/test data per move adding up all games. 6 | 7 | `cp_loss_hist_per_move_per_game`: This is where i save all the train/validation/test data per move per game. 8 | 9 | `cp_loss_hist_per_move_per_game_count`: This is where i save all the train/validation/test data per move per game in counts so they can be added later. 10 | 11 | `get_cp_loss.py`: Parsing code to get cp_loss and its histograms for both train and extended players, and save them in format of **.npy** 12 | 13 | `get_cp_loss_per_game.py`: Parsing code to get cp_loss and its histograms (counts) for extended players for each game, and save them in format of **.npy**. Note I don't normalize when saving, so I can sweep across it to get parametrization of num_games. 14 | 15 | `get_cp_loss_per_move.py`: Parsing code to get cp_loss and its histograms for both train and extended players for all games by moves, and save them in format of **.npy**. 16 | 17 | `get_cp_loss_per_move_per_game.py`: Parsing code to get cp_loss and its histograms for both train and extended players for each game by moves, and save them in format of **.npy**. 18 | 19 | `get_cp_loss_per_move_per_game_count`: Parsing code to get cp_loss and its histograms (counts) for both train and extended players for each game by moves, and save them in format of **.npy**. 20 | 21 | `test_all_games.py`: Baseline to test accuracy using all games instead of individual games, with Euclidean Distance or Naive Bayes. Data is from `cp_loss_hist`. 22 | 23 | `sweep_num_games.py`: Baseline using Euclidean Distance or Naive Bayes. Training Data is from `cp_loss_hist` and Test Data is from `cp_loss_count_per_game`. Will sweep across [1, 2, 4, 8, 16] number of games. 24 | 25 | `sweep_moves_per_game.py`: Naive Bayes on per move evaluation. This is done on average accuracy for each game. Training data is from `cp_loss_hist_per_move`, Test data is from `cp_loss_hist_per_move_per_game`. 26 | 27 | `sweep_moves_all_games.py`: Naive Bayes on per move evaluation. This is done on average accuracy for each game. Data is from `cp_loss_hist_per_move` 28 | 29 | `sweep_moves_num_games.py`: Naive Bayes on per move evaluation given number of games. Training data is from `cp_loss_hist_per_move`, Test data is from `cp_loss_hist_per_move_per_game_count`. Set it to 1 will be same as `sweep_moves_per_game.py` 30 | 31 | `train_cploss_per_game.py`: Baseline using simple neural network with 2 fully-connected layer. Training on each game, and also evaluate per game accuracy. **This now gives nan value when training on 30 players.** 32 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/get_cp_loss.py: -------------------------------------------------------------------------------- 1 | 2 | import bz2 3 | import csv 4 | import argparse 5 | import os 6 | import numpy as np 7 | # import matplotlib 8 | # matplotlib.use('TkAgg') 9 | import matplotlib.pyplot as plt 10 | import multiprocessing 11 | from functools import partial 12 | 13 | def parse_argument(): 14 | parser = argparse.ArgumentParser(description='arg parser') 15 | 16 | parser.add_argument('--input_dir', default='/data/transfer_players_validate') 17 | parser.add_argument('--player_name_dir', default='../transfer_training/final_models_val/unfrozen_copy') 18 | parser.add_argument('--saved_dir', default='cp_loss_hist') 19 | parser.add_argument('--will_save', default=True) 20 | 21 | return parser.parse_args() 22 | 23 | def normalize(data): 24 | norm = np.linalg.norm(data) 25 | data_norm = data/norm 26 | return data_norm 27 | 28 | def prepare_dataset(players, player_name, cp_hist, dataset): 29 | # add up black and white games (counts can be directly added) 30 | if players[player_name][dataset] is None: 31 | players[player_name][dataset] = cp_hist 32 | else: 33 | players[player_name][dataset] = players[player_name][dataset] + cp_hist 34 | 35 | def save_npy(saved_dir, players, player_name, dataset): 36 | if not os.path.exists(saved_dir): 37 | os.mkdir(saved_dir) 38 | 39 | saved = os.path.join(saved_dir, player_name + '_{}.npy'.format(dataset)) 40 | print('saving data to {}'.format(saved)) 41 | np.save(saved, players[player_name][dataset]) 42 | 43 | def multi_parse(input_dir, saved_dir, players, save, player_name): 44 | 45 | print("=============================================") 46 | print("parsing data for {}".format(player_name)) 47 | players[player_name] = {'train': None, 'validation': None, 'test': None} 48 | 49 | csv_dir = os.path.join(input_dir, player_name, 'csvs') 50 | # for each csv, add up black and white games (counts can be directly added) 51 | for csv_fname in os.listdir(csv_dir): 52 | path = os.path.join(csv_dir, csv_fname) 53 | # parse bz2 file 54 | source_file = bz2.BZ2File(path, "r") 55 | cp_hist, num_games = get_cp_loss_from_csv(player_name, source_file) 56 | print(path) 57 | 58 | if csv_fname.startswith('train'): 59 | prepare_dataset(players, player_name, cp_hist, 'train') 60 | 61 | elif csv_fname.startswith('validate'): 62 | prepare_dataset(players, player_name, cp_hist, 'validation') 63 | 64 | elif csv_fname.startswith('test'): 65 | prepare_dataset(players, player_name, cp_hist, 'test') 66 | 67 | # normalize the histogram to range [0, 1] 68 | players[player_name]['train'] = normalize(players[player_name]['train']) 69 | players[player_name]['validation'] = normalize(players[player_name]['validation']) 70 | players[player_name]['test'] = normalize(players[player_name]['test']) 71 | 72 | # save for future use, parsing takes too long... 73 | if save: 74 | save_npy(saved_dir, players, player_name, 'train') 75 | save_npy(saved_dir, players, player_name, 'validation') 76 | save_npy(saved_dir, players, player_name, 'test') 77 | 78 | def construct_datasets(player_names, input_dir, saved_dir, will_save): 79 | players = {} 80 | 81 | pool = multiprocessing.Pool(25) 82 | func = partial(multi_parse, input_dir, saved_dir, players, will_save) 83 | pool.map(func, player_names) 84 | pool.close() 85 | pool.join() 86 | 87 | def get_cp_loss_from_csv(player_name, path): 88 | cp_losses = [] 89 | games = {} 90 | with bz2.open(path, 'rt') as f: 91 | for i, line in enumerate(path): 92 | if i > 0: 93 | line = line.decode("utf-8") 94 | row = line.rstrip().split(',') 95 | # avoid empty line 96 | if row[0] == '': 97 | continue 98 | 99 | game_id = row[0] 100 | cp_loss = row[17] 101 | active_player = row[25] 102 | if player_name != active_player: 103 | continue 104 | 105 | # ignore cases like -inf, inf, nan 106 | if cp_loss != str(-1 * np.inf) and cp_loss != str(np.inf) and cp_loss != 'nan': 107 | # append cp loss per move 108 | cp_losses.append(float(cp_loss)) 109 | 110 | # for purpose of counting how many games 111 | if game_id not in games: 112 | games[game_id] = 1 113 | 114 | 115 | ######################## plot for viewing ######################## 116 | # plt.hist(cp_losses, density=False, bins=50) 117 | # plt.ylabel('Count') 118 | # plt.xlabel('Cp Loss') 119 | # plt.show() 120 | 121 | cp_hist = np.histogram(cp_losses, density=False, bins=50, range=(0, 5)) # density=False for counts 122 | 123 | cp_hist = cp_hist[0] # cp_hist in format of (hist count, range) 124 | 125 | print("number of games: {}".format(len(games))) 126 | 127 | return cp_hist, len(games) 128 | 129 | 130 | def get_player_names(player_name_dir): 131 | 132 | player_names = [] 133 | for player_name in os.listdir(player_name_dir): 134 | player = player_name.replace("_unfrozen_copy", "") 135 | player_names.append(player) 136 | 137 | # print(player_names) 138 | return player_names 139 | 140 | 141 | if __name__ == '__main__': 142 | args = parse_argument() 143 | 144 | player_names = get_player_names(args.player_name_dir) 145 | 146 | construct_datasets(player_names, args.input_dir, args.saved_dir, args.will_save) 147 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/get_cp_loss_per_game.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import csv 3 | import argparse 4 | import os 5 | import numpy as np 6 | # import matplotlib 7 | # matplotlib.use('TkAgg') 8 | import matplotlib.pyplot as plt 9 | import multiprocessing 10 | from functools import partial 11 | 12 | def parse_argument(): 13 | parser = argparse.ArgumentParser(description='arg parser') 14 | 15 | parser.add_argument('--input_dir', default='/data/transfer_players_validate') 16 | parser.add_argument('--player_name_dir', default='../transfer_training/final_models_val/unfrozen_copy') 17 | parser.add_argument('--saved_dir', default='cp_loss_count_per_game') 18 | parser.add_argument('--will_save', default=True) 19 | 20 | return parser.parse_args() 21 | 22 | def normalize(data): 23 | norm = np.linalg.norm(data) 24 | data_norm = data/norm 25 | return data_norm 26 | 27 | def prepare_dataset(players, player_name, games, dataset): 28 | # add up black and white games 29 | if players[player_name][dataset] is None: 30 | players[player_name][dataset] = games 31 | else: 32 | players[player_name][dataset].update(games) 33 | 34 | def save_npy(saved_dir, players, player_name, dataset): 35 | if not os.path.exists(saved_dir): 36 | os.mkdir(saved_dir) 37 | 38 | saved = os.path.join(saved_dir, player_name + '_{}.npy'.format(dataset)) 39 | print('saving data to {}'.format(saved)) 40 | print('total number of games: {}'.format(len(players[player_name][dataset]))) 41 | np.save(saved, players[player_name][dataset]) 42 | 43 | def multi_parse(input_dir, saved_dir, players, save, player_name): 44 | 45 | print("=============================================") 46 | print("parsing data for {}".format(player_name)) 47 | players[player_name] = {'train': None, 'validation': None, 'test': None} 48 | 49 | csv_dir = os.path.join(input_dir, player_name, 'csvs') 50 | for csv_fname in os.listdir(csv_dir): 51 | path = os.path.join(csv_dir, csv_fname) 52 | print(path) 53 | source_file = bz2.BZ2File(path, "r") 54 | games, num_games = get_cp_loss_from_csv(player_name, source_file) 55 | 56 | if csv_fname.startswith('train'): 57 | prepare_dataset(players, player_name, games, 'train') 58 | 59 | elif csv_fname.startswith('validate'): 60 | prepare_dataset(players, player_name, games, 'validation') 61 | 62 | elif csv_fname.startswith('test'): 63 | prepare_dataset(players, player_name, games, 'test') 64 | 65 | if save: 66 | save_npy(saved_dir, players, player_name, 'train') 67 | save_npy(saved_dir, players, player_name, 'validation') 68 | save_npy(saved_dir, players, player_name, 'test') 69 | 70 | def construct_datasets(player_names, input_dir, saved_dir, will_save): 71 | players = {} 72 | 73 | pool = multiprocessing.Pool(25) 74 | func = partial(multi_parse, input_dir, saved_dir, players, will_save) 75 | pool.map(func, player_names) 76 | pool.close() 77 | pool.join() 78 | 79 | def get_cp_loss_from_csv(player_name, path): 80 | cp_losses = [] 81 | games = {} 82 | with bz2.open(path, 'rt') as f: 83 | for i, line in enumerate(path): 84 | if i > 0: 85 | line = line.decode("utf-8") 86 | row = line.rstrip().split(',') 87 | # avoid empty line 88 | if row[0] == '': 89 | continue 90 | 91 | game_id = row[0] 92 | cp_loss = row[17] 93 | active_player = row[25] 94 | if player_name != active_player: 95 | continue 96 | 97 | if cp_loss != str(-1 * np.inf) and cp_loss != str(np.inf) and cp_loss != 'nan': 98 | cp_losses.append(float(cp_loss)) 99 | 100 | # for purpose of counting how many games 101 | if game_id not in games: 102 | games[game_id] = [float(cp_loss)] 103 | else: 104 | games[game_id].append(float(cp_loss)) 105 | 106 | # get per game histogram 107 | for key, value in games.items(): 108 | games[key] = np.histogram(games[key], density=False, bins=50, range=(0, 5)) 109 | # cp_hist in format (hist, range) 110 | games[key] = games[key][0] 111 | # games[key] = normalize(games[key]) 112 | 113 | print("number of games: {}".format(len(games))) 114 | 115 | return games, len(games) 116 | 117 | def get_player_names(player_name_dir): 118 | 119 | player_names = [] 120 | for player_name in os.listdir(player_name_dir): 121 | player = player_name.replace("_unfrozen_copy", "") 122 | player_names.append(player) 123 | 124 | # print(player_names) 125 | return player_names 126 | 127 | if __name__ == '__main__': 128 | args = parse_argument() 129 | 130 | player_names = get_player_names(args.player_name_dir) 131 | 132 | construct_datasets(player_names, args.input_dir, args.saved_dir, args.will_save) 133 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/get_cp_loss_per_move.py: -------------------------------------------------------------------------------- 1 | 2 | import bz2 3 | import csv 4 | import argparse 5 | import os 6 | import numpy as np 7 | # import matplotlib 8 | # matplotlib.use('TkAgg') 9 | import matplotlib.pyplot as plt 10 | import multiprocessing 11 | from functools import partial 12 | 13 | def parse_argument(): 14 | parser = argparse.ArgumentParser(description='arg parser') 15 | 16 | parser.add_argument('--input_dir', default='/data/transfer_players_validate') 17 | parser.add_argument('--player_name_dir', default='../transfer_training/final_models_val/unfrozen_copy') 18 | parser.add_argument('--saved_dir', default='cp_loss_hist_per_move') 19 | parser.add_argument('--will_save', default=True) 20 | 21 | return parser.parse_args() 22 | 23 | def normalize(data): 24 | for i in range(101): 25 | # update in-place 26 | if any(v != 0 for v in data['start_after'][i]): 27 | start_after_norm = np.linalg.norm(data['start_after'][i]) 28 | data['start_after'][i] = data['start_after'][i] / start_after_norm 29 | if any(v != 0 for v in data['stop_after'][i]): 30 | stop_after_norm = np.linalg.norm(data['stop_after'][i]) 31 | data['stop_after'][i] = data['stop_after'][i] / stop_after_norm 32 | 33 | 34 | def prepare_dataset(players, player_name, cp_loss_hist_dict, dataset): 35 | # add up black and white games (counts can be directly added) 36 | if players[player_name][dataset] is None: 37 | players[player_name][dataset] = cp_loss_hist_dict 38 | else: 39 | # add up each move 40 | for i in range(101): 41 | players[player_name][dataset]['start_after'][i] = players[player_name][dataset]['start_after'][i] + cp_loss_hist_dict['start_after'][i] 42 | players[player_name][dataset]['stop_after'][i] = players[player_name][dataset]['stop_after'][i] + cp_loss_hist_dict['stop_after'][i] 43 | 44 | def save_npy(saved_dir, players, player_name, dataset): 45 | if not os.path.exists(saved_dir): 46 | os.mkdir(saved_dir) 47 | 48 | saved = os.path.join(saved_dir, player_name + '_{}.npy'.format(dataset)) 49 | print('saving data to {}'.format(saved)) 50 | np.save(saved, players[player_name][dataset]) 51 | 52 | def multi_parse(input_dir, saved_dir, players, save, player_name): 53 | 54 | print("=============================================") 55 | print("parsing data for {}".format(player_name)) 56 | players[player_name] = {'train': None, 'validation': None, 'test': None} 57 | 58 | csv_dir = os.path.join(input_dir, player_name, 'csvs') 59 | # for each csv, add up black and white games (counts can be directly added) 60 | for csv_fname in os.listdir(csv_dir): 61 | path = os.path.join(csv_dir, csv_fname) 62 | # parse bz2 file 63 | source_file = bz2.BZ2File(path, "r") 64 | cp_loss_hist_dict, num_games = get_cp_loss_from_csv(player_name, source_file) 65 | print(path) 66 | 67 | if csv_fname.startswith('train'): 68 | prepare_dataset(players, player_name, cp_loss_hist_dict, 'train') 69 | 70 | elif csv_fname.startswith('validate'): 71 | prepare_dataset(players, player_name, cp_loss_hist_dict, 'validation') 72 | 73 | elif csv_fname.startswith('test'): 74 | prepare_dataset(players, player_name, cp_loss_hist_dict, 'test') 75 | 76 | # normalize the histogram to range [0, 1] 77 | normalize(players[player_name]['train']) 78 | normalize(players[player_name]['validation']) 79 | normalize(players[player_name]['test']) 80 | 81 | # save for future use, parsing takes too long... 82 | if save: 83 | save_npy(saved_dir, players, player_name, 'train') 84 | save_npy(saved_dir, players, player_name, 'validation') 85 | save_npy(saved_dir, players, player_name, 'test') 86 | 87 | def construct_datasets(player_names, input_dir, saved_dir, will_save): 88 | players = {} 89 | 90 | pool = multiprocessing.Pool(25) 91 | func = partial(multi_parse, input_dir, saved_dir, players, will_save) 92 | pool.map(func, player_names) 93 | pool.close() 94 | pool.join() 95 | 96 | def get_cp_loss_start_after(cp_losses, move_start=0): 97 | return cp_losses[move_start:] 98 | 99 | # 0 will be empty 100 | def get_cp_loss_stop_after(cp_losses, move_stop=100): 101 | return cp_losses[:move_stop] 102 | 103 | # move_stop is in range [0, 100] 104 | def get_cp_loss_from_csv(player_name, path): 105 | cp_loss_hist_dict = {'start_after': {}, 'stop_after': {}} 106 | # cp_losses = [] 107 | games = {} 108 | with bz2.open(path, 'rt') as f: 109 | for i, line in enumerate(path): 110 | if i > 0: 111 | line = line.decode("utf-8") 112 | row = line.rstrip().split(',') 113 | # avoid empty line 114 | if row[0] == '': 115 | continue 116 | 117 | active_player = row[25] 118 | if player_name != active_player: 119 | continue 120 | 121 | # move_ply starts from 0, need to add 1, move will be parsed in order 122 | move_ply = int(row[13]) 123 | move = move_ply // 2 + 1 124 | 125 | game_id = row[0] 126 | cp_loss = row[17] 127 | 128 | # ignore cases like -inf, inf, nan 129 | if cp_loss != str(-1 * np.inf) and cp_loss != str(np.inf) and cp_loss != 'nan': 130 | if game_id not in games: 131 | games[game_id] = [float(cp_loss)] 132 | else: 133 | games[game_id].append(float(cp_loss)) 134 | 135 | # get per game histogram 136 | for i in range(101): 137 | cp_loss_hist_dict['start_after'][i] = [] 138 | cp_loss_hist_dict['stop_after'][i] = [] 139 | for key, value in games.items(): 140 | cp_loss_hist_dict['start_after'][i].extend(get_cp_loss_start_after(value, i)) 141 | cp_loss_hist_dict['stop_after'][i].extend(get_cp_loss_stop_after(value, i)) 142 | 143 | # transform into counts 144 | cp_loss_hist_dict['start_after'][i] = np.histogram(cp_loss_hist_dict['start_after'][i], density=False, bins=50, range=(0, 5))[0] 145 | cp_loss_hist_dict['stop_after'][i] = np.histogram(cp_loss_hist_dict['stop_after'][i], density=False, bins=50, range=(0, 5))[0] 146 | 147 | 148 | print("number of games: {}".format(len(games))) 149 | 150 | return cp_loss_hist_dict, len(games) 151 | 152 | def get_player_names(player_name_dir): 153 | 154 | player_names = [] 155 | for player_name in os.listdir(player_name_dir): 156 | player = player_name.replace("_unfrozen_copy", "") 157 | player_names.append(player) 158 | 159 | # print(player_names) 160 | return player_names 161 | 162 | if __name__ == '__main__': 163 | args = parse_argument() 164 | 165 | player_names = get_player_names(args.player_name_dir) 166 | 167 | construct_datasets(player_names, args.input_dir, args.saved_dir, args.will_save) 168 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/get_cp_loss_per_move_per_game.py: -------------------------------------------------------------------------------- 1 | 2 | import bz2 3 | import csv 4 | import argparse 5 | import os 6 | import numpy as np 7 | # import matplotlib 8 | # matplotlib.use('TkAgg') 9 | import matplotlib.pyplot as plt 10 | import multiprocessing 11 | from functools import partial 12 | 13 | def parse_argument(): 14 | parser = argparse.ArgumentParser(description='arg parser') 15 | 16 | parser.add_argument('--input_dir', default='/data/transfer_players_validate') 17 | parser.add_argument('--player_name_dir', default='../transfer_training/final_models_val/unfrozen_copy') 18 | parser.add_argument('--saved_dir', default='cp_loss_hist_per_move_per_game') 19 | parser.add_argument('--will_save', default=True) 20 | 21 | return parser.parse_args() 22 | 23 | def normalize(data): 24 | data_norm = data 25 | 26 | if any(v != 0 for v in data): 27 | norm = np.linalg.norm(data) 28 | data_norm = data/norm 29 | 30 | return data_norm 31 | 32 | def prepare_dataset(players, player_name, games, dataset): 33 | # add up black and white games 34 | if players[player_name][dataset] is None: 35 | players[player_name][dataset] = games 36 | else: 37 | players[player_name][dataset].update(games) 38 | 39 | 40 | def save_npy(saved_dir, players, player_name, dataset): 41 | if not os.path.exists(saved_dir): 42 | os.mkdir(saved_dir) 43 | 44 | saved = os.path.join(saved_dir, player_name + '_{}.npy'.format(dataset)) 45 | print('saving data to {}'.format(saved)) 46 | np.save(saved, players[player_name][dataset]) 47 | 48 | def multi_parse(input_dir, saved_dir, players, save, player_name): 49 | 50 | print("=============================================") 51 | print("parsing data for {}".format(player_name)) 52 | players[player_name] = {'train': None, 'validation': None, 'test': None} 53 | 54 | csv_dir = os.path.join(input_dir, player_name, 'csvs') 55 | # for each csv, add up black and white games (counts can be directly added) 56 | for csv_fname in os.listdir(csv_dir): 57 | path = os.path.join(csv_dir, csv_fname) 58 | print(path) 59 | # parse bz2 file 60 | source_file = bz2.BZ2File(path, "r") 61 | games, num_games = get_cp_loss_from_csv(player_name, source_file) 62 | 63 | if csv_fname.startswith('train'): 64 | prepare_dataset(players, player_name, games, 'train') 65 | 66 | elif csv_fname.startswith('validate'): 67 | prepare_dataset(players, player_name, games, 'validation') 68 | 69 | elif csv_fname.startswith('test'): 70 | prepare_dataset(players, player_name, games, 'test') 71 | 72 | # save for future use, parsing takes too long... 73 | if save: 74 | save_npy(saved_dir, players, player_name, 'train') 75 | save_npy(saved_dir, players, player_name, 'validation') 76 | save_npy(saved_dir, players, player_name, 'test') 77 | 78 | def construct_datasets(player_names, input_dir, saved_dir, will_save): 79 | players = {} 80 | 81 | pool = multiprocessing.Pool(25) 82 | func = partial(multi_parse, input_dir, saved_dir, players, will_save) 83 | pool.map(func, player_names) 84 | pool.close() 85 | pool.join() 86 | 87 | 88 | def get_cp_loss_start_after(cp_losses, move_start=0): 89 | return cp_losses[move_start:] 90 | 91 | # 0 will be empty 92 | def get_cp_loss_stop_after(cp_losses, move_stop=100): 93 | return cp_losses[:move_stop] 94 | 95 | # move_stop is in range [0, 100] 96 | def get_cp_loss_from_csv(player_name, path): 97 | # cp_losses = [] 98 | games = {} 99 | with bz2.open(path, 'rt') as f: 100 | for i, line in enumerate(path): 101 | if i > 0: 102 | line = line.decode("utf-8") 103 | row = line.rstrip().split(',') 104 | # avoid empty line 105 | if row[0] == '': 106 | continue 107 | 108 | active_player = row[25] 109 | if player_name != active_player: 110 | continue 111 | 112 | # move_ply starts from 0, need to add 1, move will be parsed in order 113 | move_ply = int(row[13]) 114 | move = move_ply // 2 + 1 115 | 116 | game_id = row[0] 117 | cp_loss = row[17] 118 | 119 | # ignore cases like -inf, inf, nan 120 | if cp_loss != str(-1 * np.inf) and cp_loss != str(np.inf) and cp_loss != 'nan': 121 | if game_id not in games: 122 | games[game_id] = [float(cp_loss)] 123 | else: 124 | games[game_id].append(float(cp_loss)) 125 | 126 | final_games = {key: {'start_after': {}, 'stop_after': {}} for key in games.keys()} 127 | 128 | # get per game histogram 129 | for i in range(101): 130 | for key, value in games.items(): 131 | final_games[key]['start_after'][i] = get_cp_loss_start_after(value, i) 132 | final_games[key]['stop_after'][i] = get_cp_loss_stop_after(value, i) 133 | 134 | final_games[key]['start_after'][i] = np.histogram(final_games[key]['start_after'][i], density=False, bins=50, range=(0, 5))[0] 135 | final_games[key]['stop_after'][i] = np.histogram(final_games[key]['stop_after'][i], density=False, bins=50, range=(0, 5))[0] 136 | 137 | final_games[key]['start_after'][i] = normalize(final_games[key]['start_after'][i]) 138 | final_games[key]['stop_after'][i] = normalize(final_games[key]['stop_after'][i]) 139 | 140 | print("number of games: {}".format(len(games))) 141 | 142 | return final_games, len(games) 143 | 144 | 145 | def get_player_names(player_name_dir): 146 | 147 | player_names = [] 148 | for player_name in os.listdir(player_name_dir): 149 | player = player_name.replace("_unfrozen_copy", "") 150 | player_names.append(player) 151 | 152 | # print(player_names) 153 | return player_names 154 | 155 | if __name__ == '__main__': 156 | args = parse_argument() 157 | 158 | player_names = get_player_names(args.player_name_dir) 159 | 160 | construct_datasets(player_names, args.input_dir, args.saved_dir, args.will_save) 161 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/get_cp_loss_per_move_per_game_count.py: -------------------------------------------------------------------------------- 1 | 2 | import bz2 3 | import csv 4 | import argparse 5 | import os 6 | import numpy as np 7 | # import matplotlib 8 | # matplotlib.use('TkAgg') 9 | import matplotlib.pyplot as plt 10 | import multiprocessing 11 | from functools import partial 12 | 13 | def parse_argument(): 14 | parser = argparse.ArgumentParser(description='arg parser') 15 | 16 | parser.add_argument('--input_dir', default='/data/transfer_players_validate') 17 | parser.add_argument('--player_name_dir', default='../transfer_training/final_models_val/unfrozen_copy') 18 | parser.add_argument('--saved_dir', default='cp_loss_hist_per_move_per_game_count') 19 | parser.add_argument('--will_save', default=True) 20 | 21 | return parser.parse_args() 22 | 23 | def normalize(data): 24 | data_norm = data 25 | 26 | if any(v != 0 for v in data): 27 | norm = np.linalg.norm(data) 28 | data_norm = data/norm 29 | 30 | return data_norm 31 | 32 | def prepare_dataset(players, player_name, games, dataset): 33 | # add up black and white games 34 | if players[player_name][dataset] is None: 35 | players[player_name][dataset] = games 36 | else: 37 | players[player_name][dataset].update(games) 38 | 39 | 40 | def save_npy(saved_dir, players, player_name, dataset): 41 | if not os.path.exists(saved_dir): 42 | os.mkdir(saved_dir) 43 | 44 | saved = os.path.join(saved_dir, player_name + '_{}.npy'.format(dataset)) 45 | print('saving data to {}'.format(saved)) 46 | np.save(saved, players[player_name][dataset]) 47 | 48 | def multi_parse(input_dir, saved_dir, players, save, player_name): 49 | 50 | print("=============================================") 51 | print("parsing data for {}".format(player_name)) 52 | players[player_name] = {'train': None, 'validation': None, 'test': None} 53 | 54 | csv_dir = os.path.join(input_dir, player_name, 'csvs') 55 | # for each csv, add up black and white games (counts can be directly added) 56 | for csv_fname in os.listdir(csv_dir): 57 | path = os.path.join(csv_dir, csv_fname) 58 | print(path) 59 | # parse bz2 file 60 | source_file = bz2.BZ2File(path, "r") 61 | games, num_games = get_cp_loss_from_csv(player_name, source_file) 62 | 63 | if csv_fname.startswith('train'): 64 | prepare_dataset(players, player_name, games, 'train') 65 | 66 | elif csv_fname.startswith('validate'): 67 | prepare_dataset(players, player_name, games, 'validation') 68 | 69 | elif csv_fname.startswith('test'): 70 | prepare_dataset(players, player_name, games, 'test') 71 | 72 | # save for future use, parsing takes too long... 73 | if save: 74 | save_npy(saved_dir, players, player_name, 'train') 75 | save_npy(saved_dir, players, player_name, 'validation') 76 | save_npy(saved_dir, players, player_name, 'test') 77 | 78 | def construct_datasets(player_names, input_dir, saved_dir, will_save): 79 | players = {} 80 | 81 | pool = multiprocessing.Pool(40) 82 | func = partial(multi_parse, input_dir, saved_dir, players, will_save) 83 | pool.map(func, player_names) 84 | pool.close() 85 | pool.join() 86 | 87 | 88 | def get_cp_loss_start_after(cp_losses, move_start=0): 89 | return cp_losses[move_start:] 90 | 91 | # 0 will be empty 92 | def get_cp_loss_stop_after(cp_losses, move_stop=100): 93 | return cp_losses[:move_stop] 94 | 95 | # move_stop is in range [0, 100] 96 | def get_cp_loss_from_csv(player_name, path): 97 | # cp_losses = [] 98 | games = {} 99 | with bz2.open(path, 'rt') as f: 100 | for i, line in enumerate(path): 101 | if i > 0: 102 | line = line.decode("utf-8") 103 | row = line.rstrip().split(',') 104 | # avoid empty line 105 | if row[0] == '': 106 | continue 107 | 108 | active_player = row[25] 109 | if player_name != active_player: 110 | continue 111 | 112 | # move_ply starts from 0, need to add 1, move will be parsed in order 113 | move_ply = int(row[13]) 114 | move = move_ply // 2 + 1 115 | 116 | game_id = row[0] 117 | cp_loss = row[17] 118 | 119 | if game_id in games: 120 | if cp_loss == str(-1 * np.inf) or cp_loss == str(np.inf) or cp_loss == 'nan': 121 | cp_loss = float(-100) 122 | 123 | # ignore cases like -inf, inf, nan 124 | if cp_loss != str(-1 * np.inf) and cp_loss != str(np.inf) and cp_loss != 'nan': 125 | if game_id not in games: 126 | games[game_id] = [float(cp_loss)] 127 | else: 128 | games[game_id].append(float(cp_loss)) 129 | 130 | final_games = {key: {'start_after': {}, 'stop_after': {}} for key, value in games.items() if len(value) > 25 and len(value) < 50} 131 | # final_games = {key: {'start_after': {}, 'stop_after': {}} for key in games.keys()} 132 | 133 | # get per game histogram 134 | for i in range(101): 135 | for key, value in games.items(): 136 | if len(value) > 25 and len(value) < 50: 137 | if key not in final_games: 138 | print(key) 139 | 140 | final_games[key]['start_after'][i] = get_cp_loss_start_after(value, i) 141 | final_games[key]['stop_after'][i] = get_cp_loss_stop_after(value, i) 142 | 143 | final_games[key]['start_after'][i] = np.histogram(final_games[key]['start_after'][i], density=False, bins=50, range=(0, 5))[0] 144 | final_games[key]['stop_after'][i] = np.histogram(final_games[key]['stop_after'][i], density=False, bins=50, range=(0, 5))[0] 145 | 146 | # final_games[key]['start_after'][i] = normalize(final_games[key]['start_after'][i]) 147 | # final_games[key]['stop_after'][i] = normalize(final_games[key]['stop_after'][i]) 148 | 149 | print("number of games: {}".format(len(games))) 150 | return final_games, len(games) 151 | 152 | 153 | def get_player_names(player_name_dir): 154 | 155 | player_names = [] 156 | for player_name in os.listdir(player_name_dir): 157 | player = player_name.replace("_unfrozen_copy", "") 158 | player_names.append(player) 159 | 160 | # print(player_names) 161 | return player_names 162 | 163 | if __name__ == '__main__': 164 | args = parse_argument() 165 | 166 | player_names = get_player_names(args.player_name_dir) 167 | 168 | construct_datasets(player_names, args.input_dir, args.saved_dir, args.will_save) 169 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results/games_accuracy.csv: -------------------------------------------------------------------------------- 1 | num_games,accuracy 2 | 1,0.05579916684937514 3 | 2,0.06811689738519007 4 | 4,0.08245149911816578 5 | 8,0.11002661934338953 6 | 16,0.1223021582733813 7 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results/start_after.csv: -------------------------------------------------------------------------------- 1 | move,accuracy 2 | 0,0.059309021113243765 3 | 1,0.05508637236084453 4 | 2,0.053358925143953934 5 | 3,0.052207293666026874 6 | 4,0.050287907869481764 7 | 5,0.054318618042226485 8 | 6,0.05182341650671785 9 | 7,0.05105566218809981 10 | 8,0.05259117082533589 11 | 9,0.05356114417354579 12 | 10,0.05586484929928969 13 | 11,0.05682472643501632 14 | 12,0.05779569892473118 15 | 13,0.05915114269252929 16 | 14,0.05859750240153699 17 | 15,0.05573707476455891 18 | 16,0.05714835482008851 19 | 17,0.05772200772200772 20 | 18,0.05515773175924134 21 | 19,0.05539358600583091 22 | 20,0.05190989226248776 23 | 21,0.05554457402648745 24 | 22,0.05923836389280677 25 | 23,0.06203627370156636 26 | 24,0.0579647917561185 27 | 25,0.053575482406356414 28 | 26,0.053108174253548704 29 | 27,0.05705944798301486 30 | 28,0.061177152797912436 31 | 29,0.048 32 | 30,0.05115452930728242 33 | 31,0.045636509207365894 34 | 32,0.040467625899280574 35 | 33,0.04314329738058552 36 | 34,0.03882915173237754 37 | 35,0.03529411764705882 38 | 36,0.0575831305758313 39 | 37,0.04664723032069971 40 | 38,0.05052878965922444 41 | 39,0.04850213980028531 42 | 40,0.06204379562043796 43 | 41,0.06697459584295612 44 | 42,0.06382978723404255 45 | 43,0.05761316872427984 46 | 44,0.08928571428571429 47 | 45,0.13274336283185842 48 | 46,0.1506849315068493 49 | 47,0.2571428571428571 50 | 48,0.3076923076923077 51 | 49,0 52 | 50,0 53 | 51,0 54 | 52,0 55 | 53,0 56 | 54,0 57 | 55,0 58 | 56,0 59 | 57,0 60 | 58,0 61 | 59,0 62 | 60,0 63 | 61,0 64 | 62,0 65 | 63,0 66 | 64,0 67 | 65,0 68 | 66,0 69 | 67,0 70 | 68,0 71 | 69,0 72 | 70,0 73 | 71,0 74 | 72,0 75 | 73,0 76 | 74,0 77 | 75,0 78 | 76,0 79 | 77,0 80 | 78,0 81 | 79,0 82 | 80,0 83 | 81,0 84 | 82,0 85 | 83,0 86 | 84,0 87 | 85,0 88 | 86,0 89 | 87,0 90 | 88,0 91 | 89,0 92 | 90,0 93 | 91,0 94 | 92,0 95 | 93,0 96 | 94,0 97 | 95,0 98 | 96,0 99 | 97,0 100 | 98,0 101 | 99,0 102 | 100,0 103 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results/start_after_all_game.csv: -------------------------------------------------------------------------------- 1 | move,accuracy 2 | 0,0.6333333333333333 3 | 1,0.6666666666666666 4 | 2,0.5666666666666667 5 | 3,0.5666666666666667 6 | 4,0.5333333333333333 7 | 5,0.5333333333333333 8 | 6,0.5666666666666667 9 | 7,0.5333333333333333 10 | 8,0.5 11 | 9,0.5666666666666667 12 | 10,0.5 13 | 11,0.4666666666666667 14 | 12,0.4666666666666667 15 | 13,0.5 16 | 14,0.5 17 | 15,0.43333333333333335 18 | 16,0.3 19 | 17,0.23333333333333334 20 | 18,0.3 21 | 19,0.3333333333333333 22 | 20,0.3 23 | 21,0.36666666666666664 24 | 22,0.3333333333333333 25 | 23,0.26666666666666666 26 | 24,0.3 27 | 25,0.23333333333333334 28 | 26,0.26666666666666666 29 | 27,0.36666666666666664 30 | 28,0.26666666666666666 31 | 29,0.3 32 | 30,0.4666666666666667 33 | 31,0.43333333333333335 34 | 32,0.36666666666666664 35 | 33,0.36666666666666664 36 | 34,0.36666666666666664 37 | 35,0.3333333333333333 38 | 36,0.4 39 | 37,0.36666666666666664 40 | 38,0.4 41 | 39,0.4 42 | 40,0.4 43 | 41,0.23333333333333334 44 | 42,0.3333333333333333 45 | 43,0.2857142857142857 46 | 44,0.3333333333333333 47 | 45,0.2222222222222222 48 | 46,0.28 49 | 47,0.5 50 | 48,0.3333333333333333 51 | 49,0 52 | 50,0 53 | 51,0 54 | 52,0 55 | 53,0 56 | 54,0 57 | 55,0 58 | 56,0 59 | 57,0 60 | 58,0 61 | 59,0 62 | 60,0 63 | 61,0 64 | 62,0 65 | 63,0 66 | 64,0 67 | 65,0 68 | 66,0 69 | 67,0 70 | 68,0 71 | 69,0 72 | 70,0 73 | 71,0 74 | 72,0 75 | 73,0 76 | 74,0 77 | 75,0 78 | 76,0 79 | 77,0 80 | 78,0 81 | 79,0 82 | 80,0 83 | 81,0 84 | 82,0 85 | 83,0 86 | 84,0 87 | 85,0 88 | 86,0 89 | 87,0 90 | 88,0 91 | 89,0 92 | 90,0 93 | 91,0 94 | 92,0 95 | 93,0 96 | 94,0 97 | 95,0 98 | 96,0 99 | 97,0 100 | 98,0 101 | 99,0 102 | 100,0 103 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results/stop_after.csv: -------------------------------------------------------------------------------- 1 | move,accuracy 2 | 0,0.01689059500959693 3 | 1,0.10134357005758157 4 | 2,0.07044145873320537 5 | 3,0.06967370441458733 6 | 4,0.0783109404990403 7 | 5,0.06756238003838771 8 | 6,0.06641074856046066 9 | 7,0.06602687140115163 10 | 8,0.06506717850287908 11 | 9,0.06641074856046066 12 | 10,0.06238003838771593 13 | 11,0.05911708253358925 14 | 12,0.0581573896353167 15 | 13,0.059309021113243765 16 | 14,0.0600767754318618 17 | 15,0.06065259117082534 18 | 16,0.06218809980806142 19 | 17,0.06065259117082534 20 | 18,0.061420345489443376 21 | 19,0.061036468330134354 22 | 20,0.060460652591170824 23 | 21,0.061420345489443376 24 | 22,0.060844529750479846 25 | 23,0.06161228406909789 26 | 24,0.0600767754318618 27 | 25,0.05969289827255278 28 | 26,0.060844529750479846 29 | 27,0.06238003838771593 30 | 28,0.05969289827255278 31 | 29,0.05969289827255278 32 | 30,0.06333973128598848 33 | 31,0.061036468330134354 34 | 32,0.061036468330134354 35 | 33,0.06161228406909789 36 | 34,0.060844529750479846 37 | 35,0.05854126679462572 38 | 36,0.05969289827255278 39 | 37,0.0600767754318618 40 | 38,0.059309021113243765 41 | 39,0.05892514395393474 42 | 40,0.059884836852207295 43 | 41,0.05892514395393474 44 | 42,0.05911708253358925 45 | 43,0.05873320537428023 46 | 44,0.05873320537428023 47 | 45,0.05873320537428023 48 | 46,0.05854126679462572 49 | 47,0.0581573896353167 50 | 48,0.05873320537428023 51 | 49,0.059309021113243765 52 | 50,0.059309021113243765 53 | 51,0.059309021113243765 54 | 52,0.059309021113243765 55 | 53,0.059309021113243765 56 | 54,0.059309021113243765 57 | 55,0.059309021113243765 58 | 56,0.059309021113243765 59 | 57,0.059309021113243765 60 | 58,0.059309021113243765 61 | 59,0.059309021113243765 62 | 60,0.059309021113243765 63 | 61,0.059309021113243765 64 | 62,0.059309021113243765 65 | 63,0.059309021113243765 66 | 64,0.059309021113243765 67 | 65,0.059309021113243765 68 | 66,0.059309021113243765 69 | 67,0.059309021113243765 70 | 68,0.059309021113243765 71 | 69,0.059309021113243765 72 | 70,0.059309021113243765 73 | 71,0.059309021113243765 74 | 72,0.059309021113243765 75 | 73,0.059309021113243765 76 | 74,0.059309021113243765 77 | 75,0.059309021113243765 78 | 76,0.059309021113243765 79 | 77,0.059309021113243765 80 | 78,0.059309021113243765 81 | 79,0.059309021113243765 82 | 80,0.059309021113243765 83 | 81,0.059309021113243765 84 | 82,0.059309021113243765 85 | 83,0.059309021113243765 86 | 84,0.059309021113243765 87 | 85,0.059309021113243765 88 | 86,0.059309021113243765 89 | 87,0.059309021113243765 90 | 88,0.059309021113243765 91 | 89,0.059309021113243765 92 | 90,0.059309021113243765 93 | 91,0.059309021113243765 94 | 92,0.059309021113243765 95 | 93,0.059309021113243765 96 | 94,0.059309021113243765 97 | 95,0.059309021113243765 98 | 96,0.059309021113243765 99 | 97,0.059309021113243765 100 | 98,0.059309021113243765 101 | 99,0.059309021113243765 102 | 100,0.059309021113243765 103 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results/stop_after_all_game.csv: -------------------------------------------------------------------------------- 1 | move,accuracy 2 | 0,0.03333333333333333 3 | 1,0.16666666666666666 4 | 2,0.5333333333333333 5 | 3,0.6333333333333333 6 | 4,0.5 7 | 5,0.5 8 | 6,0.5 9 | 7,0.43333333333333335 10 | 8,0.43333333333333335 11 | 9,0.43333333333333335 12 | 10,0.5 13 | 11,0.43333333333333335 14 | 12,0.4666666666666667 15 | 13,0.5 16 | 14,0.5333333333333333 17 | 15,0.5 18 | 16,0.5333333333333333 19 | 17,0.5666666666666667 20 | 18,0.5333333333333333 21 | 19,0.5666666666666667 22 | 20,0.6 23 | 21,0.5333333333333333 24 | 22,0.5666666666666667 25 | 23,0.5333333333333333 26 | 24,0.6 27 | 25,0.6 28 | 26,0.6 29 | 27,0.6666666666666666 30 | 28,0.6333333333333333 31 | 29,0.6333333333333333 32 | 30,0.6333333333333333 33 | 31,0.6666666666666666 34 | 32,0.6666666666666666 35 | 33,0.6666666666666666 36 | 34,0.6 37 | 35,0.6 38 | 36,0.6 39 | 37,0.6 40 | 38,0.5666666666666667 41 | 39,0.6 42 | 40,0.6 43 | 41,0.6 44 | 42,0.6333333333333333 45 | 43,0.6 46 | 44,0.6333333333333333 47 | 45,0.6333333333333333 48 | 46,0.6333333333333333 49 | 47,0.6333333333333333 50 | 48,0.6333333333333333 51 | 49,0.6333333333333333 52 | 50,0.6333333333333333 53 | 51,0.6333333333333333 54 | 52,0.6333333333333333 55 | 53,0.6333333333333333 56 | 54,0.6333333333333333 57 | 55,0.6333333333333333 58 | 56,0.6333333333333333 59 | 57,0.6333333333333333 60 | 58,0.6333333333333333 61 | 59,0.6333333333333333 62 | 60,0.6333333333333333 63 | 61,0.6333333333333333 64 | 62,0.6333333333333333 65 | 63,0.6333333333333333 66 | 64,0.6333333333333333 67 | 65,0.6333333333333333 68 | 66,0.6333333333333333 69 | 67,0.6333333333333333 70 | 68,0.6333333333333333 71 | 69,0.6333333333333333 72 | 70,0.6333333333333333 73 | 71,0.6333333333333333 74 | 72,0.6333333333333333 75 | 73,0.6333333333333333 76 | 74,0.6333333333333333 77 | 75,0.6333333333333333 78 | 76,0.6333333333333333 79 | 77,0.6333333333333333 80 | 78,0.6333333333333333 81 | 79,0.6333333333333333 82 | 80,0.6333333333333333 83 | 81,0.6333333333333333 84 | 82,0.6333333333333333 85 | 83,0.6333333333333333 86 | 84,0.6333333333333333 87 | 85,0.6333333333333333 88 | 86,0.6333333333333333 89 | 87,0.6333333333333333 90 | 88,0.6333333333333333 91 | 89,0.6333333333333333 92 | 90,0.6333333333333333 93 | 91,0.6333333333333333 94 | 92,0.6333333333333333 95 | 93,0.6333333333333333 96 | 94,0.6333333333333333 97 | 95,0.6333333333333333 98 | 96,0.6333333333333333 99 | 97,0.6333333333333333 100 | 98,0.6333333333333333 101 | 99,0.6333333333333333 102 | 100,0.6333333333333333 103 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results_validation/games_accuracy.csv: -------------------------------------------------------------------------------- 1 | num_games,accuracy 2 | 1,0.005322294500295683 3 | 2,0.007152042305413879 4 | 4,0.00927616894222686 5 | 8,0.013657853265627736 6 | 16,0.020702709097361882 7 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results_validation/start_after.csv: -------------------------------------------------------------------------------- 1 | move,accuracy 2 | 0,0.008001119393492254 3 | 1,0.007988399012898465 4 | 2,0.006487394102831557 5 | 3,0.00449029434960694 6 | 4,0.0042867682601063425 7 | 5,0.004032360648230595 8 | 6,0.004032360648230595 9 | 7,0.004235940620508059 10 | 8,0.004312373586393762 11 | 9,0.004134439242825158 12 | 10,0.004134649636150832 13 | 11,0.0044149829507863 14 | 12,0.004517631488527761 15 | 13,0.004315392840776007 16 | 14,0.004419762835780974 17 | 15,0.004282436910527657 18 | 16,0.004453405132262304 19 | 17,0.0038727983844167794 20 | 18,0.003896336930609315 21 | 19,0.0038616499543038087 22 | 20,0.0037848837962902956 23 | 21,0.0037588076590617386 24 | 22,0.004161569962240068 25 | 23,0.0038119613902767757 26 | 24,0.0037130110684436414 27 | 25,0.003143389199255121 28 | 26,0.002826501275963433 29 | 27,0.002573004599686305 30 | 28,0.0027073019801980196 31 | 29,0.002663825253063399 32 | 30,0.0027465372321534274 33 | 31,0.0026176626123744053 34 | 32,0.0026072529035316427 35 | 33,0.0031696249833177634 36 | 34,0.0031620252200083815 37 | 35,0.0035925520262869663 38 | 36,0.003196184871391609 39 | 37,0.0030963439323567943 40 | 38,0.0032433194669674965 41 | 39,0.004067796610169492 42 | 40,0.003640902943930095 43 | 41,0.003702234563004099 44 | 42,0.004691572545612511 45 | 43,0.004493850520340586 46 | 44,0.0050150451354062184 47 | 45,0.010349926071956629 48 | 46,0.008507347254447023 49 | 47,0.008253094910591471 50 | 48,0.038338658146964855 51 | 49,0 52 | 50,0 53 | 51,0 54 | 52,0 55 | 53,0 56 | 54,0 57 | 55,0 58 | 56,0 59 | 57,0 60 | 58,0 61 | 59,0 62 | 60,0 63 | 61,0 64 | 62,0 65 | 63,0 66 | 64,0 67 | 65,0 68 | 66,0 69 | 67,0 70 | 68,0 71 | 69,0 72 | 70,0 73 | 71,0 74 | 72,0 75 | 73,0 76 | 74,0 77 | 75,0 78 | 76,0 79 | 77,0 80 | 78,0 81 | 79,0 82 | 80,0 83 | 81,0 84 | 82,0 85 | 83,0 86 | 84,0 87 | 85,0 88 | 86,0 89 | 87,0 90 | 88,0 91 | 89,0 92 | 90,0 93 | 91,0 94 | 92,0 95 | 93,0 96 | 94,0 97 | 95,0 98 | 96,0 99 | 97,0 100 | 98,0 101 | 99,0 102 | 100,0 103 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results_validation/start_after_4games.csv: -------------------------------------------------------------------------------- 1 | move,accuracy 2 | 0,0.013087661671114761 3 | 1,0.01365222746869226 4 | 2,0.011599260932046808 5 | 3,0.00949497023198522 6 | 4,0.008160541983165674 7 | 5,0.007647300349004311 8 | 6,0.008160541983165674 9 | 7,0.0083662680285377 10 | 8,0.00810963403993225 11 | 9,0.00810963403993225 12 | 10,0.00805913454134798 13 | 11,0.00765004877547877 14 | 12,0.008062445437272121 15 | 13,0.008527689304428234 16 | 14,0.008069490131578948 17 | 15,0.007920588386565858 18 | 16,0.007926295743476247 19 | 17,0.006804825239715435 20 | 18,0.006308822008480711 21 | 19,0.005659691572771172 22 | 20,0.0067447453727909655 23 | 21,0.0066029264169880095 24 | 22,0.0075660012878300065 25 | 23,0.006154184295840431 26 | 24,0.006217557469625236 27 | 25,0.005656517029726802 28 | 26,0.004635977799542932 29 | 27,0.0037816625044595075 30 | 28,0.00454331818893937 31 | 29,0.0036316472114137485 32 | 30,0.0045231450293523245 33 | 31,0.004412397761515282 34 | 32,0.0063978754225012075 35 | 33,0.006824075337791729 36 | 34,0.006715602061533656 37 | 35,0.007404731804226115 38 | 36,0.009222385244183609 39 | 37,0.0061942517343904855 40 | 38,0.006505026611472502 41 | 39,0.00972972972972973 42 | 40,0.009379187137114784 43 | 41,0.01678240740740741 44 | 42,0.01564945226917058 45 | 43,0.022598870056497175 46 | 44,0.029209621993127148 47 | 45,0.03932584269662921 48 | 46,0.06349206349206349 49 | 47,0.0641025641025641 50 | 48,0.375 51 | 49,0 52 | 50,0 53 | 51,0 54 | 52,0 55 | 53,0 56 | 54,0 57 | 55,0 58 | 56,0 59 | 57,0 60 | 58,0 61 | 59,0 62 | 60,0 63 | 61,0 64 | 62,0 65 | 63,0 66 | 64,0 67 | 65,0 68 | 66,0 69 | 67,0 70 | 68,0 71 | 69,0 72 | 70,0 73 | 71,0 74 | 72,0 75 | 73,0 76 | 74,0 77 | 75,0 78 | 76,0 79 | 77,0 80 | 78,0 81 | 79,0 82 | 80,0 83 | 81,0 84 | 82,0 85 | 83,0 86 | 84,0 87 | 85,0 88 | 86,0 89 | 87,0 90 | 88,0 91 | 89,0 92 | 90,0 93 | 91,0 94 | 92,0 95 | 93,0 96 | 94,0 97 | 95,0 98 | 96,0 99 | 97,0 100 | 98,0 101 | 99,0 102 | 100,0 103 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results_validation/stop_after.csv: -------------------------------------------------------------------------------- 1 | move,accuracy 2 | 0,0.0015773271936296335 3 | 1,0.003510825043885313 4 | 2,0.01217340422825451 5 | 3,0.011651868623909227 6 | 4,0.010201745236217467 7 | 5,0.010367110183936703 8 | 6,0.010061821049685806 9 | 7,0.009985498766123082 10 | 8,0.009743811534841123 11 | 9,0.00961660772890325 12 | 10,0.00927315745287099 13 | 11,0.009285877833464778 14 | 12,0.009196835169308266 15 | 13,0.009311318594652352 16 | 14,0.009362200117027502 17 | 15,0.009234996311089629 18 | 16,0.008878825654463582 19 | 17,0.008853384893276008 20 | 18,0.008993309079807667 21 | 19,0.00884066451268222 22 | 20,0.009018749840995242 23 | 21,0.00927315745287099 24 | 22,0.009069631363370393 25 | 23,0.008967868318620092 26 | 24,0.008777062609713282 27 | 25,0.008916986796244943 28 | 26,0.008866105273869794 29 | 27,0.008878825654463582 30 | 28,0.008853384893276008 31 | 29,0.00884066451268222 32 | 30,0.008929707176838731 33 | 31,0.008904266415651157 34 | 32,0.008891546035057369 35 | 33,0.008929707176838731 36 | 34,0.008815223751494645 37 | 35,0.008649858803775409 38 | 36,0.008700740326150558 39 | 37,0.008713460706744346 40 | 38,0.00854809575902511 41 | 39,0.00830640852774315 42 | 40,0.008433612333681024 43 | 41,0.008420891953087236 44 | 42,0.008319128908336937 45 | 43,0.008344569669524512 46 | 44,0.00820464548299285 47 | 45,0.007975678632304679 48 | 46,0.008090162057648766 49 | 47,0.008039280535273615 50 | 48,0.008001119393492254 51 | 49,0.008001119393492254 52 | 50,0.008001119393492254 53 | 51,0.008001119393492254 54 | 52,0.008001119393492254 55 | 53,0.008001119393492254 56 | 54,0.008001119393492254 57 | 55,0.008001119393492254 58 | 56,0.008001119393492254 59 | 57,0.008001119393492254 60 | 58,0.008001119393492254 61 | 59,0.008001119393492254 62 | 60,0.008001119393492254 63 | 61,0.008001119393492254 64 | 62,0.008001119393492254 65 | 63,0.008001119393492254 66 | 64,0.008001119393492254 67 | 65,0.008001119393492254 68 | 66,0.008001119393492254 69 | 67,0.008001119393492254 70 | 68,0.008001119393492254 71 | 69,0.008001119393492254 72 | 70,0.008001119393492254 73 | 71,0.008001119393492254 74 | 72,0.008001119393492254 75 | 73,0.008001119393492254 76 | 74,0.008001119393492254 77 | 75,0.008001119393492254 78 | 76,0.008001119393492254 79 | 77,0.008001119393492254 80 | 78,0.008001119393492254 81 | 79,0.008001119393492254 82 | 80,0.008001119393492254 83 | 81,0.008001119393492254 84 | 82,0.008001119393492254 85 | 83,0.008001119393492254 86 | 84,0.008001119393492254 87 | 85,0.008001119393492254 88 | 86,0.008001119393492254 89 | 87,0.008001119393492254 90 | 88,0.008001119393492254 91 | 89,0.008001119393492254 92 | 90,0.008001119393492254 93 | 91,0.008001119393492254 94 | 92,0.008001119393492254 95 | 93,0.008001119393492254 96 | 94,0.008001119393492254 97 | 95,0.008001119393492254 98 | 96,0.008001119393492254 99 | 97,0.008001119393492254 100 | 98,0.008001119393492254 101 | 99,0.008001119393492254 102 | 100,0.008001119393492254 103 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/results_validation/stop_after_4games.csv: -------------------------------------------------------------------------------- 1 | move,accuracy 2 | 0,0.0015910490659002258 3 | 1,0.006056251283104086 4 | 2,0.016115787312666805 5 | 3,0.017193594744405665 6 | 4,0.014884007390679532 7 | 5,0.016013138985834532 8 | 6,0.015499897351673168 9 | 7,0.015448573188257032 10 | 8,0.014986655717511805 11 | 9,0.014678710737014987 12 | 10,0.01591049065900226 13 | 11,0.015294600698008622 14 | 12,0.014524738246766578 15 | 13,0.014730034900431123 16 | 14,0.015037979880927942 17 | 15,0.014730034900431123 18 | 16,0.014832683227263395 19 | 17,0.015345924861424758 20 | 18,0.015089304044344077 21 | 19,0.014114144939437486 22 | 20,0.014422089919934305 23 | 21,0.015037979880927942 24 | 22,0.014678710737014987 25 | 23,0.013960172449189078 26 | 24,0.013857524122356805 27 | 25,0.013395606651611578 28 | 26,0.013087661671114761 29 | 27,0.014165469102853623 30 | 28,0.014576062410182713 31 | 29,0.01365222746869226 32 | 30,0.014319441593102033 33 | 31,0.015037979880927942 34 | 32,0.014576062410182713 35 | 33,0.01478135906384726 36 | 34,0.014422089919934305 37 | 35,0.014576062410182713 38 | 36,0.013960172449189078 39 | 37,0.013754875795524533 40 | 38,0.013600903305276125 41 | 39,0.013908848285772941 42 | 40,0.014319441593102033 43 | 41,0.014114144939437486 44 | 42,0.014268117429685897 45 | 43,0.01365222746869226 46 | 44,0.013395606651611578 47 | 45,0.013292958324779306 48 | 46,0.013190309997947033 49 | 47,0.013292958324779306 50 | 48,0.013190309997947033 51 | 49,0.013087661671114761 52 | 50,0.013087661671114761 53 | 51,0.013087661671114761 54 | 52,0.013087661671114761 55 | 53,0.013087661671114761 56 | 54,0.013087661671114761 57 | 55,0.013087661671114761 58 | 56,0.013087661671114761 59 | 57,0.013087661671114761 60 | 58,0.013087661671114761 61 | 59,0.013087661671114761 62 | 60,0.013087661671114761 63 | 61,0.013087661671114761 64 | 62,0.013087661671114761 65 | 63,0.013087661671114761 66 | 64,0.013087661671114761 67 | 65,0.013087661671114761 68 | 66,0.013087661671114761 69 | 67,0.013087661671114761 70 | 68,0.013087661671114761 71 | 69,0.013087661671114761 72 | 70,0.013087661671114761 73 | 71,0.013087661671114761 74 | 72,0.013087661671114761 75 | 73,0.013087661671114761 76 | 74,0.013087661671114761 77 | 75,0.013087661671114761 78 | 76,0.013087661671114761 79 | 77,0.013087661671114761 80 | 78,0.013087661671114761 81 | 79,0.013087661671114761 82 | 80,0.013087661671114761 83 | 81,0.013087661671114761 84 | 82,0.013087661671114761 85 | 83,0.013087661671114761 86 | 84,0.013087661671114761 87 | 85,0.013087661671114761 88 | 86,0.013087661671114761 89 | 87,0.013087661671114761 90 | 88,0.013087661671114761 91 | 89,0.013087661671114761 92 | 90,0.013087661671114761 93 | 91,0.013087661671114761 94 | 92,0.013087661671114761 95 | 93,0.013087661671114761 96 | 94,0.013087661671114761 97 | 95,0.013087661671114761 98 | 96,0.013087661671114761 99 | 97,0.013087661671114761 100 | 98,0.013087661671114761 101 | 99,0.013087661671114761 102 | 100,0.013087661671114761 103 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/sweep_moves_all_games.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import csv 3 | import argparse 4 | import os 5 | import numpy as np 6 | from sklearn.naive_bayes import GaussianNB 7 | from matplotlib import pyplot as plt 8 | 9 | def parse_argument(): 10 | parser = argparse.ArgumentParser(description='arg parser') 11 | 12 | parser.add_argument('--input_dir', default='cp_loss_hist_per_move') 13 | parser.add_argument('--output_start_after_csv', default='start_after_all_game.csv') 14 | parser.add_argument('--output_stop_after_csv', default='stop_after_all_game.csv') 15 | parser.add_argument('--saved_plot', default='plot_all_game.png') 16 | 17 | return parser.parse_args() 18 | 19 | def read_npy(input_dir): 20 | 21 | player_list = {} 22 | for input_data in os.listdir(input_dir): 23 | # will split into [player_name, 'train/test/val'] 24 | input_name = input_data.split('_') 25 | if len(input_name) > 2: 26 | player_name = input_name[:-1] 27 | player_name = '_'.join(player_name) 28 | else: 29 | player_name = input_name[0] 30 | # add into player list 31 | if player_name not in player_list: 32 | player_list[player_name] = 1 33 | 34 | player_list = list(player_list.keys()) 35 | 36 | player_data = {} 37 | for player_name in player_list: 38 | player_data[player_name] = {'train': None, 'validation': None, 'test': None} 39 | train_path = os.path.join(input_dir, player_name + '_{}.npy'.format('train')) 40 | val_path = os.path.join(input_dir, player_name + '_{}.npy'.format('validation')) 41 | test_path = os.path.join(input_dir, player_name + '_{}.npy'.format('test')) 42 | 43 | player_data[player_name]['train'] = np.load(train_path, allow_pickle=True) 44 | player_data[player_name]['train'] = player_data[player_name]['train'].item() 45 | player_data[player_name]['validation'] = np.load(val_path, allow_pickle=True) 46 | player_data[player_name]['validation'] = player_data[player_name]['validation'].item() 47 | player_data[player_name]['test'] = np.load(test_path, allow_pickle=True) 48 | player_data[player_name]['test'] = player_data[player_name]['test'].item() 49 | 50 | return player_data 51 | 52 | 53 | def construct_train_set(player_data, is_start_after, move_stop): 54 | player_index = {} 55 | train_list = [] 56 | train_label = [] 57 | 58 | i = 0 59 | for player in player_data.keys(): 60 | # if player in os.listdir('/data/csvs'): 61 | player_index[player] = i 62 | train_label.append(i) 63 | if is_start_after: 64 | train_list.append(player_data[player]['train']['start_after'][move_stop]) 65 | else: 66 | train_list.append(player_data[player]['train']['stop_after'][move_stop]) 67 | 68 | i += 1 69 | 70 | train_label = np.asarray(train_label) 71 | # one_hot = np.zeros((train_label.size, train_label.max()+1)) 72 | # one_hot[np.arange(train_label.size),train_label] = 1 73 | # print(one_hot.shape) 74 | 75 | train_data = np.stack(train_list, 0) 76 | return train_data, train_label, player_index 77 | 78 | 79 | def predict(train_data, train_label, player_data, player_index, is_start_after, move_stop): 80 | correct = 0 81 | total = 0 82 | model = GaussianNB() 83 | model.fit(train_data, train_label) 84 | for player in player_data.keys(): 85 | test = player_data[player]['test'] 86 | if is_start_after: 87 | test = test['start_after'][move_stop] 88 | if all(v == 0 for v in test): 89 | continue 90 | else: 91 | test = test['stop_after'][move_stop] 92 | 93 | predicted = model.predict(np.expand_dims(test, axis=0)) 94 | index = predicted[0] 95 | if index == player_index[player]: 96 | correct += 1 97 | total += 1 98 | 99 | if total == 0: 100 | accuracy = 0 101 | else: 102 | accuracy = correct / total 103 | 104 | print(accuracy) 105 | return accuracy 106 | 107 | 108 | def make_plots(moves, start_after_accuracies, stop_after_accuracies, plot_name): 109 | plt.plot(moves, start_after_accuracies, label="Start after x moves") 110 | plt.plot(moves, stop_after_accuracies, label="Stop after x moves") 111 | plt.legend() 112 | plt.xlabel("Moves") 113 | plt.savefig(plot_name) 114 | 115 | 116 | if __name__ == '__main__': 117 | args = parse_argument() 118 | 119 | player_data = read_npy(args.input_dir) 120 | moves = [i for i in range(101)] 121 | start_after_accuracies = [] 122 | stop_after_accuracies = [] 123 | output_start_csv = open(args.output_start_after_csv, 'w', newline='') 124 | writer_start = csv.writer(output_start_csv) 125 | writer_start.writerow(['move', 'accuracy']) 126 | 127 | output_stop_csv = open(args.output_stop_after_csv, 'w', newline='') 128 | writer_stop = csv.writer(output_stop_csv) 129 | writer_stop.writerow(['move', 'accuracy']) 130 | 131 | for is_start_after in (True, False): 132 | for i in range(101): 133 | print('testing {} move {}'.format('start_after' if is_start_after else 'stop_after', i)) 134 | train_data, train_label, player_index = construct_train_set(player_data, is_start_after, i) 135 | 136 | accuracy = predict(train_data, train_label, player_data, player_index, is_start_after, i) 137 | 138 | if is_start_after: 139 | start_after_accuracies.append(accuracy) 140 | writer_start.writerow([i, accuracy]) 141 | else: 142 | stop_after_accuracies.append(accuracy) 143 | writer_stop.writerow([i, accuracy]) 144 | 145 | make_plots(moves, start_after_accuracies, stop_after_accuracies, args.saved_plot) 146 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/sweep_moves_num_games.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import csv 3 | import argparse 4 | import os 5 | import numpy as np 6 | from sklearn.naive_bayes import GaussianNB 7 | from matplotlib import pyplot as plt 8 | 9 | def parse_argument(): 10 | parser = argparse.ArgumentParser(description='arg parser') 11 | 12 | parser.add_argument('--train_dir', default='cp_loss_hist_per_move') 13 | parser.add_argument('--input_dir', default='cp_loss_hist_per_move_per_game_count') 14 | parser.add_argument('--output_start_after_csv', default='start_after_4games.csv') 15 | parser.add_argument('--output_stop_after_csv', default='stop_after_4games.csv') 16 | parser.add_argument('--num_games', default=4) 17 | parser.add_argument('--saved_plot', default='plot_4games.png') 18 | 19 | return parser.parse_args() 20 | 21 | def read_npy(train_dir, input_dir): 22 | 23 | player_list = {} 24 | for input_data in os.listdir(input_dir): 25 | # will split into [player_name, 'train/test/val'] 26 | input_name = input_data.split('_') 27 | if len(input_name) > 2: 28 | player_name = input_name[:-1] 29 | player_name = '_'.join(player_name) 30 | else: 31 | player_name = input_name[0] 32 | # add into player list 33 | if player_name not in player_list: 34 | player_list[player_name] = 1 35 | 36 | player_list = list(player_list.keys()) 37 | 38 | player_data = {} 39 | for player_name in player_list: 40 | player_data[player_name] = {'train': None, 'validation': None, 'test': None} 41 | train_path = os.path.join(train_dir, player_name + '_{}.npy'.format('train')) 42 | val_path = os.path.join(input_dir, player_name + '_{}.npy'.format('validation')) 43 | test_path = os.path.join(input_dir, player_name + '_{}.npy'.format('test')) 44 | 45 | player_data[player_name]['train'] = np.load(train_path, allow_pickle=True) 46 | player_data[player_name]['train'] = player_data[player_name]['train'].item() 47 | player_data[player_name]['validation'] = np.load(val_path, allow_pickle=True) 48 | player_data[player_name]['validation'] = player_data[player_name]['validation'].item() 49 | player_data[player_name]['test'] = np.load(test_path, allow_pickle=True) 50 | player_data[player_name]['test'] = player_data[player_name]['test'].item() 51 | 52 | return player_data 53 | 54 | def normalize(data): 55 | data_norm = data 56 | 57 | if any(v != 0 for v in data): 58 | norm = np.linalg.norm(data) 59 | data_norm = data/norm 60 | 61 | return data_norm 62 | 63 | def construct_train_set(player_data, is_start_after, move_stop): 64 | player_index = {} 65 | train_list = [] 66 | train_label = [] 67 | 68 | i = 0 69 | for player in player_data.keys(): 70 | # if player in os.listdir('/data/csvs'): 71 | player_index[player] = i 72 | train_label.append(i) 73 | if is_start_after: 74 | train_list.append(player_data[player]['train']['start_after'][move_stop]) 75 | else: 76 | train_list.append(player_data[player]['train']['stop_after'][move_stop]) 77 | 78 | i += 1 79 | 80 | train_label = np.asarray(train_label) 81 | # one_hot = np.zeros((train_label.size, train_label.max()+1)) 82 | # one_hot[np.arange(train_label.size),train_label] = 1 83 | # print(one_hot.shape) 84 | 85 | train_data = np.stack(train_list, 0) 86 | return train_data, train_label, player_index 87 | 88 | 89 | def predict(train_data, train_label, player_data, player_index, is_start_after, move_stop, num_games): 90 | accurcies = [] 91 | correct = 0 92 | total = 0 93 | model = GaussianNB() 94 | model.fit(train_data, train_label) 95 | results = None 96 | for player in player_data.keys(): 97 | test_game = None 98 | tmp_game = None 99 | test_games = [] 100 | test = player_data[player]['test'] 101 | count = 1 102 | 103 | # key is game id 104 | for key, value in test.items(): 105 | # get which game to use 106 | if is_start_after: 107 | tmp_game = test[key]['start_after'][move_stop] 108 | # ignore all 0 cases, essentially there's no more move in this game 109 | if all(v == 0 for v in tmp_game): 110 | continue 111 | else: 112 | tmp_game = test[key]['stop_after'][move_stop] 113 | 114 | # add up counts in each game 115 | if test_game is None: 116 | test_game = tmp_game 117 | else: 118 | test_game = test_game + tmp_game 119 | 120 | if count == num_games: 121 | # test_game is addition of counts, need to normalize before testing 122 | test_game = normalize(test_game) 123 | test_games.append(test_game) 124 | 125 | # reset 126 | test_game = None 127 | tmp_game = None 128 | count = 1 129 | 130 | else: 131 | count += 1 132 | 133 | # skip player if all games are beyond move_stop 134 | if not test_games: 135 | continue 136 | 137 | test_games = np.stack(test_games, axis=0) 138 | predicted = model.predict(test_games) 139 | result = (predicted == player_index[player]).astype(float) 140 | 141 | # append to the overall result 142 | if results is None: 143 | results = result 144 | else: 145 | results = np.append(results, result, 0) 146 | 147 | if results is None: 148 | accuracy = 0 149 | 150 | else: 151 | accuracy = np.mean(results) 152 | 153 | print(accuracy) 154 | 155 | return accuracy 156 | 157 | 158 | def make_plots(moves, start_after_accuracies, stop_after_accuracies, plot_name): 159 | plt.plot(moves, start_after_accuracies, label="Start after x moves") 160 | plt.plot(moves, stop_after_accuracies, label="Stop after x moves") 161 | plt.legend() 162 | plt.xlabel("Moves") 163 | plt.savefig(plot_name) 164 | 165 | 166 | if __name__ == '__main__': 167 | args = parse_argument() 168 | 169 | player_data = read_npy(args.train_dir, args.input_dir) 170 | moves = [i for i in range(101)] 171 | start_after_accuracies = [] 172 | stop_after_accuracies = [] 173 | output_start_csv = open(args.output_start_after_csv, 'w', newline='') 174 | writer_start = csv.writer(output_start_csv) 175 | writer_start.writerow(['move', 'accuracy']) 176 | 177 | output_stop_csv = open(args.output_stop_after_csv, 'w', newline='') 178 | writer_stop = csv.writer(output_stop_csv) 179 | writer_stop.writerow(['move', 'accuracy']) 180 | 181 | for is_start_after in (True, False): 182 | for i in range(101): 183 | print('testing {} move {}'.format('start_after' if is_start_after else 'stop_after', i)) 184 | train_data, train_label, player_index = construct_train_set(player_data, is_start_after, i) 185 | 186 | accuracy = predict(train_data, train_label, player_data, player_index, is_start_after, i, args.num_games) 187 | 188 | if is_start_after: 189 | start_after_accuracies.append(accuracy) 190 | writer_start.writerow([i, accuracy]) 191 | else: 192 | stop_after_accuracies.append(accuracy) 193 | writer_stop.writerow([i, accuracy]) 194 | 195 | make_plots(moves, start_after_accuracies, stop_after_accuracies, args.saved_plot) 196 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/sweep_moves_per_game.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import csv 3 | import argparse 4 | import os 5 | import numpy as np 6 | from sklearn.naive_bayes import GaussianNB 7 | from matplotlib import pyplot as plt 8 | 9 | def parse_argument(): 10 | parser = argparse.ArgumentParser(description='arg parser') 11 | 12 | parser.add_argument('--train_dir', default='cp_loss_hist_per_move') 13 | parser.add_argument('--input_dir', default='cp_loss_hist_per_move_per_game') 14 | parser.add_argument('--output_start_after_csv', default='start_after.csv') 15 | parser.add_argument('--output_stop_after_csv', default='stop_after.csv') 16 | parser.add_argument('--saved_plot', default='plot.png') 17 | 18 | return parser.parse_args() 19 | 20 | def read_npy(train_dir, input_dir): 21 | 22 | player_list = {} 23 | for input_data in os.listdir(input_dir): 24 | # will split into [player_name, 'train/test/val'] 25 | input_name = input_data.split('_') 26 | if len(input_name) > 2: 27 | player_name = input_name[:-1] 28 | player_name = '_'.join(player_name) 29 | else: 30 | player_name = input_name[0] 31 | # add into player list 32 | if player_name not in player_list: 33 | player_list[player_name] = 1 34 | 35 | player_list = list(player_list.keys()) 36 | 37 | player_data = {} 38 | for player_name in player_list: 39 | player_data[player_name] = {'train': None, 'validation': None, 'test': None} 40 | train_path = os.path.join(train_dir, player_name + '_{}.npy'.format('train')) 41 | val_path = os.path.join(input_dir, player_name + '_{}.npy'.format('validation')) 42 | test_path = os.path.join(input_dir, player_name + '_{}.npy'.format('test')) 43 | 44 | player_data[player_name]['train'] = np.load(train_path, allow_pickle=True) 45 | player_data[player_name]['train'] = player_data[player_name]['train'].item() 46 | player_data[player_name]['validation'] = np.load(val_path, allow_pickle=True) 47 | player_data[player_name]['validation'] = player_data[player_name]['validation'].item() 48 | player_data[player_name]['test'] = np.load(test_path, allow_pickle=True) 49 | player_data[player_name]['test'] = player_data[player_name]['test'].item() 50 | 51 | return player_data 52 | 53 | 54 | def construct_train_set(player_data, is_start_after, move_stop): 55 | player_index = {} 56 | train_list = [] 57 | train_label = [] 58 | 59 | i = 0 60 | for player in player_data.keys(): 61 | # if player in os.listdir('/data/csvs'): 62 | player_index[player] = i 63 | train_label.append(i) 64 | if is_start_after: 65 | train_list.append(player_data[player]['train']['start_after'][move_stop]) 66 | else: 67 | train_list.append(player_data[player]['train']['stop_after'][move_stop]) 68 | 69 | i += 1 70 | 71 | train_label = np.asarray(train_label) 72 | # one_hot = np.zeros((train_label.size, train_label.max()+1)) 73 | # one_hot[np.arange(train_label.size),train_label] = 1 74 | # print(one_hot.shape) 75 | 76 | train_data = np.stack(train_list, 0) 77 | return train_data, train_label, player_index 78 | 79 | 80 | def predict(train_data, train_label, player_data, player_index, is_start_after, move_stop): 81 | accurcies = [] 82 | correct = 0 83 | total = 0 84 | model = GaussianNB() 85 | model.fit(train_data, train_label) 86 | results = None 87 | for player in player_data.keys(): 88 | test_game = None 89 | test_games = [] 90 | test = player_data[player]['test'] 91 | 92 | # key is game id 93 | for key, value in test.items(): 94 | if is_start_after: 95 | test_game = test[key]['start_after'][move_stop] 96 | # ignore all 0 cases, essentially there's no more move in this game 97 | if all(v == 0 for v in test_game): 98 | continue 99 | else: 100 | test_game = test[key]['stop_after'][move_stop] 101 | 102 | test_games.append(test_game) 103 | 104 | # skip player if all games are beyond move_stop 105 | if not test_games: 106 | continue 107 | 108 | test_games = np.stack(test_games, axis=0) 109 | predicted = model.predict(test_games) 110 | result = (predicted == player_index[player]).astype(float) 111 | 112 | # append to the overall result 113 | if results is None: 114 | results = result 115 | else: 116 | results = np.append(results, result, 0) 117 | 118 | if results is None: 119 | accuracy = 0 120 | 121 | else: 122 | accuracy = np.mean(results) 123 | 124 | print(accuracy) 125 | 126 | return accuracy 127 | 128 | 129 | def make_plots(moves, start_after_accuracies, stop_after_accuracies, plot_name): 130 | plt.plot(moves, start_after_accuracies, label="Start after x moves") 131 | plt.plot(moves, stop_after_accuracies, label="Stop after x moves") 132 | plt.legend() 133 | plt.xlabel("Moves") 134 | plt.savefig(plot_name) 135 | 136 | 137 | if __name__ == '__main__': 138 | args = parse_argument() 139 | 140 | player_data = read_npy(args.train_dir, args.input_dir) 141 | moves = [i for i in range(101)] 142 | start_after_accuracies = [] 143 | stop_after_accuracies = [] 144 | output_start_csv = open(args.output_start_after_csv, 'w', newline='') 145 | writer_start = csv.writer(output_start_csv) 146 | writer_start.writerow(['move', 'accuracy']) 147 | 148 | output_stop_csv = open(args.output_stop_after_csv, 'w', newline='') 149 | writer_stop = csv.writer(output_stop_csv) 150 | writer_stop.writerow(['move', 'accuracy']) 151 | 152 | for is_start_after in (True, False): 153 | for i in range(101): 154 | print('testing {} move {}'.format('start_after' if is_start_after else 'stop_after', i)) 155 | train_data, train_label, player_index = construct_train_set(player_data, is_start_after, i) 156 | 157 | accuracy = predict(train_data, train_label, player_data, player_index, is_start_after, i) 158 | 159 | if is_start_after: 160 | start_after_accuracies.append(accuracy) 161 | writer_start.writerow([i, accuracy]) 162 | else: 163 | stop_after_accuracies.append(accuracy) 164 | writer_stop.writerow([i, accuracy]) 165 | 166 | make_plots(moves, start_after_accuracies, stop_after_accuracies, args.saved_plot) 167 | -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/test_all_games.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import csv 3 | import argparse 4 | import os 5 | import numpy as np 6 | from sklearn.naive_bayes import GaussianNB 7 | 8 | def parse_argument(): 9 | parser = argparse.ArgumentParser(description='arg parser') 10 | 11 | parser.add_argument('--input_dir', default='cp_loss_hist') 12 | parser.add_argument('--use_bayes', default=True) 13 | 14 | return parser.parse_args() 15 | 16 | def read_npy(input_dir): 17 | 18 | player_list = {} 19 | for input_data in os.listdir(input_dir): 20 | # will split into [player_name, 'train/test/val'] 21 | input_name = input_data.split('_') 22 | if len(input_name) > 2: 23 | player_name = input_name[:-1] 24 | player_name = '_'.join(player_name) 25 | else: 26 | player_name = input_name[0] 27 | # add into player list 28 | if player_name not in player_list: 29 | player_list[player_name] = 1 30 | 31 | player_list = list(player_list.keys()) 32 | 33 | player_data = {} 34 | for player_name in player_list: 35 | player_data[player_name] = {'train': None, 'validation': None, 'test': None} 36 | train_path = os.path.join(input_dir, player_name + '_{}.npy'.format('train')) 37 | val_path = os.path.join(input_dir, player_name + '_{}.npy'.format('validation')) 38 | test_path = os.path.join(input_dir, player_name + '_{}.npy'.format('test')) 39 | 40 | player_data[player_name]['train'] = np.load(train_path) 41 | player_data[player_name]['validation'] = np.load(val_path) 42 | player_data[player_name]['test'] = np.load(test_path) 43 | 44 | return player_data 45 | 46 | # =============================== Naive Bayes =============================== 47 | def construct_train_set(player_data): 48 | player_index = {} 49 | train_list = [] 50 | train_label = [] 51 | 52 | i = 0 53 | for player in player_data.keys(): 54 | player_index[player] = i 55 | train_label.append(i) 56 | train_list.append(player_data[player]['train']) 57 | i += 1 58 | 59 | train_label = np.asarray(train_label) 60 | 61 | train_data = np.stack(train_list, 0) 62 | return train_data, train_label, player_index 63 | 64 | def predict(train_data, train_label, player_data, player_index): 65 | print(player_index) 66 | correct = 0 67 | total = 0 68 | model = GaussianNB() 69 | model.fit(train_data, train_label) 70 | 71 | for player in player_data.keys(): 72 | test = player_data[player]['test'] 73 | predicted = model.predict(np.expand_dims(test, axis=0)) 74 | index = predicted[0] 75 | if index == player_index[player]: 76 | correct += 1 77 | total += 1 78 | 79 | print('accuracy is {}'.format(correct / total)) 80 | 81 | # =============================== Euclidean Distance =============================== 82 | def construct_train_list(player_data): 83 | # player_index is {player_name: id} mapping 84 | player_index = {} 85 | train_list = [] 86 | i = 0 87 | for player in player_data.keys(): 88 | player_index[player] = i 89 | train_list.append(player_data[player]['train']) 90 | i += 1 91 | 92 | return train_list, player_index 93 | 94 | def test_euclidean_dist(train_list, player_data, player_index): 95 | print(player_index) 96 | correct = 0 97 | total = 0 98 | # loop through each player and test their 'test set' 99 | for player in player_data.keys(): 100 | dist_list = [] 101 | test = player_data[player]['test'] 102 | 103 | # save distance for each (test, train) 104 | for train_data in train_list: 105 | dist = np.linalg.norm(train_data - test) 106 | dist_list.append(dist) 107 | 108 | # find minimum distance and its index 109 | min_index = dist_list.index(min(dist_list)) 110 | if min_index == player_index[player]: 111 | correct += 1 112 | total += 1 113 | 114 | print('accuracy is {}'.format(correct / total)) 115 | 116 | # =============================== run bayes or euclidean =============================== 117 | def run_bayes(player_data): 118 | print("Using Naive Bayes") 119 | train_data, train_label, player_index = construct_train_set(player_data) 120 | predict(train_data, train_label, player_data, player_index) 121 | 122 | def run_euclidean_dist(player_data): 123 | print("Using Euclidean Distance") 124 | train_list, player_index = construct_train_list(player_data) 125 | test_euclidean_dist(train_list, player_data, player_index) 126 | 127 | 128 | if __name__ == '__main__': 129 | args = parse_argument() 130 | 131 | player_data = read_npy(args.input_dir) 132 | 133 | if args.use_bayes: 134 | run_bayes(player_data) 135 | else: 136 | run_euclidean_dist(player_data) -------------------------------------------------------------------------------- /4-cp_loss_stylo_baseline/train_cploss_per_game.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import csv 3 | import argparse 4 | import os 5 | import numpy as np 6 | import tensorflow as tf 7 | from sklearn.naive_bayes import GaussianNB 8 | 9 | def parse_argument(): 10 | parser = argparse.ArgumentParser(description='arg parser') 11 | 12 | parser.add_argument('--input_dir', default='cp_loss_count_per_game') 13 | parser.add_argument('--gpu', default=0, type=int) 14 | 15 | return parser.parse_args() 16 | 17 | def normalize(data): 18 | norm = np.linalg.norm(data) 19 | data_norm = data/norm 20 | return data_norm 21 | 22 | def read_npy(input_dir): 23 | 24 | player_list = {} 25 | for input_data in os.listdir(input_dir): 26 | # will split into [player_name, 'train/test/val'] 27 | input_name = input_data.split('_') 28 | if len(input_name) > 2: 29 | player_name = input_name[:-1] 30 | player_name = '_'.join(player_name) 31 | else: 32 | player_name = input_name[0] 33 | # add into player list 34 | if player_name not in player_list: 35 | player_list[player_name] = 1 36 | 37 | player_list = list(player_list.keys()) 38 | 39 | player_data = {} 40 | for player_name in player_list: 41 | player_data[player_name] = {'train': None, 'validation': None, 'test': None} 42 | train_path = os.path.join(input_dir, player_name + '_{}.npy'.format('train')) 43 | val_path = os.path.join(input_dir, player_name + '_{}.npy'.format('validation')) 44 | test_path = os.path.join(input_dir, player_name + '_{}.npy'.format('test')) 45 | 46 | player_data[player_name]['train'] = np.load(train_path, allow_pickle=True) 47 | player_data[player_name]['train'] = player_data[player_name]['train'].item() 48 | player_data[player_name]['validation'] = np.load(val_path, allow_pickle=True) 49 | player_data[player_name]['validation'] = player_data[player_name]['validation'].item() 50 | player_data[player_name]['test'] = np.load(test_path, allow_pickle=True) 51 | player_data[player_name]['test'] = player_data[player_name]['test'].item() 52 | 53 | return player_data 54 | 55 | def construct_datasets(player_data): 56 | player_index = {} 57 | train_list = [] 58 | train_labels = [] 59 | validation_list = [] 60 | validation_labels = [] 61 | test_list = [] 62 | test_labels = [] 63 | i = 0 64 | for player in player_data.keys(): 65 | label = i 66 | player_index[player] = i 67 | for key, value in player_data[player]['train'].items(): 68 | train_list.append(normalize(value)) 69 | train_labels.append(label) 70 | 71 | for key, value in player_data[player]['validation'].items(): 72 | validation_list.append(normalize(value)) 73 | validation_labels.append(label) 74 | 75 | for key, value in player_data[player]['test'].items(): 76 | test_list.append(normalize(value)) 77 | test_labels.append(label) 78 | 79 | i += 1 80 | # convert lists into numpy arrays 81 | train_list_np = np.stack(train_list, axis=0) 82 | validation_list_np = np.stack(validation_list, axis=0) 83 | test_list_np = np.stack(test_list, axis=0) 84 | 85 | train_labels_np = np.stack(train_labels, axis=0) 86 | validation_labels_np = np.stack(validation_labels, axis=0) 87 | test_labels_np = np.stack(test_labels, axis=0) 88 | 89 | return train_list_np, train_labels_np, validation_list_np, validation_labels_np, test_list_np, test_labels_np, player_index 90 | 91 | 92 | def init_net(output_size): 93 | l2reg = tf.keras.regularizers.l2(l=0.5 * (0.0001)) 94 | input_var = tf.keras.Input(shape=(50, )) 95 | dense_1 = tf.keras.layers.Dense(40, kernel_initializer='glorot_normal', kernel_regularizer=l2reg, bias_regularizer=l2reg, activation='relu')(input_var) 96 | dense_2 = tf.keras.layers.Dense(30, kernel_initializer='glorot_normal', kernel_regularizer=l2reg, bias_regularizer=l2reg)(dense_1) 97 | 98 | model= tf.keras.Model(inputs=input_var, outputs=dense_2) 99 | return model 100 | 101 | def train(train_dataset, train_labels, val_dataset, val_labels, test_dataset, test_labels, player_index): 102 | net = init_net(max(test_labels) + 1) 103 | net.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001, clipnorm=1), 104 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 105 | metrics=['accuracy']) 106 | 107 | net.fit(train_dataset, train_labels, batch_size=32, epochs=10, validation_data=(val_dataset, val_labels)) 108 | 109 | test_loss, test_acc = net.evaluate(test_dataset, test_labels, verbose=2) 110 | 111 | print('\nTest accuracy:', test_acc) 112 | 113 | return net 114 | 115 | # predict is to verify if keras test is correct 116 | def predict(net, test, test_labels): 117 | probability_model = tf.keras.Sequential([net, 118 | tf.keras.layers.Softmax()]) 119 | predictions = probability_model.predict(test) 120 | 121 | correct = 0 122 | total = 0 123 | for i, prediction in enumerate(predictions): 124 | if test_labels[i] == np.argmax(prediction): 125 | correct += 1 126 | total += 1 127 | 128 | print('test accuracy is: {}'.format(correct / total)) 129 | 130 | if __name__ == '__main__': 131 | args = parse_argument() 132 | 133 | gpus = tf.config.experimental.list_physical_devices('GPU') 134 | tf.config.experimental.set_visible_devices(gpus[args.gpu], 'GPU') 135 | tf.config.experimental.set_memory_growth(gpus[args.gpu], True) 136 | 137 | player_data = read_npy(args.input_dir) 138 | 139 | train_dataset, train_labels, val_dataset, val_labels, test_dataset, test_labels, player_index = construct_datasets(player_data) 140 | 141 | net = train(train_dataset, train_labels, val_dataset, val_labels, test_dataset, test_labels, player_index) 142 | 143 | # predict is to verify if test is correct 144 | # predict(net, test_dataset, test_labels) 145 | -------------------------------------------------------------------------------- /9-reduced-data/configs/Best_frozen.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | dataset: 4 | name: '' 5 | train_path: '' 6 | validate_path: '' 7 | gpu: 3 8 | model: 9 | back_prop_blocks: 3 10 | filters: 64 11 | keep_weights: true 12 | path: maia/1700 13 | residual_blocks: 6 14 | se_ratio: 8 15 | training: 16 | batch_size: 16 17 | checkpoint_small_steps: 18 | - 100 19 | - 200 20 | - 400 21 | - 800 22 | - 1600 23 | - 2500 24 | checkpoint_steps: 5000 25 | lr_boundaries: 26 | - 50000 27 | - 110000 28 | - 160000 29 | lr_values: 30 | - 1.0e-05 31 | - 1.0e-06 32 | - 1.0e-07 33 | - 1.0e-08 34 | num_batch_splits: 1 35 | policy_loss_weight: 1.0 36 | precision: half 37 | shuffle_size: 256 38 | small_mode: true 39 | test_small_boundaries: 40 | - 20000 41 | - 40000 42 | - 60000 43 | - 80000 44 | - 100000 45 | test_small_steps: 46 | - 100 47 | - 200 48 | - 400 49 | - 800 50 | - 1600 51 | - 2500 52 | test_steps: 2000 53 | total_steps: 200000 54 | train_avg_report_steps: 50 55 | value_loss_weight: 1.0 56 | ... 57 | -------------------------------------------------------------------------------- /9-reduced-data/configs/NFP.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | 5 | dataset: 6 | train_path: '' 7 | validate_path: '' 8 | name: '' 9 | 10 | training: 11 | precision: 'half' 12 | batch_size: 256 13 | num_batch_splits: 1 14 | test_steps: 1000 15 | train_avg_report_steps: 50 16 | total_steps: 150000 17 | checkpoint_steps: 5000 18 | shuffle_size: 256 19 | lr_values: 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | - 0.00001 24 | lr_boundaries: 25 | - 35000 26 | - 80000 27 | - 110000 28 | policy_loss_weight: 1.0 29 | value_loss_weight: 1.0 30 | 31 | model: 32 | filters: 64 33 | residual_blocks: 6 34 | se_ratio: 8 35 | path: "maia/1900" 36 | keep_weights: true 37 | back_prop_blocks: 99 38 | ... 39 | -------------------------------------------------------------------------------- /9-reduced-data/configs/Tuned.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | gpu: 0 4 | model: 5 | back_prop_blocks: 99 6 | filters: 64 7 | keep_weights: true 8 | path: maia/1900 9 | residual_blocks: 6 10 | se_ratio: 8 11 | training: 12 | batch_size: 16 13 | checkpoint_small_steps: 14 | - 50 15 | - 200 16 | - 400 17 | - 800 18 | - 1600 19 | - 2500 20 | checkpoint_steps: 5000 21 | early_stopping_steps: 10000 22 | lr_boundaries: 23 | - 50000 24 | - 110000 25 | - 160000 26 | lr_values: 27 | - 1.0e-05 28 | - 1.0e-06 29 | - 1.0e-07 30 | - 1.0e-08 31 | num_batch_splits: 1 32 | policy_loss_weight: 1.0 33 | precision: half 34 | shuffle_size: 256 35 | small_mode: true 36 | test_small_boundaries: 37 | - 20000 38 | - 40000 39 | - 60000 40 | - 80000 41 | - 100000 42 | test_small_steps: 43 | - 50 44 | - 200 45 | - 400 46 | - 800 47 | - 1600 48 | - 2500 49 | test_steps: 2000 50 | total_steps: 200000 51 | train_avg_report_steps: 50 52 | value_loss_weight: 1.0 53 | ... 54 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # YAML 1.2 2 | --- 3 | abstract: | 4 | "AI systems that can capture human-like behavior are becoming increasingly useful in situations where humans may want to learn from these systems, collaborate with them, or engage with them as partners for an extended duration. In order to develop human-oriented AI systems, the problem of predicting human actions---as opposed to predicting optimal actions---has received considerable attention. Existing work has focused on capturing human behavior in an aggregate sense, which potentially limits the benefit any particular individual could gain from interaction with these systems. We extend this line of work by developing highly accurate predictive models of individual human behavior in chess. Chess is a rich domain for exploring human-AI interaction because it combines a unique set of properties: AI systems achieved superhuman performance many years ago, and yet humans still interact with them closely, both as opponents and as preparation tools, and there is an enormous corpus of recorded data on individual player games. Starting with Maia, an open-source version of AlphaZero trained on a population of human players, we demonstrate that we can significantly improve prediction accuracy of a particular player's moves by applying a series of fine-tuning methods. Furthermore, our personalized models can be used to perform stylometry---predicting who made a given set of moves---indicating that they capture human decision-making at an individual level. Our work demonstrates a way to bring AI systems into better alignment with the behavior of individual people, which could lead to large improvements in human-AI interaction." 5 | authors: 6 | - 7 | affiliation: "Universiy of Toronto" 8 | family-names: "McIlroy-Young" 9 | given-names: Reid 10 | - 11 | affiliation: "Universiy of Toronto" 12 | family-names: Wang 13 | given-names: Russell 14 | - 15 | affiliation: "Microsoft Research" 16 | family-names: Sen 17 | given-names: Siddhartha 18 | - 19 | affiliation: "Cornell University" 20 | family-names: Kleinberg 21 | given-names: Jon 22 | - 23 | affiliation: "Universiy of Toronto" 24 | family-names: Anderson 25 | given-names: Ashton 26 | cff-version: "1.2.0" 27 | date-released: 2022-09-16 28 | doi: "10.1145/3534678.3539367" 29 | license: "GPL-3.0" 30 | message: "If you use Maia Individual, please cite the original paper" 31 | repository-code: "https://github.com/CSSLab/maia-individual" 32 | url: "https://maiachess.com" 33 | title: "Maia Individual" 34 | version: "1.0.0" 35 | preferred-citation: 36 | type: article 37 | authors: 38 | - 39 | affiliation: "Universiy of Toronto" 40 | family-names: "McIlroy-Young" 41 | given-names: Reid 42 | - 43 | affiliation: "Microsoft Research" 44 | family-names: Sen 45 | given-names: Siddhartha 46 | - 47 | affiliation: "Cornell University" 48 | family-names: Kleinberg 49 | given-names: Jon 50 | - 51 | affiliation: "Universiy of Toronto" 52 | family-names: Anderson 53 | given-names: Ashton 54 | doi: "10.1145/3534678.3539367" 55 | journal: "KDD '22: Proceedings of the 28th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining" 56 | month: 8 57 | title: "Learning Models of Individual Behavior in Chess" 58 | year: 2022 59 | ... 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Models of Individual Behavior in Chess 2 | 3 | ## [website](https://maiachess.com)/[paper](https://arxiv.org/abs/2008.10086)/[code](https://github.com/CSSLab/maia-individual) 4 | 5 |
8 | 9 | ## Overview 10 | 11 | The main code used in this project is stored in `backend` which is setup as a Python package, running `python setup.py install` will install it. Then the various scripts can be used. We also recommend using the virtual env config we include in `environment.yml` as some packages are required to be up to date. In addition for generating training data two more tools are need [`pgn-extract`](https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/) to clean the PGNs and [`trainingdata-tool`](https://github.com/DanielUranga/trainingdata-tool) to convert them into training data. 12 | 13 | To run a model as a chess engine [`lco`](https://github.com/LeelaChessZero/lc0}{github.com/LeelaChessZero/lc0) version 23 has been tested and should work with all models (you can specify a path to a model with the `-w` argument). 14 | 15 | All testing was done with Ubuntu 18.04 with CUDA Version 10. 16 | 17 | ## Models 18 | 19 | We do not have any public models of players publicly available at this time. This is because of the stylometry results shown in section 5.2 of the paper we cannot release anonymous models. 20 | 21 | We have included the maia model, from [https://github.com/CSSLab/maia-chess](https://github.com/CSSLab/maia-chess) that was used as the base. 22 | 23 | ## Running the code 24 | 25 | The code for this project is divided into different sections, each has a series of shell scripts that are numbered. If ran in order the training data, then final models can be generated. For the full release we plan to have the process to generate a model more streamlined. 26 | 27 | ### Quick Run 28 | 29 | To get the model for a single player from a single PGN a simpler system can be used first 30 | 31 | 1. Run `1-data_generation/9-pgn_to_training_data.sh input_PGN_file output_directory player_name` 32 | 2. Create a config file by copying `2-training/final_config.yaml` and adding `output_directory` and `player_name` 33 | 3. Run `python 2-training/train_transfer.py path_to_config` 34 | 4. The final model will be written to `final_models`, read the `--help` for more information 35 | 36 | ### Full Run 37 | 38 | For all scripts if applicable they start with a list of variables, these will need to be edited to match the paths on your system. 39 | 40 | The list of players we used was selected using the code in `0-player_counting`. The standard games from lichess [database.lichess.org](database.lichess.org) up to April are required to get our exact results but it should work with other sets, even non-Lichess ones with a bit of work. 41 | 42 | Then the players games are extracted and the various sets are constructed from them in `1-data_generation`. 43 | 44 | Finally `2-training` has the main training script along with a configuration file that specifies the hyper parameters. All four discussed in the main text are included. 45 | 46 | ### Extras 47 | 48 | The analysis code (`3-analysis`) is included for completeness, but as it is for generating the data used in the plots and relies on various hard coded paths we have not tested it. That said `3-analysis/prediction_generator.py` is the main workhorse and has a `--help`, note it is designed for on files output by `backend.gameToCSVlines`, but less complicated csvs could be used. 49 | 50 | The baseline models code and results are in `4-cp_loss_stylo_baseline` this is a simple baseline model to compare our results to, and is included for completeness. 51 | 52 | ## Reduced Data 53 | 54 | The model configurations for the reduced data training are included in `9-reduced-data/configs` To train them yourself simply use the configs in the quick run training. 55 | 56 | ## Citation 57 | 58 | ``` 59 | @article{McIlroyYoung_Learning_Models_Chess_2022, 60 | author = {McIlroy-Young, Reid and Sen, Siddhartha and Kleinberg, Jon and Anderson, Ashton}, 61 | doi = {10.1145/3534678.3539367}, 62 | journal = {KDD '22: Proceedings of the 28th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining}, 63 | month = {8}, 64 | title = {{Learning Models of Individual Behavior in Chess}}, 65 | year = {2022} 66 | } 67 | ``` 68 | 69 | ## License 70 | 71 | The software is available under the GPL License and includes code from the [Leela Chess Zero](https://github.com/LeelaChessZero/lczero-training) project. 72 | 73 | ## Contact 74 | 75 | Please [open an issue](https://github.com/CSSLab/maia-individual/issues/new) or email [Reid McIlroy-Young](https://reidmcy.com/) to get in touch 76 | -------------------------------------------------------------------------------- /backend/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .uci_engine import * 3 | from .pgn_parsering import * 4 | from .multiproc import * 5 | from .fen_to_vec import fen_to_vec, array_to_fen, array_to_board, game_to_vecs 6 | from .pgn_to_csv import * 7 | 8 | __version__ = '1.0.0' 9 | -------------------------------------------------------------------------------- /backend/fen_to_vec.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import chess 4 | import numpy as np 5 | 6 | # Generate the regexs 7 | boardRE = re.compile(r"(([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)) ((w)|(b)) ((-)|(K)?(Q)?(k)?(q)?)( ((-)|(\w+)))?( \d+)?( \d+)?") 8 | 9 | replaceRE = re.compile(r'[1-8/]') 10 | 11 | pieceMapWhite = {'E' : [False] * 12} 12 | pieceMapBlack = {'E' : [False] * 12} 13 | 14 | piece_reverse_lookup = {} 15 | 16 | all_pieces = 'PNBRQK' 17 | 18 | for i, p in enumerate(all_pieces): 19 | #White pieces first 20 | mP = [False] * 12 21 | mP[i] = True 22 | pieceMapBlack[p] = mP 23 | piece_reverse_lookup[i] = p 24 | 25 | #then black 26 | mP = [False] * 12 27 | mP[i + len(all_pieces)] = True 28 | pieceMapBlack[p.lower()] = mP 29 | piece_reverse_lookup[i + len(all_pieces)] = p.lower() 30 | 31 | 32 | #Black pieces first 33 | mP = [False] * 12 34 | mP[i] = True 35 | pieceMapWhite[p.lower()] = mP 36 | 37 | #then white 38 | mP = [False] * 12 39 | mP[i + len(all_pieces)] = True 40 | pieceMapWhite[p] = mP 41 | 42 | iSs = [str(i + 1) for i in range(8)] 43 | eSubss = [('E' * i, str(i)) for i in range(8,0, -1)] 44 | castling_vals = 'KQkq' 45 | 46 | def toByteBuff(l): 47 | return b''.join([b'\1' if e else b'\0' for e in l]) 48 | 49 | pieceMapBin = {k : toByteBuff(v) for k,v in pieceMapBlack.items()} 50 | 51 | def toBin(c): 52 | return pieceMapBin[c] 53 | 54 | castlesMap = {True : b'\1'*64, False : b'\0'*64} 55 | 56 | #Some previous lines are left in just in case 57 | 58 | # using N,C,H,W format 59 | 60 | move_letters = list('abcdefgh') 61 | 62 | moves_lookup = {} 63 | move_ind = 0 64 | for r_1 in range(8): 65 | for c_1 in range(8): 66 | for r_2 in range(8): 67 | for c_2 in range(8): 68 | moves_lookup[f"{move_letters[r_1]}{c_1+1}{move_letters[r_2]}{c_2+1}"] = move_ind 69 | move_ind += 1 70 | 71 | def move_to_index(move_str): 72 | return moves_lookup[move_str[:4]] 73 | 74 | def array_to_preproc(a_target): 75 | if not isinstance(a_target, np.ndarray): 76 | #check if toch Tensor without importing torch 77 | a_target = a_target.cpu().numpy() 78 | if a_target.dtype != np.bool_: 79 | a_target = a_target.astype(np.bool_) 80 | piece_layers = a_target[:12] 81 | board_a = np.moveaxis(piece_layers, 2, 0).reshape(64, 12) 82 | board_str = '' 83 | is_white = bool(a_target[12, 0, 0]) 84 | castling = [bool(l[0,0]) for l in a_target[13:]] 85 | board = [['E'] * 8 for i in range(8)] 86 | for i in range(12): 87 | for x in range(8): 88 | for y in range(8): 89 | if piece_layers[i,x,y]: 90 | board[x][y] = piece_reverse_lookup[i] 91 | board = [''.join(r) for r in board] 92 | return ''.join(board), is_white, tuple(castling) 93 | 94 | def preproc_to_fen(boardStr, is_white, castling): 95 | rows = [boardStr[(i*8):(i*8)+8] for i in range(8)] 96 | 97 | if not is_white: 98 | castling = castling[2:] + castling[:2] 99 | new_rows = [] 100 | for b in rows: 101 | new_rows.append(b.swapcase()[::-1].replace('e', 'E')) 102 | 103 | rows = reversed(new_rows) 104 | row_strs = [] 105 | for r in rows: 106 | for es, i in eSubss: 107 | if es in r: 108 | r = r.replace(es, i) 109 | row_strs.append(r) 110 | castle_str = '' 111 | for i, v in enumerate(castling): 112 | if v: 113 | castle_str += castling_vals[i] 114 | if len(castle_str) < 1: 115 | castle_str = '-' 116 | 117 | is_white_str = 'w' if is_white else 'b' 118 | board_str = '/'.join(row_strs) 119 | return f"{board_str} {is_white_str} {castle_str} - 0 1" 120 | 121 | def array_to_fen(a_target): 122 | return preproc_to_fen(*array_to_preproc(a_target)) 123 | 124 | def array_to_board(a_target): 125 | return chess.Board(fen = array_to_fen(a_target)) 126 | 127 | def simple_fen_vec(boardStr, is_white, castling): 128 | castles = [np.frombuffer(castlesMap[c], dtype='bool').reshape(1, 8, 8) for c in castling] 129 | board_buff_map = map(toBin, boardStr) 130 | board_buff = b''.join(board_buff_map) 131 | a = np.frombuffer(board_buff, dtype='bool') 132 | a = a.reshape(8, 8, -1) 133 | a = np.moveaxis(a, 2, 0) 134 | if is_white: 135 | colour_plane = np.ones((1, 8, 8), dtype='bool') 136 | else: 137 | colour_plane = np.zeros((1, 8, 8), dtype='bool') 138 | 139 | return np.concatenate([a, colour_plane, *castles], axis = 0) 140 | 141 | def preproc_fen(fenstr): 142 | r = boardRE.match(fenstr) 143 | if r.group(14): 144 | castling = (False, False, False, False) 145 | else: 146 | castling = (bool(r.group(15)), bool(r.group(16)), bool(r.group(17)), bool(r.group(18))) 147 | if r.group(11): 148 | is_white = True 149 | rows_lst = r.group(1).split('/') 150 | else: 151 | is_white = False 152 | castling = castling[2:] + castling[:2] 153 | rows_lst = r.group(1).swapcase().split('/') 154 | rows_lst = reversed([s[::-1] for s in rows_lst]) 155 | 156 | rowsS = ''.join(rows_lst) 157 | for i, iS in enumerate(iSs): 158 | if iS in rowsS: 159 | rowsS = rowsS.replace(iS, 'E' * (i + 1)) 160 | return rowsS, is_white, castling 161 | 162 | def fen_to_vec(fenstr): 163 | return simple_fen_vec(*preproc_fen(fenstr)) 164 | 165 | def game_to_vecs(game): 166 | boards = [] 167 | board = game.board() 168 | for i, node in enumerate(game.mainline()): 169 | fen = str(board.fen()) 170 | board.push(node.move) 171 | boards.append(fenToVec(fen)) 172 | return np.stack(boards, axis = 0) 173 | -------------------------------------------------------------------------------- /backend/multiproc.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import collections.abc 3 | import time 4 | import sys 5 | import traceback 6 | import functools 7 | import pickle 8 | 9 | class Multiproc(object): 10 | def __init__(self, num_procs, max_queue_size = 1000, proc_check_interval = .1): 11 | self.num_procs = num_procs 12 | self.max_queue_size = max_queue_size 13 | self.proc_check_interval = proc_check_interval 14 | 15 | self.reader = MultiprocIterable 16 | self.reader_args = [] 17 | self.reader_kwargs = {} 18 | 19 | self.processor = MultiprocWorker 20 | self.processor_args = [] 21 | self.processor_kwargs = {} 22 | 23 | self.writer = MultiprocWorker 24 | self.writer_args = [] 25 | self.writer_kwargs = {} 26 | 27 | def reader_init(self, reader_cls, *reader_args, **reader_kwargs): 28 | self.reader = reader_cls 29 | self.reader_args = reader_args 30 | self.reader_kwargs = reader_kwargs 31 | 32 | def processor_init(self, processor_cls, *processor_args, **processor_kwargs): 33 | self.processor = processor_cls 34 | self.processor_args = processor_args 35 | self.processor_kwargs = processor_kwargs 36 | 37 | def writer_init(self, writer_cls, *writer_args, **writer_kwargs): 38 | self.writer = writer_cls 39 | self.writer_args = writer_args 40 | self.writer_kwargs = writer_kwargs 41 | 42 | 43 | def run(self): 44 | with multiprocessing.Pool(self.num_procs + 2) as pool, multiprocessing.Manager() as manager: 45 | inputQueue = manager.Queue(self.max_queue_size) 46 | resultsQueue = manager.Queue(self.max_queue_size) 47 | reader_proc = pool.apply_async(reader_loop, (inputQueue, self.num_procs, self.reader, self.reader_args, self.reader_kwargs)) 48 | 49 | worker_procs = [] 50 | for _ in range(self.num_procs): 51 | wp = pool.apply_async(processor_loop, (inputQueue, resultsQueue, self.processor, self.processor_args, self.processor_kwargs)) 52 | worker_procs.append(wp) 53 | 54 | writer_proc = pool.apply_async(writer_loop, (resultsQueue, self.num_procs, self.writer, self.writer_args, self.writer_kwargs)) 55 | 56 | self.cleanup(reader_proc, worker_procs, writer_proc) 57 | 58 | def cleanup(self, reader_proc, worker_procs, writer_proc): 59 | reader_working = True 60 | processor_working = True 61 | writer_working = True 62 | while reader_working or processor_working or writer_working: 63 | if reader_working and reader_proc.ready(): 64 | reader_proc.get() 65 | reader_working = False 66 | 67 | if processor_working: 68 | new_procs = [] 69 | for p in worker_procs: 70 | if p.ready(): 71 | p.get() 72 | else: 73 | new_procs.append(p) 74 | if len(new_procs) < 1: 75 | processor_working = False 76 | else: 77 | worker_procs = new_procs 78 | 79 | if writer_working and writer_proc.ready(): 80 | writer_proc.get() 81 | writer_working = False 82 | time.sleep(self.proc_check_interval) 83 | 84 | def catch_remote_exceptions(wrapped_function): 85 | """ https://stackoverflow.com/questions/6126007/python-getting-a-traceback """ 86 | 87 | @functools.wraps(wrapped_function) 88 | def new_function(*args, **kwargs): 89 | try: 90 | return wrapped_function(*args, **kwargs) 91 | 92 | except: 93 | raise Exception( "".join(traceback.format_exception(*sys.exc_info())) ) 94 | 95 | return new_function 96 | 97 | @catch_remote_exceptions 98 | def reader_loop(inputQueue, num_workers, reader_cls, reader_args, reader_kwargs): 99 | with reader_cls(*reader_args, **reader_kwargs) as R: 100 | for dat in R: 101 | inputQueue.put(dat) 102 | for i in range(num_workers): 103 | inputQueue.put(_QueueDone(count = i)) 104 | 105 | @catch_remote_exceptions 106 | def processor_loop(inputQueue, resultsQueue, processor_cls, processor_args, processor_kwargs): 107 | with processor_cls(*processor_args, **processor_kwargs) as Proc: 108 | while True: 109 | dat = inputQueue.get() 110 | if isinstance(dat, _QueueDone): 111 | resultsQueue.put(dat) 112 | break 113 | try: 114 | if isinstance(dat, tuple): 115 | procced_dat = Proc(*dat) 116 | else: 117 | procced_dat = Proc(dat) 118 | except SkipCallMultiProc: 119 | pass 120 | except: 121 | raise 122 | resultsQueue.put(procced_dat) 123 | 124 | @catch_remote_exceptions 125 | def writer_loop(resultsQueue, num_workers, writer_cls, writer_args, writer_kwargs): 126 | complete_workers = 0 127 | with writer_cls(*writer_args, **writer_kwargs) as W: 128 | if W is None: 129 | raise AttributeError(f"Worker was created, but closure failed to form") 130 | while complete_workers < num_workers: 131 | dat = resultsQueue.get() 132 | if isinstance(dat, _QueueDone): 133 | complete_workers += 1 134 | else: 135 | if isinstance(dat, tuple): 136 | W(*dat) 137 | else: 138 | W(dat) 139 | 140 | class SkipCallMultiProc(Exception): 141 | pass 142 | 143 | class _QueueDone(object): 144 | def __init__(self, count = 0): 145 | self.count = count 146 | 147 | class MultiprocWorker(collections.abc.Callable): 148 | 149 | def __call__(self, *args): 150 | return None 151 | 152 | def __enter__(self): 153 | return self 154 | 155 | def __exit__(self, exc_type, exc_value, traceback): 156 | pass 157 | 158 | class MultiprocIterable(MultiprocWorker, collections.abc.Iterator): 159 | def __next__(self): 160 | raise StopIteration 161 | 162 | def __call__(self, *args): 163 | return next(self) 164 | -------------------------------------------------------------------------------- /backend/pgn_parsering.py: -------------------------------------------------------------------------------- 1 | import re 2 | import bz2 3 | 4 | import chess.pgn 5 | 6 | moveRegex = re.compile(r'\d+[.][ \.](\S+) (?:{[^}]*} )?(\S+)') 7 | 8 | class GamesFile(object): 9 | def __init__(self, path): 10 | if path.endswith('bz2'): 11 | self.f = bz2.open(path, 'rt') 12 | else: 13 | self.f = open(path, 'r') 14 | self.path = path 15 | self.i = 0 16 | 17 | def __iter__(self): 18 | try: 19 | while True: 20 | yield next(self) 21 | except StopIteration: 22 | return 23 | 24 | def __del__(self): 25 | try: 26 | self.f.close() 27 | except AttributeError: 28 | pass 29 | 30 | def __next__(self): 31 | 32 | ret = {} 33 | lines = '' 34 | for l in self.f: 35 | self.i += 1 36 | lines += l 37 | if len(l) < 2: 38 | if len(ret) >= 2: 39 | break 40 | else: 41 | raise RuntimeError(l) 42 | else: 43 | try: 44 | k, v, _ = l.split('"') 45 | except ValueError: 46 | #bad line 47 | if l == 'null\n': 48 | pass 49 | else: 50 | raise 51 | else: 52 | ret[k[1:-1]] = v 53 | nl = self.f.readline() 54 | lines += nl 55 | lines += self.f.readline() 56 | if len(lines) < 1: 57 | raise StopIteration 58 | return ret, lines 59 | -------------------------------------------------------------------------------- /backend/proto/__init__.py: -------------------------------------------------------------------------------- 1 | from .net_pb2 import Net, NetworkFormat 2 | -------------------------------------------------------------------------------- /backend/proto/net.proto: -------------------------------------------------------------------------------- 1 | /* 2 | This file is part of Leela Chess Zero. 3 | Copyright (C) 2018 The LCZero Authors 4 | 5 | Leela Chess is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | Leela Chess is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with Leela Chess. If not, see