├── .bumpversion.cfg
├── .github
└── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
├── .gitignore
├── .travis.yml
├── LICENSE
├── Makefile
├── README.md
├── data
└── .gitkeep
├── images
└── hero.png
├── notebooks
├── 1-load-and-convert-statsbomb-data.ipynb
├── 1-load-and-convert-wyscout-data.ipynb
├── 2-basic-usage.ipynb
├── 3-computing-and-storing-features.ipynb
└── 4-creating-custom-xg-pipelines.ipynb
├── poetry.lock
├── pyproject.toml
├── setup.cfg
├── soccer_xg
├── __init__.py
├── api.py
├── calibration.py
├── features.py
├── metrics.py
├── ml
│ ├── logreg.py
│ ├── mlp.py
│ ├── pipeline.py
│ ├── preprocessing.py
│ ├── tree_based_LR.py
│ └── xgboost.py
├── models
│ ├── openplay_logreg_advanced
│ ├── openplay_logreg_basic
│ ├── openplay_xgboost_advanced
│ └── openplay_xgboost_basic
├── utils.py
├── visualisation.py
└── xg.py
└── tests
├── conftest.py
├── data
└── download.py
├── test_api.py
├── test_metrics.py
└── test_xg.py
/.bumpversion.cfg:
--------------------------------------------------------------------------------
1 | [bumpversion]
2 | current_version = 0.0.1
3 | commit = True
4 | tag = True
5 |
6 | [bumpversion:file:pyproject.toml]
7 | search = version = "{current_version}"
8 | replace = version = "{new_version}"
9 |
10 | [bumpversion:file:soccer_xg/__init__.py]
11 | search = __version__ = '{current_version}'
12 | replace = __version__ = '{new_version}'
13 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: bug
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Run code
16 | ```python
17 | print("Hello world")
18 | ```
19 | 2. On [minimal data example](/link/to/data.ext)
20 | 3. See error
21 |
22 | **Expected behavior**
23 | A clear and concise description of what you expected to happen.
24 |
25 | **Additional context**
26 | Add any other context about the problem here.
27 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: enhancement
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .envrc
2 | analysis/
3 | # exclude all data, but keep the folder
4 | data/*
5 | !data/.gitkeep
6 | tests/data/*
7 | !tests/data/download.py
8 | # exclude all models, unless explicitly allowed
9 | soccer_xg/models/*
10 | !soccer_xg/models/openplay_logreg_basic
11 | !soccer_xg/models/openplay_xgboost_basic
12 | !soccer_xg/models/openplay_logreg_advanced
13 | !soccer_xg/models/openplay_xgboost_advanced
14 |
15 |
16 | # Byte-compiled / optimized / DLL files
17 | __pycache__/
18 | *.py[cod]
19 | *$py.class
20 |
21 | # C extensions
22 | *.so
23 |
24 | # Distribution / packaging
25 | .Python
26 | build/
27 | develop-eggs/
28 | dist/
29 | downloads/
30 | eggs/
31 | .eggs/
32 | lib/
33 | lib64/
34 | parts/
35 | sdist/
36 | var/
37 | wheels/
38 | share/python-wheels/
39 | *.egg-info/
40 | .installed.cfg
41 | *.egg
42 | MANIFEST
43 |
44 | # PyInstaller
45 | # Usually these files are written by a python script from a template
46 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
47 | *.manifest
48 | *.spec
49 |
50 | # Installer logs
51 | pip-log.txt
52 | pip-delete-this-directory.txt
53 |
54 | # Unit test / coverage reports
55 | htmlcov/
56 | .tox/
57 | .nox/
58 | .coverage
59 | .coverage.*
60 | .cache
61 | nosetests.xml
62 | coverage.xml
63 | *.cover
64 | *.py,cover
65 | .hypothesis/
66 | .pytest_cache/
67 | cover/
68 |
69 | # Translations
70 | *.mo
71 | *.pot
72 |
73 | # Django stuff:
74 | *.log
75 | local_settings.py
76 | db.sqlite3
77 | db.sqlite3-journal
78 |
79 | # Flask stuff:
80 | instance/
81 | .webassets-cache
82 |
83 | # Scrapy stuff:
84 | .scrapy
85 |
86 | # Sphinx documentation
87 | docs/_build/
88 |
89 | # PyBuilder
90 | .pybuilder/
91 | target/
92 |
93 | # Jupyter Notebook
94 | .ipynb_checkpoints
95 |
96 | # IPython
97 | profile_default/
98 | ipython_config.py
99 |
100 | # pyenv
101 | # For a library or package, you might want to ignore these files since the code is
102 | # intended to run in multiple environments; otherwise, check them in:
103 | # .python-version
104 |
105 | # pipenv
106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
109 | # install all needed dependencies.
110 | #Pipfile.lock
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 |
3 | cache: pip
4 |
5 | python:
6 | - "3.6"
7 |
8 | install:
9 | - pip install poetry codecov
10 | - poetry install
11 |
12 | script:
13 | - make test BIN=""
14 |
15 | after_success:
16 | - codecov
17 |
18 | before_deploy:
19 | - poetry build
20 |
21 | deploy:
22 | provider: script
23 | script: poetry publish --no-interaction --username=$PYPI_USER --password=$PYPI_PASS
24 | skip_cleanup: true
25 | on:
26 | branch: master
27 | tags: true
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | https://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | Copyright 2013-2018 Docker, Inc.
180 |
181 | Licensed under the Apache License, Version 2.0 (the "License");
182 | you may not use this file except in compliance with the License.
183 | You may obtain a copy of the License at
184 |
185 | https://www.apache.org/licenses/LICENSE-2.0
186 |
187 | Unless required by applicable law or agreed to in writing, software
188 | distributed under the License is distributed on an "AS IS" BASIS,
189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
190 | See the License for the specific language governing permissions and
191 | limitations under the License.
192 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: init test lint pretty precommit_install bump_major bump_minor bump_patch clean
2 |
3 | BIN = .venv/bin/
4 | CODE = soccer_xg
5 |
6 | init:
7 | python3 -m venv .venv
8 | poetry install
9 |
10 | tests/data/spadl-statsbomb-WC-2018.h5:
11 | $(BIN)python tests/data/download.py
12 |
13 | test: tests/data/spadl-statsbomb-WC-2018.h5
14 | $(BIN)pytest --verbosity=2 --showlocals --strict --log-level=DEBUG --cov=$(CODE) $(args)
15 |
16 | lint:
17 | $(BIN)flake8 --jobs 4 --statistics --show-source $(CODE) tests
18 | $(BIN)pylint --jobs 4 --rcfile=setup.cfg $(CODE)
19 | $(BIN)mypy $(CODE) tests
20 | $(BIN)black --py36 --skip-string-normalization --line-length=79 --check $(CODE) tests
21 | $(BIN)pytest --dead-fixtures --dup-fixtures
22 |
23 | pretty:
24 | $(BIN)isort --apply --recursive $(CODE) tests
25 | $(BIN)black --target-version py36 --skip-string-normalization --line-length=79 $(CODE) tests
26 | $(BIN)unify --in-place --recursive $(CODE) tests
27 |
28 | precommit_install:
29 | echo '#!/bin/sh\nmake test\n' > .git/hooks/pre-commit
30 | chmod +x .git/hooks/pre-commit
31 |
32 | bump_major:
33 | $(BIN)bumpversion major
34 |
35 | bump_minor:
36 | $(BIN)bumpversion minor
37 |
38 | bump_patch:
39 | $(BIN)bumpversion patch
40 |
41 | clean:
42 | find . -type f -name "*.py[co]" -delete
43 | find . -type d -name "__pycache__" -delete
44 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Soccer xG
3 |
A Python package for training and analyzing expected goals (xG) models in soccer.
4 |

5 |
6 |
7 |
8 |
9 |
10 | ## About
11 |
12 | This repository contains the code and models for our series on the analysis of xG models:
13 |
14 | - [How data availability affects the ability to learn good xG models](https://dtai.cs.kuleuven.be/sports/blog/how-data-availability-affects-the-ability-to-learn-good-xg-models)
15 | - [Illustrating the interplay between features and models in xG](https://dtai.cs.kuleuven.be/sports/blog/illustrating-the-interplay-between-features-and-models-in-xg)
16 | - [How data quality affects xG](https://dtai.cs.kuleuven.be/sports/blog/how-data-quality-affects-xg)
17 |
18 | In particular, it contains code for experimenting with an exhaustive set of features and machine learning pipelines for predicting xG values from soccer event stream data. Since we rely on the [SPADL](https://github.com/ML-KULeuven/socceraction) language as input format, `soccer_xg` currently supports event streams provided by Opta, Wyscout, and StatsBomb.
19 |
20 | ## Getting started
21 |
22 | The recommended way to install `soccer_xg` is to simply use pip:
23 |
24 | ```sh
25 | $ pip install soccer-xg
26 | ```
27 |
28 | Subsequently, a basic xG model can be trained and applied with the code below:
29 |
30 | ```python
31 | from itertools import product
32 | from soccer_xg import XGModel, DataApi
33 |
34 | # load the data
35 | provider = 'wyscout_opensource'
36 | leagues = ['ENG', 'ESP', 'ITA', 'GER', 'FRA']
37 | seasons = ['1718']
38 | api = DataApi([f"data/{provider}/spadl-{provider}-{l}-{s}.h5"
39 | for (l,s) in product(leagues, seasons)])
40 | # load the default pipeline
41 | model = XGModel()
42 | # train the model
43 | model.train(api, training_seasons=[('ESP', '1718'), ('ITA', '1718'), ('GER', '1718')])
44 | # validate the model
45 | model.validate(api, validation_seasons=[('ENG', '1718')])
46 | # predict xG values
47 | model.estimate(api, game_ids=[2500098])
48 | ```
49 |
50 | Although this default pipeline is suitable for computing xG, it is by no means the best possible model.
51 | The notebook [`4-creating-custom-xg-pipelines`](./notebooks/4-creating-custom-xg-pipelines.ipynb) illustrates how you can train your own xG models or you can use one of the four pipelines used in our blogpost series. These can be loaded with:
52 |
53 | ```python
54 | XGModel.load_model('openplay_logreg_basic')
55 | XGModel.load_model('openplay_xgboost_basic')
56 | XGModel.load_model('openplay_logreg_advanced')
57 | XGModel.load_model('openplay_xgboost_advanced')
58 | ```
59 |
60 | Note that these models are meant to predict shots from open play. To be able to compute xG values from all shot types, you will have to combine them with a pipeline for penalties and free kicks.
61 |
62 | ```python
63 | from soccer_xg import xg
64 |
65 | openplay_model = xg.XGModel.load_model(f'openplay_xgboost_advanced') # custom pipeline for open play shots
66 | penalty_model = xg.PenaltyXGModel() # default pipeline for penalties
67 | freekick_model = xg.FreekickXGModel() # default pipeline for free kicks
68 |
69 | model = xg.XGModel()
70 | model.model = [openplay_model, penalty_model, freekick_model]
71 | model.train(api, training_seasons=...)
72 | ```
73 |
74 | ## For developers
75 |
76 | **Create venv and install deps**
77 |
78 | make init
79 |
80 | **Install git precommit hook**
81 |
82 | make precommit_install
83 |
84 | **Run linters, autoformat, tests etc.**
85 |
86 | make pretty lint test
87 |
88 | **Bump new version**
89 |
90 | make bump_major
91 | make bump_minor
92 | make bump_patch
93 |
94 | ## Research
95 |
96 | If you make use of this package in your research, please use the following citation:
97 |
98 | ```
99 | @inproceedings{robberechts2020data,
100 | title={How data availability affects the ability to learn good xG models},
101 | author={Robberechts, Pieter and Davis, Jesse},
102 | booktitle={International Workshop on Machine Learning and Data Mining for Sports Analytics},
103 | pages={17--27},
104 | year={2020},
105 | organization={Springer}
106 | }
107 | ```
108 |
109 | ## License
110 |
111 | Copyright (c) DTAI - KU Leuven – All rights reserved.
112 | Licensed under the Apache License, Version 2.0
113 | Written by [Pieter Robberechts](https://people.cs.kuleuven.be/~pieter.robberechts/), 2020
114 |
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ML-KULeuven/soccer_xg/b9489d929e0fa34771267d429256366d0fda27ad/data/.gitkeep
--------------------------------------------------------------------------------
/images/hero.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ML-KULeuven/soccer_xg/b9489d929e0fa34771267d429256366d0fda27ad/images/hero.png
--------------------------------------------------------------------------------
/notebooks/1-load-and-convert-statsbomb-data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Data Preparation\n",
8 | "\n",
9 | "This notebook loads the 2018 World Cup dataset provided by StatsBomb and converts it to the [SPADL format](https://github.com/ML-KULeuven/socceraction)."
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "**Disclaimer**: this notebook is compatible with the following package versions:\n",
17 | "\n",
18 | "- tqdm 4.42.1\n",
19 | "- pandas 1.0\n",
20 | "- socceraction 0.1.1"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 1,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "import os; import sys\n",
30 | "from tqdm.notebook import tqdm\n",
31 | "\n",
32 | "import math\n",
33 | "import pandas as pd\n",
34 | "\n",
35 | "import socceraction.spadl as spadl\n",
36 | "import socceraction.spadl.statsbomb as statsbomb"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {},
42 | "source": [
43 | "## Configure leagues and seasons to download and convert\n",
44 | "The two dictionaries below map my internal season and league IDs to Statsbomb's IDs. Using an internal ID makes it easier to work with data from multiple providers."
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 2,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "seasons = {\n",
54 | " 3: '2018',\n",
55 | "}\n",
56 | "leagues = {\n",
57 | " 'FIFA World Cup': 'WC',\n",
58 | "}"
59 | ]
60 | },
61 | {
62 | "cell_type": "markdown",
63 | "metadata": {},
64 | "source": [
65 | "## Configure folder names and download URLs\n",
66 | "\n",
67 | "The two cells below define the URLs from where the data are downloaded and were data is stored."
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 3,
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "free_open_data_remote = \"https://raw.githubusercontent.com/statsbomb/open-data/master/data/\""
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": 4,
82 | "metadata": {},
83 | "outputs": [
84 | {
85 | "name": "stdout",
86 | "output_type": "stream",
87 | "text": [
88 | "Directory ../data/statsbomb_opensource/raw created \n"
89 | ]
90 | }
91 | ],
92 | "source": [
93 | "spadl_datafolder = \"../data/statsbomb_opensource\"\n",
94 | "raw_datafolder = f\"../data/statsbomb_opensource/raw\"\n",
95 | "\n",
96 | "# Create data folder if it doesn't exist\n",
97 | "for d in [raw_datafolder, spadl_datafolder]:\n",
98 | " if not os.path.exists(d):\n",
99 | " os.makedirs(d, exist_ok=True)\n",
100 | " print(f\"Directory {d} created \")"
101 | ]
102 | },
103 | {
104 | "cell_type": "markdown",
105 | "metadata": {},
106 | "source": [
107 | "## Set up the statsbombloader"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": 5,
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "SBL = statsbomb.StatsBombLoader(root=free_open_data_remote, getter=\"remote\")"
117 | ]
118 | },
119 | {
120 | "cell_type": "markdown",
121 | "metadata": {},
122 | "source": [
123 | "## Select competitions to load and convert"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 6,
129 | "metadata": {},
130 | "outputs": [
131 | {
132 | "data": {
133 | "text/plain": [
134 | "{'Champions League',\n",
135 | " \"FA Women's Super League\",\n",
136 | " 'FIFA World Cup',\n",
137 | " 'La Liga',\n",
138 | " 'NWSL',\n",
139 | " 'Premier League',\n",
140 | " \"Women's World Cup\"}"
141 | ]
142 | },
143 | "execution_count": 6,
144 | "metadata": {},
145 | "output_type": "execute_result"
146 | }
147 | ],
148 | "source": [
149 | "# View all available competitions\n",
150 | "df_competitions = SBL.competitions()\n",
151 | "set(df_competitions.competition_name)"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": 7,
157 | "metadata": {},
158 | "outputs": [
159 | {
160 | "data": {
161 | "text/html": [
162 | "\n",
163 | "\n",
176 | "
\n",
177 | " \n",
178 | " \n",
179 | " | \n",
180 | " competition_id | \n",
181 | " season_id | \n",
182 | " country_name | \n",
183 | " competition_name | \n",
184 | " competition_gender | \n",
185 | " season_name | \n",
186 | " match_updated | \n",
187 | " match_available | \n",
188 | "
\n",
189 | " \n",
190 | " \n",
191 | " \n",
192 | " 17 | \n",
193 | " 43 | \n",
194 | " 3 | \n",
195 | " International | \n",
196 | " FIFA World Cup | \n",
197 | " male | \n",
198 | " 2018 | \n",
199 | " 2019-12-16T23:09:16.168756 | \n",
200 | " 2019-12-16T23:09:16.168756 | \n",
201 | "
\n",
202 | " \n",
203 | "
\n",
204 | "
"
205 | ],
206 | "text/plain": [
207 | " competition_id season_id country_name competition_name \\\n",
208 | "17 43 3 International FIFA World Cup \n",
209 | "\n",
210 | " competition_gender season_name match_updated \\\n",
211 | "17 male 2018 2019-12-16T23:09:16.168756 \n",
212 | "\n",
213 | " match_available \n",
214 | "17 2019-12-16T23:09:16.168756 "
215 | ]
216 | },
217 | "execution_count": 7,
218 | "metadata": {},
219 | "output_type": "execute_result"
220 | }
221 | ],
222 | "source": [
223 | "df_selected_competitions = df_competitions[df_competitions.competition_name.isin(\n",
224 | " leagues.keys()\n",
225 | ")]\n",
226 | "\n",
227 | "df_selected_competitions"
228 | ]
229 | },
230 | {
231 | "cell_type": "markdown",
232 | "metadata": {},
233 | "source": [
234 | "## Convert to the SPADL format"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": 8,
240 | "metadata": {
241 | "scrolled": true
242 | },
243 | "outputs": [
244 | {
245 | "data": {
246 | "application/vnd.jupyter.widget-view+json": {
247 | "model_id": "ea0aaba0b07640538fcea5bc70a41681",
248 | "version_major": 2,
249 | "version_minor": 0
250 | },
251 | "text/plain": [
252 | "HBox(children=(FloatProgress(value=0.0, description='Loading match data', max=64.0, style=ProgressStyle(descri…"
253 | ]
254 | },
255 | "metadata": {},
256 | "output_type": "display_data"
257 | },
258 | {
259 | "name": "stdout",
260 | "output_type": "stream",
261 | "text": [
262 | "\n"
263 | ]
264 | },
265 | {
266 | "name": "stderr",
267 | "output_type": "stream",
268 | "text": [
269 | "/home/pieterr/Jupiter/Projects/soccer_dataprovider_comparison/.venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py:3331: PerformanceWarning: \n",
270 | "your performance may suffer as PyTables will pickle object types that it cannot\n",
271 | "map directly to c-types [inferred_type->mixed,key->block1_values] [items->Index(['game_date', 'kick_off', 'competition_id', 'country_name',\n",
272 | " 'competition_name', 'season_id', 'season_name', 'home_team_name',\n",
273 | " 'home_team_gender', 'home_team_group', 'name', 'managers',\n",
274 | " 'away_team_name', 'away_team_gender', 'away_team_group', 'match_status',\n",
275 | " 'last_updated', 'data_version'],\n",
276 | " dtype='object')]\n",
277 | "\n",
278 | " exec(code_obj, self.user_global_ns, self.user_ns)\n",
279 | "/home/pieterr/Jupiter/Projects/soccer_dataprovider_comparison/.venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py:3331: PerformanceWarning: \n",
280 | "your performance may suffer as PyTables will pickle object types that it cannot\n",
281 | "map directly to c-types [inferred_type->mixed,key->block1_values] [items->Index(['player_name', 'player_nickname', 'country_name', 'extra'], dtype='object')]\n",
282 | "\n",
283 | " exec(code_obj, self.user_global_ns, self.user_ns)\n",
284 | "/home/pieterr/Jupiter/Projects/soccer_dataprovider_comparison/.venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py:3331: PerformanceWarning: \n",
285 | "your performance may suffer as PyTables will pickle object types that it cannot\n",
286 | "map directly to c-types [inferred_type->mixed-integer,key->block2_values] [items->Index(['player_name', 'position_name', 'extra', 'team_name'], dtype='object')]\n",
287 | "\n",
288 | " exec(code_obj, self.user_global_ns, self.user_ns)\n"
289 | ]
290 | }
291 | ],
292 | "source": [
293 | "for competition in df_selected_competitions.itertuples():\n",
294 | " # Get matches from all selected competition\n",
295 | " matches = SBL.matches(competition.competition_id, competition.season_id)\n",
296 | "\n",
297 | " matches_verbose = tqdm(list(matches.itertuples()), desc=\"Loading match data\")\n",
298 | " teams, players, player_games = [], [], []\n",
299 | " \n",
300 | " competition_id = leagues[competition.competition_name]\n",
301 | " season_id = seasons[competition.season_id]\n",
302 | " spadl_h5 = os.path.join(spadl_datafolder, f\"spadl-statsbomb_opensource-{competition_id}-{season_id}.h5\")\n",
303 | " with pd.HDFStore(spadl_h5) as spadlstore:\n",
304 | " \n",
305 | " spadlstore[\"actiontypes\"] = spadl.actiontypes_df()\n",
306 | " spadlstore[\"results\"] = spadl.results_df()\n",
307 | " spadlstore[\"bodyparts\"] = spadl.bodyparts_df()\n",
308 | " \n",
309 | " for match in matches_verbose:\n",
310 | " # load data\n",
311 | " teams.append(SBL.teams(match.match_id))\n",
312 | " players.append(SBL.players(match.match_id))\n",
313 | " events = SBL.events(match.match_id)\n",
314 | "\n",
315 | " # convert data\n",
316 | " player_games.append(statsbomb.extract_player_games(events))\n",
317 | " spadlstore[f\"actions/game_{match.match_id}\"] = statsbomb.convert_to_actions(events,match.home_team_id)\n",
318 | "\n",
319 | " games = matches.rename(columns={\"match_id\": \"game_id\", \"match_date\": \"game_date\"})\n",
320 | " games.season_id = season_id\n",
321 | " games.competition_id = competition_id\n",
322 | " spadlstore[\"games\"] = games\n",
323 | " spadlstore[\"teams\"] = pd.concat(teams).drop_duplicates(\"team_id\").reset_index(drop=True)\n",
324 | " spadlstore[\"players\"] = pd.concat(players).drop_duplicates(\"player_id\").reset_index(drop=True)\n",
325 | " spadlstore[\"player_games\"] = pd.concat(player_games).reset_index(drop=True)"
326 | ]
327 | },
328 | {
329 | "cell_type": "code",
330 | "execution_count": null,
331 | "metadata": {},
332 | "outputs": [],
333 | "source": []
334 | }
335 | ],
336 | "metadata": {
337 | "kernelspec": {
338 | "display_name": "soccer_dataprovider_comparison",
339 | "language": "python",
340 | "name": "soccer_dataprovider_comparison"
341 | },
342 | "language_info": {
343 | "codemirror_mode": {
344 | "name": "ipython",
345 | "version": 3
346 | },
347 | "file_extension": ".py",
348 | "mimetype": "text/x-python",
349 | "name": "python",
350 | "nbconvert_exporter": "python",
351 | "pygments_lexer": "ipython3",
352 | "version": "3.6.2"
353 | },
354 | "toc": {
355 | "base_numbering": 1,
356 | "nav_menu": {},
357 | "number_sections": true,
358 | "sideBar": true,
359 | "skip_h1_title": false,
360 | "title_cell": "Table of Contents",
361 | "title_sidebar": "Contents",
362 | "toc_cell": false,
363 | "toc_position": {},
364 | "toc_section_display": true,
365 | "toc_window_display": true
366 | }
367 | },
368 | "nbformat": 4,
369 | "nbformat_minor": 2
370 | }
371 |
--------------------------------------------------------------------------------
/notebooks/1-load-and-convert-wyscout-data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Data Preparation\n",
8 | "\n",
9 | "This notebook downloads the opensource [Wyscoutmatch event dataset](https://figshare.com/collections/Soccer_match_event_dataset/4415000/2) and converts it to the [SPADL format](https://github.com/ML-KULeuven/socceraction). This dataset contains all spatio-temporal events (passes, shots, fouls, etc.) that occured during all matches of the 2017/18 season of the top-5 European leagues (La Liga, Serie A, Bundesliga, Premier League, Ligue 1) as well as the FIFA World Cup 2018 and UEFA Euro Cup 2016."
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "**Disclaimer**: this notebook is compatible with [v5 of the Soccer match event dataset](https://figshare.com/collections/Soccer_match_event_dataset/4415000/5) and the following package versions:\n",
17 | "\n",
18 | "- tqdm 4.42.1\n",
19 | "- pandas 1.0\n",
20 | "- socceraction 0.1.1"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 11,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "import os\n",
30 | "import sys\n",
31 | "\n",
32 | "from tqdm.notebook import tqdm\n",
33 | "\n",
34 | "import math\n",
35 | "\n",
36 | "import pandas as pd\n",
37 | "pd.set_option('display.max_columns', None)\n",
38 | "\n",
39 | "from io import BytesIO\n",
40 | "from pathlib import Path\n",
41 | "\n",
42 | "from urllib.parse import urlparse\n",
43 | "from urllib.request import urlopen, urlretrieve\n",
44 | "# optional: if you get a SSL CERTIFICATE_VERIFY_FAILED exception\n",
45 | "import ssl; ssl._create_default_https_context = ssl._create_unverified_context\n",
46 | "\n",
47 | "from zipfile import ZipFile, is_zipfile\n",
48 | "\n",
49 | "import socceraction.spadl as spadl\n",
50 | "import socceraction.spadl.wyscout as wyscout"
51 | ]
52 | },
53 | {
54 | "cell_type": "markdown",
55 | "metadata": {},
56 | "source": [
57 | "## Configure leagues and seasons to download and convert\n",
58 | "The two dictionaries below map my internal season and league IDs to Wyscout's IDs. Using an internal ID makes it easier to work with data from multiple providers."
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 12,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "seasons = {\n",
68 | " 181248: '1718',\n",
69 | " 181150: '1718',\n",
70 | " 181144: '1718',\n",
71 | " 181189: '1718',\n",
72 | " 181137: '1718'\n",
73 | "}\n",
74 | "leagues = {\n",
75 | " 'England':'ENG',\n",
76 | " 'France':'FRA',\n",
77 | " 'Germany':'GER',\n",
78 | " 'Italy':'ITA',\n",
79 | " 'Spain':'ESP'\n",
80 | "}"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "metadata": {},
86 | "source": [
87 | "## Configure folder names and download URLs\n",
88 | "\n",
89 | "The two cells below define the URLs from where the data are downloaded and were data is stored."
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 13,
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "# https://figshare.com/collections/Soccer_match_event_dataset/4415000/5\n",
99 | "dataset_urls = dict(\n",
100 | " competitions = \"https://ndownloader.figshare.com/files/15073685\",\n",
101 | " teams = \"https://ndownloader.figshare.com/files/15073697\",\n",
102 | " players = \"https://ndownloader.figshare.com/files/15073721\",\n",
103 | " matches = \"https://ndownloader.figshare.com/files/14464622\",\n",
104 | " events = \"https://ndownloader.figshare.com/files/14464685\"\n",
105 | ")"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 14,
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "raw_datafolder = \"../data/wyscout_opensource/raw\"\n",
115 | "spadl_datafolder = \"../data/wyscout_opensource\"\n",
116 | "\n",
117 | "# Create data folder if it doesn't exist\n",
118 | "for d in [raw_datafolder, spadl_datafolder]:\n",
119 | " if not os.path.exists(d):\n",
120 | " os.makedirs(d, exist_ok=True)\n",
121 | " print(f\"Directory {d} created \")"
122 | ]
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "metadata": {},
127 | "source": [
128 | "## Download WyScout data \n",
129 | "\n",
130 | "The following cell loops through the dataset_urls dict and stores each downloaded data file to the `raw_datafolder` in the local file system.\n",
131 | "\n",
132 | "If the downloaded data file is a ZIP archive, the included JSON files are extracted from the ZIP archive."
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": 15,
138 | "metadata": {},
139 | "outputs": [
140 | {
141 | "data": {
142 | "application/vnd.jupyter.widget-view+json": {
143 | "model_id": "ff51207484174bffbc24d14fde05ca2b",
144 | "version_major": 2,
145 | "version_minor": 0
146 | },
147 | "text/plain": [
148 | "HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))"
149 | ]
150 | },
151 | "metadata": {},
152 | "output_type": "display_data"
153 | },
154 | {
155 | "name": "stdout",
156 | "output_type": "stream",
157 | "text": [
158 | "\n",
159 | "Downloaded files:\n"
160 | ]
161 | },
162 | {
163 | "data": {
164 | "text/plain": [
165 | "['events_France.json',\n",
166 | " 'events_England.json',\n",
167 | " 'events.zip',\n",
168 | " 'eventid2name.csv',\n",
169 | " 'coaches.json',\n",
170 | " 'tags2name.csv',\n",
171 | " 'competitions.json',\n",
172 | " 'referees.json',\n",
173 | " 'matches_Germany.json',\n",
174 | " 'events_Italy.json',\n",
175 | " 'matches.zip',\n",
176 | " 'matches_Italy.json',\n",
177 | " 'events_European_Championship.json',\n",
178 | " 'teams.json',\n",
179 | " 'matches_France.json',\n",
180 | " 'events_Germany.json',\n",
181 | " 'events_Spain.json',\n",
182 | " 'events_World_Cup.json',\n",
183 | " 'players.json',\n",
184 | " 'matches_World_Cup.json',\n",
185 | " 'matches_European_Championship.json',\n",
186 | " 'matches_England.json',\n",
187 | " 'matches_Spain.json']"
188 | ]
189 | },
190 | "execution_count": 15,
191 | "metadata": {},
192 | "output_type": "execute_result"
193 | }
194 | ],
195 | "source": [
196 | "for url in tqdm(dataset_urls.values()):\n",
197 | " url_obj = urlopen(url).geturl()\n",
198 | " path = Path(urlparse(url_obj).path)\n",
199 | " file_name = os.path.join(raw_datafolder, path.name)\n",
200 | " file_local, _ = urlretrieve(url_obj, file_name)\n",
201 | " if is_zipfile(file_local):\n",
202 | " with ZipFile(file_local) as zip_file:\n",
203 | " zip_file.extractall(raw_datafolder)\n",
204 | "\n",
205 | "print(\"Downloaded files:\")\n",
206 | "os.listdir(raw_datafolder)"
207 | ]
208 | },
209 | {
210 | "cell_type": "markdown",
211 | "metadata": {},
212 | "source": [
213 | "## Preprocess Wyscout data"
214 | ]
215 | },
216 | {
217 | "cell_type": "markdown",
218 | "metadata": {},
219 | "source": [
220 | "The read_json_file function reads and returns the content of a given JSON file. The function handles the encoding of special characters (e.g., accents in names of players and teams) that the pd.read_json function cannot handle properly."
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 16,
226 | "metadata": {},
227 | "outputs": [],
228 | "source": [
229 | "def read_json_file(filename):\n",
230 | " with open(filename, 'rb') as json_file:\n",
231 | " return BytesIO(json_file.read()).getvalue().decode('unicode_escape')"
232 | ]
233 | },
234 | {
235 | "cell_type": "markdown",
236 | "metadata": {},
237 | "source": [
238 | "Wyscout does not distinguish between headers and other body\n",
239 | "parts on shots. The socceraction convertor simply labels all\n",
240 | "shots as performed by foot. I think it is better to label \n",
241 | "them as headers."
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": 17,
247 | "metadata": {},
248 | "outputs": [],
249 | "source": [
250 | "def determine_bodypart_id(event):\n",
251 | " \"\"\"\n",
252 | " This function determines the body part used for an event\n",
253 | " Args:\n",
254 | " event (pd.Series): Wyscout event Series\n",
255 | " Returns:\n",
256 | " int: id of the body part used for the action\n",
257 | " \"\"\"\n",
258 | " if event[\"subtype_id\"] in [81, 36, 21, 90, 91]:\n",
259 | " body_part = \"other\"\n",
260 | " elif event[\"subtype_id\"] == 82 or event['head/body']:\n",
261 | " body_part = \"head\"\n",
262 | " else: # all other cases\n",
263 | " body_part = \"foot\"\n",
264 | " return spadl.config.bodyparts.index(body_part)\n",
265 | "wyscout.determine_bodypart_id = determine_bodypart_id"
266 | ]
267 | },
268 | {
269 | "cell_type": "markdown",
270 | "metadata": {},
271 | "source": [
272 | "### Select competitions to load and convert"
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": 18,
278 | "metadata": {},
279 | "outputs": [
280 | {
281 | "data": {
282 | "text/plain": [
283 | "{'England',\n",
284 | " 'European Championship',\n",
285 | " 'France',\n",
286 | " 'Germany',\n",
287 | " 'Italy',\n",
288 | " 'Spain',\n",
289 | " 'World Cup'}"
290 | ]
291 | },
292 | "execution_count": 18,
293 | "metadata": {},
294 | "output_type": "execute_result"
295 | }
296 | ],
297 | "source": [
298 | "json_competitions = read_json_file(f\"{raw_datafolder}/competitions.json\")\n",
299 | "df_competitions = pd.read_json(json_competitions)\n",
300 | "# Rename competitions to the names used in the file names\n",
301 | "df_competitions['name'] = df_competitions.apply(lambda x: x.area['name'] if x.area['name'] != \"\" else x['name'], axis=1)\n",
302 | "df_competitions['id'] = df_competitions.apply(lambda x: leagues.get(x.area['name'], 'NULL'), axis=1)\n",
303 | "# View all available competitions\n",
304 | "set(df_competitions.name)"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": 19,
310 | "metadata": {},
311 | "outputs": [
312 | {
313 | "data": {
314 | "text/html": [
315 | "\n",
316 | "\n",
329 | "
\n",
330 | " \n",
331 | " \n",
332 | " | \n",
333 | " name | \n",
334 | " wyId | \n",
335 | " format | \n",
336 | " area | \n",
337 | " type | \n",
338 | " id | \n",
339 | "
\n",
340 | " \n",
341 | " \n",
342 | " \n",
343 | " 0 | \n",
344 | " Italy | \n",
345 | " 524 | \n",
346 | " Domestic league | \n",
347 | " {'name': 'Italy', 'id': '380', 'alpha3code': '... | \n",
348 | " club | \n",
349 | " ITA | \n",
350 | "
\n",
351 | " \n",
352 | " 1 | \n",
353 | " England | \n",
354 | " 364 | \n",
355 | " Domestic league | \n",
356 | " {'name': 'England', 'id': '0', 'alpha3code': '... | \n",
357 | " club | \n",
358 | " ENG | \n",
359 | "
\n",
360 | " \n",
361 | " 2 | \n",
362 | " Spain | \n",
363 | " 795 | \n",
364 | " Domestic league | \n",
365 | " {'name': 'Spain', 'id': '724', 'alpha3code': '... | \n",
366 | " club | \n",
367 | " ESP | \n",
368 | "
\n",
369 | " \n",
370 | " 3 | \n",
371 | " France | \n",
372 | " 412 | \n",
373 | " Domestic league | \n",
374 | " {'name': 'France', 'id': '250', 'alpha3code': ... | \n",
375 | " club | \n",
376 | " FRA | \n",
377 | "
\n",
378 | " \n",
379 | " 4 | \n",
380 | " Germany | \n",
381 | " 426 | \n",
382 | " Domestic league | \n",
383 | " {'name': 'Germany', 'id': '276', 'alpha3code':... | \n",
384 | " club | \n",
385 | " GER | \n",
386 | "
\n",
387 | " \n",
388 | "
\n",
389 | "
"
390 | ],
391 | "text/plain": [
392 | " name wyId format \\\n",
393 | "0 Italy 524 Domestic league \n",
394 | "1 England 364 Domestic league \n",
395 | "2 Spain 795 Domestic league \n",
396 | "3 France 412 Domestic league \n",
397 | "4 Germany 426 Domestic league \n",
398 | "\n",
399 | " area type id \n",
400 | "0 {'name': 'Italy', 'id': '380', 'alpha3code': '... club ITA \n",
401 | "1 {'name': 'England', 'id': '0', 'alpha3code': '... club ENG \n",
402 | "2 {'name': 'Spain', 'id': '724', 'alpha3code': '... club ESP \n",
403 | "3 {'name': 'France', 'id': '250', 'alpha3code': ... club FRA \n",
404 | "4 {'name': 'Germany', 'id': '276', 'alpha3code':... club GER "
405 | ]
406 | },
407 | "execution_count": 19,
408 | "metadata": {},
409 | "output_type": "execute_result"
410 | }
411 | ],
412 | "source": [
413 | "df_selected_competitions = df_competitions[df_competitions.name.isin(\n",
414 | " leagues.keys()\n",
415 | ")]\n",
416 | "df_selected_competitions"
417 | ]
418 | },
419 | {
420 | "cell_type": "markdown",
421 | "metadata": {},
422 | "source": [
423 | "## Convert to the SPADL format"
424 | ]
425 | },
426 | {
427 | "cell_type": "code",
428 | "execution_count": 22,
429 | "metadata": {},
430 | "outputs": [
431 | {
432 | "name": "stdout",
433 | "output_type": "stream",
434 | "text": [
435 | "Converting ITA 1718\n"
436 | ]
437 | },
438 | {
439 | "data": {
440 | "application/vnd.jupyter.widget-view+json": {
441 | "model_id": "1ab744a7c2ae49959e710b309b62cb1a",
442 | "version_major": 2,
443 | "version_minor": 0
444 | },
445 | "text/plain": [
446 | "HBox(children=(FloatProgress(value=0.0, max=380.0), HTML(value='')))"
447 | ]
448 | },
449 | "metadata": {},
450 | "output_type": "display_data"
451 | },
452 | {
453 | "name": "stdout",
454 | "output_type": "stream",
455 | "text": [
456 | "\n",
457 | "Converting ENG 1718\n"
458 | ]
459 | },
460 | {
461 | "data": {
462 | "application/vnd.jupyter.widget-view+json": {
463 | "model_id": "50692b69f4da415397980283ea063570",
464 | "version_major": 2,
465 | "version_minor": 0
466 | },
467 | "text/plain": [
468 | "HBox(children=(FloatProgress(value=0.0, max=380.0), HTML(value='')))"
469 | ]
470 | },
471 | "metadata": {},
472 | "output_type": "display_data"
473 | },
474 | {
475 | "name": "stdout",
476 | "output_type": "stream",
477 | "text": [
478 | "\n",
479 | "Converting ESP 1718\n"
480 | ]
481 | },
482 | {
483 | "data": {
484 | "application/vnd.jupyter.widget-view+json": {
485 | "model_id": "f8357efe4bb34d7981647a9dea4acb43",
486 | "version_major": 2,
487 | "version_minor": 0
488 | },
489 | "text/plain": [
490 | "HBox(children=(FloatProgress(value=0.0, max=380.0), HTML(value='')))"
491 | ]
492 | },
493 | "metadata": {},
494 | "output_type": "display_data"
495 | },
496 | {
497 | "name": "stdout",
498 | "output_type": "stream",
499 | "text": [
500 | "\n",
501 | "Converting FRA 1718\n"
502 | ]
503 | },
504 | {
505 | "data": {
506 | "application/vnd.jupyter.widget-view+json": {
507 | "model_id": "758211fbdae845cab30001d48dba442c",
508 | "version_major": 2,
509 | "version_minor": 0
510 | },
511 | "text/plain": [
512 | "HBox(children=(FloatProgress(value=0.0, max=380.0), HTML(value='')))"
513 | ]
514 | },
515 | "metadata": {},
516 | "output_type": "display_data"
517 | },
518 | {
519 | "name": "stdout",
520 | "output_type": "stream",
521 | "text": [
522 | "\n",
523 | "Converting GER 1718\n"
524 | ]
525 | },
526 | {
527 | "data": {
528 | "application/vnd.jupyter.widget-view+json": {
529 | "model_id": "f4594b50fcd948d9a371a3f8fa47dd65",
530 | "version_major": 2,
531 | "version_minor": 0
532 | },
533 | "text/plain": [
534 | "HBox(children=(FloatProgress(value=0.0, max=306.0), HTML(value='')))"
535 | ]
536 | },
537 | "metadata": {},
538 | "output_type": "display_data"
539 | },
540 | {
541 | "name": "stdout",
542 | "output_type": "stream",
543 | "text": [
544 | "\n"
545 | ]
546 | }
547 | ],
548 | "source": [
549 | "json_teams = read_json_file(f\"{raw_datafolder}/teams.json\")\n",
550 | "df_teams = wyscout.convert_teams(pd.read_json(json_teams))\n",
551 | "\n",
552 | "json_players = read_json_file(f\"{raw_datafolder}/players.json\")\n",
553 | "df_players = wyscout.convert_players(pd.read_json(json_players))\n",
554 | "\n",
555 | "\n",
556 | "for competition in df_selected_competitions.itertuples():\n",
557 | " json_matches = read_json_file(f\"{raw_datafolder}/matches_{competition.name}.json\")\n",
558 | " df_matches = pd.read_json(json_matches)\n",
559 | " season_id = seasons[df_matches.seasonId.unique()[0]]\n",
560 | " df_games = wyscout.convert_games(df_matches)\n",
561 | " df_games['competition_id'] = competition.id\n",
562 | " df_games['season_id'] = season_id\n",
563 | " \n",
564 | " json_events = read_json_file(f\"{raw_datafolder}/events_{competition.name}.json\")\n",
565 | " df_events = pd.read_json(json_events).groupby('matchId', as_index=False)\n",
566 | " \n",
567 | " player_games = []\n",
568 | " \n",
569 | " spadl_h5 = os.path.join(spadl_datafolder, f\"spadl-wyscout_opensource-{competition.id}-{season_id}.h5\")\n",
570 | "\n",
571 | " # Store all spadl data in h5-file\n",
572 | " print(f\"Converting {competition.id} {season_id}\")\n",
573 | " with pd.HDFStore(spadl_h5) as spadlstore:\n",
574 | " \n",
575 | " spadlstore[\"actiontypes\"] = spadl.actiontypes_df()\n",
576 | " spadlstore[\"results\"] = spadl.results_df()\n",
577 | " spadlstore[\"bodyparts\"] = spadl.bodyparts_df()\n",
578 | " spadlstore[\"games\"] = df_games\n",
579 | "\n",
580 | " for game in tqdm(list(df_games.itertuples())):\n",
581 | " game_id = game.game_id\n",
582 | " game_events = df_events.get_group(game_id)\n",
583 | "\n",
584 | " # filter the players that were lined up in this season\n",
585 | " player_games.append(wyscout.get_player_games(df_matches[df_matches.wyId == game_id].iloc[0], game_events))\n",
586 | "\n",
587 | " # convert events to SPADL actions\n",
588 | " home_team = game.home_team_id\n",
589 | " df_actions = wyscout.convert_actions(game_events, home_team)\n",
590 | " df_actions[\"action_id\"] = range(len(df_actions))\n",
591 | " spadlstore[f\"actions/game_{game_id}\"] = df_actions\n",
592 | "\n",
593 | " player_games = pd.concat(player_games).reset_index(drop=True) \n",
594 | " spadlstore[\"player_games\"] = player_games\n",
595 | " spadlstore[\"players\"] = df_players[df_players.player_id.isin(player_games.player_id)]\n",
596 | " spadlstore[\"teams\"] = df_teams[df_teams.team_id.isin(df_games.home_team_id) | df_teams.team_id.isin(df_games.away_team_id)]"
597 | ]
598 | },
599 | {
600 | "cell_type": "code",
601 | "execution_count": null,
602 | "metadata": {},
603 | "outputs": [],
604 | "source": []
605 | }
606 | ],
607 | "metadata": {
608 | "kernelspec": {
609 | "display_name": "soccer_dataprovider_comparison",
610 | "language": "python",
611 | "name": "soccer_dataprovider_comparison"
612 | },
613 | "language_info": {
614 | "codemirror_mode": {
615 | "name": "ipython",
616 | "version": 3
617 | },
618 | "file_extension": ".py",
619 | "mimetype": "text/x-python",
620 | "name": "python",
621 | "nbconvert_exporter": "python",
622 | "pygments_lexer": "ipython3",
623 | "version": "3.6.2"
624 | },
625 | "toc": {
626 | "base_numbering": 1,
627 | "nav_menu": {},
628 | "number_sections": true,
629 | "sideBar": true,
630 | "skip_h1_title": false,
631 | "title_cell": "Table of Contents",
632 | "title_sidebar": "Contents",
633 | "toc_cell": false,
634 | "toc_position": {},
635 | "toc_section_display": true,
636 | "toc_window_display": true
637 | }
638 | },
639 | "nbformat": 4,
640 | "nbformat_minor": 2
641 | }
642 |
--------------------------------------------------------------------------------
/notebooks/3-computing-and-storing-features.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Features\n",
8 | "This notebook generates features and labels (goal/no goal) for all shots and stores them in a HDF file. This is a good practice to save computational time if you want to experiment with multiple pipelines."
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "metadata": {},
15 | "outputs": [],
16 | "source": [
17 | "import pandas as pd\n",
18 | "pd.set_option('display.max_columns', None)\n",
19 | "import numpy as np\n",
20 | "import itertools"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 2,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "%load_ext autoreload\n",
30 | "%autoreload 2\n",
31 | "from soccer_xg.api import DataApi\n",
32 | "import soccer_xg.xg as xg\n",
33 | "import soccer_xg.features as fs"
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {},
39 | "source": [
40 | "## Config"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 3,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "# dataset\n",
50 | "dir_data = \"../data\"\n",
51 | "provider = 'wyscout_opensource'\n",
52 | "leagues = ['ENG', 'ESP', 'ITA', 'GER', 'FRA']\n",
53 | "seasons = ['1718']\n",
54 | "\n",
55 | "# features\n",
56 | "store_features = f'../data/{provider}/features.h5'"
57 | ]
58 | },
59 | {
60 | "cell_type": "markdown",
61 | "metadata": {},
62 | "source": [
63 | "By default, all features defined in `soccer_xg.features.all_features` are computed. It is also possible to compute a subset of these features or add additional feature generators. Each feature generator is a function that expects either a DataFrame object containing actions (i.e., individual actions) or a list of DataFrame objects containing consecutive actions (i.e., game states), and returns the corresponding feature for the individual action or game state. Features that contain information about the shot's outcome are automatically removed."
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 4,
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "feature_generators = fs.all_features"
73 | ]
74 | },
75 | {
76 | "cell_type": "markdown",
77 | "metadata": {},
78 | "source": [
79 | "## Compute features and labels"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 5,
85 | "metadata": {},
86 | "outputs": [
87 | {
88 | "name": "stdout",
89 | "output_type": "stream",
90 | "text": [
91 | "ENG 1718\n"
92 | ]
93 | },
94 | {
95 | "name": "stderr",
96 | "output_type": "stream",
97 | "text": [
98 | "Generating features: 100%|██████████| 380/380 [01:40<00:00, 3.78it/s]\n",
99 | "Generating labels: 100%|██████████| 380/380 [00:08<00:00, 43.51it/s]\n"
100 | ]
101 | },
102 | {
103 | "name": "stdout",
104 | "output_type": "stream",
105 | "text": [
106 | "ESP 1718\n"
107 | ]
108 | },
109 | {
110 | "name": "stderr",
111 | "output_type": "stream",
112 | "text": [
113 | "Generating features: 100%|██████████| 380/380 [01:40<00:00, 3.79it/s]\n",
114 | "Generating labels: 100%|██████████| 380/380 [00:08<00:00, 43.70it/s]\n"
115 | ]
116 | },
117 | {
118 | "name": "stdout",
119 | "output_type": "stream",
120 | "text": [
121 | "ITA 1718\n"
122 | ]
123 | },
124 | {
125 | "name": "stderr",
126 | "output_type": "stream",
127 | "text": [
128 | "Generating features: 100%|██████████| 380/380 [01:40<00:00, 3.77it/s]\n",
129 | "Generating labels: 100%|██████████| 380/380 [00:08<00:00, 43.62it/s]\n"
130 | ]
131 | },
132 | {
133 | "name": "stdout",
134 | "output_type": "stream",
135 | "text": [
136 | "GER 1718\n"
137 | ]
138 | },
139 | {
140 | "name": "stderr",
141 | "output_type": "stream",
142 | "text": [
143 | "Generating features: 100%|██████████| 306/306 [01:21<00:00, 3.77it/s]\n",
144 | "Generating labels: 100%|██████████| 306/306 [00:06<00:00, 43.93it/s]\n"
145 | ]
146 | },
147 | {
148 | "name": "stdout",
149 | "output_type": "stream",
150 | "text": [
151 | "FRA 1718\n"
152 | ]
153 | },
154 | {
155 | "name": "stderr",
156 | "output_type": "stream",
157 | "text": [
158 | "Generating features: 100%|██████████| 380/380 [01:40<00:00, 3.78it/s]\n",
159 | "Generating labels: 100%|██████████| 380/380 [00:08<00:00, 43.76it/s]\n"
160 | ]
161 | }
162 | ],
163 | "source": [
164 | "for (l,s) in itertools.product(leagues, seasons):\n",
165 | " print(l, s)\n",
166 | " api = DataApi([f\"{dir_data}/{provider}/spadl-{provider}-{l}-{s}.h5\"])\n",
167 | " xg.get_features(api, xfns=feature_generators).to_hdf(store_features, key=f'{l}/{s}/features', format='table') \n",
168 | " xg.get_labels(api).to_hdf(store_features, key=f'{l}/{s}/labels', format='table') "
169 | ]
170 | },
171 | {
172 | "cell_type": "markdown",
173 | "metadata": {},
174 | "source": [
175 | "## Load features"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": 6,
181 | "metadata": {},
182 | "outputs": [
183 | {
184 | "data": {
185 | "text/html": [
186 | "\n",
187 | "\n",
200 | "
\n",
201 | " \n",
202 | " \n",
203 | " | \n",
204 | " | \n",
205 | " type_id_a0 | \n",
206 | " type_id_a1 | \n",
207 | " type_id_a2 | \n",
208 | " bodypart_id_a0 | \n",
209 | " bodypart_id_a1 | \n",
210 | " bodypart_id_a2 | \n",
211 | " result_id_a1 | \n",
212 | " result_id_a2 | \n",
213 | " start_x_a0 | \n",
214 | " start_y_a0 | \n",
215 | " start_x_a1 | \n",
216 | " start_y_a1 | \n",
217 | " start_x_a2 | \n",
218 | " start_y_a2 | \n",
219 | " end_x_a1 | \n",
220 | " end_y_a1 | \n",
221 | " end_x_a2 | \n",
222 | " end_y_a2 | \n",
223 | " dx_a1 | \n",
224 | " dy_a1 | \n",
225 | " movement_a1 | \n",
226 | " dx_a2 | \n",
227 | " dy_a2 | \n",
228 | " movement_a2 | \n",
229 | " dx_a01 | \n",
230 | " dy_a01 | \n",
231 | " mov_a01 | \n",
232 | " dx_a02 | \n",
233 | " dy_a02 | \n",
234 | " mov_a02 | \n",
235 | " start_dist_to_goal_a0 | \n",
236 | " start_angle_to_goal_a0 | \n",
237 | " start_dist_to_goal_a1 | \n",
238 | " start_angle_to_goal_a1 | \n",
239 | " start_dist_to_goal_a2 | \n",
240 | " start_angle_to_goal_a2 | \n",
241 | " end_dist_to_goal_a1 | \n",
242 | " end_angle_to_goal_a1 | \n",
243 | " end_dist_to_goal_a2 | \n",
244 | " end_angle_to_goal_a2 | \n",
245 | " team_1 | \n",
246 | " team_2 | \n",
247 | " time_delta_1 | \n",
248 | " time_delta_2 | \n",
249 | " speedx_a01 | \n",
250 | " speedy_a01 | \n",
251 | " speed_a01 | \n",
252 | " speedx_a02 | \n",
253 | " speedy_a02 | \n",
254 | " speed_a02 | \n",
255 | " shot_angle_a0 | \n",
256 | " shot_angle_a1 | \n",
257 | " shot_angle_a2 | \n",
258 | " caley_zone_a0 | \n",
259 | " caley_zone_a1 | \n",
260 | " caley_zone_a2 | \n",
261 | " angle_zone_a0 | \n",
262 | " angle_zone_a1 | \n",
263 | " angle_zone_a2 | \n",
264 | "
\n",
265 | " \n",
266 | " game_id | \n",
267 | " action_id | \n",
268 | " | \n",
269 | " | \n",
270 | " | \n",
271 | " | \n",
272 | " | \n",
273 | " | \n",
274 | " | \n",
275 | " | \n",
276 | " | \n",
277 | " | \n",
278 | " | \n",
279 | " | \n",
280 | " | \n",
281 | " | \n",
282 | " | \n",
283 | " | \n",
284 | " | \n",
285 | " | \n",
286 | " | \n",
287 | " | \n",
288 | " | \n",
289 | " | \n",
290 | " | \n",
291 | " | \n",
292 | " | \n",
293 | " | \n",
294 | " | \n",
295 | " | \n",
296 | " | \n",
297 | " | \n",
298 | " | \n",
299 | " | \n",
300 | " | \n",
301 | " | \n",
302 | " | \n",
303 | " | \n",
304 | " | \n",
305 | " | \n",
306 | " | \n",
307 | " | \n",
308 | " | \n",
309 | " | \n",
310 | " | \n",
311 | " | \n",
312 | " | \n",
313 | " | \n",
314 | " | \n",
315 | " | \n",
316 | " | \n",
317 | " | \n",
318 | " | \n",
319 | " | \n",
320 | " | \n",
321 | " | \n",
322 | " | \n",
323 | " | \n",
324 | " | \n",
325 | " | \n",
326 | " | \n",
327 | "
\n",
328 | " \n",
329 | " \n",
330 | " \n",
331 | " 2500098 | \n",
332 | " 17 | \n",
333 | " shot | \n",
334 | " dribble | \n",
335 | " cross | \n",
336 | " foot | \n",
337 | " foot | \n",
338 | " foot | \n",
339 | " success | \n",
340 | " success | \n",
341 | " 99.75 | \n",
342 | " 26.52 | \n",
343 | " 91.35 | \n",
344 | " 29.92 | \n",
345 | " 97.65 | \n",
346 | " 6.12 | \n",
347 | " 99.75 | \n",
348 | " 26.52 | \n",
349 | " 91.35 | \n",
350 | " 29.92 | \n",
351 | " 8.40 | \n",
352 | " -3.40 | \n",
353 | " 9.062009 | \n",
354 | " -6.30 | \n",
355 | " 23.80 | \n",
356 | " 24.619708 | \n",
357 | " 0.0 | \n",
358 | " 0.0 | \n",
359 | " 0.0 | \n",
360 | " -8.40 | \n",
361 | " 3.40 | \n",
362 | " 9.062009 | \n",
363 | " 9.138539 | \n",
364 | " 0.958815 | \n",
365 | " 14.246715 | \n",
366 | " 0.290448 | \n",
367 | " 28.832567 | \n",
368 | " 1.313031 | \n",
369 | " 9.138539 | \n",
370 | " 0.958815 | \n",
371 | " 14.246715 | \n",
372 | " 0.290448 | \n",
373 | " True | \n",
374 | " True | \n",
375 | " 3.433228 | \n",
376 | " 6.866456 | \n",
377 | " 0.0 | \n",
378 | " 0.0 | \n",
379 | " 0.0 | \n",
380 | " 1.223339 | \n",
381 | " 0.495161 | \n",
382 | " 1.319750 | \n",
383 | " 0.499778 | \n",
384 | " 0.483780 | \n",
385 | " 0.065500 | \n",
386 | " 2 | \n",
387 | " 3 | \n",
388 | " 8 | \n",
389 | " 9 | \n",
390 | " 12 | \n",
391 | " 18 | \n",
392 | "
\n",
393 | " \n",
394 | " 40 | \n",
395 | " shot | \n",
396 | " corner_crossed | \n",
397 | " pass | \n",
398 | " foot | \n",
399 | " foot | \n",
400 | " foot | \n",
401 | " success | \n",
402 | " fail | \n",
403 | " 91.35 | \n",
404 | " 35.36 | \n",
405 | " 105.00 | \n",
406 | " 0.00 | \n",
407 | " 96.60 | \n",
408 | " 23.80 | \n",
409 | " 91.35 | \n",
410 | " 35.36 | \n",
411 | " 0.00 | \n",
412 | " 53.72 | \n",
413 | " -13.65 | \n",
414 | " 35.36 | \n",
415 | " 37.903194 | \n",
416 | " -96.60 | \n",
417 | " 29.92 | \n",
418 | " 101.127476 | \n",
419 | " 0.0 | \n",
420 | " 0.0 | \n",
421 | " 0.0 | \n",
422 | " -91.35 | \n",
423 | " 18.36 | \n",
424 | " 93.176779 | \n",
425 | " 13.717584 | \n",
426 | " 0.099306 | \n",
427 | " 34.000000 | \n",
428 | " 1.570796 | \n",
429 | " 13.213629 | \n",
430 | " 0.881872 | \n",
431 | " 13.717584 | \n",
432 | " 0.099306 | \n",
433 | " 106.835754 | \n",
434 | " 0.185647 | \n",
435 | " True | \n",
436 | " True | \n",
437 | " 2.102531 | \n",
438 | " 21.927228 | \n",
439 | " 0.0 | \n",
440 | " 0.0 | \n",
441 | " 0.0 | \n",
442 | " 4.166053 | \n",
443 | " 0.837315 | \n",
444 | " 4.249364 | \n",
445 | " 0.517985 | \n",
446 | " 0.000000 | \n",
447 | " 0.363334 | \n",
448 | " 3 | \n",
449 | " 8 | \n",
450 | " 4 | \n",
451 | " 12 | \n",
452 | " 21 | \n",
453 | " 12 | \n",
454 | "
\n",
455 | " \n",
456 | " 77 | \n",
457 | " shot | \n",
458 | " clearance | \n",
459 | " cross | \n",
460 | " foot | \n",
461 | " foot | \n",
462 | " foot | \n",
463 | " fail | \n",
464 | " fail | \n",
465 | " 75.60 | \n",
466 | " 29.92 | \n",
467 | " 94.50 | \n",
468 | " 27.20 | \n",
469 | " 98.70 | \n",
470 | " 65.96 | \n",
471 | " 75.60 | \n",
472 | " 29.92 | \n",
473 | " 94.50 | \n",
474 | " 27.20 | \n",
475 | " -18.90 | \n",
476 | " 2.72 | \n",
477 | " 19.094722 | \n",
478 | " -4.20 | \n",
479 | " -38.76 | \n",
480 | " 38.986890 | \n",
481 | " 0.0 | \n",
482 | " 0.0 | \n",
483 | " 0.0 | \n",
484 | " 18.90 | \n",
485 | " -2.72 | \n",
486 | " 19.094722 | \n",
487 | " 29.681752 | \n",
488 | " 0.137895 | \n",
489 | " 12.509596 | \n",
490 | " 0.574700 | \n",
491 | " 32.575015 | \n",
492 | " 1.376170 | \n",
493 | " 29.681752 | \n",
494 | " 0.137895 | \n",
495 | " 12.509596 | \n",
496 | " 0.574700 | \n",
497 | " False | \n",
498 | " True | \n",
499 | " 2.629861 | \n",
500 | " 3.250682 | \n",
501 | " 0.0 | \n",
502 | " 0.0 | \n",
503 | " 0.0 | \n",
504 | " 5.814165 | \n",
505 | " 0.836747 | \n",
506 | " 5.874066 | \n",
507 | " 0.242481 | \n",
508 | " 0.491555 | \n",
509 | " 0.043863 | \n",
510 | " 6 | \n",
511 | " 3 | \n",
512 | " 0 | \n",
513 | " 18 | \n",
514 | " 12 | \n",
515 | " 18 | \n",
516 | "
\n",
517 | " \n",
518 | " 140 | \n",
519 | " shot | \n",
520 | " cross | \n",
521 | " dribble | \n",
522 | " foot | \n",
523 | " foot | \n",
524 | " foot | \n",
525 | " success | \n",
526 | " success | \n",
527 | " 92.40 | \n",
528 | " 43.52 | \n",
529 | " 98.70 | \n",
530 | " 51.68 | \n",
531 | " 91.35 | \n",
532 | " 54.40 | \n",
533 | " 92.40 | \n",
534 | " 43.52 | \n",
535 | " 98.70 | \n",
536 | " 51.68 | \n",
537 | " -6.30 | \n",
538 | " -8.16 | \n",
539 | " 10.309006 | \n",
540 | " 7.35 | \n",
541 | " -2.72 | \n",
542 | " 7.837149 | \n",
543 | " 0.0 | \n",
544 | " 0.0 | \n",
545 | " 0.0 | \n",
546 | " 6.30 | \n",
547 | " 8.16 | \n",
548 | " 10.309006 | \n",
549 | " 15.792099 | \n",
550 | " 0.647047 | \n",
551 | " 18.768921 | \n",
552 | " 1.228489 | \n",
553 | " 24.545519 | \n",
554 | " 0.981099 | \n",
555 | " 15.792099 | \n",
556 | " 0.647047 | \n",
557 | " 18.768921 | \n",
558 | " 1.228489 | \n",
559 | " True | \n",
560 | " True | \n",
561 | " 1.052499 | \n",
562 | " 5.000627 | \n",
563 | " 0.0 | \n",
564 | " 0.0 | \n",
565 | " 0.0 | \n",
566 | " 1.259842 | \n",
567 | " 1.631795 | \n",
568 | " 2.061543 | \n",
569 | " 0.371538 | \n",
570 | " 0.134860 | \n",
571 | " 0.167545 | \n",
572 | " 4 | \n",
573 | " 5 | \n",
574 | " 0 | \n",
575 | " 12 | \n",
576 | " 15 | \n",
577 | " 18 | \n",
578 | "
\n",
579 | " \n",
580 | " 145 | \n",
581 | " shot | \n",
582 | " pass | \n",
583 | " pass | \n",
584 | " foot | \n",
585 | " foot | \n",
586 | " foot | \n",
587 | " success | \n",
588 | " success | \n",
589 | " 99.75 | \n",
590 | " 37.40 | \n",
591 | " 96.60 | \n",
592 | " 38.76 | \n",
593 | " 93.45 | \n",
594 | " 45.56 | \n",
595 | " 99.75 | \n",
596 | " 37.40 | \n",
597 | " 96.60 | \n",
598 | " 38.76 | \n",
599 | " 3.15 | \n",
600 | " -1.36 | \n",
601 | " 3.431049 | \n",
602 | " 3.15 | \n",
603 | " -6.80 | \n",
604 | " 7.494164 | \n",
605 | " 0.0 | \n",
606 | " 0.0 | \n",
607 | " 0.0 | \n",
608 | " -3.15 | \n",
609 | " 1.36 | \n",
610 | " 3.431049 | \n",
611 | " 6.254798 | \n",
612 | " 0.574700 | \n",
613 | " 9.654926 | \n",
614 | " 0.515549 | \n",
615 | " 16.341239 | \n",
616 | " 0.785831 | \n",
617 | " 6.254798 | \n",
618 | " 0.574700 | \n",
619 | " 9.654926 | \n",
620 | " 0.515549 | \n",
621 | " True | \n",
622 | " True | \n",
623 | " 1.677755 | \n",
624 | " 2.659997 | \n",
625 | " 0.0 | \n",
626 | " 0.0 | \n",
627 | " 0.0 | \n",
628 | " 1.184212 | \n",
629 | " 0.511279 | \n",
630 | " 1.289870 | \n",
631 | " 0.978291 | \n",
632 | " 0.654611 | \n",
633 | " 0.320841 | \n",
634 | " 1 | \n",
635 | " 3 | \n",
636 | " 4 | \n",
637 | " 6 | \n",
638 | " 9 | \n",
639 | " 15 | \n",
640 | "
\n",
641 | " \n",
642 | "
\n",
643 | "
"
644 | ],
645 | "text/plain": [
646 | " type_id_a0 type_id_a1 type_id_a2 bodypart_id_a0 \\\n",
647 | "game_id action_id \n",
648 | "2500098 17 shot dribble cross foot \n",
649 | " 40 shot corner_crossed pass foot \n",
650 | " 77 shot clearance cross foot \n",
651 | " 140 shot cross dribble foot \n",
652 | " 145 shot pass pass foot \n",
653 | "\n",
654 | " bodypart_id_a1 bodypart_id_a2 result_id_a1 result_id_a2 \\\n",
655 | "game_id action_id \n",
656 | "2500098 17 foot foot success success \n",
657 | " 40 foot foot success fail \n",
658 | " 77 foot foot fail fail \n",
659 | " 140 foot foot success success \n",
660 | " 145 foot foot success success \n",
661 | "\n",
662 | " start_x_a0 start_y_a0 start_x_a1 start_y_a1 start_x_a2 \\\n",
663 | "game_id action_id \n",
664 | "2500098 17 99.75 26.52 91.35 29.92 97.65 \n",
665 | " 40 91.35 35.36 105.00 0.00 96.60 \n",
666 | " 77 75.60 29.92 94.50 27.20 98.70 \n",
667 | " 140 92.40 43.52 98.70 51.68 91.35 \n",
668 | " 145 99.75 37.40 96.60 38.76 93.45 \n",
669 | "\n",
670 | " start_y_a2 end_x_a1 end_y_a1 end_x_a2 end_y_a2 dx_a1 \\\n",
671 | "game_id action_id \n",
672 | "2500098 17 6.12 99.75 26.52 91.35 29.92 8.40 \n",
673 | " 40 23.80 91.35 35.36 0.00 53.72 -13.65 \n",
674 | " 77 65.96 75.60 29.92 94.50 27.20 -18.90 \n",
675 | " 140 54.40 92.40 43.52 98.70 51.68 -6.30 \n",
676 | " 145 45.56 99.75 37.40 96.60 38.76 3.15 \n",
677 | "\n",
678 | " dy_a1 movement_a1 dx_a2 dy_a2 movement_a2 dx_a01 \\\n",
679 | "game_id action_id \n",
680 | "2500098 17 -3.40 9.062009 -6.30 23.80 24.619708 0.0 \n",
681 | " 40 35.36 37.903194 -96.60 29.92 101.127476 0.0 \n",
682 | " 77 2.72 19.094722 -4.20 -38.76 38.986890 0.0 \n",
683 | " 140 -8.16 10.309006 7.35 -2.72 7.837149 0.0 \n",
684 | " 145 -1.36 3.431049 3.15 -6.80 7.494164 0.0 \n",
685 | "\n",
686 | " dy_a01 mov_a01 dx_a02 dy_a02 mov_a02 \\\n",
687 | "game_id action_id \n",
688 | "2500098 17 0.0 0.0 -8.40 3.40 9.062009 \n",
689 | " 40 0.0 0.0 -91.35 18.36 93.176779 \n",
690 | " 77 0.0 0.0 18.90 -2.72 19.094722 \n",
691 | " 140 0.0 0.0 6.30 8.16 10.309006 \n",
692 | " 145 0.0 0.0 -3.15 1.36 3.431049 \n",
693 | "\n",
694 | " start_dist_to_goal_a0 start_angle_to_goal_a0 \\\n",
695 | "game_id action_id \n",
696 | "2500098 17 9.138539 0.958815 \n",
697 | " 40 13.717584 0.099306 \n",
698 | " 77 29.681752 0.137895 \n",
699 | " 140 15.792099 0.647047 \n",
700 | " 145 6.254798 0.574700 \n",
701 | "\n",
702 | " start_dist_to_goal_a1 start_angle_to_goal_a1 \\\n",
703 | "game_id action_id \n",
704 | "2500098 17 14.246715 0.290448 \n",
705 | " 40 34.000000 1.570796 \n",
706 | " 77 12.509596 0.574700 \n",
707 | " 140 18.768921 1.228489 \n",
708 | " 145 9.654926 0.515549 \n",
709 | "\n",
710 | " start_dist_to_goal_a2 start_angle_to_goal_a2 \\\n",
711 | "game_id action_id \n",
712 | "2500098 17 28.832567 1.313031 \n",
713 | " 40 13.213629 0.881872 \n",
714 | " 77 32.575015 1.376170 \n",
715 | " 140 24.545519 0.981099 \n",
716 | " 145 16.341239 0.785831 \n",
717 | "\n",
718 | " end_dist_to_goal_a1 end_angle_to_goal_a1 \\\n",
719 | "game_id action_id \n",
720 | "2500098 17 9.138539 0.958815 \n",
721 | " 40 13.717584 0.099306 \n",
722 | " 77 29.681752 0.137895 \n",
723 | " 140 15.792099 0.647047 \n",
724 | " 145 6.254798 0.574700 \n",
725 | "\n",
726 | " end_dist_to_goal_a2 end_angle_to_goal_a2 team_1 team_2 \\\n",
727 | "game_id action_id \n",
728 | "2500098 17 14.246715 0.290448 True True \n",
729 | " 40 106.835754 0.185647 True True \n",
730 | " 77 12.509596 0.574700 False True \n",
731 | " 140 18.768921 1.228489 True True \n",
732 | " 145 9.654926 0.515549 True True \n",
733 | "\n",
734 | " time_delta_1 time_delta_2 speedx_a01 speedy_a01 \\\n",
735 | "game_id action_id \n",
736 | "2500098 17 3.433228 6.866456 0.0 0.0 \n",
737 | " 40 2.102531 21.927228 0.0 0.0 \n",
738 | " 77 2.629861 3.250682 0.0 0.0 \n",
739 | " 140 1.052499 5.000627 0.0 0.0 \n",
740 | " 145 1.677755 2.659997 0.0 0.0 \n",
741 | "\n",
742 | " speed_a01 speedx_a02 speedy_a02 speed_a02 \\\n",
743 | "game_id action_id \n",
744 | "2500098 17 0.0 1.223339 0.495161 1.319750 \n",
745 | " 40 0.0 4.166053 0.837315 4.249364 \n",
746 | " 77 0.0 5.814165 0.836747 5.874066 \n",
747 | " 140 0.0 1.259842 1.631795 2.061543 \n",
748 | " 145 0.0 1.184212 0.511279 1.289870 \n",
749 | "\n",
750 | " shot_angle_a0 shot_angle_a1 shot_angle_a2 caley_zone_a0 \\\n",
751 | "game_id action_id \n",
752 | "2500098 17 0.499778 0.483780 0.065500 2 \n",
753 | " 40 0.517985 0.000000 0.363334 3 \n",
754 | " 77 0.242481 0.491555 0.043863 6 \n",
755 | " 140 0.371538 0.134860 0.167545 4 \n",
756 | " 145 0.978291 0.654611 0.320841 1 \n",
757 | "\n",
758 | " caley_zone_a1 caley_zone_a2 angle_zone_a0 angle_zone_a1 \\\n",
759 | "game_id action_id \n",
760 | "2500098 17 3 8 9 12 \n",
761 | " 40 8 4 12 21 \n",
762 | " 77 3 0 18 12 \n",
763 | " 140 5 0 12 15 \n",
764 | " 145 3 4 6 9 \n",
765 | "\n",
766 | " angle_zone_a2 \n",
767 | "game_id action_id \n",
768 | "2500098 17 18 \n",
769 | " 40 12 \n",
770 | " 77 18 \n",
771 | " 140 18 \n",
772 | " 145 15 "
773 | ]
774 | },
775 | "metadata": {},
776 | "output_type": "display_data"
777 | },
778 | {
779 | "data": {
780 | "text/html": [
781 | "\n",
782 | "\n",
795 | "
\n",
796 | " \n",
797 | " \n",
798 | " | \n",
799 | " | \n",
800 | " goal | \n",
801 | "
\n",
802 | " \n",
803 | " game_id | \n",
804 | " action_id | \n",
805 | " | \n",
806 | "
\n",
807 | " \n",
808 | " \n",
809 | " \n",
810 | " 2500098 | \n",
811 | " 17 | \n",
812 | " False | \n",
813 | "
\n",
814 | " \n",
815 | " 40 | \n",
816 | " False | \n",
817 | "
\n",
818 | " \n",
819 | " 77 | \n",
820 | " False | \n",
821 | "
\n",
822 | " \n",
823 | " 140 | \n",
824 | " False | \n",
825 | "
\n",
826 | " \n",
827 | " 145 | \n",
828 | " False | \n",
829 | "
\n",
830 | " \n",
831 | "
\n",
832 | "
"
833 | ],
834 | "text/plain": [
835 | " goal\n",
836 | "game_id action_id \n",
837 | "2500098 17 False\n",
838 | " 40 False\n",
839 | " 77 False\n",
840 | " 140 False\n",
841 | " 145 False"
842 | ]
843 | },
844 | "metadata": {},
845 | "output_type": "display_data"
846 | }
847 | ],
848 | "source": [
849 | "features = []\n",
850 | "labels = []\n",
851 | "for (l,s) in itertools.product(leagues, seasons):\n",
852 | " features.append(pd.read_hdf(store_features, key=f'{l}/{s}/features'))\n",
853 | " labels.append(pd.read_hdf(store_features, key=f'{l}/{s}/labels'))\n",
854 | "features = pd.concat(features)\n",
855 | "labels = pd.concat(labels)\n",
856 | "\n",
857 | "display(features.head())\n",
858 | "display(labels.to_frame().head())"
859 | ]
860 | },
861 | {
862 | "cell_type": "code",
863 | "execution_count": null,
864 | "metadata": {},
865 | "outputs": [],
866 | "source": []
867 | }
868 | ],
869 | "metadata": {
870 | "kernelspec": {
871 | "display_name": "soccer_dataprovider_comparison",
872 | "language": "python",
873 | "name": "soccer_dataprovider_comparison"
874 | },
875 | "language_info": {
876 | "codemirror_mode": {
877 | "name": "ipython",
878 | "version": 3
879 | },
880 | "file_extension": ".py",
881 | "mimetype": "text/x-python",
882 | "name": "python",
883 | "nbconvert_exporter": "python",
884 | "pygments_lexer": "ipython3",
885 | "version": "3.6.2"
886 | },
887 | "toc": {
888 | "base_numbering": 1,
889 | "nav_menu": {},
890 | "number_sections": true,
891 | "sideBar": true,
892 | "skip_h1_title": false,
893 | "title_cell": "Table of Contents",
894 | "title_sidebar": "Contents",
895 | "toc_cell": false,
896 | "toc_position": {},
897 | "toc_section_display": true,
898 | "toc_window_display": true
899 | }
900 | },
901 | "nbformat": 4,
902 | "nbformat_minor": 2
903 | }
904 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "soccer_xg"
3 | version = "0.0.1"
4 | description = "Train and analyse xG models on soccer event stream data"
5 | authors = ["Pieter Robberechts "]
6 | license = "Apache Software License (http://www.apache.org/licenses/LICENSE-2.0)"
7 | readme = 'README.md'
8 | repository = "https://github.com/probberechts/soccer_xg"
9 | homepage = "https://pypi.org/project/soccer_xg"
10 | keywords = ["expected goals", "xG", "Soccer", "Football", "event stream data", "sports analytics", "Statsbomb", "Opta", "Wyscout"]
11 | classifiers=[
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | "Operating System :: OS Independent"
15 | ]
16 |
17 | [tool.poetry.dependencies]
18 | python = "^3.6.1"
19 | numpy = "^1.18"
20 | pandas = "^1.0"
21 | scikit-learn = "^0.22.1"
22 | ipykernel = "^5.1"
23 | matplotsoccer = "^0.0.8"
24 | matplotlib = "^3.1"
25 | click = "^7.0"
26 | tables = "^3.6"
27 | requests = "^2.23"
28 | xgboost = "^1.0"
29 | seaborn = "^0.10.0"
30 | fuzzywuzzy = "^0.18.0"
31 | python-Levenshtein = "^0.12.0"
32 | dask = {extras = ["array","distributed","dataframe"], version = "^2.15.0", optional = true}
33 | dask_ml = {version = "^1.3.0", optional = true}
34 | category_encoders = "^2.2.2"
35 | asyncssh = {version = "^2.2.1", optional = true}
36 | paramiko = {version = "^2.7.1", optional = true}
37 | understat = "^0.1.2"
38 | socceraction = "^0.2.1"
39 | betacal = "^0.2.7"
40 |
41 | [tool.poetry.extras]
42 | dask = ["dask", "dask_ml", "asyncssh", "paramiko"]
43 |
44 | [tool.poetry.dev-dependencies]
45 | flake8-awesome = "^1.2"
46 | mypy = "^0.761.0"
47 | pylint = "^2.4"
48 | pytest = "^5.3"
49 | pytest-cov = "^2.8"
50 | pytest-deadfixtures = "^2.1"
51 | unify = "^0.5.0"
52 | black = {version = "^19.10b0", allow-prereleases = true}
53 | bumpversion = "^0.6.0"
54 |
55 | [build-system]
56 | requires = ["poetry>=0.12"]
57 | build-backend = "poetry.masonry.api"
58 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | enable-extensions = G
3 | exclude = .git, .venv
4 | ignore =
5 | A003 ; 'id' is a python builtin, consider renaming the class attribute
6 | W503 ; line break before binary operator
7 | max-complexity = 8
8 | max-line-length = 100
9 | show-source = true
10 |
11 | [mypy]
12 | check_untyped_defs = true
13 | disallow_any_generics = true
14 | disallow_incomplete_defs = true
15 | disallow_untyped_defs = true
16 | ignore_missing_imports = True
17 | no_implicit_optional = true
18 |
19 | [mypy-tests.*]
20 | disallow_untyped_defs = false
21 |
22 | [isort]
23 | balanced_wrapping = true
24 | default_section = THIRDPARTY
25 | include_trailing_comma=True
26 | known_first_party = src, tests
27 | line_length = 79
28 | multi_line_output = 3
29 | not_skip = __init__.py
30 |
31 | [pylint]
32 | good-names=i,j,k,e,x,_,pk,id
33 | max-args=5
34 | max-attributes=10
35 | max-bool-expr=5
36 | max-module-lines=200
37 | max-nested-blocks=2
38 | max-public-methods=5
39 | max-returns=5
40 | max-statements=20
41 | output-format = colorized
42 |
43 | disable=
44 | C0103, ; Constant name "api" doesn't conform to UPPER_CASE naming style (invalid-name)
45 | C0111, ; Missing module docstring (missing-docstring)
46 | C0330, ; Wrong hanging indentation before block (add 4 spaces)
47 | E0213, ; Method should have "self" as first argument (no-self-argument) - N805 for flake8
48 | R0201, ; Method could be a function (no-self-use)
49 | R0901, ; Too many ancestors (m/n) (too-many-ancestors)
50 | R0903, ; Too few public methods (m/n) (too-few-public-methods)
51 |
52 | ignored-classes=
53 | contextlib.closing,
54 |
55 | [coverage:run]
56 | omit = tests/*,**/__main__.py
57 | branch = True
58 |
59 | [coverage:report]
60 | show_missing = True
61 | skip_covered = True
62 | fail_under = 0
63 |
--------------------------------------------------------------------------------
/soccer_xg/__init__.py:
--------------------------------------------------------------------------------
1 | from soccer_xg.api import DataApi
2 | from soccer_xg.xg import XGModel
3 |
4 | __version__ = '0.0.1'
5 |
--------------------------------------------------------------------------------
/soccer_xg/api.py:
--------------------------------------------------------------------------------
1 | """An API wrapper for the SPADL format."""
2 | import logging
3 | import os
4 | import warnings
5 |
6 | import pandas as pd
7 |
8 | deduplic = dict(
9 | games=('game_id', 'game_id'),
10 | teams=(['team_id'], 'team_id'),
11 | players=(['player_id'], 'player_id'),
12 | player_games=(
13 | ['game_id', 'team_id', 'player_id'],
14 | ['game_id', 'team_id', 'player_id'],
15 | ),
16 | files=('file_url', 'file_url'),
17 | )
18 |
19 |
20 | class DataApi:
21 | """An objectect that provides easy access to a SPADL event stream dataset.
22 |
23 | Automatically defines an attribute which lazy loads the contents of
24 | each table in the HDF files and defines a couple of methods to easily execute
25 | common queries on the SPADL data.
26 |
27 | Parameters
28 | ----------
29 | db_path : A list of strings or a single string
30 | Path(s) to HDF files containing the data.
31 |
32 | Attributes
33 | ----------
34 | ``table_name`` : ``pd.DataFrame``
35 | A single pandas dataframe that contains all records from all ``table_name``
36 | tables in each HDF file.
37 | """
38 |
39 | def __init__(self, db_path):
40 | self.logger = logging.getLogger(__name__)
41 | self.logger.info('Loading datasets')
42 | if type(db_path) is list:
43 | self.db_path = set(db_path)
44 | elif type(db_path) is not set:
45 | self.db_path = set([db_path])
46 |
47 | for p in self.db_path:
48 | if not os.path.exists(p):
49 | raise ValueError(
50 | 'A database `{}` does not exist.'.format(str(p))
51 | )
52 |
53 | def __getattr__(self, name):
54 | self.logger.info(f'Loading `{name}` data')
55 | DB = []
56 | for p in self.db_path:
57 | with pd.HDFStore(p, 'r') as store:
58 | for key in [
59 | k for k in store.keys() if (k[1:].rsplit('/')[0] == name)
60 | ]:
61 | db = store[key]
62 | db['db_path'] = p
63 | DB.append(db)
64 | if len(DB) == 0:
65 | raise ValueError('A table `{}` does not exist.'.format(str(name)))
66 | else:
67 | DB = pd.concat(DB, sort=False)
68 | if name in deduplic:
69 | sortcols, idcols = deduplic[name]
70 | DB.sort_values(by=sortcols, ascending=False, inplace=True)
71 | DB.drop_duplicates(subset=idcols, inplace=True)
72 | DB.set_index(idcols, inplace=True)
73 | setattr(self, name, DB)
74 | return DB
75 |
76 | def get_events(self, game_id, only_home=False, only_away=False):
77 | """Return all events performed in a given game.
78 |
79 | Parameters
80 | ----------
81 | game_id : int
82 | The ID of a game.
83 | only_home : bool
84 | Include only events from the home team.
85 | only_away : bool
86 | Include only events from the away team.
87 |
88 | Returns
89 | -------
90 | pd.DataFrame
91 | A dataframe with a row for each event, indexed by period_id and
92 | a timestamp (ms) in which the event happened.
93 |
94 | Raises
95 | ------
96 | ValueError
97 | If both `only_home` and `only_away` are True.
98 | IndexError
99 | If no game exists with the provided ID.
100 | """
101 | if only_home and only_away:
102 | raise ValueError('only_home and only_away cannot be both True.')
103 |
104 | try:
105 | db = self.games.at[game_id, 'db_path']
106 | with pd.HDFStore(db, 'r') as store:
107 | df_game_events = store.get(f'events/game_{game_id}')
108 | home_team_id, away_team_id = self.get_home_away_team_id(
109 | game_id
110 | )
111 | if only_home:
112 | team_filter = df_game_events.team_id == home_team_id
113 | elif only_away:
114 | team_filter = df_game_events.team_id == away_team_id
115 | else:
116 | team_filter = [True] * len(df_game_events)
117 | return df_game_events.loc[(team_filter), :].set_index(
118 | ['period_id', 'period_milliseconds']
119 | )
120 | except KeyError:
121 | raise IndexError(
122 | 'No events found for a game with the provided ID.'
123 | )
124 |
125 | def get_actions(
126 | self, game_id, only_home=False, only_away=False, features=False
127 | ):
128 | """Return all actions performed in a given game.
129 |
130 | Parameters
131 | ----------
132 | game_id : int
133 | The ID of a game.
134 | only_home : bool
135 | Include only actions from the home team.
136 | only_away : bool
137 | Include only actions from the away team.
138 |
139 | Returns
140 | -------
141 | pd.DataFrame
142 | A dataframe with a row for each action, indexed by period_id and
143 | a timestamp (ms) in which the action was executed.
144 |
145 | Raises
146 | ------
147 | ValueError
148 | If both `only_home` and `only_away` are True.
149 | IndexError
150 | If no game exists with the provided ID.
151 | """
152 | if only_home and only_away:
153 | raise ValueError('only_home and only_away cannot be both True.')
154 |
155 | try:
156 | db = self.games.at[game_id, 'db_path']
157 | with pd.HDFStore(db, 'r') as store:
158 | df_game_actions = store.get(f'actions/game_{game_id}')
159 | if features:
160 | try:
161 | df_game_features = store.get(
162 | f'features/game_{game_id}'
163 | )
164 | df_game_actions = pd.concat(
165 | [df_game_actions, df_game_features], axis=1
166 | )
167 | except KeyError:
168 | warnings.warn('Could not find precomputed features')
169 |
170 | home_team_id, away_team_id = self.get_home_away_team_id(game_id)
171 | if only_home:
172 | team_filter = df_game_actions.team_id == home_team_id
173 | elif only_away:
174 | team_filter = df_game_actions.team_id == away_team_id
175 | else:
176 | team_filter = [True] * len(df_game_actions)
177 | return df_game_actions.loc[(team_filter), :].set_index(
178 | ['action_id']
179 | )
180 | except KeyError:
181 | raise IndexError(
182 | 'No actions found for a game with the provided ID.'
183 | )
184 |
185 | # Games ##################################################################
186 |
187 | def get_home_away_team_id(self, game_id):
188 | """Return the id of the home and away team in a given game.
189 |
190 | Parameters
191 | ----------
192 | game_id : int
193 | The ID of a game.
194 |
195 | Returns
196 | -------
197 | (int, int)
198 | The ID of the home and away team.
199 |
200 | Raises
201 | ------
202 | IndexError
203 | If no game exists with the provided ID.
204 | """
205 | try:
206 | return self.games.loc[
207 | game_id, ['home_team_id', 'away_team_id']
208 | ].values
209 | except KeyError:
210 | raise IndexError('No game found with the provided ID.')
211 |
212 | # Players ################################################################
213 |
214 | def get_player_name(self, player_id):
215 | """Return the name of a player with a given ID.
216 |
217 | Parameters
218 | ----------
219 | player_id : int
220 | The ID of a player.
221 |
222 | Returns
223 | -------
224 | The name of the player.
225 |
226 | Raises
227 | ------
228 | IndexError
229 | If no player exists with the provided ID.
230 | """
231 | try:
232 | return self.players.at[player_id, 'player_name']
233 | except KeyError:
234 | raise IndexError('No player found with the provided ID.')
235 |
236 | def search_player(self, query, limit=10):
237 | """Search for a player by name.
238 |
239 | Parameters
240 | ----------
241 | query : str
242 | The name of a player.
243 | limit : int
244 | Max number of results that are returned.
245 |
246 | Returns
247 | -------
248 | pd.DataFrame
249 | The first `limit` players that game the given query.
250 |
251 | """
252 | return self.players[
253 | self.players.player_name.str.contains(query, case=False)
254 | ].head(limit)
255 |
256 | # Teams ##################################################################
257 |
258 | def get_team_name(self, team_id):
259 | """Return the name of a team with a given ID.
260 |
261 | Parameters
262 | ----------
263 | team_id : int
264 | The ID of a team.
265 |
266 | Returns
267 | -------
268 | The name of the team.
269 |
270 | Raises
271 | ------
272 | IndexError
273 | If no team exists with the provided ID.
274 | """
275 | try:
276 | return self.teams.at[team_id, 'team_name']
277 | except KeyError:
278 | raise IndexError('No team found with the provided ID.')
279 |
280 | def search_team(self, query, limit=10):
281 | """Search for a team by name.
282 |
283 | Parameters
284 | ----------
285 | query : str
286 | The name of a team.
287 | limit : int
288 | Max number of results that are returned.
289 |
290 | Returns
291 | -------
292 | pd.DataFrame
293 | The first `limit` teams that game the given query.
294 |
295 | """
296 | return self.teams[
297 | self.teams.team_name.str.contains(query, case=False)
298 | ].head(limit)
299 |
--------------------------------------------------------------------------------
/soccer_xg/calibration.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import warnings
4 | from inspect import signature
5 |
6 | import numpy as np
7 | from betacal import BetaCalibration
8 | from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, clone
9 | from sklearn.isotonic import IsotonicRegression
10 | from sklearn.linear_model import LogisticRegression
11 | from sklearn.model_selection import check_cv
12 | from sklearn.preprocessing import LabelBinarizer, label_binarize
13 | from sklearn.svm import LinearSVC
14 | from sklearn.utils import check_X_y, column_or_1d, indexable
15 | from sklearn.utils.validation import check_is_fitted
16 |
17 |
18 | class CalibratedClassifierCV(BaseEstimator, ClassifierMixin):
19 | """Probability calibration with isotonic regression, sigmoid or beta.
20 |
21 | With this class, the base_estimator is fit on the train set of the
22 | cross-validation generator and the test set is used for calibration.
23 | The probabilities for each of the folds are then averaged
24 | for prediction. In case cv="prefit" is passed to __init__,
25 | it is assumed that base_estimator has been
26 | fitted already and all data is used for calibration. Note that
27 | data for fitting the classifier and for calibrating it must be disjoint.
28 |
29 | Read more in the :ref:`User Guide `.
30 |
31 | Parameters
32 | ----------
33 | base_estimator : instance BaseEstimator
34 | The classifier whose output decision function needs to be calibrated
35 | to offer more accurate predict_proba outputs. If cv=prefit, the
36 | classifier must have been fit already on data.
37 |
38 | method : None, 'sigmoid', 'isotonic', 'beta', 'beta_am' or 'beta_ab'
39 | The method to use for calibration. Can be 'sigmoid' which
40 | corresponds to Platt's method, 'isotonic' which is a
41 | non-parameteric approach or 'beta', 'beta_am' or 'beta_ab' which
42 | correspond to three different beta calibration methods. It is
43 | not advised to use isotonic calibration with too few calibration
44 | samples ``(<<1000)`` since it tends to overfit.
45 | Use beta models in this case.
46 |
47 | cv : integer, cross-validation generator, iterable or "prefit", optional
48 | Determines the cross-validation splitting strategy.
49 | Possible inputs for cv are:
50 |
51 | - None, to use the default 3-fold cross-validation,
52 | - integer, to specify the number of folds.
53 | - An object to be used as a cross-validation generator.
54 | - An iterable yielding train/test splits.
55 |
56 | For integer/None inputs, if ``y`` is binary or multiclass,
57 | :class:`StratifiedKFold` used. If ``y`` is neither binary nor
58 | multiclass, :class:`KFold` is used.
59 |
60 | Refer :ref:`User Guide ` for the various
61 | cross-validation strategies that can be used here.
62 |
63 | If "prefit" is passed, it is assumed that base_estimator has been
64 | fitted already and all data is used for calibration.
65 |
66 | Attributes
67 | ----------
68 | classes_ : array, shape (n_classes)
69 | The class labels.
70 |
71 | calibrated_classifiers_: list (len() equal to cv or 1 if cv == "prefit")
72 | The list of calibrated classifiers, one for each cross-validation fold,
73 | which has been fitted on all but the validation fold and calibrated
74 | on the validation fold.
75 |
76 | References
77 | ----------
78 | .. [1] Obtaining calibrated probability estimates from decision trees
79 | and naive Bayesian classifiers, B. Zadrozny & C. Elkan, ICML 2001
80 |
81 | .. [2] Transforming Classifier Scores into Accurate Multiclass
82 | Probability Estimates, B. Zadrozny & C. Elkan, (KDD 2002)
83 |
84 | .. [3] Probabilistic Outputs for Support Vector Machines and Comparisons to
85 | Regularized Likelihood Methods, J. Platt, (1999)
86 |
87 | .. [4] Predicting Good Probabilities with Supervised Learning,
88 | A. Niculescu-Mizil & R. Caruana, ICML 2005
89 | """
90 |
91 | def __init__(
92 | self, base_estimator=None, method=None, cv=3, score_type=None
93 | ):
94 | self.base_estimator = base_estimator
95 | self.method = method
96 | self.cv = cv
97 | self.score_type = score_type
98 |
99 | def fit(self, X, y, sample_weight=None):
100 | """Fit the calibrated model
101 |
102 | Parameters
103 | ----------
104 | X : array-like, shape (n_samples, n_features)
105 | Training data.
106 |
107 | y : array-like, shape (n_samples,)
108 | Target values.
109 |
110 | sample_weight : array-like, shape = [n_samples] or None
111 | Sample weights. If None, then samples are equally weighted.
112 |
113 | Returns
114 | -------
115 | self : object
116 | Returns an instance of self.
117 | """
118 | # X, y = check_X_y(X, y, accept_sparse=['csc', 'csr', 'coo'],
119 | # force_all_finite=False)
120 | X, y = indexable(X, y)
121 | lb = LabelBinarizer().fit(y)
122 | self.classes_ = lb.classes_
123 |
124 | # Check that each cross-validation fold can have at least one
125 | # example per class
126 | n_folds = (
127 | self.cv
128 | if isinstance(self.cv, int)
129 | else self.cv.n_folds
130 | if hasattr(self.cv, 'n_folds')
131 | else None
132 | )
133 | if n_folds and np.any(
134 | [np.sum(y == class_) < n_folds for class_ in self.classes_]
135 | ):
136 | raise ValueError(
137 | 'Requesting %d-fold cross-validation but provided'
138 | ' less than %d examples for at least one class.'
139 | % (n_folds, n_folds)
140 | )
141 |
142 | self.calibrated_classifiers_ = []
143 | if self.base_estimator is None:
144 | # we want all classifiers that don't expose a random_state
145 | # to be deterministic (and we don't want to expose this one).
146 | base_estimator = LinearSVC(random_state=0)
147 | else:
148 | base_estimator = self.base_estimator
149 |
150 | if self.cv == 'prefit':
151 | calibrated_classifier = _CalibratedClassifier(
152 | base_estimator, method=self.method, score_type=self.score_type
153 | )
154 | if sample_weight is not None:
155 | calibrated_classifier.fit(X, y, sample_weight)
156 | else:
157 | calibrated_classifier.fit(X, y)
158 | self.calibrated_classifiers_.append(calibrated_classifier)
159 | else:
160 | cv = check_cv(self.cv, X, y, classifier=True)
161 | fit_parameters = signature(base_estimator.fit).parameters
162 | estimator_name = type(base_estimator).__name__
163 | if (
164 | sample_weight is not None
165 | and 'sample_weight' not in fit_parameters
166 | ):
167 | warnings.warn(
168 | '%s does not support sample_weight. Samples'
169 | ' weights are only used for the calibration'
170 | ' itself.' % estimator_name
171 | )
172 | base_estimator_sample_weight = None
173 | else:
174 | base_estimator_sample_weight = sample_weight
175 | for train, test in cv:
176 | this_estimator = clone(base_estimator)
177 | if base_estimator_sample_weight is not None:
178 | this_estimator.fit(
179 | X[train],
180 | y[train],
181 | sample_weight=base_estimator_sample_weight[train],
182 | )
183 | else:
184 | this_estimator.fit(X[train], y[train])
185 |
186 | calibrated_classifier = _CalibratedClassifier(
187 | this_estimator,
188 | method=self.method,
189 | score_type=self.score_type,
190 | )
191 | if sample_weight is not None:
192 | calibrated_classifier.fit(
193 | X[test], y[test], sample_weight[test]
194 | )
195 | else:
196 | calibrated_classifier.fit(X[test], y[test])
197 | self.calibrated_classifiers_.append(calibrated_classifier)
198 |
199 | return self
200 |
201 | def predict_proba(self, X):
202 | """Posterior probabilities of classification
203 |
204 | This function returns posterior probabilities of classification
205 | according to each class on an array of test vectors X.
206 |
207 | Parameters
208 | ----------
209 | X : array-like, shape (n_samples, n_features)
210 | The samples.
211 |
212 | Returns
213 | -------
214 | C : array, shape (n_samples, n_classes)
215 | The predicted probas.
216 | """
217 | check_is_fitted(self, ['classes_', 'calibrated_classifiers_'])
218 | # X = check_array(X, accept_sparse=['csc', 'csr', 'coo'],
219 | # force_all_finite=False)
220 | # Compute the arithmetic mean of the predictions of the calibrated
221 | # classfiers
222 | mean_proba = np.zeros((X.shape[0], len(self.classes_)))
223 | for calibrated_classifier in self.calibrated_classifiers_:
224 | proba = calibrated_classifier.predict_proba(X)
225 | mean_proba += proba
226 |
227 | mean_proba /= len(self.calibrated_classifiers_)
228 |
229 | return mean_proba
230 |
231 | def calibrate_scores(self, df):
232 | """Posterior probabilities of classification
233 |
234 | This function returns posterior probabilities of classification
235 | according to each class on an array of test vectors X.
236 |
237 | Parameters
238 | ----------
239 | X : array-like, shape (n_samples, n_features)
240 | The samples.
241 |
242 | Returns
243 | -------
244 | C : array, shape (n_samples, n_classes)
245 | The predicted probas.
246 | """
247 | check_is_fitted(self, ['classes_', 'calibrated_classifiers_'])
248 | # Compute the arithmetic mean of the predictions of the calibrated
249 | # classifiers
250 | df = df.reshape(-1, 1)
251 | mean_proba = np.zeros((len(df), len(self.classes_)))
252 | for calibrated_classifier in self.calibrated_classifiers_:
253 | proba = calibrated_classifier.calibrate_scores(df)
254 | mean_proba += proba
255 |
256 | mean_proba /= len(self.calibrated_classifiers_)
257 |
258 | return mean_proba
259 |
260 | def predict(self, X):
261 | """Predict the target of new samples. Can be different from the
262 | prediction of the uncalibrated classifier.
263 |
264 | Parameters
265 | ----------
266 | X : array-like, shape (n_samples, n_features)
267 | The samples.
268 |
269 | Returns
270 | -------
271 | C : array, shape (n_samples,)
272 | The predicted class.
273 | """
274 | check_is_fitted(self, ['classes_', 'calibrated_classifiers_'])
275 | return self.classes_[np.argmax(self.predict_proba(X), axis=1)]
276 |
277 |
278 | class _CalibratedClassifier(object):
279 | """Probability calibration with isotonic regression or sigmoid.
280 |
281 | It assumes that base_estimator has already been fit, and trains the
282 | calibration on the input set of the fit function. Note that this class
283 | should not be used as an estimator directly. Use CalibratedClassifierCV
284 | with cv="prefit" instead.
285 |
286 | Parameters
287 | ----------
288 | base_estimator : instance BaseEstimator
289 | The classifier whose output decision function needs to be calibrated
290 | to offer more accurate predict_proba outputs. No default value since
291 | it has to be an already fitted estimator.
292 |
293 | method : 'sigmoid' | 'isotonic' | 'beta' | 'beta_am' | 'beta_ab'
294 | The method to use for calibration. Can be 'sigmoid' which
295 | corresponds to Platt's method, 'isotonic' which is a
296 | non-parameteric approach based on isotonic regression or 'beta',
297 | 'beta_am' or 'beta_ab' which correspond to beta calibration methods.
298 |
299 | References
300 | ----------
301 | .. [1] Obtaining calibrated probability estimates from decision trees
302 | and naive Bayesian classifiers, B. Zadrozny & C. Elkan, ICML 2001
303 |
304 | .. [2] Transforming Classifier Scores into Accurate Multiclass
305 | Probability Estimates, B. Zadrozny & C. Elkan, (KDD 2002)
306 |
307 | .. [3] Probabilistic Outputs for Support Vector Machines and Comparisons to
308 | Regularized Likelihood Methods, J. Platt, (1999)
309 |
310 | .. [4] Predicting Good Probabilities with Supervised Learning,
311 | A. Niculescu-Mizil & R. Caruana, ICML 2005
312 | """
313 |
314 | def __init__(self, base_estimator, method='beta', score_type=None):
315 | self.base_estimator = base_estimator
316 | self.method = method
317 | self.score_type = score_type
318 |
319 | def _preproc(self, X):
320 | n_classes = len(self.classes_)
321 | if self.score_type is None:
322 | if hasattr(self.base_estimator, 'decision_function'):
323 | df = self.base_estimator.decision_function(X)
324 | if df.ndim == 1:
325 | df = df[:, np.newaxis]
326 | elif hasattr(self.base_estimator, 'predict_proba'):
327 | df = self.base_estimator.predict_proba(X)
328 | if n_classes == 2:
329 | df = df[:, 1:]
330 | else:
331 | raise RuntimeError(
332 | 'classifier has no decision_function or '
333 | 'predict_proba method.'
334 | )
335 | else:
336 | if hasattr(self.base_estimator, self.score_type):
337 | df = getattr(self.base_estimator, self.score_type)(X)
338 | if self.score_type == 'decision_function':
339 | if df.ndim == 1:
340 | df = df[:, np.newaxis]
341 | elif self.score_type == 'predict_proba':
342 | if n_classes == 2:
343 | df = df[:, 1:]
344 | else:
345 | raise RuntimeError(
346 | 'classifier has no ' + self.score_type + 'method.'
347 | )
348 |
349 | idx_pos_class = np.arange(df.shape[1])
350 |
351 | return df, idx_pos_class
352 |
353 | def fit(self, X, y, sample_weight=None):
354 | """Calibrate the fitted model
355 |
356 | Parameters
357 | ----------
358 | X : array-like, shape (n_samples, n_features)
359 | Training data.
360 |
361 | y : array-like, shape (n_samples,)
362 | Target values.
363 |
364 | sample_weight : array-like, shape = [n_samples] or None
365 | Sample weights. If None, then samples are equally weighted.
366 |
367 | Returns
368 | -------
369 | self : object
370 | Returns an instance of self.
371 | """
372 | lb = LabelBinarizer()
373 | Y = lb.fit_transform(y)
374 | self.classes_ = lb.classes_
375 |
376 | df, idx_pos_class = self._preproc(X)
377 | self.calibrators_ = []
378 |
379 | for k, this_df in zip(idx_pos_class, df.T):
380 | if self.method is None:
381 | calibrator = _DummyCalibration()
382 | elif self.method == 'isotonic':
383 | calibrator = IsotonicRegression(out_of_bounds='clip')
384 | elif self.method == 'sigmoid':
385 | calibrator = _SigmoidCalibration()
386 | elif self.method == 'beta':
387 | calibrator = BetaCalibration(parameters='abm')
388 | elif self.method == 'beta_am':
389 | calibrator = BetaCalibration(parameters='am')
390 | elif self.method == 'beta_ab':
391 | calibrator = BetaCalibration(parameters='ab')
392 | else:
393 | raise ValueError(
394 | 'method should be None, "sigmoid", '
395 | '"isotonic", "beta", "beta2" or "beta05". '
396 | 'Got %s.' % self.method
397 | )
398 | calibrator.fit(this_df, Y[:, k], sample_weight)
399 | self.calibrators_.append(calibrator)
400 |
401 | return self
402 |
403 | def predict_proba(self, X):
404 | """Posterior probabilities of classification
405 |
406 | This function returns posterior probabilities of classification
407 | according to each class on an array of test vectors X.
408 |
409 | Parameters
410 | ----------
411 | X : array-like, shape (n_samples, n_features)
412 | The samples.
413 |
414 | Returns
415 | -------
416 | C : array, shape (n_samples, n_classes)
417 | The predicted probas. Can be exact zeros.
418 | """
419 | n_classes = len(self.classes_)
420 | proba = np.zeros((X.shape[0], n_classes))
421 |
422 | df, idx_pos_class = self._preproc(X)
423 | for k, this_df, calibrator in zip(
424 | idx_pos_class, df.T, self.calibrators_
425 | ):
426 | if n_classes == 2:
427 | k += 1
428 | proba[:, k] = calibrator.predict(this_df)
429 |
430 | # Normalize the probabilities
431 | if n_classes == 2:
432 | proba[:, 0] = 1.0 - proba[:, 1]
433 | else:
434 | proba /= np.sum(proba, axis=1)[:, np.newaxis]
435 |
436 | # XXX : for some reason all probas can be 0
437 | proba[np.isnan(proba)] = 1.0 / n_classes
438 |
439 | # Deal with cases where the predicted probability minimally exceeds 1.0
440 | proba[(1.0 < proba) & (proba <= 1.0 + 1e-5)] = 1.0
441 |
442 | return proba
443 |
444 | def calibrate_scores(self, df):
445 | """Posterior probabilities of classification
446 |
447 | This function returns posterior probabilities of classification
448 | according to each class on an array of test vectors X.
449 |
450 | Parameters
451 | ----------
452 | X : array-like, shape (n_samples, n_features)
453 | The samples.
454 |
455 | Returns
456 | -------
457 | C : array, shape (n_samples, n_classes)
458 | The predicted probas. Can be exact zeros.
459 | """
460 | n_classes = len(self.classes_)
461 | proba = np.zeros((len(df), n_classes))
462 | idx_pos_class = [0]
463 |
464 | for k, this_df, calibrator in zip(
465 | idx_pos_class, df.T, self.calibrators_
466 | ):
467 | if n_classes == 2:
468 | k += 1
469 | pro = calibrator.predict(this_df)
470 | if np.any(np.isnan(pro)):
471 | pro[np.isnan(pro)] = calibrator.predict(
472 | this_df[np.isnan(pro)] + 1e-300
473 | )
474 | proba[:, k] = pro
475 |
476 | # Normalize the probabilities
477 | if n_classes == 2:
478 | proba[:, 0] = 1.0 - proba[:, 1]
479 | else:
480 | proba /= np.sum(proba, axis=1)[:, np.newaxis]
481 |
482 | # XXX : for some reason all probas can be 0
483 | proba[np.isnan(proba)] = 1.0 / n_classes
484 |
485 | # Deal with cases where the predicted probability minimally exceeds 1.0
486 | proba[(1.0 < proba) & (proba <= 1.0 + 1e-5)] = 1.0
487 | return proba
488 |
489 |
490 | class _SigmoidCalibration(BaseEstimator, RegressorMixin):
491 | """Sigmoid regression model.
492 |
493 | Attributes
494 | ----------
495 | a_ : float
496 | The slope.
497 |
498 | b_ : float
499 | The intercept.
500 | """
501 |
502 | def fit(self, X, y, sample_weight=None):
503 | """Fit the model using X, y as training data.
504 |
505 | Parameters
506 | ----------
507 | X : array-like, shape (n_samples,)
508 | Training data.
509 |
510 | y : array-like, shape (n_samples,)
511 | Training target.
512 |
513 | sample_weight : array-like, shape = [n_samples] or None
514 | Sample weights. If None, then samples are equally weighted.
515 |
516 | Returns
517 | -------
518 | self : object
519 | Returns an instance of self.
520 | """
521 | X = column_or_1d(X)
522 | y = column_or_1d(y)
523 | X, y = indexable(X, y)
524 | self.lr = LogisticRegression(C=99999999999)
525 | self.lr.fit(X.reshape(-1, 1), y)
526 | return self
527 |
528 | def predict(self, T):
529 | """Predict new data by linear interpolation.
530 |
531 | Parameters
532 | ----------
533 | T : array-like, shape (n_samples,)
534 | Data to predict from.
535 |
536 | Returns
537 | -------
538 | T_ : array, shape (n_samples,)
539 | The predicted data.
540 | """
541 | T = column_or_1d(T)
542 | return self.lr.predict_proba(T.reshape(-1, 1))[:, 1]
543 |
544 |
545 | class _DummyCalibration(BaseEstimator, RegressorMixin):
546 | """Dummy regression model. The purpose of this class is to give
547 | the CalibratedClassifierCV class the option to just return the
548 | probabilities of the base classifier.
549 |
550 |
551 | """
552 |
553 | def fit(self, X, y, sample_weight=None):
554 | """Does nothing.
555 |
556 | Parameters
557 | ----------
558 | X : array-like, shape (n_samples,)
559 | Training data.
560 |
561 | y : array-like, shape (n_samples,)
562 | Training target.
563 |
564 | sample_weight : array-like, shape = [n_samples] or None
565 | Sample weights. If None, then samples are equally weighted.
566 |
567 | Returns
568 | -------
569 | self : object
570 | Returns an instance of self.
571 | """
572 | return self
573 |
574 | def predict(self, T):
575 | """Return the probabilities of the base classifier.
576 |
577 | Parameters
578 | ----------
579 | T : array-like, shape (n_samples,)
580 | Data to predict from.
581 |
582 | Returns
583 | -------
584 | T_ : array, shape (n_samples,)
585 | The predicted data.
586 | """
587 | return T
588 |
589 |
590 | def calibration_curve(y_true, y_prob, normalize=False, n_bins=5):
591 | """Compute true and predicted probabilities for a calibration curve.
592 |
593 | Read more in the :ref:`User Guide `.
594 |
595 | Parameters
596 | ----------
597 | y_true : array, shape (n_samples,)
598 | True targets.
599 |
600 | y_prob : array, shape (n_samples,)
601 | Probabilities of the positive class.
602 |
603 | normalize : bool, optional, default=False
604 | Whether y_prob needs to be normalized into the bin [0, 1], i.e. is not
605 | a proper probability. If True, the smallest value in y_prob is mapped
606 | onto 0 and the largest one onto 1.
607 |
608 | n_bins : int
609 | Number of bins. A bigger number requires more data.
610 |
611 | Returns
612 | -------
613 | prob_true : array, shape (n_bins,)
614 | The true probability in each bin (fraction of positives).
615 |
616 | prob_pred : array, shape (n_bins,)
617 | The mean predicted probability in each bin.
618 |
619 | References
620 | ----------
621 | Alexandru Niculescu-Mizil and Rich Caruana (2005) Predicting Good
622 | Probabilities With Supervised Learning, in Proceedings of the 22nd
623 | International Conference on Machine Learning (ICML).
624 | See section 4 (Qualitative Analysis of Predictions).
625 | """
626 | y_true = column_or_1d(y_true)
627 | y_prob = column_or_1d(y_prob)
628 |
629 | if normalize: # Normalize predicted values into interval [0, 1]
630 | y_prob = (y_prob - y_prob.min()) / (y_prob.max() - y_prob.min())
631 | elif y_prob.min() < 0 or y_prob.max() > 1:
632 | raise ValueError(
633 | 'y_prob has values outside [0, 1] and normalize is '
634 | 'set to False.'
635 | )
636 |
637 | y_true = _check_binary_probabilistic_predictions(y_true, y_prob)
638 |
639 | bins = np.linspace(0.0, 1.0 + 1e-8, n_bins + 1)
640 | binids = np.digitize(y_prob, bins) - 1
641 |
642 | bin_sums = np.bincount(binids, weights=y_prob, minlength=len(bins))
643 | bin_true = np.bincount(binids, weights=y_true, minlength=len(bins))
644 | bin_total = np.bincount(binids, minlength=len(bins))
645 |
646 | zero = bin_total == 0
647 | bin_total[zero] = 2
648 | # nonzero = bin_total != 0
649 |
650 | prob_true = bin_true / bin_total
651 | prob_pred = bin_sums / bin_total
652 |
653 | return prob_true, prob_pred
654 |
655 |
656 | def _check_binary_probabilistic_predictions(y_true, y_prob):
657 | """Check that y_true is binary and y_prob contains valid probabilities"""
658 | check_consistent_length(y_true, y_prob)
659 |
660 | labels = np.unique(y_true)
661 |
662 | if len(labels) != 2:
663 | raise ValueError(
664 | 'Only binary classification is supported. '
665 | 'Provided labels %s.' % labels
666 | )
667 |
668 | if y_prob.max() > 1:
669 | raise ValueError('y_prob contains values greater than 1.')
670 |
671 | if y_prob.min() < 0:
672 | raise ValueError('y_prob contains values less than 0.')
673 |
674 | return label_binarize(y_true, labels)[:, 0]
675 |
--------------------------------------------------------------------------------
/soccer_xg/features.py:
--------------------------------------------------------------------------------
1 | """A collection of feature generators."""
2 | import numpy as np
3 | import pandas as pd
4 | from socceraction.vaep.features import *
5 |
6 | _spadl_cfg = {
7 | 'length': 105,
8 | 'width': 68,
9 | 'penalty_box_length': 16.5,
10 | 'penalty_box_width': 40.3,
11 | 'six_yard_box_length': 5.5,
12 | 'six_yard_box_width': 18.3,
13 | 'goal_widht': 7.32,
14 | 'penalty_spot_distance': 11,
15 | 'goal_width': 7.3,
16 | 'goal_length': 2,
17 | 'origin_x': 0,
18 | 'origin_y': 0,
19 | 'circle_radius': 9.15,
20 | }
21 |
22 |
23 | @simple
24 | def goalangle(actions, cfg=_spadl_cfg):
25 | dx = cfg['length'] - actions['start_x']
26 | dy = cfg['width'] / 2 - actions['start_y']
27 | angledf = pd.DataFrame()
28 | angledf['shot_angle'] = np.arctan(
29 | cfg['goal_width']
30 | * dx
31 | / (dx ** 2 + dy ** 2 - (cfg['goal_width'] / 2) ** 2)
32 | )
33 | angledf.loc[angledf['shot_angle'] < 0, 'shot_angle'] += np.pi
34 | angledf.loc[(actions['start_x'] >= cfg['length']), 'shot_angle'] = 0
35 | # Ball is on the goal line
36 | angledf.loc[
37 | (actions['start_x'] == cfg['length'])
38 | & (
39 | actions['start_y'].between(
40 | cfg['width'] / 2 - cfg['goal_width'] / 2,
41 | cfg['width'] / 2 + cfg['goal_width'] / 2,
42 | )
43 | ),
44 | 'shot_angle',
45 | ] = np.pi
46 | return angledf
47 |
48 |
49 | def speed(gamestates):
50 | a0 = gamestates[0]
51 | spaced = pd.DataFrame()
52 | for i, a in enumerate(gamestates[1:]):
53 | dt = a0.time_seconds - a.time_seconds
54 | dt[dt < 1] = 1
55 | dx = a.end_x - a0.start_x
56 | spaced['speedx_a0' + (str(i + 1))] = dx.abs() / dt
57 | dy = a.end_y - a0.start_y
58 | spaced['speedy_a0' + (str(i + 1))] = dy.abs() / dt
59 | spaced['speed_a0' + (str(i + 1))] = np.sqrt(dx ** 2 + dy ** 2) / dt
60 | return spaced
61 |
62 |
63 | def _caley_shot_matrix(cfg=_spadl_cfg):
64 | """
65 | https://cartilagefreecaptain.sbnation.com/2013/11/13/5098186/shot-matrix-i-shot-location-and-expected-goals
66 | """
67 | m = (cfg['origin_y'] + cfg['width']) / 2
68 |
69 | zones = []
70 | # Zone 1 is the central area of the six-yard box
71 | x1 = cfg['origin_x'] + cfg['length'] - cfg['six_yard_box_length']
72 | x2 = cfg['origin_x'] + cfg['length']
73 | y1 = m - cfg['goal_width'] / 2
74 | y2 = m + cfg['goal_width'] / 2
75 | zones.append([(x1, y1, x2, y2)])
76 | # Zone 2 includes the wide areas, left and right, of the six-yard box.
77 | ## Left
78 | x1 = cfg['origin_x'] + cfg['length'] - cfg['six_yard_box_length']
79 | x2 = cfg['origin_x'] + cfg['length']
80 | y1 = m - cfg['six_yard_box_width'] / 2
81 | y2 = m - cfg['goal_width'] / 2
82 | zone_left = (x1, y1, x2, y2)
83 | ## Right
84 | x1 = cfg['origin_x'] + cfg['length'] - cfg['six_yard_box_length']
85 | x2 = cfg['origin_x'] + cfg['length']
86 | y1 = m + cfg['goal_width'] / 2
87 | y2 = m + cfg['six_yard_box_width'] / 2
88 | zone_right = (x1, y1, x2, y2)
89 | zones.append([zone_left, zone_right])
90 | # Zone 3 is the central area between the edges of the six- and eighteen-yard boxes.
91 | x1 = cfg['origin_x'] + cfg['length'] - cfg['penalty_box_length']
92 | x2 = cfg['origin_x'] + cfg['length'] - cfg['six_yard_box_length']
93 | y1 = m - cfg['six_yard_box_width'] / 2
94 | y2 = m + cfg['six_yard_box_width'] / 2
95 | zones.append([(x1, y1, x2, y2)])
96 | # Zone 4 comprises the wide areas in the eighteen-yard box, further from the endline than the six-yard box extended.
97 | ## Left
98 | x1 = cfg['origin_x'] + cfg['length'] - cfg['penalty_box_length']
99 | x2 = cfg['origin_x'] + cfg['length'] - cfg['six_yard_box_length'] - 2
100 | y1 = m - cfg['penalty_box_width'] / 2
101 | y2 = m - cfg['six_yard_box_width'] / 2
102 | zone_left = (x1, y1, x2, y2)
103 | ## Right
104 | x1 = cfg['origin_x'] + cfg['length'] - cfg['penalty_box_length']
105 | x2 = cfg['origin_x'] + cfg['length'] - cfg['six_yard_box_length'] - 2
106 | y1 = m + cfg['six_yard_box_width'] / 2
107 | y2 = m + cfg['penalty_box_width'] / 2
108 | zone_right = (x1, y1, x2, y2)
109 | zones.append([zone_left, zone_right])
110 | # Zone 5 includes the wide areas left and right in the eighteen yard box within the six-yard box extended.
111 | ## Left
112 | x1 = cfg['origin_x'] + cfg['length'] - cfg['six_yard_box_length'] - 2
113 | x2 = cfg['origin_x'] + cfg['length']
114 | y1 = m - cfg['penalty_box_width'] / 2
115 | y2 = m - cfg['six_yard_box_width'] / 2
116 | zone_left = (x1, y1, x2, y2)
117 | ## Right
118 | x1 = cfg['origin_x'] + cfg['length'] - cfg['six_yard_box_length'] - 2
119 | x2 = cfg['origin_x'] + cfg['length']
120 | y1 = m + cfg['six_yard_box_width'] / 2
121 | y2 = m + cfg['penalty_box_width'] / 2
122 | zone_right = (x1, y1, x2, y2)
123 | zones.append([zone_left, zone_right])
124 | # Zone 6 is the eighteen-yard box extended out to roughly 35 yards (=32m).
125 | x1 = cfg['origin_x'] + cfg['length'] - 32
126 | x2 = cfg['origin_x'] + cfg['length'] - cfg['penalty_box_length']
127 | y1 = m - cfg['penalty_box_width'] / 2
128 | y2 = m + cfg['penalty_box_width'] / 2
129 | zones.append([(x1, y1, x2, y2)])
130 | # Zone 7 is the deep, deep area beyond that
131 | x1 = cfg['origin_x']
132 | x2 = cfg['origin_x'] + cfg['length'] - 32
133 | y1 = cfg['origin_y']
134 | y2 = cfg['origin_y'] + cfg['width']
135 | zones.append([(x1, y1, x2, y2)])
136 | # Zone 8 comprises the regions right and left of the box.
137 | ## Left
138 | x1 = cfg['origin_x'] + cfg['length'] - 32
139 | x2 = cfg['origin_x'] + cfg['length']
140 | y1 = cfg['origin_y'] + cfg['width']
141 | y2 = m + cfg['penalty_box_width'] / 2
142 | zone_left = (x1, y1, x2, y2)
143 | ## Right
144 | x1 = cfg['origin_x'] + cfg['length'] - 32
145 | x2 = cfg['origin_x'] + cfg['length']
146 | y1 = cfg['origin_y']
147 | y2 = m - cfg['penalty_box_width'] / 2
148 | zone_right = (x1, y1, x2, y2)
149 | zones.append([zone_left, zone_right])
150 | return zones
151 |
152 |
153 | def _point_in_rect(rect):
154 | x1, y1, x2, y2 = rect
155 |
156 | def fn(point):
157 | x, y = point
158 | if x1 <= x and x <= x2:
159 | if y1 <= y and y <= y2:
160 | return True
161 | return False
162 |
163 | return fn
164 |
165 |
166 | def triangular_grid(name, angle_bins, dist_bins, symmetrical=False):
167 | @simple
168 | def fn(actions):
169 | zonedf = startpolar(actions)
170 | if symmetrical:
171 | zonedf.loc[
172 | zonedf.start_angle_to_goal_a0 > np.pi / 2,
173 | 'start_angle_to_goal_a0',
174 | ] -= (np.pi / 2)
175 | dist_bin = np.digitize(zonedf.start_dist_to_goal_a0, dist_bins)
176 | angle_bin = np.digitize(zonedf.start_angle_to_goal_a0, angle_bins)
177 | zonedf[name] = dist_bin * angle_bin + dist_bin
178 | zonedf[name] = pd.Categorical(
179 | zonedf[name],
180 | categories=list(range(len(dist_bins) * len(angle_bins))),
181 | ordered=False,
182 | )
183 | return zonedf[[name]]
184 |
185 | return fn
186 |
187 |
188 | def rectangular_grid(name, x_bins, y_bins, symmetrical=False, cfg=_spadl_cfg):
189 | @simple
190 | def fn(actions):
191 | zonedf = actions[['start_x', 'start_y']].copy()
192 | if symmetrical:
193 | m = (cfg['origin_y'] + cfg['width']) / 2
194 | zonedf.loc[zonedf.start_y > m, 'start_y'] -= m
195 | x_bin = np.digitize(zonedf.start_x, x_bins)
196 | y_bin = np.digitize(zonedf.start_y, y_bins)
197 | zonedf[name] = x_bin * y_bin + y_bin
198 | zonedf[name] = pd.Categorical(
199 | zonedf[name],
200 | categories=list(range(len(x_bins) * len(y_bins))),
201 | ordered=False,
202 | )
203 | return zonedf[[name]]
204 |
205 | return fn
206 |
207 |
208 | def custom_grid(name, zones, is_in_zone):
209 | @simple
210 | def fn(actions):
211 | zonedf = actions[['start_x', 'start_y']].copy()
212 | zonedf[name] = [0] * len(actions) # zone 0 if no match
213 | for (i, zone) in enumerate(zones):
214 | for subzone in zone:
215 | zonedf.loc[
216 | np.apply_along_axis(
217 | is_in_zone(subzone),
218 | 1,
219 | zonedf[['start_x', 'start_y']].values,
220 | ),
221 | name,
222 | ] = (i + 1)
223 | zonedf[name] = pd.Categorical(
224 | zonedf[name], categories=list(range(len(zones) + 1)), ordered=False
225 | )
226 | return zonedf[[name]]
227 |
228 | return fn
229 |
230 |
231 | caley_grid = custom_grid('caley_zone', _caley_shot_matrix(), _point_in_rect)
232 |
233 |
234 | all_features = [
235 | actiontype,
236 | bodypart,
237 | result,
238 | startlocation,
239 | endlocation,
240 | movement,
241 | space_delta,
242 | startpolar,
243 | endpolar,
244 | team,
245 | time_delta,
246 | speed,
247 | goalangle,
248 | caley_grid,
249 | triangular_grid(
250 | 'angle_zone',
251 | [-50, -20, 20, 50],
252 | [2, 4, 8, 11, 16, 24, 34, 50],
253 | symmetrical=True,
254 | ),
255 | ]
256 |
--------------------------------------------------------------------------------
/soccer_xg/metrics.py:
--------------------------------------------------------------------------------
1 | """A collection of metrics for evaluation xG models."""
2 | import numpy as np
3 | from scipy import integrate
4 | from sklearn.neighbors import KernelDensity
5 |
6 |
7 | def expected_calibration_error(y_true, y_prob, n_bins=5, strategy='uniform'):
8 | """Compute the Expected Calibration Error (ECE).
9 |
10 | This method implements equation (3) in [1], as well as the ACE variant in [2].
11 | In this equation the probability of the decided label being correct is
12 | used to estimate the calibration property of the predictor.
13 |
14 | Note: a trade-off exist between using a small number of `n_bins` and the
15 | estimation reliability of the ECE. In particular, this method may produce
16 | unreliable ECE estimates in case there are few samples available in some bins.
17 |
18 | Parameters
19 | ----------
20 | y_true : array, shape (n_samples,)
21 | True targets.
22 | y_prob : array, shape (n_samples,)
23 | Probabilities of the positive class.
24 | n_bins : int, default=5
25 | Number of bins to discretize the [0, 1] interval. A bigger number
26 | requires more data. Bins with no samples (i.e. without
27 | corresponding values in `y_prob`) will not be returned, thus the
28 | returned arrays may have less than `n_bins` values.
29 | strategy : {'uniform', 'quantile'}, default='uniform'
30 | Strategy used to define the widths of the bins.
31 | uniform
32 | The bins have identical widths. This corresponds to the ECE formula.
33 | quantile
34 | The bins have the same number of samples and depend on `y_prob`. This
35 | corresponds to the ACE formula.
36 |
37 | Returns
38 | -------
39 | ece : float
40 | The expected calibration error.
41 |
42 | References
43 | ----------
44 | [1]: Chuan Guo, Geoff Pleiss, Yu Sun, Kilian Q. Weinberger,
45 | On Calibration of Modern Neural Networks.
46 | Proceedings of the 34th International Conference on Machine Learning
47 | (ICML 2017).
48 | arXiv:1706.04599
49 | https://arxiv.org/pdf/1706.04599.pdf
50 | [2]: Nixon, Jeremy, et al.,
51 | Measuring calibration in deep learning.
52 | arXiv:1904.01685
53 | https://arxiv.org/abs/1904.01685
54 |
55 | """
56 |
57 | if y_prob.shape != y_true.shape:
58 | raise ValueError(f'Shapes must match')
59 | if y_prob.min() < 0 or y_prob.max() > 1:
60 | raise ValueError('y_prob has values outside [0, 1].')
61 | labels = np.unique(y_true)
62 | if len(labels) > 2:
63 | raise ValueError('Only binary classification is supported.')
64 |
65 | if strategy == 'quantile': # Determine bin edges by distribution of data
66 | quantiles = np.linspace(0, 1, n_bins + 1)
67 | bins = np.percentile(y_prob, quantiles * 100)
68 | bins[-1] = bins[-1] + 1e-8
69 | elif strategy == 'uniform':
70 | bins = np.linspace(0.0, 1.0 + 1e-8, n_bins + 1)
71 | else:
72 | raise ValueError(
73 | "Invalid entry to 'strategy' input. Strategy "
74 | "must be either 'quantile' or 'uniform'."
75 | )
76 |
77 | n = y_prob.shape[0]
78 | accs, confs, counts = _reliability(y_true, y_prob, bins)
79 | return np.sum(counts * np.abs(accs - confs) / n)
80 |
81 |
82 | def _reliability(y_true, y_prob, bins):
83 | n_bins = len(bins) - 1
84 | accs = np.zeros(n_bins)
85 | confs = np.zeros(n_bins)
86 | counts = np.zeros(n_bins)
87 | for m in range(n_bins):
88 | low = bins[m]
89 | high = bins[m + 1]
90 |
91 | where_in_bin = (low <= y_prob) & (y_prob < high)
92 | if where_in_bin.sum() > 0:
93 | accs[m] = (
94 | np.sum((y_prob[where_in_bin] >= 0.5) == y_true[where_in_bin])
95 | / where_in_bin.sum()
96 | )
97 | confs[m] = np.mean(
98 | np.maximum(y_prob[where_in_bin], 1 - y_prob[where_in_bin])
99 | )
100 | counts[m] = where_in_bin.sum()
101 |
102 | return accs, confs, counts
103 |
104 |
105 | def bayesian_calibration_curve(y_true, y_pred, n_bins=100):
106 | """Compute true and predicted probabilities for a calibration curve using
107 | kernel density estimation instead of bins with a fixed width.
108 |
109 | Parameters
110 | ----------
111 | y_true : array-like of shape (n_samples,)
112 | True targets.
113 | y_prob : array-like of shape (n_samples,)
114 | Probabilities of the positive class.
115 | n_bins : float, default=100
116 | Number of bins to discretize the [0, 1] interval. A bigger number
117 | requires more data.
118 |
119 | Returns
120 | -------
121 | prob_true : ndarray of shape (n_bins,)
122 | The proportion of samples whose class is the positive class, in each
123 | bin (fraction of positives).
124 | prob_pred : ndarray of shape (n_bins,)
125 | The mean predicted probability in each bin.
126 | number_total : ndarray of shape (n_bins,)
127 | The number of examples in each bin.
128 | """
129 | y_true = np.array(y_true, dtype=bool)
130 | bandwidth = 1 / n_bins
131 | kde_pos = KernelDensity(kernel='gaussian', bandwidth=bandwidth).fit(
132 | (y_pred[y_true])[:, np.newaxis]
133 | )
134 | kde_total = KernelDensity(kernel='gaussian', bandwidth=bandwidth).fit(
135 | y_pred[:, np.newaxis]
136 | )
137 | sample_probabilities = np.linspace(0.01, 0.99, 99)
138 | number_density_offense_won = np.exp(
139 | kde_pos.score_samples(sample_probabilities[:, np.newaxis])
140 | ) * np.sum((y_true))
141 | number_density_total = np.exp(
142 | kde_total.score_samples(sample_probabilities[:, np.newaxis])
143 | ) * len(y_true)
144 | number_pos = (
145 | number_density_offense_won
146 | * np.sum(y_true)
147 | / np.sum(number_density_offense_won)
148 | )
149 | number_total = (
150 | number_density_total * len(y_true) / np.sum(number_density_total)
151 | )
152 | predicted_pos_percents = np.nan_to_num(number_pos / number_total, 1)
153 |
154 | return (
155 | 100.0 * sample_probabilities,
156 | 100.0 * predicted_pos_percents,
157 | number_total,
158 | )
159 |
160 |
161 | def max_deviation(sample_probabilities, predicted_pos_percents):
162 | """Compute the largest discrepancy between the model and expectation.
163 | """
164 | abs_deviations = np.abs(predicted_pos_percents - sample_probabilities)
165 | return np.max(abs_deviations)
166 |
167 |
168 | def residual_area(sample_probabilities, predicted_pos_percents):
169 | """Compute the total area under the curve of |predicted prob - expected prob|
170 | """
171 | abs_deviations = np.abs(predicted_pos_percents - sample_probabilities)
172 | return integrate.trapz(abs_deviations, sample_probabilities)
173 |
--------------------------------------------------------------------------------
/soccer_xg/ml/logreg.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.linear_model import LogisticRegression, SGDClassifier
3 | from sklearn.pipeline import Pipeline
4 | from soccer_xg.ml.preprocessing import simple_proc_for_linear_algoritms
5 |
6 |
7 | def logreg_gridsearch_classifier(
8 | numeric_features,
9 | categoric_features,
10 | learning_rate=0.08,
11 | use_dask=False,
12 | n_iter=100,
13 | scoring='roc_auc',
14 | ):
15 | """
16 | Simple classification pipeline using hyperband to optimize logreg hyper-parameters
17 | Parameters
18 | ----------
19 | `numeric_features` : The list of numeric features
20 | `categoric_features` : The list of categoric features
21 | `learning_rate` : The learning rate
22 | """
23 |
24 | return _logreg_gridsearch_model(
25 | 'classification',
26 | numeric_features,
27 | categoric_features,
28 | learning_rate,
29 | use_dask,
30 | n_iter,
31 | scoring,
32 | )
33 |
34 |
35 | def logreg_gridsearch_regressor(
36 | numeric_features,
37 | categoric_features,
38 | learning_rate=0.08,
39 | use_dask=False,
40 | n_iter=100,
41 | scoring='roc_auc',
42 | ):
43 | """
44 | Simple regression pipeline using hyperband to optimize logreg hyper-parameters
45 | Parameters
46 | ----------
47 | `numeric_features` : The list of numeric features
48 | `categoric_features` : The list of categoric features
49 | `learning_rate` : The learning rate
50 | """
51 |
52 | return _logreg_gridsearch_model(
53 | 'regression',
54 | numeric_features,
55 | categoric_features,
56 | learning_rate,
57 | use_dask,
58 | n_iter,
59 | scoring,
60 | )
61 |
62 |
63 | def _logreg_gridsearch_model(
64 | task,
65 | numeric_features,
66 | categoric_features,
67 | learning_rate,
68 | use_dask,
69 | n_iter,
70 | scoring,
71 | ):
72 | if learning_rate is None:
73 | param_space = {
74 | 'clf__C': np.logspace(-5, 5, 100),
75 | 'clf__class_weight': ['balanced', None],
76 | }
77 | model = LogisticRegression(max_iter=10000, fit_intercept=False)
78 | else:
79 | param_space = {
80 | 'clf__penalty': ['l1', 'l2'],
81 | 'clf__alpha': np.logspace(-5, 5, 100),
82 | 'clf__class_weight': ['balanced', None],
83 | }
84 | learning_rate_schedule = (
85 | 'constant' if isinstance(learning_rate, float) else learning_rate
86 | )
87 | eta0 = learning_rate if isinstance(learning_rate, float) else 0
88 | model = SGDClassifier(
89 | learning_rate=learning_rate_schedule,
90 | eta0=eta0,
91 | loss='log',
92 | max_iter=10000,
93 | fit_intercept=False,
94 | )
95 |
96 | pipe = Pipeline(
97 | [
98 | (
99 | 'preprocessing',
100 | simple_proc_for_linear_algoritms(
101 | numeric_features, categoric_features
102 | ),
103 | ),
104 | ('clf', model),
105 | ]
106 | )
107 |
108 | if use_dask:
109 | from dask_ml.model_selection import RandomizedSearchCV
110 |
111 | return RandomizedSearchCV(
112 | pipe, param_space, n_iter=n_iter, scoring=scoring, cv=5
113 | )
114 | else:
115 | from sklearn.model_selection import RandomizedSearchCV
116 |
117 | return RandomizedSearchCV(
118 | pipe, param_space, n_iter=n_iter, scoring=scoring, cv=5
119 | )
120 |
--------------------------------------------------------------------------------
/soccer_xg/ml/mlp.py:
--------------------------------------------------------------------------------
1 | from scipy.stats.distributions import randint, uniform
2 | from sklearn.neural_network import MLPClassifier, MLPRegressor
3 | from sklearn.pipeline import Pipeline
4 |
5 | from .preprocessing import simple_proc_for_linear_algoritms
6 |
7 |
8 | def mlp_gridsearch_classifier(
9 | numeric_features,
10 | categoric_features,
11 | learning_rate=0.08,
12 | use_dask=False,
13 | n_iter=100,
14 | scoring='roc_auc',
15 | ):
16 | """
17 | Simple classification pipeline using hyperband to optimize mlp hyper-parameters
18 | Parameters
19 | ----------
20 | `numeric_features` : The list of numeric features
21 | `categoric_features` : The list of categoric features
22 | `learning_rate` : The learning rate
23 | """
24 |
25 | return _mlp_gridsearch_model(
26 | 'classification',
27 | numeric_features,
28 | categoric_features,
29 | learning_rate,
30 | use_dask,
31 | n_iter,
32 | scoring,
33 | )
34 |
35 |
36 | def mlp_gridsearch_regressor(
37 | numeric_features,
38 | categoric_features,
39 | learning_rate=0.08,
40 | use_dask=False,
41 | n_iter=100,
42 | scoring='roc_auc',
43 | ):
44 | """
45 | Simple regression pipeline using hyperband to optimize mlp hyper-parameters
46 | Parameters
47 | ----------
48 | `numeric_features` : The list of numeric features
49 | `categoric_features` : The list of categoric features
50 | `learning_rate` : The learning rate
51 | """
52 |
53 | return _mlp_gridsearch_model(
54 | 'regression',
55 | numeric_features,
56 | categoric_features,
57 | learning_rate,
58 | use_dask,
59 | n_iter,
60 | scoring,
61 | )
62 |
63 |
64 | def _mlp_gridsearch_model(
65 | task,
66 | numeric_features,
67 | categoric_features,
68 | learning_rate,
69 | use_dask,
70 | n_iter,
71 | scoring,
72 | ):
73 | param_space = {
74 | 'clf__hidden_layer_sizes': [
75 | (24,),
76 | (12, 12),
77 | (6, 6, 6, 6),
78 | (4, 4, 4, 4, 4, 4),
79 | (12, 6, 3, 3),
80 | ],
81 | 'clf__activation': ['relu', 'logistic', 'tanh'],
82 | 'clf__batch_size': [16, 32, 64, 128, 256, 512],
83 | 'clf__alpha': uniform(0.0001, 0.9),
84 | 'clf__learning_rate': ['constant', 'adaptive'],
85 | }
86 |
87 | model = (
88 | MLPClassifier(learning_rate_init=learning_rate)
89 | if task == 'classification'
90 | else MLPRegressor(learning_rate_init=learning_rate)
91 | )
92 |
93 | pipe = Pipeline(
94 | [
95 | (
96 | 'preprocessing',
97 | simple_proc_for_linear_algoritms(
98 | numeric_features, categoric_features
99 | ),
100 | ),
101 | ('clf', model),
102 | ]
103 | )
104 |
105 | if use_dask:
106 | from dask_ml.model_selection import RandomizedSearchCV
107 |
108 | return RandomizedSearchCV(
109 | pipe, param_space, n_iter=n_iter, scoring=scoring, cv=5
110 | )
111 | else:
112 | from sklearn.model_selection import RandomizedSearchCV
113 |
114 | return RandomizedSearchCV(
115 | pipe, param_space, n_iter=n_iter, scoring=scoring, cv=5
116 | )
117 |
--------------------------------------------------------------------------------
/soccer_xg/ml/pipeline.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import scipy.sparse as sp
4 | from scipy import stats
5 | from sklearn import clone
6 | from sklearn.base import BaseEstimator, TransformerMixin
7 | from sklearn.decomposition import TruncatedSVD
8 | from sklearn.model_selection import cross_val_predict
9 | from sklearn.preprocessing import LabelEncoder, MinMaxScaler
10 |
11 |
12 | class ColumnsSelector(BaseEstimator, TransformerMixin):
13 | def __init__(self, columns):
14 | assert isinstance(columns, list)
15 | self.columns = columns
16 |
17 | def fit(self, X, y=None):
18 | return self
19 |
20 | def transform(self, X):
21 | return X[self.columns]
22 |
23 |
24 | class TolerantLE(LabelEncoder):
25 | def transform(self, y):
26 | return np.searchsorted(self.classes_, y)
27 |
28 |
29 | class UniqueCountColumnSelector(BaseEstimator, TransformerMixin):
30 | """
31 | To select those columns whose unique-count values are between
32 | lowerbound (inclusive) and upperbound (exclusive)
33 | """
34 |
35 | def __init__(self, lowerbound, upperbound):
36 | self.lowerbound = lowerbound
37 | self.upperbound = upperbound
38 |
39 | def fit(self, X, y=None):
40 | counts = X.apply(lambda vect: vect.unique().shape[0])
41 | self.columns = counts.index[
42 | counts.between(self.lowerbound, self.upperbound + 1)
43 | ]
44 | return self
45 |
46 | def transform(self, X):
47 | return X[self.columns]
48 |
49 |
50 | class ColumnApplier(BaseEstimator, TransformerMixin):
51 | """
52 | Some sklearn transformers can apply only on ONE column at a time
53 | Wrap them with ColumnApplier to apply on all the dataset
54 | """
55 |
56 | def __init__(self, underlying):
57 | self.underlying = underlying
58 |
59 | def fit(self, X, y=None):
60 | m = {}
61 | X = pd.DataFrame(X) # TODO: :( reimplement in pure numpy?
62 | for c in X.columns:
63 | k = clone(self.underlying)
64 | k.fit(X[c])
65 | m[c] = k
66 | self._column_stages = m
67 | return self
68 |
69 | def transform(self, X):
70 | ret = {}
71 | X = pd.DataFrame(X)
72 | for c, k in self._column_stages.items():
73 | ret[c] = k.transform(X[c])
74 | return pd.DataFrame(ret)[X.columns] # keep the same order
75 |
76 |
77 | class OrdinalEncoder(BaseEstimator, TransformerMixin):
78 | """
79 | Encode the categorical value by natural number based on alphabetical order
80 | N/A are encoded to -2
81 | rare values to -1
82 | Very similar to TolerentLabelEncoder
83 | TODO: improve the implementation
84 | """
85 |
86 | def __init__(self, min_support):
87 | self.min_support = min_support
88 | self.vc = {}
89 |
90 | def _mapping(self, vc):
91 | mapping = {}
92 | for i, v in enumerate(vc[vc >= self.min_support].index):
93 | mapping[v] = i
94 | for v in vc.index[vc < self.min_support]:
95 | mapping[v] = -1
96 | mapping['nan'] = -2
97 | return mapping
98 |
99 | def _transform_column(self, x):
100 | x = x.astype(str)
101 | vc = self.vc[x.name]
102 |
103 | mapping = self._mapping(vc)
104 |
105 | output = pd.DataFrame()
106 | output[x.name] = x.map(
107 | lambda a: mapping[a] if a in mapping.keys() else -3
108 | )
109 | output.index = x.index
110 | return output.astype(int)
111 |
112 | def fit(self, x, y=None):
113 | x = x.astype(str)
114 | self.vc = dict((c, x[c].value_counts()) for c in x.columns)
115 | return self
116 |
117 | def transform(self, df):
118 | if len(df[df.index.duplicated()]):
119 | print(df[df.index.duplicated()].index)
120 | raise ValueError('Input contains duplicate index')
121 | dfs = [self._transform_column(df[c]) for c in df.columns]
122 | out = pd.DataFrame(index=df.index)
123 | for df in dfs:
124 | out = out.join(df)
125 | return out.values
126 |
127 |
128 | class CountFrequencyEncoder(BaseEstimator, TransformerMixin):
129 | """
130 | Encode the value by their frequency observed in the training set
131 | """
132 |
133 | def __init__(self, min_card=5, count_na=False):
134 | self.min_card = min_card
135 | self.count_na = count_na
136 | self.vc = None
137 |
138 | def fit(self, x, y=None):
139 | x = pd.Series(x)
140 | vc = x.value_counts()
141 | self.others_count = vc[vc < self.min_card].sum()
142 | self.vc = vc[vc >= self.min_card].to_dict()
143 | self.num_na = x.isnull().sum()
144 | return self
145 |
146 | def transform(self, x):
147 | vc = self.vc
148 | output = x.map(lambda a: vc.get(a, self.others_count))
149 | if self.count_na:
150 | output = output.fillna(self.num_na)
151 | return output.values
152 |
153 |
154 | class BoxCoxTransformer(BaseEstimator, TransformerMixin):
155 | """
156 | Boxcox transformation for numerical columns
157 | To make them more Gaussian-like
158 | """
159 |
160 | def __init__(self):
161 | self.scaler = MinMaxScaler()
162 | self.shift = 0.0001
163 |
164 | def fit(self, x, y=None):
165 | x = x.values.reshape(-1, 1)
166 | x = self.scaler.fit_transform(x) + self.shift
167 | self.boxcox_lmbda = stats.boxcox(x)[1]
168 | return self
169 |
170 | def transform(self, x):
171 | x = x.values.reshape(-1, 1)
172 | scaled = np.maximum(self.shift, self.scaler.transform(x) + self.shift)
173 | ret = stats.boxcox(scaled, self.boxcox_lmbda)
174 | return ret[:, 0]
175 |
176 |
177 | class Logify(BaseEstimator, TransformerMixin):
178 | """
179 | Log transformation
180 | """
181 |
182 | def __init__(self):
183 | self.shift = 2
184 |
185 | def fit(self, x, y=None):
186 | return self
187 |
188 | def transform(self, x):
189 | return np.log10(x - x.min() + self.shift)
190 |
191 |
192 | class YToLog(BaseEstimator, TransformerMixin):
193 | """
194 | Transforming Y to log before fitting
195 | and transforming back the prediction to real values before return
196 | """
197 |
198 | def __init__(self, delegate, shift=0):
199 | self.delegate = delegate
200 | self.shift = shift
201 |
202 | def fit(self, X, y):
203 | logy = np.log(y + self.shift)
204 | self.delegate.fit(X, logy)
205 | return self
206 |
207 | def predict(self, X):
208 | pred = self.delegate.predict(X)
209 | return np.exp(pred) - self.shift
210 |
211 |
212 | class FillNaN(BaseEstimator, TransformerMixin):
213 | def __init__(self, replace):
214 | self.replace = replace
215 |
216 | def fit(self, x, y=None):
217 | return self
218 |
219 | def transform(self, x):
220 | return x.fillna(self.replace)
221 |
222 |
223 | class AsString(BaseEstimator, TransformerMixin):
224 | def fit(self, x, y=None):
225 | return self
226 |
227 | def transform(self, x):
228 | return x.astype(str)
229 |
230 |
231 | class StackedModel(BaseEstimator, TransformerMixin):
232 | def __init__(self, delegate, cv=5, method='predict_proba'):
233 | self.delegate = delegate
234 | self.cv = cv
235 | self.method = method
236 |
237 | def fit(self, X, y):
238 | raise Exception
239 |
240 | def fit_transform(self, X, y):
241 | a = cross_val_predict(
242 | self.delegate, X, y, cv=self.cv, method=self.method
243 | )
244 | self.delegate.fit(X, y)
245 | if len(a.shape) == 1:
246 | a = a.reshape(-1, 1)
247 | return a
248 |
249 | def transform(self, X):
250 | if self.method == 'predict_proba':
251 | return self.delegate.predict_proba(X)
252 | else:
253 | return self.delegate.predict(X).reshape(-1, 1)
254 |
255 |
256 | class To1D(BaseEstimator, TransformerMixin):
257 | def fit(self, X, y=None):
258 | return self
259 |
260 | def transform(self, X):
261 | return X.values.reshape(-1)
262 |
263 |
264 | class TolerantLabelEncoder(TransformerMixin):
265 | """ LabelEncoder is not tolerant to unseen values
266 | """
267 |
268 | def __init__(self, min_count=10):
269 | self.min_count = min_count
270 |
271 | def fit(self, x, y=None):
272 | assert len(x.shape) == 1
273 | vc = x.value_counts()
274 | vc = vc[vc > self.min_count]
275 | self.values = {
276 | value: (1 + index) for index, value in enumerate(vc.index)
277 | }
278 | return self
279 |
280 | def transform(self, x):
281 | values = self.values
282 | return x.map(lambda a: values.get(a, 0))
283 |
284 | def inverse_transform(self, y):
285 | if not hasattr(self, 'inversed_mapping'):
286 | self.inversed_mapping = {v: k for k, v in self.values.items()}
287 | self.inversed_mapping[0] = None
288 | mapping = self.inversed_mapping
289 | return pd.Series(y).map(lambda a: mapping[a])
290 |
291 |
292 | class SVD_Embedding(TransformerMixin):
293 | def __init__(
294 | self, rowname, colname, valuename=None, svd_kwargs={'n_components': 10}
295 | ):
296 | self.rowname = rowname
297 | self.colname = colname
298 | self.valuename = valuename
299 | self.rowle = TolerantLabelEncoder()
300 | self.colle = TolerantLabelEncoder()
301 | self.svd = TruncatedSVD(**svd_kwargs)
302 |
303 | def fit(self, X, y=None):
304 | row = self.rowle.fit_transform(X[self.rowname])
305 | col = self.colle.fit_transform(X[self.colname])
306 | if not self.valuename:
307 | data = np.ones(X.shape[0])
308 | else:
309 | data = X[self.valuename].groupby([row, col]).mean()
310 | row = data.index.get_level_values(0).values
311 | col = data.index.get_level_values(1).values
312 | data = data.values
313 | matrix = sp.coo_matrix((data, (row, col)))
314 | self.embedding = self.svd.fit_transform(matrix)
315 | return self
316 |
317 | def transform(self, X):
318 | row = self.rowle.transform(X[self.rowname])
319 | return self.embedding[row, :]
320 |
--------------------------------------------------------------------------------
/soccer_xg/ml/preprocessing.py:
--------------------------------------------------------------------------------
1 | import category_encoders as ce
2 | from sklearn.impute import SimpleImputer
3 | from sklearn.pipeline import make_pipeline, make_union
4 | from sklearn.preprocessing import OneHotEncoder, StandardScaler
5 |
6 | from .pipeline import AsString, ColumnsSelector, OrdinalEncoder
7 |
8 |
9 | def simple_proc_for_tree_algoritms(numeric_features, categoric_features):
10 | """
11 | Create a simple preprocessing pipeline for tree based algorithms
12 | """
13 |
14 | catpipe = make_pipeline(
15 | ColumnsSelector(categoric_features),
16 | OrdinalEncoder(min_support=5)
17 | # ColumnApplier(FillNaN('nan')),
18 | # ColumnApplier(TolerantLabelEncoder())
19 | )
20 | numpipe = make_pipeline(
21 | ColumnsSelector(numeric_features),
22 | SimpleImputer(strategy='mean'),
23 | StandardScaler(),
24 | )
25 | if numeric_features and categoric_features:
26 | return make_union(catpipe, numpipe)
27 | elif numeric_features:
28 | return numpipe
29 | elif categoric_features:
30 | return catpipe
31 | raise Exception('Both variable lists are empty')
32 |
33 |
34 | def simple_proc_for_linear_algoritms(numeric_features, categoric_features):
35 | """
36 | Create a simple preprocessing pipeline for linear algorithms
37 | """
38 |
39 | catpipe = make_pipeline(
40 | ColumnsSelector(categoric_features),
41 | AsString(),
42 | ce.OneHotEncoder()
43 | # ColumnApplier(FillNaN('nan')),
44 | # ColumnApplier(TolerantLabelEncoder())
45 | )
46 | numpipe = make_pipeline(
47 | ColumnsSelector(numeric_features),
48 | SimpleImputer(strategy='mean'),
49 | StandardScaler(),
50 | )
51 | if numeric_features and categoric_features:
52 | return make_union(catpipe, numpipe)
53 | elif numeric_features:
54 | return numpipe
55 | elif categoric_features:
56 | return catpipe
57 | raise Exception('Both variable lists are empty')
58 |
--------------------------------------------------------------------------------
/soccer_xg/ml/tree_based_LR.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.base import BaseEstimator, ClassifierMixin
3 | from sklearn.preprocessing import OneHotEncoder
4 |
5 |
6 | # class for the tree-based/logistic regression pipeline
7 | # see: https://gdmarmerola.github.io/probability-calibration/
8 | class TreeBasedLR(BaseEstimator, ClassifierMixin):
9 |
10 | # initialization
11 | def __init__(self, forest, lr):
12 |
13 | # configuring the models
14 | self.forest = forest
15 | self.lr = lr
16 |
17 | # method for fitting the model
18 | def fit(self, X, y, sample_weight=None, fit_params={}):
19 |
20 | self.classes_ = np.unique(y)
21 |
22 | # first, we fit our tree-based model on the dataset
23 | self.forest.fit(X, y, **fit_params)
24 |
25 | # then, we apply the model to the data in order to get the leave indexes
26 | # if self.forest_model == 'cat':
27 | # leaves = self.forest.calc_leaf_indexes(X)
28 | # else:
29 | leaves = self.forest.named_steps['clf'].apply(
30 | self.forest.named_steps['preprocessing'].transform(X)
31 | )
32 |
33 | # then, we one-hot encode the leave indexes so we can use them in the logistic regression
34 | self.encoder = OneHotEncoder(sparse=True)
35 | leaves_encoded = self.encoder.fit_transform(leaves)
36 |
37 | # and fit it to the encoded leaves
38 | self.lr.fit(leaves_encoded, y)
39 |
40 | # method for predicting probabilities
41 | def predict_proba(self, X):
42 |
43 | # then, we apply the model to the data in order to get the leave indexes
44 | # if self.forest_model == 'cat':
45 | # leaves = self.forest.calc_leaf_indexes(X)
46 | # else:
47 | leaves = self.forest.named_steps['clf'].apply(
48 | self.forest.named_steps['preprocessing'].transform(X)
49 | )
50 |
51 | # then, we one-hot encode the leave indexes so we can use them in the logistic regression
52 | leaves_encoded = self.encoder.transform(leaves)
53 |
54 | # and fit it to the encoded leaves
55 | y_hat = self.lr.predict_proba(leaves_encoded)
56 |
57 | # retuning probabilities
58 | return y_hat
59 |
60 | # get_params, needed for sklearn estimators
61 | def get_params(self, deep=True):
62 | return {
63 | 'forest': self.forest,
64 | 'lr': self.lr,
65 | }
66 |
--------------------------------------------------------------------------------
/soccer_xg/ml/xgboost.py:
--------------------------------------------------------------------------------
1 | from scipy.stats.distributions import randint, uniform
2 | from sklearn.pipeline import Pipeline
3 | from xgboost import sklearn as xgbsk
4 |
5 | from .preprocessing import simple_proc_for_tree_algoritms
6 |
7 |
8 | def xgboost_gridsearch_classifier(
9 | numeric_features,
10 | categoric_features,
11 | learning_rate=0.08,
12 | use_dask=False,
13 | n_iter=100,
14 | scoring='roc_auc',
15 | ):
16 | """
17 | Simple classification pipeline using hyperband to optimize xgboost hyper-parameters
18 | Parameters
19 | ----------
20 | `numeric_features` : The list of numeric features
21 | `categoric_features` : The list of categoric features
22 | `learning_rate` : The learning rate
23 | """
24 |
25 | return _xgboost_gridsearch_model(
26 | 'classification',
27 | numeric_features,
28 | categoric_features,
29 | learning_rate,
30 | use_dask,
31 | n_iter,
32 | scoring,
33 | )
34 |
35 |
36 | def xgboost_gridsearch_regressor(
37 | numeric_features,
38 | categoric_features,
39 | learning_rate=0.08,
40 | use_dask=False,
41 | n_iter=100,
42 | scoring='roc_auc',
43 | ):
44 | """
45 | Simple regression pipeline using hyperband to optimize xgboost hyper-parameters
46 | Parameters
47 | ----------
48 | `numeric_features` : The list of numeric features
49 | `categoric_features` : The list of categoric features
50 | `learning_rate` : The learning rate
51 | """
52 |
53 | return _xgboost_gridsearch_model(
54 | 'regression',
55 | numeric_features,
56 | categoric_features,
57 | learning_rate,
58 | use_dask,
59 | n_iter,
60 | scoring,
61 | )
62 |
63 |
64 | def _xgboost_gridsearch_model(
65 | task,
66 | numeric_features,
67 | categoric_features,
68 | learning_rate,
69 | use_dask,
70 | n_iter,
71 | scoring,
72 | ):
73 | param_space = {
74 | 'clf__max_depth': randint(2, 11),
75 | 'clf__min_child_weight': randint(1, 11),
76 | 'clf__subsample': uniform(0.5, 0.5),
77 | 'clf__colsample_bytree': uniform(0.5, 0.5),
78 | 'clf__colsample_bylevel': uniform(0.5, 0.5),
79 | 'clf__gamma': uniform(0, 1),
80 | 'clf__reg_alpha': uniform(0, 1),
81 | 'clf__reg_lambda': uniform(0, 10),
82 | 'clf__base_score': uniform(0.1, 0.9),
83 | 'clf__scale_pos_weight': uniform(0.1, 9.9),
84 | }
85 |
86 | model = (
87 | xgbsk.XGBClassifier(learning_rate=learning_rate)
88 | if task == 'classification'
89 | else xgbsk.XGBRegressor(learning_rate=learning_rate)
90 | )
91 |
92 | pipe = Pipeline(
93 | [
94 | (
95 | 'preprocessing',
96 | simple_proc_for_tree_algoritms(
97 | numeric_features, categoric_features
98 | ),
99 | ),
100 | ('clf', model),
101 | ]
102 | )
103 |
104 | if use_dask:
105 | from dask_ml.model_selection import RandomizedSearchCV
106 |
107 | return RandomizedSearchCV(
108 | pipe, param_space, n_iter=n_iter, scoring=scoring, cv=5
109 | )
110 | else:
111 | from sklearn.model_selection import RandomizedSearchCV
112 |
113 | return RandomizedSearchCV(
114 | pipe, param_space, n_iter=n_iter, scoring=scoring, cv=5
115 | )
116 |
--------------------------------------------------------------------------------
/soccer_xg/models/openplay_logreg_advanced:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ML-KULeuven/soccer_xg/b9489d929e0fa34771267d429256366d0fda27ad/soccer_xg/models/openplay_logreg_advanced
--------------------------------------------------------------------------------
/soccer_xg/models/openplay_logreg_basic:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ML-KULeuven/soccer_xg/b9489d929e0fa34771267d429256366d0fda27ad/soccer_xg/models/openplay_logreg_basic
--------------------------------------------------------------------------------
/soccer_xg/models/openplay_xgboost_advanced:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ML-KULeuven/soccer_xg/b9489d929e0fa34771267d429256366d0fda27ad/soccer_xg/models/openplay_xgboost_advanced
--------------------------------------------------------------------------------
/soccer_xg/models/openplay_xgboost_basic:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ML-KULeuven/soccer_xg/b9489d929e0fa34771267d429256366d0fda27ad/soccer_xg/models/openplay_xgboost_basic
--------------------------------------------------------------------------------
/soccer_xg/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import pandas as pd
4 | import socceraction.spadl.config as spadlcfg
5 | from fuzzywuzzy import fuzz
6 |
7 |
8 | def play_left_to_right(actions, home_team_id):
9 | away_idx = actions.team_id != home_team_id
10 | for col in ['start_x', 'end_x']:
11 | actions.loc[away_idx, col] = (
12 | spadlcfg.field_length - actions.loc[away_idx][col].values
13 | )
14 | for col in ['start_y', 'end_y']:
15 | actions.loc[away_idx, col] = (
16 | spadlcfg.field_width - actions.loc[away_idx][col].values
17 | )
18 | return actions
19 |
20 |
21 | def enhance_actions(actions):
22 | # data
23 | actiontypes = pd.DataFrame(
24 | list(enumerate(spadlcfg.actiontypes)), columns=['type_id', 'type_name']
25 | )
26 |
27 | bodyparts = pd.DataFrame(
28 | list(enumerate(spadlcfg.bodyparts)),
29 | columns=['bodypart_id', 'bodypart_name'],
30 | )
31 |
32 | results = pd.DataFrame(
33 | list(enumerate(spadlcfg.results)), columns=['result_id', 'result_name']
34 | )
35 |
36 | return (
37 | actions.merge(actiontypes, how='left')
38 | .merge(results, how='left')
39 | .merge(bodyparts, how='left')
40 | # .sort_values(["period_id", "time_seconds", "timestamp"])
41 | )
42 |
43 |
44 | def match_name(name, list_names, min_score=0):
45 | # -1 score incase we don't get any matches
46 | max_score = -1
47 | # Returning empty name for no match as well
48 | max_name = ''
49 | # Iternating over all names in the other
50 | for name2 in list_names:
51 | # Finding fuzzy match score
52 | score = fuzz.ratio(name, name2)
53 | # Checking if we are above our threshold and have a better score
54 | if (score > min_score) & (score > max_score):
55 | max_name = name2
56 | max_score = score
57 | return (max_name, max_score)
58 |
59 |
60 | def map_names(
61 | df1,
62 | df1_match_colname,
63 | df1_output_colname,
64 | df2,
65 | df2_match_colname,
66 | df2_output_colname,
67 | threshold=75,
68 | ):
69 | # List for dicts for easy dataframe creation
70 | dict_list = []
71 | for _, (id, name) in df1[
72 | [df1_output_colname, df1_match_colname]
73 | ].iterrows():
74 | # Use our method to find best match, we can set a threshold here
75 | match = match_name(name, df2[df2_match_colname], threshold)
76 | # New dict for storing data
77 | dict_ = {}
78 | dict_.update({'df1_name': name})
79 | dict_.update({'df1_id': id})
80 | if match[1] > threshold:
81 | dict_.update({'df2_name': match[0]})
82 | dict_.update(
83 | {
84 | 'df2_id': df2.loc[
85 | df2[df2_match_colname] == match[0], df2_output_colname
86 | ].iloc[0]
87 | }
88 | )
89 | else:
90 | dict_.update({'df2_name': 'unknown'})
91 | dict_.update({'df2_id': 0})
92 | dict_list.append(dict_)
93 | merge_table = pd.DataFrame(dict_list)
94 | return merge_table
95 |
96 |
97 | def get_matching_game(api, game_id, provider, other_provider, teams_mapping):
98 | season_id = str(api[provider].games.loc[game_id, 'season_id'])
99 | competition_id = api[provider].games.loc[game_id, 'competition_id']
100 | # Get matching game
101 | home_team, away_team = api[provider].get_home_away_team_id(game_id)
102 | other_home_team = teams_mapping.set_index(f'{provider}_id').loc[
103 | home_team, f'{other_provider}_id'
104 | ]
105 | other_away_team = teams_mapping.set_index(f'{provider}_id').loc[
106 | away_team, f'{other_provider}_id'
107 | ]
108 | other_games = api[other_provider].games
109 | other_game_id = (
110 | other_games[
111 | (other_games.home_team_id == other_home_team)
112 | & (other_games.away_team_id == other_away_team)
113 | & (other_games.competition_id == competition_id)
114 | & (other_games.season_id.astype(str) == season_id)
115 | ]
116 | .iloc[0]
117 | .name
118 | )
119 | return other_game_id
120 |
121 |
122 | def get_matching_shot(
123 | api,
124 | shot,
125 | provider_shot,
126 | other_shots,
127 | provider_other_shots,
128 | teams_mapping,
129 | players_mapping=None,
130 | ):
131 | # Get matching game
132 | game_id = shot.game_id
133 | season_id = str(api[provider_shot].games.loc[game_id, 'season_id'])
134 | competition_id = api[provider_shot].games.loc[game_id, 'competition_id']
135 | home_team, away_team = api[provider_shot].get_home_away_team_id(game_id)
136 | other_home_team = teams_mapping.set_index(f'{provider_shot}_id').loc[
137 | home_team, f'{provider_other_shots}_id'
138 | ]
139 | other_away_team = teams_mapping.set_index(f'{provider_shot}_id').loc[
140 | away_team, f'{provider_other_shots}_id'
141 | ]
142 | other_games = api[provider_other_shots].games
143 | other_game_id = (
144 | other_games[
145 | (other_games.home_team_id == other_home_team)
146 | & (other_games.away_team_id == other_away_team)
147 | & (other_games.competition_id == competition_id)
148 | & (other_games.season_id.astype(str) == season_id)
149 | ]
150 | .iloc[0]
151 | .name
152 | )
153 | other_shots_in_game = other_shots[other_shots.game_id == other_game_id]
154 | # Get matching shot-taker
155 | if players_mapping is not None:
156 | player_id = shot.player_id
157 | other_player_id = players_mapping.set_index(f'{provider_shot}_id').loc[
158 | int(player_id), f'{provider_other_shots}_id'
159 | ]
160 | other_shots_by_player = other_shots_in_game[
161 | other_shots_in_game.player_id == other_player_id
162 | ]
163 | else:
164 | other_shots_by_player = other_shots_in_game
165 | # Get shots in same period
166 | period_id = shot.period_id
167 | other_shots_by_player_in_period = other_shots_by_player[
168 | other_shots_by_player.period_id == period_id
169 | ]
170 | # Get shots that happened around the same time
171 | ts = shot.time_seconds
172 | best_match = other_shots_by_player_in_period.iloc[
173 | (other_shots_by_player_in_period['time_seconds'] - ts)
174 | .abs()
175 | .argsort()[:1]
176 | ].iloc[0]
177 | if abs(ts - best_match.time_seconds) < 3:
178 | return best_match
179 | return None
180 |
181 |
182 | def sample_temporal(api, size_val=0.0, size_test=0.2):
183 | game_ids = api.games.sort_values(by='game_date').index.values
184 | nb_games = len(game_ids)
185 | games_train = game_ids[
186 | 0 : math.floor((1 - size_val - size_test) * nb_games)
187 | ]
188 | games_val = game_ids[
189 | math.ceil((1 - size_val - size_test) * nb_games) : math.floor(
190 | (1 - size_test) * nb_games
191 | )
192 | ]
193 | games_test = game_ids[math.ceil((1 - size_test) * nb_games) + 1 : -1]
194 | return games_train, games_val, games_test
195 |
--------------------------------------------------------------------------------
/soccer_xg/visualisation.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import matplotsoccer as mps
3 | import numpy as np
4 | import numpy.ma as ma
5 | from matplotlib.ticker import MultipleLocator
6 | from sklearn.metrics import auc, roc_curve
7 | from soccer_xg import metrics
8 |
9 |
10 | def plot_calibration_curve(
11 | y_true,
12 | y_pred,
13 | name='Calibration curve',
14 | min_samples=None,
15 | axis=None,
16 | **kwargs,
17 | ):
18 | """Plot the validation data.
19 |
20 | Parameters
21 | ----------
22 | axis : matplotlib.pyplot.axis object or ``None`` (default=``None``)
23 | If provided, the validation line will be overlaid on ``axis``.
24 | Otherwise, a new figure and axis will be generated and plotted on.
25 | **kwargs
26 | Arguments to ``axis.plot``.
27 |
28 | Returns
29 | -------
30 | matplotlib.pylot.axis
31 | The axis the plot was made on.
32 |
33 | Raises
34 | ------
35 | NotFittedError
36 | If the model hasn't been fit **and** validated.
37 | """
38 |
39 | if axis is None:
40 | axis = plt.figure(figsize=(5, 5)).add_subplot(111)
41 |
42 | axis.set_title(name)
43 | axis.plot([0, 100], [0, 100], ls='--', lw=1, color='grey')
44 | axis.set_xlabel('Predicted probability')
45 | axis.set_ylabel('True probability in each bin')
46 | axis.set_xlim((0, 100))
47 | axis.xaxis.set_major_locator(MultipleLocator(20))
48 | axis.xaxis.set_minor_locator(MultipleLocator(10))
49 | axis.set_ylim((0, 100))
50 | axis.yaxis.set_major_locator(MultipleLocator(20))
51 | axis.yaxis.set_minor_locator(MultipleLocator(10))
52 | # axis.set_aspect(1)
53 | axis.grid(which='both')
54 |
55 | (
56 | sample_probabilities,
57 | predicted_pos_percents,
58 | num_plays_used,
59 | ) = metrics.bayesian_calibration_curve(y_true, y_pred)
60 |
61 | if min_samples is not None:
62 | axis.plot(
63 | sample_probabilities,
64 | predicted_pos_percents,
65 | c='c',
66 | alpha=0.3,
67 | **kwargs,
68 | )
69 | sample_probabilities = ma.array(sample_probabilities)
70 | sample_probabilities[num_plays_used < min_samples] = ma.masked
71 | predicted_pos_percents = ma.array(predicted_pos_percents)
72 | predicted_pos_percents[num_plays_used < min_samples] = ma.masked
73 |
74 | max_deviation = metrics.max_deviation(
75 | sample_probabilities, predicted_pos_percents
76 | )
77 | residual_area = metrics.residual_area(
78 | sample_probabilities, predicted_pos_percents
79 | )
80 |
81 | axis.plot(
82 | sample_probabilities,
83 | predicted_pos_percents,
84 | c='c',
85 | label='Calibration curve\n(area = %0.2f, max dev = %0.2f)'
86 | % (residual_area, max_deviation),
87 | **kwargs,
88 | )
89 |
90 | axis.legend(loc='lower right')
91 |
92 | ax2 = axis.twinx()
93 | ax2.hist(
94 | y_pred * 100,
95 | bins=np.arange(0, 101, 1),
96 | density=True,
97 | alpha=0.4,
98 | facecolor='grey',
99 | )
100 | ax2.set_ylim([0, 0.2])
101 | ax2.set_yticks([0, 0.1, 0.2])
102 |
103 | plt.tight_layout()
104 | return axis
105 |
106 |
107 | def plot_roc_curve(y_true, y_prob, name='Calibration curve', axis=None):
108 |
109 | fpr, tpr, _ = roc_curve(y_true, y_prob)
110 | roc_auc = auc(fpr, tpr)
111 |
112 | if axis is None:
113 | axis = plt.figure(figsize=(5, 5)).add_subplot(111)
114 |
115 | axis.plot(
116 | fpr, tpr, linewidth=1, label='ROC curve (area = %0.2f)' % roc_auc
117 | )
118 |
119 | # reference line, legends, and axis labels
120 | axis.plot([0, 1], [0, 1], linestyle='--', color='gray')
121 | axis.set_title('ROC curve')
122 | axis.set_xlabel('False Positive Rate')
123 | axis.set_ylabel('True Positive Rate')
124 | axis.set_xlim(0, 1)
125 | axis.xaxis.set_major_locator(MultipleLocator(0.20))
126 | axis.xaxis.set_minor_locator(MultipleLocator(0.10))
127 | axis.set_ylim(0, 1)
128 | axis.yaxis.set_major_locator(MultipleLocator(0.20))
129 | axis.yaxis.set_minor_locator(MultipleLocator(0.10))
130 | axis.grid(which='both')
131 |
132 | # sns.despine()
133 | # plt.gca().xaxis.set_ticks_position('none')
134 | # plt.gca().yaxis.set_ticks_position('none')
135 | plt.gca().legend()
136 |
137 | axis.legend(loc='lower right')
138 | plt.tight_layout()
139 |
140 |
141 | def plot_heatmap(model, data, axis=None):
142 |
143 | if axis is None:
144 | axis = plt.figure(figsize=(8, 10)).add_subplot(111)
145 |
146 | z = model.estimate(data)['xG'].values
147 | axis = mps.field(ax=axis, show=False)
148 | axis = mps.heatmap(
149 | z.reshape((106, 69)).T, show=False, ax=axis, cmap='viridis_r'
150 | )
151 | axis.set_xlim((70, 108))
152 | axis.set_axis_off()
153 | return axis
154 |
--------------------------------------------------------------------------------
/soccer_xg/xg.py:
--------------------------------------------------------------------------------
1 | """Tools for creating and analyzing xG models."""
2 | import os
3 |
4 | import joblib
5 | import pandas as pd
6 | import socceraction.spadl.config as spadlcfg
7 | from sklearn.linear_model import LogisticRegression
8 | from sklearn.metrics import brier_score_loss, roc_auc_score
9 | from sklearn.pipeline import make_pipeline
10 | from sklearn.utils.validation import NotFittedError
11 | from soccer_xg import features as fs
12 | from soccer_xg import metrics, utils
13 | from soccer_xg.api import DataApi
14 | from soccer_xg.ml.preprocessing import simple_proc_for_linear_algoritms
15 | from tqdm import tqdm
16 |
17 |
18 | class XGModel(object):
19 | """A wrapper around a pipeline for computing xG values.
20 |
21 | Parameters
22 | ----------
23 | copy_data : boolean (default=``True``)
24 | Whether or not to copy data when fitting and applying the model. Running the model
25 | in-place (``copy_data=False``) will be faster and have a smaller memory footprint,
26 | but if not done carefully can lead to data integrity issues.
27 |
28 | Attributes
29 | ----------
30 | model : A Scikit-learn pipeline (or equivalent)
31 | The actual model used to compute xG. Upon initialization it will be set to
32 | a default model, but can be overridden by the user.
33 | column_descriptions : dictionary
34 | A dictionary whose keys are the names of the columns used in the model, and the values are
35 | string descriptions of what the columns mean. Set at initialization to be the default model,
36 | if you create your own model you'll need to update this attribute manually.
37 | training_seasons : A list of tuples, or ``None`` (default=``None``)
38 | If the model was trained using data from the DataApi, a list of (competition_id, season_id) tuples
39 | used to train the model. If the DataApi was **not** used, an empty list. If no model
40 | has been trained yet, ``None``.
41 | validation_seasons : same as ``training_seasons``, but for validation data.
42 | sample_probabilities : A numpy array of floats or ``None`` (default=``None``)
43 | After the model has been validated, contains the sampled predicted probabilities used to
44 | compute the validation statistic.
45 | predicted_goal_percents : A numpy array of floats or ``None`` (default=``None``)
46 | After the model has been validated, contains the actual probabilities in the test
47 | set at each probability in ``sample_probabilities``.
48 | num_shots_used : A numpy array of floats or ``None`` (default=``None``)
49 | After the model has been validated, contains the number of shots used to compute each
50 | element of ``predicted_goal_percents``.
51 | model_directory : string
52 | The directory where all models will be saved to or loaded from.
53 | """
54 |
55 | model_directory = os.path.join(
56 | os.path.dirname(os.path.abspath(__file__)), 'models'
57 | )
58 | _default_model_filename = 'default_model.xg'
59 |
60 | def __init__(self, copy_data=True):
61 | self.copy_data = copy_data
62 | self.column_descriptions = None
63 |
64 | self.model = self.create_default_pipeline()
65 | self._fitted = False
66 | self._training_seasons = None
67 | self._validation_seasons = None
68 |
69 | self._sample_probabilities = None
70 | self._predicted_goal_percents = None
71 | self._num_shots_used = None
72 |
73 | @property
74 | def training_seasons(self):
75 | return self._training_seasons
76 |
77 | @property
78 | def validation_seasons(self):
79 | return self._validation_seasons
80 |
81 | @property
82 | def sample_probabilities(self):
83 | return self._sample_probabilities
84 |
85 | @property
86 | def predicted_goal_percents(self):
87 | return self._predicted_goal_percents
88 |
89 | @property
90 | def num_shots_used(self):
91 | return self._num_shots_used
92 |
93 | def train(
94 | self,
95 | source_data,
96 | training_seasons=(('ENG', '1617'), ('ENG', '1718')),
97 | target_colname='goal',
98 | ):
99 | """Train the model.
100 |
101 | Once a modeling pipeline is set up (either the default or something
102 | custom-generated), historical data needs to be fed into it in order to
103 | "fit" the model so that it can then be used to predict future results.
104 | This method implements a simple wrapper around the core Scikit-learn functionality
105 | which does this.
106 |
107 | The default is to use data from a DataApi object, however that can be changed
108 | to a simple Pandas DataFrame with precomputed features and labels if desired.
109 |
110 | There is no particular output from this function, rather the parameters governing
111 | the fit of the model are saved inside the model object itself. If you want to get an
112 | estimate of the quality of the fit, use the ``validate_model`` method after running
113 | this method.
114 |
115 | Parameters
116 | ----------
117 | source_data : ``DataApi`` or a Pandas DataFrame
118 | The data to be used to train the model. If an instance of
119 | ``DataApi`` is given, will query the api database for the training data.
120 | training_seasons : list of tuples (default=``[('ENG', '1617'), ('ENG', '1718')]``)
121 | What seasons to use to train the model if getting data from a DataApi instance.
122 | If ``source_data`` is not a ``DataApi``, this argument will be ignored.
123 | **NOTE:** it is critical not to use all possible data in order to train the
124 | model - some will need to be reserved for a final validation (see the
125 | ``validate_model`` method). A good dataset to reserve
126 | for validation is the most recent one or two seasons.
127 | target_colname : string or integer (default=``"goal"``)
128 | The name of the target variable column. This is only relevant if
129 | ``source_data`` is not a ``DataApi``.
130 |
131 | Returns
132 | -------
133 | ``None``
134 | """
135 | if isinstance(self.model, list):
136 | for model in self.model:
137 | model.train(source_data, training_seasons, target_colname)
138 | else:
139 | self._training_seasons = []
140 | if isinstance(source_data, DataApi):
141 | game_ids = source_data.games[
142 | source_data.games.season_id.astype(str).isin(
143 | [s[1] for s in training_seasons]
144 | )
145 | & source_data.games.competition_id.astype(str).isin(
146 | [s[0] for s in training_seasons]
147 | )
148 | ].index
149 | feature_cols = get_features(source_data, game_ids)
150 | target_col = get_labels(source_data, game_ids)
151 | self._training_seasons = training_seasons
152 | else:
153 | target_col = source_data[target_colname]
154 | feature_cols = source_data.drop(target_colname, axis=1)
155 | self.model.fit(feature_cols, target_col)
156 | self._fitted = True
157 |
158 | def validate(
159 | self,
160 | source_data,
161 | validation_seasons=(('ENG', '1819')),
162 | target_colname='goal',
163 | plot=True,
164 | ):
165 | """Validate the model.
166 |
167 | Once a modeling pipeline is trained, a different dataset must be fed into the trained model
168 | to validate the quality of the fit.
169 | This method implements a simple wrapper around the core Scikit-learn functionality
170 | which does this.
171 |
172 | The default is to use data from a DataApi object, however that can be changed
173 | to a simple Pandas DataFrame with precomputed features and labels if desired.
174 |
175 | The output of this method is a dictionary with relevant error metrics (see ``soccer_xg.metrics``).
176 |
177 | Parameters
178 | ----------
179 | source_data : ``DataApi`` or a Pandas DataFrame
180 | The data to be used to validate the model. If an instance of
181 | ``DataApi`` is given, will query the api database for the training data.
182 | validation_seasons : list of tuples (default=``[('ENG', '1819')]``)
183 | What seasons to use to validated the model if getting data from a DataApi instance.
184 | If ``source_data`` is not a ``DataApi``, this argument will be ignored.
185 | **NOTE:** it is critical not to use the same data to validate the model as was used
186 | in the fit. Generally a good data set to use for validation is one from a time
187 | period more recent than was used to train the model.
188 | target_colname : string or integer (default=``"goal"``)
189 | The name of the target variable column. This is only relevant if
190 | ``source_data`` is not a ``DataApi``.
191 | plot: bool (default=true)
192 | Whether to plot the AUROC and probability calibration curves.
193 |
194 | Returns
195 | -------
196 | dict
197 | Error metrics on the validation data.
198 |
199 | Raises
200 | ------
201 | NotFittedError
202 | If the model hasn't been fit.
203 |
204 | """
205 | if not self._fitted:
206 | raise NotFittedError('Must fit model before validating.')
207 |
208 | if isinstance(source_data, DataApi):
209 | game_ids = source_data.games[
210 | source_data.games.season_id.astype(str).isin(
211 | [s[1] for s in validation_seasons]
212 | )
213 | & source_data.games.competition_id.astype(str).isin(
214 | [s[0] for s in validation_seasons]
215 | )
216 | ].index
217 | target_col = get_labels(source_data, game_ids)
218 | self._validation_seasons = validation_seasons
219 | else:
220 | game_ids = None
221 | target_col = source_data[target_colname]
222 | self._validation_seasons = []
223 |
224 | df_predictions = self.estimate(source_data, game_ids)
225 | predicted_probabilities = df_predictions['xG']
226 | target_col = target_col.loc[df_predictions.index]
227 |
228 | (
229 | self._sample_probabilities,
230 | self._predicted_goal_percents,
231 | self._num_shots_used,
232 | ) = metrics.bayesian_calibration_curve(
233 | target_col.values, predicted_probabilities
234 | )
235 |
236 | # Compute the maximal deviation from a perfect prediction as well as the area under the
237 | # curve of the residual between |predicted - perfect|:
238 | max_deviation = metrics.max_deviation(
239 | self.sample_probabilities, self.predicted_goal_percents
240 | )
241 | residual_area = metrics.residual_area(
242 | self.sample_probabilities, self.predicted_goal_percents
243 | )
244 | roc = roc_auc_score(target_col, predicted_probabilities)
245 | brier = brier_score_loss(target_col, predicted_probabilities)
246 | ece = metrics.expected_calibration_error(
247 | target_col, predicted_probabilities, 10, 'uniform'
248 | )
249 | ace = metrics.expected_calibration_error(
250 | target_col, predicted_probabilities, 10, 'quantile'
251 | )
252 |
253 | if plot:
254 | import matplotlib.pyplot as plt
255 | from soccer_xg.visualisation import (
256 | plot_roc_curve,
257 | plot_calibration_curve,
258 | )
259 |
260 | fig, ax = plt.subplots(1, 2, figsize=(10, 5))
261 | plot_roc_curve(target_col, predicted_probabilities, axis=ax[0])
262 | plot_calibration_curve(
263 | target_col,
264 | predicted_probabilities,
265 | min_samples=100,
266 | axis=ax[1],
267 | )
268 |
269 | return {
270 | 'max_dev': max_deviation,
271 | 'residual_area': residual_area,
272 | 'roc': roc,
273 | 'brier': brier,
274 | 'ece': ece,
275 | 'ace': ace,
276 | 'fig': fig if plot else None,
277 | }
278 |
279 | def estimate(self, source_data, game_ids=None):
280 | """Estimate the xG values for all shots in a set of games.
281 |
282 | The default is to use data from a DataApi object, however that can be changed
283 | to a simple Pandas DataFrame with precomputed features and labels if desired.
284 |
285 | Parameters
286 | ----------
287 | source_data : ``DataApi`` or a Pandas DataFrame
288 | The data to be used to validate the model. If an instance of
289 | ``DataApi`` is given, will query the api database for the training data.
290 | game_ids : list of ints (default=None)
291 | Only xG values for the games in this list are returned. By default,
292 | xG values are computed for all games in the source data.
293 | If ``source_data`` is not a ``DataApi``, this argument will be ignored.
294 |
295 | Returns
296 | -------
297 | A Pandas DataFrame
298 | A dataframe with a column 'xG', containing the predictted xG value
299 | of each shot in the given data, indexed by (game_id, action_id) of
300 | the corresponding shot.
301 |
302 | Raises
303 | ------
304 | NotFittedError
305 | If the model hasn't been fit.
306 | """
307 | if not self._fitted:
308 | raise NotFittedError('Must fit model before predicting WP.')
309 |
310 | if isinstance(self.model, list):
311 | xg = []
312 | for model in self.model:
313 | xg.append(model.estimate(source_data, game_ids))
314 | return pd.concat(xg).sort_index()
315 | else:
316 | if isinstance(source_data, DataApi):
317 | if game_ids is None:
318 | game_ids = (
319 | source_data.games.index
320 | if game_ids is None
321 | else game_ids
322 | )
323 | source_data = get_features(source_data, game_ids)
324 |
325 | xg = pd.DataFrame(index=source_data.index)
326 | xg['xG'] = self.model.predict_proba(source_data)[:, 1]
327 | return xg
328 |
329 | def create_default_pipeline(self):
330 | """Create the default xG estimation pipeline.
331 |
332 | Returns
333 | -------
334 | Scikit-learn pipeline
335 | The default pipeline, suitable for computing xG
336 | but by no means the best possible model.
337 | """
338 | models = [OpenplayXGModel(), FreekickXGModel(), PenaltyXGModel()]
339 | self.column_descriptions = {
340 | m.__class__.__name__: m.column_descriptions for m in models
341 | }
342 | return models
343 |
344 | def save_model(self, filename=None):
345 | """Save the XGModel instance to disk.
346 |
347 | All models are saved to the same place, with the installed
348 | soccer_xg library (given by ``XGModel.model_directory``).
349 |
350 | Parameters
351 | ----------
352 | filename : string (default=None):
353 | The filename to use for the saved model. If this parameter
354 | is not specified, save to the default filename. Note that if a model
355 | already lists with this filename, it will be overwritten. Note also that
356 | this is a filename only, **not** a full path. If a full path is specified
357 | it is likely (albeit not guaranteed) to cause errors.
358 |
359 | Returns
360 | -------
361 | ``None``
362 | """
363 | if filename is None:
364 | filename = self._default_model_filename
365 | joblib.dump(self, os.path.join(self.model_directory, filename))
366 |
367 | @classmethod
368 | def load_model(cls, filename=None):
369 | """Load a saved XGModel.
370 |
371 | Parameters
372 | ----------
373 | filename : string (default=None):
374 | The filename to use for the saved model. If this parameter
375 | is not specified, load the default model. Note that
376 | this is a filename only, **not** a full path.
377 |
378 | Returns
379 | -------
380 | ``soccer_xg.XGModel`` instance.
381 | """
382 | if filename is None:
383 | filename = cls._default_model_filename
384 |
385 | return joblib.load(os.path.join(cls.model_directory, filename))
386 |
387 |
388 | class OpenplayXGModel(XGModel):
389 | _default_model_filename = 'default_openplay_model.xg'
390 |
391 | def train(
392 | self,
393 | source_data,
394 | training_seasons=(('ENG', '1617'), ('ENG', '1718')),
395 | target_colname='goal',
396 | ):
397 | self._training_seasons = []
398 | if isinstance(source_data, DataApi):
399 | game_ids = source_data.games[
400 | source_data.games.season_id.astype(str).isin(
401 | [s[1] for s in training_seasons]
402 | )
403 | & source_data.games.competition_id.astype(str).isin(
404 | [s[0] for s in training_seasons]
405 | )
406 | ].index
407 | feature_cols = get_features(
408 | source_data, game_ids, shotfilter=OpenplayXGModel.filter_shots
409 | )
410 | target_col = get_labels(
411 | source_data, game_ids, shotfilter=OpenplayXGModel.filter_shots
412 | )
413 | self._training_seasons = training_seasons
414 | else:
415 | target_col = source_data[target_colname]
416 | feature_cols = source_data.drop(target_colname, axis=1)
417 | self.model.fit(feature_cols, target_col)
418 | self._fitted = True
419 |
420 | def estimate(self, source_data, game_ids=None):
421 |
422 | if isinstance(source_data, DataApi):
423 | game_ids = (
424 | source_data.games.index if game_ids is None else game_ids
425 | )
426 | source_data = get_features(
427 | source_data, game_ids, shotfilter=OpenplayXGModel.filter_shots
428 | )
429 |
430 | xg = pd.DataFrame(index=source_data.index)
431 | xg['xG'] = self.model.predict_proba(source_data)[:, 1]
432 | return xg
433 |
434 | def create_default_pipeline(self):
435 | bodypart_colname = 'bodypart_id_a0'
436 | dist_to_goal_colname = 'start_dist_to_goal_a0'
437 | angle_to_goal_colname = 'start_angle_to_goal_a0'
438 |
439 | self.column_descriptions = {
440 | bodypart_colname: 'Bodypart used for the shot (head, foot or other)',
441 | dist_to_goal_colname: 'Distance to goal',
442 | angle_to_goal_colname: 'Angle to goal',
443 | }
444 |
445 | preprocess_pipeline = simple_proc_for_linear_algoritms(
446 | [dist_to_goal_colname, angle_to_goal_colname], [bodypart_colname]
447 | )
448 | base_model = LogisticRegression(
449 | max_iter=10000, solver='lbfgs', fit_intercept=False
450 | )
451 | pipe = make_pipeline(preprocess_pipeline, base_model)
452 | return pipe
453 |
454 | @staticmethod
455 | def filter_shots(df_actions):
456 | shot_idx = (
457 | df_actions.type_name == 'shot'
458 | ) & df_actions.result_name.isin(['fail', 'success'])
459 | return shot_idx
460 |
461 |
462 | class PenaltyXGModel(XGModel):
463 | _default_model_filename = 'default_penalty_model.xg'
464 |
465 | def __init__(self, copy_data=True):
466 | super().__init__(copy_data)
467 | self._fitted = True
468 |
469 | def train(
470 | self,
471 | source_data,
472 | training_seasons=(('ENG', '1617'), ('ENG', '1718')),
473 | target_colname='goal',
474 | ):
475 | pass
476 |
477 | def estimate(self, source_data, game_ids=None):
478 |
479 | if isinstance(source_data, DataApi):
480 | game_ids = (
481 | source_data.games.index if game_ids is None else game_ids
482 | )
483 | source_data = get_features(
484 | source_data,
485 | game_ids,
486 | xfns=[],
487 | shotfilter=PenaltyXGModel.filter_shots,
488 | )
489 |
490 | xg = pd.DataFrame(index=source_data.index)
491 | xg['xG'] = 0.792453
492 |
493 | return xg
494 |
495 | def create_default_pipeline(self):
496 | return None
497 |
498 | @staticmethod
499 | def filter_shots(df_actions):
500 | shot_idx = df_actions.type_name == 'shot_penalty'
501 | return shot_idx
502 |
503 |
504 | class FreekickXGModel(XGModel):
505 |
506 | _default_model_filename = 'default_freekick_model.xg'
507 |
508 | def train(
509 | self,
510 | source_data,
511 | training_seasons=(('ENG', '1617'), ('ENG', '1718')),
512 | target_colname='goal',
513 | ):
514 | self._training_seasons = []
515 | if isinstance(source_data, DataApi):
516 | game_ids = source_data.games[
517 | source_data.games.season_id.astype(str).isin(
518 | [s[1] for s in training_seasons]
519 | )
520 | & source_data.games.competition_id.astype(str).isin(
521 | [s[0] for s in training_seasons]
522 | )
523 | ].index
524 | feature_cols = get_features(
525 | source_data, game_ids, shotfilter=FreekickXGModel.filter_shots
526 | )
527 | target_col = get_labels(
528 | source_data, game_ids, shotfilter=FreekickXGModel.filter_shots
529 | )
530 | self._training_seasons = training_seasons
531 | else:
532 | target_col = source_data[target_colname]
533 | feature_cols = source_data.drop(target_colname, axis=1)
534 | self.model.fit(feature_cols, target_col)
535 | self._fitted = True
536 |
537 | def estimate(self, source_data, game_ids=None):
538 |
539 | if isinstance(source_data, DataApi):
540 | game_ids = (
541 | source_data.games.index if game_ids is None else game_ids
542 | )
543 | source_data = get_features(
544 | source_data, game_ids, shotfilter=FreekickXGModel.filter_shots
545 | )
546 |
547 | xg = pd.DataFrame(index=source_data.index)
548 | xg['xG'] = self.model.predict_proba(source_data)[:, 1]
549 | return xg
550 |
551 | def create_default_pipeline(self):
552 | dist_to_goal_colname = 'start_dist_to_goal_a0'
553 | angle_to_goal_colname = 'start_angle_to_goal_a0'
554 |
555 | self.column_descriptions = {
556 | dist_to_goal_colname: 'Distance to goal',
557 | angle_to_goal_colname: 'Angle to goal',
558 | }
559 |
560 | preprocess_pipeline = simple_proc_for_linear_algoritms(
561 | [dist_to_goal_colname, angle_to_goal_colname], []
562 | )
563 | base_model = LogisticRegression(
564 | max_iter=10000, solver='lbfgs', fit_intercept=True
565 | )
566 | pipe = make_pipeline(preprocess_pipeline, base_model)
567 | return pipe
568 |
569 | @staticmethod
570 | def filter_shots(df_actions):
571 | shot_idx = df_actions.type_name == 'shot_freekick'
572 | return shot_idx
573 |
574 |
575 | def get_features(
576 | api,
577 | game_ids=None,
578 | xfns=fs.all_features,
579 | shotfilter=None,
580 | nb_prev_actions=3,
581 | ):
582 | game_ids = api.games.index if game_ids is None else game_ids
583 | X = {}
584 | for game_id in tqdm(game_ids, desc=f'Generating features'):
585 | # try:
586 | game = api.games.loc[game_id]
587 | game_actions = utils.enhance_actions(api.get_actions(game_id))
588 | X[game_id] = _compute_features_game(
589 | game, game_actions, xfns, shotfilter, nb_prev_actions
590 | )
591 | X[game_id].index.name = 'action_id'
592 | X[game_id]['game_id'] = game_id
593 | # except Exception as e:
594 | # print(f"Failed for game with id={game_id}: {e}")
595 | X = pd.concat(X.values()).reset_index().set_index(['game_id', 'action_id'])
596 | # remove post-shot features (these will all have a single unique value)
597 | f = X.columns[X.nunique() > 1]
598 | return X[f]
599 |
600 |
601 | def _compute_features_game(
602 | game, actions, xfns=fs.all_features, shotfilter=None, nb_prev_actions=3
603 | ):
604 | if shotfilter is None:
605 | # filter shots and ignore own goals
606 | shot_idx = actions.type_name.isin(
607 | ['shot', 'shot_penalty', 'shot_freekick']
608 | ) & actions.result_name.isin(['fail', 'success'])
609 | else:
610 | shot_idx = shotfilter(actions)
611 | if shot_idx.sum() < 1:
612 | return pd.DataFrame()
613 | if len(xfns) < 1:
614 | return pd.DataFrame(index=actions.index.values[shot_idx])
615 | # convert actions to gamestates
616 | gamestates = [
617 | states.loc[shot_idx].copy()
618 | for states in fs.gamestates(actions, nb_prev_actions)
619 | ]
620 | gamestates = fs.play_left_to_right(gamestates, game.home_team_id)
621 | # remove post-shot attributes
622 | gamestates[0].loc[shot_idx, 'end_x'] = float('NaN')
623 | gamestates[0].loc[shot_idx, 'end_y'] = float('NaN')
624 | gamestates[0].loc[shot_idx, 'result_id'] = float('NaN')
625 | # compute features
626 | X = pd.concat([fn(gamestates) for fn in xfns], axis=1)
627 | # fix data types
628 | for c in [c for c in X.columns.values if c.startswith('type_id')]:
629 | X[c] = pd.Categorical(
630 | X[c].replace(spadlcfg.actiontypes_df().type_name.to_dict()),
631 | categories=spadlcfg.actiontypes,
632 | ordered=False,
633 | )
634 | for c in [c for c in X.columns.values if c.startswith('result_id')]:
635 | X[c] = pd.Categorical(
636 | X[c].replace(spadlcfg.results_df().result_name.to_dict()),
637 | categories=spadlcfg.results,
638 | ordered=False,
639 | )
640 | for c in [c for c in X.columns.values if c.startswith('bodypart_id')]:
641 | X[c] = pd.Categorical(
642 | X[c].replace(spadlcfg.bodyparts_df().bodypart_name.to_dict()),
643 | categories=spadlcfg.bodyparts,
644 | ordered=False,
645 | )
646 | return X
647 |
648 |
649 | def get_labels(api, game_ids=None, shotfilter=None):
650 | game_ids = api.games.index if game_ids is None else game_ids
651 | y = {}
652 | for game_id in tqdm(game_ids, desc=f'Generating labels'):
653 | try:
654 | game = api.games.loc[game_id]
655 | game_actions = utils.enhance_actions(api.get_actions(game_id))
656 | y[game_id] = _compute_labels_game(game, game_actions, shotfilter)
657 | y[game_id].index.name = 'action_id'
658 | y[game_id]['game_id'] = game_id
659 | except Exception as e:
660 | print(e)
661 | return (
662 | pd.concat(y.values())
663 | .reset_index()
664 | .set_index(['game_id', 'action_id'])['goal']
665 | )
666 |
667 |
668 | def _compute_labels_game(game, actions, shotfilter=None):
669 | # compute labels
670 | y = actions['result_name'] == 'success'
671 | if shotfilter is None:
672 | # filter shots and ignore own goals
673 | shot_idx = actions.type_name.isin(
674 | ['shot', 'shot_penalty', 'shot_freekick']
675 | ) & actions.result_name.isin(['fail', 'success'])
676 | else:
677 | shot_idx = shotfilter(actions)
678 | return y.loc[shot_idx].to_frame('goal')
679 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import soccer_xg.xg as xg
3 | from soccer_xg.api import DataApi
4 |
5 |
6 | @pytest.fixture(scope='session')
7 | def api():
8 | return DataApi('tests/data/spadl-statsbomb-WC-2018.h5')
9 |
10 |
11 | @pytest.fixture()
12 | def model():
13 | return xg.XGModel()
14 |
--------------------------------------------------------------------------------
/tests/data/download.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pandas as pd
4 | import socceraction.spadl as spadl
5 | import socceraction.spadl.statsbomb as statsbomb
6 | from tqdm import tqdm
7 |
8 | seasons = {
9 | 3: '2018',
10 | }
11 | leagues = {
12 | 'FIFA World Cup': 'WC',
13 | }
14 |
15 | free_open_data_remote = (
16 | 'https://raw.githubusercontent.com/statsbomb/open-data/master/data/'
17 | )
18 | spadl_datafolder = 'tests/data'
19 |
20 | SBL = statsbomb.StatsBombLoader(root=free_open_data_remote, getter='remote')
21 |
22 | # View all available competitions
23 | df_competitions = SBL.competitions()
24 | df_selected_competitions = df_competitions[
25 | df_competitions.competition_name.isin(leagues.keys())
26 | ]
27 |
28 | for competition in df_selected_competitions.itertuples():
29 | # Get matches from all selected competition
30 | matches = SBL.matches(competition.competition_id, competition.season_id)
31 |
32 | matches_verbose = tqdm(
33 | list(matches.itertuples()), desc='Loading match data'
34 | )
35 | teams, players, player_games = [], [], []
36 |
37 | competition_id = leagues[competition.competition_name]
38 | season_id = seasons[competition.season_id]
39 | spadl_h5 = os.path.join(
40 | spadl_datafolder, f'spadl-statsbomb-{competition_id}-{season_id}.h5'
41 | )
42 | with pd.HDFStore(spadl_h5) as spadlstore:
43 |
44 | spadlstore.put('actiontypes', spadl.actiontypes_df(), format='table')
45 | spadlstore.put('results', spadl.results_df(), format='table')
46 | spadlstore.put('bodyparts', spadl.bodyparts_df(), format='table')
47 |
48 | for match in matches_verbose:
49 | # load data
50 | teams.append(SBL.teams(match.match_id))
51 | players.append(SBL.players(match.match_id))
52 | events = SBL.events(match.match_id)
53 |
54 | # convert data
55 | player_games.append(statsbomb.extract_player_games(events))
56 | spadlstore.put(
57 | f'actions/game_{match.match_id}',
58 | statsbomb.convert_to_actions(events, match.home_team_id),
59 | format='table',
60 | )
61 |
62 | games = matches.rename(
63 | columns={'match_id': 'game_id', 'match_date': 'game_date'}
64 | )
65 | games.season_id = season_id
66 | games.competition_id = competition_id
67 | spadlstore.put('games', games)
68 | spadlstore.put(
69 | 'teams',
70 | pd.concat(teams).drop_duplicates('team_id').reset_index(drop=True),
71 | )
72 | spadlstore.put(
73 | 'players',
74 | pd.concat(players)
75 | .drop_duplicates('player_id')
76 | .reset_index(drop=True),
77 | )
78 | spadlstore.put(
79 | 'player_games', pd.concat(player_games).reset_index(drop=True)
80 | )
81 |
--------------------------------------------------------------------------------
/tests/test_api.py:
--------------------------------------------------------------------------------
1 | def test_get_actions(api):
2 | df_actions = api.get_actions(7584)
3 | assert len(df_actions)
4 |
5 |
6 | def test_get_home_away_team_id(api):
7 | home_id, away_id = api.get_home_away_team_id(7584)
8 | assert home_id == 782
9 | assert away_id == 778
10 |
11 |
12 | def test_get_team_name(api):
13 | name = api.get_team_name(782)
14 | assert name == 'Belgium'
15 |
16 |
17 | def test_get_player_name(api):
18 | name = api.get_player_name(3089)
19 | assert name == 'Kevin De Bruyne'
20 |
--------------------------------------------------------------------------------
/tests/test_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import soccer_xg.metrics as metrics_lib
4 |
5 | #
6 | # expected_calibration_error
7 | #
8 |
9 |
10 | def test_expected_calibration_error():
11 | np.random.seed(1)
12 | nsamples = 100
13 | probs = np.linspace(0, 1, nsamples)
14 | labels = np.random.rand(nsamples) < probs
15 | ece = metrics_lib.expected_calibration_error(labels, probs)
16 | bad_ece = metrics_lib.expected_calibration_error(labels, probs / 2)
17 |
18 | assert ece > 0 and ece < 1
19 | assert bad_ece > 0 and bad_ece < 1
20 | assert ece < bad_ece
21 |
22 |
23 | def test_expected_calibration_error_all_wrong():
24 | n_bins = 90
25 | ece = metrics_lib.expected_calibration_error(
26 | np.ones(10), np.zeros(10), n_bins=n_bins
27 | )
28 | assert ece == pytest.approx(1.0)
29 |
30 | ece = metrics_lib.expected_calibration_error(
31 | np.zeros(10), np.ones(10), n_bins=n_bins
32 | )
33 | assert ece == pytest.approx(1.0)
34 |
35 |
36 | def test_expected_calibration_error_all_right():
37 | n_bins = 90
38 | ece = metrics_lib.expected_calibration_error(
39 | np.ones(10), np.ones(10), n_bins=n_bins
40 | )
41 | assert ece == pytest.approx(0.0)
42 |
43 | ece = metrics_lib.expected_calibration_error(
44 | np.zeros(10), np.zeros(10), n_bins=n_bins
45 | )
46 | assert ece == pytest.approx(0.0)
47 |
--------------------------------------------------------------------------------
/tests/test_xg.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os
3 |
4 | from soccer_xg import xg
5 |
6 |
7 | class TestDefaults(object):
8 | """Tests for defaults."""
9 |
10 | def test_column_descriptions_set(self, model):
11 | assert isinstance(model.column_descriptions, collections.Mapping)
12 |
13 |
14 | class TestModelTrain(object):
15 | """Tests for the train_model method."""
16 |
17 | def test_api_input(self, model, api):
18 | model.train(source_data=api, training_seasons=[('WC', '2018')])
19 |
20 | def test_dataframe_input(self, model, api):
21 | features = xg.get_features(api)
22 | labels = xg.get_labels(api)
23 | df = features.assign(goal=labels)
24 | model.train(source_data=df)
25 |
26 |
27 | class TestModelValidate(object):
28 | """Tests for the validate_model method."""
29 |
30 | def test_api_input(self, model, api):
31 | model.train(source_data=api, training_seasons=[('WC', '2018')])
32 | model.validate(
33 | source_data=api, validation_seasons=[('WC', '2018')], plot=False
34 | )
35 |
36 | def test_dataframe_input(self, model, api):
37 | features = xg.get_features(api)
38 | labels = xg.get_labels(api)
39 | df = features.assign(goal=labels)
40 | model.train(source_data=df)
41 | model.validate(source_data=df, plot=False)
42 |
43 |
44 | class TestModelIO(object):
45 | """Tests functions that deal with model saving and loading"""
46 |
47 | def teardown_method(self, method):
48 | try:
49 | os.remove(self.expected_path)
50 | except OSError:
51 | pass
52 |
53 | def test_model_save_default(self, model):
54 | model_name = 'test_hazard.xgmodel'
55 | model._default_model_filename = model_name
56 |
57 | self.expected_path = os.path.join(
58 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
59 | 'soccer_xg',
60 | 'models',
61 | model_name,
62 | )
63 | assert os.path.isfile(self.expected_path) is False
64 |
65 | model.save_model()
66 | assert os.path.isfile(self.expected_path) is True
67 |
68 | def test_model_save_specified(self, model):
69 | model = xg.XGModel()
70 | model_name = 'test_lukaku.xgmodel'
71 |
72 | self.expected_path = os.path.join(
73 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
74 | 'soccer_xg',
75 | 'models',
76 | model_name,
77 | )
78 | assert os.path.isfile(self.expected_path) is False
79 |
80 | model.save_model(filename=model_name)
81 | assert os.path.isfile(self.expected_path) is True
82 |
83 | def test_model_load_default(self, model):
84 | model_name = 'test_witsel.xgmodel'
85 | model._default_model_filename = model_name
86 |
87 | self.expected_path = os.path.join(
88 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
89 | 'soccer_xg',
90 | 'models',
91 | model_name,
92 | )
93 | assert os.path.isfile(self.expected_path) is False
94 |
95 | model.save_model()
96 |
97 | xGModel_class = xg.XGModel
98 | xGModel_class._default_model_filename = model_name
99 |
100 | loaded_model = xGModel_class.load_model()
101 |
102 | assert isinstance(loaded_model, xg.XGModel)
103 |
104 | def test_model_load_specified(self):
105 | model = xg.XGModel()
106 | model_name = 'test_kompany.xgmodel'
107 |
108 | self.expected_path = os.path.join(
109 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
110 | 'soccer_xg',
111 | 'models',
112 | model_name,
113 | )
114 | assert os.path.isfile(self.expected_path) is False
115 |
116 | model.save_model(filename=model_name)
117 |
118 | loaded_model = xg.XGModel.load_model(filename=model_name)
119 | assert isinstance(loaded_model, xg.XGModel)
120 |
121 |
122 | def test_get_features(api):
123 | features = xg.get_features(api, game_ids=[7584])
124 | assert len(features) == 40 # one row for each shot
125 |
126 |
127 | def test_get_labels(api):
128 | labels = xg.get_labels(api, game_ids=[7584])
129 | assert len(labels) == 40 # one row for each shot
130 | assert labels.sum() == 5
131 |
--------------------------------------------------------------------------------