├── .dvc ├── .gitignore └── config ├── .gitattributes ├── .gitignore ├── README.md ├── configs └── preprocess_params.json ├── data ├── .gitignore ├── preprocessed │ └── .gitignore └── raw_data │ └── json │ └── 2017 │ └── .gitignore ├── data_exploration_and_heatmaps.ipynb ├── dvc-stages ├── eval_get_moves.dvc ├── eval_moves_to_fen.dvc ├── eval_pgn_to_json.dvc ├── moves_to_fen_black_human.dvc ├── moves_to_fen_white_human.dvc ├── pgn_to_json_2016.dvc ├── pgn_to_json_2017.dvc ├── pgn_to_json_2018.dvc ├── pgn_to_json_2019.dvc ├── preprocess_black_human_player.dvc ├── preprocess_white_human_player.dvc ├── train_CNN_LSTM_black_human.dvc ├── train_CNN_LSTM_white_human.dvc ├── train_conv3D_black_human.dvc └── train_fully_connected_LSTM_black_human.dvc ├── environment.yml ├── models ├── .gitignore ├── best_model_black_human.h5 └── best_model_white_human.h5 └── python_code ├── .gitignore ├── __intit__.py ├── __pycache__ └── pgn_to_fen.cpython-37.pyc ├── eval_get_moves.py ├── make_prediction.py ├── moves_to_fen.py ├── pgn_to_fen.py ├── pgn_to_json.py ├── preprocess.py ├── test_moves_to_fen.py ├── train_CNN_LSTM.py ├── train_Conv3D.py └── train_Fully_Connected_LSTM.py /.dvc/.gitignore: -------------------------------------------------------------------------------- 1 | /config.local 2 | /updater 3 | /lock 4 | /updater.lock 5 | /tmp 6 | /state-journal 7 | /state-wal 8 | /state 9 | /cache 10 | -------------------------------------------------------------------------------- /.dvc/config: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moritzhambach/Detecting-Cheating-in-Chess/658798f5d5e36a3420603e182323ebef5130f535/.dvc/config -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-vendored -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /overview.pdf 2 | /__pycache__/ 3 | /.vscode/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Detecting-Cheating-in-Chess 2 | Can chess engine use be detected just from the moves on the board (but no chess engine)? Let's try it out, using a CNN - LSTM architecture and other architectures. 3 | - [Usage](#Usage) 4 | - [Why is cheating in chess hard to detect](#Why-is-cheating-in-chess-hard-to-detect) 5 | - [Getting the data](#Getting-the-data) 6 | - [Preprocessing](#Preprocessing) 7 | - [Visualisation as heatmaps](#Visualisation-as-heatmaps) 8 | - [CNN LSTM](#CNN-LSTM) 9 | 10 | 11 | 12 | 13 | ## Usage 14 | ### installation 15 | clone repo, then from root of project: `conda create -f environment.yml` , `activate chess-classifier` 16 | ### classify a game 17 | * get the pgn notation, save it under data/raw_data/evaluation/eval.pgn 18 | * run `dvc repro dvc-stages/eval_moves_to_fen.dvc && python python_code/make_prediction.py --human-player-color White` (or Black if you were playing black pieces) 19 | ### train with your own data 20 | * get pgn data, put it into folders data/raw_data/, make sure the preprocess stages have the correct input path variable, then run `dvc repro dvc-stages/train_CNN_LSTM_black_human.dvc`. Might take a while for a lot of data. If happy with the results, put trained model into models/best_model_black_human.h5 (or white), and start predicting. 21 | 22 | ### reproducibility 23 | * I use data versioning control (https://dvc.org/) for a reproducible pipeline from data ingestion to preprocessing and training. 24 | 25 | ## Why is cheating in chess hard to detect 26 | Online chess suffers from the problem that the opponent could easily enter the moves into a chess engine on their smartphone and win easily. 27 | Chess websites try to detect this by running an engine theirselves and comparing the moves played to the suggestions. Sophisticated cheaters could circumvent this by randomly choosing moves that are further down the list of sugestion list, or play a bad move once in a while, since they will win anyways. Also, with new chess engines based on neural networks (like Deepminds "Alpha Zero" or open source "Leela Chess Zero"), the comparison of chess moves might need to be done for several engines. Another try to catch cheaters currently is to analyze timing between moves, but future engines could not require much time to calculate and the "natural" waiting times can be added to fool the detection tool. 28 | 29 | Let's approach the problem from another side. Experienced chess players can often detect a cheater when he plays "unintuitive" moves, 30 | the kind that a human wouldn't ever naturally come up with. This hints that there might be hidden patterns in human chess, which could be 31 | distinguished from the way a computer plays (non-principled, just play whatever works out best 20-30 moves into the future). 32 | 33 | Enter machine learning, specifically Deep Learning. Since the problem is both spatial (8x8 board) and temporal (board changes each move), 34 | my intuitive approach is using convolutional neural networks (CNN) which feed their results into recurrent neural networks (RNN), specifically 35 | a type called Long-Short-Term Memory (LSTM). But let's start from the beginning: 36 | 37 | ## Getting the data 38 | 39 | millions of chess games are easily available online, and are usually labeled whether a computer or human played each side. I dowloaded close to 1 million games from https://www.ficsgames.org/ . The data comes in pgn format, which I first convert into JSON (using https://github.com/Assios/pgn-to-json ) and then read into pandas dataframe for preprocessing. The moves themselves are converted into board states (with help of https://github.com/niklasf/python-chess), meaning (channel x 8 x 8) arrays of 1s and 0s, where a 1 means a piece exists at this position, and the channel determines the piece type (white knight, black Queen, etc). As there are 6 pieces per color (pawns, knights, bishops, rooks, Queen, King), we have 12 channels. The board state per move is then stacked onto each other to create a tensor of shape (time, channel, 8, 8 ) for each game. I also get another 12 channels with the fields that each piece can currently attack, although this does not seem to help too much. 40 | 41 | ## Preprocessing 42 | 43 | 1.) only take games where the human player lost (who cares about engines if you at least draw) 44 | 2.) opponent is 50/50 human or computer 45 | 3.) select games lengths between 20 and 100 ply, and only train and evaluate model with from 20 to 40 ply. The first 10 moves could be memorized, so detecting engines here does not make sense. 46 | 47 | Currently left with about 90k games for each case (human plays black or white). 48 | 49 | ## Visualisation as heatmaps 50 | Below we see heatmaps (average square occupation over the game), averaged over a thousand games each (here including the whole game, also the opening). There seems to be slight differences in black human vs black computer heatmaps, but in first baseline attempts wasn't able to get more than 60% test accuracy when using only heatmaps as training data for classification. More heatmaps can be seen in the data exploration notebook. 51 | 52 | black pawns, black is computer: 53 | 54 | ![alt text](https://user-images.githubusercontent.com/33765868/43685360-05665d3e-98b2-11e8-80d3-7586e53cdc1e.png) 55 | 56 | black pawns, black is human: 57 | 58 | ![alt text](https://user-images.githubusercontent.com/33765868/43685394-8e200774-98b2-11e8-88b6-e95bfd5b7ade.png) 59 | 60 | 61 | ## CNN LSTM 62 | (work in progress) 63 | Using the TimeDistributed wrapper on Conv2D layers allows easy setup of my network. The (channelx8x8) maps of each time step fist undergo 3 convolutional layers of kernel size 3x3 without padding, reducing the size to (filter x 2 x 2), are then flattened and fed into LSTM neurons, followed by a Dense (fully connected layer). The currently best result is 80% accuracy, see below. It is still overfitting, although dropout is applied. Will add more data soon. 64 | 65 | ![alt text](https://user-images.githubusercontent.com/33765868/43685326-382504ce-98b1-11e8-8564-a89dd4d4c57a.png) 66 | 67 | ## Transformers to the rescue?! 68 | with recent applications of Transformers (originally used in NLP) to images, see https://paperswithcode.com/paper/an-image-is-worth-16x16-words-transformers, maybe this might be of great use here. After all, when a human looks at a chess board, we don't look at all pieces and squares equally, but focus our attention (!) on the ones relevant for tactics etc. I will try it out, stay tuned! 69 | -------------------------------------------------------------------------------- /configs/preprocess_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "timecontrols": [ 3 | "300+0", 4 | "600+0", 5 | "900+0", 6 | "900+5", 7 | "900+10", 8 | "1200+0" 9 | ], 10 | "plymin": 20, 11 | "plymax": 40, 12 | "max_game_length": 100 13 | } -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | /raw_data 2 | -------------------------------------------------------------------------------- /data/preprocessed/.gitignore: -------------------------------------------------------------------------------- 1 | /games.parquet 2 | /games_black_human.parquet 3 | /games_white_human.parquet 4 | /fen_white_human.npy 5 | /fen_black_human.npy.npz 6 | /fen_white_human.npy.npz 7 | /fen_black_human_labels.npy.npz 8 | /fen_white_human_labels.npy.npz 9 | /fen_black_human_attacks.npy.npz 10 | /fen_white_human_attacks.npy.npz 11 | /eval.parquet 12 | /fen_eval_attacks.npy.npz 13 | /fen_eval_labels.npy.npz 14 | /fen_eval.npy.npz 15 | -------------------------------------------------------------------------------- /data/raw_data/json/2017/.gitignore: -------------------------------------------------------------------------------- 1 | /join_data.json 2 | -------------------------------------------------------------------------------- /dvc-stages/eval_get_moves.dvc: -------------------------------------------------------------------------------- 1 | md5: d78705e2289b88a8e5f375e8697bd618 2 | cmd: python python_code/eval_get_moves.py --input-path data/raw_data/json/evaluation/eval.json 3 | --output-path data/preprocessed/eval.parquet 4 | wdir: .. 5 | deps: 6 | - md5: 6e4c0894a24382bfb3d70825f18f4a18 7 | path: data/raw_data/json/evaluation/eval.json 8 | - md5: ecf426d8c33067e8bb6b6e0c086d81cd 9 | path: python_code/eval_get_moves.py 10 | outs: 11 | - md5: e95868f5758d9067cc38429447f6da04 12 | path: data/preprocessed/eval.parquet 13 | cache: true 14 | metric: false 15 | persist: false 16 | -------------------------------------------------------------------------------- /dvc-stages/eval_moves_to_fen.dvc: -------------------------------------------------------------------------------- 1 | md5: e6b363097e002fc5a3fc111c95b906eb 2 | cmd: python python_code/moves_to_fen.py --input-path data/preprocessed/eval.parquet 3 | --params-path configs/preprocess_params.json --output-path data/preprocessed/fen_eval.npy.npz 4 | --output-path-labels data/preprocessed/fen_eval_labels.npy.npz --output-path-attacks 5 | data/preprocessed/fen_eval_attacks.npy.npz 6 | wdir: .. 7 | deps: 8 | - md5: e95868f5758d9067cc38429447f6da04 9 | path: data/preprocessed/eval.parquet 10 | - md5: fe8b2c3497df52f1bccfa478744f551a 11 | path: python_code/moves_to_fen.py 12 | - md5: c3f100f339d669cc08267f7143d33ada 13 | path: configs/preprocess_params.json 14 | outs: 15 | - md5: d2d7cc3a5b9b854a1bf64e214c35839c 16 | path: data/preprocessed/fen_eval_attacks.npy.npz 17 | cache: true 18 | metric: false 19 | persist: false 20 | - md5: 3f1b333faa0f693f2367574f8eb9dfe7 21 | path: data/preprocessed/fen_eval_labels.npy.npz 22 | cache: true 23 | metric: false 24 | persist: false 25 | - md5: 2fe8c83e199fc1435ae6751d036d166f 26 | path: data/preprocessed/fen_eval.npy.npz 27 | cache: true 28 | metric: false 29 | persist: false 30 | -------------------------------------------------------------------------------- /dvc-stages/eval_pgn_to_json.dvc: -------------------------------------------------------------------------------- 1 | md5: d03444e1df9981768b18aa5e38151953 2 | cmd: python python_code/pgn_to_json.py data/raw_data/evaluation/ data/raw_data/json/evaluation/ 3 | 1 4 | wdir: .. 5 | deps: 6 | - md5: 88757bbf6a1f88c90784efd1509696bf.dir 7 | path: data/raw_data/evaluation 8 | - md5: 8fb0a3a2f55cab1d4c96432badc31a1c 9 | path: python_code/pgn_to_json.py 10 | outs: 11 | - md5: 6e4c0894a24382bfb3d70825f18f4a18 12 | path: data/raw_data/json/evaluation/eval.json 13 | cache: true 14 | metric: false 15 | persist: false 16 | -------------------------------------------------------------------------------- /dvc-stages/moves_to_fen_black_human.dvc: -------------------------------------------------------------------------------- 1 | md5: d650934c79e7dcbd5c13b861f8b4e003 2 | cmd: python python_code/moves_to_fen.py --input-path data/preprocessed/games_black_human.parquet 3 | --output-path data/preprocessed/fen_black_human.npy.npz --output-path-labels data/preprocessed/fen_black_human_labels.npy.npz 4 | --output-path-attacks data/preprocessed/fen_black_human_attacks.npy.npz --params-path 5 | configs/preprocess_params.json 6 | wdir: .. 7 | deps: 8 | - md5: c3f100f339d669cc08267f7143d33ada 9 | path: configs/preprocess_params.json 10 | - md5: ff6bca4ddb76453a3aa2018cf1cf8594 11 | path: data/preprocessed/games_black_human.parquet 12 | - md5: fe8b2c3497df52f1bccfa478744f551a 13 | path: python_code/moves_to_fen.py 14 | outs: 15 | - md5: 30c01ca9651e2ceb9e775aa94bc371cd 16 | path: data/preprocessed/fen_black_human.npy.npz 17 | cache: true 18 | metric: false 19 | persist: false 20 | - md5: 84bca342766ea9338e2f76c0ee117825 21 | path: data/preprocessed/fen_black_human_labels.npy.npz 22 | cache: true 23 | metric: false 24 | persist: false 25 | - md5: 01aafc5a798af0d8a30cb7bae8f05337 26 | path: data/preprocessed/fen_black_human_attacks.npy.npz 27 | cache: true 28 | metric: false 29 | persist: false 30 | -------------------------------------------------------------------------------- /dvc-stages/moves_to_fen_white_human.dvc: -------------------------------------------------------------------------------- 1 | md5: 6d5d1f24b604062b2e8fb7aeb8ce741a 2 | cmd: python python_code/moves_to_fen.py --input-path data/preprocessed/games_white_human.parquet 3 | --output-path data/preprocessed/fen_white_human.npy.npz --output-path-labels data/preprocessed/fen_white_human_labels.npy.npz 4 | --output-path-attacks data/preprocessed/fen_white_human_attacks.npy.npz --params-path 5 | configs/preprocess_params.json 6 | wdir: .. 7 | deps: 8 | - md5: c3f100f339d669cc08267f7143d33ada 9 | path: configs/preprocess_params.json 10 | - md5: 6a31f8dfbf282208dc4ebf6b1c369589 11 | path: data/preprocessed/games_white_human.parquet 12 | - md5: fe8b2c3497df52f1bccfa478744f551a 13 | path: python_code/moves_to_fen.py 14 | outs: 15 | - md5: 16bb411932a37931f31400183e2aa4ba 16 | path: data/preprocessed/fen_white_human.npy.npz 17 | cache: true 18 | metric: false 19 | persist: false 20 | - md5: d79731e398f2d902d7e9db6fbbe55318 21 | path: data/preprocessed/fen_white_human_labels.npy.npz 22 | cache: true 23 | metric: false 24 | persist: false 25 | - md5: 84c9becedfe8cfbce87a10d59bc1fc53 26 | path: data/preprocessed/fen_white_human_attacks.npy.npz 27 | cache: true 28 | metric: false 29 | persist: false 30 | -------------------------------------------------------------------------------- /dvc-stages/pgn_to_json_2016.dvc: -------------------------------------------------------------------------------- 1 | md5: eca0facdd1c83e18785545a948a6c130 2 | cmd: python python_code/pgn_to_json.py data/raw_data/2016/ data/raw_data/json/2016/ 3 | 100000 join 4 | wdir: .. 5 | deps: 6 | - md5: 06a30e133f24c1d0e8dea0432941a098.dir 7 | path: data/raw_data/2016 8 | - md5: 8fb0a3a2f55cab1d4c96432badc31a1c 9 | path: python_code/pgn_to_json.py 10 | outs: 11 | - md5: fe375e5805a974805c31c58a1b62682d 12 | path: data/raw_data/json/2016/join_data.json 13 | cache: true 14 | metric: false 15 | persist: false 16 | -------------------------------------------------------------------------------- /dvc-stages/pgn_to_json_2017.dvc: -------------------------------------------------------------------------------- 1 | md5: 97f6b49f6eaf299ee1634bbff5a25605 2 | cmd: python python_code/pgn_to_json.py data/raw_data/2017/ data/raw_data/json/2017/ 3 | 200000 join 4 | wdir: .. 5 | deps: 6 | - md5: 295654578ab2abf08b2bdf51679da12c.dir 7 | path: data/raw_data/2017 8 | - md5: 8fb0a3a2f55cab1d4c96432badc31a1c 9 | path: python_code/pgn_to_json.py 10 | outs: 11 | - md5: b8cdb3aa78a075b373ea0ea2886a240d 12 | path: data/raw_data/json/2017/join_data.json 13 | cache: true 14 | metric: false 15 | persist: false 16 | -------------------------------------------------------------------------------- /dvc-stages/pgn_to_json_2018.dvc: -------------------------------------------------------------------------------- 1 | md5: 9527f314708132ac874e1ca61becbb6d 2 | cmd: python python_code/pgn_to_json.py data/raw_data/2018/ data/raw_data/json/2018/ 3 | 200000 join 4 | wdir: .. 5 | deps: 6 | - md5: 28a49f0219c3a816b1ce8356bfd84960.dir 7 | path: data/raw_data/2018 8 | - md5: 8fb0a3a2f55cab1d4c96432badc31a1c 9 | path: python_code/pgn_to_json.py 10 | outs: 11 | - md5: 1fdbbc8194a0bfded0f9324829b85ebf 12 | path: data/raw_data/json/2018/join_data.json 13 | cache: true 14 | metric: false 15 | persist: false 16 | -------------------------------------------------------------------------------- /dvc-stages/pgn_to_json_2019.dvc: -------------------------------------------------------------------------------- 1 | md5: de5f0913a2bd1fdf27e3dc83633f9dab 2 | cmd: python python_code/pgn_to_json.py data/raw_data/2019/ data/raw_data/json/2019/ 3 | 200000 join 4 | wdir: .. 5 | deps: 6 | - md5: 3f0b579225c4e29a3ac852949630cce0.dir 7 | path: data/raw_data/2019 8 | - md5: 8fb0a3a2f55cab1d4c96432badc31a1c 9 | path: python_code/pgn_to_json.py 10 | outs: 11 | - md5: 6528fe344a049f0e695cde35d7b43aba 12 | path: data/raw_data/json/2019/join_data.json 13 | cache: true 14 | metric: false 15 | persist: false 16 | -------------------------------------------------------------------------------- /dvc-stages/preprocess_black_human_player.dvc: -------------------------------------------------------------------------------- 1 | md5: 203bc8d936994280d0c5c09ea0b7c636 2 | cmd: python python_code/preprocess.py --input-paths data/raw_data/json/2017/join_data.json,data/raw_data/json/2018/join_data.json,data/raw_data/json/2019/join_data.json,data/raw_data/json/2016/join_data.json 3 | --output-path data/preprocessed/games_black_human.parquet --params-path configs/preprocess_params.json 4 | --human-color Black 5 | wdir: .. 6 | deps: 7 | - md5: c3f100f339d669cc08267f7143d33ada 8 | path: configs/preprocess_params.json 9 | - md5: fe375e5805a974805c31c58a1b62682d 10 | path: data/raw_data/json/2016/join_data.json 11 | - md5: b8cdb3aa78a075b373ea0ea2886a240d 12 | path: data/raw_data/json/2017/join_data.json 13 | - md5: 1fdbbc8194a0bfded0f9324829b85ebf 14 | path: data/raw_data/json/2018/join_data.json 15 | - md5: 6528fe344a049f0e695cde35d7b43aba 16 | path: data/raw_data/json/2019/join_data.json 17 | - md5: 2f7dc46cf530fa85d06f1b41bfd61bfe 18 | path: python_code/preprocess.py 19 | outs: 20 | - md5: ff6bca4ddb76453a3aa2018cf1cf8594 21 | path: data/preprocessed/games_black_human.parquet 22 | cache: true 23 | metric: false 24 | persist: false 25 | -------------------------------------------------------------------------------- /dvc-stages/preprocess_white_human_player.dvc: -------------------------------------------------------------------------------- 1 | md5: ae71195367a25d9558df0f49b83e1e1e 2 | cmd: python python_code/preprocess.py --input-paths data/raw_data/json/2017/join_data.json,data/raw_data/json/2018/join_data.json,data/raw_data/json/2019/join_data.json,data/raw_data/json/2016/join_data.json 3 | --output-path data/preprocessed/games_white_human.parquet --params-path configs/preprocess_params.json 4 | --human-color White 5 | wdir: .. 6 | deps: 7 | - md5: c3f100f339d669cc08267f7143d33ada 8 | path: configs/preprocess_params.json 9 | - md5: fe375e5805a974805c31c58a1b62682d 10 | path: data/raw_data/json/2016/join_data.json 11 | - md5: b8cdb3aa78a075b373ea0ea2886a240d 12 | path: data/raw_data/json/2017/join_data.json 13 | - md5: 1fdbbc8194a0bfded0f9324829b85ebf 14 | path: data/raw_data/json/2018/join_data.json 15 | - md5: 6528fe344a049f0e695cde35d7b43aba 16 | path: data/raw_data/json/2019/join_data.json 17 | - md5: 2f7dc46cf530fa85d06f1b41bfd61bfe 18 | path: python_code/preprocess.py 19 | outs: 20 | - md5: 6a31f8dfbf282208dc4ebf6b1c369589 21 | path: data/preprocessed/games_white_human.parquet 22 | cache: true 23 | metric: false 24 | persist: false 25 | -------------------------------------------------------------------------------- /dvc-stages/train_CNN_LSTM_black_human.dvc: -------------------------------------------------------------------------------- 1 | md5: c0ec342499b812524cf460c65f336303 2 | cmd: python python_code/train_CNN_LSTM.py --input-path data/preprocessed/fen_black_human.npy.npz 3 | --input-path-labels data/preprocessed/fen_black_human_labels.npy.npz --output-path 4 | models/LSTM_model_black_human.h5 --input-path-attacks data/preprocessed/fen_black_human_attacks.npy.npz 5 | wdir: .. 6 | deps: 7 | - md5: 01aafc5a798af0d8a30cb7bae8f05337 8 | path: data/preprocessed/fen_black_human_attacks.npy.npz 9 | - md5: 30c01ca9651e2ceb9e775aa94bc371cd 10 | path: data/preprocessed/fen_black_human.npy.npz 11 | - md5: 84bca342766ea9338e2f76c0ee117825 12 | path: data/preprocessed/fen_black_human_labels.npy.npz 13 | - md5: a0b666a45095e47d9731441e17a92229 14 | path: python_code/train_CNN_LSTM.py 15 | outs: 16 | - md5: 714e5b2bb81df45018ec4615f0a784ad 17 | path: models/LSTM_model_black_human.h5 18 | cache: true 19 | metric: false 20 | persist: false 21 | -------------------------------------------------------------------------------- /dvc-stages/train_CNN_LSTM_white_human.dvc: -------------------------------------------------------------------------------- 1 | md5: 02e4a42412390dd2defcdf3522e5f88a 2 | cmd: python python_code/train_CNN_LSTM.py --input-path data/preprocessed/fen_white_human.npy.npz 3 | --input-path-labels data/preprocessed/fen_white_human_labels.npy.npz --output-path 4 | models/LSTM_model_white_human.h5 --input-path-attacks data/preprocessed/fen_white_human_attacks.npy.npz 5 | wdir: .. 6 | deps: 7 | - md5: 84c9becedfe8cfbce87a10d59bc1fc53 8 | path: data/preprocessed/fen_white_human_attacks.npy.npz 9 | - md5: 16bb411932a37931f31400183e2aa4ba 10 | path: data/preprocessed/fen_white_human.npy.npz 11 | - md5: d79731e398f2d902d7e9db6fbbe55318 12 | path: data/preprocessed/fen_white_human_labels.npy.npz 13 | - md5: a0b666a45095e47d9731441e17a92229 14 | path: python_code/train_CNN_LSTM.py 15 | outs: 16 | - md5: 9ec511d539d37d4ce1b94012dd477265 17 | path: models/LSTM_model_white_human.h5 18 | cache: true 19 | metric: false 20 | persist: false 21 | -------------------------------------------------------------------------------- /dvc-stages/train_conv3D_black_human.dvc: -------------------------------------------------------------------------------- 1 | md5: df3b33fc23cb8daa97b7185bc7bcfed3 2 | cmd: python python_code/train_conv3D.py --input-path data/preprocessed/fen_black_human.npy.npz 3 | --input-path-labels data/preprocessed/fen_black_human_labels.npy.npz --input-path-attacks 4 | data/preprocessed/fen_black_human_attacks.npy.npz --output-path models/model_conv3D_black_human.h5 5 | wdir: .. 6 | deps: 7 | - md5: 755e3d09f7a96762c1e09cb051998b9e 8 | path: data/preprocessed/fen_black_human.npy.npz 9 | - md5: 75c4f9d3b5a54fc82fcefa8f42124332 10 | path: data/preprocessed/fen_black_human_attacks.npy.npz 11 | - md5: 7efd277d3eccc6625e268dd9dcbf1b55 12 | path: data/preprocessed/fen_black_human_labels.npy.npz 13 | - md5: cb03edfa1b672834dcfbcb22d462e5e8 14 | path: python_code/train_conv3D.py 15 | outs: 16 | - md5: 1824bd38e2438edfc9b8852ed0f21583 17 | path: models/model_conv3D_black_human.h5 18 | cache: true 19 | metric: false 20 | persist: false 21 | -------------------------------------------------------------------------------- /dvc-stages/train_fully_connected_LSTM_black_human.dvc: -------------------------------------------------------------------------------- 1 | md5: b6fe1b85a0ce42fe96296eb35744300d 2 | cmd: python python_code/train_Fully_Connected_LSTM.py --input-path data/preprocessed/fen_black_human.npy.npz 3 | --input-path-labels data/preprocessed/fen_black_human_labels.npy.npz --output-path 4 | models/FC_LSTM_model_black_human.h5 --input-path-attacks data/preprocessed/fen_black_human_attacks.npy.npz 5 | wdir: .. 6 | deps: 7 | - md5: 01aafc5a798af0d8a30cb7bae8f05337 8 | path: data/preprocessed/fen_black_human_attacks.npy.npz 9 | - md5: 30c01ca9651e2ceb9e775aa94bc371cd 10 | path: data/preprocessed/fen_black_human.npy.npz 11 | - md5: 84bca342766ea9338e2f76c0ee117825 12 | path: data/preprocessed/fen_black_human_labels.npy.npz 13 | - md5: d8fed540f40a8d4b8afb7b2532065c38 14 | path: python_code/train_Fully_Connected_LSTM.py 15 | outs: 16 | - md5: a4ad8d470dc252492bb51590d838f4f1 17 | path: models/FC_LSTM_model_black_human.h5 18 | cache: true 19 | metric: false 20 | persist: false 21 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: chess-classifier 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.7 6 | - pytest=5.0.1 7 | - pytest-cov 8 | - flake8=3.8.1 9 | - pip 10 | - click=7.0 11 | - numpy=1.16.4 12 | - pandas=1.0.4 13 | - tqdm 14 | - vaex 15 | - xarray 16 | - dask=2.1.0 17 | - tqdm 18 | - pip: 19 | - dvc[s3]==0.93.0 20 | - tensorflow==2.0.0-alpha0 21 | - python-chess 22 | - seaborn 23 | - scikit-learn -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | /model_black_human.h5 2 | /model_conv3D_black_human.h5 3 | /LSTM_model_black_human.h5 4 | /FC_LSTM_model_black_human.h5 5 | /LSTM_model_white_human.h5 6 | -------------------------------------------------------------------------------- /models/best_model_black_human.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moritzhambach/Detecting-Cheating-in-Chess/658798f5d5e36a3420603e182323ebef5130f535/models/best_model_black_human.h5 -------------------------------------------------------------------------------- /models/best_model_white_human.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moritzhambach/Detecting-Cheating-in-Chess/658798f5d5e36a3420603e182323ebef5130f535/models/best_model_white_human.h5 -------------------------------------------------------------------------------- /python_code/.gitignore: -------------------------------------------------------------------------------- 1 | /__pycache__ 2 | -------------------------------------------------------------------------------- /python_code/__intit__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moritzhambach/Detecting-Cheating-in-Chess/658798f5d5e36a3420603e182323ebef5130f535/python_code/__intit__.py -------------------------------------------------------------------------------- /python_code/__pycache__/pgn_to_fen.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moritzhambach/Detecting-Cheating-in-Chess/658798f5d5e36a3420603e182323ebef5130f535/python_code/__pycache__/pgn_to_fen.cpython-37.pyc -------------------------------------------------------------------------------- /python_code/eval_get_moves.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import click 3 | import logging 4 | 5 | logging.basicConfig(level=logging.INFO,) 6 | 7 | LOGGER = logging.getLogger() 8 | 9 | 10 | @click.command() 11 | @click.option("--input-path", help="filename of single input file", required=True) 12 | @click.option("--output-path", help="where to save result (as parquet)", required=True) 13 | def main(input_path, output_path): 14 | df = pd.read_json(input_path) 15 | df["opponentIsComp"] = -1 # column is expected by next stage 16 | df[["moves", "opponentIsComp"]].to_parquet(output_path) 17 | 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /python_code/make_prediction.py: -------------------------------------------------------------------------------- 1 | import click 2 | import logging 3 | import numpy as np 4 | import tensorflow as tf 5 | import h5py 6 | import json 7 | import pandas as pd 8 | 9 | 10 | logging.basicConfig(level=logging.INFO,) 11 | 12 | LOGGER = logging.getLogger() 13 | 14 | 15 | def checkParameters(df, human_color, params): 16 | if not len(df) == 1: 17 | LOGGER.info("evaluating single games only, check your input!") 18 | raise ValueError 19 | else: 20 | LOGGER.info("check 1 passed") 21 | if not ( 22 | (human_color == "White" and df.Result[0] == "0-1") 23 | or (human_color == "Black" and df.Result[0] == "1-0") 24 | ): 25 | LOGGER.info("you did not lose the game, why do you care if engine was used?") 26 | LOGGER.info("model was not trained on won games, might be incorrect") 27 | else: 28 | LOGGER.info("check 2 passed") 29 | if not (len(df.moves[0]) > params["plymin"]): 30 | LOGGER.info( 31 | "game is too short, can not distinguish engine use from opening knowledge" 32 | ) 33 | raise ValueError 34 | else: 35 | LOGGER.info("check 3 passed") 36 | if not (df.TimeControl[0] in params["timecontrols"]): 37 | LOGGER.info( 38 | "game has a different time control than the model knows about, results might be incorrect" 39 | ) 40 | else: 41 | LOGGER.info("check 4 passed") 42 | 43 | 44 | @click.command() 45 | @click.option( 46 | "--input-path-attacks", 47 | help="input array of training data (attacked squares)", 48 | required=True, 49 | default="data/preprocessed/fen_eval_attacks.npy.npz", 50 | ) 51 | @click.option( 52 | "--input-path", 53 | help="input array of training data (positions)", 54 | required=True, 55 | default="data/preprocessed/fen_eval.npy.npz", 56 | ) 57 | @click.option( 58 | "--input-path-model-black", 59 | help="path to load best model, if human played black", 60 | required=True, 61 | default="models/best_model_black_human.h5", 62 | ) 63 | @click.option( 64 | "--input-path-model-white", 65 | help="path to load best model, if human played white", 66 | required=True, 67 | default="models/best_model_white_human.h5", 68 | ) 69 | @click.option( 70 | "--params-path", required=True, default="configs/preprocess_params.json", 71 | ) 72 | @click.option( 73 | "--path-json-data", required=True, default="data/raw_data/json/evaluation/eval.json" 74 | ) 75 | @click.option( 76 | "--human-player-color", 77 | help="Black or White, what did the human play", 78 | required=True, 79 | ) 80 | def main( 81 | input_path, 82 | path_json_data, 83 | input_path_attacks, 84 | input_path_model_black, 85 | input_path_model_white, 86 | human_player_color, 87 | params_path, 88 | ): 89 | with open(params_path) as f: 90 | params = json.load(f) 91 | df = pd.read_json(path_json_data) 92 | checkParameters(df, human_player_color, params) 93 | 94 | data_positions = np.load(input_path)["arr_0"] 95 | data_attacks = np.load(input_path_attacks)["arr_0"] 96 | data = np.concatenate((data_positions, data_attacks), axis=2) 97 | if human_player_color == "White": 98 | opponent = "Black" 99 | model_path = input_path_model_white 100 | elif human_player_color == "Black": 101 | opponent = "White" 102 | model_path = input_path_model_black 103 | else: 104 | raise ValueError("please specify the color you played, Black or White") 105 | model = tf.keras.models.load_model(model_path) 106 | try: 107 | res = model.predict(data)[0][1] 108 | except: 109 | data = data.reshape( 110 | data.shape[0], data.shape[1], -1 111 | ) # flatten for use with non-CNN model 112 | res = model.predict(data)[0][1] 113 | 114 | LOGGER.info( 115 | f" Probability that your opponent ({opponent} Player) is using an engine: {res}" 116 | ) 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /python_code/moves_to_fen.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import click 3 | import logging 4 | import numpy as np 5 | import pgn_to_fen 6 | import chess 7 | from tqdm import tqdm 8 | import json 9 | 10 | logging.basicConfig(level=logging.INFO,) 11 | 12 | LOGGER = logging.getLogger() 13 | 14 | 15 | def getFenArray(fen): 16 | fen = fen.split(" ")[0] 17 | fen = ( 18 | fen.replace("1", "0") 19 | .replace("2", "00") 20 | .replace("3", "000") 21 | .replace("4", "0000") 22 | .replace("5", "00000") 23 | .replace("6", "000000") 24 | .replace("7", "0000000") 25 | .replace("8", "00000000") 26 | ) 27 | fen = list(fen.replace("/", "")) 28 | fenArray = np.array(fen) 29 | fenArray = np.reshape(fenArray, [8, 8]) 30 | return fenArray 31 | 32 | 33 | def movesToFenList(movesList): 34 | """given a list of moves, return a list of standard fen notations""" 35 | fenlist = [] 36 | pgnConverter = pgn_to_fen.PgnToFen() 37 | pgnConverter.resetBoard() 38 | try: 39 | for move in movesList: 40 | pgnConverter.move(str(move)) 41 | fen = pgnConverter.getFullFen() 42 | fenlist.append(fen.split(" ")[0]) 43 | return fenlist 44 | except Exception: 45 | LOGGER.info("can not create fen") 46 | 47 | 48 | def fenList_to_fenArray(fenList): 49 | """turn standard fen notation into a fen version with zeros for empty squares 50 | (e.g. "00000000" instead of "8" for eight free squares)""" 51 | fenArrayList = [] 52 | for fen in fenList: 53 | fenArray = getFenArray(fen) 54 | fenArrayList.append(fenArray) 55 | 56 | fenArrayOverTime = np.stack(fenArrayList, axis=0) 57 | return fenArrayOverTime 58 | 59 | 60 | def getAttacksPerPiece(fen): 61 | baseBoard = chess.BaseBoard(fen) 62 | attacks_list = [] 63 | for k in range(64): # 0 = A1, 1 = A2, etc 64 | piece = baseBoard.piece_at(k) 65 | if piece: 66 | attacks = baseBoard.attacks( 67 | k 68 | ).tolist() # 64 bools, true if square is attacked from piece in square k 69 | attacks_sparse = [j for j, x in enumerate(attacks) if x] 70 | attacks_list.append((str(piece), attacks_sparse)) 71 | 72 | return attacks_list 73 | 74 | 75 | def getAttacksByPiecetype(fen): 76 | attacks_by_piece = getAttacksPerPiece(fen) 77 | pieceList = ("P", "R", "N", "B", "Q", "K", "p", "r", "n", "b", "q", "k") 78 | attacks_by_pieceType = {} 79 | for piece in pieceList: 80 | attacked_squares = [ 81 | tup[1] for tup in attacks_by_piece if tup[0] == piece 82 | ] # list of all squares attacked by all pieces of this type (for example all white pawns) 83 | attacked_squares = sum(attacked_squares, []) # flatten 84 | attacks_by_pieceType[piece] = set(attacked_squares) # ignoring double attacks 85 | return attacks_by_pieceType 86 | 87 | 88 | def getAttackTensor(attacks_by_pieceType): 89 | pieceList = ("P", "R", "N", "B", "Q", "K", "p", "r", "n", "b", "q", "k") 90 | output_array = np.zeros((12, 8, 8)) 91 | for channel, piece in enumerate(pieceList): 92 | attacksList = attacks_by_pieceType[piece] 93 | for pos in attacksList: 94 | pos_x = pos % 8 95 | pos_y = 7 - pos // 8 96 | output_array[channel, pos_y, pos_x] = 1 97 | return output_array 98 | 99 | 100 | def getAttacksTensorOverTime(fenList): 101 | """returns tensor of shape (time, channel, row, col) describing a single game, 102 | where the channels represent the 12 piece types and a value of 1 means this piece 103 | attacks this square at this time""" 104 | res = np.zeros((len(fenList), 12, 8, 8)) 105 | for j, fen in enumerate(fenList): 106 | attacks_by_pieceType = getAttacksByPiecetype(fen) 107 | res[j] = getAttackTensor(attacks_by_pieceType) 108 | return res.astype(int) 109 | 110 | 111 | def getFenPerChannel(input_array): 112 | """ takes a fen (with strings describing the pieces on the field) and expands it 113 | in an additional dimension, basically one-hot-encoding the pieces""" 114 | pieceList = ("P", "R", "N", "B", "Q", "K", "p", "r", "n", "b", "q", "k") 115 | res = np.zeros((input_array.shape[0], 12, 8, 8)) # time, channel, row, column 116 | 117 | if input_array is None: 118 | return res 119 | for k, piece in enumerate(pieceList): 120 | mask = input_array == piece 121 | res[:, k, :, :] = mask 122 | return res.astype(int) 123 | 124 | 125 | def getArrayLists(df, min_ply, max_ply): 126 | resList = [] 127 | labelList = [] 128 | attacksList = [] 129 | for moveList, label in tqdm(zip(df["moves"], df["opponentIsComp"])): 130 | # loop over games, TODO: optimize 131 | fenList = movesToFenList(moveList) 132 | failCounter = 0 133 | if not fenList: 134 | failCounter += 1 135 | continue 136 | fenList = fenList[min_ply:max_ply] # only keep positions of the middle game! 137 | fenArray = fenList_to_fenArray(fenList) 138 | fen_per_channel = getFenPerChannel(fenArray) 139 | attacksTensor = getAttacksTensorOverTime(fenList) 140 | 141 | if np.count_nonzero(fen_per_channel) > 0: 142 | resList.append(fen_per_channel) 143 | labelList.append(label) 144 | attacksList.append(attacksTensor) 145 | LOGGER.info(f"failed games: {failCounter}") 146 | return resList, labelList, attacksList 147 | 148 | 149 | @click.command() 150 | @click.option( 151 | "--input-path", help="expects parquet file", required=True, type=click.Path() 152 | ) 153 | @click.option( 154 | "--params-path", help="configuration params", 155 | ) 156 | @click.option("--output-path", help="where to save result", required=True) 157 | @click.option("--output-path-labels", help="where to save labels", required=True) 158 | @click.option( 159 | "--output-path-attacks", help="where to save attack tensors", required=True 160 | ) 161 | def main( 162 | input_path, output_path, output_path_labels, output_path_attacks, params_path, 163 | ): 164 | df = pd.read_parquet(input_path) 165 | with open(params_path) as f: 166 | params = json.load(f) 167 | 168 | resList, labelList, attacksList = getArrayLists( 169 | df, params["plymin"], params["plymax"] 170 | ) 171 | res = np.stack(resList, axis=0).astype(int) 172 | labels = np.array(labelList).astype(int) 173 | attacks = np.stack(attacksList, axis=0).astype(int) 174 | 175 | LOGGER.info( 176 | f"output shape: {res.shape}, labels shape: {labels.shape}, attacks shape: {attacks.shape}" 177 | ) 178 | np.savez_compressed(output_path, res) 179 | np.savez_compressed(output_path_labels, labels) 180 | np.savez_compressed(output_path_attacks, attacks) 181 | 182 | 183 | if __name__ == "__main__": 184 | main() 185 | -------------------------------------------------------------------------------- /python_code/pgn_to_fen.py: -------------------------------------------------------------------------------- 1 | ########################################################################################## 2 | # not my own work! # 3 | # copied from https://github.com/SindreSvendby/pgnToFen/blob/master/pgntofen.py # 4 | ########################################################################################## 5 | 6 | 7 | 8 | #!/bin/python 9 | # coding=utf8 10 | from __future__ import print_function 11 | from __future__ import division 12 | from functools import partial 13 | import math 14 | import re 15 | import os 16 | 17 | 18 | class PgnToFen: 19 | fen = 'rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR' 20 | whiteToMove = True 21 | internalChessBoard = [ 22 | 'R','N','B','Q','K','B','N','R', 23 | 'P','P','P','P','P','P','P','P', 24 | '1','1','1','1','1','1','1','1', 25 | '1','1','1','1','1','1','1','1', 26 | '1','1','1','1','1','1','1','1', 27 | '1','1','1','1','1','1','1','1', 28 | 'p','p','p','p','p','p','p','p', 29 | 'r','n','b','q','k','b','n','r'] 30 | enpassant = '-' 31 | castlingRights = 'KQkq' 32 | DEBUG = False 33 | lastMove = 'Before first move' 34 | fens = [] 35 | result = '' 36 | 37 | def getFullFen(self): 38 | return self.getFen() + ' ' + ('w ' if self.whiteToMove else 'b ') + self.enpassant + ' ' + (self.castlingRights if self.castlingRights else '-') 39 | 40 | def getFen(self): 41 | fenpos = '' 42 | for n in reversed((8,16,24,32,40,48,56,64)): 43 | emptyPosLength = 0; 44 | for i in self.internalChessBoard[n-8:n]: 45 | if(i is not '1'): 46 | if(emptyPosLength is not 0): 47 | fenpos = fenpos + str(emptyPosLength); 48 | emptyPosLength = 0 49 | fenpos = fenpos + i 50 | else: 51 | emptyPosLength = emptyPosLength + 1 52 | if(emptyPosLength is not 0): 53 | fenpos = fenpos + str(emptyPosLength); 54 | fenpos = fenpos + '/' 55 | fenpos = fenpos[:-1] 56 | return fenpos 57 | 58 | def printFen(self): 59 | print(self.getFen()) 60 | 61 | def moves(self, moves): 62 | if isinstance(moves, str): 63 | nrReCompile = re.compile('[0-9]+\.') 64 | transformedMoves = nrReCompile.sub('', moves) 65 | pgnMoves = transformedMoves.replace(' ', ' ').split(' ') 66 | result = pgnMoves[-1:][0] 67 | if(result in ['1/2-1/2', '1-0', '0-1']): 68 | self.result = result 69 | pgnMoves = pgnMoves[:-1] 70 | print('pgnMoves') 71 | print(pgnMoves) 72 | return self.pgnToFen(pgnMoves) 73 | else: 74 | return self.pgnToFen(moves) 75 | 76 | def pgnFile(self, file): 77 | pgnGames = { 78 | 'failed' : [], 79 | 'succeeded' : [], 80 | } 81 | started = False 82 | game_info = [] 83 | pgnMoves = '' 84 | for moves in open(file, 'rt').readlines(): 85 | 86 | if moves[:1] == '[': 87 | #print('game_info line: ', moves) 88 | game_info.append(moves) 89 | continue 90 | if moves[:2] == '1.': 91 | started = True 92 | if (moves == '\n' or moves == '\r\n') and started: 93 | try: 94 | #print('Processing ', game_info[0:6]) 95 | pgnToFen = PgnToFen() 96 | pgnToFen.resetBoard() 97 | fens = pgnToFen.moves(pgnMoves).getAllFens() 98 | pgnGames['succeeded'].append((game_info, fens)) 99 | except ValueError as e: 100 | pgnGames['failed'].append((game_info, '"' + pgnToFen.lastMove + '"', pgnToFen.getFullFen(), e)) 101 | except TypeError as e: 102 | pgnGames['failed'].append((game_info, '"' + pgnToFen.lastMove + '"', pgnToFen.getFullFen(), e)) 103 | except IndexError as e: 104 | raise IndexError(game_info, '"' + pgnToFen.lastMove + '"', pgnToFen.getFullFen(), e) 105 | pgnGames['failed'].append((game_info, '"' + pgnToFen.lastMove + '"', pgnToFen.getFullFen(), e)) 106 | except ZeroDivisionError as e: 107 | pgnGames['failed'].append((game_info, '"' + pgnToFen.lastMove + '"', pgnToFen.getFullFen(), e)) 108 | finally: 109 | started = False 110 | game_info = [] 111 | pgnMoves = '' 112 | if(started): 113 | pgnMoves = pgnMoves + ' ' + moves.replace('\n', '').replace('\r', '') 114 | return pgnGames 115 | 116 | def pgnToFen(self, moves): 117 | try: 118 | loopC = 1 119 | for move in moves: 120 | self.lastMove = move 121 | self.DEBUG and print('=========') 122 | self.DEBUG and print('Movenumber',loopC) 123 | self.DEBUG and print('TO MOVE:', 'w' if self.whiteToMove else 'b') 124 | self.DEBUG and print('MOVE:', move) 125 | self.move(move) 126 | self.DEBUG and print('after move:') 127 | self.DEBUG and self.printBoard() 128 | loopC = loopC + 1 129 | self.fens.append(self.getFullFen()) 130 | self.sucess = True 131 | return self 132 | except ValueError: 133 | print('Converting PGN to FEN failed.') 134 | print('Move that failed:', self.lastMove) 135 | self.printBoard() 136 | print(self.getFullFen()) 137 | self.fens = [] 138 | self.sucess = False 139 | 140 | 141 | 142 | def move(self, move): 143 | try: 144 | self.lastMove = move 145 | self.handleAllmoves(move) 146 | if(self.whiteToMove): 147 | self.whiteToMove = False 148 | else: 149 | self.whiteToMove = True 150 | return self 151 | except ValueError: 152 | self.DEBUG and print('Converting PGN to FEN failed.') 153 | self.DEBUG and print('Move that failed:', self.lastMove) 154 | self.DEBUG and self.printBoard() 155 | self.DEBUG and print('FEN:', self.getFullFen()) 156 | 157 | def getAllFens(self): 158 | return self.fens 159 | 160 | def handleAllmoves(self, move): 161 | move = move.replace('+', '') 162 | move = move.replace('#', '') 163 | promote = '' 164 | if(move.find('=') > -1): 165 | promote = move[-1] 166 | move = move[:-2] 167 | if(move.find('-O') != -1): 168 | self.castelingMove(move) 169 | return; 170 | toPosition = move[-2:] 171 | move = move[:-2] 172 | if len(move) > 0: 173 | if move[0] in ['R','N','B','Q','K']: 174 | officer = move[0] 175 | move = move[1:] 176 | else: 177 | officer = 'P' 178 | else: 179 | officer = 'P' 180 | takes = False 181 | if 'x' in move: 182 | takes = True 183 | move = move[:-1] 184 | specificRow = "" 185 | specificCol = "" 186 | if len(move) > 0: 187 | if move in ['1','2','3','4','5','6','7','8']: 188 | specificRow = move 189 | elif move in ['a','b','c','d','e','f','g','h']: 190 | specificCol = move 191 | elif len(move) == 2: 192 | specificCol = move[0] 193 | specificRow = move[1] 194 | if(officer != 'P'): 195 | self.enpassant = '-' 196 | if(officer == 'N'): 197 | self.knightMove(toPosition, specificCol, specificRow) 198 | elif(officer == 'B'): 199 | self.bishopMove(toPosition, specificCol, specificRow) 200 | elif(officer == 'R'): 201 | self.rookMove(toPosition, specificCol, specificRow) 202 | elif(officer == 'Q'): 203 | self.queenMove(toPosition, specificCol, specificRow) 204 | elif(officer == 'K'): 205 | self.kingMove(toPosition, specificCol, specificRow) 206 | elif(officer == 'P'): 207 | self.pawnMove(toPosition, specificCol, specificRow, takes, promote) 208 | 209 | def castelingMove(self, move): 210 | if(len(move) == 3): #short castling 211 | if(self.whiteToMove): 212 | self.internalChessBoard[7] = '1' 213 | self.internalChessBoard[6] = 'K' 214 | self.internalChessBoard[5] = 'R' 215 | self.internalChessBoard[4] = '1' 216 | self.castlingRights = self.castlingRights.replace('KQ','') 217 | 218 | else: 219 | self.internalChessBoard[63] = '1' 220 | self.internalChessBoard[62] = 'k' 221 | self.internalChessBoard[61] = 'r' 222 | self.internalChessBoard[60] = '1' 223 | self.castlingRights = self.castlingRights.replace('kq', '') 224 | else: # long castling 225 | if(self.whiteToMove): 226 | self.internalChessBoard[0] = '1' 227 | self.internalChessBoard[2] = 'K' 228 | self.internalChessBoard[3] = 'R' 229 | self.internalChessBoard[4] = '1' 230 | self.castlingRights = self.castlingRights.replace('KQ', '') 231 | else: 232 | self.internalChessBoard[60] = '1' 233 | self.internalChessBoard[59] = 'r' 234 | self.internalChessBoard[58] = 'k' 235 | self.internalChessBoard[56] = '1' 236 | self.castlingRights = self.castlingRights.replace('kq', '') 237 | 238 | def queenMove(self, move, specificCol, specificRow): 239 | column = move[:1] 240 | row = move[1:2] 241 | chessBoardNumber = self.placeOnBoard(row, column) 242 | piece = 'Q' if self.whiteToMove else 'q' 243 | possibelPositons = [i for i, pos in enumerate(self.internalChessBoard) if pos == piece] 244 | self.validQueenMoves(possibelPositons, move, specificCol, specificRow) 245 | self.internalChessBoard[chessBoardNumber] = piece 246 | 247 | def validQueenMoves(self, posistions, move, specificCol, specificRow): 248 | newColumn = self.columnToInt(move[:1]) 249 | newRow = self.rowToInt(move[1:2]) 250 | newPos = self.placeOnBoard(newRow + 1, move[:1]) 251 | potensialPosisitionsToRemove=[] 252 | for pos in posistions: 253 | (existingRow, existingCol) = self.internalChessBoardPlaceToPlaceOnBoard(pos) 254 | diffRow = int(existingRow - newRow) 255 | diffCol = int(self.columnToInt(existingCol) - newColumn) 256 | if diffRow == 0 or diffCol == 0 or diffRow == diffCol or -diffRow == diffCol or diffRow == -diffCol: 257 | if not specificCol or specificCol == existingCol: 258 | if not specificRow or (int(specificRow) -1) == int(existingRow): 259 | xVect = 0 260 | yVect = 0 261 | if abs(diffRow) > abs(diffCol): 262 | xVect = -(diffCol / abs(diffRow)) 263 | yVect = -(diffRow / abs(diffRow)) 264 | else: 265 | xVect = -(diffCol / abs(diffCol)) 266 | yVect = -(diffRow / abs(diffCol)) 267 | checkPos = pos 268 | nothingInBetween = True 269 | while(checkPos != newPos): 270 | checkPos = int(checkPos + yVect * 8 + xVect) 271 | if(checkPos == newPos): 272 | continue 273 | if self.internalChessBoard[checkPos] != "1": 274 | nothingInBetween = False 275 | if nothingInBetween: 276 | potensialPosisitionsToRemove.append(pos) 277 | if len(potensialPosisitionsToRemove) == 1: 278 | correctPos = potensialPosisitionsToRemove[0]; 279 | else: 280 | if len(potensialPosisitionsToRemove) == 0: 281 | raise ValueError('Cant find a valid posistion to remove', potensialPosisitionsToRemove) 282 | notInCheckLineBindNewPos = partial(self.notInCheckLine, self.posOnBoard('K')) 283 | correctPosToRemove = list(filter(notInCheckLineBindNewPos, potensialPosisitionsToRemove)) 284 | if len(correctPosToRemove) > 1: 285 | raise ValueError('Several valid positions to remove from the board') 286 | if len(correctPosToRemove) == 0: 287 | raise ValueError('None valid positions to remove from the board') 288 | correctPos = correctPosToRemove[0] 289 | self.internalChessBoard[correctPos] = "1" 290 | return 291 | 292 | 293 | def rookMove(self, move, specificCol, specificRow): 294 | column = move[:1] 295 | row = move[1:2] 296 | chessBoardNumber = self.placeOnBoard(row, column) 297 | piece = 'R' if self.whiteToMove else 'r' 298 | possibelPositons = [i for i, pos in enumerate(self.internalChessBoard) if pos == piece] 299 | self.validRookMoves(possibelPositons, move, specificCol, specificRow) 300 | self.internalChessBoard[chessBoardNumber] = piece 301 | 302 | def validRookMoves(self, posistions, move, specificCol, specificRow): 303 | newColumn = self.columnToInt(move[:1]) 304 | newRow = self.rowToInt(move[1:2]) 305 | newPos = self.placeOnBoard(newRow + 1, move[:1]) 306 | potensialPosisitionsToRemove=[] 307 | if(len(posistions) == 1): 308 | self.internalChessBoard[posistions[0]] = "1" 309 | return 310 | for pos in posistions: 311 | (existingRow, existingCol) = self.internalChessBoardPlaceToPlaceOnBoard(pos) 312 | diffRow = int(existingRow - newRow) 313 | diffCol = int(self.columnToInt(existingCol) - newColumn) 314 | if diffRow == 0 or diffCol == 0: 315 | if not specificCol or specificCol == existingCol: 316 | if not specificRow or (int(specificRow) -1) == int(existingRow): 317 | xVect = 0 318 | yVect = 0 319 | if abs(diffRow) > abs(diffCol): 320 | xVect = -(diffCol / abs(diffRow)) 321 | yVect = -(diffRow / abs(diffRow)) 322 | else: 323 | xVect = -(diffCol / abs(diffCol)) 324 | yVect = -(diffRow / abs(diffCol)) 325 | checkPos = pos 326 | nothingInBetween = True 327 | while(checkPos != newPos): 328 | checkPos = int(checkPos + yVect * 8 + xVect) 329 | if(checkPos == newPos): 330 | continue 331 | if self.internalChessBoard[checkPos] != "1": 332 | nothingInBetween = False 333 | if nothingInBetween: 334 | potensialPosisitionsToRemove.append(pos) 335 | if len(potensialPosisitionsToRemove) == 1: 336 | correctPos = potensialPosisitionsToRemove[0]; 337 | else: 338 | if len(potensialPosisitionsToRemove) == 0: 339 | raise ValueError('Cant find a valid posistion to remove', potensialPosisitionsToRemove) 340 | notInCheckLineBindNewPos = partial(self.notInCheckLine, self.posOnBoard('K')) 341 | correctPosToRemove = list(filter(notInCheckLineBindNewPos, potensialPosisitionsToRemove)) 342 | if len(correctPosToRemove) > 1: 343 | raise ValueError('Several valid positions to remove from the board') 344 | correctPos = correctPosToRemove[0] 345 | if(correctPos == 0): 346 | self.castlingRights = self.castlingRights.replace('Q', '') 347 | elif(correctPos == 63): 348 | self.castlingRights = self.castlingRights.replace('k', '') 349 | elif(correctPos == 7): 350 | self.castlingRights = self.castlingRights.replace('K', '') 351 | elif(correctPos == (63-8)): 352 | self.castlingRights = self.castlingRights.replace('q', '') 353 | self.internalChessBoard[correctPos] = "1" 354 | return 355 | 356 | def kingMove(self, move, specificCol, specificRow): 357 | column = move[:1] 358 | row = move[1:2] 359 | chessBoardNumber = self.placeOnBoard(row, column) 360 | piece = 'K' if self.whiteToMove else 'k' 361 | lostCastleRights = 'Q' if self.whiteToMove else 'q' 362 | kingPos = [i for i, pos in enumerate(self.internalChessBoard) if pos == piece] 363 | self.castlingRights = self.castlingRights.replace(piece, '') 364 | self.castlingRights = self.castlingRights.replace(lostCastleRights, '') 365 | self.internalChessBoard[chessBoardNumber] = piece 366 | self.internalChessBoard[kingPos[0]] = '1' 367 | 368 | 369 | def bishopMove(self, move, specificCol, specificRow): 370 | column = move[:1] 371 | row = move[1:2] 372 | chessBoardNumber = self.placeOnBoard(row, column) 373 | piece = 'B' if self.whiteToMove else 'b' 374 | possibelPositons = [i for i, pos in enumerate(self.internalChessBoard) if pos == piece] 375 | self.validBishopMoves(possibelPositons, move, specificCol, specificRow) 376 | self.internalChessBoard[chessBoardNumber] = piece 377 | 378 | def validBishopMoves(self, posistions, move, specificCol, specificRow): 379 | newColumn = self.columnToInt(move[:1]) 380 | newRow = self.rowToInt(move[1:2]) 381 | newPos = self.placeOnBoard(newRow + 1, move[:1]) 382 | potensialPosisitionsToRemove = [] 383 | for pos in posistions: 384 | (existingRow, existingCol) = self.internalChessBoardPlaceToPlaceOnBoard(pos) 385 | diffRow = int(existingRow - newRow) 386 | diffCol = int(self.columnToInt(existingCol) - newColumn) 387 | if diffRow == diffCol or -diffRow == diffCol or diffRow == -diffCol: 388 | if not specificCol or specificCol == existingCol: 389 | if not specificRow or (int(specificRow) -1) == int(existingRow): 390 | xVect = 0 391 | yVect = 0 392 | if abs(diffRow) > abs(diffCol): 393 | xVect = -(diffCol / abs(diffRow)) 394 | yVect = -(diffRow / abs(diffRow)) 395 | else: 396 | xVect = -(diffCol / abs(diffCol)) 397 | yVect = -(diffRow / abs(diffCol)) 398 | checkPos = pos 399 | nothingInBetween = True 400 | while(checkPos != newPos): 401 | checkPos = int(checkPos + yVect * 8 + xVect) 402 | if(checkPos == newPos): 403 | continue 404 | if self.internalChessBoard[checkPos] != "1": 405 | nothingInBetween = False 406 | if nothingInBetween: 407 | potensialPosisitionsToRemove.append(pos) 408 | if len(potensialPosisitionsToRemove) == 1: 409 | correctPos = potensialPosisitionsToRemove[0]; 410 | else: 411 | if len(potensialPosisitionsToRemove) == 0: 412 | raise ValueError('Cant find a valid posistion to remove', potensialPosisitionsToRemove) 413 | notInCheckLineBindNewPos = partial(self.notInCheckLine, self.posOnBoard('K')) 414 | correctPosToRemove = list(filter(notInCheckLineBindNewPos, potensialPosisitionsToRemove)) 415 | if len(correctPosToRemove) > 1: 416 | raise ValueError('Several valid positions to remove from the board') 417 | correctPos = correctPosToRemove[0] 418 | self.internalChessBoard[correctPos] = "1" 419 | 420 | def knightMove(self, move, specificCol, specificRow): 421 | column = move[:1] 422 | row = move[1:2] 423 | chessBoardNumber = self.placeOnBoard(row, column) 424 | piece = 'N' if self.whiteToMove else 'n' 425 | knightPositons = [i for i, pos in enumerate(self.internalChessBoard) if pos == piece] 426 | self.validKnighMoves(knightPositons, move, specificCol, specificRow) 427 | self.internalChessBoard[chessBoardNumber] = piece 428 | 429 | def validKnighMoves(self, posistions, move, specificCol, specificRow): 430 | newColumn = self.columnToInt(move[:1]) 431 | newRow = self.rowToInt(move[1:2]) 432 | potensialPosisitionsToRemove = [] 433 | for pos in posistions: 434 | (existingRow, existingCol) = self.internalChessBoardPlaceToPlaceOnBoard(pos) 435 | validatePos = str(int(existingRow - newRow)) + str(int(self.columnToInt(existingCol) - newColumn)) 436 | if validatePos in ['2-1','21','1-2','12','-1-2','-12','-2-1','-21']: 437 | if not specificCol or specificCol == existingCol: 438 | if not specificRow or (int(specificRow) -1) == int(existingRow): 439 | potensialPosisitionsToRemove.append(pos) 440 | if len(potensialPosisitionsToRemove) == 1: 441 | correctPos = potensialPosisitionsToRemove[0]; 442 | else: 443 | if len(potensialPosisitionsToRemove) == 0: 444 | raise ValueError('Cant find a valid posistion to remove', potensialPosisitionsToRemove) 445 | notInCheckLineBindNewPos = partial(self.notInCheckLine, self.posOnBoard('K')) 446 | correctPosToRemove = list(filter(notInCheckLineBindNewPos, potensialPosisitionsToRemove)) 447 | if len(correctPosToRemove) > 1: 448 | raise ValueError('Several valid positions to remove from the board') 449 | if len(correctPosToRemove) == 0: 450 | raise ValueError('None valid positions to remove from the board') 451 | correctPos = correctPosToRemove[0] 452 | self.internalChessBoard[correctPos] = "1" 453 | return 454 | def pawnMove(self, toPosition, specificCol, specificRow, takes, promote): 455 | column = toPosition[:1] 456 | row = toPosition[1:2] 457 | chessBoardNumber = self.placeOnBoard(row, column) 458 | if(promote): 459 | piece = promote if self.whiteToMove else promote.lower() 460 | else: 461 | piece = 'P' if self.whiteToMove else 'p' 462 | self.internalChessBoard[chessBoardNumber] = piece 463 | if(takes): 464 | removeFromRow = (int(row) - 1) if self.whiteToMove else (int(row) + 1) 465 | posistion = self.placeOnBoard(removeFromRow, specificCol) 466 | piece = self.internalChessBoard[posistion] = '1' 467 | if(self.enpassant != '-'): 468 | enpassantPos = self.placeOnBoard(self.enpassant[1], self.enpassant[0]) 469 | toPositionPos = self.placeOnBoard(toPosition[1], toPosition[0]) 470 | if(self.enpassant == toPosition): 471 | if(self.whiteToMove == True): 472 | self.internalChessBoard[chessBoardNumber - 8] = '1' 473 | else: 474 | self.internalChessBoard[chessBoardNumber + 8] = '1' 475 | return 476 | 477 | else: 478 | #run piece one more time if case of promotion 479 | piece = 'P' if self.whiteToMove else 'p' 480 | self.updateOldLinePos(piece,chessBoardNumber, toPosition) 481 | 482 | 483 | def updateOldLinePos(self, char, posistion, toPosition): 484 | startPos = posistion 485 | counter = 0; 486 | piece = '' 487 | step = 8 488 | while(posistion >= 0 and posistion < 64): 489 | if(piece == char): 490 | if(abs(posistion - startPos) > 10): 491 | (row, column) = self.internalChessBoardPlaceToPlaceOnBoard(startPos) 492 | rowAdjustedByColor = -1 if self.whiteToMove else 1 493 | enpassant = str(column) + str(int(row) + 1 + rowAdjustedByColor) 494 | self.enpassant = enpassant 495 | else: 496 | self.enpassant = '-' 497 | piece = self.internalChessBoard[posistion] = '1' 498 | return; 499 | else: 500 | if(self.whiteToMove == True): 501 | posistion = posistion - step 502 | else: 503 | posistion = posistion + step 504 | piece = self.internalChessBoard[posistion] 505 | 506 | 507 | def placeOnBoard(self, row, column): 508 | # returns internalChessBoard place 509 | return 8 * (int(row) - 1) + self.columnToInt(column); 510 | 511 | def internalChessBoardPlaceToPlaceOnBoard(self, chessPos): 512 | column = int(chessPos) % 8 513 | row = math.floor(chessPos/8) 514 | return (row, self.intToColum(column)) 515 | 516 | def rowToInt(self, n): 517 | return int(n)-1 518 | 519 | def columnToInt(self, char): 520 | # TODO: char.toLowerCase??? 521 | if(char == 'a'): 522 | return 0 523 | elif(char == 'b'): 524 | return 1 525 | elif(char == 'c'): 526 | return 2 527 | elif(char == 'd'): 528 | return 3 529 | elif(char == 'e'): 530 | return 4 531 | elif(char == 'f'): 532 | return 5 533 | elif(char == 'g'): 534 | return 6 535 | elif(char == 'h'): 536 | return 7 537 | 538 | def intToColum(self, num): 539 | # TODO: char.toLowerCase??? 540 | if(num == 0): 541 | return 'a' 542 | elif(num == 1): 543 | return 'b' 544 | elif(num == 2): 545 | return 'c' 546 | elif(num == 3): 547 | return 'd' 548 | elif(num == 4): 549 | return 'e' 550 | elif(num == 5): 551 | return 'f' 552 | elif(num == 6): 553 | return 'g' 554 | elif(num == 7): 555 | return 'h' 556 | 557 | def resetBoard(self): 558 | self.fen = 'rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR' 559 | self.whiteToMove = True 560 | self.enpassant = '-' 561 | self.internalChessBoard = [ 562 | 'R','N','B','Q','K','B','N','R', 563 | 'P','P','P','P','P','P','P','P', 564 | '1','1','1','1','1','1','1','1', 565 | '1','1','1','1','1','1','1','1', 566 | '1','1','1','1','1','1','1','1', 567 | '1','1','1','1','1','1','1','1', 568 | 'p','p','p','p','p','p','p','p', 569 | 'r','n','b','q','k','b','n','r'] 570 | self.result = '' 571 | 572 | def printBoard(self): 573 | loop = 1 574 | for i in self.internalChessBoard: 575 | print(i, end=' ') 576 | if(loop%8 == 0): 577 | print() 578 | loop = loop + 1 579 | 580 | def notInCheckLine(self, kingPos, piecePos): 581 | """ 582 | Verifies that the piece is not standing in "line of fire" between and enemy piece and your king as the only piece 583 | :returns: True if the piece can move 584 | """ 585 | return self.checkLine(kingPos, piecePos) 586 | 587 | def checkLine(self, kingPos, piecePos): 588 | (kingRowInt, kingColumn) = self.internalChessBoardPlaceToPlaceOnBoard(kingPos) 589 | kingColumnInt = self.columnToInt(kingColumn) 590 | (pieceRowInt, pieceColumn) = self.internalChessBoardPlaceToPlaceOnBoard(piecePos); 591 | pieceColumnInt = self.columnToInt(pieceColumn) 592 | 593 | diffRow = int(kingRowInt - pieceRowInt) 594 | diffCol = int(kingColumnInt - pieceColumnInt) 595 | if (abs(diffRow) != abs(diffCol)) and diffRow != 0 and diffCol != 0: 596 | return True 597 | if abs(diffRow) > abs(diffCol): 598 | xVect = (diffCol / abs(diffRow)) 599 | yVect = -(diffRow / abs(diffRow)) 600 | else: 601 | xVect = -(diffCol / abs(diffCol)) 602 | yVect = -(diffRow / abs(diffCol)) 603 | checkPos = kingPos 604 | nothingInBetween = True 605 | while checkPos != piecePos and (checkPos < 64 and checkPos > 0): 606 | checkPos = int(checkPos + yVect * 8 + xVect) 607 | if(checkPos == piecePos): 608 | continue 609 | if self.internalChessBoard[checkPos] != "1": 610 | #print('Something between king and piece, returning a false value') 611 | # Piece between the king and the piece can not be a self-disvoery-check. 612 | return True 613 | #print('No piece between the king and the piece, need to verify if an enemy piece with the possibily to go that direction exist') 614 | # No piece between the king and the piece, need to verify if an enemy piece with the possibily to go that direction exist 615 | 616 | columnNr = (piecePos % 8) 617 | if(xVect == 1): 618 | columnsLeft = 7- columnNr 619 | else: 620 | columnsLeft = columnNr 621 | posInMove = (yVect * 8) + xVect 622 | 623 | while checkPos >= 0 and checkPos < 64 and columnsLeft > -1: 624 | columnsLeft = columnsLeft - abs(xVect) 625 | checkPos = int(checkPos + posInMove) 626 | if(checkPos < 0 or checkPos > 63): 627 | continue 628 | if self.internalChessBoard[checkPos] in self.getOppositePieces(["Q", "R"]) and (xVect == 0 or yVect == 0): 629 | return False 630 | elif self.internalChessBoard[checkPos] in self.getOppositePieces(["Q", "B"]) and True: 631 | #TODO: check direction 632 | return False 633 | #else: 634 | #print('Friendly pieces or empty:', self.internalChessBoard[checkPos], checkPos) 635 | return True 636 | 637 | def getOppositePieces(self, pieces): 638 | """" 639 | Takes a list of pieces and returns it in uppercase if blacks turn, or lowercase if white. 640 | """ 641 | return map(lambda p: p.lower() if self.whiteToMove else p.upper(), pieces) 642 | 643 | 644 | def posOnBoard(self, piece): 645 | """ 646 | :param piece: a case _sensitiv_ one letter string. Valid 'K', 'Q', 'N', 'P', 'B', 'R', will be transformed to lowercase if it's black's turn to move 647 | :return int|[int]: Returns the posistion(s) on the board for a piece, if only one pos, a int is return, else a list of int is returned 648 | """ 649 | correctPiece = piece if self.whiteToMove else piece.lower() 650 | posistionsOnBoard = [i for i, pos in enumerate(self.internalChessBoard) if pos == correctPiece] 651 | if len(posistionsOnBoard) == 1: 652 | return posistionsOnBoard[0] 653 | else: 654 | return posistionsOnBoard 655 | 656 | if __name__ == "__main__": 657 | pgnFormat = 'c4 Nc6 Nc3 e5 Nf3 Nf6 g3 d5 cxd5 Nxd5 Bg2 Nb6 O-O Be7 a3 Be6 b4 a5 b5 Nd4 Nxd4 exd4 Na4 Bd5 Nxb6 cxb6 Bxd5' 658 | converter = PgnToFen() 659 | for move in pgnFormat.split(' '): 660 | converter.move(move) 661 | print(converter.getFullFen()) -------------------------------------------------------------------------------- /python_code/pgn_to_json.py: -------------------------------------------------------------------------------- 1 | ########################################################################################## 2 | # not my own work! taken from # 3 | # https://github.com/JonathanCauchi/PGN-to-JSON-Parser/blob/master/pgn_to_json.py # 4 | ########################################################################################## 5 | 6 | #!/usr/bin/env python 7 | # -*- coding: utf-8 -*- 8 | import json 9 | import chess.pgn 10 | import re 11 | import sys 12 | import os.path 13 | from tqdm import tqdm 14 | import pathlib 15 | import logging 16 | from datetime import datetime 17 | import sys, traceback 18 | 19 | log = logging.getLogger().error 20 | 21 | for i in [1, 2]: 22 | dir_ = sys.argv[i] 23 | if not os.path.exists(dir_): 24 | raise Exception(dir_ + " not found") 25 | 26 | max_games = int(sys.argv[3]) 27 | 28 | is_join = False 29 | if len(sys.argv) == 5: 30 | if sys.argv[4] == "join": 31 | is_join = True 32 | 33 | 34 | inp_dir = pathlib.Path(sys.argv[1]) 35 | out_dir = pathlib.Path(sys.argv[2]) 36 | 37 | 38 | def get_file_list(local_path): 39 | tree = os.walk(str(local_path)) 40 | file_list = [] 41 | out = [] 42 | test = r".+pgn$" 43 | for i in tree: 44 | file_list = i[2] 45 | 46 | for name in file_list: 47 | if len(re.findall(test, name)): 48 | out.append(str(local_path / name)) 49 | return out 50 | 51 | 52 | def get_data(pgn_file, max_games): 53 | node = chess.pgn.read_game(pgn_file) 54 | error_counter = 0 55 | game_counter = 0 56 | while node is not None and game_counter <= max_games: 57 | game_counter += 1 58 | try: 59 | data = node.headers 60 | 61 | data["moves"] = [] 62 | 63 | while node.variations: 64 | next_node = node.variation(0) 65 | data["moves"].append( 66 | re.sub("\{.*?\}", "", node.board().san(next_node.move)) 67 | ) 68 | node = next_node 69 | 70 | out_dict = {} 71 | 72 | for key in data.keys(): 73 | out_dict[key] = data.get(key) 74 | 75 | # log(data.get('Event')) 76 | node = chess.pgn.read_game(pgn_file) 77 | yield out_dict 78 | except: 79 | error_counter = error_counter + 1 80 | print("skipping {}".format(error_counter)) 81 | node = chess.pgn.read_game(pgn_file) 82 | continue 83 | 84 | 85 | def convert_file(file_path, max_games): 86 | file_name = file_path.name.replace(file_path.suffix, "") + ".json" 87 | log("convert file " + file_path.name) 88 | out_list = [] 89 | try: 90 | json_file = open(str(out_dir / file_name), "w") 91 | pgn_file = open(str(file_path), encoding="utf-8-sig") # changed encoding 92 | 93 | for count_d, data in tqdm(enumerate(get_data(pgn_file, max_games), start=0)): 94 | # log(file_path.name + " " + str(count_d)) 95 | out_list.append(data) 96 | 97 | log(" save " + file_path.name) 98 | json.dump(out_list, json_file) 99 | json_file.close() 100 | log("done") 101 | except Exception as e: 102 | log(traceback.format_exc(10)) 103 | log("ERROR file " + file_name + " not converted") 104 | 105 | 106 | def create_join_file(file_list, max_games): 107 | log(" create_join_file ") 108 | name = str(out_dir / "join_data.json") 109 | open(name, "w").close() 110 | json_file = open(str(out_dir / "join_data.json"), "a") 111 | json_file.write("[") 112 | for count_f, file in enumerate(file_list, start=0): 113 | pgn_file = open(file, encoding="ISO-8859-1") 114 | for count_d, data in tqdm(enumerate(get_data(pgn_file, max_games), start=0)): 115 | # log(str(count_f) + " " + str(count_d)) 116 | if count_f or count_d: 117 | json_file.write(",") 118 | data_str = json.dumps(data) 119 | json_file.write(data_str) 120 | log(pathlib.Path(file).name) 121 | json_file.write("]") 122 | json_file.close() 123 | 124 | 125 | file_list = get_file_list(inp_dir) 126 | 127 | start_time = datetime.now() 128 | if not is_join: 129 | for file in file_list: 130 | convert_file(pathlib.Path(file), max_games) 131 | else: 132 | create_join_file(file_list, max_games) 133 | 134 | end_time = datetime.now() 135 | log("time " + str(end_time - start_time)) 136 | 137 | -------------------------------------------------------------------------------- /python_code/preprocess.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import click 3 | import logging 4 | import json 5 | 6 | logging.basicConfig(level=logging.INFO,) 7 | 8 | LOGGER = logging.getLogger() 9 | 10 | 11 | def loadData(file_list): 12 | df = pd.DataFrame() 13 | for file in file_list: 14 | data = pd.read_json(file) 15 | df = df.append(data, ignore_index=True) 16 | return df 17 | 18 | 19 | def balanceEngineRatio(df): 20 | df_comp = df[df.opponentIsComp == 1.0] 21 | df_human = df[df.opponentIsComp == 0.0] 22 | n_min = min(len(df_comp), len(df_human)) 23 | df = pd.concat([df_comp[:n_min], df_human[:n_min]]) 24 | df = df.sample(frac=1) # reshuffle 25 | df = df.reset_index(drop=True) 26 | return df 27 | 28 | 29 | def prefilterGames(df, params, human_color): 30 | """we want games where the human player (with color human_color) lost, and the opponent is 50/50 human or computer""" 31 | min_game_length = params["plymax"] # need at least as many moves as will be used in algorithm 32 | max_game_length = params["max_game_length"] 33 | df = df[ 34 | (df.PlyCount > min_game_length) & (df.PlyCount < max_game_length) 35 | ] # restrict game lengths. First moves are irrelevant as they can be memorized. 36 | 37 | # choose timecontrol. Very short games are weird (and hard to use engines on due to computation time) 38 | df = df[df["TimeControl"].isin(params["timecontrols"])] 39 | 40 | df.loc[ 41 | df["WhiteIsComp"].isnull(), "WhiteIsComp" 42 | ] = 0.0 # field is null for human vs human games 43 | df.loc[df["WhiteIsComp"] == "Yes", "WhiteIsComp"] = 1.0 44 | df.loc[df["BlackIsComp"].isnull(), "BlackIsComp"] = 0.0 45 | df.loc[df["BlackIsComp"] == "Yes", "BlackIsComp"] = 1.0 46 | 47 | if human_color == "White": 48 | df = df[df["WhiteIsComp"] == 0.0] 49 | df = df[ 50 | df.Result == "0-1" 51 | ] # only interested in games where the human player lost 52 | df = df.rename(columns={"BlackIsComp": "opponentIsComp"}) 53 | elif human_color == "Black": 54 | df = df[df["BlackIsComp"] == 0.0] 55 | df = df[df.Result == "1-0"] 56 | df = df.rename(columns={"WhiteIsComp": "opponentIsComp"}) 57 | 58 | df = df[["opponentIsComp", "moves"]] 59 | df = balanceEngineRatio(df) 60 | return df 61 | 62 | 63 | @click.command() 64 | @click.option( 65 | "--input-paths", help="filenames of input files, separated by comma", required=True 66 | ) 67 | @click.option("--params-path", help="path to config file", required=True) 68 | @click.option("--output-path", help="where to save result (as parquet)", required=True) 69 | @click.option( 70 | "--human-color", help="Black or White, what was the human playing", required=True 71 | ) 72 | def main(input_paths, output_path, params_path, human_color): 73 | with open(params_path) as f: 74 | params = json.load(f) 75 | file_list = input_paths.split(",") 76 | n_files = len(file_list) 77 | LOGGER.info("found {} files".format(n_files)) 78 | df = loadData(file_list) 79 | df = prefilterGames(df, params, human_color) 80 | 81 | LOGGER.info("number of games after preprocessing: {}".format(df.shape[0])) 82 | df.to_parquet(output_path) 83 | 84 | 85 | if __name__ == "__main__": 86 | main() 87 | -------------------------------------------------------------------------------- /python_code/test_moves_to_fen.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | print(sys.path) 4 | import moves_to_fen as mtf 5 | import numpy as np 6 | 7 | 8 | def test_movesToFenList(): 9 | movesList = ["e4", "e5"] 10 | res = mtf.movesToFenList(movesList) 11 | expected = [ 12 | "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR", # e4 13 | "rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR", # e5 14 | ] 15 | assert res == expected 16 | 17 | 18 | def test_fenList_to_fenArray(): 19 | fenList = [ 20 | "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR", # e4 21 | "rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR", # e5 22 | ] 23 | expected_res = np.array( 24 | [ 25 | [ 26 | ["r", "n", "b", "q", "k", "b", "n", "r"], 27 | ["p", "p", "p", "p", "p", "p", "p", "p"], 28 | ["0", "0", "0", "0", "0", "0", "0", "0"], 29 | ["0", "0", "0", "0", "0", "0", "0", "0"], 30 | ["0", "0", "0", "0", "P", "0", "0", "0"], 31 | ["0", "0", "0", "0", "0", "0", "0", "0"], 32 | ["P", "P", "P", "P", "0", "P", "P", "P"], 33 | ["R", "N", "B", "Q", "K", "B", "N", "R"], 34 | ], 35 | [ 36 | ["r", "n", "b", "q", "k", "b", "n", "r"], 37 | ["p", "p", "p", "p", "0", "p", "p", "p"], 38 | ["0", "0", "0", "0", "0", "0", "0", "0"], 39 | ["0", "0", "0", "0", "p", "0", "0", "0"], 40 | ["0", "0", "0", "0", "P", "0", "0", "0"], 41 | ["0", "0", "0", "0", "0", "0", "0", "0"], 42 | ["P", "P", "P", "P", "0", "P", "P", "P"], 43 | ["R", "N", "B", "Q", "K", "B", "N", "R"], 44 | ], 45 | ] 46 | ) 47 | res = mtf.fenList_to_fenArray(fenList) 48 | print(res) 49 | assert np.array_equal(res, expected_res) 50 | 51 | 52 | def test_getFenPerChannel(): 53 | # channel order: ("P", "R", "N", "B", "Q", "K", "p", "r", "n", "b", "q", "k") 54 | fenList = np.array( 55 | [ 56 | [ 57 | ["r", "n", "b", "q", "k", "b", "n", "r"], 58 | ["p", "p", "p", "p", "p", "p", "p", "p"], 59 | ["0", "0", "0", "0", "0", "0", "0", "0"], 60 | ["0", "0", "0", "0", "0", "0", "0", "0"], 61 | ["0", "0", "0", "0", "P", "0", "0", "0"], 62 | ["0", "0", "0", "0", "0", "0", "0", "0"], 63 | ["P", "P", "P", "P", "0", "P", "P", "P"], 64 | ["R", "N", "B", "Q", "K", "B", "N", "R"], 65 | ], 66 | ] 67 | ) 68 | res = mtf.getFenPerChannel(fenList) 69 | expected_res = np.array( 70 | [ 71 | [ 72 | [ 73 | [0, 0, 0, 0, 0, 0, 0, 0], 74 | [0, 0, 0, 0, 0, 0, 0, 0], 75 | [0, 0, 0, 0, 0, 0, 0, 0], 76 | [0, 0, 0, 0, 0, 0, 0, 0], 77 | [0, 0, 0, 0, 1, 0, 0, 0], 78 | [0, 0, 0, 0, 0, 0, 0, 0], 79 | [1, 1, 1, 1, 0, 1, 1, 1], 80 | [0, 0, 0, 0, 0, 0, 0, 0], 81 | ], 82 | [ 83 | [0, 0, 0, 0, 0, 0, 0, 0], 84 | [0, 0, 0, 0, 0, 0, 0, 0], 85 | [0, 0, 0, 0, 0, 0, 0, 0], 86 | [0, 0, 0, 0, 0, 0, 0, 0], 87 | [0, 0, 0, 0, 0, 0, 0, 0], 88 | [0, 0, 0, 0, 0, 0, 0, 0], 89 | [0, 0, 0, 0, 0, 0, 0, 0], 90 | [1, 0, 0, 0, 0, 0, 0, 1], 91 | ], 92 | [ 93 | [0, 0, 0, 0, 0, 0, 0, 0], 94 | [0, 0, 0, 0, 0, 0, 0, 0], 95 | [0, 0, 0, 0, 0, 0, 0, 0], 96 | [0, 0, 0, 0, 0, 0, 0, 0], 97 | [0, 0, 0, 0, 0, 0, 0, 0], 98 | [0, 0, 0, 0, 0, 0, 0, 0], 99 | [0, 0, 0, 0, 0, 0, 0, 0], 100 | [0, 1, 0, 0, 0, 0, 1, 0], 101 | ], 102 | [ 103 | [0, 0, 0, 0, 0, 0, 0, 0], 104 | [0, 0, 0, 0, 0, 0, 0, 0], 105 | [0, 0, 0, 0, 0, 0, 0, 0], 106 | [0, 0, 0, 0, 0, 0, 0, 0], 107 | [0, 0, 0, 0, 0, 0, 0, 0], 108 | [0, 0, 0, 0, 0, 0, 0, 0], 109 | [0, 0, 0, 0, 0, 0, 0, 0], 110 | [0, 0, 1, 0, 0, 1, 0, 0], 111 | ], 112 | [ 113 | [0, 0, 0, 0, 0, 0, 0, 0], 114 | [0, 0, 0, 0, 0, 0, 0, 0], 115 | [0, 0, 0, 0, 0, 0, 0, 0], 116 | [0, 0, 0, 0, 0, 0, 0, 0], 117 | [0, 0, 0, 0, 0, 0, 0, 0], 118 | [0, 0, 0, 0, 0, 0, 0, 0], 119 | [0, 0, 0, 0, 0, 0, 0, 0], 120 | [0, 0, 0, 1, 0, 0, 0, 0], 121 | ], 122 | [ 123 | [0, 0, 0, 0, 0, 0, 0, 0], 124 | [0, 0, 0, 0, 0, 0, 0, 0], 125 | [0, 0, 0, 0, 0, 0, 0, 0], 126 | [0, 0, 0, 0, 0, 0, 0, 0], 127 | [0, 0, 0, 0, 0, 0, 0, 0], 128 | [0, 0, 0, 0, 0, 0, 0, 0], 129 | [0, 0, 0, 0, 0, 0, 0, 0], 130 | [0, 0, 0, 0, 1, 0, 0, 0], 131 | ], 132 | [ 133 | [0, 0, 0, 0, 0, 0, 0, 0], 134 | [1, 1, 1, 1, 1, 1, 1, 1], 135 | [0, 0, 0, 0, 0, 0, 0, 0], 136 | [0, 0, 0, 0, 0, 0, 0, 0], 137 | [0, 0, 0, 0, 0, 0, 0, 0], 138 | [0, 0, 0, 0, 0, 0, 0, 0], 139 | [0, 0, 0, 0, 0, 0, 0, 0], 140 | [0, 0, 0, 0, 0, 0, 0, 0], 141 | ], 142 | [ 143 | [1, 0, 0, 0, 0, 0, 0, 1], 144 | [0, 0, 0, 0, 0, 0, 0, 0], 145 | [0, 0, 0, 0, 0, 0, 0, 0], 146 | [0, 0, 0, 0, 0, 0, 0, 0], 147 | [0, 0, 0, 0, 0, 0, 0, 0], 148 | [0, 0, 0, 0, 0, 0, 0, 0], 149 | [0, 0, 0, 0, 0, 0, 0, 0], 150 | [0, 0, 0, 0, 0, 0, 0, 0], 151 | ], 152 | [ 153 | [0, 1, 0, 0, 0, 0, 1, 0], 154 | [0, 0, 0, 0, 0, 0, 0, 0], 155 | [0, 0, 0, 0, 0, 0, 0, 0], 156 | [0, 0, 0, 0, 0, 0, 0, 0], 157 | [0, 0, 0, 0, 0, 0, 0, 0], 158 | [0, 0, 0, 0, 0, 0, 0, 0], 159 | [0, 0, 0, 0, 0, 0, 0, 0], 160 | [0, 0, 0, 0, 0, 0, 0, 0], 161 | ], 162 | [ 163 | [0, 0, 1, 0, 0, 1, 0, 0], 164 | [0, 0, 0, 0, 0, 0, 0, 0], 165 | [0, 0, 0, 0, 0, 0, 0, 0], 166 | [0, 0, 0, 0, 0, 0, 0, 0], 167 | [0, 0, 0, 0, 0, 0, 0, 0], 168 | [0, 0, 0, 0, 0, 0, 0, 0], 169 | [0, 0, 0, 0, 0, 0, 0, 0], 170 | [0, 0, 0, 0, 0, 0, 0, 0], 171 | ], 172 | [ 173 | [0, 0, 0, 1, 0, 0, 0, 0], 174 | [0, 0, 0, 0, 0, 0, 0, 0], 175 | [0, 0, 0, 0, 0, 0, 0, 0], 176 | [0, 0, 0, 0, 0, 0, 0, 0], 177 | [0, 0, 0, 0, 0, 0, 0, 0], 178 | [0, 0, 0, 0, 0, 0, 0, 0], 179 | [0, 0, 0, 0, 0, 0, 0, 0], 180 | [0, 0, 0, 0, 0, 0, 0, 0], 181 | ], 182 | [ 183 | [0, 0, 0, 0, 1, 0, 0, 0], 184 | [0, 0, 0, 0, 0, 0, 0, 0], 185 | [0, 0, 0, 0, 0, 0, 0, 0], 186 | [0, 0, 0, 0, 0, 0, 0, 0], 187 | [0, 0, 0, 0, 0, 0, 0, 0], 188 | [0, 0, 0, 0, 0, 0, 0, 0], 189 | [0, 0, 0, 0, 0, 0, 0, 0], 190 | [0, 0, 0, 0, 0, 0, 0, 0], 191 | ], 192 | ] 193 | ] 194 | ) 195 | print(res.shape, expected_res.shape) 196 | 197 | assert np.array_equal(res, expected_res) 198 | 199 | 200 | def test_getAttacksPerPiece(): 201 | fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR" # normal starting position 202 | res = mtf.getAttacksPerPiece(fen) 203 | print(res) 204 | expected = [ 205 | ("R", [1, 8]), 206 | ("N", [11, 16, 18]), 207 | ("B", [9, 11]), 208 | ("Q", [2, 4, 10, 11, 12]), 209 | ("K", [3, 5, 11, 12, 13]), 210 | ("B", [12, 14]), 211 | ("N", [12, 21, 23]), 212 | ("R", [6, 15]), 213 | ("P", [17]), 214 | ("P", [16, 18]), 215 | ("P", [17, 19]), 216 | ("P", [18, 20]), 217 | ("P", [19, 21]), 218 | ("P", [20, 22]), 219 | ("P", [21, 23]), 220 | ("P", [22]), 221 | ("p", [41]), 222 | ("p", [40, 42]), 223 | ("p", [41, 43]), 224 | ("p", [42, 44]), 225 | ("p", [43, 45]), 226 | ("p", [44, 46]), 227 | ("p", [45, 47]), 228 | ("p", [46]), 229 | ("r", [48, 57]), 230 | ("n", [40, 42, 51]), 231 | ("b", [49, 51]), 232 | ("q", [50, 51, 52, 58, 60]), 233 | ("k", [51, 52, 53, 59, 61]), 234 | ("b", [52, 54]), 235 | ("n", [45, 47, 52]), 236 | ("r", [55, 62]), 237 | ] 238 | assert res == expected 239 | 240 | 241 | def test_getAttacksByPiecetype(): 242 | fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR" # normal starting position 243 | res = mtf.getAttacksByPiecetype(fen) 244 | print(res) 245 | # numbers 0-59 describe board, starting from 0=A1 246 | expected = { 247 | "P": {16, 17, 18, 19, 20, 21, 22, 23}, 248 | "R": {8, 1, 6, 15}, 249 | "N": {11, 12, 16, 18, 21, 23}, 250 | "B": {9, 11, 12, 14}, 251 | "Q": {2, 4, 10, 11, 12}, 252 | "K": {3, 5, 11, 12, 13}, 253 | "p": {40, 41, 42, 43, 44, 45, 46, 47}, 254 | "r": {48, 57, 62, 55}, 255 | "n": {40, 42, 45, 47, 51, 52}, 256 | "b": {49, 51, 52, 54}, 257 | "q": {50, 51, 52, 58, 60}, 258 | "k": {51, 52, 53, 59, 61}, 259 | } 260 | assert res == expected 261 | 262 | 263 | def test_getAttacksTensorOverTime(): 264 | # channel order: ("P", "R", "N", "B", "Q", "K", "p", "r", "n", "b", "q", "k") 265 | 266 | fenList = [ 267 | "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR", # e4 268 | ] 269 | res = mtf.getAttacksTensorOverTime(fenList) 270 | expected = np.array( 271 | [ 272 | [ 273 | [ 274 | [0, 0, 0, 0, 0, 0, 0, 0], # white pawns 275 | [0, 0, 0, 0, 0, 0, 0, 0], 276 | [0, 0, 0, 0, 0, 0, 0, 0], 277 | [0, 0, 0, 1, 0, 1, 0, 0], 278 | [0, 0, 0, 0, 0, 0, 0, 0], 279 | [1, 1, 1, 1, 1, 1, 1, 1], 280 | [0, 0, 0, 0, 0, 0, 0, 0], 281 | [0, 0, 0, 0, 0, 0, 0, 0], 282 | ], 283 | [ 284 | [0, 0, 0, 0, 0, 0, 0, 0], # white rooks 285 | [0, 0, 0, 0, 0, 0, 0, 0], 286 | [0, 0, 0, 0, 0, 0, 0, 0], 287 | [0, 0, 0, 0, 0, 0, 0, 0], 288 | [0, 0, 0, 0, 0, 0, 0, 0], 289 | [0, 0, 0, 0, 0, 0, 0, 0], 290 | [1, 0, 0, 0, 0, 0, 0, 1], 291 | [0, 1, 0, 0, 0, 0, 1, 0], 292 | ], 293 | [ 294 | [0, 0, 0, 0, 0, 0, 0, 0], # white knights 295 | [0, 0, 0, 0, 0, 0, 0, 0], 296 | [0, 0, 0, 0, 0, 0, 0, 0], 297 | [0, 0, 0, 0, 0, 0, 0, 0], 298 | [0, 0, 0, 0, 0, 0, 0, 0], 299 | [1, 0, 1, 0, 0, 1, 0, 1], 300 | [0, 0, 0, 1, 1, 0, 0, 0], 301 | [0, 0, 0, 0, 0, 0, 0, 0], 302 | ], 303 | [ 304 | [0, 0, 0, 0, 0, 0, 0, 0], # white bishops 305 | [0, 0, 0, 0, 0, 0, 0, 0], 306 | [1, 0, 0, 0, 0, 0, 0, 0], 307 | [0, 1, 0, 0, 0, 0, 0, 0], 308 | [0, 0, 1, 0, 0, 0, 0, 0], 309 | [0, 0, 0, 1, 0, 0, 0, 0], 310 | [0, 1, 0, 1, 1, 0, 1, 0], 311 | [0, 0, 0, 0, 0, 0, 0, 0], 312 | ], 313 | [ 314 | [0, 0, 0, 0, 0, 0, 0, 0], # white queen 315 | [0, 0, 0, 0, 0, 0, 0, 0], 316 | [0, 0, 0, 0, 0, 0, 0, 0], 317 | [0, 0, 0, 0, 0, 0, 0, 1], 318 | [0, 0, 0, 0, 0, 0, 1, 0], 319 | [0, 0, 0, 0, 0, 1, 0, 0], 320 | [0, 0, 1, 1, 1, 0, 0, 0], 321 | [0, 0, 1, 0, 1, 0, 0, 0], 322 | ], 323 | [ 324 | [0, 0, 0, 0, 0, 0, 0, 0], # white king 325 | [0, 0, 0, 0, 0, 0, 0, 0], 326 | [0, 0, 0, 0, 0, 0, 0, 0], 327 | [0, 0, 0, 0, 0, 0, 0, 0], 328 | [0, 0, 0, 0, 0, 0, 0, 0], 329 | [0, 0, 0, 0, 0, 0, 0, 0], 330 | [0, 0, 0, 1, 1, 1, 0, 0], 331 | [0, 0, 0, 1, 0, 1, 0, 0], 332 | ], 333 | [ 334 | [0, 0, 0, 0, 0, 0, 0, 0], # black pawns 335 | [0, 0, 0, 0, 0, 0, 0, 0], 336 | [1, 1, 1, 1, 1, 1, 1, 1], 337 | [0, 0, 0, 0, 0, 0, 0, 0], 338 | [0, 0, 0, 0, 0, 0, 0, 0], 339 | [0, 0, 0, 0, 0, 0, 0, 0], 340 | [0, 0, 0, 0, 0, 0, 0, 0], 341 | [0, 0, 0, 0, 0, 0, 0, 0], 342 | ], 343 | [ 344 | [0, 1, 0, 0, 0, 0, 1, 0], # rooks 345 | [1, 0, 0, 0, 0, 0, 0, 1], 346 | [0, 0, 0, 0, 0, 0, 0, 0], 347 | [0, 0, 0, 0, 0, 0, 0, 0], 348 | [0, 0, 0, 0, 0, 0, 0, 0], 349 | [0, 0, 0, 0, 0, 0, 0, 0], 350 | [0, 0, 0, 0, 0, 0, 0, 0], 351 | [0, 0, 0, 0, 0, 0, 0, 0], 352 | ], 353 | [ 354 | [0, 0, 0, 0, 0, 0, 0, 0], # knights 355 | [0, 0, 0, 1, 1, 0, 0, 0], 356 | [1, 0, 1, 0, 0, 1, 0, 1], 357 | [0, 0, 0, 0, 0, 0, 0, 0], 358 | [0, 0, 0, 0, 0, 0, 0, 0], 359 | [0, 0, 0, 0, 0, 0, 0, 0], 360 | [0, 0, 0, 0, 0, 0, 0, 0], 361 | [0, 0, 0, 0, 0, 0, 0, 0], 362 | ], 363 | [ 364 | [0, 0, 0, 0, 0, 0, 0, 0], # bishops 365 | [0, 1, 0, 1, 1, 0, 1, 0], 366 | [0, 0, 0, 0, 0, 0, 0, 0], 367 | [0, 0, 0, 0, 0, 0, 0, 0], 368 | [0, 0, 0, 0, 0, 0, 0, 0], 369 | [0, 0, 0, 0, 0, 0, 0, 0], 370 | [0, 0, 0, 0, 0, 0, 0, 0], 371 | [0, 0, 0, 0, 0, 0, 0, 0], 372 | ], 373 | [ 374 | [0, 0, 1, 0, 1, 0, 0, 0], # queen 375 | [0, 0, 1, 1, 1, 0, 0, 0], 376 | [0, 0, 0, 0, 0, 0, 0, 0], 377 | [0, 0, 0, 0, 0, 0, 0, 0], 378 | [0, 0, 0, 0, 0, 0, 0, 0], 379 | [0, 0, 0, 0, 0, 0, 0, 0], 380 | [0, 0, 0, 0, 0, 0, 0, 0], 381 | [0, 0, 0, 0, 0, 0, 0, 0], 382 | ], 383 | [ 384 | [0, 0, 0, 1, 0, 1, 0, 0], # king 385 | [0, 0, 0, 1, 1, 1, 0, 0], 386 | [0, 0, 0, 0, 0, 0, 0, 0], 387 | [0, 0, 0, 0, 0, 0, 0, 0], 388 | [0, 0, 0, 0, 0, 0, 0, 0], 389 | [0, 0, 0, 0, 0, 0, 0, 0], 390 | [0, 0, 0, 0, 0, 0, 0, 0], 391 | [0, 0, 0, 0, 0, 0, 0, 0], 392 | ], 393 | ] 394 | ] 395 | ) 396 | print(res.shape, expected.shape) 397 | print(res) 398 | 399 | assert np.array_equal(res, expected) 400 | 401 | -------------------------------------------------------------------------------- /python_code/train_CNN_LSTM.py: -------------------------------------------------------------------------------- 1 | import click 2 | import logging 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow import keras as k 6 | from tensorflow.keras.layers import ( 7 | Dense, 8 | Flatten, 9 | Conv2D, 10 | Dropout, 11 | TimeDistributed, 12 | LSTM, 13 | MaxPooling2D, 14 | BatchNormalization, 15 | Input, 16 | Permute, 17 | ) 18 | from tensorflow.keras.optimizers import SGD, Adam 19 | from tensorflow.keras.regularizers import l2 20 | from tensorflow.keras import Model 21 | import h5py 22 | from tensorflow.keras.models import Sequential 23 | 24 | # from keras.models import load_model 25 | # from keras.preprocessing.image import ImageDataGenerator 26 | # from keras.models import load_model 27 | from tensorflow.keras.callbacks import EarlyStopping 28 | 29 | 30 | from sklearn.model_selection import train_test_split 31 | 32 | 33 | logging.basicConfig(level=logging.INFO,) 34 | 35 | LOGGER = logging.getLogger() 36 | 37 | 38 | def scaleAndSplit(data, labels): 39 | # data = data - np.mean(data) 40 | X_train, X_test, y_train, y_test = train_test_split( 41 | data, labels, test_size=0.2, random_state=42 42 | ) 43 | num_classes = 2 44 | y_train = k.utils.to_categorical(y_train, num_classes) 45 | y_test = k.utils.to_categorical(y_test, num_classes) 46 | return X_train, X_test, y_train, y_test 47 | 48 | 49 | def buildModel( 50 | num_input_channels=24, 51 | num_timesteps=20, 52 | num_filters=20, 53 | num_LSTM=20, 54 | num_Dense=200, 55 | drop_rate=0.2, 56 | reg=None, 57 | ): 58 | kernel_size = (3, 3) 59 | num_classes = 2 60 | ac = "relu" 61 | opt = Adam(lr=0.001, decay=0, beta_1=0.9, beta_2=0.999, epsilon=1e-08) 62 | 63 | inp = Input( 64 | (num_timesteps, num_input_channels, 8, 8) 65 | ) # timesteps, channels, rows, columns 66 | permuted = Permute((1, 3, 4, 2))(inp) # expects channels_last 67 | 68 | x = TimeDistributed( 69 | Conv2D( 70 | num_filters, 71 | kernel_size, 72 | activation=ac, 73 | kernel_regularizer=reg, 74 | # data_format="channels_first", 75 | ) 76 | )(permuted) 77 | 78 | x = BatchNormalization(axis=-1)(x) 79 | x = TimeDistributed( 80 | Conv2D( 81 | 2 * num_filters, 82 | kernel_size, 83 | activation=ac, 84 | kernel_regularizer=reg, 85 | # data_format="channels_first", 86 | ) 87 | )(x) 88 | x = BatchNormalization(axis=-1)(x) 89 | x = TimeDistributed( 90 | Conv2D( 91 | 4 * num_filters, 92 | kernel_size, 93 | activation=ac, 94 | kernel_regularizer=reg, 95 | # data_format="channels_first", 96 | ) 97 | )(x) 98 | x = BatchNormalization(axis=-1)(x) 99 | x = TimeDistributed( 100 | Conv2D( 101 | 8 * num_filters, 102 | kernel_size, 103 | activation=ac, 104 | kernel_regularizer=reg, 105 | padding="same", 106 | # data_format="channels_first", 107 | ) 108 | )(x) 109 | x = BatchNormalization(axis=-1)(x) 110 | x = TimeDistributed(Flatten())(x) 111 | x = LSTM(num_LSTM)(x) 112 | x = Dropout(drop_rate)(x) 113 | x = Dense(num_Dense)(x) 114 | x = Dropout(drop_rate)(x) 115 | out = Dense(num_classes, activation="softmax")(x) 116 | myModel = Model(inputs=[inp], outputs=[out]) 117 | myModel.compile(loss="binary_crossentropy", metrics=["accuracy"], optimizer=opt) 118 | LOGGER.info("built model!") 119 | LOGGER.info(myModel.summary()) 120 | return myModel 121 | 122 | 123 | def trainModel(model, X_train, y_train, X_test, y_test): 124 | # early_stopping = EarlyStopping(monitor="val_loss", patience=2) 125 | hist = model.fit( 126 | X_train, 127 | y_train, 128 | batch_size=128, 129 | epochs=20, 130 | validation_data=(X_test, y_test), 131 | # callbacks=[early_stopping,], 132 | ) 133 | 134 | 135 | @click.command() 136 | @click.option( 137 | "--input-path-attacks", 138 | help="input array of training data (attacked squares)", 139 | required=True, 140 | ) 141 | @click.option("--input-path", help="input array of training data", required=True) 142 | @click.option("--input-path-labels", help="input array of labels", required=True) 143 | @click.option("--output-path", help="where to save the model", required=True) 144 | def main(input_path, input_path_labels, input_path_attacks, output_path): 145 | data_positions = np.load(input_path)["arr_0"] 146 | data_attacks = np.load(input_path_attacks)["arr_0"] 147 | data = np.concatenate((data_positions, data_attacks), axis=2) 148 | labels = np.load(input_path_labels)["arr_0"] 149 | LOGGER.info(f"Training CNN-LSTM on data of shape {data.shape}") 150 | 151 | model = buildModel(num_input_channels=data.shape[2], num_timesteps=data.shape[1]) 152 | 153 | X_train, X_test, y_train, y_test = scaleAndSplit(data, labels) 154 | trainModel(model, X_train, y_train, X_test, y_test) 155 | model.save(output_path) 156 | 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /python_code/train_Conv3D.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import click 4 | import h5py 5 | import numpy as np 6 | import tensorflow as tf 7 | from sklearn.model_selection import train_test_split 8 | from tensorflow import keras as k 9 | from tensorflow.keras import Model 10 | from tensorflow.keras.callbacks import EarlyStopping 11 | from tensorflow.keras.layers import ( 12 | BatchNormalization, 13 | Conv3D, 14 | Dense, 15 | Dropout, 16 | Flatten, 17 | Input, 18 | Permute, 19 | ) 20 | from tensorflow.keras.optimizers import SGD, Adam 21 | from tensorflow.keras.regularizers import l2 22 | 23 | logging.basicConfig(level=logging.INFO,) 24 | 25 | LOGGER = logging.getLogger() 26 | 27 | 28 | def scaleAndSplit(data, labels): 29 | # data = data - np.mean(data) # not needed? 30 | X_train, X_test, y_train, y_test = train_test_split( 31 | data, labels, test_size=0.2, random_state=42 32 | ) 33 | num_classes = 2 34 | y_train = k.utils.to_categorical(y_train, num_classes) 35 | y_test = k.utils.to_categorical(y_test, num_classes) 36 | return X_train, X_test, y_train, y_test 37 | 38 | 39 | def buildModel( 40 | num_input_channels=24, 41 | num_timesteps=20, 42 | num_filters=20, 43 | num_Dense=200, 44 | drop_rate=0.0, 45 | reg=None, 46 | ): 47 | kernel_size = (3, 3, 3) 48 | num_classes = 2 49 | ac = "relu" 50 | opt = Adam(lr=0.001, decay=0, beta_1=0.9, beta_2=0.999, epsilon=1e-08) 51 | 52 | inp = Input( 53 | (num_timesteps, num_input_channels, 8, 8) 54 | ) # timesteps, channels, rows, columns 55 | permuted = Permute((1, 3, 4, 2))(inp) # expects channels_last 56 | conv1 = Conv3D( 57 | num_filters, kernel_size, activation=ac, kernel_regularizer=reg, padding="valid" 58 | )(permuted) 59 | conv2 = Conv3D( 60 | num_filters, kernel_size, activation=ac, kernel_regularizer=reg, padding="valid" 61 | )(conv1) 62 | conv3 = Conv3D( 63 | num_filters, kernel_size, activation=ac, kernel_regularizer=reg, padding="valid" 64 | )(conv2) 65 | conv4 = Conv3D( 66 | num_filters, kernel_size, activation=ac, kernel_regularizer=reg, padding="valid" 67 | )(conv3) 68 | flat = Flatten()(conv4) 69 | dense = Dense(num_Dense)(flat) 70 | dropout = Dropout(drop_rate)(dense) 71 | out = Dense(num_classes, activation="softmax")(dropout) 72 | 73 | model = Model(inputs=[inp], outputs=[out]) 74 | 75 | model.compile(loss="binary_crossentropy", metrics=["accuracy"], optimizer=opt) 76 | LOGGER.info("build model!") 77 | LOGGER.info(model.summary()) 78 | return model 79 | 80 | 81 | def trainModel(model, X_train, y_train, X_test, y_test): 82 | # early_stopping = EarlyStopping(monitor="val_loss", patience=3) 83 | # early_stopping = EarlyStopping(monitor="loss", patience=3) 84 | 85 | hist = model.fit( 86 | X_train, 87 | y_train, 88 | batch_size=128, 89 | epochs=20, 90 | validation_data=(X_test, y_test), 91 | # callbacks=[early_stopping,], 92 | ) 93 | 94 | 95 | @click.command() 96 | @click.option( 97 | "--input-path", help="input array of training data (piece positions)", required=True 98 | ) 99 | @click.option( 100 | "--input-path-attacks", 101 | help="input array of training data (attacked squares)", 102 | required=True, 103 | ) 104 | @click.option("--input-path-labels", help="input array of labels", required=True) 105 | @click.option("--output-path", help="where to save the model", required=True) 106 | def main(input_path, input_path_labels, input_path_attacks, output_path): 107 | # data dims: (samples, time, channel, row, col) 108 | data_positions = np.load(input_path)["arr_0"] 109 | data_attacks = np.load(input_path_attacks)["arr_0"] 110 | data = np.concatenate((data_positions, data_attacks), axis=2) 111 | labels = np.load(input_path_labels)["arr_0"] 112 | LOGGER.info(f"Training Conv3D model on data of shape {data.shape}") 113 | 114 | model = buildModel(num_input_channels=data.shape[2], num_timesteps=data.shape[1]) 115 | 116 | X_train, X_test, y_train, y_test = scaleAndSplit(data, labels) 117 | trainModel(model, X_train, y_train, X_test, y_test) 118 | model.save(output_path) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /python_code/train_Fully_Connected_LSTM.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import click 4 | import h5py 5 | import numpy as np 6 | import tensorflow as tf 7 | from sklearn.model_selection import train_test_split 8 | from tensorflow import keras as k 9 | from tensorflow.keras import Model 10 | from tensorflow.keras.callbacks import EarlyStopping 11 | from tensorflow.keras.layers import ( 12 | BatchNormalization, 13 | Conv3D, 14 | Dense, 15 | TimeDistributed, 16 | Dropout, 17 | Flatten, 18 | Input, 19 | Permute, 20 | LSTM, 21 | ) 22 | from tensorflow.keras.optimizers import SGD, Adam 23 | from tensorflow.keras.regularizers import l2 24 | 25 | logging.basicConfig(level=logging.INFO,) 26 | 27 | LOGGER = logging.getLogger() 28 | 29 | 30 | def scaleAndSplit(data, labels): 31 | # data = data - np.mean(data) # not needed? 32 | X_train, X_test, y_train, y_test = train_test_split( 33 | data, labels, test_size=0.2, random_state=42 34 | ) 35 | num_classes = 2 36 | y_train = k.utils.to_categorical(y_train, num_classes) 37 | y_test = k.utils.to_categorical(y_test, num_classes) 38 | return X_train, X_test, y_train, y_test 39 | 40 | 41 | def buildModel( 42 | num_timesteps=20, 43 | num_input_channels=24, 44 | num_Dense=400, 45 | num_LSTM=20, 46 | drop_rate=0.3, 47 | reg=None, 48 | ): 49 | num_classes = 2 50 | ac = "relu" 51 | opt = Adam(lr=0.001, decay=0, beta_1=0.9, beta_2=0.999, epsilon=1e-08) 52 | 53 | inp = Input((num_timesteps, num_input_channels * 8 * 8)) 54 | 55 | x = TimeDistributed(Dense(num_Dense, activation=ac, kernel_regularizer=reg,))(inp) 56 | x = BatchNormalization(axis=-1)(x) 57 | x = Dropout(drop_rate)(x) 58 | x = LSTM(num_LSTM)(x) 59 | out = Dense(num_classes, activation="softmax")(x) 60 | myModel = Model(inputs=[inp], outputs=[out]) 61 | myModel.compile(loss="binary_crossentropy", metrics=["accuracy"], optimizer=opt) 62 | LOGGER.info("built model!") 63 | LOGGER.info(myModel.summary()) 64 | return myModel 65 | 66 | 67 | def trainModel(model, X_train, y_train, X_test, y_test): 68 | # early_stopping = EarlyStopping(monitor="val_loss", patience=3) 69 | # early_stopping = EarlyStopping(monitor="loss", patience=3) 70 | 71 | hist = model.fit( 72 | X_train, 73 | y_train, 74 | batch_size=128, 75 | epochs=20, 76 | validation_data=(X_test, y_test), 77 | # callbacks=[early_stopping,], 78 | ) 79 | 80 | 81 | @click.command() 82 | @click.option( 83 | "--input-path", help="input array of training data (piece positions)", required=True 84 | ) 85 | @click.option( 86 | "--input-path-attacks", 87 | help="input array of training data (attacked squares)", 88 | required=True, 89 | ) 90 | @click.option("--input-path-labels", help="input array of labels", required=True) 91 | @click.option("--output-path", help="where to save the model", required=True) 92 | def main(input_path, input_path_labels, input_path_attacks, output_path): 93 | # data dims: (samples, time, channel, row, col) 94 | data_positions = np.load(input_path)["arr_0"] 95 | data_attacks = np.load(input_path_attacks)["arr_0"] 96 | data = np.concatenate((data_positions, data_attacks), axis=2) 97 | num_channels = data.shape[2] 98 | data = data.reshape(data.shape[0], data.shape[1], -1) 99 | labels = np.load(input_path_labels)["arr_0"] 100 | LOGGER.info(f"Training Conv3D model on data of shape {data.shape}") 101 | 102 | model = buildModel(num_input_channels=num_channels, num_timesteps=data.shape[1]) 103 | 104 | X_train, X_test, y_train, y_test = scaleAndSplit(data, labels) 105 | trainModel(model, X_train, y_train, X_test, y_test) 106 | model.save(output_path) 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | --------------------------------------------------------------------------------