├── .bumpversion.cfg ├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── .gitignore ├── .travis.yml ├── LICENSE ├── Makefile ├── README.md ├── data └── .gitkeep ├── images └── hero.png ├── notebooks ├── 1-load-and-convert-statsbomb-data.ipynb ├── 1-load-and-convert-wyscout-data.ipynb ├── 2-basic-usage.ipynb ├── 3-computing-and-storing-features.ipynb └── 4-creating-custom-xg-pipelines.ipynb ├── poetry.lock ├── pyproject.toml ├── setup.cfg ├── soccer_xg ├── __init__.py ├── api.py ├── calibration.py ├── features.py ├── metrics.py ├── ml │ ├── logreg.py │ ├── mlp.py │ ├── pipeline.py │ ├── preprocessing.py │ ├── tree_based_LR.py │ └── xgboost.py ├── models │ ├── openplay_logreg_advanced │ ├── openplay_logreg_basic │ ├── openplay_xgboost_advanced │ └── openplay_xgboost_basic ├── utils.py ├── visualisation.py └── xg.py └── tests ├── conftest.py ├── data └── download.py ├── test_api.py ├── test_metrics.py └── test_xg.py /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.0.1 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:pyproject.toml] 7 | search = version = "{current_version}" 8 | replace = version = "{new_version}" 9 | 10 | [bumpversion:file:soccer_xg/__init__.py] 11 | search = __version__ = '{current_version}' 12 | replace = __version__ = '{new_version}' 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Run code 16 | ```python 17 | print("Hello world") 18 | ``` 19 | 2. On [minimal data example](/link/to/data.ext) 20 | 3. See error 21 | 22 | **Expected behavior** 23 | A clear and concise description of what you expected to happen. 24 | 25 | **Additional context** 26 | Add any other context about the problem here. 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .envrc 2 | analysis/ 3 | # exclude all data, but keep the folder 4 | data/* 5 | !data/.gitkeep 6 | tests/data/* 7 | !tests/data/download.py 8 | # exclude all models, unless explicitly allowed 9 | soccer_xg/models/* 10 | !soccer_xg/models/openplay_logreg_basic 11 | !soccer_xg/models/openplay_xgboost_basic 12 | !soccer_xg/models/openplay_logreg_advanced 13 | !soccer_xg/models/openplay_xgboost_advanced 14 | 15 | 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | *.py,cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | cover/ 68 | 69 | # Translations 70 | *.mo 71 | *.pot 72 | 73 | # Django stuff: 74 | *.log 75 | local_settings.py 76 | db.sqlite3 77 | db.sqlite3-journal 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | .pybuilder/ 91 | target/ 92 | 93 | # Jupyter Notebook 94 | .ipynb_checkpoints 95 | 96 | # IPython 97 | profile_default/ 98 | ipython_config.py 99 | 100 | # pyenv 101 | # For a library or package, you might want to ignore these files since the code is 102 | # intended to run in multiple environments; otherwise, check them in: 103 | # .python-version 104 | 105 | # pipenv 106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 109 | # install all needed dependencies. 110 | #Pipfile.lock 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | cache: pip 4 | 5 | python: 6 | - "3.6" 7 | 8 | install: 9 | - pip install poetry codecov 10 | - poetry install 11 | 12 | script: 13 | - make test BIN="" 14 | 15 | after_success: 16 | - codecov 17 | 18 | before_deploy: 19 | - poetry build 20 | 21 | deploy: 22 | provider: script 23 | script: poetry publish --no-interaction --username=$PYPI_USER --password=$PYPI_PASS 24 | skip_cleanup: true 25 | on: 26 | branch: master 27 | tags: true 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | https://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | Copyright 2013-2018 Docker, Inc. 180 | 181 | Licensed under the Apache License, Version 2.0 (the "License"); 182 | you may not use this file except in compliance with the License. 183 | You may obtain a copy of the License at 184 | 185 | https://www.apache.org/licenses/LICENSE-2.0 186 | 187 | Unless required by applicable law or agreed to in writing, software 188 | distributed under the License is distributed on an "AS IS" BASIS, 189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 190 | See the License for the specific language governing permissions and 191 | limitations under the License. 192 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: init test lint pretty precommit_install bump_major bump_minor bump_patch clean 2 | 3 | BIN = .venv/bin/ 4 | CODE = soccer_xg 5 | 6 | init: 7 | python3 -m venv .venv 8 | poetry install 9 | 10 | tests/data/spadl-statsbomb-WC-2018.h5: 11 | $(BIN)python tests/data/download.py 12 | 13 | test: tests/data/spadl-statsbomb-WC-2018.h5 14 | $(BIN)pytest --verbosity=2 --showlocals --strict --log-level=DEBUG --cov=$(CODE) $(args) 15 | 16 | lint: 17 | $(BIN)flake8 --jobs 4 --statistics --show-source $(CODE) tests 18 | $(BIN)pylint --jobs 4 --rcfile=setup.cfg $(CODE) 19 | $(BIN)mypy $(CODE) tests 20 | $(BIN)black --py36 --skip-string-normalization --line-length=79 --check $(CODE) tests 21 | $(BIN)pytest --dead-fixtures --dup-fixtures 22 | 23 | pretty: 24 | $(BIN)isort --apply --recursive $(CODE) tests 25 | $(BIN)black --target-version py36 --skip-string-normalization --line-length=79 $(CODE) tests 26 | $(BIN)unify --in-place --recursive $(CODE) tests 27 | 28 | precommit_install: 29 | echo '#!/bin/sh\nmake test\n' > .git/hooks/pre-commit 30 | chmod +x .git/hooks/pre-commit 31 | 32 | bump_major: 33 | $(BIN)bumpversion major 34 | 35 | bump_minor: 36 | $(BIN)bumpversion minor 37 | 38 | bump_patch: 39 | $(BIN)bumpversion patch 40 | 41 | clean: 42 | find . -type f -name "*.py[co]" -delete 43 | find . -type d -name "__pycache__" -delete 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Soccer xG

3 |

A Python package for training and analyzing expected goals (xG) models in soccer.

4 | 5 |
6 |
7 |
8 |
9 | 10 | ## About 11 | 12 | This repository contains the code and models for our series on the analysis of xG models: 13 | 14 | - [How data availability affects the ability to learn good xG models](https://dtai.cs.kuleuven.be/sports/blog/how-data-availability-affects-the-ability-to-learn-good-xg-models) 15 | - [Illustrating the interplay between features and models in xG](https://dtai.cs.kuleuven.be/sports/blog/illustrating-the-interplay-between-features-and-models-in-xg) 16 | - [How data quality affects xG](https://dtai.cs.kuleuven.be/sports/blog/how-data-quality-affects-xg) 17 | 18 | In particular, it contains code for experimenting with an exhaustive set of features and machine learning pipelines for predicting xG values from soccer event stream data. Since we rely on the [SPADL](https://github.com/ML-KULeuven/socceraction) language as input format, `soccer_xg` currently supports event streams provided by Opta, Wyscout, and StatsBomb. 19 | 20 | ## Getting started 21 | 22 | The recommended way to install `soccer_xg` is to simply use pip: 23 | 24 | ```sh 25 | $ pip install soccer-xg 26 | ``` 27 | 28 | Subsequently, a basic xG model can be trained and applied with the code below: 29 | 30 | ```python 31 | from itertools import product 32 | from soccer_xg import XGModel, DataApi 33 | 34 | # load the data 35 | provider = 'wyscout_opensource' 36 | leagues = ['ENG', 'ESP', 'ITA', 'GER', 'FRA'] 37 | seasons = ['1718'] 38 | api = DataApi([f"data/{provider}/spadl-{provider}-{l}-{s}.h5" 39 | for (l,s) in product(leagues, seasons)]) 40 | # load the default pipeline 41 | model = XGModel() 42 | # train the model 43 | model.train(api, training_seasons=[('ESP', '1718'), ('ITA', '1718'), ('GER', '1718')]) 44 | # validate the model 45 | model.validate(api, validation_seasons=[('ENG', '1718')]) 46 | # predict xG values 47 | model.estimate(api, game_ids=[2500098]) 48 | ``` 49 | 50 | Although this default pipeline is suitable for computing xG, it is by no means the best possible model. 51 | The notebook [`4-creating-custom-xg-pipelines`](./notebooks/4-creating-custom-xg-pipelines.ipynb) illustrates how you can train your own xG models or you can use one of the four pipelines used in our blogpost series. These can be loaded with: 52 | 53 | ```python 54 | XGModel.load_model('openplay_logreg_basic') 55 | XGModel.load_model('openplay_xgboost_basic') 56 | XGModel.load_model('openplay_logreg_advanced') 57 | XGModel.load_model('openplay_xgboost_advanced') 58 | ``` 59 | 60 | Note that these models are meant to predict shots from open play. To be able to compute xG values from all shot types, you will have to combine them with a pipeline for penalties and free kicks. 61 | 62 | ```python 63 | from soccer_xg import xg 64 | 65 | openplay_model = xg.XGModel.load_model(f'openplay_xgboost_advanced') # custom pipeline for open play shots 66 | penalty_model = xg.PenaltyXGModel() # default pipeline for penalties 67 | freekick_model = xg.FreekickXGModel() # default pipeline for free kicks 68 | 69 | model = xg.XGModel() 70 | model.model = [openplay_model, penalty_model, freekick_model] 71 | model.train(api, training_seasons=...) 72 | ``` 73 | 74 | ## For developers 75 | 76 | **Create venv and install deps** 77 | 78 | make init 79 | 80 | **Install git precommit hook** 81 | 82 | make precommit_install 83 | 84 | **Run linters, autoformat, tests etc.** 85 | 86 | make pretty lint test 87 | 88 | **Bump new version** 89 | 90 | make bump_major 91 | make bump_minor 92 | make bump_patch 93 | 94 | ## Research 95 | 96 | If you make use of this package in your research, please use the following citation: 97 | 98 | ``` 99 | @inproceedings{robberechts2020data, 100 | title={How data availability affects the ability to learn good xG models}, 101 | author={Robberechts, Pieter and Davis, Jesse}, 102 | booktitle={International Workshop on Machine Learning and Data Mining for Sports Analytics}, 103 | pages={17--27}, 104 | year={2020}, 105 | organization={Springer} 106 | } 107 | ``` 108 | 109 | ## License 110 | 111 | Copyright (c) DTAI - KU Leuven – All rights reserved. 112 | Licensed under the Apache License, Version 2.0 113 | Written by [Pieter Robberechts](https://people.cs.kuleuven.be/~pieter.robberechts/), 2020 114 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-KULeuven/soccer_xg/b9489d929e0fa34771267d429256366d0fda27ad/data/.gitkeep -------------------------------------------------------------------------------- /images/hero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-KULeuven/soccer_xg/b9489d929e0fa34771267d429256366d0fda27ad/images/hero.png -------------------------------------------------------------------------------- /notebooks/1-load-and-convert-statsbomb-data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Data Preparation\n", 8 | "\n", 9 | "This notebook loads the 2018 World Cup dataset provided by StatsBomb and converts it to the [SPADL format](https://github.com/ML-KULeuven/socceraction)." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "**Disclaimer**: this notebook is compatible with the following package versions:\n", 17 | "\n", 18 | "- tqdm 4.42.1\n", 19 | "- pandas 1.0\n", 20 | "- socceraction 0.1.1" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import os; import sys\n", 30 | "from tqdm.notebook import tqdm\n", 31 | "\n", 32 | "import math\n", 33 | "import pandas as pd\n", 34 | "\n", 35 | "import socceraction.spadl as spadl\n", 36 | "import socceraction.spadl.statsbomb as statsbomb" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "## Configure leagues and seasons to download and convert\n", 44 | "The two dictionaries below map my internal season and league IDs to Statsbomb's IDs. Using an internal ID makes it easier to work with data from multiple providers." 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "seasons = {\n", 54 | " 3: '2018',\n", 55 | "}\n", 56 | "leagues = {\n", 57 | " 'FIFA World Cup': 'WC',\n", 58 | "}" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "## Configure folder names and download URLs\n", 66 | "\n", 67 | "The two cells below define the URLs from where the data are downloaded and were data is stored." 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "free_open_data_remote = \"https://raw.githubusercontent.com/statsbomb/open-data/master/data/\"" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 4, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "Directory ../data/statsbomb_opensource/raw created \n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "spadl_datafolder = \"../data/statsbomb_opensource\"\n", 94 | "raw_datafolder = f\"../data/statsbomb_opensource/raw\"\n", 95 | "\n", 96 | "# Create data folder if it doesn't exist\n", 97 | "for d in [raw_datafolder, spadl_datafolder]:\n", 98 | " if not os.path.exists(d):\n", 99 | " os.makedirs(d, exist_ok=True)\n", 100 | " print(f\"Directory {d} created \")" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "## Set up the statsbombloader" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 5, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "SBL = statsbomb.StatsBombLoader(root=free_open_data_remote, getter=\"remote\")" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "## Select competitions to load and convert" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 6, 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "data": { 133 | "text/plain": [ 134 | "{'Champions League',\n", 135 | " \"FA Women's Super League\",\n", 136 | " 'FIFA World Cup',\n", 137 | " 'La Liga',\n", 138 | " 'NWSL',\n", 139 | " 'Premier League',\n", 140 | " \"Women's World Cup\"}" 141 | ] 142 | }, 143 | "execution_count": 6, 144 | "metadata": {}, 145 | "output_type": "execute_result" 146 | } 147 | ], 148 | "source": [ 149 | "# View all available competitions\n", 150 | "df_competitions = SBL.competitions()\n", 151 | "set(df_competitions.competition_name)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 7, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "data": { 161 | "text/html": [ 162 | "
\n", 163 | "\n", 176 | "\n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | "
competition_idseason_idcountry_namecompetition_namecompetition_genderseason_namematch_updatedmatch_available
17433InternationalFIFA World Cupmale20182019-12-16T23:09:16.1687562019-12-16T23:09:16.168756
\n", 204 | "
" 205 | ], 206 | "text/plain": [ 207 | " competition_id season_id country_name competition_name \\\n", 208 | "17 43 3 International FIFA World Cup \n", 209 | "\n", 210 | " competition_gender season_name match_updated \\\n", 211 | "17 male 2018 2019-12-16T23:09:16.168756 \n", 212 | "\n", 213 | " match_available \n", 214 | "17 2019-12-16T23:09:16.168756 " 215 | ] 216 | }, 217 | "execution_count": 7, 218 | "metadata": {}, 219 | "output_type": "execute_result" 220 | } 221 | ], 222 | "source": [ 223 | "df_selected_competitions = df_competitions[df_competitions.competition_name.isin(\n", 224 | " leagues.keys()\n", 225 | ")]\n", 226 | "\n", 227 | "df_selected_competitions" 228 | ] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "metadata": {}, 233 | "source": [ 234 | "## Convert to the SPADL format" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 8, 240 | "metadata": { 241 | "scrolled": true 242 | }, 243 | "outputs": [ 244 | { 245 | "data": { 246 | "application/vnd.jupyter.widget-view+json": { 247 | "model_id": "ea0aaba0b07640538fcea5bc70a41681", 248 | "version_major": 2, 249 | "version_minor": 0 250 | }, 251 | "text/plain": [ 252 | "HBox(children=(FloatProgress(value=0.0, description='Loading match data', max=64.0, style=ProgressStyle(descri…" 253 | ] 254 | }, 255 | "metadata": {}, 256 | "output_type": "display_data" 257 | }, 258 | { 259 | "name": "stdout", 260 | "output_type": "stream", 261 | "text": [ 262 | "\n" 263 | ] 264 | }, 265 | { 266 | "name": "stderr", 267 | "output_type": "stream", 268 | "text": [ 269 | "/home/pieterr/Jupiter/Projects/soccer_dataprovider_comparison/.venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py:3331: PerformanceWarning: \n", 270 | "your performance may suffer as PyTables will pickle object types that it cannot\n", 271 | "map directly to c-types [inferred_type->mixed,key->block1_values] [items->Index(['game_date', 'kick_off', 'competition_id', 'country_name',\n", 272 | " 'competition_name', 'season_id', 'season_name', 'home_team_name',\n", 273 | " 'home_team_gender', 'home_team_group', 'name', 'managers',\n", 274 | " 'away_team_name', 'away_team_gender', 'away_team_group', 'match_status',\n", 275 | " 'last_updated', 'data_version'],\n", 276 | " dtype='object')]\n", 277 | "\n", 278 | " exec(code_obj, self.user_global_ns, self.user_ns)\n", 279 | "/home/pieterr/Jupiter/Projects/soccer_dataprovider_comparison/.venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py:3331: PerformanceWarning: \n", 280 | "your performance may suffer as PyTables will pickle object types that it cannot\n", 281 | "map directly to c-types [inferred_type->mixed,key->block1_values] [items->Index(['player_name', 'player_nickname', 'country_name', 'extra'], dtype='object')]\n", 282 | "\n", 283 | " exec(code_obj, self.user_global_ns, self.user_ns)\n", 284 | "/home/pieterr/Jupiter/Projects/soccer_dataprovider_comparison/.venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py:3331: PerformanceWarning: \n", 285 | "your performance may suffer as PyTables will pickle object types that it cannot\n", 286 | "map directly to c-types [inferred_type->mixed-integer,key->block2_values] [items->Index(['player_name', 'position_name', 'extra', 'team_name'], dtype='object')]\n", 287 | "\n", 288 | " exec(code_obj, self.user_global_ns, self.user_ns)\n" 289 | ] 290 | } 291 | ], 292 | "source": [ 293 | "for competition in df_selected_competitions.itertuples():\n", 294 | " # Get matches from all selected competition\n", 295 | " matches = SBL.matches(competition.competition_id, competition.season_id)\n", 296 | "\n", 297 | " matches_verbose = tqdm(list(matches.itertuples()), desc=\"Loading match data\")\n", 298 | " teams, players, player_games = [], [], []\n", 299 | " \n", 300 | " competition_id = leagues[competition.competition_name]\n", 301 | " season_id = seasons[competition.season_id]\n", 302 | " spadl_h5 = os.path.join(spadl_datafolder, f\"spadl-statsbomb_opensource-{competition_id}-{season_id}.h5\")\n", 303 | " with pd.HDFStore(spadl_h5) as spadlstore:\n", 304 | " \n", 305 | " spadlstore[\"actiontypes\"] = spadl.actiontypes_df()\n", 306 | " spadlstore[\"results\"] = spadl.results_df()\n", 307 | " spadlstore[\"bodyparts\"] = spadl.bodyparts_df()\n", 308 | " \n", 309 | " for match in matches_verbose:\n", 310 | " # load data\n", 311 | " teams.append(SBL.teams(match.match_id))\n", 312 | " players.append(SBL.players(match.match_id))\n", 313 | " events = SBL.events(match.match_id)\n", 314 | "\n", 315 | " # convert data\n", 316 | " player_games.append(statsbomb.extract_player_games(events))\n", 317 | " spadlstore[f\"actions/game_{match.match_id}\"] = statsbomb.convert_to_actions(events,match.home_team_id)\n", 318 | "\n", 319 | " games = matches.rename(columns={\"match_id\": \"game_id\", \"match_date\": \"game_date\"})\n", 320 | " games.season_id = season_id\n", 321 | " games.competition_id = competition_id\n", 322 | " spadlstore[\"games\"] = games\n", 323 | " spadlstore[\"teams\"] = pd.concat(teams).drop_duplicates(\"team_id\").reset_index(drop=True)\n", 324 | " spadlstore[\"players\"] = pd.concat(players).drop_duplicates(\"player_id\").reset_index(drop=True)\n", 325 | " spadlstore[\"player_games\"] = pd.concat(player_games).reset_index(drop=True)" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": null, 331 | "metadata": {}, 332 | "outputs": [], 333 | "source": [] 334 | } 335 | ], 336 | "metadata": { 337 | "kernelspec": { 338 | "display_name": "soccer_dataprovider_comparison", 339 | "language": "python", 340 | "name": "soccer_dataprovider_comparison" 341 | }, 342 | "language_info": { 343 | "codemirror_mode": { 344 | "name": "ipython", 345 | "version": 3 346 | }, 347 | "file_extension": ".py", 348 | "mimetype": "text/x-python", 349 | "name": "python", 350 | "nbconvert_exporter": "python", 351 | "pygments_lexer": "ipython3", 352 | "version": "3.6.2" 353 | }, 354 | "toc": { 355 | "base_numbering": 1, 356 | "nav_menu": {}, 357 | "number_sections": true, 358 | "sideBar": true, 359 | "skip_h1_title": false, 360 | "title_cell": "Table of Contents", 361 | "title_sidebar": "Contents", 362 | "toc_cell": false, 363 | "toc_position": {}, 364 | "toc_section_display": true, 365 | "toc_window_display": true 366 | } 367 | }, 368 | "nbformat": 4, 369 | "nbformat_minor": 2 370 | } 371 | -------------------------------------------------------------------------------- /notebooks/1-load-and-convert-wyscout-data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Data Preparation\n", 8 | "\n", 9 | "This notebook downloads the opensource [Wyscoutmatch event dataset](https://figshare.com/collections/Soccer_match_event_dataset/4415000/2) and converts it to the [SPADL format](https://github.com/ML-KULeuven/socceraction). This dataset contains all spatio-temporal events (passes, shots, fouls, etc.) that occured during all matches of the 2017/18 season of the top-5 European leagues (La Liga, Serie A, Bundesliga, Premier League, Ligue 1) as well as the FIFA World Cup 2018 and UEFA Euro Cup 2016." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "**Disclaimer**: this notebook is compatible with [v5 of the Soccer match event dataset](https://figshare.com/collections/Soccer_match_event_dataset/4415000/5) and the following package versions:\n", 17 | "\n", 18 | "- tqdm 4.42.1\n", 19 | "- pandas 1.0\n", 20 | "- socceraction 0.1.1" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 11, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import os\n", 30 | "import sys\n", 31 | "\n", 32 | "from tqdm.notebook import tqdm\n", 33 | "\n", 34 | "import math\n", 35 | "\n", 36 | "import pandas as pd\n", 37 | "pd.set_option('display.max_columns', None)\n", 38 | "\n", 39 | "from io import BytesIO\n", 40 | "from pathlib import Path\n", 41 | "\n", 42 | "from urllib.parse import urlparse\n", 43 | "from urllib.request import urlopen, urlretrieve\n", 44 | "# optional: if you get a SSL CERTIFICATE_VERIFY_FAILED exception\n", 45 | "import ssl; ssl._create_default_https_context = ssl._create_unverified_context\n", 46 | "\n", 47 | "from zipfile import ZipFile, is_zipfile\n", 48 | "\n", 49 | "import socceraction.spadl as spadl\n", 50 | "import socceraction.spadl.wyscout as wyscout" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "## Configure leagues and seasons to download and convert\n", 58 | "The two dictionaries below map my internal season and league IDs to Wyscout's IDs. Using an internal ID makes it easier to work with data from multiple providers." 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 12, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "seasons = {\n", 68 | " 181248: '1718',\n", 69 | " 181150: '1718',\n", 70 | " 181144: '1718',\n", 71 | " 181189: '1718',\n", 72 | " 181137: '1718'\n", 73 | "}\n", 74 | "leagues = {\n", 75 | " 'England':'ENG',\n", 76 | " 'France':'FRA',\n", 77 | " 'Germany':'GER',\n", 78 | " 'Italy':'ITA',\n", 79 | " 'Spain':'ESP'\n", 80 | "}" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "## Configure folder names and download URLs\n", 88 | "\n", 89 | "The two cells below define the URLs from where the data are downloaded and were data is stored." 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 13, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "# https://figshare.com/collections/Soccer_match_event_dataset/4415000/5\n", 99 | "dataset_urls = dict(\n", 100 | " competitions = \"https://ndownloader.figshare.com/files/15073685\",\n", 101 | " teams = \"https://ndownloader.figshare.com/files/15073697\",\n", 102 | " players = \"https://ndownloader.figshare.com/files/15073721\",\n", 103 | " matches = \"https://ndownloader.figshare.com/files/14464622\",\n", 104 | " events = \"https://ndownloader.figshare.com/files/14464685\"\n", 105 | ")" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 14, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "raw_datafolder = \"../data/wyscout_opensource/raw\"\n", 115 | "spadl_datafolder = \"../data/wyscout_opensource\"\n", 116 | "\n", 117 | "# Create data folder if it doesn't exist\n", 118 | "for d in [raw_datafolder, spadl_datafolder]:\n", 119 | " if not os.path.exists(d):\n", 120 | " os.makedirs(d, exist_ok=True)\n", 121 | " print(f\"Directory {d} created \")" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "## Download WyScout data \n", 129 | "\n", 130 | "The following cell loops through the dataset_urls dict and stores each downloaded data file to the `raw_datafolder` in the local file system.\n", 131 | "\n", 132 | "If the downloaded data file is a ZIP archive, the included JSON files are extracted from the ZIP archive." 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 15, 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "data": { 142 | "application/vnd.jupyter.widget-view+json": { 143 | "model_id": "ff51207484174bffbc24d14fde05ca2b", 144 | "version_major": 2, 145 | "version_minor": 0 146 | }, 147 | "text/plain": [ 148 | "HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))" 149 | ] 150 | }, 151 | "metadata": {}, 152 | "output_type": "display_data" 153 | }, 154 | { 155 | "name": "stdout", 156 | "output_type": "stream", 157 | "text": [ 158 | "\n", 159 | "Downloaded files:\n" 160 | ] 161 | }, 162 | { 163 | "data": { 164 | "text/plain": [ 165 | "['events_France.json',\n", 166 | " 'events_England.json',\n", 167 | " 'events.zip',\n", 168 | " 'eventid2name.csv',\n", 169 | " 'coaches.json',\n", 170 | " 'tags2name.csv',\n", 171 | " 'competitions.json',\n", 172 | " 'referees.json',\n", 173 | " 'matches_Germany.json',\n", 174 | " 'events_Italy.json',\n", 175 | " 'matches.zip',\n", 176 | " 'matches_Italy.json',\n", 177 | " 'events_European_Championship.json',\n", 178 | " 'teams.json',\n", 179 | " 'matches_France.json',\n", 180 | " 'events_Germany.json',\n", 181 | " 'events_Spain.json',\n", 182 | " 'events_World_Cup.json',\n", 183 | " 'players.json',\n", 184 | " 'matches_World_Cup.json',\n", 185 | " 'matches_European_Championship.json',\n", 186 | " 'matches_England.json',\n", 187 | " 'matches_Spain.json']" 188 | ] 189 | }, 190 | "execution_count": 15, 191 | "metadata": {}, 192 | "output_type": "execute_result" 193 | } 194 | ], 195 | "source": [ 196 | "for url in tqdm(dataset_urls.values()):\n", 197 | " url_obj = urlopen(url).geturl()\n", 198 | " path = Path(urlparse(url_obj).path)\n", 199 | " file_name = os.path.join(raw_datafolder, path.name)\n", 200 | " file_local, _ = urlretrieve(url_obj, file_name)\n", 201 | " if is_zipfile(file_local):\n", 202 | " with ZipFile(file_local) as zip_file:\n", 203 | " zip_file.extractall(raw_datafolder)\n", 204 | "\n", 205 | "print(\"Downloaded files:\")\n", 206 | "os.listdir(raw_datafolder)" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": {}, 212 | "source": [ 213 | "## Preprocess Wyscout data" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "The read_json_file function reads and returns the content of a given JSON file. The function handles the encoding of special characters (e.g., accents in names of players and teams) that the pd.read_json function cannot handle properly." 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 16, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "def read_json_file(filename):\n", 230 | " with open(filename, 'rb') as json_file:\n", 231 | " return BytesIO(json_file.read()).getvalue().decode('unicode_escape')" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "Wyscout does not distinguish between headers and other body\n", 239 | "parts on shots. The socceraction convertor simply labels all\n", 240 | "shots as performed by foot. I think it is better to label \n", 241 | "them as headers." 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 17, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "def determine_bodypart_id(event):\n", 251 | " \"\"\"\n", 252 | " This function determines the body part used for an event\n", 253 | " Args:\n", 254 | " event (pd.Series): Wyscout event Series\n", 255 | " Returns:\n", 256 | " int: id of the body part used for the action\n", 257 | " \"\"\"\n", 258 | " if event[\"subtype_id\"] in [81, 36, 21, 90, 91]:\n", 259 | " body_part = \"other\"\n", 260 | " elif event[\"subtype_id\"] == 82 or event['head/body']:\n", 261 | " body_part = \"head\"\n", 262 | " else: # all other cases\n", 263 | " body_part = \"foot\"\n", 264 | " return spadl.config.bodyparts.index(body_part)\n", 265 | "wyscout.determine_bodypart_id = determine_bodypart_id" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "### Select competitions to load and convert" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 18, 278 | "metadata": {}, 279 | "outputs": [ 280 | { 281 | "data": { 282 | "text/plain": [ 283 | "{'England',\n", 284 | " 'European Championship',\n", 285 | " 'France',\n", 286 | " 'Germany',\n", 287 | " 'Italy',\n", 288 | " 'Spain',\n", 289 | " 'World Cup'}" 290 | ] 291 | }, 292 | "execution_count": 18, 293 | "metadata": {}, 294 | "output_type": "execute_result" 295 | } 296 | ], 297 | "source": [ 298 | "json_competitions = read_json_file(f\"{raw_datafolder}/competitions.json\")\n", 299 | "df_competitions = pd.read_json(json_competitions)\n", 300 | "# Rename competitions to the names used in the file names\n", 301 | "df_competitions['name'] = df_competitions.apply(lambda x: x.area['name'] if x.area['name'] != \"\" else x['name'], axis=1)\n", 302 | "df_competitions['id'] = df_competitions.apply(lambda x: leagues.get(x.area['name'], 'NULL'), axis=1)\n", 303 | "# View all available competitions\n", 304 | "set(df_competitions.name)" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 19, 310 | "metadata": {}, 311 | "outputs": [ 312 | { 313 | "data": { 314 | "text/html": [ 315 | "
\n", 316 | "\n", 329 | "\n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | "
namewyIdformatareatypeid
0Italy524Domestic league{'name': 'Italy', 'id': '380', 'alpha3code': '...clubITA
1England364Domestic league{'name': 'England', 'id': '0', 'alpha3code': '...clubENG
2Spain795Domestic league{'name': 'Spain', 'id': '724', 'alpha3code': '...clubESP
3France412Domestic league{'name': 'France', 'id': '250', 'alpha3code': ...clubFRA
4Germany426Domestic league{'name': 'Germany', 'id': '276', 'alpha3code':...clubGER
\n", 389 | "
" 390 | ], 391 | "text/plain": [ 392 | " name wyId format \\\n", 393 | "0 Italy 524 Domestic league \n", 394 | "1 England 364 Domestic league \n", 395 | "2 Spain 795 Domestic league \n", 396 | "3 France 412 Domestic league \n", 397 | "4 Germany 426 Domestic league \n", 398 | "\n", 399 | " area type id \n", 400 | "0 {'name': 'Italy', 'id': '380', 'alpha3code': '... club ITA \n", 401 | "1 {'name': 'England', 'id': '0', 'alpha3code': '... club ENG \n", 402 | "2 {'name': 'Spain', 'id': '724', 'alpha3code': '... club ESP \n", 403 | "3 {'name': 'France', 'id': '250', 'alpha3code': ... club FRA \n", 404 | "4 {'name': 'Germany', 'id': '276', 'alpha3code':... club GER " 405 | ] 406 | }, 407 | "execution_count": 19, 408 | "metadata": {}, 409 | "output_type": "execute_result" 410 | } 411 | ], 412 | "source": [ 413 | "df_selected_competitions = df_competitions[df_competitions.name.isin(\n", 414 | " leagues.keys()\n", 415 | ")]\n", 416 | "df_selected_competitions" 417 | ] 418 | }, 419 | { 420 | "cell_type": "markdown", 421 | "metadata": {}, 422 | "source": [ 423 | "## Convert to the SPADL format" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": 22, 429 | "metadata": {}, 430 | "outputs": [ 431 | { 432 | "name": "stdout", 433 | "output_type": "stream", 434 | "text": [ 435 | "Converting ITA 1718\n" 436 | ] 437 | }, 438 | { 439 | "data": { 440 | "application/vnd.jupyter.widget-view+json": { 441 | "model_id": "1ab744a7c2ae49959e710b309b62cb1a", 442 | "version_major": 2, 443 | "version_minor": 0 444 | }, 445 | "text/plain": [ 446 | "HBox(children=(FloatProgress(value=0.0, max=380.0), HTML(value='')))" 447 | ] 448 | }, 449 | "metadata": {}, 450 | "output_type": "display_data" 451 | }, 452 | { 453 | "name": "stdout", 454 | "output_type": "stream", 455 | "text": [ 456 | "\n", 457 | "Converting ENG 1718\n" 458 | ] 459 | }, 460 | { 461 | "data": { 462 | "application/vnd.jupyter.widget-view+json": { 463 | "model_id": "50692b69f4da415397980283ea063570", 464 | "version_major": 2, 465 | "version_minor": 0 466 | }, 467 | "text/plain": [ 468 | "HBox(children=(FloatProgress(value=0.0, max=380.0), HTML(value='')))" 469 | ] 470 | }, 471 | "metadata": {}, 472 | "output_type": "display_data" 473 | }, 474 | { 475 | "name": "stdout", 476 | "output_type": "stream", 477 | "text": [ 478 | "\n", 479 | "Converting ESP 1718\n" 480 | ] 481 | }, 482 | { 483 | "data": { 484 | "application/vnd.jupyter.widget-view+json": { 485 | "model_id": "f8357efe4bb34d7981647a9dea4acb43", 486 | "version_major": 2, 487 | "version_minor": 0 488 | }, 489 | "text/plain": [ 490 | "HBox(children=(FloatProgress(value=0.0, max=380.0), HTML(value='')))" 491 | ] 492 | }, 493 | "metadata": {}, 494 | "output_type": "display_data" 495 | }, 496 | { 497 | "name": "stdout", 498 | "output_type": "stream", 499 | "text": [ 500 | "\n", 501 | "Converting FRA 1718\n" 502 | ] 503 | }, 504 | { 505 | "data": { 506 | "application/vnd.jupyter.widget-view+json": { 507 | "model_id": "758211fbdae845cab30001d48dba442c", 508 | "version_major": 2, 509 | "version_minor": 0 510 | }, 511 | "text/plain": [ 512 | "HBox(children=(FloatProgress(value=0.0, max=380.0), HTML(value='')))" 513 | ] 514 | }, 515 | "metadata": {}, 516 | "output_type": "display_data" 517 | }, 518 | { 519 | "name": "stdout", 520 | "output_type": "stream", 521 | "text": [ 522 | "\n", 523 | "Converting GER 1718\n" 524 | ] 525 | }, 526 | { 527 | "data": { 528 | "application/vnd.jupyter.widget-view+json": { 529 | "model_id": "f4594b50fcd948d9a371a3f8fa47dd65", 530 | "version_major": 2, 531 | "version_minor": 0 532 | }, 533 | "text/plain": [ 534 | "HBox(children=(FloatProgress(value=0.0, max=306.0), HTML(value='')))" 535 | ] 536 | }, 537 | "metadata": {}, 538 | "output_type": "display_data" 539 | }, 540 | { 541 | "name": "stdout", 542 | "output_type": "stream", 543 | "text": [ 544 | "\n" 545 | ] 546 | } 547 | ], 548 | "source": [ 549 | "json_teams = read_json_file(f\"{raw_datafolder}/teams.json\")\n", 550 | "df_teams = wyscout.convert_teams(pd.read_json(json_teams))\n", 551 | "\n", 552 | "json_players = read_json_file(f\"{raw_datafolder}/players.json\")\n", 553 | "df_players = wyscout.convert_players(pd.read_json(json_players))\n", 554 | "\n", 555 | "\n", 556 | "for competition in df_selected_competitions.itertuples():\n", 557 | " json_matches = read_json_file(f\"{raw_datafolder}/matches_{competition.name}.json\")\n", 558 | " df_matches = pd.read_json(json_matches)\n", 559 | " season_id = seasons[df_matches.seasonId.unique()[0]]\n", 560 | " df_games = wyscout.convert_games(df_matches)\n", 561 | " df_games['competition_id'] = competition.id\n", 562 | " df_games['season_id'] = season_id\n", 563 | " \n", 564 | " json_events = read_json_file(f\"{raw_datafolder}/events_{competition.name}.json\")\n", 565 | " df_events = pd.read_json(json_events).groupby('matchId', as_index=False)\n", 566 | " \n", 567 | " player_games = []\n", 568 | " \n", 569 | " spadl_h5 = os.path.join(spadl_datafolder, f\"spadl-wyscout_opensource-{competition.id}-{season_id}.h5\")\n", 570 | "\n", 571 | " # Store all spadl data in h5-file\n", 572 | " print(f\"Converting {competition.id} {season_id}\")\n", 573 | " with pd.HDFStore(spadl_h5) as spadlstore:\n", 574 | " \n", 575 | " spadlstore[\"actiontypes\"] = spadl.actiontypes_df()\n", 576 | " spadlstore[\"results\"] = spadl.results_df()\n", 577 | " spadlstore[\"bodyparts\"] = spadl.bodyparts_df()\n", 578 | " spadlstore[\"games\"] = df_games\n", 579 | "\n", 580 | " for game in tqdm(list(df_games.itertuples())):\n", 581 | " game_id = game.game_id\n", 582 | " game_events = df_events.get_group(game_id)\n", 583 | "\n", 584 | " # filter the players that were lined up in this season\n", 585 | " player_games.append(wyscout.get_player_games(df_matches[df_matches.wyId == game_id].iloc[0], game_events))\n", 586 | "\n", 587 | " # convert events to SPADL actions\n", 588 | " home_team = game.home_team_id\n", 589 | " df_actions = wyscout.convert_actions(game_events, home_team)\n", 590 | " df_actions[\"action_id\"] = range(len(df_actions))\n", 591 | " spadlstore[f\"actions/game_{game_id}\"] = df_actions\n", 592 | "\n", 593 | " player_games = pd.concat(player_games).reset_index(drop=True) \n", 594 | " spadlstore[\"player_games\"] = player_games\n", 595 | " spadlstore[\"players\"] = df_players[df_players.player_id.isin(player_games.player_id)]\n", 596 | " spadlstore[\"teams\"] = df_teams[df_teams.team_id.isin(df_games.home_team_id) | df_teams.team_id.isin(df_games.away_team_id)]" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": null, 602 | "metadata": {}, 603 | "outputs": [], 604 | "source": [] 605 | } 606 | ], 607 | "metadata": { 608 | "kernelspec": { 609 | "display_name": "soccer_dataprovider_comparison", 610 | "language": "python", 611 | "name": "soccer_dataprovider_comparison" 612 | }, 613 | "language_info": { 614 | "codemirror_mode": { 615 | "name": "ipython", 616 | "version": 3 617 | }, 618 | "file_extension": ".py", 619 | "mimetype": "text/x-python", 620 | "name": "python", 621 | "nbconvert_exporter": "python", 622 | "pygments_lexer": "ipython3", 623 | "version": "3.6.2" 624 | }, 625 | "toc": { 626 | "base_numbering": 1, 627 | "nav_menu": {}, 628 | "number_sections": true, 629 | "sideBar": true, 630 | "skip_h1_title": false, 631 | "title_cell": "Table of Contents", 632 | "title_sidebar": "Contents", 633 | "toc_cell": false, 634 | "toc_position": {}, 635 | "toc_section_display": true, 636 | "toc_window_display": true 637 | } 638 | }, 639 | "nbformat": 4, 640 | "nbformat_minor": 2 641 | } 642 | -------------------------------------------------------------------------------- /notebooks/3-computing-and-storing-features.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Features\n", 8 | "This notebook generates features and labels (goal/no goal) for all shots and stores them in a HDF file. This is a good practice to save computational time if you want to experiment with multiple pipelines." 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "import pandas as pd\n", 18 | "pd.set_option('display.max_columns', None)\n", 19 | "import numpy as np\n", 20 | "import itertools" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "%load_ext autoreload\n", 30 | "%autoreload 2\n", 31 | "from soccer_xg.api import DataApi\n", 32 | "import soccer_xg.xg as xg\n", 33 | "import soccer_xg.features as fs" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## Config" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "# dataset\n", 50 | "dir_data = \"../data\"\n", 51 | "provider = 'wyscout_opensource'\n", 52 | "leagues = ['ENG', 'ESP', 'ITA', 'GER', 'FRA']\n", 53 | "seasons = ['1718']\n", 54 | "\n", 55 | "# features\n", 56 | "store_features = f'../data/{provider}/features.h5'" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "By default, all features defined in `soccer_xg.features.all_features` are computed. It is also possible to compute a subset of these features or add additional feature generators. Each feature generator is a function that expects either a DataFrame object containing actions (i.e., individual actions) or a list of DataFrame objects containing consecutive actions (i.e., game states), and returns the corresponding feature for the individual action or game state. Features that contain information about the shot's outcome are automatically removed." 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "feature_generators = fs.all_features" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "## Compute features and labels" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 5, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "ENG 1718\n" 92 | ] 93 | }, 94 | { 95 | "name": "stderr", 96 | "output_type": "stream", 97 | "text": [ 98 | "Generating features: 100%|██████████| 380/380 [01:40<00:00, 3.78it/s]\n", 99 | "Generating labels: 100%|██████████| 380/380 [00:08<00:00, 43.51it/s]\n" 100 | ] 101 | }, 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "ESP 1718\n" 107 | ] 108 | }, 109 | { 110 | "name": "stderr", 111 | "output_type": "stream", 112 | "text": [ 113 | "Generating features: 100%|██████████| 380/380 [01:40<00:00, 3.79it/s]\n", 114 | "Generating labels: 100%|██████████| 380/380 [00:08<00:00, 43.70it/s]\n" 115 | ] 116 | }, 117 | { 118 | "name": "stdout", 119 | "output_type": "stream", 120 | "text": [ 121 | "ITA 1718\n" 122 | ] 123 | }, 124 | { 125 | "name": "stderr", 126 | "output_type": "stream", 127 | "text": [ 128 | "Generating features: 100%|██████████| 380/380 [01:40<00:00, 3.77it/s]\n", 129 | "Generating labels: 100%|██████████| 380/380 [00:08<00:00, 43.62it/s]\n" 130 | ] 131 | }, 132 | { 133 | "name": "stdout", 134 | "output_type": "stream", 135 | "text": [ 136 | "GER 1718\n" 137 | ] 138 | }, 139 | { 140 | "name": "stderr", 141 | "output_type": "stream", 142 | "text": [ 143 | "Generating features: 100%|██████████| 306/306 [01:21<00:00, 3.77it/s]\n", 144 | "Generating labels: 100%|██████████| 306/306 [00:06<00:00, 43.93it/s]\n" 145 | ] 146 | }, 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "FRA 1718\n" 152 | ] 153 | }, 154 | { 155 | "name": "stderr", 156 | "output_type": "stream", 157 | "text": [ 158 | "Generating features: 100%|██████████| 380/380 [01:40<00:00, 3.78it/s]\n", 159 | "Generating labels: 100%|██████████| 380/380 [00:08<00:00, 43.76it/s]\n" 160 | ] 161 | } 162 | ], 163 | "source": [ 164 | "for (l,s) in itertools.product(leagues, seasons):\n", 165 | " print(l, s)\n", 166 | " api = DataApi([f\"{dir_data}/{provider}/spadl-{provider}-{l}-{s}.h5\"])\n", 167 | " xg.get_features(api, xfns=feature_generators).to_hdf(store_features, key=f'{l}/{s}/features', format='table') \n", 168 | " xg.get_labels(api).to_hdf(store_features, key=f'{l}/{s}/labels', format='table') " 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "metadata": {}, 174 | "source": [ 175 | "## Load features" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 6, 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "data": { 185 | "text/html": [ 186 | "
\n", 187 | "\n", 200 | "\n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | "
type_id_a0type_id_a1type_id_a2bodypart_id_a0bodypart_id_a1bodypart_id_a2result_id_a1result_id_a2start_x_a0start_y_a0start_x_a1start_y_a1start_x_a2start_y_a2end_x_a1end_y_a1end_x_a2end_y_a2dx_a1dy_a1movement_a1dx_a2dy_a2movement_a2dx_a01dy_a01mov_a01dx_a02dy_a02mov_a02start_dist_to_goal_a0start_angle_to_goal_a0start_dist_to_goal_a1start_angle_to_goal_a1start_dist_to_goal_a2start_angle_to_goal_a2end_dist_to_goal_a1end_angle_to_goal_a1end_dist_to_goal_a2end_angle_to_goal_a2team_1team_2time_delta_1time_delta_2speedx_a01speedy_a01speed_a01speedx_a02speedy_a02speed_a02shot_angle_a0shot_angle_a1shot_angle_a2caley_zone_a0caley_zone_a1caley_zone_a2angle_zone_a0angle_zone_a1angle_zone_a2
game_idaction_id
250009817shotdribblecrossfootfootfootsuccesssuccess99.7526.5291.3529.9297.656.1299.7526.5291.3529.928.40-3.409.062009-6.3023.8024.6197080.00.00.0-8.403.409.0620099.1385390.95881514.2467150.29044828.8325671.3130319.1385390.95881514.2467150.290448TrueTrue3.4332286.8664560.00.00.01.2233390.4951611.3197500.4997780.4837800.06550023891218
40shotcorner_crossedpassfootfootfootsuccessfail91.3535.36105.000.0096.6023.8091.3535.360.0053.72-13.6535.3637.903194-96.6029.92101.1274760.00.00.0-91.3518.3693.17677913.7175840.09930634.0000001.57079613.2136290.88187213.7175840.099306106.8357540.185647TrueTrue2.10253121.9272280.00.00.04.1660530.8373154.2493640.5179850.0000000.363334384122112
77shotclearancecrossfootfootfootfailfail75.6029.9294.5027.2098.7065.9675.6029.9294.5027.20-18.902.7219.094722-4.20-38.7638.9868900.00.00.018.90-2.7219.09472229.6817520.13789512.5095960.57470032.5750151.37617029.6817520.13789512.5095960.574700FalseTrue2.6298613.2506820.00.00.05.8141650.8367475.8740660.2424810.4915550.043863630181218
140shotcrossdribblefootfootfootsuccesssuccess92.4043.5298.7051.6891.3554.4092.4043.5298.7051.68-6.30-8.1610.3090067.35-2.727.8371490.00.00.06.308.1610.30900615.7920990.64704718.7689211.22848924.5455190.98109915.7920990.64704718.7689211.228489TrueTrue1.0524995.0006270.00.00.01.2598421.6317952.0615430.3715380.1348600.167545450121518
145shotpasspassfootfootfootsuccesssuccess99.7537.4096.6038.7693.4545.5699.7537.4096.6038.763.15-1.363.4310493.15-6.807.4941640.00.00.0-3.151.363.4310496.2547980.5747009.6549260.51554916.3412390.7858316.2547980.5747009.6549260.515549TrueTrue1.6777552.6599970.00.00.01.1842120.5112791.2898700.9782910.6546110.3208411346915
\n", 643 | "
" 644 | ], 645 | "text/plain": [ 646 | " type_id_a0 type_id_a1 type_id_a2 bodypart_id_a0 \\\n", 647 | "game_id action_id \n", 648 | "2500098 17 shot dribble cross foot \n", 649 | " 40 shot corner_crossed pass foot \n", 650 | " 77 shot clearance cross foot \n", 651 | " 140 shot cross dribble foot \n", 652 | " 145 shot pass pass foot \n", 653 | "\n", 654 | " bodypart_id_a1 bodypart_id_a2 result_id_a1 result_id_a2 \\\n", 655 | "game_id action_id \n", 656 | "2500098 17 foot foot success success \n", 657 | " 40 foot foot success fail \n", 658 | " 77 foot foot fail fail \n", 659 | " 140 foot foot success success \n", 660 | " 145 foot foot success success \n", 661 | "\n", 662 | " start_x_a0 start_y_a0 start_x_a1 start_y_a1 start_x_a2 \\\n", 663 | "game_id action_id \n", 664 | "2500098 17 99.75 26.52 91.35 29.92 97.65 \n", 665 | " 40 91.35 35.36 105.00 0.00 96.60 \n", 666 | " 77 75.60 29.92 94.50 27.20 98.70 \n", 667 | " 140 92.40 43.52 98.70 51.68 91.35 \n", 668 | " 145 99.75 37.40 96.60 38.76 93.45 \n", 669 | "\n", 670 | " start_y_a2 end_x_a1 end_y_a1 end_x_a2 end_y_a2 dx_a1 \\\n", 671 | "game_id action_id \n", 672 | "2500098 17 6.12 99.75 26.52 91.35 29.92 8.40 \n", 673 | " 40 23.80 91.35 35.36 0.00 53.72 -13.65 \n", 674 | " 77 65.96 75.60 29.92 94.50 27.20 -18.90 \n", 675 | " 140 54.40 92.40 43.52 98.70 51.68 -6.30 \n", 676 | " 145 45.56 99.75 37.40 96.60 38.76 3.15 \n", 677 | "\n", 678 | " dy_a1 movement_a1 dx_a2 dy_a2 movement_a2 dx_a01 \\\n", 679 | "game_id action_id \n", 680 | "2500098 17 -3.40 9.062009 -6.30 23.80 24.619708 0.0 \n", 681 | " 40 35.36 37.903194 -96.60 29.92 101.127476 0.0 \n", 682 | " 77 2.72 19.094722 -4.20 -38.76 38.986890 0.0 \n", 683 | " 140 -8.16 10.309006 7.35 -2.72 7.837149 0.0 \n", 684 | " 145 -1.36 3.431049 3.15 -6.80 7.494164 0.0 \n", 685 | "\n", 686 | " dy_a01 mov_a01 dx_a02 dy_a02 mov_a02 \\\n", 687 | "game_id action_id \n", 688 | "2500098 17 0.0 0.0 -8.40 3.40 9.062009 \n", 689 | " 40 0.0 0.0 -91.35 18.36 93.176779 \n", 690 | " 77 0.0 0.0 18.90 -2.72 19.094722 \n", 691 | " 140 0.0 0.0 6.30 8.16 10.309006 \n", 692 | " 145 0.0 0.0 -3.15 1.36 3.431049 \n", 693 | "\n", 694 | " start_dist_to_goal_a0 start_angle_to_goal_a0 \\\n", 695 | "game_id action_id \n", 696 | "2500098 17 9.138539 0.958815 \n", 697 | " 40 13.717584 0.099306 \n", 698 | " 77 29.681752 0.137895 \n", 699 | " 140 15.792099 0.647047 \n", 700 | " 145 6.254798 0.574700 \n", 701 | "\n", 702 | " start_dist_to_goal_a1 start_angle_to_goal_a1 \\\n", 703 | "game_id action_id \n", 704 | "2500098 17 14.246715 0.290448 \n", 705 | " 40 34.000000 1.570796 \n", 706 | " 77 12.509596 0.574700 \n", 707 | " 140 18.768921 1.228489 \n", 708 | " 145 9.654926 0.515549 \n", 709 | "\n", 710 | " start_dist_to_goal_a2 start_angle_to_goal_a2 \\\n", 711 | "game_id action_id \n", 712 | "2500098 17 28.832567 1.313031 \n", 713 | " 40 13.213629 0.881872 \n", 714 | " 77 32.575015 1.376170 \n", 715 | " 140 24.545519 0.981099 \n", 716 | " 145 16.341239 0.785831 \n", 717 | "\n", 718 | " end_dist_to_goal_a1 end_angle_to_goal_a1 \\\n", 719 | "game_id action_id \n", 720 | "2500098 17 9.138539 0.958815 \n", 721 | " 40 13.717584 0.099306 \n", 722 | " 77 29.681752 0.137895 \n", 723 | " 140 15.792099 0.647047 \n", 724 | " 145 6.254798 0.574700 \n", 725 | "\n", 726 | " end_dist_to_goal_a2 end_angle_to_goal_a2 team_1 team_2 \\\n", 727 | "game_id action_id \n", 728 | "2500098 17 14.246715 0.290448 True True \n", 729 | " 40 106.835754 0.185647 True True \n", 730 | " 77 12.509596 0.574700 False True \n", 731 | " 140 18.768921 1.228489 True True \n", 732 | " 145 9.654926 0.515549 True True \n", 733 | "\n", 734 | " time_delta_1 time_delta_2 speedx_a01 speedy_a01 \\\n", 735 | "game_id action_id \n", 736 | "2500098 17 3.433228 6.866456 0.0 0.0 \n", 737 | " 40 2.102531 21.927228 0.0 0.0 \n", 738 | " 77 2.629861 3.250682 0.0 0.0 \n", 739 | " 140 1.052499 5.000627 0.0 0.0 \n", 740 | " 145 1.677755 2.659997 0.0 0.0 \n", 741 | "\n", 742 | " speed_a01 speedx_a02 speedy_a02 speed_a02 \\\n", 743 | "game_id action_id \n", 744 | "2500098 17 0.0 1.223339 0.495161 1.319750 \n", 745 | " 40 0.0 4.166053 0.837315 4.249364 \n", 746 | " 77 0.0 5.814165 0.836747 5.874066 \n", 747 | " 140 0.0 1.259842 1.631795 2.061543 \n", 748 | " 145 0.0 1.184212 0.511279 1.289870 \n", 749 | "\n", 750 | " shot_angle_a0 shot_angle_a1 shot_angle_a2 caley_zone_a0 \\\n", 751 | "game_id action_id \n", 752 | "2500098 17 0.499778 0.483780 0.065500 2 \n", 753 | " 40 0.517985 0.000000 0.363334 3 \n", 754 | " 77 0.242481 0.491555 0.043863 6 \n", 755 | " 140 0.371538 0.134860 0.167545 4 \n", 756 | " 145 0.978291 0.654611 0.320841 1 \n", 757 | "\n", 758 | " caley_zone_a1 caley_zone_a2 angle_zone_a0 angle_zone_a1 \\\n", 759 | "game_id action_id \n", 760 | "2500098 17 3 8 9 12 \n", 761 | " 40 8 4 12 21 \n", 762 | " 77 3 0 18 12 \n", 763 | " 140 5 0 12 15 \n", 764 | " 145 3 4 6 9 \n", 765 | "\n", 766 | " angle_zone_a2 \n", 767 | "game_id action_id \n", 768 | "2500098 17 18 \n", 769 | " 40 12 \n", 770 | " 77 18 \n", 771 | " 140 18 \n", 772 | " 145 15 " 773 | ] 774 | }, 775 | "metadata": {}, 776 | "output_type": "display_data" 777 | }, 778 | { 779 | "data": { 780 | "text/html": [ 781 | "
\n", 782 | "\n", 795 | "\n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | "
goal
game_idaction_id
250009817False
40False
77False
140False
145False
\n", 832 | "
" 833 | ], 834 | "text/plain": [ 835 | " goal\n", 836 | "game_id action_id \n", 837 | "2500098 17 False\n", 838 | " 40 False\n", 839 | " 77 False\n", 840 | " 140 False\n", 841 | " 145 False" 842 | ] 843 | }, 844 | "metadata": {}, 845 | "output_type": "display_data" 846 | } 847 | ], 848 | "source": [ 849 | "features = []\n", 850 | "labels = []\n", 851 | "for (l,s) in itertools.product(leagues, seasons):\n", 852 | " features.append(pd.read_hdf(store_features, key=f'{l}/{s}/features'))\n", 853 | " labels.append(pd.read_hdf(store_features, key=f'{l}/{s}/labels'))\n", 854 | "features = pd.concat(features)\n", 855 | "labels = pd.concat(labels)\n", 856 | "\n", 857 | "display(features.head())\n", 858 | "display(labels.to_frame().head())" 859 | ] 860 | }, 861 | { 862 | "cell_type": "code", 863 | "execution_count": null, 864 | "metadata": {}, 865 | "outputs": [], 866 | "source": [] 867 | } 868 | ], 869 | "metadata": { 870 | "kernelspec": { 871 | "display_name": "soccer_dataprovider_comparison", 872 | "language": "python", 873 | "name": "soccer_dataprovider_comparison" 874 | }, 875 | "language_info": { 876 | "codemirror_mode": { 877 | "name": "ipython", 878 | "version": 3 879 | }, 880 | "file_extension": ".py", 881 | "mimetype": "text/x-python", 882 | "name": "python", 883 | "nbconvert_exporter": "python", 884 | "pygments_lexer": "ipython3", 885 | "version": "3.6.2" 886 | }, 887 | "toc": { 888 | "base_numbering": 1, 889 | "nav_menu": {}, 890 | "number_sections": true, 891 | "sideBar": true, 892 | "skip_h1_title": false, 893 | "title_cell": "Table of Contents", 894 | "title_sidebar": "Contents", 895 | "toc_cell": false, 896 | "toc_position": {}, 897 | "toc_section_display": true, 898 | "toc_window_display": true 899 | } 900 | }, 901 | "nbformat": 4, 902 | "nbformat_minor": 2 903 | } 904 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "soccer_xg" 3 | version = "0.0.1" 4 | description = "Train and analyse xG models on soccer event stream data" 5 | authors = ["Pieter Robberechts "] 6 | license = "Apache Software License (http://www.apache.org/licenses/LICENSE-2.0)" 7 | readme = 'README.md' 8 | repository = "https://github.com/probberechts/soccer_xg" 9 | homepage = "https://pypi.org/project/soccer_xg" 10 | keywords = ["expected goals", "xG", "Soccer", "Football", "event stream data", "sports analytics", "Statsbomb", "Opta", "Wyscout"] 11 | classifiers=[ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | "Operating System :: OS Independent" 15 | ] 16 | 17 | [tool.poetry.dependencies] 18 | python = "^3.6.1" 19 | numpy = "^1.18" 20 | pandas = "^1.0" 21 | scikit-learn = "^0.22.1" 22 | ipykernel = "^5.1" 23 | matplotsoccer = "^0.0.8" 24 | matplotlib = "^3.1" 25 | click = "^7.0" 26 | tables = "^3.6" 27 | requests = "^2.23" 28 | xgboost = "^1.0" 29 | seaborn = "^0.10.0" 30 | fuzzywuzzy = "^0.18.0" 31 | python-Levenshtein = "^0.12.0" 32 | dask = {extras = ["array","distributed","dataframe"], version = "^2.15.0", optional = true} 33 | dask_ml = {version = "^1.3.0", optional = true} 34 | category_encoders = "^2.2.2" 35 | asyncssh = {version = "^2.2.1", optional = true} 36 | paramiko = {version = "^2.7.1", optional = true} 37 | understat = "^0.1.2" 38 | socceraction = "^0.2.1" 39 | betacal = "^0.2.7" 40 | 41 | [tool.poetry.extras] 42 | dask = ["dask", "dask_ml", "asyncssh", "paramiko"] 43 | 44 | [tool.poetry.dev-dependencies] 45 | flake8-awesome = "^1.2" 46 | mypy = "^0.761.0" 47 | pylint = "^2.4" 48 | pytest = "^5.3" 49 | pytest-cov = "^2.8" 50 | pytest-deadfixtures = "^2.1" 51 | unify = "^0.5.0" 52 | black = {version = "^19.10b0", allow-prereleases = true} 53 | bumpversion = "^0.6.0" 54 | 55 | [build-system] 56 | requires = ["poetry>=0.12"] 57 | build-backend = "poetry.masonry.api" 58 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | enable-extensions = G 3 | exclude = .git, .venv 4 | ignore = 5 | A003 ; 'id' is a python builtin, consider renaming the class attribute 6 | W503 ; line break before binary operator 7 | max-complexity = 8 8 | max-line-length = 100 9 | show-source = true 10 | 11 | [mypy] 12 | check_untyped_defs = true 13 | disallow_any_generics = true 14 | disallow_incomplete_defs = true 15 | disallow_untyped_defs = true 16 | ignore_missing_imports = True 17 | no_implicit_optional = true 18 | 19 | [mypy-tests.*] 20 | disallow_untyped_defs = false 21 | 22 | [isort] 23 | balanced_wrapping = true 24 | default_section = THIRDPARTY 25 | include_trailing_comma=True 26 | known_first_party = src, tests 27 | line_length = 79 28 | multi_line_output = 3 29 | not_skip = __init__.py 30 | 31 | [pylint] 32 | good-names=i,j,k,e,x,_,pk,id 33 | max-args=5 34 | max-attributes=10 35 | max-bool-expr=5 36 | max-module-lines=200 37 | max-nested-blocks=2 38 | max-public-methods=5 39 | max-returns=5 40 | max-statements=20 41 | output-format = colorized 42 | 43 | disable= 44 | C0103, ; Constant name "api" doesn't conform to UPPER_CASE naming style (invalid-name) 45 | C0111, ; Missing module docstring (missing-docstring) 46 | C0330, ; Wrong hanging indentation before block (add 4 spaces) 47 | E0213, ; Method should have "self" as first argument (no-self-argument) - N805 for flake8 48 | R0201, ; Method could be a function (no-self-use) 49 | R0901, ; Too many ancestors (m/n) (too-many-ancestors) 50 | R0903, ; Too few public methods (m/n) (too-few-public-methods) 51 | 52 | ignored-classes= 53 | contextlib.closing, 54 | 55 | [coverage:run] 56 | omit = tests/*,**/__main__.py 57 | branch = True 58 | 59 | [coverage:report] 60 | show_missing = True 61 | skip_covered = True 62 | fail_under = 0 63 | -------------------------------------------------------------------------------- /soccer_xg/__init__.py: -------------------------------------------------------------------------------- 1 | from soccer_xg.api import DataApi 2 | from soccer_xg.xg import XGModel 3 | 4 | __version__ = '0.0.1' 5 | -------------------------------------------------------------------------------- /soccer_xg/api.py: -------------------------------------------------------------------------------- 1 | """An API wrapper for the SPADL format.""" 2 | import logging 3 | import os 4 | import warnings 5 | 6 | import pandas as pd 7 | 8 | deduplic = dict( 9 | games=('game_id', 'game_id'), 10 | teams=(['team_id'], 'team_id'), 11 | players=(['player_id'], 'player_id'), 12 | player_games=( 13 | ['game_id', 'team_id', 'player_id'], 14 | ['game_id', 'team_id', 'player_id'], 15 | ), 16 | files=('file_url', 'file_url'), 17 | ) 18 | 19 | 20 | class DataApi: 21 | """An objectect that provides easy access to a SPADL event stream dataset. 22 | 23 | Automatically defines an attribute which lazy loads the contents of 24 | each table in the HDF files and defines a couple of methods to easily execute 25 | common queries on the SPADL data. 26 | 27 | Parameters 28 | ---------- 29 | db_path : A list of strings or a single string 30 | Path(s) to HDF files containing the data. 31 | 32 | Attributes 33 | ---------- 34 | ``table_name`` : ``pd.DataFrame`` 35 | A single pandas dataframe that contains all records from all ``table_name`` 36 | tables in each HDF file. 37 | """ 38 | 39 | def __init__(self, db_path): 40 | self.logger = logging.getLogger(__name__) 41 | self.logger.info('Loading datasets') 42 | if type(db_path) is list: 43 | self.db_path = set(db_path) 44 | elif type(db_path) is not set: 45 | self.db_path = set([db_path]) 46 | 47 | for p in self.db_path: 48 | if not os.path.exists(p): 49 | raise ValueError( 50 | 'A database `{}` does not exist.'.format(str(p)) 51 | ) 52 | 53 | def __getattr__(self, name): 54 | self.logger.info(f'Loading `{name}` data') 55 | DB = [] 56 | for p in self.db_path: 57 | with pd.HDFStore(p, 'r') as store: 58 | for key in [ 59 | k for k in store.keys() if (k[1:].rsplit('/')[0] == name) 60 | ]: 61 | db = store[key] 62 | db['db_path'] = p 63 | DB.append(db) 64 | if len(DB) == 0: 65 | raise ValueError('A table `{}` does not exist.'.format(str(name))) 66 | else: 67 | DB = pd.concat(DB, sort=False) 68 | if name in deduplic: 69 | sortcols, idcols = deduplic[name] 70 | DB.sort_values(by=sortcols, ascending=False, inplace=True) 71 | DB.drop_duplicates(subset=idcols, inplace=True) 72 | DB.set_index(idcols, inplace=True) 73 | setattr(self, name, DB) 74 | return DB 75 | 76 | def get_events(self, game_id, only_home=False, only_away=False): 77 | """Return all events performed in a given game. 78 | 79 | Parameters 80 | ---------- 81 | game_id : int 82 | The ID of a game. 83 | only_home : bool 84 | Include only events from the home team. 85 | only_away : bool 86 | Include only events from the away team. 87 | 88 | Returns 89 | ------- 90 | pd.DataFrame 91 | A dataframe with a row for each event, indexed by period_id and 92 | a timestamp (ms) in which the event happened. 93 | 94 | Raises 95 | ------ 96 | ValueError 97 | If both `only_home` and `only_away` are True. 98 | IndexError 99 | If no game exists with the provided ID. 100 | """ 101 | if only_home and only_away: 102 | raise ValueError('only_home and only_away cannot be both True.') 103 | 104 | try: 105 | db = self.games.at[game_id, 'db_path'] 106 | with pd.HDFStore(db, 'r') as store: 107 | df_game_events = store.get(f'events/game_{game_id}') 108 | home_team_id, away_team_id = self.get_home_away_team_id( 109 | game_id 110 | ) 111 | if only_home: 112 | team_filter = df_game_events.team_id == home_team_id 113 | elif only_away: 114 | team_filter = df_game_events.team_id == away_team_id 115 | else: 116 | team_filter = [True] * len(df_game_events) 117 | return df_game_events.loc[(team_filter), :].set_index( 118 | ['period_id', 'period_milliseconds'] 119 | ) 120 | except KeyError: 121 | raise IndexError( 122 | 'No events found for a game with the provided ID.' 123 | ) 124 | 125 | def get_actions( 126 | self, game_id, only_home=False, only_away=False, features=False 127 | ): 128 | """Return all actions performed in a given game. 129 | 130 | Parameters 131 | ---------- 132 | game_id : int 133 | The ID of a game. 134 | only_home : bool 135 | Include only actions from the home team. 136 | only_away : bool 137 | Include only actions from the away team. 138 | 139 | Returns 140 | ------- 141 | pd.DataFrame 142 | A dataframe with a row for each action, indexed by period_id and 143 | a timestamp (ms) in which the action was executed. 144 | 145 | Raises 146 | ------ 147 | ValueError 148 | If both `only_home` and `only_away` are True. 149 | IndexError 150 | If no game exists with the provided ID. 151 | """ 152 | if only_home and only_away: 153 | raise ValueError('only_home and only_away cannot be both True.') 154 | 155 | try: 156 | db = self.games.at[game_id, 'db_path'] 157 | with pd.HDFStore(db, 'r') as store: 158 | df_game_actions = store.get(f'actions/game_{game_id}') 159 | if features: 160 | try: 161 | df_game_features = store.get( 162 | f'features/game_{game_id}' 163 | ) 164 | df_game_actions = pd.concat( 165 | [df_game_actions, df_game_features], axis=1 166 | ) 167 | except KeyError: 168 | warnings.warn('Could not find precomputed features') 169 | 170 | home_team_id, away_team_id = self.get_home_away_team_id(game_id) 171 | if only_home: 172 | team_filter = df_game_actions.team_id == home_team_id 173 | elif only_away: 174 | team_filter = df_game_actions.team_id == away_team_id 175 | else: 176 | team_filter = [True] * len(df_game_actions) 177 | return df_game_actions.loc[(team_filter), :].set_index( 178 | ['action_id'] 179 | ) 180 | except KeyError: 181 | raise IndexError( 182 | 'No actions found for a game with the provided ID.' 183 | ) 184 | 185 | # Games ################################################################## 186 | 187 | def get_home_away_team_id(self, game_id): 188 | """Return the id of the home and away team in a given game. 189 | 190 | Parameters 191 | ---------- 192 | game_id : int 193 | The ID of a game. 194 | 195 | Returns 196 | ------- 197 | (int, int) 198 | The ID of the home and away team. 199 | 200 | Raises 201 | ------ 202 | IndexError 203 | If no game exists with the provided ID. 204 | """ 205 | try: 206 | return self.games.loc[ 207 | game_id, ['home_team_id', 'away_team_id'] 208 | ].values 209 | except KeyError: 210 | raise IndexError('No game found with the provided ID.') 211 | 212 | # Players ################################################################ 213 | 214 | def get_player_name(self, player_id): 215 | """Return the name of a player with a given ID. 216 | 217 | Parameters 218 | ---------- 219 | player_id : int 220 | The ID of a player. 221 | 222 | Returns 223 | ------- 224 | The name of the player. 225 | 226 | Raises 227 | ------ 228 | IndexError 229 | If no player exists with the provided ID. 230 | """ 231 | try: 232 | return self.players.at[player_id, 'player_name'] 233 | except KeyError: 234 | raise IndexError('No player found with the provided ID.') 235 | 236 | def search_player(self, query, limit=10): 237 | """Search for a player by name. 238 | 239 | Parameters 240 | ---------- 241 | query : str 242 | The name of a player. 243 | limit : int 244 | Max number of results that are returned. 245 | 246 | Returns 247 | ------- 248 | pd.DataFrame 249 | The first `limit` players that game the given query. 250 | 251 | """ 252 | return self.players[ 253 | self.players.player_name.str.contains(query, case=False) 254 | ].head(limit) 255 | 256 | # Teams ################################################################## 257 | 258 | def get_team_name(self, team_id): 259 | """Return the name of a team with a given ID. 260 | 261 | Parameters 262 | ---------- 263 | team_id : int 264 | The ID of a team. 265 | 266 | Returns 267 | ------- 268 | The name of the team. 269 | 270 | Raises 271 | ------ 272 | IndexError 273 | If no team exists with the provided ID. 274 | """ 275 | try: 276 | return self.teams.at[team_id, 'team_name'] 277 | except KeyError: 278 | raise IndexError('No team found with the provided ID.') 279 | 280 | def search_team(self, query, limit=10): 281 | """Search for a team by name. 282 | 283 | Parameters 284 | ---------- 285 | query : str 286 | The name of a team. 287 | limit : int 288 | Max number of results that are returned. 289 | 290 | Returns 291 | ------- 292 | pd.DataFrame 293 | The first `limit` teams that game the given query. 294 | 295 | """ 296 | return self.teams[ 297 | self.teams.team_name.str.contains(query, case=False) 298 | ].head(limit) 299 | -------------------------------------------------------------------------------- /soccer_xg/calibration.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import warnings 4 | from inspect import signature 5 | 6 | import numpy as np 7 | from betacal import BetaCalibration 8 | from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, clone 9 | from sklearn.isotonic import IsotonicRegression 10 | from sklearn.linear_model import LogisticRegression 11 | from sklearn.model_selection import check_cv 12 | from sklearn.preprocessing import LabelBinarizer, label_binarize 13 | from sklearn.svm import LinearSVC 14 | from sklearn.utils import check_X_y, column_or_1d, indexable 15 | from sklearn.utils.validation import check_is_fitted 16 | 17 | 18 | class CalibratedClassifierCV(BaseEstimator, ClassifierMixin): 19 | """Probability calibration with isotonic regression, sigmoid or beta. 20 | 21 | With this class, the base_estimator is fit on the train set of the 22 | cross-validation generator and the test set is used for calibration. 23 | The probabilities for each of the folds are then averaged 24 | for prediction. In case cv="prefit" is passed to __init__, 25 | it is assumed that base_estimator has been 26 | fitted already and all data is used for calibration. Note that 27 | data for fitting the classifier and for calibrating it must be disjoint. 28 | 29 | Read more in the :ref:`User Guide `. 30 | 31 | Parameters 32 | ---------- 33 | base_estimator : instance BaseEstimator 34 | The classifier whose output decision function needs to be calibrated 35 | to offer more accurate predict_proba outputs. If cv=prefit, the 36 | classifier must have been fit already on data. 37 | 38 | method : None, 'sigmoid', 'isotonic', 'beta', 'beta_am' or 'beta_ab' 39 | The method to use for calibration. Can be 'sigmoid' which 40 | corresponds to Platt's method, 'isotonic' which is a 41 | non-parameteric approach or 'beta', 'beta_am' or 'beta_ab' which 42 | correspond to three different beta calibration methods. It is 43 | not advised to use isotonic calibration with too few calibration 44 | samples ``(<<1000)`` since it tends to overfit. 45 | Use beta models in this case. 46 | 47 | cv : integer, cross-validation generator, iterable or "prefit", optional 48 | Determines the cross-validation splitting strategy. 49 | Possible inputs for cv are: 50 | 51 | - None, to use the default 3-fold cross-validation, 52 | - integer, to specify the number of folds. 53 | - An object to be used as a cross-validation generator. 54 | - An iterable yielding train/test splits. 55 | 56 | For integer/None inputs, if ``y`` is binary or multiclass, 57 | :class:`StratifiedKFold` used. If ``y`` is neither binary nor 58 | multiclass, :class:`KFold` is used. 59 | 60 | Refer :ref:`User Guide ` for the various 61 | cross-validation strategies that can be used here. 62 | 63 | If "prefit" is passed, it is assumed that base_estimator has been 64 | fitted already and all data is used for calibration. 65 | 66 | Attributes 67 | ---------- 68 | classes_ : array, shape (n_classes) 69 | The class labels. 70 | 71 | calibrated_classifiers_: list (len() equal to cv or 1 if cv == "prefit") 72 | The list of calibrated classifiers, one for each cross-validation fold, 73 | which has been fitted on all but the validation fold and calibrated 74 | on the validation fold. 75 | 76 | References 77 | ---------- 78 | .. [1] Obtaining calibrated probability estimates from decision trees 79 | and naive Bayesian classifiers, B. Zadrozny & C. Elkan, ICML 2001 80 | 81 | .. [2] Transforming Classifier Scores into Accurate Multiclass 82 | Probability Estimates, B. Zadrozny & C. Elkan, (KDD 2002) 83 | 84 | .. [3] Probabilistic Outputs for Support Vector Machines and Comparisons to 85 | Regularized Likelihood Methods, J. Platt, (1999) 86 | 87 | .. [4] Predicting Good Probabilities with Supervised Learning, 88 | A. Niculescu-Mizil & R. Caruana, ICML 2005 89 | """ 90 | 91 | def __init__( 92 | self, base_estimator=None, method=None, cv=3, score_type=None 93 | ): 94 | self.base_estimator = base_estimator 95 | self.method = method 96 | self.cv = cv 97 | self.score_type = score_type 98 | 99 | def fit(self, X, y, sample_weight=None): 100 | """Fit the calibrated model 101 | 102 | Parameters 103 | ---------- 104 | X : array-like, shape (n_samples, n_features) 105 | Training data. 106 | 107 | y : array-like, shape (n_samples,) 108 | Target values. 109 | 110 | sample_weight : array-like, shape = [n_samples] or None 111 | Sample weights. If None, then samples are equally weighted. 112 | 113 | Returns 114 | ------- 115 | self : object 116 | Returns an instance of self. 117 | """ 118 | # X, y = check_X_y(X, y, accept_sparse=['csc', 'csr', 'coo'], 119 | # force_all_finite=False) 120 | X, y = indexable(X, y) 121 | lb = LabelBinarizer().fit(y) 122 | self.classes_ = lb.classes_ 123 | 124 | # Check that each cross-validation fold can have at least one 125 | # example per class 126 | n_folds = ( 127 | self.cv 128 | if isinstance(self.cv, int) 129 | else self.cv.n_folds 130 | if hasattr(self.cv, 'n_folds') 131 | else None 132 | ) 133 | if n_folds and np.any( 134 | [np.sum(y == class_) < n_folds for class_ in self.classes_] 135 | ): 136 | raise ValueError( 137 | 'Requesting %d-fold cross-validation but provided' 138 | ' less than %d examples for at least one class.' 139 | % (n_folds, n_folds) 140 | ) 141 | 142 | self.calibrated_classifiers_ = [] 143 | if self.base_estimator is None: 144 | # we want all classifiers that don't expose a random_state 145 | # to be deterministic (and we don't want to expose this one). 146 | base_estimator = LinearSVC(random_state=0) 147 | else: 148 | base_estimator = self.base_estimator 149 | 150 | if self.cv == 'prefit': 151 | calibrated_classifier = _CalibratedClassifier( 152 | base_estimator, method=self.method, score_type=self.score_type 153 | ) 154 | if sample_weight is not None: 155 | calibrated_classifier.fit(X, y, sample_weight) 156 | else: 157 | calibrated_classifier.fit(X, y) 158 | self.calibrated_classifiers_.append(calibrated_classifier) 159 | else: 160 | cv = check_cv(self.cv, X, y, classifier=True) 161 | fit_parameters = signature(base_estimator.fit).parameters 162 | estimator_name = type(base_estimator).__name__ 163 | if ( 164 | sample_weight is not None 165 | and 'sample_weight' not in fit_parameters 166 | ): 167 | warnings.warn( 168 | '%s does not support sample_weight. Samples' 169 | ' weights are only used for the calibration' 170 | ' itself.' % estimator_name 171 | ) 172 | base_estimator_sample_weight = None 173 | else: 174 | base_estimator_sample_weight = sample_weight 175 | for train, test in cv: 176 | this_estimator = clone(base_estimator) 177 | if base_estimator_sample_weight is not None: 178 | this_estimator.fit( 179 | X[train], 180 | y[train], 181 | sample_weight=base_estimator_sample_weight[train], 182 | ) 183 | else: 184 | this_estimator.fit(X[train], y[train]) 185 | 186 | calibrated_classifier = _CalibratedClassifier( 187 | this_estimator, 188 | method=self.method, 189 | score_type=self.score_type, 190 | ) 191 | if sample_weight is not None: 192 | calibrated_classifier.fit( 193 | X[test], y[test], sample_weight[test] 194 | ) 195 | else: 196 | calibrated_classifier.fit(X[test], y[test]) 197 | self.calibrated_classifiers_.append(calibrated_classifier) 198 | 199 | return self 200 | 201 | def predict_proba(self, X): 202 | """Posterior probabilities of classification 203 | 204 | This function returns posterior probabilities of classification 205 | according to each class on an array of test vectors X. 206 | 207 | Parameters 208 | ---------- 209 | X : array-like, shape (n_samples, n_features) 210 | The samples. 211 | 212 | Returns 213 | ------- 214 | C : array, shape (n_samples, n_classes) 215 | The predicted probas. 216 | """ 217 | check_is_fitted(self, ['classes_', 'calibrated_classifiers_']) 218 | # X = check_array(X, accept_sparse=['csc', 'csr', 'coo'], 219 | # force_all_finite=False) 220 | # Compute the arithmetic mean of the predictions of the calibrated 221 | # classfiers 222 | mean_proba = np.zeros((X.shape[0], len(self.classes_))) 223 | for calibrated_classifier in self.calibrated_classifiers_: 224 | proba = calibrated_classifier.predict_proba(X) 225 | mean_proba += proba 226 | 227 | mean_proba /= len(self.calibrated_classifiers_) 228 | 229 | return mean_proba 230 | 231 | def calibrate_scores(self, df): 232 | """Posterior probabilities of classification 233 | 234 | This function returns posterior probabilities of classification 235 | according to each class on an array of test vectors X. 236 | 237 | Parameters 238 | ---------- 239 | X : array-like, shape (n_samples, n_features) 240 | The samples. 241 | 242 | Returns 243 | ------- 244 | C : array, shape (n_samples, n_classes) 245 | The predicted probas. 246 | """ 247 | check_is_fitted(self, ['classes_', 'calibrated_classifiers_']) 248 | # Compute the arithmetic mean of the predictions of the calibrated 249 | # classifiers 250 | df = df.reshape(-1, 1) 251 | mean_proba = np.zeros((len(df), len(self.classes_))) 252 | for calibrated_classifier in self.calibrated_classifiers_: 253 | proba = calibrated_classifier.calibrate_scores(df) 254 | mean_proba += proba 255 | 256 | mean_proba /= len(self.calibrated_classifiers_) 257 | 258 | return mean_proba 259 | 260 | def predict(self, X): 261 | """Predict the target of new samples. Can be different from the 262 | prediction of the uncalibrated classifier. 263 | 264 | Parameters 265 | ---------- 266 | X : array-like, shape (n_samples, n_features) 267 | The samples. 268 | 269 | Returns 270 | ------- 271 | C : array, shape (n_samples,) 272 | The predicted class. 273 | """ 274 | check_is_fitted(self, ['classes_', 'calibrated_classifiers_']) 275 | return self.classes_[np.argmax(self.predict_proba(X), axis=1)] 276 | 277 | 278 | class _CalibratedClassifier(object): 279 | """Probability calibration with isotonic regression or sigmoid. 280 | 281 | It assumes that base_estimator has already been fit, and trains the 282 | calibration on the input set of the fit function. Note that this class 283 | should not be used as an estimator directly. Use CalibratedClassifierCV 284 | with cv="prefit" instead. 285 | 286 | Parameters 287 | ---------- 288 | base_estimator : instance BaseEstimator 289 | The classifier whose output decision function needs to be calibrated 290 | to offer more accurate predict_proba outputs. No default value since 291 | it has to be an already fitted estimator. 292 | 293 | method : 'sigmoid' | 'isotonic' | 'beta' | 'beta_am' | 'beta_ab' 294 | The method to use for calibration. Can be 'sigmoid' which 295 | corresponds to Platt's method, 'isotonic' which is a 296 | non-parameteric approach based on isotonic regression or 'beta', 297 | 'beta_am' or 'beta_ab' which correspond to beta calibration methods. 298 | 299 | References 300 | ---------- 301 | .. [1] Obtaining calibrated probability estimates from decision trees 302 | and naive Bayesian classifiers, B. Zadrozny & C. Elkan, ICML 2001 303 | 304 | .. [2] Transforming Classifier Scores into Accurate Multiclass 305 | Probability Estimates, B. Zadrozny & C. Elkan, (KDD 2002) 306 | 307 | .. [3] Probabilistic Outputs for Support Vector Machines and Comparisons to 308 | Regularized Likelihood Methods, J. Platt, (1999) 309 | 310 | .. [4] Predicting Good Probabilities with Supervised Learning, 311 | A. Niculescu-Mizil & R. Caruana, ICML 2005 312 | """ 313 | 314 | def __init__(self, base_estimator, method='beta', score_type=None): 315 | self.base_estimator = base_estimator 316 | self.method = method 317 | self.score_type = score_type 318 | 319 | def _preproc(self, X): 320 | n_classes = len(self.classes_) 321 | if self.score_type is None: 322 | if hasattr(self.base_estimator, 'decision_function'): 323 | df = self.base_estimator.decision_function(X) 324 | if df.ndim == 1: 325 | df = df[:, np.newaxis] 326 | elif hasattr(self.base_estimator, 'predict_proba'): 327 | df = self.base_estimator.predict_proba(X) 328 | if n_classes == 2: 329 | df = df[:, 1:] 330 | else: 331 | raise RuntimeError( 332 | 'classifier has no decision_function or ' 333 | 'predict_proba method.' 334 | ) 335 | else: 336 | if hasattr(self.base_estimator, self.score_type): 337 | df = getattr(self.base_estimator, self.score_type)(X) 338 | if self.score_type == 'decision_function': 339 | if df.ndim == 1: 340 | df = df[:, np.newaxis] 341 | elif self.score_type == 'predict_proba': 342 | if n_classes == 2: 343 | df = df[:, 1:] 344 | else: 345 | raise RuntimeError( 346 | 'classifier has no ' + self.score_type + 'method.' 347 | ) 348 | 349 | idx_pos_class = np.arange(df.shape[1]) 350 | 351 | return df, idx_pos_class 352 | 353 | def fit(self, X, y, sample_weight=None): 354 | """Calibrate the fitted model 355 | 356 | Parameters 357 | ---------- 358 | X : array-like, shape (n_samples, n_features) 359 | Training data. 360 | 361 | y : array-like, shape (n_samples,) 362 | Target values. 363 | 364 | sample_weight : array-like, shape = [n_samples] or None 365 | Sample weights. If None, then samples are equally weighted. 366 | 367 | Returns 368 | ------- 369 | self : object 370 | Returns an instance of self. 371 | """ 372 | lb = LabelBinarizer() 373 | Y = lb.fit_transform(y) 374 | self.classes_ = lb.classes_ 375 | 376 | df, idx_pos_class = self._preproc(X) 377 | self.calibrators_ = [] 378 | 379 | for k, this_df in zip(idx_pos_class, df.T): 380 | if self.method is None: 381 | calibrator = _DummyCalibration() 382 | elif self.method == 'isotonic': 383 | calibrator = IsotonicRegression(out_of_bounds='clip') 384 | elif self.method == 'sigmoid': 385 | calibrator = _SigmoidCalibration() 386 | elif self.method == 'beta': 387 | calibrator = BetaCalibration(parameters='abm') 388 | elif self.method == 'beta_am': 389 | calibrator = BetaCalibration(parameters='am') 390 | elif self.method == 'beta_ab': 391 | calibrator = BetaCalibration(parameters='ab') 392 | else: 393 | raise ValueError( 394 | 'method should be None, "sigmoid", ' 395 | '"isotonic", "beta", "beta2" or "beta05". ' 396 | 'Got %s.' % self.method 397 | ) 398 | calibrator.fit(this_df, Y[:, k], sample_weight) 399 | self.calibrators_.append(calibrator) 400 | 401 | return self 402 | 403 | def predict_proba(self, X): 404 | """Posterior probabilities of classification 405 | 406 | This function returns posterior probabilities of classification 407 | according to each class on an array of test vectors X. 408 | 409 | Parameters 410 | ---------- 411 | X : array-like, shape (n_samples, n_features) 412 | The samples. 413 | 414 | Returns 415 | ------- 416 | C : array, shape (n_samples, n_classes) 417 | The predicted probas. Can be exact zeros. 418 | """ 419 | n_classes = len(self.classes_) 420 | proba = np.zeros((X.shape[0], n_classes)) 421 | 422 | df, idx_pos_class = self._preproc(X) 423 | for k, this_df, calibrator in zip( 424 | idx_pos_class, df.T, self.calibrators_ 425 | ): 426 | if n_classes == 2: 427 | k += 1 428 | proba[:, k] = calibrator.predict(this_df) 429 | 430 | # Normalize the probabilities 431 | if n_classes == 2: 432 | proba[:, 0] = 1.0 - proba[:, 1] 433 | else: 434 | proba /= np.sum(proba, axis=1)[:, np.newaxis] 435 | 436 | # XXX : for some reason all probas can be 0 437 | proba[np.isnan(proba)] = 1.0 / n_classes 438 | 439 | # Deal with cases where the predicted probability minimally exceeds 1.0 440 | proba[(1.0 < proba) & (proba <= 1.0 + 1e-5)] = 1.0 441 | 442 | return proba 443 | 444 | def calibrate_scores(self, df): 445 | """Posterior probabilities of classification 446 | 447 | This function returns posterior probabilities of classification 448 | according to each class on an array of test vectors X. 449 | 450 | Parameters 451 | ---------- 452 | X : array-like, shape (n_samples, n_features) 453 | The samples. 454 | 455 | Returns 456 | ------- 457 | C : array, shape (n_samples, n_classes) 458 | The predicted probas. Can be exact zeros. 459 | """ 460 | n_classes = len(self.classes_) 461 | proba = np.zeros((len(df), n_classes)) 462 | idx_pos_class = [0] 463 | 464 | for k, this_df, calibrator in zip( 465 | idx_pos_class, df.T, self.calibrators_ 466 | ): 467 | if n_classes == 2: 468 | k += 1 469 | pro = calibrator.predict(this_df) 470 | if np.any(np.isnan(pro)): 471 | pro[np.isnan(pro)] = calibrator.predict( 472 | this_df[np.isnan(pro)] + 1e-300 473 | ) 474 | proba[:, k] = pro 475 | 476 | # Normalize the probabilities 477 | if n_classes == 2: 478 | proba[:, 0] = 1.0 - proba[:, 1] 479 | else: 480 | proba /= np.sum(proba, axis=1)[:, np.newaxis] 481 | 482 | # XXX : for some reason all probas can be 0 483 | proba[np.isnan(proba)] = 1.0 / n_classes 484 | 485 | # Deal with cases where the predicted probability minimally exceeds 1.0 486 | proba[(1.0 < proba) & (proba <= 1.0 + 1e-5)] = 1.0 487 | return proba 488 | 489 | 490 | class _SigmoidCalibration(BaseEstimator, RegressorMixin): 491 | """Sigmoid regression model. 492 | 493 | Attributes 494 | ---------- 495 | a_ : float 496 | The slope. 497 | 498 | b_ : float 499 | The intercept. 500 | """ 501 | 502 | def fit(self, X, y, sample_weight=None): 503 | """Fit the model using X, y as training data. 504 | 505 | Parameters 506 | ---------- 507 | X : array-like, shape (n_samples,) 508 | Training data. 509 | 510 | y : array-like, shape (n_samples,) 511 | Training target. 512 | 513 | sample_weight : array-like, shape = [n_samples] or None 514 | Sample weights. If None, then samples are equally weighted. 515 | 516 | Returns 517 | ------- 518 | self : object 519 | Returns an instance of self. 520 | """ 521 | X = column_or_1d(X) 522 | y = column_or_1d(y) 523 | X, y = indexable(X, y) 524 | self.lr = LogisticRegression(C=99999999999) 525 | self.lr.fit(X.reshape(-1, 1), y) 526 | return self 527 | 528 | def predict(self, T): 529 | """Predict new data by linear interpolation. 530 | 531 | Parameters 532 | ---------- 533 | T : array-like, shape (n_samples,) 534 | Data to predict from. 535 | 536 | Returns 537 | ------- 538 | T_ : array, shape (n_samples,) 539 | The predicted data. 540 | """ 541 | T = column_or_1d(T) 542 | return self.lr.predict_proba(T.reshape(-1, 1))[:, 1] 543 | 544 | 545 | class _DummyCalibration(BaseEstimator, RegressorMixin): 546 | """Dummy regression model. The purpose of this class is to give 547 | the CalibratedClassifierCV class the option to just return the 548 | probabilities of the base classifier. 549 | 550 | 551 | """ 552 | 553 | def fit(self, X, y, sample_weight=None): 554 | """Does nothing. 555 | 556 | Parameters 557 | ---------- 558 | X : array-like, shape (n_samples,) 559 | Training data. 560 | 561 | y : array-like, shape (n_samples,) 562 | Training target. 563 | 564 | sample_weight : array-like, shape = [n_samples] or None 565 | Sample weights. If None, then samples are equally weighted. 566 | 567 | Returns 568 | ------- 569 | self : object 570 | Returns an instance of self. 571 | """ 572 | return self 573 | 574 | def predict(self, T): 575 | """Return the probabilities of the base classifier. 576 | 577 | Parameters 578 | ---------- 579 | T : array-like, shape (n_samples,) 580 | Data to predict from. 581 | 582 | Returns 583 | ------- 584 | T_ : array, shape (n_samples,) 585 | The predicted data. 586 | """ 587 | return T 588 | 589 | 590 | def calibration_curve(y_true, y_prob, normalize=False, n_bins=5): 591 | """Compute true and predicted probabilities for a calibration curve. 592 | 593 | Read more in the :ref:`User Guide `. 594 | 595 | Parameters 596 | ---------- 597 | y_true : array, shape (n_samples,) 598 | True targets. 599 | 600 | y_prob : array, shape (n_samples,) 601 | Probabilities of the positive class. 602 | 603 | normalize : bool, optional, default=False 604 | Whether y_prob needs to be normalized into the bin [0, 1], i.e. is not 605 | a proper probability. If True, the smallest value in y_prob is mapped 606 | onto 0 and the largest one onto 1. 607 | 608 | n_bins : int 609 | Number of bins. A bigger number requires more data. 610 | 611 | Returns 612 | ------- 613 | prob_true : array, shape (n_bins,) 614 | The true probability in each bin (fraction of positives). 615 | 616 | prob_pred : array, shape (n_bins,) 617 | The mean predicted probability in each bin. 618 | 619 | References 620 | ---------- 621 | Alexandru Niculescu-Mizil and Rich Caruana (2005) Predicting Good 622 | Probabilities With Supervised Learning, in Proceedings of the 22nd 623 | International Conference on Machine Learning (ICML). 624 | See section 4 (Qualitative Analysis of Predictions). 625 | """ 626 | y_true = column_or_1d(y_true) 627 | y_prob = column_or_1d(y_prob) 628 | 629 | if normalize: # Normalize predicted values into interval [0, 1] 630 | y_prob = (y_prob - y_prob.min()) / (y_prob.max() - y_prob.min()) 631 | elif y_prob.min() < 0 or y_prob.max() > 1: 632 | raise ValueError( 633 | 'y_prob has values outside [0, 1] and normalize is ' 634 | 'set to False.' 635 | ) 636 | 637 | y_true = _check_binary_probabilistic_predictions(y_true, y_prob) 638 | 639 | bins = np.linspace(0.0, 1.0 + 1e-8, n_bins + 1) 640 | binids = np.digitize(y_prob, bins) - 1 641 | 642 | bin_sums = np.bincount(binids, weights=y_prob, minlength=len(bins)) 643 | bin_true = np.bincount(binids, weights=y_true, minlength=len(bins)) 644 | bin_total = np.bincount(binids, minlength=len(bins)) 645 | 646 | zero = bin_total == 0 647 | bin_total[zero] = 2 648 | # nonzero = bin_total != 0 649 | 650 | prob_true = bin_true / bin_total 651 | prob_pred = bin_sums / bin_total 652 | 653 | return prob_true, prob_pred 654 | 655 | 656 | def _check_binary_probabilistic_predictions(y_true, y_prob): 657 | """Check that y_true is binary and y_prob contains valid probabilities""" 658 | check_consistent_length(y_true, y_prob) 659 | 660 | labels = np.unique(y_true) 661 | 662 | if len(labels) != 2: 663 | raise ValueError( 664 | 'Only binary classification is supported. ' 665 | 'Provided labels %s.' % labels 666 | ) 667 | 668 | if y_prob.max() > 1: 669 | raise ValueError('y_prob contains values greater than 1.') 670 | 671 | if y_prob.min() < 0: 672 | raise ValueError('y_prob contains values less than 0.') 673 | 674 | return label_binarize(y_true, labels)[:, 0] 675 | -------------------------------------------------------------------------------- /soccer_xg/features.py: -------------------------------------------------------------------------------- 1 | """A collection of feature generators.""" 2 | import numpy as np 3 | import pandas as pd 4 | from socceraction.vaep.features import * 5 | 6 | _spadl_cfg = { 7 | 'length': 105, 8 | 'width': 68, 9 | 'penalty_box_length': 16.5, 10 | 'penalty_box_width': 40.3, 11 | 'six_yard_box_length': 5.5, 12 | 'six_yard_box_width': 18.3, 13 | 'goal_widht': 7.32, 14 | 'penalty_spot_distance': 11, 15 | 'goal_width': 7.3, 16 | 'goal_length': 2, 17 | 'origin_x': 0, 18 | 'origin_y': 0, 19 | 'circle_radius': 9.15, 20 | } 21 | 22 | 23 | @simple 24 | def goalangle(actions, cfg=_spadl_cfg): 25 | dx = cfg['length'] - actions['start_x'] 26 | dy = cfg['width'] / 2 - actions['start_y'] 27 | angledf = pd.DataFrame() 28 | angledf['shot_angle'] = np.arctan( 29 | cfg['goal_width'] 30 | * dx 31 | / (dx ** 2 + dy ** 2 - (cfg['goal_width'] / 2) ** 2) 32 | ) 33 | angledf.loc[angledf['shot_angle'] < 0, 'shot_angle'] += np.pi 34 | angledf.loc[(actions['start_x'] >= cfg['length']), 'shot_angle'] = 0 35 | # Ball is on the goal line 36 | angledf.loc[ 37 | (actions['start_x'] == cfg['length']) 38 | & ( 39 | actions['start_y'].between( 40 | cfg['width'] / 2 - cfg['goal_width'] / 2, 41 | cfg['width'] / 2 + cfg['goal_width'] / 2, 42 | ) 43 | ), 44 | 'shot_angle', 45 | ] = np.pi 46 | return angledf 47 | 48 | 49 | def speed(gamestates): 50 | a0 = gamestates[0] 51 | spaced = pd.DataFrame() 52 | for i, a in enumerate(gamestates[1:]): 53 | dt = a0.time_seconds - a.time_seconds 54 | dt[dt < 1] = 1 55 | dx = a.end_x - a0.start_x 56 | spaced['speedx_a0' + (str(i + 1))] = dx.abs() / dt 57 | dy = a.end_y - a0.start_y 58 | spaced['speedy_a0' + (str(i + 1))] = dy.abs() / dt 59 | spaced['speed_a0' + (str(i + 1))] = np.sqrt(dx ** 2 + dy ** 2) / dt 60 | return spaced 61 | 62 | 63 | def _caley_shot_matrix(cfg=_spadl_cfg): 64 | """ 65 | https://cartilagefreecaptain.sbnation.com/2013/11/13/5098186/shot-matrix-i-shot-location-and-expected-goals 66 | """ 67 | m = (cfg['origin_y'] + cfg['width']) / 2 68 | 69 | zones = [] 70 | # Zone 1 is the central area of the six-yard box 71 | x1 = cfg['origin_x'] + cfg['length'] - cfg['six_yard_box_length'] 72 | x2 = cfg['origin_x'] + cfg['length'] 73 | y1 = m - cfg['goal_width'] / 2 74 | y2 = m + cfg['goal_width'] / 2 75 | zones.append([(x1, y1, x2, y2)]) 76 | # Zone 2 includes the wide areas, left and right, of the six-yard box. 77 | ## Left 78 | x1 = cfg['origin_x'] + cfg['length'] - cfg['six_yard_box_length'] 79 | x2 = cfg['origin_x'] + cfg['length'] 80 | y1 = m - cfg['six_yard_box_width'] / 2 81 | y2 = m - cfg['goal_width'] / 2 82 | zone_left = (x1, y1, x2, y2) 83 | ## Right 84 | x1 = cfg['origin_x'] + cfg['length'] - cfg['six_yard_box_length'] 85 | x2 = cfg['origin_x'] + cfg['length'] 86 | y1 = m + cfg['goal_width'] / 2 87 | y2 = m + cfg['six_yard_box_width'] / 2 88 | zone_right = (x1, y1, x2, y2) 89 | zones.append([zone_left, zone_right]) 90 | # Zone 3 is the central area between the edges of the six- and eighteen-yard boxes. 91 | x1 = cfg['origin_x'] + cfg['length'] - cfg['penalty_box_length'] 92 | x2 = cfg['origin_x'] + cfg['length'] - cfg['six_yard_box_length'] 93 | y1 = m - cfg['six_yard_box_width'] / 2 94 | y2 = m + cfg['six_yard_box_width'] / 2 95 | zones.append([(x1, y1, x2, y2)]) 96 | # Zone 4 comprises the wide areas in the eighteen-yard box, further from the endline than the six-yard box extended. 97 | ## Left 98 | x1 = cfg['origin_x'] + cfg['length'] - cfg['penalty_box_length'] 99 | x2 = cfg['origin_x'] + cfg['length'] - cfg['six_yard_box_length'] - 2 100 | y1 = m - cfg['penalty_box_width'] / 2 101 | y2 = m - cfg['six_yard_box_width'] / 2 102 | zone_left = (x1, y1, x2, y2) 103 | ## Right 104 | x1 = cfg['origin_x'] + cfg['length'] - cfg['penalty_box_length'] 105 | x2 = cfg['origin_x'] + cfg['length'] - cfg['six_yard_box_length'] - 2 106 | y1 = m + cfg['six_yard_box_width'] / 2 107 | y2 = m + cfg['penalty_box_width'] / 2 108 | zone_right = (x1, y1, x2, y2) 109 | zones.append([zone_left, zone_right]) 110 | # Zone 5 includes the wide areas left and right in the eighteen yard box within the six-yard box extended. 111 | ## Left 112 | x1 = cfg['origin_x'] + cfg['length'] - cfg['six_yard_box_length'] - 2 113 | x2 = cfg['origin_x'] + cfg['length'] 114 | y1 = m - cfg['penalty_box_width'] / 2 115 | y2 = m - cfg['six_yard_box_width'] / 2 116 | zone_left = (x1, y1, x2, y2) 117 | ## Right 118 | x1 = cfg['origin_x'] + cfg['length'] - cfg['six_yard_box_length'] - 2 119 | x2 = cfg['origin_x'] + cfg['length'] 120 | y1 = m + cfg['six_yard_box_width'] / 2 121 | y2 = m + cfg['penalty_box_width'] / 2 122 | zone_right = (x1, y1, x2, y2) 123 | zones.append([zone_left, zone_right]) 124 | # Zone 6 is the eighteen-yard box extended out to roughly 35 yards (=32m). 125 | x1 = cfg['origin_x'] + cfg['length'] - 32 126 | x2 = cfg['origin_x'] + cfg['length'] - cfg['penalty_box_length'] 127 | y1 = m - cfg['penalty_box_width'] / 2 128 | y2 = m + cfg['penalty_box_width'] / 2 129 | zones.append([(x1, y1, x2, y2)]) 130 | # Zone 7 is the deep, deep area beyond that 131 | x1 = cfg['origin_x'] 132 | x2 = cfg['origin_x'] + cfg['length'] - 32 133 | y1 = cfg['origin_y'] 134 | y2 = cfg['origin_y'] + cfg['width'] 135 | zones.append([(x1, y1, x2, y2)]) 136 | # Zone 8 comprises the regions right and left of the box. 137 | ## Left 138 | x1 = cfg['origin_x'] + cfg['length'] - 32 139 | x2 = cfg['origin_x'] + cfg['length'] 140 | y1 = cfg['origin_y'] + cfg['width'] 141 | y2 = m + cfg['penalty_box_width'] / 2 142 | zone_left = (x1, y1, x2, y2) 143 | ## Right 144 | x1 = cfg['origin_x'] + cfg['length'] - 32 145 | x2 = cfg['origin_x'] + cfg['length'] 146 | y1 = cfg['origin_y'] 147 | y2 = m - cfg['penalty_box_width'] / 2 148 | zone_right = (x1, y1, x2, y2) 149 | zones.append([zone_left, zone_right]) 150 | return zones 151 | 152 | 153 | def _point_in_rect(rect): 154 | x1, y1, x2, y2 = rect 155 | 156 | def fn(point): 157 | x, y = point 158 | if x1 <= x and x <= x2: 159 | if y1 <= y and y <= y2: 160 | return True 161 | return False 162 | 163 | return fn 164 | 165 | 166 | def triangular_grid(name, angle_bins, dist_bins, symmetrical=False): 167 | @simple 168 | def fn(actions): 169 | zonedf = startpolar(actions) 170 | if symmetrical: 171 | zonedf.loc[ 172 | zonedf.start_angle_to_goal_a0 > np.pi / 2, 173 | 'start_angle_to_goal_a0', 174 | ] -= (np.pi / 2) 175 | dist_bin = np.digitize(zonedf.start_dist_to_goal_a0, dist_bins) 176 | angle_bin = np.digitize(zonedf.start_angle_to_goal_a0, angle_bins) 177 | zonedf[name] = dist_bin * angle_bin + dist_bin 178 | zonedf[name] = pd.Categorical( 179 | zonedf[name], 180 | categories=list(range(len(dist_bins) * len(angle_bins))), 181 | ordered=False, 182 | ) 183 | return zonedf[[name]] 184 | 185 | return fn 186 | 187 | 188 | def rectangular_grid(name, x_bins, y_bins, symmetrical=False, cfg=_spadl_cfg): 189 | @simple 190 | def fn(actions): 191 | zonedf = actions[['start_x', 'start_y']].copy() 192 | if symmetrical: 193 | m = (cfg['origin_y'] + cfg['width']) / 2 194 | zonedf.loc[zonedf.start_y > m, 'start_y'] -= m 195 | x_bin = np.digitize(zonedf.start_x, x_bins) 196 | y_bin = np.digitize(zonedf.start_y, y_bins) 197 | zonedf[name] = x_bin * y_bin + y_bin 198 | zonedf[name] = pd.Categorical( 199 | zonedf[name], 200 | categories=list(range(len(x_bins) * len(y_bins))), 201 | ordered=False, 202 | ) 203 | return zonedf[[name]] 204 | 205 | return fn 206 | 207 | 208 | def custom_grid(name, zones, is_in_zone): 209 | @simple 210 | def fn(actions): 211 | zonedf = actions[['start_x', 'start_y']].copy() 212 | zonedf[name] = [0] * len(actions) # zone 0 if no match 213 | for (i, zone) in enumerate(zones): 214 | for subzone in zone: 215 | zonedf.loc[ 216 | np.apply_along_axis( 217 | is_in_zone(subzone), 218 | 1, 219 | zonedf[['start_x', 'start_y']].values, 220 | ), 221 | name, 222 | ] = (i + 1) 223 | zonedf[name] = pd.Categorical( 224 | zonedf[name], categories=list(range(len(zones) + 1)), ordered=False 225 | ) 226 | return zonedf[[name]] 227 | 228 | return fn 229 | 230 | 231 | caley_grid = custom_grid('caley_zone', _caley_shot_matrix(), _point_in_rect) 232 | 233 | 234 | all_features = [ 235 | actiontype, 236 | bodypart, 237 | result, 238 | startlocation, 239 | endlocation, 240 | movement, 241 | space_delta, 242 | startpolar, 243 | endpolar, 244 | team, 245 | time_delta, 246 | speed, 247 | goalangle, 248 | caley_grid, 249 | triangular_grid( 250 | 'angle_zone', 251 | [-50, -20, 20, 50], 252 | [2, 4, 8, 11, 16, 24, 34, 50], 253 | symmetrical=True, 254 | ), 255 | ] 256 | -------------------------------------------------------------------------------- /soccer_xg/metrics.py: -------------------------------------------------------------------------------- 1 | """A collection of metrics for evaluation xG models.""" 2 | import numpy as np 3 | from scipy import integrate 4 | from sklearn.neighbors import KernelDensity 5 | 6 | 7 | def expected_calibration_error(y_true, y_prob, n_bins=5, strategy='uniform'): 8 | """Compute the Expected Calibration Error (ECE). 9 | 10 | This method implements equation (3) in [1], as well as the ACE variant in [2]. 11 | In this equation the probability of the decided label being correct is 12 | used to estimate the calibration property of the predictor. 13 | 14 | Note: a trade-off exist between using a small number of `n_bins` and the 15 | estimation reliability of the ECE. In particular, this method may produce 16 | unreliable ECE estimates in case there are few samples available in some bins. 17 | 18 | Parameters 19 | ---------- 20 | y_true : array, shape (n_samples,) 21 | True targets. 22 | y_prob : array, shape (n_samples,) 23 | Probabilities of the positive class. 24 | n_bins : int, default=5 25 | Number of bins to discretize the [0, 1] interval. A bigger number 26 | requires more data. Bins with no samples (i.e. without 27 | corresponding values in `y_prob`) will not be returned, thus the 28 | returned arrays may have less than `n_bins` values. 29 | strategy : {'uniform', 'quantile'}, default='uniform' 30 | Strategy used to define the widths of the bins. 31 | uniform 32 | The bins have identical widths. This corresponds to the ECE formula. 33 | quantile 34 | The bins have the same number of samples and depend on `y_prob`. This 35 | corresponds to the ACE formula. 36 | 37 | Returns 38 | ------- 39 | ece : float 40 | The expected calibration error. 41 | 42 | References 43 | ---------- 44 | [1]: Chuan Guo, Geoff Pleiss, Yu Sun, Kilian Q. Weinberger, 45 | On Calibration of Modern Neural Networks. 46 | Proceedings of the 34th International Conference on Machine Learning 47 | (ICML 2017). 48 | arXiv:1706.04599 49 | https://arxiv.org/pdf/1706.04599.pdf 50 | [2]: Nixon, Jeremy, et al., 51 | Measuring calibration in deep learning. 52 | arXiv:1904.01685 53 | https://arxiv.org/abs/1904.01685 54 | 55 | """ 56 | 57 | if y_prob.shape != y_true.shape: 58 | raise ValueError(f'Shapes must match') 59 | if y_prob.min() < 0 or y_prob.max() > 1: 60 | raise ValueError('y_prob has values outside [0, 1].') 61 | labels = np.unique(y_true) 62 | if len(labels) > 2: 63 | raise ValueError('Only binary classification is supported.') 64 | 65 | if strategy == 'quantile': # Determine bin edges by distribution of data 66 | quantiles = np.linspace(0, 1, n_bins + 1) 67 | bins = np.percentile(y_prob, quantiles * 100) 68 | bins[-1] = bins[-1] + 1e-8 69 | elif strategy == 'uniform': 70 | bins = np.linspace(0.0, 1.0 + 1e-8, n_bins + 1) 71 | else: 72 | raise ValueError( 73 | "Invalid entry to 'strategy' input. Strategy " 74 | "must be either 'quantile' or 'uniform'." 75 | ) 76 | 77 | n = y_prob.shape[0] 78 | accs, confs, counts = _reliability(y_true, y_prob, bins) 79 | return np.sum(counts * np.abs(accs - confs) / n) 80 | 81 | 82 | def _reliability(y_true, y_prob, bins): 83 | n_bins = len(bins) - 1 84 | accs = np.zeros(n_bins) 85 | confs = np.zeros(n_bins) 86 | counts = np.zeros(n_bins) 87 | for m in range(n_bins): 88 | low = bins[m] 89 | high = bins[m + 1] 90 | 91 | where_in_bin = (low <= y_prob) & (y_prob < high) 92 | if where_in_bin.sum() > 0: 93 | accs[m] = ( 94 | np.sum((y_prob[where_in_bin] >= 0.5) == y_true[where_in_bin]) 95 | / where_in_bin.sum() 96 | ) 97 | confs[m] = np.mean( 98 | np.maximum(y_prob[where_in_bin], 1 - y_prob[where_in_bin]) 99 | ) 100 | counts[m] = where_in_bin.sum() 101 | 102 | return accs, confs, counts 103 | 104 | 105 | def bayesian_calibration_curve(y_true, y_pred, n_bins=100): 106 | """Compute true and predicted probabilities for a calibration curve using 107 | kernel density estimation instead of bins with a fixed width. 108 | 109 | Parameters 110 | ---------- 111 | y_true : array-like of shape (n_samples,) 112 | True targets. 113 | y_prob : array-like of shape (n_samples,) 114 | Probabilities of the positive class. 115 | n_bins : float, default=100 116 | Number of bins to discretize the [0, 1] interval. A bigger number 117 | requires more data. 118 | 119 | Returns 120 | ------- 121 | prob_true : ndarray of shape (n_bins,) 122 | The proportion of samples whose class is the positive class, in each 123 | bin (fraction of positives). 124 | prob_pred : ndarray of shape (n_bins,) 125 | The mean predicted probability in each bin. 126 | number_total : ndarray of shape (n_bins,) 127 | The number of examples in each bin. 128 | """ 129 | y_true = np.array(y_true, dtype=bool) 130 | bandwidth = 1 / n_bins 131 | kde_pos = KernelDensity(kernel='gaussian', bandwidth=bandwidth).fit( 132 | (y_pred[y_true])[:, np.newaxis] 133 | ) 134 | kde_total = KernelDensity(kernel='gaussian', bandwidth=bandwidth).fit( 135 | y_pred[:, np.newaxis] 136 | ) 137 | sample_probabilities = np.linspace(0.01, 0.99, 99) 138 | number_density_offense_won = np.exp( 139 | kde_pos.score_samples(sample_probabilities[:, np.newaxis]) 140 | ) * np.sum((y_true)) 141 | number_density_total = np.exp( 142 | kde_total.score_samples(sample_probabilities[:, np.newaxis]) 143 | ) * len(y_true) 144 | number_pos = ( 145 | number_density_offense_won 146 | * np.sum(y_true) 147 | / np.sum(number_density_offense_won) 148 | ) 149 | number_total = ( 150 | number_density_total * len(y_true) / np.sum(number_density_total) 151 | ) 152 | predicted_pos_percents = np.nan_to_num(number_pos / number_total, 1) 153 | 154 | return ( 155 | 100.0 * sample_probabilities, 156 | 100.0 * predicted_pos_percents, 157 | number_total, 158 | ) 159 | 160 | 161 | def max_deviation(sample_probabilities, predicted_pos_percents): 162 | """Compute the largest discrepancy between the model and expectation. 163 | """ 164 | abs_deviations = np.abs(predicted_pos_percents - sample_probabilities) 165 | return np.max(abs_deviations) 166 | 167 | 168 | def residual_area(sample_probabilities, predicted_pos_percents): 169 | """Compute the total area under the curve of |predicted prob - expected prob| 170 | """ 171 | abs_deviations = np.abs(predicted_pos_percents - sample_probabilities) 172 | return integrate.trapz(abs_deviations, sample_probabilities) 173 | -------------------------------------------------------------------------------- /soccer_xg/ml/logreg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.linear_model import LogisticRegression, SGDClassifier 3 | from sklearn.pipeline import Pipeline 4 | from soccer_xg.ml.preprocessing import simple_proc_for_linear_algoritms 5 | 6 | 7 | def logreg_gridsearch_classifier( 8 | numeric_features, 9 | categoric_features, 10 | learning_rate=0.08, 11 | use_dask=False, 12 | n_iter=100, 13 | scoring='roc_auc', 14 | ): 15 | """ 16 | Simple classification pipeline using hyperband to optimize logreg hyper-parameters 17 | Parameters 18 | ---------- 19 | `numeric_features` : The list of numeric features 20 | `categoric_features` : The list of categoric features 21 | `learning_rate` : The learning rate 22 | """ 23 | 24 | return _logreg_gridsearch_model( 25 | 'classification', 26 | numeric_features, 27 | categoric_features, 28 | learning_rate, 29 | use_dask, 30 | n_iter, 31 | scoring, 32 | ) 33 | 34 | 35 | def logreg_gridsearch_regressor( 36 | numeric_features, 37 | categoric_features, 38 | learning_rate=0.08, 39 | use_dask=False, 40 | n_iter=100, 41 | scoring='roc_auc', 42 | ): 43 | """ 44 | Simple regression pipeline using hyperband to optimize logreg hyper-parameters 45 | Parameters 46 | ---------- 47 | `numeric_features` : The list of numeric features 48 | `categoric_features` : The list of categoric features 49 | `learning_rate` : The learning rate 50 | """ 51 | 52 | return _logreg_gridsearch_model( 53 | 'regression', 54 | numeric_features, 55 | categoric_features, 56 | learning_rate, 57 | use_dask, 58 | n_iter, 59 | scoring, 60 | ) 61 | 62 | 63 | def _logreg_gridsearch_model( 64 | task, 65 | numeric_features, 66 | categoric_features, 67 | learning_rate, 68 | use_dask, 69 | n_iter, 70 | scoring, 71 | ): 72 | if learning_rate is None: 73 | param_space = { 74 | 'clf__C': np.logspace(-5, 5, 100), 75 | 'clf__class_weight': ['balanced', None], 76 | } 77 | model = LogisticRegression(max_iter=10000, fit_intercept=False) 78 | else: 79 | param_space = { 80 | 'clf__penalty': ['l1', 'l2'], 81 | 'clf__alpha': np.logspace(-5, 5, 100), 82 | 'clf__class_weight': ['balanced', None], 83 | } 84 | learning_rate_schedule = ( 85 | 'constant' if isinstance(learning_rate, float) else learning_rate 86 | ) 87 | eta0 = learning_rate if isinstance(learning_rate, float) else 0 88 | model = SGDClassifier( 89 | learning_rate=learning_rate_schedule, 90 | eta0=eta0, 91 | loss='log', 92 | max_iter=10000, 93 | fit_intercept=False, 94 | ) 95 | 96 | pipe = Pipeline( 97 | [ 98 | ( 99 | 'preprocessing', 100 | simple_proc_for_linear_algoritms( 101 | numeric_features, categoric_features 102 | ), 103 | ), 104 | ('clf', model), 105 | ] 106 | ) 107 | 108 | if use_dask: 109 | from dask_ml.model_selection import RandomizedSearchCV 110 | 111 | return RandomizedSearchCV( 112 | pipe, param_space, n_iter=n_iter, scoring=scoring, cv=5 113 | ) 114 | else: 115 | from sklearn.model_selection import RandomizedSearchCV 116 | 117 | return RandomizedSearchCV( 118 | pipe, param_space, n_iter=n_iter, scoring=scoring, cv=5 119 | ) 120 | -------------------------------------------------------------------------------- /soccer_xg/ml/mlp.py: -------------------------------------------------------------------------------- 1 | from scipy.stats.distributions import randint, uniform 2 | from sklearn.neural_network import MLPClassifier, MLPRegressor 3 | from sklearn.pipeline import Pipeline 4 | 5 | from .preprocessing import simple_proc_for_linear_algoritms 6 | 7 | 8 | def mlp_gridsearch_classifier( 9 | numeric_features, 10 | categoric_features, 11 | learning_rate=0.08, 12 | use_dask=False, 13 | n_iter=100, 14 | scoring='roc_auc', 15 | ): 16 | """ 17 | Simple classification pipeline using hyperband to optimize mlp hyper-parameters 18 | Parameters 19 | ---------- 20 | `numeric_features` : The list of numeric features 21 | `categoric_features` : The list of categoric features 22 | `learning_rate` : The learning rate 23 | """ 24 | 25 | return _mlp_gridsearch_model( 26 | 'classification', 27 | numeric_features, 28 | categoric_features, 29 | learning_rate, 30 | use_dask, 31 | n_iter, 32 | scoring, 33 | ) 34 | 35 | 36 | def mlp_gridsearch_regressor( 37 | numeric_features, 38 | categoric_features, 39 | learning_rate=0.08, 40 | use_dask=False, 41 | n_iter=100, 42 | scoring='roc_auc', 43 | ): 44 | """ 45 | Simple regression pipeline using hyperband to optimize mlp hyper-parameters 46 | Parameters 47 | ---------- 48 | `numeric_features` : The list of numeric features 49 | `categoric_features` : The list of categoric features 50 | `learning_rate` : The learning rate 51 | """ 52 | 53 | return _mlp_gridsearch_model( 54 | 'regression', 55 | numeric_features, 56 | categoric_features, 57 | learning_rate, 58 | use_dask, 59 | n_iter, 60 | scoring, 61 | ) 62 | 63 | 64 | def _mlp_gridsearch_model( 65 | task, 66 | numeric_features, 67 | categoric_features, 68 | learning_rate, 69 | use_dask, 70 | n_iter, 71 | scoring, 72 | ): 73 | param_space = { 74 | 'clf__hidden_layer_sizes': [ 75 | (24,), 76 | (12, 12), 77 | (6, 6, 6, 6), 78 | (4, 4, 4, 4, 4, 4), 79 | (12, 6, 3, 3), 80 | ], 81 | 'clf__activation': ['relu', 'logistic', 'tanh'], 82 | 'clf__batch_size': [16, 32, 64, 128, 256, 512], 83 | 'clf__alpha': uniform(0.0001, 0.9), 84 | 'clf__learning_rate': ['constant', 'adaptive'], 85 | } 86 | 87 | model = ( 88 | MLPClassifier(learning_rate_init=learning_rate) 89 | if task == 'classification' 90 | else MLPRegressor(learning_rate_init=learning_rate) 91 | ) 92 | 93 | pipe = Pipeline( 94 | [ 95 | ( 96 | 'preprocessing', 97 | simple_proc_for_linear_algoritms( 98 | numeric_features, categoric_features 99 | ), 100 | ), 101 | ('clf', model), 102 | ] 103 | ) 104 | 105 | if use_dask: 106 | from dask_ml.model_selection import RandomizedSearchCV 107 | 108 | return RandomizedSearchCV( 109 | pipe, param_space, n_iter=n_iter, scoring=scoring, cv=5 110 | ) 111 | else: 112 | from sklearn.model_selection import RandomizedSearchCV 113 | 114 | return RandomizedSearchCV( 115 | pipe, param_space, n_iter=n_iter, scoring=scoring, cv=5 116 | ) 117 | -------------------------------------------------------------------------------- /soccer_xg/ml/pipeline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import scipy.sparse as sp 4 | from scipy import stats 5 | from sklearn import clone 6 | from sklearn.base import BaseEstimator, TransformerMixin 7 | from sklearn.decomposition import TruncatedSVD 8 | from sklearn.model_selection import cross_val_predict 9 | from sklearn.preprocessing import LabelEncoder, MinMaxScaler 10 | 11 | 12 | class ColumnsSelector(BaseEstimator, TransformerMixin): 13 | def __init__(self, columns): 14 | assert isinstance(columns, list) 15 | self.columns = columns 16 | 17 | def fit(self, X, y=None): 18 | return self 19 | 20 | def transform(self, X): 21 | return X[self.columns] 22 | 23 | 24 | class TolerantLE(LabelEncoder): 25 | def transform(self, y): 26 | return np.searchsorted(self.classes_, y) 27 | 28 | 29 | class UniqueCountColumnSelector(BaseEstimator, TransformerMixin): 30 | """ 31 | To select those columns whose unique-count values are between 32 | lowerbound (inclusive) and upperbound (exclusive) 33 | """ 34 | 35 | def __init__(self, lowerbound, upperbound): 36 | self.lowerbound = lowerbound 37 | self.upperbound = upperbound 38 | 39 | def fit(self, X, y=None): 40 | counts = X.apply(lambda vect: vect.unique().shape[0]) 41 | self.columns = counts.index[ 42 | counts.between(self.lowerbound, self.upperbound + 1) 43 | ] 44 | return self 45 | 46 | def transform(self, X): 47 | return X[self.columns] 48 | 49 | 50 | class ColumnApplier(BaseEstimator, TransformerMixin): 51 | """ 52 | Some sklearn transformers can apply only on ONE column at a time 53 | Wrap them with ColumnApplier to apply on all the dataset 54 | """ 55 | 56 | def __init__(self, underlying): 57 | self.underlying = underlying 58 | 59 | def fit(self, X, y=None): 60 | m = {} 61 | X = pd.DataFrame(X) # TODO: :( reimplement in pure numpy? 62 | for c in X.columns: 63 | k = clone(self.underlying) 64 | k.fit(X[c]) 65 | m[c] = k 66 | self._column_stages = m 67 | return self 68 | 69 | def transform(self, X): 70 | ret = {} 71 | X = pd.DataFrame(X) 72 | for c, k in self._column_stages.items(): 73 | ret[c] = k.transform(X[c]) 74 | return pd.DataFrame(ret)[X.columns] # keep the same order 75 | 76 | 77 | class OrdinalEncoder(BaseEstimator, TransformerMixin): 78 | """ 79 | Encode the categorical value by natural number based on alphabetical order 80 | N/A are encoded to -2 81 | rare values to -1 82 | Very similar to TolerentLabelEncoder 83 | TODO: improve the implementation 84 | """ 85 | 86 | def __init__(self, min_support): 87 | self.min_support = min_support 88 | self.vc = {} 89 | 90 | def _mapping(self, vc): 91 | mapping = {} 92 | for i, v in enumerate(vc[vc >= self.min_support].index): 93 | mapping[v] = i 94 | for v in vc.index[vc < self.min_support]: 95 | mapping[v] = -1 96 | mapping['nan'] = -2 97 | return mapping 98 | 99 | def _transform_column(self, x): 100 | x = x.astype(str) 101 | vc = self.vc[x.name] 102 | 103 | mapping = self._mapping(vc) 104 | 105 | output = pd.DataFrame() 106 | output[x.name] = x.map( 107 | lambda a: mapping[a] if a in mapping.keys() else -3 108 | ) 109 | output.index = x.index 110 | return output.astype(int) 111 | 112 | def fit(self, x, y=None): 113 | x = x.astype(str) 114 | self.vc = dict((c, x[c].value_counts()) for c in x.columns) 115 | return self 116 | 117 | def transform(self, df): 118 | if len(df[df.index.duplicated()]): 119 | print(df[df.index.duplicated()].index) 120 | raise ValueError('Input contains duplicate index') 121 | dfs = [self._transform_column(df[c]) for c in df.columns] 122 | out = pd.DataFrame(index=df.index) 123 | for df in dfs: 124 | out = out.join(df) 125 | return out.values 126 | 127 | 128 | class CountFrequencyEncoder(BaseEstimator, TransformerMixin): 129 | """ 130 | Encode the value by their frequency observed in the training set 131 | """ 132 | 133 | def __init__(self, min_card=5, count_na=False): 134 | self.min_card = min_card 135 | self.count_na = count_na 136 | self.vc = None 137 | 138 | def fit(self, x, y=None): 139 | x = pd.Series(x) 140 | vc = x.value_counts() 141 | self.others_count = vc[vc < self.min_card].sum() 142 | self.vc = vc[vc >= self.min_card].to_dict() 143 | self.num_na = x.isnull().sum() 144 | return self 145 | 146 | def transform(self, x): 147 | vc = self.vc 148 | output = x.map(lambda a: vc.get(a, self.others_count)) 149 | if self.count_na: 150 | output = output.fillna(self.num_na) 151 | return output.values 152 | 153 | 154 | class BoxCoxTransformer(BaseEstimator, TransformerMixin): 155 | """ 156 | Boxcox transformation for numerical columns 157 | To make them more Gaussian-like 158 | """ 159 | 160 | def __init__(self): 161 | self.scaler = MinMaxScaler() 162 | self.shift = 0.0001 163 | 164 | def fit(self, x, y=None): 165 | x = x.values.reshape(-1, 1) 166 | x = self.scaler.fit_transform(x) + self.shift 167 | self.boxcox_lmbda = stats.boxcox(x)[1] 168 | return self 169 | 170 | def transform(self, x): 171 | x = x.values.reshape(-1, 1) 172 | scaled = np.maximum(self.shift, self.scaler.transform(x) + self.shift) 173 | ret = stats.boxcox(scaled, self.boxcox_lmbda) 174 | return ret[:, 0] 175 | 176 | 177 | class Logify(BaseEstimator, TransformerMixin): 178 | """ 179 | Log transformation 180 | """ 181 | 182 | def __init__(self): 183 | self.shift = 2 184 | 185 | def fit(self, x, y=None): 186 | return self 187 | 188 | def transform(self, x): 189 | return np.log10(x - x.min() + self.shift) 190 | 191 | 192 | class YToLog(BaseEstimator, TransformerMixin): 193 | """ 194 | Transforming Y to log before fitting 195 | and transforming back the prediction to real values before return 196 | """ 197 | 198 | def __init__(self, delegate, shift=0): 199 | self.delegate = delegate 200 | self.shift = shift 201 | 202 | def fit(self, X, y): 203 | logy = np.log(y + self.shift) 204 | self.delegate.fit(X, logy) 205 | return self 206 | 207 | def predict(self, X): 208 | pred = self.delegate.predict(X) 209 | return np.exp(pred) - self.shift 210 | 211 | 212 | class FillNaN(BaseEstimator, TransformerMixin): 213 | def __init__(self, replace): 214 | self.replace = replace 215 | 216 | def fit(self, x, y=None): 217 | return self 218 | 219 | def transform(self, x): 220 | return x.fillna(self.replace) 221 | 222 | 223 | class AsString(BaseEstimator, TransformerMixin): 224 | def fit(self, x, y=None): 225 | return self 226 | 227 | def transform(self, x): 228 | return x.astype(str) 229 | 230 | 231 | class StackedModel(BaseEstimator, TransformerMixin): 232 | def __init__(self, delegate, cv=5, method='predict_proba'): 233 | self.delegate = delegate 234 | self.cv = cv 235 | self.method = method 236 | 237 | def fit(self, X, y): 238 | raise Exception 239 | 240 | def fit_transform(self, X, y): 241 | a = cross_val_predict( 242 | self.delegate, X, y, cv=self.cv, method=self.method 243 | ) 244 | self.delegate.fit(X, y) 245 | if len(a.shape) == 1: 246 | a = a.reshape(-1, 1) 247 | return a 248 | 249 | def transform(self, X): 250 | if self.method == 'predict_proba': 251 | return self.delegate.predict_proba(X) 252 | else: 253 | return self.delegate.predict(X).reshape(-1, 1) 254 | 255 | 256 | class To1D(BaseEstimator, TransformerMixin): 257 | def fit(self, X, y=None): 258 | return self 259 | 260 | def transform(self, X): 261 | return X.values.reshape(-1) 262 | 263 | 264 | class TolerantLabelEncoder(TransformerMixin): 265 | """ LabelEncoder is not tolerant to unseen values 266 | """ 267 | 268 | def __init__(self, min_count=10): 269 | self.min_count = min_count 270 | 271 | def fit(self, x, y=None): 272 | assert len(x.shape) == 1 273 | vc = x.value_counts() 274 | vc = vc[vc > self.min_count] 275 | self.values = { 276 | value: (1 + index) for index, value in enumerate(vc.index) 277 | } 278 | return self 279 | 280 | def transform(self, x): 281 | values = self.values 282 | return x.map(lambda a: values.get(a, 0)) 283 | 284 | def inverse_transform(self, y): 285 | if not hasattr(self, 'inversed_mapping'): 286 | self.inversed_mapping = {v: k for k, v in self.values.items()} 287 | self.inversed_mapping[0] = None 288 | mapping = self.inversed_mapping 289 | return pd.Series(y).map(lambda a: mapping[a]) 290 | 291 | 292 | class SVD_Embedding(TransformerMixin): 293 | def __init__( 294 | self, rowname, colname, valuename=None, svd_kwargs={'n_components': 10} 295 | ): 296 | self.rowname = rowname 297 | self.colname = colname 298 | self.valuename = valuename 299 | self.rowle = TolerantLabelEncoder() 300 | self.colle = TolerantLabelEncoder() 301 | self.svd = TruncatedSVD(**svd_kwargs) 302 | 303 | def fit(self, X, y=None): 304 | row = self.rowle.fit_transform(X[self.rowname]) 305 | col = self.colle.fit_transform(X[self.colname]) 306 | if not self.valuename: 307 | data = np.ones(X.shape[0]) 308 | else: 309 | data = X[self.valuename].groupby([row, col]).mean() 310 | row = data.index.get_level_values(0).values 311 | col = data.index.get_level_values(1).values 312 | data = data.values 313 | matrix = sp.coo_matrix((data, (row, col))) 314 | self.embedding = self.svd.fit_transform(matrix) 315 | return self 316 | 317 | def transform(self, X): 318 | row = self.rowle.transform(X[self.rowname]) 319 | return self.embedding[row, :] 320 | -------------------------------------------------------------------------------- /soccer_xg/ml/preprocessing.py: -------------------------------------------------------------------------------- 1 | import category_encoders as ce 2 | from sklearn.impute import SimpleImputer 3 | from sklearn.pipeline import make_pipeline, make_union 4 | from sklearn.preprocessing import OneHotEncoder, StandardScaler 5 | 6 | from .pipeline import AsString, ColumnsSelector, OrdinalEncoder 7 | 8 | 9 | def simple_proc_for_tree_algoritms(numeric_features, categoric_features): 10 | """ 11 | Create a simple preprocessing pipeline for tree based algorithms 12 | """ 13 | 14 | catpipe = make_pipeline( 15 | ColumnsSelector(categoric_features), 16 | OrdinalEncoder(min_support=5) 17 | # ColumnApplier(FillNaN('nan')), 18 | # ColumnApplier(TolerantLabelEncoder()) 19 | ) 20 | numpipe = make_pipeline( 21 | ColumnsSelector(numeric_features), 22 | SimpleImputer(strategy='mean'), 23 | StandardScaler(), 24 | ) 25 | if numeric_features and categoric_features: 26 | return make_union(catpipe, numpipe) 27 | elif numeric_features: 28 | return numpipe 29 | elif categoric_features: 30 | return catpipe 31 | raise Exception('Both variable lists are empty') 32 | 33 | 34 | def simple_proc_for_linear_algoritms(numeric_features, categoric_features): 35 | """ 36 | Create a simple preprocessing pipeline for linear algorithms 37 | """ 38 | 39 | catpipe = make_pipeline( 40 | ColumnsSelector(categoric_features), 41 | AsString(), 42 | ce.OneHotEncoder() 43 | # ColumnApplier(FillNaN('nan')), 44 | # ColumnApplier(TolerantLabelEncoder()) 45 | ) 46 | numpipe = make_pipeline( 47 | ColumnsSelector(numeric_features), 48 | SimpleImputer(strategy='mean'), 49 | StandardScaler(), 50 | ) 51 | if numeric_features and categoric_features: 52 | return make_union(catpipe, numpipe) 53 | elif numeric_features: 54 | return numpipe 55 | elif categoric_features: 56 | return catpipe 57 | raise Exception('Both variable lists are empty') 58 | -------------------------------------------------------------------------------- /soccer_xg/ml/tree_based_LR.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.base import BaseEstimator, ClassifierMixin 3 | from sklearn.preprocessing import OneHotEncoder 4 | 5 | 6 | # class for the tree-based/logistic regression pipeline 7 | # see: https://gdmarmerola.github.io/probability-calibration/ 8 | class TreeBasedLR(BaseEstimator, ClassifierMixin): 9 | 10 | # initialization 11 | def __init__(self, forest, lr): 12 | 13 | # configuring the models 14 | self.forest = forest 15 | self.lr = lr 16 | 17 | # method for fitting the model 18 | def fit(self, X, y, sample_weight=None, fit_params={}): 19 | 20 | self.classes_ = np.unique(y) 21 | 22 | # first, we fit our tree-based model on the dataset 23 | self.forest.fit(X, y, **fit_params) 24 | 25 | # then, we apply the model to the data in order to get the leave indexes 26 | # if self.forest_model == 'cat': 27 | # leaves = self.forest.calc_leaf_indexes(X) 28 | # else: 29 | leaves = self.forest.named_steps['clf'].apply( 30 | self.forest.named_steps['preprocessing'].transform(X) 31 | ) 32 | 33 | # then, we one-hot encode the leave indexes so we can use them in the logistic regression 34 | self.encoder = OneHotEncoder(sparse=True) 35 | leaves_encoded = self.encoder.fit_transform(leaves) 36 | 37 | # and fit it to the encoded leaves 38 | self.lr.fit(leaves_encoded, y) 39 | 40 | # method for predicting probabilities 41 | def predict_proba(self, X): 42 | 43 | # then, we apply the model to the data in order to get the leave indexes 44 | # if self.forest_model == 'cat': 45 | # leaves = self.forest.calc_leaf_indexes(X) 46 | # else: 47 | leaves = self.forest.named_steps['clf'].apply( 48 | self.forest.named_steps['preprocessing'].transform(X) 49 | ) 50 | 51 | # then, we one-hot encode the leave indexes so we can use them in the logistic regression 52 | leaves_encoded = self.encoder.transform(leaves) 53 | 54 | # and fit it to the encoded leaves 55 | y_hat = self.lr.predict_proba(leaves_encoded) 56 | 57 | # retuning probabilities 58 | return y_hat 59 | 60 | # get_params, needed for sklearn estimators 61 | def get_params(self, deep=True): 62 | return { 63 | 'forest': self.forest, 64 | 'lr': self.lr, 65 | } 66 | -------------------------------------------------------------------------------- /soccer_xg/ml/xgboost.py: -------------------------------------------------------------------------------- 1 | from scipy.stats.distributions import randint, uniform 2 | from sklearn.pipeline import Pipeline 3 | from xgboost import sklearn as xgbsk 4 | 5 | from .preprocessing import simple_proc_for_tree_algoritms 6 | 7 | 8 | def xgboost_gridsearch_classifier( 9 | numeric_features, 10 | categoric_features, 11 | learning_rate=0.08, 12 | use_dask=False, 13 | n_iter=100, 14 | scoring='roc_auc', 15 | ): 16 | """ 17 | Simple classification pipeline using hyperband to optimize xgboost hyper-parameters 18 | Parameters 19 | ---------- 20 | `numeric_features` : The list of numeric features 21 | `categoric_features` : The list of categoric features 22 | `learning_rate` : The learning rate 23 | """ 24 | 25 | return _xgboost_gridsearch_model( 26 | 'classification', 27 | numeric_features, 28 | categoric_features, 29 | learning_rate, 30 | use_dask, 31 | n_iter, 32 | scoring, 33 | ) 34 | 35 | 36 | def xgboost_gridsearch_regressor( 37 | numeric_features, 38 | categoric_features, 39 | learning_rate=0.08, 40 | use_dask=False, 41 | n_iter=100, 42 | scoring='roc_auc', 43 | ): 44 | """ 45 | Simple regression pipeline using hyperband to optimize xgboost hyper-parameters 46 | Parameters 47 | ---------- 48 | `numeric_features` : The list of numeric features 49 | `categoric_features` : The list of categoric features 50 | `learning_rate` : The learning rate 51 | """ 52 | 53 | return _xgboost_gridsearch_model( 54 | 'regression', 55 | numeric_features, 56 | categoric_features, 57 | learning_rate, 58 | use_dask, 59 | n_iter, 60 | scoring, 61 | ) 62 | 63 | 64 | def _xgboost_gridsearch_model( 65 | task, 66 | numeric_features, 67 | categoric_features, 68 | learning_rate, 69 | use_dask, 70 | n_iter, 71 | scoring, 72 | ): 73 | param_space = { 74 | 'clf__max_depth': randint(2, 11), 75 | 'clf__min_child_weight': randint(1, 11), 76 | 'clf__subsample': uniform(0.5, 0.5), 77 | 'clf__colsample_bytree': uniform(0.5, 0.5), 78 | 'clf__colsample_bylevel': uniform(0.5, 0.5), 79 | 'clf__gamma': uniform(0, 1), 80 | 'clf__reg_alpha': uniform(0, 1), 81 | 'clf__reg_lambda': uniform(0, 10), 82 | 'clf__base_score': uniform(0.1, 0.9), 83 | 'clf__scale_pos_weight': uniform(0.1, 9.9), 84 | } 85 | 86 | model = ( 87 | xgbsk.XGBClassifier(learning_rate=learning_rate) 88 | if task == 'classification' 89 | else xgbsk.XGBRegressor(learning_rate=learning_rate) 90 | ) 91 | 92 | pipe = Pipeline( 93 | [ 94 | ( 95 | 'preprocessing', 96 | simple_proc_for_tree_algoritms( 97 | numeric_features, categoric_features 98 | ), 99 | ), 100 | ('clf', model), 101 | ] 102 | ) 103 | 104 | if use_dask: 105 | from dask_ml.model_selection import RandomizedSearchCV 106 | 107 | return RandomizedSearchCV( 108 | pipe, param_space, n_iter=n_iter, scoring=scoring, cv=5 109 | ) 110 | else: 111 | from sklearn.model_selection import RandomizedSearchCV 112 | 113 | return RandomizedSearchCV( 114 | pipe, param_space, n_iter=n_iter, scoring=scoring, cv=5 115 | ) 116 | -------------------------------------------------------------------------------- /soccer_xg/models/openplay_logreg_advanced: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-KULeuven/soccer_xg/b9489d929e0fa34771267d429256366d0fda27ad/soccer_xg/models/openplay_logreg_advanced -------------------------------------------------------------------------------- /soccer_xg/models/openplay_logreg_basic: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-KULeuven/soccer_xg/b9489d929e0fa34771267d429256366d0fda27ad/soccer_xg/models/openplay_logreg_basic -------------------------------------------------------------------------------- /soccer_xg/models/openplay_xgboost_advanced: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-KULeuven/soccer_xg/b9489d929e0fa34771267d429256366d0fda27ad/soccer_xg/models/openplay_xgboost_advanced -------------------------------------------------------------------------------- /soccer_xg/models/openplay_xgboost_basic: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-KULeuven/soccer_xg/b9489d929e0fa34771267d429256366d0fda27ad/soccer_xg/models/openplay_xgboost_basic -------------------------------------------------------------------------------- /soccer_xg/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import pandas as pd 4 | import socceraction.spadl.config as spadlcfg 5 | from fuzzywuzzy import fuzz 6 | 7 | 8 | def play_left_to_right(actions, home_team_id): 9 | away_idx = actions.team_id != home_team_id 10 | for col in ['start_x', 'end_x']: 11 | actions.loc[away_idx, col] = ( 12 | spadlcfg.field_length - actions.loc[away_idx][col].values 13 | ) 14 | for col in ['start_y', 'end_y']: 15 | actions.loc[away_idx, col] = ( 16 | spadlcfg.field_width - actions.loc[away_idx][col].values 17 | ) 18 | return actions 19 | 20 | 21 | def enhance_actions(actions): 22 | # data 23 | actiontypes = pd.DataFrame( 24 | list(enumerate(spadlcfg.actiontypes)), columns=['type_id', 'type_name'] 25 | ) 26 | 27 | bodyparts = pd.DataFrame( 28 | list(enumerate(spadlcfg.bodyparts)), 29 | columns=['bodypart_id', 'bodypart_name'], 30 | ) 31 | 32 | results = pd.DataFrame( 33 | list(enumerate(spadlcfg.results)), columns=['result_id', 'result_name'] 34 | ) 35 | 36 | return ( 37 | actions.merge(actiontypes, how='left') 38 | .merge(results, how='left') 39 | .merge(bodyparts, how='left') 40 | # .sort_values(["period_id", "time_seconds", "timestamp"]) 41 | ) 42 | 43 | 44 | def match_name(name, list_names, min_score=0): 45 | # -1 score incase we don't get any matches 46 | max_score = -1 47 | # Returning empty name for no match as well 48 | max_name = '' 49 | # Iternating over all names in the other 50 | for name2 in list_names: 51 | # Finding fuzzy match score 52 | score = fuzz.ratio(name, name2) 53 | # Checking if we are above our threshold and have a better score 54 | if (score > min_score) & (score > max_score): 55 | max_name = name2 56 | max_score = score 57 | return (max_name, max_score) 58 | 59 | 60 | def map_names( 61 | df1, 62 | df1_match_colname, 63 | df1_output_colname, 64 | df2, 65 | df2_match_colname, 66 | df2_output_colname, 67 | threshold=75, 68 | ): 69 | # List for dicts for easy dataframe creation 70 | dict_list = [] 71 | for _, (id, name) in df1[ 72 | [df1_output_colname, df1_match_colname] 73 | ].iterrows(): 74 | # Use our method to find best match, we can set a threshold here 75 | match = match_name(name, df2[df2_match_colname], threshold) 76 | # New dict for storing data 77 | dict_ = {} 78 | dict_.update({'df1_name': name}) 79 | dict_.update({'df1_id': id}) 80 | if match[1] > threshold: 81 | dict_.update({'df2_name': match[0]}) 82 | dict_.update( 83 | { 84 | 'df2_id': df2.loc[ 85 | df2[df2_match_colname] == match[0], df2_output_colname 86 | ].iloc[0] 87 | } 88 | ) 89 | else: 90 | dict_.update({'df2_name': 'unknown'}) 91 | dict_.update({'df2_id': 0}) 92 | dict_list.append(dict_) 93 | merge_table = pd.DataFrame(dict_list) 94 | return merge_table 95 | 96 | 97 | def get_matching_game(api, game_id, provider, other_provider, teams_mapping): 98 | season_id = str(api[provider].games.loc[game_id, 'season_id']) 99 | competition_id = api[provider].games.loc[game_id, 'competition_id'] 100 | # Get matching game 101 | home_team, away_team = api[provider].get_home_away_team_id(game_id) 102 | other_home_team = teams_mapping.set_index(f'{provider}_id').loc[ 103 | home_team, f'{other_provider}_id' 104 | ] 105 | other_away_team = teams_mapping.set_index(f'{provider}_id').loc[ 106 | away_team, f'{other_provider}_id' 107 | ] 108 | other_games = api[other_provider].games 109 | other_game_id = ( 110 | other_games[ 111 | (other_games.home_team_id == other_home_team) 112 | & (other_games.away_team_id == other_away_team) 113 | & (other_games.competition_id == competition_id) 114 | & (other_games.season_id.astype(str) == season_id) 115 | ] 116 | .iloc[0] 117 | .name 118 | ) 119 | return other_game_id 120 | 121 | 122 | def get_matching_shot( 123 | api, 124 | shot, 125 | provider_shot, 126 | other_shots, 127 | provider_other_shots, 128 | teams_mapping, 129 | players_mapping=None, 130 | ): 131 | # Get matching game 132 | game_id = shot.game_id 133 | season_id = str(api[provider_shot].games.loc[game_id, 'season_id']) 134 | competition_id = api[provider_shot].games.loc[game_id, 'competition_id'] 135 | home_team, away_team = api[provider_shot].get_home_away_team_id(game_id) 136 | other_home_team = teams_mapping.set_index(f'{provider_shot}_id').loc[ 137 | home_team, f'{provider_other_shots}_id' 138 | ] 139 | other_away_team = teams_mapping.set_index(f'{provider_shot}_id').loc[ 140 | away_team, f'{provider_other_shots}_id' 141 | ] 142 | other_games = api[provider_other_shots].games 143 | other_game_id = ( 144 | other_games[ 145 | (other_games.home_team_id == other_home_team) 146 | & (other_games.away_team_id == other_away_team) 147 | & (other_games.competition_id == competition_id) 148 | & (other_games.season_id.astype(str) == season_id) 149 | ] 150 | .iloc[0] 151 | .name 152 | ) 153 | other_shots_in_game = other_shots[other_shots.game_id == other_game_id] 154 | # Get matching shot-taker 155 | if players_mapping is not None: 156 | player_id = shot.player_id 157 | other_player_id = players_mapping.set_index(f'{provider_shot}_id').loc[ 158 | int(player_id), f'{provider_other_shots}_id' 159 | ] 160 | other_shots_by_player = other_shots_in_game[ 161 | other_shots_in_game.player_id == other_player_id 162 | ] 163 | else: 164 | other_shots_by_player = other_shots_in_game 165 | # Get shots in same period 166 | period_id = shot.period_id 167 | other_shots_by_player_in_period = other_shots_by_player[ 168 | other_shots_by_player.period_id == period_id 169 | ] 170 | # Get shots that happened around the same time 171 | ts = shot.time_seconds 172 | best_match = other_shots_by_player_in_period.iloc[ 173 | (other_shots_by_player_in_period['time_seconds'] - ts) 174 | .abs() 175 | .argsort()[:1] 176 | ].iloc[0] 177 | if abs(ts - best_match.time_seconds) < 3: 178 | return best_match 179 | return None 180 | 181 | 182 | def sample_temporal(api, size_val=0.0, size_test=0.2): 183 | game_ids = api.games.sort_values(by='game_date').index.values 184 | nb_games = len(game_ids) 185 | games_train = game_ids[ 186 | 0 : math.floor((1 - size_val - size_test) * nb_games) 187 | ] 188 | games_val = game_ids[ 189 | math.ceil((1 - size_val - size_test) * nb_games) : math.floor( 190 | (1 - size_test) * nb_games 191 | ) 192 | ] 193 | games_test = game_ids[math.ceil((1 - size_test) * nb_games) + 1 : -1] 194 | return games_train, games_val, games_test 195 | -------------------------------------------------------------------------------- /soccer_xg/visualisation.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotsoccer as mps 3 | import numpy as np 4 | import numpy.ma as ma 5 | from matplotlib.ticker import MultipleLocator 6 | from sklearn.metrics import auc, roc_curve 7 | from soccer_xg import metrics 8 | 9 | 10 | def plot_calibration_curve( 11 | y_true, 12 | y_pred, 13 | name='Calibration curve', 14 | min_samples=None, 15 | axis=None, 16 | **kwargs, 17 | ): 18 | """Plot the validation data. 19 | 20 | Parameters 21 | ---------- 22 | axis : matplotlib.pyplot.axis object or ``None`` (default=``None``) 23 | If provided, the validation line will be overlaid on ``axis``. 24 | Otherwise, a new figure and axis will be generated and plotted on. 25 | **kwargs 26 | Arguments to ``axis.plot``. 27 | 28 | Returns 29 | ------- 30 | matplotlib.pylot.axis 31 | The axis the plot was made on. 32 | 33 | Raises 34 | ------ 35 | NotFittedError 36 | If the model hasn't been fit **and** validated. 37 | """ 38 | 39 | if axis is None: 40 | axis = plt.figure(figsize=(5, 5)).add_subplot(111) 41 | 42 | axis.set_title(name) 43 | axis.plot([0, 100], [0, 100], ls='--', lw=1, color='grey') 44 | axis.set_xlabel('Predicted probability') 45 | axis.set_ylabel('True probability in each bin') 46 | axis.set_xlim((0, 100)) 47 | axis.xaxis.set_major_locator(MultipleLocator(20)) 48 | axis.xaxis.set_minor_locator(MultipleLocator(10)) 49 | axis.set_ylim((0, 100)) 50 | axis.yaxis.set_major_locator(MultipleLocator(20)) 51 | axis.yaxis.set_minor_locator(MultipleLocator(10)) 52 | # axis.set_aspect(1) 53 | axis.grid(which='both') 54 | 55 | ( 56 | sample_probabilities, 57 | predicted_pos_percents, 58 | num_plays_used, 59 | ) = metrics.bayesian_calibration_curve(y_true, y_pred) 60 | 61 | if min_samples is not None: 62 | axis.plot( 63 | sample_probabilities, 64 | predicted_pos_percents, 65 | c='c', 66 | alpha=0.3, 67 | **kwargs, 68 | ) 69 | sample_probabilities = ma.array(sample_probabilities) 70 | sample_probabilities[num_plays_used < min_samples] = ma.masked 71 | predicted_pos_percents = ma.array(predicted_pos_percents) 72 | predicted_pos_percents[num_plays_used < min_samples] = ma.masked 73 | 74 | max_deviation = metrics.max_deviation( 75 | sample_probabilities, predicted_pos_percents 76 | ) 77 | residual_area = metrics.residual_area( 78 | sample_probabilities, predicted_pos_percents 79 | ) 80 | 81 | axis.plot( 82 | sample_probabilities, 83 | predicted_pos_percents, 84 | c='c', 85 | label='Calibration curve\n(area = %0.2f, max dev = %0.2f)' 86 | % (residual_area, max_deviation), 87 | **kwargs, 88 | ) 89 | 90 | axis.legend(loc='lower right') 91 | 92 | ax2 = axis.twinx() 93 | ax2.hist( 94 | y_pred * 100, 95 | bins=np.arange(0, 101, 1), 96 | density=True, 97 | alpha=0.4, 98 | facecolor='grey', 99 | ) 100 | ax2.set_ylim([0, 0.2]) 101 | ax2.set_yticks([0, 0.1, 0.2]) 102 | 103 | plt.tight_layout() 104 | return axis 105 | 106 | 107 | def plot_roc_curve(y_true, y_prob, name='Calibration curve', axis=None): 108 | 109 | fpr, tpr, _ = roc_curve(y_true, y_prob) 110 | roc_auc = auc(fpr, tpr) 111 | 112 | if axis is None: 113 | axis = plt.figure(figsize=(5, 5)).add_subplot(111) 114 | 115 | axis.plot( 116 | fpr, tpr, linewidth=1, label='ROC curve (area = %0.2f)' % roc_auc 117 | ) 118 | 119 | # reference line, legends, and axis labels 120 | axis.plot([0, 1], [0, 1], linestyle='--', color='gray') 121 | axis.set_title('ROC curve') 122 | axis.set_xlabel('False Positive Rate') 123 | axis.set_ylabel('True Positive Rate') 124 | axis.set_xlim(0, 1) 125 | axis.xaxis.set_major_locator(MultipleLocator(0.20)) 126 | axis.xaxis.set_minor_locator(MultipleLocator(0.10)) 127 | axis.set_ylim(0, 1) 128 | axis.yaxis.set_major_locator(MultipleLocator(0.20)) 129 | axis.yaxis.set_minor_locator(MultipleLocator(0.10)) 130 | axis.grid(which='both') 131 | 132 | # sns.despine() 133 | # plt.gca().xaxis.set_ticks_position('none') 134 | # plt.gca().yaxis.set_ticks_position('none') 135 | plt.gca().legend() 136 | 137 | axis.legend(loc='lower right') 138 | plt.tight_layout() 139 | 140 | 141 | def plot_heatmap(model, data, axis=None): 142 | 143 | if axis is None: 144 | axis = plt.figure(figsize=(8, 10)).add_subplot(111) 145 | 146 | z = model.estimate(data)['xG'].values 147 | axis = mps.field(ax=axis, show=False) 148 | axis = mps.heatmap( 149 | z.reshape((106, 69)).T, show=False, ax=axis, cmap='viridis_r' 150 | ) 151 | axis.set_xlim((70, 108)) 152 | axis.set_axis_off() 153 | return axis 154 | -------------------------------------------------------------------------------- /soccer_xg/xg.py: -------------------------------------------------------------------------------- 1 | """Tools for creating and analyzing xG models.""" 2 | import os 3 | 4 | import joblib 5 | import pandas as pd 6 | import socceraction.spadl.config as spadlcfg 7 | from sklearn.linear_model import LogisticRegression 8 | from sklearn.metrics import brier_score_loss, roc_auc_score 9 | from sklearn.pipeline import make_pipeline 10 | from sklearn.utils.validation import NotFittedError 11 | from soccer_xg import features as fs 12 | from soccer_xg import metrics, utils 13 | from soccer_xg.api import DataApi 14 | from soccer_xg.ml.preprocessing import simple_proc_for_linear_algoritms 15 | from tqdm import tqdm 16 | 17 | 18 | class XGModel(object): 19 | """A wrapper around a pipeline for computing xG values. 20 | 21 | Parameters 22 | ---------- 23 | copy_data : boolean (default=``True``) 24 | Whether or not to copy data when fitting and applying the model. Running the model 25 | in-place (``copy_data=False``) will be faster and have a smaller memory footprint, 26 | but if not done carefully can lead to data integrity issues. 27 | 28 | Attributes 29 | ---------- 30 | model : A Scikit-learn pipeline (or equivalent) 31 | The actual model used to compute xG. Upon initialization it will be set to 32 | a default model, but can be overridden by the user. 33 | column_descriptions : dictionary 34 | A dictionary whose keys are the names of the columns used in the model, and the values are 35 | string descriptions of what the columns mean. Set at initialization to be the default model, 36 | if you create your own model you'll need to update this attribute manually. 37 | training_seasons : A list of tuples, or ``None`` (default=``None``) 38 | If the model was trained using data from the DataApi, a list of (competition_id, season_id) tuples 39 | used to train the model. If the DataApi was **not** used, an empty list. If no model 40 | has been trained yet, ``None``. 41 | validation_seasons : same as ``training_seasons``, but for validation data. 42 | sample_probabilities : A numpy array of floats or ``None`` (default=``None``) 43 | After the model has been validated, contains the sampled predicted probabilities used to 44 | compute the validation statistic. 45 | predicted_goal_percents : A numpy array of floats or ``None`` (default=``None``) 46 | After the model has been validated, contains the actual probabilities in the test 47 | set at each probability in ``sample_probabilities``. 48 | num_shots_used : A numpy array of floats or ``None`` (default=``None``) 49 | After the model has been validated, contains the number of shots used to compute each 50 | element of ``predicted_goal_percents``. 51 | model_directory : string 52 | The directory where all models will be saved to or loaded from. 53 | """ 54 | 55 | model_directory = os.path.join( 56 | os.path.dirname(os.path.abspath(__file__)), 'models' 57 | ) 58 | _default_model_filename = 'default_model.xg' 59 | 60 | def __init__(self, copy_data=True): 61 | self.copy_data = copy_data 62 | self.column_descriptions = None 63 | 64 | self.model = self.create_default_pipeline() 65 | self._fitted = False 66 | self._training_seasons = None 67 | self._validation_seasons = None 68 | 69 | self._sample_probabilities = None 70 | self._predicted_goal_percents = None 71 | self._num_shots_used = None 72 | 73 | @property 74 | def training_seasons(self): 75 | return self._training_seasons 76 | 77 | @property 78 | def validation_seasons(self): 79 | return self._validation_seasons 80 | 81 | @property 82 | def sample_probabilities(self): 83 | return self._sample_probabilities 84 | 85 | @property 86 | def predicted_goal_percents(self): 87 | return self._predicted_goal_percents 88 | 89 | @property 90 | def num_shots_used(self): 91 | return self._num_shots_used 92 | 93 | def train( 94 | self, 95 | source_data, 96 | training_seasons=(('ENG', '1617'), ('ENG', '1718')), 97 | target_colname='goal', 98 | ): 99 | """Train the model. 100 | 101 | Once a modeling pipeline is set up (either the default or something 102 | custom-generated), historical data needs to be fed into it in order to 103 | "fit" the model so that it can then be used to predict future results. 104 | This method implements a simple wrapper around the core Scikit-learn functionality 105 | which does this. 106 | 107 | The default is to use data from a DataApi object, however that can be changed 108 | to a simple Pandas DataFrame with precomputed features and labels if desired. 109 | 110 | There is no particular output from this function, rather the parameters governing 111 | the fit of the model are saved inside the model object itself. If you want to get an 112 | estimate of the quality of the fit, use the ``validate_model`` method after running 113 | this method. 114 | 115 | Parameters 116 | ---------- 117 | source_data : ``DataApi`` or a Pandas DataFrame 118 | The data to be used to train the model. If an instance of 119 | ``DataApi`` is given, will query the api database for the training data. 120 | training_seasons : list of tuples (default=``[('ENG', '1617'), ('ENG', '1718')]``) 121 | What seasons to use to train the model if getting data from a DataApi instance. 122 | If ``source_data`` is not a ``DataApi``, this argument will be ignored. 123 | **NOTE:** it is critical not to use all possible data in order to train the 124 | model - some will need to be reserved for a final validation (see the 125 | ``validate_model`` method). A good dataset to reserve 126 | for validation is the most recent one or two seasons. 127 | target_colname : string or integer (default=``"goal"``) 128 | The name of the target variable column. This is only relevant if 129 | ``source_data`` is not a ``DataApi``. 130 | 131 | Returns 132 | ------- 133 | ``None`` 134 | """ 135 | if isinstance(self.model, list): 136 | for model in self.model: 137 | model.train(source_data, training_seasons, target_colname) 138 | else: 139 | self._training_seasons = [] 140 | if isinstance(source_data, DataApi): 141 | game_ids = source_data.games[ 142 | source_data.games.season_id.astype(str).isin( 143 | [s[1] for s in training_seasons] 144 | ) 145 | & source_data.games.competition_id.astype(str).isin( 146 | [s[0] for s in training_seasons] 147 | ) 148 | ].index 149 | feature_cols = get_features(source_data, game_ids) 150 | target_col = get_labels(source_data, game_ids) 151 | self._training_seasons = training_seasons 152 | else: 153 | target_col = source_data[target_colname] 154 | feature_cols = source_data.drop(target_colname, axis=1) 155 | self.model.fit(feature_cols, target_col) 156 | self._fitted = True 157 | 158 | def validate( 159 | self, 160 | source_data, 161 | validation_seasons=(('ENG', '1819')), 162 | target_colname='goal', 163 | plot=True, 164 | ): 165 | """Validate the model. 166 | 167 | Once a modeling pipeline is trained, a different dataset must be fed into the trained model 168 | to validate the quality of the fit. 169 | This method implements a simple wrapper around the core Scikit-learn functionality 170 | which does this. 171 | 172 | The default is to use data from a DataApi object, however that can be changed 173 | to a simple Pandas DataFrame with precomputed features and labels if desired. 174 | 175 | The output of this method is a dictionary with relevant error metrics (see ``soccer_xg.metrics``). 176 | 177 | Parameters 178 | ---------- 179 | source_data : ``DataApi`` or a Pandas DataFrame 180 | The data to be used to validate the model. If an instance of 181 | ``DataApi`` is given, will query the api database for the training data. 182 | validation_seasons : list of tuples (default=``[('ENG', '1819')]``) 183 | What seasons to use to validated the model if getting data from a DataApi instance. 184 | If ``source_data`` is not a ``DataApi``, this argument will be ignored. 185 | **NOTE:** it is critical not to use the same data to validate the model as was used 186 | in the fit. Generally a good data set to use for validation is one from a time 187 | period more recent than was used to train the model. 188 | target_colname : string or integer (default=``"goal"``) 189 | The name of the target variable column. This is only relevant if 190 | ``source_data`` is not a ``DataApi``. 191 | plot: bool (default=true) 192 | Whether to plot the AUROC and probability calibration curves. 193 | 194 | Returns 195 | ------- 196 | dict 197 | Error metrics on the validation data. 198 | 199 | Raises 200 | ------ 201 | NotFittedError 202 | If the model hasn't been fit. 203 | 204 | """ 205 | if not self._fitted: 206 | raise NotFittedError('Must fit model before validating.') 207 | 208 | if isinstance(source_data, DataApi): 209 | game_ids = source_data.games[ 210 | source_data.games.season_id.astype(str).isin( 211 | [s[1] for s in validation_seasons] 212 | ) 213 | & source_data.games.competition_id.astype(str).isin( 214 | [s[0] for s in validation_seasons] 215 | ) 216 | ].index 217 | target_col = get_labels(source_data, game_ids) 218 | self._validation_seasons = validation_seasons 219 | else: 220 | game_ids = None 221 | target_col = source_data[target_colname] 222 | self._validation_seasons = [] 223 | 224 | df_predictions = self.estimate(source_data, game_ids) 225 | predicted_probabilities = df_predictions['xG'] 226 | target_col = target_col.loc[df_predictions.index] 227 | 228 | ( 229 | self._sample_probabilities, 230 | self._predicted_goal_percents, 231 | self._num_shots_used, 232 | ) = metrics.bayesian_calibration_curve( 233 | target_col.values, predicted_probabilities 234 | ) 235 | 236 | # Compute the maximal deviation from a perfect prediction as well as the area under the 237 | # curve of the residual between |predicted - perfect|: 238 | max_deviation = metrics.max_deviation( 239 | self.sample_probabilities, self.predicted_goal_percents 240 | ) 241 | residual_area = metrics.residual_area( 242 | self.sample_probabilities, self.predicted_goal_percents 243 | ) 244 | roc = roc_auc_score(target_col, predicted_probabilities) 245 | brier = brier_score_loss(target_col, predicted_probabilities) 246 | ece = metrics.expected_calibration_error( 247 | target_col, predicted_probabilities, 10, 'uniform' 248 | ) 249 | ace = metrics.expected_calibration_error( 250 | target_col, predicted_probabilities, 10, 'quantile' 251 | ) 252 | 253 | if plot: 254 | import matplotlib.pyplot as plt 255 | from soccer_xg.visualisation import ( 256 | plot_roc_curve, 257 | plot_calibration_curve, 258 | ) 259 | 260 | fig, ax = plt.subplots(1, 2, figsize=(10, 5)) 261 | plot_roc_curve(target_col, predicted_probabilities, axis=ax[0]) 262 | plot_calibration_curve( 263 | target_col, 264 | predicted_probabilities, 265 | min_samples=100, 266 | axis=ax[1], 267 | ) 268 | 269 | return { 270 | 'max_dev': max_deviation, 271 | 'residual_area': residual_area, 272 | 'roc': roc, 273 | 'brier': brier, 274 | 'ece': ece, 275 | 'ace': ace, 276 | 'fig': fig if plot else None, 277 | } 278 | 279 | def estimate(self, source_data, game_ids=None): 280 | """Estimate the xG values for all shots in a set of games. 281 | 282 | The default is to use data from a DataApi object, however that can be changed 283 | to a simple Pandas DataFrame with precomputed features and labels if desired. 284 | 285 | Parameters 286 | ---------- 287 | source_data : ``DataApi`` or a Pandas DataFrame 288 | The data to be used to validate the model. If an instance of 289 | ``DataApi`` is given, will query the api database for the training data. 290 | game_ids : list of ints (default=None) 291 | Only xG values for the games in this list are returned. By default, 292 | xG values are computed for all games in the source data. 293 | If ``source_data`` is not a ``DataApi``, this argument will be ignored. 294 | 295 | Returns 296 | ------- 297 | A Pandas DataFrame 298 | A dataframe with a column 'xG', containing the predictted xG value 299 | of each shot in the given data, indexed by (game_id, action_id) of 300 | the corresponding shot. 301 | 302 | Raises 303 | ------ 304 | NotFittedError 305 | If the model hasn't been fit. 306 | """ 307 | if not self._fitted: 308 | raise NotFittedError('Must fit model before predicting WP.') 309 | 310 | if isinstance(self.model, list): 311 | xg = [] 312 | for model in self.model: 313 | xg.append(model.estimate(source_data, game_ids)) 314 | return pd.concat(xg).sort_index() 315 | else: 316 | if isinstance(source_data, DataApi): 317 | if game_ids is None: 318 | game_ids = ( 319 | source_data.games.index 320 | if game_ids is None 321 | else game_ids 322 | ) 323 | source_data = get_features(source_data, game_ids) 324 | 325 | xg = pd.DataFrame(index=source_data.index) 326 | xg['xG'] = self.model.predict_proba(source_data)[:, 1] 327 | return xg 328 | 329 | def create_default_pipeline(self): 330 | """Create the default xG estimation pipeline. 331 | 332 | Returns 333 | ------- 334 | Scikit-learn pipeline 335 | The default pipeline, suitable for computing xG 336 | but by no means the best possible model. 337 | """ 338 | models = [OpenplayXGModel(), FreekickXGModel(), PenaltyXGModel()] 339 | self.column_descriptions = { 340 | m.__class__.__name__: m.column_descriptions for m in models 341 | } 342 | return models 343 | 344 | def save_model(self, filename=None): 345 | """Save the XGModel instance to disk. 346 | 347 | All models are saved to the same place, with the installed 348 | soccer_xg library (given by ``XGModel.model_directory``). 349 | 350 | Parameters 351 | ---------- 352 | filename : string (default=None): 353 | The filename to use for the saved model. If this parameter 354 | is not specified, save to the default filename. Note that if a model 355 | already lists with this filename, it will be overwritten. Note also that 356 | this is a filename only, **not** a full path. If a full path is specified 357 | it is likely (albeit not guaranteed) to cause errors. 358 | 359 | Returns 360 | ------- 361 | ``None`` 362 | """ 363 | if filename is None: 364 | filename = self._default_model_filename 365 | joblib.dump(self, os.path.join(self.model_directory, filename)) 366 | 367 | @classmethod 368 | def load_model(cls, filename=None): 369 | """Load a saved XGModel. 370 | 371 | Parameters 372 | ---------- 373 | filename : string (default=None): 374 | The filename to use for the saved model. If this parameter 375 | is not specified, load the default model. Note that 376 | this is a filename only, **not** a full path. 377 | 378 | Returns 379 | ------- 380 | ``soccer_xg.XGModel`` instance. 381 | """ 382 | if filename is None: 383 | filename = cls._default_model_filename 384 | 385 | return joblib.load(os.path.join(cls.model_directory, filename)) 386 | 387 | 388 | class OpenplayXGModel(XGModel): 389 | _default_model_filename = 'default_openplay_model.xg' 390 | 391 | def train( 392 | self, 393 | source_data, 394 | training_seasons=(('ENG', '1617'), ('ENG', '1718')), 395 | target_colname='goal', 396 | ): 397 | self._training_seasons = [] 398 | if isinstance(source_data, DataApi): 399 | game_ids = source_data.games[ 400 | source_data.games.season_id.astype(str).isin( 401 | [s[1] for s in training_seasons] 402 | ) 403 | & source_data.games.competition_id.astype(str).isin( 404 | [s[0] for s in training_seasons] 405 | ) 406 | ].index 407 | feature_cols = get_features( 408 | source_data, game_ids, shotfilter=OpenplayXGModel.filter_shots 409 | ) 410 | target_col = get_labels( 411 | source_data, game_ids, shotfilter=OpenplayXGModel.filter_shots 412 | ) 413 | self._training_seasons = training_seasons 414 | else: 415 | target_col = source_data[target_colname] 416 | feature_cols = source_data.drop(target_colname, axis=1) 417 | self.model.fit(feature_cols, target_col) 418 | self._fitted = True 419 | 420 | def estimate(self, source_data, game_ids=None): 421 | 422 | if isinstance(source_data, DataApi): 423 | game_ids = ( 424 | source_data.games.index if game_ids is None else game_ids 425 | ) 426 | source_data = get_features( 427 | source_data, game_ids, shotfilter=OpenplayXGModel.filter_shots 428 | ) 429 | 430 | xg = pd.DataFrame(index=source_data.index) 431 | xg['xG'] = self.model.predict_proba(source_data)[:, 1] 432 | return xg 433 | 434 | def create_default_pipeline(self): 435 | bodypart_colname = 'bodypart_id_a0' 436 | dist_to_goal_colname = 'start_dist_to_goal_a0' 437 | angle_to_goal_colname = 'start_angle_to_goal_a0' 438 | 439 | self.column_descriptions = { 440 | bodypart_colname: 'Bodypart used for the shot (head, foot or other)', 441 | dist_to_goal_colname: 'Distance to goal', 442 | angle_to_goal_colname: 'Angle to goal', 443 | } 444 | 445 | preprocess_pipeline = simple_proc_for_linear_algoritms( 446 | [dist_to_goal_colname, angle_to_goal_colname], [bodypart_colname] 447 | ) 448 | base_model = LogisticRegression( 449 | max_iter=10000, solver='lbfgs', fit_intercept=False 450 | ) 451 | pipe = make_pipeline(preprocess_pipeline, base_model) 452 | return pipe 453 | 454 | @staticmethod 455 | def filter_shots(df_actions): 456 | shot_idx = ( 457 | df_actions.type_name == 'shot' 458 | ) & df_actions.result_name.isin(['fail', 'success']) 459 | return shot_idx 460 | 461 | 462 | class PenaltyXGModel(XGModel): 463 | _default_model_filename = 'default_penalty_model.xg' 464 | 465 | def __init__(self, copy_data=True): 466 | super().__init__(copy_data) 467 | self._fitted = True 468 | 469 | def train( 470 | self, 471 | source_data, 472 | training_seasons=(('ENG', '1617'), ('ENG', '1718')), 473 | target_colname='goal', 474 | ): 475 | pass 476 | 477 | def estimate(self, source_data, game_ids=None): 478 | 479 | if isinstance(source_data, DataApi): 480 | game_ids = ( 481 | source_data.games.index if game_ids is None else game_ids 482 | ) 483 | source_data = get_features( 484 | source_data, 485 | game_ids, 486 | xfns=[], 487 | shotfilter=PenaltyXGModel.filter_shots, 488 | ) 489 | 490 | xg = pd.DataFrame(index=source_data.index) 491 | xg['xG'] = 0.792453 492 | 493 | return xg 494 | 495 | def create_default_pipeline(self): 496 | return None 497 | 498 | @staticmethod 499 | def filter_shots(df_actions): 500 | shot_idx = df_actions.type_name == 'shot_penalty' 501 | return shot_idx 502 | 503 | 504 | class FreekickXGModel(XGModel): 505 | 506 | _default_model_filename = 'default_freekick_model.xg' 507 | 508 | def train( 509 | self, 510 | source_data, 511 | training_seasons=(('ENG', '1617'), ('ENG', '1718')), 512 | target_colname='goal', 513 | ): 514 | self._training_seasons = [] 515 | if isinstance(source_data, DataApi): 516 | game_ids = source_data.games[ 517 | source_data.games.season_id.astype(str).isin( 518 | [s[1] for s in training_seasons] 519 | ) 520 | & source_data.games.competition_id.astype(str).isin( 521 | [s[0] for s in training_seasons] 522 | ) 523 | ].index 524 | feature_cols = get_features( 525 | source_data, game_ids, shotfilter=FreekickXGModel.filter_shots 526 | ) 527 | target_col = get_labels( 528 | source_data, game_ids, shotfilter=FreekickXGModel.filter_shots 529 | ) 530 | self._training_seasons = training_seasons 531 | else: 532 | target_col = source_data[target_colname] 533 | feature_cols = source_data.drop(target_colname, axis=1) 534 | self.model.fit(feature_cols, target_col) 535 | self._fitted = True 536 | 537 | def estimate(self, source_data, game_ids=None): 538 | 539 | if isinstance(source_data, DataApi): 540 | game_ids = ( 541 | source_data.games.index if game_ids is None else game_ids 542 | ) 543 | source_data = get_features( 544 | source_data, game_ids, shotfilter=FreekickXGModel.filter_shots 545 | ) 546 | 547 | xg = pd.DataFrame(index=source_data.index) 548 | xg['xG'] = self.model.predict_proba(source_data)[:, 1] 549 | return xg 550 | 551 | def create_default_pipeline(self): 552 | dist_to_goal_colname = 'start_dist_to_goal_a0' 553 | angle_to_goal_colname = 'start_angle_to_goal_a0' 554 | 555 | self.column_descriptions = { 556 | dist_to_goal_colname: 'Distance to goal', 557 | angle_to_goal_colname: 'Angle to goal', 558 | } 559 | 560 | preprocess_pipeline = simple_proc_for_linear_algoritms( 561 | [dist_to_goal_colname, angle_to_goal_colname], [] 562 | ) 563 | base_model = LogisticRegression( 564 | max_iter=10000, solver='lbfgs', fit_intercept=True 565 | ) 566 | pipe = make_pipeline(preprocess_pipeline, base_model) 567 | return pipe 568 | 569 | @staticmethod 570 | def filter_shots(df_actions): 571 | shot_idx = df_actions.type_name == 'shot_freekick' 572 | return shot_idx 573 | 574 | 575 | def get_features( 576 | api, 577 | game_ids=None, 578 | xfns=fs.all_features, 579 | shotfilter=None, 580 | nb_prev_actions=3, 581 | ): 582 | game_ids = api.games.index if game_ids is None else game_ids 583 | X = {} 584 | for game_id in tqdm(game_ids, desc=f'Generating features'): 585 | # try: 586 | game = api.games.loc[game_id] 587 | game_actions = utils.enhance_actions(api.get_actions(game_id)) 588 | X[game_id] = _compute_features_game( 589 | game, game_actions, xfns, shotfilter, nb_prev_actions 590 | ) 591 | X[game_id].index.name = 'action_id' 592 | X[game_id]['game_id'] = game_id 593 | # except Exception as e: 594 | # print(f"Failed for game with id={game_id}: {e}") 595 | X = pd.concat(X.values()).reset_index().set_index(['game_id', 'action_id']) 596 | # remove post-shot features (these will all have a single unique value) 597 | f = X.columns[X.nunique() > 1] 598 | return X[f] 599 | 600 | 601 | def _compute_features_game( 602 | game, actions, xfns=fs.all_features, shotfilter=None, nb_prev_actions=3 603 | ): 604 | if shotfilter is None: 605 | # filter shots and ignore own goals 606 | shot_idx = actions.type_name.isin( 607 | ['shot', 'shot_penalty', 'shot_freekick'] 608 | ) & actions.result_name.isin(['fail', 'success']) 609 | else: 610 | shot_idx = shotfilter(actions) 611 | if shot_idx.sum() < 1: 612 | return pd.DataFrame() 613 | if len(xfns) < 1: 614 | return pd.DataFrame(index=actions.index.values[shot_idx]) 615 | # convert actions to gamestates 616 | gamestates = [ 617 | states.loc[shot_idx].copy() 618 | for states in fs.gamestates(actions, nb_prev_actions) 619 | ] 620 | gamestates = fs.play_left_to_right(gamestates, game.home_team_id) 621 | # remove post-shot attributes 622 | gamestates[0].loc[shot_idx, 'end_x'] = float('NaN') 623 | gamestates[0].loc[shot_idx, 'end_y'] = float('NaN') 624 | gamestates[0].loc[shot_idx, 'result_id'] = float('NaN') 625 | # compute features 626 | X = pd.concat([fn(gamestates) for fn in xfns], axis=1) 627 | # fix data types 628 | for c in [c for c in X.columns.values if c.startswith('type_id')]: 629 | X[c] = pd.Categorical( 630 | X[c].replace(spadlcfg.actiontypes_df().type_name.to_dict()), 631 | categories=spadlcfg.actiontypes, 632 | ordered=False, 633 | ) 634 | for c in [c for c in X.columns.values if c.startswith('result_id')]: 635 | X[c] = pd.Categorical( 636 | X[c].replace(spadlcfg.results_df().result_name.to_dict()), 637 | categories=spadlcfg.results, 638 | ordered=False, 639 | ) 640 | for c in [c for c in X.columns.values if c.startswith('bodypart_id')]: 641 | X[c] = pd.Categorical( 642 | X[c].replace(spadlcfg.bodyparts_df().bodypart_name.to_dict()), 643 | categories=spadlcfg.bodyparts, 644 | ordered=False, 645 | ) 646 | return X 647 | 648 | 649 | def get_labels(api, game_ids=None, shotfilter=None): 650 | game_ids = api.games.index if game_ids is None else game_ids 651 | y = {} 652 | for game_id in tqdm(game_ids, desc=f'Generating labels'): 653 | try: 654 | game = api.games.loc[game_id] 655 | game_actions = utils.enhance_actions(api.get_actions(game_id)) 656 | y[game_id] = _compute_labels_game(game, game_actions, shotfilter) 657 | y[game_id].index.name = 'action_id' 658 | y[game_id]['game_id'] = game_id 659 | except Exception as e: 660 | print(e) 661 | return ( 662 | pd.concat(y.values()) 663 | .reset_index() 664 | .set_index(['game_id', 'action_id'])['goal'] 665 | ) 666 | 667 | 668 | def _compute_labels_game(game, actions, shotfilter=None): 669 | # compute labels 670 | y = actions['result_name'] == 'success' 671 | if shotfilter is None: 672 | # filter shots and ignore own goals 673 | shot_idx = actions.type_name.isin( 674 | ['shot', 'shot_penalty', 'shot_freekick'] 675 | ) & actions.result_name.isin(['fail', 'success']) 676 | else: 677 | shot_idx = shotfilter(actions) 678 | return y.loc[shot_idx].to_frame('goal') 679 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import soccer_xg.xg as xg 3 | from soccer_xg.api import DataApi 4 | 5 | 6 | @pytest.fixture(scope='session') 7 | def api(): 8 | return DataApi('tests/data/spadl-statsbomb-WC-2018.h5') 9 | 10 | 11 | @pytest.fixture() 12 | def model(): 13 | return xg.XGModel() 14 | -------------------------------------------------------------------------------- /tests/data/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import socceraction.spadl as spadl 5 | import socceraction.spadl.statsbomb as statsbomb 6 | from tqdm import tqdm 7 | 8 | seasons = { 9 | 3: '2018', 10 | } 11 | leagues = { 12 | 'FIFA World Cup': 'WC', 13 | } 14 | 15 | free_open_data_remote = ( 16 | 'https://raw.githubusercontent.com/statsbomb/open-data/master/data/' 17 | ) 18 | spadl_datafolder = 'tests/data' 19 | 20 | SBL = statsbomb.StatsBombLoader(root=free_open_data_remote, getter='remote') 21 | 22 | # View all available competitions 23 | df_competitions = SBL.competitions() 24 | df_selected_competitions = df_competitions[ 25 | df_competitions.competition_name.isin(leagues.keys()) 26 | ] 27 | 28 | for competition in df_selected_competitions.itertuples(): 29 | # Get matches from all selected competition 30 | matches = SBL.matches(competition.competition_id, competition.season_id) 31 | 32 | matches_verbose = tqdm( 33 | list(matches.itertuples()), desc='Loading match data' 34 | ) 35 | teams, players, player_games = [], [], [] 36 | 37 | competition_id = leagues[competition.competition_name] 38 | season_id = seasons[competition.season_id] 39 | spadl_h5 = os.path.join( 40 | spadl_datafolder, f'spadl-statsbomb-{competition_id}-{season_id}.h5' 41 | ) 42 | with pd.HDFStore(spadl_h5) as spadlstore: 43 | 44 | spadlstore.put('actiontypes', spadl.actiontypes_df(), format='table') 45 | spadlstore.put('results', spadl.results_df(), format='table') 46 | spadlstore.put('bodyparts', spadl.bodyparts_df(), format='table') 47 | 48 | for match in matches_verbose: 49 | # load data 50 | teams.append(SBL.teams(match.match_id)) 51 | players.append(SBL.players(match.match_id)) 52 | events = SBL.events(match.match_id) 53 | 54 | # convert data 55 | player_games.append(statsbomb.extract_player_games(events)) 56 | spadlstore.put( 57 | f'actions/game_{match.match_id}', 58 | statsbomb.convert_to_actions(events, match.home_team_id), 59 | format='table', 60 | ) 61 | 62 | games = matches.rename( 63 | columns={'match_id': 'game_id', 'match_date': 'game_date'} 64 | ) 65 | games.season_id = season_id 66 | games.competition_id = competition_id 67 | spadlstore.put('games', games) 68 | spadlstore.put( 69 | 'teams', 70 | pd.concat(teams).drop_duplicates('team_id').reset_index(drop=True), 71 | ) 72 | spadlstore.put( 73 | 'players', 74 | pd.concat(players) 75 | .drop_duplicates('player_id') 76 | .reset_index(drop=True), 77 | ) 78 | spadlstore.put( 79 | 'player_games', pd.concat(player_games).reset_index(drop=True) 80 | ) 81 | -------------------------------------------------------------------------------- /tests/test_api.py: -------------------------------------------------------------------------------- 1 | def test_get_actions(api): 2 | df_actions = api.get_actions(7584) 3 | assert len(df_actions) 4 | 5 | 6 | def test_get_home_away_team_id(api): 7 | home_id, away_id = api.get_home_away_team_id(7584) 8 | assert home_id == 782 9 | assert away_id == 778 10 | 11 | 12 | def test_get_team_name(api): 13 | name = api.get_team_name(782) 14 | assert name == 'Belgium' 15 | 16 | 17 | def test_get_player_name(api): 18 | name = api.get_player_name(3089) 19 | assert name == 'Kevin De Bruyne' 20 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import soccer_xg.metrics as metrics_lib 4 | 5 | # 6 | # expected_calibration_error 7 | # 8 | 9 | 10 | def test_expected_calibration_error(): 11 | np.random.seed(1) 12 | nsamples = 100 13 | probs = np.linspace(0, 1, nsamples) 14 | labels = np.random.rand(nsamples) < probs 15 | ece = metrics_lib.expected_calibration_error(labels, probs) 16 | bad_ece = metrics_lib.expected_calibration_error(labels, probs / 2) 17 | 18 | assert ece > 0 and ece < 1 19 | assert bad_ece > 0 and bad_ece < 1 20 | assert ece < bad_ece 21 | 22 | 23 | def test_expected_calibration_error_all_wrong(): 24 | n_bins = 90 25 | ece = metrics_lib.expected_calibration_error( 26 | np.ones(10), np.zeros(10), n_bins=n_bins 27 | ) 28 | assert ece == pytest.approx(1.0) 29 | 30 | ece = metrics_lib.expected_calibration_error( 31 | np.zeros(10), np.ones(10), n_bins=n_bins 32 | ) 33 | assert ece == pytest.approx(1.0) 34 | 35 | 36 | def test_expected_calibration_error_all_right(): 37 | n_bins = 90 38 | ece = metrics_lib.expected_calibration_error( 39 | np.ones(10), np.ones(10), n_bins=n_bins 40 | ) 41 | assert ece == pytest.approx(0.0) 42 | 43 | ece = metrics_lib.expected_calibration_error( 44 | np.zeros(10), np.zeros(10), n_bins=n_bins 45 | ) 46 | assert ece == pytest.approx(0.0) 47 | -------------------------------------------------------------------------------- /tests/test_xg.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | 4 | from soccer_xg import xg 5 | 6 | 7 | class TestDefaults(object): 8 | """Tests for defaults.""" 9 | 10 | def test_column_descriptions_set(self, model): 11 | assert isinstance(model.column_descriptions, collections.Mapping) 12 | 13 | 14 | class TestModelTrain(object): 15 | """Tests for the train_model method.""" 16 | 17 | def test_api_input(self, model, api): 18 | model.train(source_data=api, training_seasons=[('WC', '2018')]) 19 | 20 | def test_dataframe_input(self, model, api): 21 | features = xg.get_features(api) 22 | labels = xg.get_labels(api) 23 | df = features.assign(goal=labels) 24 | model.train(source_data=df) 25 | 26 | 27 | class TestModelValidate(object): 28 | """Tests for the validate_model method.""" 29 | 30 | def test_api_input(self, model, api): 31 | model.train(source_data=api, training_seasons=[('WC', '2018')]) 32 | model.validate( 33 | source_data=api, validation_seasons=[('WC', '2018')], plot=False 34 | ) 35 | 36 | def test_dataframe_input(self, model, api): 37 | features = xg.get_features(api) 38 | labels = xg.get_labels(api) 39 | df = features.assign(goal=labels) 40 | model.train(source_data=df) 41 | model.validate(source_data=df, plot=False) 42 | 43 | 44 | class TestModelIO(object): 45 | """Tests functions that deal with model saving and loading""" 46 | 47 | def teardown_method(self, method): 48 | try: 49 | os.remove(self.expected_path) 50 | except OSError: 51 | pass 52 | 53 | def test_model_save_default(self, model): 54 | model_name = 'test_hazard.xgmodel' 55 | model._default_model_filename = model_name 56 | 57 | self.expected_path = os.path.join( 58 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 59 | 'soccer_xg', 60 | 'models', 61 | model_name, 62 | ) 63 | assert os.path.isfile(self.expected_path) is False 64 | 65 | model.save_model() 66 | assert os.path.isfile(self.expected_path) is True 67 | 68 | def test_model_save_specified(self, model): 69 | model = xg.XGModel() 70 | model_name = 'test_lukaku.xgmodel' 71 | 72 | self.expected_path = os.path.join( 73 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 74 | 'soccer_xg', 75 | 'models', 76 | model_name, 77 | ) 78 | assert os.path.isfile(self.expected_path) is False 79 | 80 | model.save_model(filename=model_name) 81 | assert os.path.isfile(self.expected_path) is True 82 | 83 | def test_model_load_default(self, model): 84 | model_name = 'test_witsel.xgmodel' 85 | model._default_model_filename = model_name 86 | 87 | self.expected_path = os.path.join( 88 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 89 | 'soccer_xg', 90 | 'models', 91 | model_name, 92 | ) 93 | assert os.path.isfile(self.expected_path) is False 94 | 95 | model.save_model() 96 | 97 | xGModel_class = xg.XGModel 98 | xGModel_class._default_model_filename = model_name 99 | 100 | loaded_model = xGModel_class.load_model() 101 | 102 | assert isinstance(loaded_model, xg.XGModel) 103 | 104 | def test_model_load_specified(self): 105 | model = xg.XGModel() 106 | model_name = 'test_kompany.xgmodel' 107 | 108 | self.expected_path = os.path.join( 109 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 110 | 'soccer_xg', 111 | 'models', 112 | model_name, 113 | ) 114 | assert os.path.isfile(self.expected_path) is False 115 | 116 | model.save_model(filename=model_name) 117 | 118 | loaded_model = xg.XGModel.load_model(filename=model_name) 119 | assert isinstance(loaded_model, xg.XGModel) 120 | 121 | 122 | def test_get_features(api): 123 | features = xg.get_features(api, game_ids=[7584]) 124 | assert len(features) == 40 # one row for each shot 125 | 126 | 127 | def test_get_labels(api): 128 | labels = xg.get_labels(api, game_ids=[7584]) 129 | assert len(labels) == 40 # one row for each shot 130 | assert labels.sum() == 5 131 | --------------------------------------------------------------------------------