├── .gitignore ├── LICENSE ├── README.md ├── docs ├── pass-mixture.png └── soccermix_ecml2020.pdf ├── mixture.py ├── notebooks ├── 1-load-and-convert-statsbomb-data.ipynb ├── 2-create-mixture-models.ipynb ├── paper-casestudy-city-liverpool.ipynb ├── paper-experiment1-deanonymizing-players.ipynb ├── paper-experiment2-player-illustration.ipynb ├── paper-experiment3-team-illustration.ipynb └── paper-experiment4-defensive-style.ipynb └── vis.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 KU Leuven Machine Learning Research Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # soccermix 2 | SoccerMix is a soft clustering technique based on mixture models that decomposes event stream data into a number of *prototypical* actions of a specific type, location, and direction. 3 | 4 | Here is an example of 71 different prototypical passes discovered by SoccerMix. 5 | ![](docs/pass-mixture.png) 6 | 7 | 8 | 9 | 10 | Copyright 2020 Tom Decroos, Maaike Van Roy, Jesse Davis 11 | -------------------------------------------------------------------------------- /docs/pass-mixture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-KULeuven/soccermix/9f23c98b8e20b5cfeb569a43f8a2b110bf50dfd2/docs/pass-mixture.png -------------------------------------------------------------------------------- /docs/soccermix_ecml2020.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-KULeuven/soccermix/9f23c98b8e20b5cfeb569a43f8a2b110bf50dfd2/docs/soccermix_ecml2020.pdf -------------------------------------------------------------------------------- /mixture.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import tqdm 4 | 5 | import scipy.stats as stats 6 | import sklearn.cluster as cluster 7 | 8 | import cvxpy as cp 9 | 10 | 11 | class MultiGauss: 12 | 13 | def init_responsibilities(X, weights,n_components): 14 | # initialize weights with KMeans 15 | n_samples, _ = X.shape 16 | labels = cluster.KMeans(n_clusters=n_components, n_init=10).fit_predict( 17 | X, sample_weight=weights 18 | ) 19 | resp = np.zeros((n_samples, n_components)) 20 | resp[np.arange(n_samples), labels] = 1 21 | return resp 22 | 23 | def fit(self, X, w): 24 | self.mean = np.average(X, weights=w, axis=0) 25 | self.cov = np.nan_to_num(np.cov(X.T, aweights=w), 0) 26 | np.fill_diagonal(self.cov, self.cov.diagonal() + 1e-6) 27 | 28 | self._n_parameters = len(self.mean) + len(self.cov.flatten()) 29 | return self 30 | 31 | def pdf(self, X): 32 | return np.nan_to_num( 33 | stats.multivariate_normal.pdf( 34 | X, mean=self.mean, cov=self.cov 35 | ) 36 | ) 37 | 38 | class Gauss: 39 | 40 | def fit(self, X, w): 41 | self.mean = np.average(X, weights=w, axis=0) 42 | self.std = np.sqrt(np.cov(X, aweights=w)) 43 | self.std += 1e-6 44 | 45 | self._n_parameters = 2 46 | return self 47 | 48 | def pdf(self, X): 49 | return np.nan_to_num( 50 | stats.norm.pdf( 51 | X, loc=self.mean, scale=self.std 52 | ) 53 | ) 54 | 55 | class VonMises: 56 | 57 | def init_responsibilities(alpha, weights,n_components): 58 | # initialize weights with KMeans 59 | n_samples, _ = alpha.shape 60 | X = np.concatenate([np.sin(alpha),np.cos(alpha)],axis=1) 61 | labels = cluster.KMeans(n_clusters=n_components, n_init=10).fit_predict( 62 | X, sample_weight=weights 63 | ) 64 | resp = np.zeros((n_samples, n_components)) 65 | resp[np.arange(n_samples), labels] = 1 66 | return resp 67 | 68 | def fit(self, alpha, w): 69 | sin = np.average(np.sin(alpha), weights=w, axis=0) 70 | cos = np.average(np.cos(alpha), weights=w, axis=0) 71 | 72 | self.loc = np.arctan2(sin, cos) 73 | self.R = np.sqrt(sin ** 2 + cos ** 2) # mean resultant length 74 | # self.R = min(self.R,0.99) 75 | 76 | maxR = np.empty_like(self.R) 77 | maxR[0] = 0.99 78 | self.R = min(self.R, maxR) 79 | 80 | self.kappa = ( 81 | self.R * (2 - self.R ** 2) / (1 - self.R ** 2) 82 | ) # approximation for kappa 83 | 84 | self._n_parameters = 2 85 | return self 86 | 87 | def pdf(self, alpha): 88 | return np.nan_to_num( 89 | stats.vonmises.pdf(alpha, kappa=self.kappa, loc=self.loc).flatten() 90 | ) 91 | 92 | 93 | class CategoricalModel: 94 | 95 | def __init__(self, tol=1e-6): 96 | self.tol = tol 97 | 98 | def fit(self, X, weights=None): 99 | if weights: 100 | X = X[weights > self.tol] 101 | self.categories = set(X) 102 | return self 103 | 104 | def predict_proba(self, X, weights=None): 105 | p = pd.DataFrame() 106 | if weights is None: 107 | weights = np.zeros(len(X)) + 1 108 | for c in self.categories: 109 | p[str(c)] = ((X == c) & (weights > self.tol)).apply(float) 110 | return p 111 | 112 | 113 | class MixtureModel: 114 | def __init__(self, n_components, distribution=MultiGauss, max_iter=100, tol=1e-6): 115 | self.n_components = n_components 116 | self.distribution = distribution 117 | self.max_iter = max_iter 118 | self.tol = tol 119 | 120 | def fit(self, X, weights=None, verbose=False): 121 | 122 | # handle sparsity 123 | if weights is None: 124 | weights = np.zeros(len(X)) + 1 125 | pos_weights_idx = weights > self.tol 126 | X = X[pos_weights_idx] 127 | weights = weights[pos_weights_idx] 128 | 129 | self.weight_total = weights.sum() 130 | self.loglikelihood = -np.inf 131 | self.submodels = list(self.distribution() for _i in range(self.n_components)) 132 | 133 | if len(X) < self.n_components: 134 | return None 135 | 136 | responsibilities = self.distribution.init_responsibilities(X, weights,self.n_components) 137 | 138 | # learn models on initial weights 139 | self.priors = responsibilities.sum(axis=0) / responsibilities.sum() 140 | # invalid model if less clusters found than given components 141 | if any(self.priors < self.tol): 142 | return None 143 | 144 | for i in range(self.n_components): 145 | self.submodels[i].fit(X, weights * responsibilities[:, i]) 146 | 147 | iterations = ( 148 | range(self.max_iter) if not verbose else tqdm.tqdm(range(self.max_iter)) 149 | ) 150 | 151 | for self._n_iter in iterations: 152 | # Expectation 153 | for i in range(self.n_components): 154 | responsibilities[:, i] = self.priors[i] * self.submodels[i].pdf(X) 155 | 156 | # enough improvement or not? 157 | new_loglikelihood = (weights * np.log(responsibilities.sum(axis=1))).sum() 158 | 159 | if new_loglikelihood > self.loglikelihood + self.tol: 160 | self.loglikelihood = new_loglikelihood 161 | # self.responsibilities = responsibilities 162 | # self.weights = weights 163 | else: 164 | break 165 | 166 | # normalize responsibilities such that each data point occurs with P=1 167 | responsibilities /= responsibilities.sum(axis=1)[:, np.newaxis] 168 | 169 | # Maximalization 170 | self.priors = responsibilities.sum(axis=0) / responsibilities.sum() 171 | for i in range(self.n_components): 172 | self.submodels[i].fit(X, weights * responsibilities[:, i]) 173 | 174 | if np.isinf(self.loglikelihood): 175 | return None 176 | 177 | return self 178 | 179 | def predict_proba(self, X, weights=None): 180 | p = np.zeros((len(X), self.n_components)) 181 | 182 | # handle sparsity 183 | if weights is None: 184 | weights = np.zeros(len(X)) + 1 185 | pos_weights_idx = weights > self.tol 186 | X = X[pos_weights_idx] 187 | weights = weights[pos_weights_idx] 188 | 189 | pdfs = np.vstack([m.pdf(X) for m in self.submodels]).T 190 | resp = self.priors * pdfs 191 | probs = resp / resp.sum(axis=1)[:, np.newaxis] 192 | 193 | p[pos_weights_idx, :] = (weights * probs.T).T 194 | return p 195 | 196 | def params(self): 197 | return list(m.__dict__ for m in self.submodels) 198 | 199 | def _n_parameters(self): 200 | return ( 201 | sum(m._n_parameters for m in self.submodels) 202 | - self.submodels[0]._n_parameters 203 | ) 204 | 205 | 206 | def ilp_select_models_max(models, max_components, verbose=False): 207 | x = cp.Variable(len(models), boolean=True) 208 | c = np.array(list(m.loglikelihood for m in models)) 209 | n_components = np.array(list(m.n_components for m in models)) 210 | 211 | objective = cp.Maximize(cp.sum(c * x)) 212 | constraints = [] 213 | constraints += [n_components * x <= max_components] 214 | for name in set(m.name for m in models): 215 | name_idx = np.array(list(int(m.name == name) for m in models)) 216 | constraints += [name_idx * x == 1] 217 | 218 | prob = cp.Problem(objective, constraints) 219 | prob.solve(verbose=verbose) 220 | idx, = np.where(x.value > 0.3) 221 | return list(models[i] for i in idx) 222 | 223 | 224 | def ilp_select_models_bic_triangle(models, verbose=False): 225 | x = cp.Variable(len(models), boolean=True) 226 | c = np.array(list(m.loglikelihood for m in models)) 227 | n_parameters = np.array(list(m.n_components for m in models)) 228 | dataweights = {} 229 | for m in models: 230 | if m.name not in dataweights: 231 | dataweights[m.name] = m.weight_total 232 | n_data = sum(dataweights.values()) 233 | 234 | n = cp.sum(n_parameters * x) 235 | para = n + (cp.square(n) + n) / 2 236 | objective = cp.Minimize(np.log(n_data) * para - 2 * cp.sum(c * x)) 237 | 238 | constraints = [] 239 | for name in set(m.name for m in models): 240 | name_idx = np.array(list(int(m.name == name) for m in models)) 241 | constraints += [name_idx * x == 1] 242 | 243 | prob = cp.Problem(objective, constraints) 244 | prob.solve(verbose=verbose) 245 | idx, = np.where(x.value > 0.3) 246 | return list(models[i] for i in idx) 247 | 248 | 249 | def ilp_select_models_bic(models, verbose=False): 250 | x = cp.Variable(len(models), boolean=True) 251 | c = np.array(list(m.loglikelihood for m in models)) 252 | n_parameters = np.array(list(m._n_parameters() for m in models)) 253 | dataweights = {} 254 | for m in models: 255 | if m.name not in dataweights: 256 | dataweights[m.name] = m.weight_total 257 | n_data = sum(dataweights.values()) 258 | 259 | objective = cp.Minimize( 260 | np.log(n_data) * cp.sum(n_parameters * x) - 2 * cp.sum(c * x) 261 | ) 262 | 263 | constraints = [] 264 | for name in set(m.name for m in models): 265 | name_idx = np.array(list(int(m.name == name) for m in models)) 266 | constraints += [name_idx * x == 1] 267 | 268 | prob = cp.Problem(objective, constraints) 269 | prob.solve(verbose=verbose) 270 | idx, = np.where(x.value > 0.3) 271 | return list(models[i] for i in idx) 272 | 273 | 274 | def select_models_solo_bic(models): 275 | for m in models: 276 | m.solo_bic = np.log(m.weight_total) * m._n_parameters() - 2 * m.loglikelihood 277 | 278 | ms = [] 279 | for name in set(m.name for m in models): 280 | bestm = min([m for m in models if m.name == name], key=lambda m: m.solo_bic) 281 | ms.append(bestm) 282 | return ms 283 | 284 | 285 | def probabilities(models, X, W): 286 | weights = [] 287 | for model in models: 288 | probs = model.predict_proba(X, W[model.name].values) 289 | nextlevel_columns = list(f"{model.name}_{i}" for i in range(model.n_components)) 290 | weights.append(pd.DataFrame(probs, columns=nextlevel_columns)) 291 | return pd.concat(weights, axis=1) 292 | 293 | -------------------------------------------------------------------------------- /notebooks/1-load-and-convert-statsbomb-data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "**Disclaimer**: this notebook's compatibility with StatsBomb event data 4.0.0 was last checked on June 15th, 2020" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%load_ext autoreload\n", 17 | "%autoreload 2\n", 18 | "import os;\n", 19 | "import warnings\n", 20 | "import pandas as pd\n", 21 | "warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)\n", 22 | "import tqdm" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import socceraction.spadl as spadl\n", 32 | "import socceraction.spadl.statsbomb as statsbomb\n", 33 | "import socceraction.atomic.spadl as atomicspadl" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## Set up the statsbombloader" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "# Use this if you only want to use the free public statsbomb data\n", 50 | "free_open_data_remote = \"https://raw.githubusercontent.com/statsbomb/open-data/master/data/\"\n", 51 | "SBL = statsbomb.StatsBombLoader(root=free_open_data_remote,getter=\"remote\")\n", 52 | "\n", 53 | "# # Uncomment the code below if you have a local folder on your computer with statsbomb data\n", 54 | "#datafolder = \"../data-epl\" # Example of local folder with statsbomb data\n", 55 | "#SBL = statsbomb.StatsBombLoader(root=datafolder,getter=\"local\")" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "## Select competitions to load and convert" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 4, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "data": { 72 | "text/plain": [ 73 | "{'Champions League',\n", 74 | " \"FA Women's Super League\",\n", 75 | " 'FIFA World Cup',\n", 76 | " 'La Liga',\n", 77 | " 'NWSL',\n", 78 | " 'Premier League',\n", 79 | " \"Women's World Cup\"}" 80 | ] 81 | }, 82 | "execution_count": 4, 83 | "metadata": {}, 84 | "output_type": "execute_result" 85 | } 86 | ], 87 | "source": [ 88 | "# View all available competitions\n", 89 | "competitions = SBL.competitions()\n", 90 | "set(competitions.competition_name)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 5, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "data": { 100 | "text/html": [ 101 | "
\n", 102 | "\n", 115 | "\n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | "
competition_idseason_idcountry_namecompetition_namecompetition_genderseason_namematch_updatedmatch_available
17433InternationalFIFA World Cupmale20182019-12-16T23:09:16.1687562019-12-16T23:09:16.168756
\n", 143 | "
" 144 | ], 145 | "text/plain": [ 146 | " competition_id season_id country_name competition_name \\\n", 147 | "17 43 3 International FIFA World Cup \n", 148 | "\n", 149 | " competition_gender season_name match_updated \\\n", 150 | "17 male 2018 2019-12-16T23:09:16.168756 \n", 151 | "\n", 152 | " match_available \n", 153 | "17 2019-12-16T23:09:16.168756 " 154 | ] 155 | }, 156 | "execution_count": 5, 157 | "metadata": {}, 158 | "output_type": "execute_result" 159 | } 160 | ], 161 | "source": [ 162 | "# Fifa world cup\n", 163 | "selected_competitions = competitions[competitions.competition_name==\"FIFA World Cup\"]\n", 164 | "\n", 165 | "# # Messi data\n", 166 | "# selected_competitions = competitions[competitions.competition_name==\"La Liga\"]\n", 167 | "\n", 168 | "# # FA Women's Super League\n", 169 | "# selected_competitions = competitions[competitions.competition_name==\"FA Women's Super League\"]\n", 170 | "selected_competitions" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 6, 176 | "metadata": { 177 | "scrolled": true 178 | }, 179 | "outputs": [ 180 | { 181 | "data": { 182 | "text/html": [ 183 | "
\n", 184 | "\n", 197 | "\n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | "
home_team_nameaway_team_namematch_datehome_scoreaway_score
0CroatiaDenmark2018-07-0111
1AustraliaPeru2018-06-2602
2NigeriaIceland2018-06-2220
3SerbiaBrazil2018-06-2702
4IranPortugal2018-06-2511
..................
59ColombiaJapan2018-06-1912
60JapanPoland2018-06-2801
61DenmarkAustralia2018-06-2111
62SpainRussia2018-07-0111
63CroatiaEngland2018-07-1121
\n", 299 | "

64 rows × 5 columns

\n", 300 | "
" 301 | ], 302 | "text/plain": [ 303 | " home_team_name away_team_name match_date home_score away_score\n", 304 | "0 Croatia Denmark 2018-07-01 1 1\n", 305 | "1 Australia Peru 2018-06-26 0 2\n", 306 | "2 Nigeria Iceland 2018-06-22 2 0\n", 307 | "3 Serbia Brazil 2018-06-27 0 2\n", 308 | "4 Iran Portugal 2018-06-25 1 1\n", 309 | ".. ... ... ... ... ...\n", 310 | "59 Colombia Japan 2018-06-19 1 2\n", 311 | "60 Japan Poland 2018-06-28 0 1\n", 312 | "61 Denmark Australia 2018-06-21 1 1\n", 313 | "62 Spain Russia 2018-07-01 1 1\n", 314 | "63 Croatia England 2018-07-11 2 1\n", 315 | "\n", 316 | "[64 rows x 5 columns]" 317 | ] 318 | }, 319 | "execution_count": 6, 320 | "metadata": {}, 321 | "output_type": "execute_result" 322 | } 323 | ], 324 | "source": [ 325 | "# Get matches from all selected competitions\n", 326 | "matches = list(\n", 327 | " SBL.matches(row.competition_id, row.season_id)\n", 328 | " for row in selected_competitions.itertuples()\n", 329 | ")\n", 330 | "matches = pd.concat(matches, sort=True).reset_index(drop=True)\n", 331 | "matches[[\"home_team_name\",\"away_team_name\",\"match_date\",\"home_score\",\"away_score\"]]" 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "metadata": {}, 337 | "source": [ 338 | "## Load and convert match data" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 7, 344 | "metadata": {}, 345 | "outputs": [ 346 | { 347 | "name": "stderr", 348 | "output_type": "stream", 349 | "text": [ 350 | "Loading match data: 100%|██████████| 64/64 [02:09<00:00, 2.03s/it]\n" 351 | ] 352 | } 353 | ], 354 | "source": [ 355 | "matches_verbose = tqdm.tqdm(list(matches.itertuples()),desc=\"Loading match data\")\n", 356 | "teams,players,player_games = [],[],[]\n", 357 | "actions = {}\n", 358 | "atomic_actions = {}\n", 359 | "for match in matches_verbose:\n", 360 | " # load data\n", 361 | " teams.append(SBL.teams(match.match_id))\n", 362 | " players.append(SBL.players(match.match_id))\n", 363 | " events = SBL.events(match.match_id)\n", 364 | " \n", 365 | " # convert data\n", 366 | " player_games.append(statsbomb.extract_player_games(events))\n", 367 | " actions = statsbomb.convert_to_actions(events,match.home_team_id)\n", 368 | " atomic_actions[match.match_id] = atomicspadl.convert_to_atomic(actions)\n", 369 | "\n", 370 | "games = matches.rename(columns={\"match_id\":\"game_id\"})\n", 371 | "teams = pd.concat(teams).drop_duplicates(\"team_id\").reset_index(drop=True)\n", 372 | "players = pd.concat(players).drop_duplicates(\"player_id\").reset_index(drop=True)\n", 373 | "player_games = pd.concat(player_games).reset_index(drop=True)" 374 | ] 375 | }, 376 | { 377 | "cell_type": "markdown", 378 | "metadata": {}, 379 | "source": [ 380 | "## Store converted spadl data in a h5-file" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 8, 386 | "metadata": { 387 | "scrolled": true 388 | }, 389 | "outputs": [ 390 | { 391 | "name": "stdout", 392 | "output_type": "stream", 393 | "text": [ 394 | "Directory ../data-fifa created \n" 395 | ] 396 | } 397 | ], 398 | "source": [ 399 | "datafolder = \"../data-fifa\"\n", 400 | "\n", 401 | "# Create data folder if it doesn't exist\n", 402 | "if not os.path.exists(datafolder):\n", 403 | " os.mkdir(datafolder)\n", 404 | " print(f\"Directory {datafolder} created \")\n", 405 | "\n", 406 | "spadl_h5 = os.path.join(datafolder, \"atomic-spadl-statsbomb.h5\")\n", 407 | "\n", 408 | "# Store all spadl data in h5-file\n", 409 | "with pd.HDFStore(spadl_h5) as spadlstore:\n", 410 | " spadlstore[\"competitions\"] = selected_competitions\n", 411 | " spadlstore[\"games\"] = games\n", 412 | " spadlstore[\"teams\"] = teams\n", 413 | " spadlstore[\"players\"] = players\n", 414 | " spadlstore[\"player_games\"] = player_games\n", 415 | " for game_id in atomic_actions.keys():\n", 416 | " spadlstore[f\"atomic_actions/game_{game_id}\"] = atomic_actions[game_id]\n", 417 | "\n", 418 | " spadlstore[\"results\"] = spadl.results_df()\n", 419 | " spadlstore[\"bodyparts\"] = spadl.bodyparts_df()\n", 420 | " spadlstore[\"atomic_actiontypes\"] = atomicspadl.actiontypes_df()" 421 | ] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "metadata": {}, 426 | "source": [ 427 | "## Plot the spadl data\n", 428 | "Extra library required: ```pip install matplotsoccer```" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": 14, 434 | "metadata": {}, 435 | "outputs": [], 436 | "source": [ 437 | "# Select England vs Belgium game at World Cup\n", 438 | "with pd.HDFStore(spadl_h5) as spadlstore:\n", 439 | " games = spadlstore[\"games\"].merge(spadlstore[\"competitions\"])\n", 440 | " game_id = games[(games.competition_name == \"FIFA World Cup\") \n", 441 | " & (games.home_team_name == \"Belgium\")\n", 442 | " & (games.away_team_name == \"England\")].game_id.values[0]\n", 443 | " \n", 444 | " atomic_actions = spadlstore[f\"atomic_actions/game_{game_id}\"]\n", 445 | " atomic_actions = (\n", 446 | " atomic_actions.merge(spadlstore[\"atomic_actiontypes\"],how=\"left\")\n", 447 | " #.merge(spadlstore[\"results\"],how=\"left\")\n", 448 | " .merge(spadlstore[\"bodyparts\"],how=\"left\")\n", 449 | " .merge(spadlstore[\"players\"],how=\"left\")\n", 450 | " .merge(spadlstore[\"teams\"],how=\"left\")\n", 451 | " )\n", 452 | "\n", 453 | "# use nickname if available else use full name\n", 454 | "atomic_actions[\"player\"] = atomic_actions[[\"player_nickname\",\"player_name\"]].apply(lambda x: x[0] if x[0] else x[1],axis=1)" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": 17, 460 | "metadata": {}, 461 | "outputs": [ 462 | { 463 | "name": "stdout", 464 | "output_type": "stream", 465 | "text": [ 466 | "2018-07-14 Belgium 2-0 England 4'\n" 467 | ] 468 | }, 469 | { 470 | "data": { 471 | "image/png": "\n", 472 | "text/plain": [ 473 | "
" 474 | ] 475 | }, 476 | "metadata": { 477 | "needs_background": "light" 478 | }, 479 | "output_type": "display_data" 480 | }, 481 | { 482 | "name": "stdout", 483 | "output_type": "stream", 484 | "text": [ 485 | "2018-07-14 Belgium 2-0 England 82'\n" 486 | ] 487 | }, 488 | { 489 | "data": { 490 | "image/png": "\n", 491 | "text/plain": [ 492 | "
" 493 | ] 494 | }, 495 | "metadata": { 496 | "needs_background": "light" 497 | }, 498 | "output_type": "display_data" 499 | } 500 | ], 501 | "source": [ 502 | "import matplotsoccer\n", 503 | "\n", 504 | "for shot in list(atomic_actions[(atomic_actions.type_name == \"goal\")].index):\n", 505 | " a = atomic_actions[shot-8:shot+1].copy()\n", 506 | "\n", 507 | " a[\"start_x\"] = a.x\n", 508 | " a[\"start_y\"] = a.y\n", 509 | " a[\"end_x\"] = a.x + a.dx\n", 510 | " a[\"end_y\"] = a.y + a.dy\n", 511 | "\n", 512 | " g = list(games[games.game_id == a.game_id.values[0]].itertuples())[0]\n", 513 | " minute = int((a.period_id.values[0]-1)*45 +a.time_seconds.values[0] // 60)\n", 514 | " game_info = f\"{g.match_date} {g.home_team_name} {g.home_score}-{g.away_score} {g.away_team_name} {minute + 1}'\"\n", 515 | " print(game_info)\n", 516 | "\n", 517 | " def nice_time(row):\n", 518 | " minute = int((row.period_id-1)*45 +row.time_seconds // 60)\n", 519 | " second = int(row.time_seconds % 60)\n", 520 | " return f\"{minute}m{second}s\"\n", 521 | "\n", 522 | " a[\"nice_time\"] = a.apply(nice_time,axis=1)\n", 523 | " labels = a[[\"nice_time\", \"type_name\", \"player\", \"team_name\"]]\n", 524 | "\n", 525 | " matplotsoccer.actions(\n", 526 | " location=a[[\"start_x\", \"start_y\", \"end_x\", \"end_y\"]],\n", 527 | " action_type=a.type_name,\n", 528 | " team= a.team_name,\n", 529 | " label=labels,\n", 530 | " labeltitle=[\"time\",\"actiontype\",\"player\",\"team\"],\n", 531 | " zoom=False,\n", 532 | " figsize=6\n", 533 | " )" 534 | ] 535 | } 536 | ], 537 | "metadata": { 538 | "kernelspec": { 539 | "display_name": "Python 3", 540 | "language": "python", 541 | "name": "python3" 542 | }, 543 | "language_info": { 544 | "codemirror_mode": { 545 | "name": "ipython", 546 | "version": 3 547 | }, 548 | "file_extension": ".py", 549 | "mimetype": "text/x-python", 550 | "name": "python", 551 | "nbconvert_exporter": "python", 552 | "pygments_lexer": "ipython3", 553 | "version": "3.7.1" 554 | }, 555 | "varInspector": { 556 | "cols": { 557 | "lenName": 16, 558 | "lenType": 16, 559 | "lenVar": 40 560 | }, 561 | "kernels_config": { 562 | "python": { 563 | "delete_cmd_postfix": "", 564 | "delete_cmd_prefix": "del ", 565 | "library": "var_list.py", 566 | "varRefreshCmd": "print(var_dic_list())" 567 | }, 568 | "r": { 569 | "delete_cmd_postfix": ") ", 570 | "delete_cmd_prefix": "rm(", 571 | "library": "var_list.r", 572 | "varRefreshCmd": "cat(var_dic_list()) " 573 | } 574 | }, 575 | "types_to_exclude": [ 576 | "module", 577 | "function", 578 | "builtin_function_or_method", 579 | "instance", 580 | "_Feature" 581 | ], 582 | "window_display": false 583 | } 584 | }, 585 | "nbformat": 4, 586 | "nbformat_minor": 2 587 | } 588 | -------------------------------------------------------------------------------- /notebooks/paper-experiment1-deanonymizing-players.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "import os; import sys; sys.path.insert(0, '../')\n", 12 | "import pandas as pd\n", 13 | "import tqdm\n", 14 | "import pickle\n", 15 | "\n", 16 | "import numpy as np\n", 17 | "import warnings" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "### Configure file and folder names\n", 27 | "data_h5 = \"../data/paper/soccermix_all_data.h5\"\n", 28 | "\n", 29 | "d_weights = \"../data/paper/soccermix_all_dirweights.pkl\"\n", 30 | "\n", 31 | "spadl_h5 = \"../data/tomd/spadl-statsbomb.h5\"" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "X = pd.read_hdf(data_h5, \"X\")" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 4, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "def loadall(filename):\n", 50 | " with open(filename, \"rb\") as f:\n", 51 | " while True:\n", 52 | " try:\n", 53 | " yield pickle.load(f)\n", 54 | " except EOFError:\n", 55 | " break\n", 56 | "\n", 57 | "d_w = loadall(d_weights)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 5, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "dir_weights = next(d_w)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 6, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "games = pd.read_hdf(spadl_h5, \"games\")\n", 76 | "\n", 77 | "games_1819 = games[games.season_name == '2018/2019']\n", 78 | "games_1718 = games[games.season_name == '2017/2018']" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 7, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "players = pd.read_hdf(spadl_h5, \"players\")\n", 88 | "pg = pd.read_hdf(spadl_h5, \"player_games\")" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 8, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "pg_1819 = pg[pg.game_id.isin(games_1819.game_id)]\n", 98 | "pg_1718 = pg[pg.game_id.isin(games_1718.game_id)]" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 9, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "players_1819 = players[players.player_id.isin(pg_1819.player_id)]\n", 108 | "players_1718 = players[players.player_id.isin(pg_1718.player_id)]" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 10, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "mp_1819 = pg_1819[[\"player_id\", \"minutes_played\"]].groupby(\"player_id\").sum().reset_index()\n", 118 | "mp_1718 = pg_1718[[\"player_id\", \"minutes_played\"]].groupby(\"player_id\").sum().reset_index()" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 11, 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "name": "stderr", 128 | "output_type": "stream", 129 | "text": [ 130 | "100%|██████████| 515/515 [02:00<00:00, 4.29it/s]\n", 131 | "100%|██████████| 505/505 [01:46<00:00, 4.73it/s]\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "# Get player vectors\n", 137 | "\n", 138 | "merged_weights = dir_weights.copy()\n", 139 | "merged_weights[\"player_id\"] = X.player_id.values\n", 140 | "merged_weights[\"game_id\"] = X.game_id.values\n", 141 | "\n", 142 | "vectors_1718 = {}\n", 143 | "for p in tqdm.tqdm(list(players_1718.player_id.unique())):\n", 144 | " vectors_1718[int(p)] = merged_weights.loc[((merged_weights.player_id == p)\n", 145 | " & (merged_weights.game_id.isin(games_1718.game_id))),\n", 146 | " dir_weights.columns].sum().values\n", 147 | " \n", 148 | "vectors_1819 = {}\n", 149 | "for p in tqdm.tqdm(list(players_1819.player_id.unique())):\n", 150 | " vectors_1819[int(p)] = merged_weights.loc[((merged_weights.player_id == p)\n", 151 | " & (merged_weights.game_id.isin(games_1819.game_id))),\n", 152 | " dir_weights.columns].sum().values\n", 153 | " \n", 154 | "vectors_1718_pd = pd.concat({k: pd.DataFrame(v).T for k,v in vectors_1718.items()}).droplevel(level=1)\n", 155 | "vectors_1718_pd.index.name = \"player_id\"\n", 156 | "vectors_1718_pd.columns = dir_weights.columns\n", 157 | "\n", 158 | "vectors_1819_pd = pd.concat({k: pd.DataFrame(v).T for k,v in vectors_1819.items()}).droplevel(level=1)\n", 159 | "vectors_1819_pd.index.name = \"player_id\"\n", 160 | "vectors_1819_pd.columns = dir_weights.columns" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 12, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "# Normalize vectors per 90 min game time\n", 170 | "\n", 171 | "vectors_1718_norm = pd.merge(vectors_1718_pd, mp_1718, left_index=True, right_on='player_id').set_index('player_id')\n", 172 | "df1 = vectors_1718_norm.loc[:, dir_weights.columns] * 90\n", 173 | "vectors_1718_norm.loc[:, dir_weights.columns] = df1.divide(vectors_1718_norm.minutes_played, axis='rows')\n", 174 | "vectors_1718_norm.drop(columns=['minutes_played'], inplace=True)\n", 175 | "\n", 176 | "vectors_1819_norm = pd.merge(vectors_1819_pd, mp_1819, left_index=True, right_on='player_id').set_index('player_id')\n", 177 | "df1 = vectors_1819_norm.loc[:, dir_weights.columns] * 90\n", 178 | "vectors_1819_norm.loc[:, dir_weights.columns] = df1.divide(vectors_1819_norm.minutes_played, axis='rows')\n", 179 | "vectors_1819_norm.drop(columns=['minutes_played'], inplace=True)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 13, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "# Code below mainly from Pieter's implementation of this experiment with soccer vectors\n", 189 | "# https://github.com/probberechts/soccer-player-vectors-thesis/blob/master/notebooks/5-experiments.ipynb\n", 190 | "\n", 191 | "# Select correct players to test on \n", 192 | "\n", 193 | "train_players = pg_1718.groupby('player_id').agg({\n", 194 | " 'minutes_played': 'sum',\n", 195 | " 'team_id': set\n", 196 | "}).merge(players_1718, on=\"player_id\", how='left')\n", 197 | "\n", 198 | "test_players = pg_1819.groupby('player_id').agg({\n", 199 | " 'minutes_played': 'sum',\n", 200 | " 'team_id': set\n", 201 | "}).merge(players_1819, on=\"player_id\", how='left')" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 14, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "all_players = pd.merge(train_players, test_players, on=\"player_id\", suffixes=(\"_train\", \"_test\"))\n", 211 | "all_players['nb_teams'] = all_players.apply(lambda x: len(x.team_id_train | x.team_id_test), axis=1)\n", 212 | "all_players = all_players[all_players.nb_teams == 1]" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 15, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "# Only players who played >= 900 minutes in both train and test season\n", 222 | "all_players = all_players[(all_players.minutes_played_train >= 900) & (all_players.minutes_played_test >= 900)]" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 16, 228 | "metadata": {}, 229 | "outputs": [ 230 | { 231 | "name": "stdout", 232 | "output_type": "stream", 233 | "text": [ 234 | "Number of players: 193\n" 235 | ] 236 | } 237 | ], 238 | "source": [ 239 | "all_players = all_players.player_id.unique()\n", 240 | "print(\"Number of players: \", len(all_players))" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 17, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "# Compute pairwise distances\n", 250 | "\n", 251 | "from sklearn.metrics import pairwise_distances\n", 252 | "from sklearn import preprocessing\n", 253 | "\n", 254 | "# D = pairwise_distances(\n", 255 | "# vectors_1718_norm.loc[all_players],\n", 256 | "# vectors_1819_norm.loc[all_players],\n", 257 | "# metric='manhattan'\n", 258 | "# )\n", 259 | "\n", 260 | "D = pairwise_distances(\n", 261 | " preprocessing.normalize(vectors_1718_norm.loc[all_players], norm=\"l1\"),\n", 262 | " preprocessing.normalize(vectors_1819_norm.loc[all_players], norm=\"l1\"),\n", 263 | " metric=\"manhattan\")\n", 264 | "\n", 265 | "# sort each row\n", 266 | "k_d = np.sort(D, axis = 1) \n", 267 | "# sort each row and replace distances by index\n", 268 | "k_i = np.argsort(D, axis = 1) \n", 269 | "# replace indices by player ids\n", 270 | "p_i = np.take(all_players, k_i, axis = 0)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 18, 276 | "metadata": {}, 277 | "outputs": [ 278 | { 279 | "data": { 280 | "text/plain": [ 281 | "array([ 13, 1, 4, 0, 4, 0, 90, 0, 0, 0, 0, 0, 0,\n", 282 | " 4, 0, 0, 0, 0, 20, 6, 0, 0, 13, 0, 1, 0,\n", 283 | " 5, 0, 0, 29, 0, 3, 2, 1, 15, 1, 1, 142, 0,\n", 284 | " 1, 1, 0, 0, 0, 3, 0, 0, 0, 6, 7, 1, 0,\n", 285 | " 0, 4, 5, 0, 0, 0, 0, 0, 2, 7, 0, 15, 0,\n", 286 | " 7, 0, 5, 2, 0, 0, 11, 5, 12, 0, 0, 4, 0,\n", 287 | " 0, 2, 0, 1, 0, 0, 0, 60, 0, 8, 3, 0, 8,\n", 288 | " 2, 0, 0, 0, 10, 13, 0, 0, 3, 25, 27, 23, 0,\n", 289 | " 2, 0, 0, 34, 0, 1, 20, 1, 0, 0, 1, 0, 2,\n", 290 | " 16, 3, 0, 0, 0, 13, 1, 11, 11, 9, 0, 8, 3,\n", 291 | " 158, 0, 0, 106, 0, 0, 0, 5, 0, 4, 0, 40, 1,\n", 292 | " 0, 90, 6, 0, 0, 0, 0, 0, 0, 0, 0, 1, 4,\n", 293 | " 1, 0, 0, 0, 2, 0, 6, 0, 1, 0, 0, 24, 11,\n", 294 | " 37, 11, 17, 4, 4, 73, 53, 1, 1, 6, 2, 0, 0,\n", 295 | " 31, 0, 3, 0, 0, 0, 15, 14, 15, 3, 6])" 296 | ] 297 | }, 298 | "execution_count": 18, 299 | "metadata": {}, 300 | "output_type": "execute_result" 301 | } 302 | ], 303 | "source": [ 304 | "rs = np.argmax(np.array([p_i[i,:] == all_players[i] for i in range(p_i.shape[0])]), axis=1)\n", 305 | "rs" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 19, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [ 314 | "def mean_reciprocal_rank(rs):\n", 315 | " return np.mean(1. / (rs + 1))\n", 316 | "\n", 317 | "def top_k(rs, k):\n", 318 | " return (rs < k).sum() / len(rs)" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": 20, 324 | "metadata": {}, 325 | "outputs": [ 326 | { 327 | "data": { 328 | "text/plain": [ 329 | "0.5885390745244184" 330 | ] 331 | }, 332 | "execution_count": 20, 333 | "metadata": {}, 334 | "output_type": "execute_result" 335 | } 336 | ], 337 | "source": [ 338 | "mean_reciprocal_rank(rs)" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 21, 344 | "metadata": {}, 345 | "outputs": [ 346 | { 347 | "data": { 348 | "text/plain": [ 349 | "0.8082901554404145" 350 | ] 351 | }, 352 | "execution_count": 21, 353 | "metadata": {}, 354 | "output_type": "execute_result" 355 | } 356 | ], 357 | "source": [ 358 | "top_k(rs, 10)" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 22, 364 | "metadata": {}, 365 | "outputs": [ 366 | { 367 | "name": "stdout", 368 | "output_type": "stream", 369 | "text": [ 370 | "0.7150259067357513\n", 371 | "0.6269430051813472\n", 372 | "0.48186528497409326\n" 373 | ] 374 | } 375 | ], 376 | "source": [ 377 | "print(top_k(rs, 5))\n", 378 | "print(top_k(rs, 3))\n", 379 | "print(top_k(rs, 1))" 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "metadata": {}, 385 | "source": [ 386 | "# Get similar players to player" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 23, 392 | "metadata": {}, 393 | "outputs": [], 394 | "source": [ 395 | "def get_similar_players(player_id):\n", 396 | " player_index = np.where(all_players == player_id)[0][0]\n", 397 | " print(player_index)\n", 398 | " sims = p_i[player_index,:]\n", 399 | " names = players_1819.set_index(\"player_id\").loc[sims, \"player_name\"].values\n", 400 | " dists = k_d[player_index,:]\n", 401 | " return pd.DataFrame({\"name\": names, \"dist\": dists})" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": 24, 407 | "metadata": {}, 408 | "outputs": [ 409 | { 410 | "name": "stdout", 411 | "output_type": "stream", 412 | "text": [ 413 | "55 3202\n", 414 | "Name: player_id, dtype: int64\n", 415 | "48 3202\n", 416 | "Name: player_id, dtype: int64\n", 417 | "60 3237\n", 418 | "Name: player_id, dtype: int64\n", 419 | "53 3237\n", 420 | "Name: player_id, dtype: int64\n" 421 | ] 422 | } 423 | ], 424 | "source": [ 425 | "print(train_players[train_players.player_name.str.contains('Jesus')].player_id)\n", 426 | "print(test_players[test_players.player_name.str.contains('Jesus')].player_id)\n", 427 | "\n", 428 | "print(train_players[train_players.player_name.str.contains('Agüero')].player_id)\n", 429 | "print(test_players[test_players.player_name.str.contains('Agüero')].player_id)" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": 25, 435 | "metadata": {}, 436 | "outputs": [ 437 | { 438 | "name": "stdout", 439 | "output_type": "stream", 440 | "text": [ 441 | "43\n" 442 | ] 443 | }, 444 | { 445 | "data": { 446 | "text/html": [ 447 | "
\n", 448 | "\n", 461 | "\n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | "
namedist
0Sergio Leonel Agüero del Castillo0.208176
1Marko Arnautović0.293974
2Gabriel Fernando de Jesus0.319744
3Cenk Tosun0.322811
4Jamie Vardy0.349722
.........
188Alex McCarthy1.914147
189Martin Dúbravka1.914678
190Asmir Begović1.915225
191Hugo Lloris1.924485
192David de Gea Quintana1.924613
\n", 527 | "

193 rows × 2 columns

\n", 528 | "
" 529 | ], 530 | "text/plain": [ 531 | " name dist\n", 532 | "0 Sergio Leonel Agüero del Castillo 0.208176\n", 533 | "1 Marko Arnautović 0.293974\n", 534 | "2 Gabriel Fernando de Jesus 0.319744\n", 535 | "3 Cenk Tosun 0.322811\n", 536 | "4 Jamie Vardy 0.349722\n", 537 | ".. ... ...\n", 538 | "188 Alex McCarthy 1.914147\n", 539 | "189 Martin Dúbravka 1.914678\n", 540 | "190 Asmir Begović 1.915225\n", 541 | "191 Hugo Lloris 1.924485\n", 542 | "192 David de Gea Quintana 1.924613\n", 543 | "\n", 544 | "[193 rows x 2 columns]" 545 | ] 546 | }, 547 | "execution_count": 25, 548 | "metadata": {}, 549 | "output_type": "execute_result" 550 | } 551 | ], 552 | "source": [ 553 | "get_similar_players(3237) # Similar to Aguero" 554 | ] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": 26, 559 | "metadata": { 560 | "scrolled": false 561 | }, 562 | "outputs": [ 563 | { 564 | "name": "stdout", 565 | "output_type": "stream", 566 | "text": [ 567 | "39\n" 568 | ] 569 | }, 570 | { 571 | "data": { 572 | "text/html": [ 573 | "
\n", 574 | "\n", 587 | "\n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | "
namedist
0Sergio Leonel Agüero del Castillo0.232574
1Gabriel Fernando de Jesus0.235393
2Jamie Vardy0.289722
3Harry Kane0.297915
4Troy Deeney0.314603
.........
188Alex McCarthy1.896625
189Mathew Ryan1.897377
190Asmir Begović1.899667
191Hugo Lloris1.905453
192David de Gea Quintana1.912968
\n", 653 | "

193 rows × 2 columns

\n", 654 | "
" 655 | ], 656 | "text/plain": [ 657 | " name dist\n", 658 | "0 Sergio Leonel Agüero del Castillo 0.232574\n", 659 | "1 Gabriel Fernando de Jesus 0.235393\n", 660 | "2 Jamie Vardy 0.289722\n", 661 | "3 Harry Kane 0.297915\n", 662 | "4 Troy Deeney 0.314603\n", 663 | ".. ... ...\n", 664 | "188 Alex McCarthy 1.896625\n", 665 | "189 Mathew Ryan 1.897377\n", 666 | "190 Asmir Begović 1.899667\n", 667 | "191 Hugo Lloris 1.905453\n", 668 | "192 David de Gea Quintana 1.912968\n", 669 | "\n", 670 | "[193 rows x 2 columns]" 671 | ] 672 | }, 673 | "execution_count": 26, 674 | "metadata": {}, 675 | "output_type": "execute_result" 676 | } 677 | ], 678 | "source": [ 679 | "get_similar_players(3202) # Similar to Jesus" 680 | ] 681 | } 682 | ], 683 | "metadata": { 684 | "kernelspec": { 685 | "display_name": "Python 3", 686 | "language": "python", 687 | "name": "python3" 688 | }, 689 | "language_info": { 690 | "codemirror_mode": { 691 | "name": "ipython", 692 | "version": 3 693 | }, 694 | "file_extension": ".py", 695 | "mimetype": "text/x-python", 696 | "name": "python", 697 | "nbconvert_exporter": "python", 698 | "pygments_lexer": "ipython3", 699 | "version": "3.7.1" 700 | }, 701 | "varInspector": { 702 | "cols": { 703 | "lenName": 16, 704 | "lenType": 16, 705 | "lenVar": 40 706 | }, 707 | "kernels_config": { 708 | "python": { 709 | "delete_cmd_postfix": "", 710 | "delete_cmd_prefix": "del ", 711 | "library": "var_list.py", 712 | "varRefreshCmd": "print(var_dic_list())" 713 | }, 714 | "r": { 715 | "delete_cmd_postfix": ") ", 716 | "delete_cmd_prefix": "rm(", 717 | "library": "var_list.r", 718 | "varRefreshCmd": "cat(var_dic_list()) " 719 | } 720 | }, 721 | "types_to_exclude": [ 722 | "module", 723 | "function", 724 | "builtin_function_or_method", 725 | "instance", 726 | "_Feature" 727 | ], 728 | "window_display": false 729 | } 730 | }, 731 | "nbformat": 4, 732 | "nbformat_minor": 4 733 | } 734 | -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | from scipy import linalg 2 | import matplotlib.pyplot as plt 3 | import matplotlib as mpl 4 | import matplotsoccer as mps 5 | import numpy as np 6 | import math 7 | 8 | 9 | def dual_axes(figsize=4): 10 | fig, axs = plt.subplots(1, 2) 11 | fig.set_size_inches((figsize * 3, figsize)) 12 | return axs[0], axs[1] 13 | 14 | 15 | def loc_angle_axes(figsize=4): 16 | fig, _axs = plt.subplots(1, 2) 17 | fig.set_size_inches((figsize * 3, figsize)) 18 | 19 | axloc = plt.subplot(121) 20 | axloc = field(axloc) 21 | axpol = plt.subplot(122, projection="polar") 22 | # axpol.set_rticks(np.linspace(0, 2, 21)) 23 | return axloc, axpol 24 | 25 | 26 | def field(ax): 27 | ax = mps.field(ax=ax, show=False) 28 | ax.set_xlim(-1, 105 + 1) 29 | ax.set_ylim(-1, 68 + 1) 30 | return ax 31 | 32 | 33 | def movement(ax): 34 | plt.axis("on") 35 | plt.axis("scaled") 36 | ax.spines["left"].set_position("center") 37 | ax.spines["bottom"].set_position("center") 38 | ax.spines["right"].set_color("none") 39 | ax.spines["top"].set_color("none") 40 | ax.set_xlim(-60, 60) 41 | ax.set_ylim(-60, 60) 42 | return ax 43 | 44 | 45 | def polar(ax): 46 | plt.axis("on") 47 | ax.set_xlim(-3.2, 3.2) 48 | ax.spines["left"].set_position("center") 49 | ax.spines["right"].set_color("none") 50 | ax.spines["top"].set_color("none") 51 | return ax 52 | 53 | 54 | ################################## 55 | # MODEL-BASED VISUALIZATION 56 | ################################# 57 | 58 | 59 | def show_location_model(loc_model, show=True, figsize=6): 60 | ax = mps.field(show=False, figsize=figsize) 61 | 62 | norm_strengths = loc_model.priors / np.max(loc_model.priors) * 0.8 63 | for strength, gauss, color in zip(norm_strengths, loc_model.submodels, colors * 10): 64 | add_ellips(ax, gauss.mean, gauss.cov, color=color, alpha=strength) 65 | if show: 66 | plt.show() 67 | 68 | 69 | def show_direction_model(gauss, dir_models, show=True, figsize=6): 70 | ax = mps.field(show=False, figsize=figsize) 71 | 72 | # for gauss in loc_model.submodels: 73 | add_ellips(ax, gauss.mean, gauss.cov, alpha=0.5) 74 | 75 | x, y = gauss.mean 76 | 77 | for vonmises in dir_models.submodels: 78 | dx = np.cos(vonmises.loc)[0] 79 | dy = np.sin(vonmises.loc)[0] 80 | r = vonmises.R[0] 81 | add_arrow(ax, x, y, 10*dx, 10*dy, 82 | linewidth=0.5) 83 | 84 | if show: 85 | plt.show() 86 | 87 | 88 | 89 | def show_location_models(loc_models, figsize=6): 90 | """ 91 | Model-based visualization 92 | """ 93 | for model in loc_models: 94 | print(model.name, model.n_components) 95 | show_location_model(model, figsize=6) 96 | 97 | 98 | 99 | def show_all_models(loc_models, dir_models): 100 | 101 | for loc_model in loc_models: 102 | print(loc_model.name, loc_model.n_components) 103 | ax = mps.field(show=False, figsize=8) 104 | 105 | am_subclusters = [] 106 | for a, _ in enumerate(loc_model.submodels): 107 | for dir_model in dir_models: 108 | if f"{loc_model.name}_{a}" == dir_model.name: 109 | am_subclusters.append(dir_model.n_components) 110 | 111 | am_subclusters = np.array(am_subclusters) 112 | 113 | for i, gauss in enumerate(loc_model.submodels): 114 | 115 | if (am_subclusters == 1).all(): 116 | add_ellips(ax, gauss.mean, gauss.cov, alpha=0.5) 117 | 118 | else: 119 | add_ellips(ax, gauss.mean, gauss.cov, color='grey') 120 | 121 | x, y = gauss.mean 122 | for dir_model in dir_models: 123 | if f"{loc_model.name}_{i}" == dir_model.name: 124 | print(dir_model.name, dir_model.n_components) 125 | 126 | for j, vonmises in enumerate(dir_model.submodels): 127 | dx = np.cos(vonmises.loc)[0] 128 | dy = np.sin(vonmises.loc)[0] 129 | r = vonmises.R[0] 130 | add_arrow(ax, x, y, 10*dx, 10*dy, 131 | linewidth=0.5) 132 | 133 | plt.show() 134 | 135 | 136 | def show_direction_models(loc_models, dir_models, figsize=8): 137 | """ 138 | Model-based visualization 139 | """ 140 | for loc_model in loc_models: 141 | print(loc_model.name, loc_model.n_components) 142 | ax = mps.field(show=False, figsize=figsize) 143 | 144 | norm_strengths = loc_model.priors / np.max(loc_model.priors) * 0.8 145 | for i, (strength, gauss) in enumerate(zip(norm_strengths, loc_model.submodels)): 146 | add_ellips(ax, gauss.mean, gauss.cov, alpha=strength) 147 | 148 | x, y = gauss.mean 149 | for dir_model in dir_models: 150 | if f"{loc_model.name}_{i}" == dir_model.name: 151 | print(dir_model.name, dir_model.n_components) 152 | dir_norm_strengths = ( 153 | dir_model.priors / np.max(dir_model.priors) * 0.8 154 | ) 155 | for strength, vonmises in zip( 156 | dir_norm_strengths, dir_model.submodels 157 | ): 158 | dx = np.cos(vonmises.loc)[0] 159 | dy = np.sin(vonmises.loc)[0] 160 | r = vonmises.R[0] 161 | add_arrow( 162 | ax, 163 | x, 164 | y, 165 | 10 * r * dx, 166 | 10 * r * dy, 167 | alpha=strength, 168 | threshold=0, 169 | ) 170 | plt.show() 171 | 172 | 173 | def add_ellips(ax, mean, covar, color=None, alpha=0.7): 174 | v, w = linalg.eigh(covar) 175 | v = 2.0 * np.sqrt(2.0) * np.sqrt(v) 176 | u = w[0] / linalg.norm(w[0]) 177 | 178 | # Plot an ellipse to show the Gaussian component 179 | angle = np.arctan(u[1] / u[0]) 180 | angle = 180.0 * angle / np.pi # convert to degrees 181 | ell = mpl.patches.Ellipse(mean, v[0], v[1], 180.0 + angle, color=color) 182 | # ell.set_clip_box(axs[0].bbox) 183 | ell.set_alpha(alpha) 184 | ell.width = max(ell.width, 3) 185 | ell.height = max(ell.height, 3) 186 | ax.add_artist(ell) 187 | return ax 188 | 189 | 190 | def add_arrow(ax, x, y, dx, dy, arrowsize=2.5, linewidth=2, threshold=2, alpha=1, fc='black', ec='black'): 191 | if abs(dx) > threshold or abs(dy) > threshold: 192 | return ax.arrow( 193 | x, 194 | y, 195 | dx, 196 | dy, 197 | head_width=arrowsize, 198 | head_length=arrowsize, 199 | linewidth=linewidth, 200 | fc=fc, # colors[i % len(colors)], 201 | ec=ec, # colors[i % len(colors)], 202 | length_includes_head=True, 203 | alpha=alpha, 204 | zorder=3, 205 | ) 206 | 207 | 208 | ###################################################### 209 | # PROBABILITY-DENSITY-FUNCTION BASED VISUALIZATION 210 | ###################################################### 211 | 212 | 213 | def show_direction_models_pdf(loc_models, dir_models): 214 | """ 215 | Probability-density function based visualization 216 | """ 217 | for loc_model in loc_models: 218 | print(loc_model.name, loc_model.n_components) 219 | for i, gauss in enumerate(loc_model.submodels): 220 | # axloc, axpol = dual_axes() 221 | # # vis.add_ellips(axloc,gauss.mean,gauss.cov) 222 | # draw_contour(axloc, gauss, cmap="Blues") 223 | for dir_model in dir_models: 224 | if f"{loc_model.name}_{i}" == dir_model.name: 225 | print(dir_model.name, dir_model.n_components) 226 | 227 | axcol, axpol = loc_angle_axes() 228 | draw_contour(axcol, gauss, cmap="Blues") 229 | draw_vonmises_pdfs(dir_model, axpol) 230 | plt.show() 231 | 232 | 233 | def draw_contour(ax, gauss, n=100, cmap="Blues"): 234 | x = np.linspace(0, 105, n) 235 | y = np.linspace(0, 105, n) 236 | xx, yy = np.meshgrid(x, y) 237 | zz = gauss.pdf(np.array([xx.flatten(), yy.flatten()]).T) 238 | zz = zz.reshape(xx.shape) 239 | ax.contourf(xx, yy, zz, cmap=cmap) 240 | return ax 241 | 242 | 243 | def draw_vonmises_pdfs(model, ax=None,figsize=4,projection="polar",n=200,show=True): 244 | if ax is None: 245 | ax = plt.subplot(111, projection=projection) 246 | plt.gcf().set_size_inches((figsize, figsize)) 247 | x = np.linspace(-np.pi, np.pi, n) 248 | total = np.zeros(x.shape) 249 | for i, (prior, vonmises) in enumerate(zip(model.priors, model.submodels)): 250 | p = prior * vonmises.pdf(x) 251 | p = np.nan_to_num(p) 252 | ax.plot(x, p, linewidth=2, color=(colors * 10)[i],label = f"Component {i}") 253 | total += p 254 | # ax.plot(x, total, linewidth=3, color="black") 255 | return ax 256 | 257 | 258 | ################################# 259 | # DATA-BASED VISUALIZATION 260 | ################################# 261 | 262 | colors = [ 263 | "#377eb8", 264 | "#e41a1c", 265 | "#4daf4a", 266 | "#984ea3", 267 | "#ff7f00", 268 | "#ffff33", 269 | "#a65628", 270 | "#f781bf", 271 | "#999999", 272 | ] 273 | 274 | 275 | def scatter_location_model( 276 | loc_model, actions, W, samplefn="max", tol=0.1, figsize=6, alpha=0.5, show=True 277 | ): 278 | X = actions[["x", "y"]] 279 | probs = loc_model.predict_proba(X, W[loc_model.name].values) 280 | probs = np.nan_to_num(probs) 281 | pos_prob_idx = probs.sum(axis=1) > tol 282 | x = X[pos_prob_idx] 283 | w = probs[pos_prob_idx] 284 | 285 | if loc_model.n_components > len(colors): 286 | means = [m.mean for m in loc_model.submodels] 287 | good_colors = color_submodels(means, colors) 288 | else: 289 | good_colors = colors 290 | c = scattercolors(w, good_colors, samplefn=samplefn) 291 | 292 | ax = mps.field(show=False, figsize=figsize) 293 | ax.scatter(x.x, x.y, c=c, alpha=alpha) 294 | if show: 295 | plt.show() 296 | 297 | def scatter_location_model_black( 298 | loc_model, actions, W, samplefn="max", tol=0.1, figsize=6, alpha=0.5, show=True 299 | ): 300 | X = actions[["x", "y"]] 301 | probs = loc_model.predict_proba(X, W[loc_model.name].values) 302 | probs = np.nan_to_num(probs) 303 | pos_prob_idx = probs.sum(axis=1) > tol 304 | x = X[pos_prob_idx] 305 | w = probs[pos_prob_idx] 306 | 307 | if loc_model.n_components > len(colors): 308 | means = [m.mean for m in loc_model.submodels] 309 | good_colors = color_submodels(means, colors) 310 | else: 311 | good_colors = colors 312 | c = scattercolors(w, good_colors, samplefn=samplefn) 313 | 314 | ax = mps.field(show=False, figsize=figsize) 315 | ax.scatter(x.x, x.y, c="black", alpha=alpha) 316 | if show: 317 | plt.show() 318 | 319 | 320 | def scatter_location_models( 321 | loc_models, actions, W, samplefn="max", tol=0.1, figsize=8, alpha=0.5 322 | ): 323 | """ 324 | Data-based visualization 325 | """ 326 | for model in loc_models: 327 | print(model.name, model.n_components) 328 | X = actions[["x", "y"]] 329 | probs = model.predict_proba(X, W[model.name].values) 330 | probs = np.nan_to_num(probs) 331 | pos_prob_idx = probs.sum(axis=1) > tol 332 | x = X[pos_prob_idx] 333 | w = probs[pos_prob_idx] 334 | 335 | if model.n_components > len(colors): 336 | means = [m.mean for m in model.submodels] 337 | good_colors = color_submodels(means, colors) 338 | else: 339 | good_colors = colors 340 | c = scattercolors(w, good_colors, samplefn=samplefn) 341 | 342 | ax = mps.field(show=False, figsize=figsize) 343 | ax.scatter(x.x, x.y, c=c, alpha=alpha) 344 | plt.show() 345 | 346 | 347 | def scatter_direction_models( 348 | dir_models, actions, X, W, samplefn="max", tol=0.1, figsize=4, alpha=0.5 349 | ): 350 | for model in dir_models: 351 | print(model.name, model.n_components) 352 | probs = model.predict_proba(X, W[model.name].values) 353 | probs = np.nan_to_num(probs) 354 | pos_prob_idx = probs.sum(axis=1) > tol 355 | w = probs[pos_prob_idx] 356 | c = scattercolors(w, samplefn=samplefn) 357 | 358 | axloc, axmov = dual_axes() 359 | field(axloc) 360 | movement(axmov) 361 | 362 | x = actions[pos_prob_idx] 363 | axloc.scatter(x.x, x.y, c=c, alpha=alpha) 364 | axmov.scatter(x.dx, x.dy, c=c, alpha=alpha) 365 | plt.show() 366 | 367 | 368 | def hist_direction_model( 369 | dir_model, 370 | actions, 371 | W, 372 | samplefn="max", 373 | tol=0.1, 374 | figsize=4, 375 | alpha=0.5, 376 | projection="polar", 377 | bins=20, 378 | show=False, 379 | ): 380 | X = actions["mov_angle_a0"] 381 | probs = dir_model.predict_proba(X, W[dir_model.name].values) 382 | probs = np.nan_to_num(probs) 383 | pos_prob_idx = probs.sum(axis=1) > tol 384 | w = probs[pos_prob_idx] 385 | c = scattercolors(w, samplefn=samplefn) 386 | 387 | axpol = plt.subplot(111, projection=projection) 388 | plt.gcf().set_size_inches((figsize, figsize)) 389 | 390 | x = actions[pos_prob_idx] 391 | for p, c in zip(w.T, colors): 392 | p = p.flatten() 393 | axpol.hist(x.mov_angle_a0, weights=p.flatten(), color=c, alpha=alpha, bins=bins) 394 | if show: 395 | plt.show() 396 | 397 | 398 | 399 | 400 | def hist_direction_models( 401 | dir_models, actions, W, samplefn="max", tol=0.1, figsize=4, alpha=0.5 402 | ): 403 | for model in dir_models: 404 | print(model.name, model.n_components) 405 | X = actions["mov_angle_a0"] 406 | probs = model.predict_proba(X, W[model.name].values) 407 | probs = np.nan_to_num(probs) 408 | pos_prob_idx = probs.sum(axis=1) > tol 409 | w = probs[pos_prob_idx] 410 | c = scattercolors(w, samplefn=samplefn) 411 | 412 | # axloc, axmov = dual_axes() 413 | # field(axloc) 414 | # movement(axmov) 415 | axloc, axpol = loc_angle_axes() 416 | 417 | x = actions[pos_prob_idx] 418 | axloc.scatter(x.x, x.y, c=c, alpha=alpha) 419 | for p, c in zip(w.T, colors): 420 | p = p.flatten() 421 | axpol.hist( 422 | x.mov_angle_a0, weights=p.flatten(), color=c, alpha=alpha, bins=100 423 | ) 424 | # axpol.hist(x.mov_angle_a0, c=c, alpha=alpha) 425 | plt.show() 426 | 427 | 428 | def model_vs_data( 429 | dir_models, loc_models, actions, W, samplefn="max", tol=0.1, figsize=4, alpha=0.5 430 | ): 431 | for loc_model in loc_models: 432 | print(loc_model.name, loc_model.n_components) 433 | for i, gauss in enumerate(loc_model.submodels): 434 | # axloc, axpol = dual_axes() 435 | # # vis.add_ellips(axloc,gauss.mean,gauss.cov) 436 | # draw_contour(axloc, gauss, cmap="Blues") 437 | for dir_model in dir_models: 438 | if f"{loc_model.name}_{i}" == dir_model.name: 439 | 440 | print(dir_model.name, dir_model.n_components) 441 | axcol, axpol = loc_angle_axes() 442 | draw_contour(axcol, gauss, cmap="Blues") 443 | draw_vonmises_pdfs(axpol, dir_model) 444 | plt.show() 445 | 446 | X = actions["mov_angle_a0"] 447 | probs = dir_model.predict_proba(X, W[dir_model.name].values) 448 | probs = np.nan_to_num(probs) 449 | pos_prob_idx = probs.sum(axis=1) > tol 450 | w = probs[pos_prob_idx] 451 | c = scattercolors(w, samplefn=samplefn) 452 | 453 | # axloc, axmov = dual_axes() 454 | # field(axloc) 455 | # movement(axmov) 456 | axloc, axpol = loc_angle_axes() 457 | 458 | x = actions[pos_prob_idx] 459 | axloc.scatter(x.x, x.y, c=c, alpha=alpha) 460 | for p, c in zip(w.T, colors): 461 | p = p.flatten() 462 | axpol.hist( 463 | x.mov_angle_a0, 464 | weights=p.flatten(), 465 | color=c, 466 | alpha=alpha, 467 | bins=100, 468 | ) 469 | # axpol.hist(x.mov_angle_a0, c=c, alpha=alpha) 470 | plt.show() 471 | 472 | 473 | from scipy.spatial import Delaunay 474 | import networkx as nx 475 | 476 | 477 | def color_submodels(means, colors): 478 | tri = Delaunay(means) 479 | edges = set() 480 | for s in tri.simplices: 481 | [a, b, c] = s 482 | es = set([frozenset([a, b]), frozenset([b, c]), frozenset([c, a])]) 483 | edges = edges | es 484 | G = nx.Graph() 485 | for e in edges: 486 | [i, j] = list(e) 487 | G.add_edge(i, j) 488 | 489 | if len(G.nodes) > 0: 490 | r_ = max([G.degree(node) for node in G.nodes]) 491 | else: 492 | r_ = 0 493 | if r_ > len(colors) - 1: 494 | colorassign = nx.algorithms.coloring.greedy_color(G) 495 | else: 496 | colorassign = nx.algorithms.coloring.equitable_color(G, len(colors)) 497 | colorvector = [0] * len(means) 498 | for k, v in colorassign.items(): 499 | colorvector[k] = int(v) 500 | 501 | return [colors[i] for i in colorvector] 502 | 503 | 504 | def sample(probs): 505 | return np.random.choice(len(probs), p=probs / sum(probs)) 506 | 507 | 508 | def scattercolors(weights, colors=colors, samplefn="max"): 509 | if samplefn == "max": 510 | labels = np.argmax(weights, axis=1) 511 | else: 512 | labels = np.apply_along_axis(sample, axis=1, arr=weights) 513 | 514 | pcolors = [colors[l % len(colors)] for l in labels] 515 | return pcolors 516 | 517 | 518 | ################################# 519 | # EXPERIMENTS VISUALIZATION 520 | ################################# 521 | 522 | 523 | def savefigure(figname): 524 | plt.savefig(figname,dpi=300, 525 | bbox_inches="tight", 526 | pad_inches=0.0 527 | ) 528 | 529 | 530 | def show_component_differences(loc_models, dir_models, vec_p1, vec_p2, name1, name2, save=True): 531 | 532 | # determine colors of dir sub models 533 | difference = vec_p1 - vec_p2 534 | cmap = mpl.cm.get_cmap('bwr_r') 535 | 536 | for loc_model in loc_models: 537 | 538 | mini = min(difference.loc[difference.index.str.contains(f"^{loc_model.name}_")]) 539 | maxi = max(difference.loc[difference.index.str.contains(f"^{loc_model.name}_")]) 540 | ab = max(abs(mini), abs(maxi)) 541 | 542 | if (ab == 0): 543 | ab = 0.0001 544 | 545 | norm = mpl.colors.DivergingNorm(vcenter=0, vmin=-ab, 546 | vmax = ab) 547 | 548 | print(loc_model.name, loc_model.n_components) 549 | ax = mps.field(show=False, figsize=8) 550 | 551 | 552 | am_subclusters = [] 553 | for a, _ in enumerate(loc_model.submodels): 554 | for dir_model in dir_models: 555 | if f"{loc_model.name}_{a}" == dir_model.name: 556 | am_subclusters.append(dir_model.n_components) 557 | 558 | am_subclusters = np.array(am_subclusters) 559 | 560 | for i, gauss in enumerate(loc_model.submodels): 561 | 562 | if (am_subclusters == 1).all(): 563 | add_ellips(ax, gauss.mean, gauss.cov, 564 | color=cmap(norm(difference.loc[f"{loc_model.name}_{i}_0"])), alpha=1) 565 | 566 | else: 567 | add_ellips(ax, gauss.mean, gauss.cov, color='gainsboro') 568 | 569 | x, y = gauss.mean 570 | for dir_model in dir_models: 571 | if f"{loc_model.name}_{i}" == dir_model.name: 572 | print(dir_model.name, dir_model.n_components) 573 | 574 | for j, vonmises in enumerate(dir_model.submodels): 575 | dx = np.cos(vonmises.loc)[0] 576 | dy = np.sin(vonmises.loc)[0] 577 | add_arrow(ax, x, y, 10*dx, 10*dy, 578 | fc=cmap(norm(difference.loc[f"{loc_model.name}_{i}_{j}"])), 579 | arrowsize=4.5, linewidth=1 580 | ) 581 | 582 | cb = plt.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm), ax=ax, fraction=0.065, pad=-0.05, orientation='horizontal') 583 | cb.ax.xaxis.set_ticks_position('bottom') 584 | cb.ax.tick_params(labelsize=16) 585 | plt.axis("scaled") 586 | 587 | if save: 588 | savefigure(f"../figures/{name1}-{name2}-{loc_model.name}.png") 589 | else: 590 | plt.show() 591 | --------------------------------------------------------------------------------