216 |
217 |
218 |
--------------------------------------------------------------------------------
/nbs/blocks.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "87d13599",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "# default_exp blocks"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "id": "ec3f14c3",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "#export\n",
21 | "import abc\n",
22 | "\n",
23 | "import numpy as np\n",
24 | "\n",
25 | "import mezzala.parameters"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "id": "8240d697",
31 | "metadata": {},
32 | "source": [
33 | "# Model blocks"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": null,
39 | "id": "0ffc35b3",
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "#exporti\n",
44 | "\n",
45 | "\n",
46 | "class ModelBlockABC(abc.ABC):\n",
47 | " \"\"\"\n",
48 | " Base class for model blocks\n",
49 | " \"\"\"\n",
50 | " PRIORITY = 0\n",
51 | " \n",
52 | " def param_keys(self, adapter, data):\n",
53 | " return []\n",
54 | "\n",
55 | " def constraints(self, adapter, data):\n",
56 | " return []\n",
57 | " \n",
58 | " def home_terms(self, adapter, data):\n",
59 | " return []\n",
60 | " \n",
61 | " def away_terms(self, adapter, data):\n",
62 | " return []"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": null,
68 | "id": "dd79c4c0",
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "#export\n",
73 | "\n",
74 | "\n",
75 | "class BaseRate(ModelBlockABC):\n",
76 | " \"\"\"\n",
77 | " Estimate average goalscoring rate as a separate parameter.\n",
78 | " \n",
79 | " This can be useful, since it results in both team offence and\n",
80 | " team defence parameters being centered around 1.0\n",
81 | " \"\"\"\n",
82 | " \n",
83 | " def __init__(self):\n",
84 | " pass\n",
85 | " \n",
86 | " def __repr__(self):\n",
87 | " return 'BaseRate()'\n",
88 | " \n",
89 | " def param_keys(self, adapter, data):\n",
90 | " return [mezzala.parameters.AVG_KEY]\n",
91 | " \n",
92 | " def home_terms(self, adapter, row):\n",
93 | " return [\n",
94 | " (mezzala.parameters.AVG_KEY, 1.0)\n",
95 | " ]\n",
96 | " \n",
97 | " def away_terms(self, adapter, row):\n",
98 | " return [\n",
99 | " (mezzala.parameters.AVG_KEY, 1.0)\n",
100 | " ]"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "id": "2b40048d",
107 | "metadata": {},
108 | "outputs": [],
109 | "source": [
110 | "#export\n",
111 | "\n",
112 | "\n",
113 | "class HomeAdvantage(ModelBlockABC):\n",
114 | " \"\"\"\n",
115 | " Estimate home advantage.\n",
116 | " \n",
117 | " Assumes constant home advantage is present in every match in the\n",
118 | " dataset\n",
119 | " \"\"\"\n",
120 | " \n",
121 | " def __init__(self):\n",
122 | " # TODO: allow HFA on/off depending on the data?\n",
123 | " pass\n",
124 | " \n",
125 | " def __repr__(self):\n",
126 | " return 'HomeAdvantage()'\n",
127 | " \n",
128 | " def param_keys(self, adapter, data):\n",
129 | " return [mezzala.parameters.HFA_KEY]\n",
130 | " \n",
131 | " def home_terms(self, adapter, row):\n",
132 | " return [\n",
133 | " (mezzala.parameters.HFA_KEY, 1.0)\n",
134 | " ]"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": null,
140 | "id": "3b4c3987",
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "#export\n",
145 | "\n",
146 | "\n",
147 | "class TeamStrength(ModelBlockABC):\n",
148 | " \"\"\"\n",
149 | " Estimate team offence and team defence parameters.\n",
150 | " \"\"\"\n",
151 | " \n",
152 | " # This is a gross hack so that we know that the \n",
153 | " # team strength parameters come first, and thus can\n",
154 | " # do the constraints (which are positionally indexed)\n",
155 | " PRIORITY = 1\n",
156 | " \n",
157 | " def __init__(self):\n",
158 | " pass\n",
159 | " \n",
160 | " def __repr__(self):\n",
161 | " return 'TeamStrength()'\n",
162 | " \n",
163 | " def _teams(self, adapter, data):\n",
164 | " return set(adapter.home_team(r) for r in data) | set(adapter.away_team(r) for r in data)\n",
165 | " \n",
166 | " def offence_key(self, label):\n",
167 | " return mezzala.parameters.OffenceParameterKey(label)\n",
168 | " \n",
169 | " def defence_key(self, label):\n",
170 | " return mezzala.parameters.DefenceParameterKey(label)\n",
171 | " \n",
172 | " def param_keys(self, adapter, data):\n",
173 | " teams = self._teams(adapter, data)\n",
174 | "\n",
175 | " offence = [self.offence_key(t) for t in teams]\n",
176 | " defence = [self.defence_key(t) for t in teams]\n",
177 | "\n",
178 | " return offence + defence\n",
179 | " \n",
180 | " def constraints(self, adapter, data):\n",
181 | " n_teams = len(self._teams(adapter, data))\n",
182 | " return [\n",
183 | " # Force team offence parameters to average to 1\n",
184 | " {'fun': lambda x: 1 - np.mean(np.exp(x[0:n_teams])),\n",
185 | " 'type': 'eq'},\n",
186 | " ]\n",
187 | " \n",
188 | " def home_terms(self, adapter, row):\n",
189 | " return [\n",
190 | " (self.offence_key(adapter.home_team(row)), 1.0),\n",
191 | " (self.defence_key(adapter.away_team(row)), 1.0),\n",
192 | " ]\n",
193 | " \n",
194 | " def away_terms(self, adapter, row):\n",
195 | " return [\n",
196 | " (self.offence_key(adapter.away_team(row)), 1.0),\n",
197 | " (self.defence_key(adapter.home_team(row)), 1.0),\n",
198 | " ]"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": null,
204 | "id": "bcc5939f",
205 | "metadata": {},
206 | "outputs": [],
207 | "source": [
208 | "#export\n",
209 | "\n",
210 | "\n",
211 | "class KeyBlock(ModelBlockABC):\n",
212 | " \"\"\"\n",
213 | " Generic model block for adding arbitrary model terms from the data\n",
214 | " to both home and away team\n",
215 | " \"\"\"\n",
216 | " def __init__(self, key):\n",
217 | " self.key = key\n",
218 | " \n",
219 | " def __repr__(self):\n",
220 | " return 'KeyBlock()'\n",
221 | " \n",
222 | " def param_keys(self, adapter, data):\n",
223 | " return list(set(self.key(r) for r in data))\n",
224 | " \n",
225 | " def home_terms(self, adapter, row):\n",
226 | " return [self.key(row)]\n",
227 | " \n",
228 | " def away_terms(self, adapter, row):\n",
229 | " return [self.key(row)]"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": null,
235 | "id": "57c18565",
236 | "metadata": {},
237 | "outputs": [],
238 | "source": [
239 | "#export\n",
240 | "\n",
241 | "\n",
242 | "class ConstantBlock(ModelBlockABC):\n",
243 | " \"\"\"\n",
244 | " A model block for adding specific model terms to the parameter keys.\n",
245 | " \n",
246 | " Can be useful in conjunction with `LumpedAdapter` to ensure that certain parameters\n",
247 | " are in the model (even if they aren't estimated)\n",
248 | " \"\"\"\n",
249 | " def __init__(self, *args):\n",
250 | " self.terms = args\n",
251 | " \n",
252 | " def __repr__(self):\n",
253 | " return 'ConstantBlock()'\n",
254 | " \n",
255 | " def param_keys(self, adapter, data):\n",
256 | " return list(self.terms)"
257 | ]
258 | }
259 | ],
260 | "metadata": {
261 | "kernelspec": {
262 | "display_name": "Python 3",
263 | "language": "python",
264 | "name": "python3"
265 | }
266 | },
267 | "nbformat": 4,
268 | "nbformat_minor": 5
269 | }
270 |
--------------------------------------------------------------------------------
/docs/Gemfile.lock:
--------------------------------------------------------------------------------
1 | GEM
2 | remote: https://rubygems.org/
3 | specs:
4 | activesupport (6.0.3.7)
5 | concurrent-ruby (~> 1.0, >= 1.0.2)
6 | i18n (>= 0.7, < 2)
7 | minitest (~> 5.1)
8 | tzinfo (~> 1.1)
9 | zeitwerk (~> 2.2, >= 2.2.2)
10 | addressable (2.7.0)
11 | public_suffix (>= 2.0.2, < 5.0)
12 | coffee-script (2.4.1)
13 | coffee-script-source
14 | execjs
15 | coffee-script-source (1.11.1)
16 | colorator (1.1.0)
17 | commonmarker (0.17.13)
18 | ruby-enum (~> 0.5)
19 | concurrent-ruby (1.1.8)
20 | dnsruby (1.61.5)
21 | simpleidn (~> 0.1)
22 | em-websocket (0.5.2)
23 | eventmachine (>= 0.12.9)
24 | http_parser.rb (~> 0.6.0)
25 | ethon (0.14.0)
26 | ffi (>= 1.15.0)
27 | eventmachine (1.2.7)
28 | execjs (2.8.1)
29 | faraday (1.4.1)
30 | faraday-excon (~> 1.1)
31 | faraday-net_http (~> 1.0)
32 | faraday-net_http_persistent (~> 1.1)
33 | multipart-post (>= 1.2, < 3)
34 | ruby2_keywords (>= 0.0.4)
35 | faraday-excon (1.1.0)
36 | faraday-net_http (1.0.1)
37 | faraday-net_http_persistent (1.1.0)
38 | ffi (1.15.0)
39 | forwardable-extended (2.6.0)
40 | gemoji (3.0.1)
41 | github-pages (214)
42 | github-pages-health-check (= 1.17.0)
43 | jekyll (= 3.9.0)
44 | jekyll-avatar (= 0.7.0)
45 | jekyll-coffeescript (= 1.1.1)
46 | jekyll-commonmark-ghpages (= 0.1.6)
47 | jekyll-default-layout (= 0.1.4)
48 | jekyll-feed (= 0.15.1)
49 | jekyll-gist (= 1.5.0)
50 | jekyll-github-metadata (= 2.13.0)
51 | jekyll-mentions (= 1.6.0)
52 | jekyll-optional-front-matter (= 0.3.2)
53 | jekyll-paginate (= 1.1.0)
54 | jekyll-readme-index (= 0.3.0)
55 | jekyll-redirect-from (= 0.16.0)
56 | jekyll-relative-links (= 0.6.1)
57 | jekyll-remote-theme (= 0.4.3)
58 | jekyll-sass-converter (= 1.5.2)
59 | jekyll-seo-tag (= 2.7.1)
60 | jekyll-sitemap (= 1.4.0)
61 | jekyll-swiss (= 1.0.0)
62 | jekyll-theme-architect (= 0.1.1)
63 | jekyll-theme-cayman (= 0.1.1)
64 | jekyll-theme-dinky (= 0.1.1)
65 | jekyll-theme-hacker (= 0.1.2)
66 | jekyll-theme-leap-day (= 0.1.1)
67 | jekyll-theme-merlot (= 0.1.1)
68 | jekyll-theme-midnight (= 0.1.1)
69 | jekyll-theme-minimal (= 0.1.1)
70 | jekyll-theme-modernist (= 0.1.1)
71 | jekyll-theme-primer (= 0.5.4)
72 | jekyll-theme-slate (= 0.1.1)
73 | jekyll-theme-tactile (= 0.1.1)
74 | jekyll-theme-time-machine (= 0.1.1)
75 | jekyll-titles-from-headings (= 0.5.3)
76 | jemoji (= 0.12.0)
77 | kramdown (= 2.3.1)
78 | kramdown-parser-gfm (= 1.1.0)
79 | liquid (= 4.0.3)
80 | mercenary (~> 0.3)
81 | minima (= 2.5.1)
82 | nokogiri (>= 1.10.4, < 2.0)
83 | rouge (= 3.26.0)
84 | terminal-table (~> 1.4)
85 | github-pages-health-check (1.17.0)
86 | addressable (~> 2.3)
87 | dnsruby (~> 1.60)
88 | octokit (~> 4.0)
89 | public_suffix (>= 2.0.2, < 5.0)
90 | typhoeus (~> 1.3)
91 | html-pipeline (2.14.0)
92 | activesupport (>= 2)
93 | nokogiri (>= 1.4)
94 | http_parser.rb (0.6.0)
95 | i18n (0.9.5)
96 | concurrent-ruby (~> 1.0)
97 | jekyll (3.9.0)
98 | addressable (~> 2.4)
99 | colorator (~> 1.0)
100 | em-websocket (~> 0.5)
101 | i18n (~> 0.7)
102 | jekyll-sass-converter (~> 1.0)
103 | jekyll-watch (~> 2.0)
104 | kramdown (>= 1.17, < 3)
105 | liquid (~> 4.0)
106 | mercenary (~> 0.3.3)
107 | pathutil (~> 0.9)
108 | rouge (>= 1.7, < 4)
109 | safe_yaml (~> 1.0)
110 | jekyll-avatar (0.7.0)
111 | jekyll (>= 3.0, < 5.0)
112 | jekyll-coffeescript (1.1.1)
113 | coffee-script (~> 2.2)
114 | coffee-script-source (~> 1.11.1)
115 | jekyll-commonmark (1.3.1)
116 | commonmarker (~> 0.14)
117 | jekyll (>= 3.7, < 5.0)
118 | jekyll-commonmark-ghpages (0.1.6)
119 | commonmarker (~> 0.17.6)
120 | jekyll-commonmark (~> 1.2)
121 | rouge (>= 2.0, < 4.0)
122 | jekyll-default-layout (0.1.4)
123 | jekyll (~> 3.0)
124 | jekyll-feed (0.15.1)
125 | jekyll (>= 3.7, < 5.0)
126 | jekyll-gist (1.5.0)
127 | octokit (~> 4.2)
128 | jekyll-github-metadata (2.13.0)
129 | jekyll (>= 3.4, < 5.0)
130 | octokit (~> 4.0, != 4.4.0)
131 | jekyll-mentions (1.6.0)
132 | html-pipeline (~> 2.3)
133 | jekyll (>= 3.7, < 5.0)
134 | jekyll-optional-front-matter (0.3.2)
135 | jekyll (>= 3.0, < 5.0)
136 | jekyll-paginate (1.1.0)
137 | jekyll-readme-index (0.3.0)
138 | jekyll (>= 3.0, < 5.0)
139 | jekyll-redirect-from (0.16.0)
140 | jekyll (>= 3.3, < 5.0)
141 | jekyll-relative-links (0.6.1)
142 | jekyll (>= 3.3, < 5.0)
143 | jekyll-remote-theme (0.4.3)
144 | addressable (~> 2.0)
145 | jekyll (>= 3.5, < 5.0)
146 | jekyll-sass-converter (>= 1.0, <= 3.0.0, != 2.0.0)
147 | rubyzip (>= 1.3.0, < 3.0)
148 | jekyll-sass-converter (1.5.2)
149 | sass (~> 3.4)
150 | jekyll-seo-tag (2.7.1)
151 | jekyll (>= 3.8, < 5.0)
152 | jekyll-sitemap (1.4.0)
153 | jekyll (>= 3.7, < 5.0)
154 | jekyll-swiss (1.0.0)
155 | jekyll-theme-architect (0.1.1)
156 | jekyll (~> 3.5)
157 | jekyll-seo-tag (~> 2.0)
158 | jekyll-theme-cayman (0.1.1)
159 | jekyll (~> 3.5)
160 | jekyll-seo-tag (~> 2.0)
161 | jekyll-theme-dinky (0.1.1)
162 | jekyll (~> 3.5)
163 | jekyll-seo-tag (~> 2.0)
164 | jekyll-theme-hacker (0.1.2)
165 | jekyll (> 3.5, < 5.0)
166 | jekyll-seo-tag (~> 2.0)
167 | jekyll-theme-leap-day (0.1.1)
168 | jekyll (~> 3.5)
169 | jekyll-seo-tag (~> 2.0)
170 | jekyll-theme-merlot (0.1.1)
171 | jekyll (~> 3.5)
172 | jekyll-seo-tag (~> 2.0)
173 | jekyll-theme-midnight (0.1.1)
174 | jekyll (~> 3.5)
175 | jekyll-seo-tag (~> 2.0)
176 | jekyll-theme-minimal (0.1.1)
177 | jekyll (~> 3.5)
178 | jekyll-seo-tag (~> 2.0)
179 | jekyll-theme-modernist (0.1.1)
180 | jekyll (~> 3.5)
181 | jekyll-seo-tag (~> 2.0)
182 | jekyll-theme-primer (0.5.4)
183 | jekyll (> 3.5, < 5.0)
184 | jekyll-github-metadata (~> 2.9)
185 | jekyll-seo-tag (~> 2.0)
186 | jekyll-theme-slate (0.1.1)
187 | jekyll (~> 3.5)
188 | jekyll-seo-tag (~> 2.0)
189 | jekyll-theme-tactile (0.1.1)
190 | jekyll (~> 3.5)
191 | jekyll-seo-tag (~> 2.0)
192 | jekyll-theme-time-machine (0.1.1)
193 | jekyll (~> 3.5)
194 | jekyll-seo-tag (~> 2.0)
195 | jekyll-titles-from-headings (0.5.3)
196 | jekyll (>= 3.3, < 5.0)
197 | jekyll-watch (2.2.1)
198 | listen (~> 3.0)
199 | jemoji (0.12.0)
200 | gemoji (~> 3.0)
201 | html-pipeline (~> 2.2)
202 | jekyll (>= 3.0, < 5.0)
203 | kramdown (2.3.1)
204 | rexml
205 | kramdown-parser-gfm (1.1.0)
206 | kramdown (~> 2.0)
207 | liquid (4.0.3)
208 | listen (3.5.1)
209 | rb-fsevent (~> 0.10, >= 0.10.3)
210 | rb-inotify (~> 0.9, >= 0.9.10)
211 | mercenary (0.3.6)
212 | mini_portile2 (2.5.1)
213 | minima (2.5.1)
214 | jekyll (>= 3.5, < 5.0)
215 | jekyll-feed (~> 0.9)
216 | jekyll-seo-tag (~> 2.1)
217 | minitest (5.14.4)
218 | multipart-post (2.1.1)
219 | nokogiri (1.11.0)
220 | mini_portile2 (~> 2.5.0)
221 | racc (~> 1.4)
222 | octokit (4.21.0)
223 | faraday (>= 0.9)
224 | sawyer (~> 0.8.0, >= 0.5.3)
225 | pathutil (0.16.2)
226 | forwardable-extended (~> 2.6)
227 | public_suffix (4.0.6)
228 | racc (1.5.2)
229 | rb-fsevent (0.11.0)
230 | rb-inotify (0.10.1)
231 | ffi (~> 1.0)
232 | rexml (3.2.5)
233 | rouge (3.26.0)
234 | ruby-enum (0.9.0)
235 | i18n
236 | ruby2_keywords (0.0.4)
237 | rubyzip (2.3.0)
238 | safe_yaml (1.0.5)
239 | sass (3.7.4)
240 | sass-listen (~> 4.0.0)
241 | sass-listen (4.0.0)
242 | rb-fsevent (~> 0.9, >= 0.9.4)
243 | rb-inotify (~> 0.9, >= 0.9.7)
244 | sawyer (0.8.2)
245 | addressable (>= 2.3.5)
246 | faraday (> 0.8, < 2.0)
247 | simpleidn (0.2.1)
248 | unf (~> 0.1.4)
249 | terminal-table (1.8.0)
250 | unicode-display_width (~> 1.1, >= 1.1.1)
251 | thread_safe (0.3.6)
252 | typhoeus (1.4.0)
253 | ethon (>= 0.9.0)
254 | tzinfo (1.2.9)
255 | thread_safe (~> 0.1)
256 | unf (0.1.4)
257 | unf_ext
258 | unf_ext (0.0.7.7)
259 | unicode-display_width (1.7.0)
260 | webrick (1.7.0)
261 | zeitwerk (2.4.2)
262 |
263 | PLATFORMS
264 | ruby
265 |
266 | DEPENDENCIES
267 | github-pages
268 | jekyll (>= 3.7)
269 | jekyll-remote-theme
270 | kramdown (>= 2.3.1)
271 | nokogiri (< 1.11.1)
272 | webrick (~> 1.7)
273 |
274 | BUNDLED WITH
275 | 2.2.17
276 |
--------------------------------------------------------------------------------
/mezzala/models.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/models.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['ScorelinePrediction', 'Outcomes', 'OutcomePrediction', 'scoreline_to_outcome', 'scorelines_to_outcomes',
4 | 'DixonColes']
5 |
6 | # Cell
7 | import collections
8 | import dataclasses
9 | import enum
10 | import functools
11 | import itertools
12 | import typing
13 | import warnings
14 |
15 | import numpy as np
16 | import scipy.stats
17 | import scipy.optimize
18 |
19 | import mezzala.blocks
20 | import mezzala.weights
21 | import mezzala.parameters
22 |
23 | # Cell
24 |
25 |
26 | @dataclasses.dataclass(frozen=True)
27 | class ScorelinePrediction:
28 | home_goals: int
29 | away_goals: int
30 | probability: float
31 |
32 |
33 | # Cell
34 |
35 |
36 | class Outcomes(enum.Enum):
37 | HOME_WIN = 'Home win'
38 | DRAW = 'Draw'
39 | AWAY_WIN = 'Away win'
40 |
41 | def __repr__(self):
42 | return f"Outcomes('{self.value}')"
43 |
44 |
45 | @dataclasses.dataclass(frozen=True)
46 | class OutcomePrediction:
47 | outcome: Outcomes
48 | probability: float
49 |
50 | # Cell
51 |
52 |
53 | def scoreline_to_outcome(home_goals, away_goals):
54 | if home_goals > away_goals:
55 | return Outcomes.HOME_WIN
56 | if home_goals == away_goals:
57 | return Outcomes.DRAW
58 | if home_goals < away_goals:
59 | return Outcomes.AWAY_WIN
60 |
61 |
62 | def scorelines_to_outcomes(scorelines):
63 | return {
64 | outcome: OutcomePrediction(
65 | outcome,
66 | sum(s.probability for s in scorelines if scoreline_to_outcome(s.home_goals, s.away_goals) == outcome)
67 | )
68 | for outcome in Outcomes
69 | }
70 |
71 | # Cell
72 |
73 | _DEFAULT_BLOCKS = [
74 | mezzala.blocks.BaseRate(),
75 | mezzala.blocks.HomeAdvantage(),
76 | mezzala.blocks.TeamStrength(),
77 | ]
78 |
79 |
80 | class DixonColes:
81 | """
82 | Dixon-Coles models in Python
83 | """
84 |
85 | def __init__(self, adapter, blocks=_DEFAULT_BLOCKS, weight=mezzala.weights.UniformWeight(), params=None):
86 | # NOTE: Should params be stored internally as separate lists of keys and values?
87 | # Then `params` (the dict) can be a property?
88 | self.params = params
89 | self.adapter = adapter
90 | self.weight = weight
91 | self._blocks = blocks
92 |
93 | def __repr__(self):
94 | return f'DixonColes(adapter={repr(self.adapter)}, blocks={repr([b for b in self.blocks])}), weight={repr(self.weight)}'
95 |
96 | @property
97 | def blocks(self):
98 | # Make sure blocks are always in the correct order
99 | return sorted(self._blocks, key=lambda x: -x.PRIORITY)
100 |
101 | def home_goals(self, row):
102 | """ Returns home goals scored """
103 | return self.adapter.home_goals(row)
104 |
105 | def away_goals(self, row):
106 | """ Returns away goals scored """
107 | return self.adapter.away_goals(row)
108 |
109 | def parse_params(self, data):
110 | """ Returns a tuple of (parameter_names, [constraints]) """
111 | base_params = [mezzala.parameters.RHO_KEY]
112 | block_params = list(itertools.chain(*[b.param_keys(self.adapter, data) for b in self.blocks]))
113 | return (
114 | block_params + base_params,
115 | list(itertools.chain(*[b.constraints(self.adapter, data) for b in self.blocks]))
116 | )
117 |
118 | def _home_terms(self, row):
119 | return dict(itertools.chain(*[b.home_terms(self.adapter, row) for b in self.blocks]))
120 |
121 | def _away_terms(self, row):
122 | return dict(itertools.chain(*[b.away_terms(self.adapter, row) for b in self.blocks]))
123 |
124 | # Core methods
125 |
126 | @staticmethod
127 | def _assign_params(param_keys, param_values):
128 | return dict(zip(param_keys, param_values))
129 |
130 | def _create_feature_matrices(self, param_keys, data):
131 | """ Create X (feature) matrices for home and away poisson rates """
132 | home_X = np.empty([len(data), len(param_keys)])
133 | away_X = np.empty([len(data), len(param_keys)])
134 | for row_i, row in enumerate(data):
135 | home_rate_terms = self._home_terms(row)
136 | away_rate_terms = self._away_terms(row)
137 | for param_i, param_key in enumerate(param_keys):
138 | home_X[row_i, param_i] = home_rate_terms.get(param_key, 0)
139 | away_X[row_i, param_i] = away_rate_terms.get(param_key, 0)
140 | return home_X, away_X
141 |
142 | @staticmethod
143 | def _tau(home_goals, away_goals, home_rate, away_rate, rho):
144 |
145 | tau = np.ones(len(home_goals))
146 | tau = np.where((home_goals == 0) & (away_goals == 0), 1 - home_rate*away_rate*rho, tau)
147 | tau = np.where((home_goals == 0) & (away_goals == 1), 1 + home_rate*rho, tau)
148 | tau = np.where((home_goals == 1) & (away_goals == 0), 1 + away_rate*rho, tau)
149 | tau = np.where((home_goals == 1) & (away_goals == 1), 1 - rho, tau)
150 |
151 | return tau
152 |
153 | def _log_like(self, home_goals, away_goals, home_rate, away_rate, rho):
154 | return (
155 | scipy.stats.poisson.logpmf(home_goals, home_rate) +
156 | scipy.stats.poisson.logpmf(away_goals, away_rate) +
157 | np.log(self._tau(home_goals, away_goals, home_rate, away_rate, rho))
158 | )
159 |
160 | def objective_fn(self, data, home_goals, away_goals, weights, home_X, away_X, rho_ix, xs):
161 | rho = xs[rho_ix]
162 |
163 | # Parameters are estimated in log-space, but `scipy.stats.poisson`
164 | # expects real number inputs, so we have to use `np.exp`
165 | home_rate = np.exp(np.dot(home_X, xs))
166 | away_rate = np.exp(np.dot(away_X, xs))
167 |
168 | log_like = self._log_like(home_goals, away_goals, home_rate, away_rate, rho)
169 | pseudo_log_like = log_like * weights
170 | return -np.sum(pseudo_log_like)
171 |
172 | def fit(self, data, **kwargs):
173 | param_keys, constraints = self.parse_params(data)
174 |
175 | init_params = (
176 | # Attempt to initialise parameters from any already-existing parameters
177 | # This substantially speeds up fitting during (e.g.) backtesting
178 | np.asarray([self.params.get(p, 0) for p in param_keys])
179 | # If the model has no parameters, just initialise with 0s
180 | if self.params
181 | else np.zeros(len(param_keys))
182 | )
183 |
184 | # Precalculate the things we can (for speed)
185 |
186 | # Create X (feature) matrices for home and away poisson rates
187 | home_X, away_X = self._create_feature_matrices(param_keys, data)
188 |
189 | # Get home goals, away goals, and weights from the data
190 | home_goals, away_goals = np.empty(len(data)), np.empty(len(data))
191 | weights = np.empty(len(data))
192 | for i, row in enumerate(data):
193 | home_goals[i] = self.home_goals(row)
194 | away_goals[i] = self.away_goals(row)
195 | weights[i] = self.weight(row)
196 |
197 | # Get the index of the Rho correlation parameter
198 | rho_ix = param_keys.index(mezzala.parameters.RHO_KEY)
199 |
200 | # Optimise!
201 | with warnings.catch_warnings():
202 | # This is a hack
203 | # Because we haven't properly constrained `rho`, it's possible for 0 or even negative
204 | # values of `tau` (and therefore invalid probabilities)
205 | # Ignoring the warnings has little practical impact, since the model
206 | # will still find the objective function's minimum point regardless
207 | warnings.simplefilter('ignore')
208 |
209 | estimate = scipy.optimize.minimize(
210 | lambda xs: self.objective_fn(data, home_goals, away_goals, weights, home_X, away_X, rho_ix, xs),
211 | x0=init_params,
212 | constraints=constraints,
213 | **kwargs
214 | )
215 |
216 | # Parse the estimates into parameter map
217 | self.params = self._assign_params(param_keys, estimate.x)
218 |
219 | return self
220 |
221 | def predict_one(self, row, up_to=26):
222 | scorelines = list(itertools.product(range(up_to), repeat=2))
223 |
224 | home_goals = np.asarray([h for h, a in scorelines])
225 | away_goals = np.asarray([a for h, a in scorelines])
226 |
227 | param_keys = self.params.keys()
228 | param_values = np.asarray([v for v in self.params.values()])
229 |
230 | home_X, away_X = self._create_feature_matrices(param_keys, [row])
231 |
232 | home_rate = np.exp(np.dot(home_X, param_values))
233 | away_rate = np.exp(np.dot(away_X, param_values))
234 | rho = self.params[mezzala.parameters.RHO_KEY]
235 |
236 | probs = np.exp(self._log_like(home_goals, away_goals, home_rate, away_rate, rho))
237 |
238 | return [ScorelinePrediction(*vals) for vals in zip(home_goals.tolist(), away_goals.tolist(), probs)]
239 |
240 | def predict(self, data, up_to=26):
241 | scorelines = [self.predict_one(row, up_to=up_to) for row in data]
242 | return scorelines
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/nbs/index.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Mezzala\n",
8 | "\n",
9 | "> Models for estimating football (soccer) team-strength"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "## Install"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | "`pip install mezzala`"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {},
29 | "source": [
30 | "## How to use"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "import mezzala"
40 | ]
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {},
45 | "source": [
46 | "Fitting a Dixon-Coles team strength model:"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "metadata": {},
52 | "source": [
53 | "First, we need to get some data"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {},
60 | "outputs": [
61 | {
62 | "data": {
63 | "text/plain": [
64 | "[{'date': '2016-08-13',\n",
65 | " 'team1': 'Hull City AFC',\n",
66 | " 'team2': 'Leicester City FC',\n",
67 | " 'score': {'ft': [2, 1]}},\n",
68 | " {'date': '2016-08-13',\n",
69 | " 'team1': 'Everton FC',\n",
70 | " 'team2': 'Tottenham Hotspur FC',\n",
71 | " 'score': {'ft': [1, 1]}},\n",
72 | " {'date': '2016-08-13',\n",
73 | " 'team1': 'Crystal Palace FC',\n",
74 | " 'team2': 'West Bromwich Albion FC',\n",
75 | " 'score': {'ft': [0, 1]}}]"
76 | ]
77 | },
78 | "execution_count": null,
79 | "metadata": {},
80 | "output_type": "execute_result"
81 | }
82 | ],
83 | "source": [
84 | "import itertools\n",
85 | "import json\n",
86 | "import urllib.request\n",
87 | "\n",
88 | "\n",
89 | "# Use 2016/17 Premier League data from the openfootball repo\n",
90 | "url = 'https://raw.githubusercontent.com/openfootball/football.json/master/2016-17/en.1.json'\n",
91 | "\n",
92 | "\n",
93 | "response = urllib.request.urlopen(url)\n",
94 | "data_raw = json.loads(response.read())\n",
95 | "\n",
96 | "# Reshape the data to just get the matches\n",
97 | "data = list(itertools.chain(*[d['matches'] for d in data_raw['rounds']]))\n",
98 | "\n",
99 | "data[0:3]"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "metadata": {},
105 | "source": [
106 | "### Fitting a model"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "metadata": {},
112 | "source": [
113 | "To fit a model with mezzala, you need to create an \"adapter\". Adapters are used to connect a model to a data source.\n",
114 | "\n",
115 | "Because our data is a list of dicts, we are going to use a `KeyAdapter`."
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "metadata": {},
122 | "outputs": [
123 | {
124 | "data": {
125 | "text/plain": [
126 | "'Hull City AFC'"
127 | ]
128 | },
129 | "execution_count": null,
130 | "metadata": {},
131 | "output_type": "execute_result"
132 | }
133 | ],
134 | "source": [
135 | "adapter = mezzala.KeyAdapter( # `KeyAdapter` = datum['...']\n",
136 | " home_team='team1',\n",
137 | " away_team='team2',\n",
138 | " home_goals=['score', 'ft', 0], # Get nested fields with lists of fields\n",
139 | " away_goals=['score', 'ft', 1], # i.e. datum['score']['ft'][1]\n",
140 | ")\n",
141 | "\n",
142 | "# You'll never need to call the methods on an \n",
143 | "# adapter directly, but just to show that it \n",
144 | "# works as expected:\n",
145 | "adapter.home_team(data[0])"
146 | ]
147 | },
148 | {
149 | "cell_type": "markdown",
150 | "metadata": {},
151 | "source": [
152 | "Once we have an adapter for our specific data source, we can fit the model:"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": null,
158 | "metadata": {},
159 | "outputs": [
160 | {
161 | "data": {
162 | "text/plain": [
163 | "DixonColes(adapter=KeyAdapter(home_goals=['score', 'ft', 0], away_goals=['score', 'ft', 1], home_team='team1', away_team='team2'), blocks=[TeamStrength(), BaseRate(), HomeAdvantage()]), weight=UniformWeight()"
164 | ]
165 | },
166 | "execution_count": null,
167 | "metadata": {},
168 | "output_type": "execute_result"
169 | }
170 | ],
171 | "source": [
172 | "model = mezzala.DixonColes(adapter=adapter)\n",
173 | "model.fit(data)"
174 | ]
175 | },
176 | {
177 | "cell_type": "markdown",
178 | "metadata": {},
179 | "source": [
180 | "### Making predictions"
181 | ]
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "metadata": {},
186 | "source": [
187 | "By default, you only need to supply the home and away team to get predictions. This should be supplied in the same format as the training data.\n",
188 | "\n",
189 | "`DixonColes` has two methods for making predictions:\n",
190 | "\n",
191 | "* `predict_one` - for predicting a single match\n",
192 | "* `predict` - for predicting multiple matches"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": null,
198 | "metadata": {},
199 | "outputs": [
200 | {
201 | "data": {
202 | "text/plain": [
203 | "[ScorelinePrediction(home_goals=0, away_goals=0, probability=0.023625049697587167),\n",
204 | " ScorelinePrediction(home_goals=0, away_goals=1, probability=0.012682094432376022),\n",
205 | " ScorelinePrediction(home_goals=0, away_goals=2, probability=0.00623268833779594),\n",
206 | " ScorelinePrediction(home_goals=0, away_goals=3, probability=0.0016251514235046444),\n",
207 | " ScorelinePrediction(home_goals=0, away_goals=4, probability=0.00031781436109636405)]"
208 | ]
209 | },
210 | "execution_count": null,
211 | "metadata": {},
212 | "output_type": "execute_result"
213 | }
214 | ],
215 | "source": [
216 | "match_to_predict = {\n",
217 | " 'team1': 'Manchester City FC',\n",
218 | " 'team2': 'Swansea City FC',\n",
219 | "}\n",
220 | "\n",
221 | "scorelines = model.predict_one(match_to_predict)\n",
222 | "\n",
223 | "scorelines[0:5]"
224 | ]
225 | },
226 | {
227 | "cell_type": "markdown",
228 | "metadata": {},
229 | "source": [
230 | "Each of these methods return predictions in the form of `ScorelinePredictions`. \n",
231 | "\n",
232 | "* `predict_one` returns a list of `ScorelinePredictions`\n",
233 | "* `predict` returns a list of `ScorelinePredictions` for each predicted match (i.e. a list of lists)\n",
234 | "\n",
235 | "However, it can sometimes be more useful to have predictions in the form of match _outcomes_. Mezzala exposes the `scorelines_to_outcomes` function for this purpose:"
236 | ]
237 | },
238 | {
239 | "cell_type": "code",
240 | "execution_count": null,
241 | "metadata": {},
242 | "outputs": [
243 | {
244 | "data": {
245 | "text/plain": [
246 | "{Outcomes('Home win'): OutcomePrediction(outcome=Outcomes('Home win'), probability=0.8255103334702835),\n",
247 | " Outcomes('Draw'): OutcomePrediction(outcome=Outcomes('Draw'), probability=0.11615659853961693),\n",
248 | " Outcomes('Away win'): OutcomePrediction(outcome=Outcomes('Away win'), probability=0.058333067990098304)}"
249 | ]
250 | },
251 | "execution_count": null,
252 | "metadata": {},
253 | "output_type": "execute_result"
254 | }
255 | ],
256 | "source": [
257 | "mezzala.scorelines_to_outcomes(scorelines)"
258 | ]
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "metadata": {},
263 | "source": [
264 | "### Extending the model\n",
265 | "\n",
266 | "It's possible to fit more sophisticated models with mezzala, using **weights** and **model blocks**\n",
267 | "\n",
268 | "#### Weights\n",
269 | "\n",
270 | "You can weight individual data points by supplying a function (or callable) to the `weight` argument to `DixonColes`:"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": null,
276 | "metadata": {},
277 | "outputs": [
278 | {
279 | "data": {
280 | "text/plain": [
281 | "DixonColes(adapter=KeyAdapter(home_goals=['score', 'ft', 0], away_goals=['score', 'ft', 1], home_team='team1', away_team='team2'), blocks=[TeamStrength(), BaseRate(), HomeAdvantage()]), weight= at 0x123067488>"
282 | ]
283 | },
284 | "execution_count": null,
285 | "metadata": {},
286 | "output_type": "execute_result"
287 | }
288 | ],
289 | "source": [
290 | "mezzala.DixonColes(\n",
291 | " adapter=adapter,\n",
292 | " # By default, all data points are weighted equally,\n",
293 | " # which is equivalent to:\n",
294 | " weight=lambda x: 1\n",
295 | ")"
296 | ]
297 | },
298 | {
299 | "cell_type": "markdown",
300 | "metadata": {},
301 | "source": [
302 | "Mezzala also provides an `ExponentialWeight` for the purpose of time-discounting:"
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": null,
308 | "metadata": {},
309 | "outputs": [
310 | {
311 | "data": {
312 | "text/plain": [
313 | "DixonColes(adapter=KeyAdapter(home_goals=['score', 'ft', 0], away_goals=['score', 'ft', 1], home_team='team1', away_team='team2'), blocks=[TeamStrength(), BaseRate(), HomeAdvantage()]), weight=ExponentialWeight(epsilon=-0.0065, key= at 0x122f938c8>)"
314 | ]
315 | },
316 | "execution_count": null,
317 | "metadata": {},
318 | "output_type": "execute_result"
319 | }
320 | ],
321 | "source": [
322 | "mezzala.DixonColes(\n",
323 | " adapter=adapter,\n",
324 | " weight=mezzala.ExponentialWeight(\n",
325 | " epsilon=-0.0065, # Decay rate\n",
326 | " key=lambda x: x['days_ago']\n",
327 | " )\n",
328 | ")"
329 | ]
330 | },
331 | {
332 | "cell_type": "markdown",
333 | "metadata": {},
334 | "source": [
335 | "#### Model blocks\n",
336 | "\n",
337 | "Model \"blocks\" define the calculation and estimation of home and away goalscoring rates."
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": null,
343 | "metadata": {},
344 | "outputs": [
345 | {
346 | "data": {
347 | "text/plain": [
348 | "DixonColes(adapter=KeyAdapter(home_goals=['score', 'ft', 0], away_goals=['score', 'ft', 1], home_team='team1', away_team='team2'), blocks=[TeamStrength(), HomeAdvantage(), BaseRate()]), weight=UniformWeight()"
349 | ]
350 | },
351 | "execution_count": null,
352 | "metadata": {},
353 | "output_type": "execute_result"
354 | }
355 | ],
356 | "source": [
357 | "mezzala.DixonColes(\n",
358 | " adapter=adapter,\n",
359 | " # By default, only team strength and home advantage,\n",
360 | " # is estimated:\n",
361 | " blocks=[\n",
362 | " mezzala.blocks.HomeAdvantage(),\n",
363 | " mezzala.blocks.TeamStrength(),\n",
364 | " mezzala.blocks.BaseRate(), # Adds \"average goalscoring rate\" as a distinct parameter\n",
365 | " ]\n",
366 | ")"
367 | ]
368 | },
369 | {
370 | "cell_type": "markdown",
371 | "metadata": {},
372 | "source": [
373 | "To add custom parameters (e.g. per-league home advantage), you need to add additional model blocks."
374 | ]
375 | }
376 | ],
377 | "metadata": {
378 | "kernelspec": {
379 | "display_name": "Python 3",
380 | "language": "python",
381 | "name": "python3"
382 | }
383 | },
384 | "nbformat": 4,
385 | "nbformat_minor": 2
386 | }
387 |
--------------------------------------------------------------------------------
/nbs/models.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "edbf830b",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "# default_exp models"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "id": "62b5819d",
16 | "metadata": {},
17 | "source": [
18 | "# Models"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "id": "a2ed6dc9",
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "#hide\n",
29 | "from nbdev.showdoc import *"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": null,
35 | "id": "1c6d1b2e",
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "#export\n",
40 | "import collections\n",
41 | "import dataclasses\n",
42 | "import enum\n",
43 | "import functools\n",
44 | "import itertools\n",
45 | "import typing\n",
46 | "import warnings\n",
47 | "\n",
48 | "import numpy as np\n",
49 | "import scipy.stats\n",
50 | "import scipy.optimize\n",
51 | "\n",
52 | "import mezzala.blocks\n",
53 | "import mezzala.weights\n",
54 | "import mezzala.parameters"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "id": "476722af",
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "#export\n",
65 | "\n",
66 | "\n",
67 | "@dataclasses.dataclass(frozen=True)\n",
68 | "class ScorelinePrediction:\n",
69 | " home_goals: int\n",
70 | " away_goals: int\n",
71 | " probability: float\n",
72 | " "
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "id": "ca2fbcf3",
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "#export\n",
83 | "\n",
84 | "\n",
85 | "class Outcomes(enum.Enum):\n",
86 | " HOME_WIN = 'Home win'\n",
87 | " DRAW = 'Draw'\n",
88 | " AWAY_WIN = 'Away win'\n",
89 | " \n",
90 | " def __repr__(self):\n",
91 | " return f\"Outcomes('{self.value}')\"\n",
92 | "\n",
93 | "\n",
94 | "@dataclasses.dataclass(frozen=True)\n",
95 | "class OutcomePrediction:\n",
96 | " outcome: Outcomes\n",
97 | " probability: float"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "id": "1a44f039",
104 | "metadata": {},
105 | "outputs": [],
106 | "source": [
107 | "#export\n",
108 | "\n",
109 | "\n",
110 | "def scoreline_to_outcome(home_goals, away_goals):\n",
111 | " if home_goals > away_goals:\n",
112 | " return Outcomes.HOME_WIN\n",
113 | " if home_goals == away_goals:\n",
114 | " return Outcomes.DRAW\n",
115 | " if home_goals < away_goals:\n",
116 | " return Outcomes.AWAY_WIN\n",
117 | " \n",
118 | " \n",
119 | "def scorelines_to_outcomes(scorelines):\n",
120 | " return {\n",
121 | " outcome: OutcomePrediction(\n",
122 | " outcome, \n",
123 | " sum(s.probability for s in scorelines if scoreline_to_outcome(s.home_goals, s.away_goals) == outcome)\n",
124 | " )\n",
125 | " for outcome in Outcomes\n",
126 | " }"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "id": "8bff2f9f",
133 | "metadata": {},
134 | "outputs": [],
135 | "source": [
136 | "#export\n",
137 | "\n",
138 | "_DEFAULT_BLOCKS = [\n",
139 | " mezzala.blocks.BaseRate(),\n",
140 | " mezzala.blocks.HomeAdvantage(),\n",
141 | " mezzala.blocks.TeamStrength(),\n",
142 | "]\n",
143 | "\n",
144 | "\n",
145 | "class DixonColes:\n",
146 | " \"\"\"\n",
147 | " Dixon-Coles models in Python\n",
148 | " \"\"\"\n",
149 | " \n",
150 | " def __init__(self, adapter, blocks=_DEFAULT_BLOCKS, weight=mezzala.weights.UniformWeight(), params=None):\n",
151 | " # NOTE: Should params be stored internally as separate lists of keys and values? \n",
152 | " # Then `params` (the dict) can be a property?\n",
153 | " self.params = params\n",
154 | " self.adapter = adapter\n",
155 | " self.weight = weight\n",
156 | " self._blocks = blocks\n",
157 | " \n",
158 | " def __repr__(self):\n",
159 | " return f'DixonColes(adapter={repr(self.adapter)}, blocks={repr([b for b in self.blocks])}), weight={repr(self.weight)}'\n",
160 | " \n",
161 | " @property\n",
162 | " def blocks(self):\n",
163 | " # Make sure blocks are always in the correct order\n",
164 | " return sorted(self._blocks, key=lambda x: -x.PRIORITY)\n",
165 | "\n",
166 | " def home_goals(self, row):\n",
167 | " \"\"\" Returns home goals scored \"\"\"\n",
168 | " return self.adapter.home_goals(row)\n",
169 | "\n",
170 | " def away_goals(self, row):\n",
171 | " \"\"\" Returns away goals scored \"\"\"\n",
172 | " return self.adapter.away_goals(row)\n",
173 | "\n",
174 | " def parse_params(self, data):\n",
175 | " \"\"\" Returns a tuple of (parameter_names, [constraints]) \"\"\"\n",
176 | " base_params = [mezzala.parameters.RHO_KEY]\n",
177 | " block_params = list(itertools.chain(*[b.param_keys(self.adapter, data) for b in self.blocks]))\n",
178 | " return (\n",
179 | " block_params + base_params,\n",
180 | " list(itertools.chain(*[b.constraints(self.adapter, data) for b in self.blocks]))\n",
181 | " )\n",
182 | " \n",
183 | " def _home_terms(self, row):\n",
184 | " return dict(itertools.chain(*[b.home_terms(self.adapter, row) for b in self.blocks]))\n",
185 | " \n",
186 | " def _away_terms(self, row):\n",
187 | " return dict(itertools.chain(*[b.away_terms(self.adapter, row) for b in self.blocks]))\n",
188 | " \n",
189 | " # Core methods\n",
190 | "\n",
191 | " @staticmethod\n",
192 | " def _assign_params(param_keys, param_values):\n",
193 | " return dict(zip(param_keys, param_values))\n",
194 | " \n",
195 | " def _create_feature_matrices(self, param_keys, data):\n",
196 | " \"\"\" Create X (feature) matrices for home and away poisson rates \"\"\"\n",
197 | " home_X = np.empty([len(data), len(param_keys)])\n",
198 | " away_X = np.empty([len(data), len(param_keys)])\n",
199 | " for row_i, row in enumerate(data):\n",
200 | " home_rate_terms = self._home_terms(row)\n",
201 | " away_rate_terms = self._away_terms(row)\n",
202 | " for param_i, param_key in enumerate(param_keys):\n",
203 | " home_X[row_i, param_i] = home_rate_terms.get(param_key, 0)\n",
204 | " away_X[row_i, param_i] = away_rate_terms.get(param_key, 0)\n",
205 | " return home_X, away_X\n",
206 | "\n",
207 | " @staticmethod\n",
208 | " def _tau(home_goals, away_goals, home_rate, away_rate, rho):\n",
209 | " \n",
210 | " tau = np.ones(len(home_goals))\n",
211 | " tau = np.where((home_goals == 0) & (away_goals == 0), 1 - home_rate*away_rate*rho, tau)\n",
212 | " tau = np.where((home_goals == 0) & (away_goals == 1), 1 + home_rate*rho, tau)\n",
213 | " tau = np.where((home_goals == 1) & (away_goals == 0), 1 + away_rate*rho, tau)\n",
214 | " tau = np.where((home_goals == 1) & (away_goals == 1), 1 - rho, tau) \n",
215 | " \n",
216 | " return tau\n",
217 | "\n",
218 | " def _log_like(self, home_goals, away_goals, home_rate, away_rate, rho):\n",
219 | " return (\n",
220 | " scipy.stats.poisson.logpmf(home_goals, home_rate) +\n",
221 | " scipy.stats.poisson.logpmf(away_goals, away_rate) +\n",
222 | " np.log(self._tau(home_goals, away_goals, home_rate, away_rate, rho))\n",
223 | " )\n",
224 | "\n",
225 | " def objective_fn(self, data, home_goals, away_goals, weights, home_X, away_X, rho_ix, xs):\n",
226 | " rho = xs[rho_ix]\n",
227 | "\n",
228 | " # Parameters are estimated in log-space, but `scipy.stats.poisson`\n",
229 | " # expects real number inputs, so we have to use `np.exp`\n",
230 | " home_rate = np.exp(np.dot(home_X, xs))\n",
231 | " away_rate = np.exp(np.dot(away_X, xs))\n",
232 | "\n",
233 | " log_like = self._log_like(home_goals, away_goals, home_rate, away_rate, rho)\n",
234 | " pseudo_log_like = log_like * weights\n",
235 | " return -np.sum(pseudo_log_like)\n",
236 | "\n",
237 | " def fit(self, data, **kwargs):\n",
238 | " param_keys, constraints = self.parse_params(data)\n",
239 | "\n",
240 | " init_params = (\n",
241 | " # Attempt to initialise parameters from any already-existing parameters\n",
242 | " # This substantially speeds up fitting during (e.g.) backtesting\n",
243 | " np.asarray([self.params.get(p, 0) for p in param_keys])\n",
244 | " # If the model has no parameters, just initialise with 0s\n",
245 | " if self.params\n",
246 | " else np.zeros(len(param_keys))\n",
247 | " )\n",
248 | "\n",
249 | " # Precalculate the things we can (for speed)\n",
250 | " \n",
251 | " # Create X (feature) matrices for home and away poisson rates\n",
252 | " home_X, away_X = self._create_feature_matrices(param_keys, data)\n",
253 | " \n",
254 | " # Get home goals, away goals, and weights from the data\n",
255 | " home_goals, away_goals = np.empty(len(data)), np.empty(len(data))\n",
256 | " weights = np.empty(len(data))\n",
257 | " for i, row in enumerate(data):\n",
258 | " home_goals[i] = self.home_goals(row)\n",
259 | " away_goals[i] = self.away_goals(row)\n",
260 | " weights[i] = self.weight(row)\n",
261 | " \n",
262 | " # Get the index of the Rho correlation parameter\n",
263 | " rho_ix = param_keys.index(mezzala.parameters.RHO_KEY)\n",
264 | "\n",
265 | " # Optimise!\n",
266 | " with warnings.catch_warnings():\n",
267 | " # This is a hack\n",
268 | " # Because we haven't properly constrained `rho`, it's possible for 0 or even negative\n",
269 | " # values of `tau` (and therefore invalid probabilities)\n",
270 | " # Ignoring the warnings has little practical impact, since the model\n",
271 | " # will still find the objective function's minimum point regardless\n",
272 | " warnings.simplefilter('ignore')\n",
273 | " \n",
274 | " estimate = scipy.optimize.minimize(\n",
275 | " lambda xs: self.objective_fn(data, home_goals, away_goals, weights, home_X, away_X, rho_ix, xs),\n",
276 | " x0=init_params,\n",
277 | " constraints=constraints,\n",
278 | " **kwargs\n",
279 | " )\n",
280 | "\n",
281 | " # Parse the estimates into parameter map\n",
282 | " self.params = self._assign_params(param_keys, estimate.x)\n",
283 | "\n",
284 | " return self\n",
285 | "\n",
286 | " def predict_one(self, row, up_to=26):\n",
287 | " scorelines = list(itertools.product(range(up_to), repeat=2))\n",
288 | "\n",
289 | " home_goals = np.asarray([h for h, a in scorelines])\n",
290 | " away_goals = np.asarray([a for h, a in scorelines])\n",
291 | " \n",
292 | " param_keys = self.params.keys()\n",
293 | " param_values = np.asarray([v for v in self.params.values()])\n",
294 | " \n",
295 | " home_X, away_X = self._create_feature_matrices(param_keys, [row])\n",
296 | " \n",
297 | " home_rate = np.exp(np.dot(home_X, param_values))\n",
298 | " away_rate = np.exp(np.dot(away_X, param_values))\n",
299 | " rho = self.params[mezzala.parameters.RHO_KEY]\n",
300 | " \n",
301 | " probs = np.exp(self._log_like(home_goals, away_goals, home_rate, away_rate, rho))\n",
302 | " \n",
303 | " return [ScorelinePrediction(*vals) for vals in zip(home_goals.tolist(), away_goals.tolist(), probs)]\n",
304 | "\n",
305 | " def predict(self, data, up_to=26):\n",
306 | " scorelines = [self.predict_one(row, up_to=up_to) for row in data]\n",
307 | " return scorelines"
308 | ]
309 | }
310 | ],
311 | "metadata": {
312 | "kernelspec": {
313 | "display_name": "Python 3",
314 | "language": "python",
315 | "name": "python3"
316 | }
317 | },
318 | "nbformat": 4,
319 | "nbformat_minor": 5
320 | }
321 |
--------------------------------------------------------------------------------
/nbs/adapters.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "220e6de0",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "# default_exp adapters"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "id": "6c394436",
16 | "metadata": {},
17 | "source": [
18 | "# Data Adapters"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "id": "a81115fc",
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "#hide\n",
29 | "from nbdev.showdoc import *"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": null,
35 | "id": "d8ad4fbc",
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "import dataclasses\n",
40 | "import typing"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "id": "9b18f71f",
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "#export\n",
51 | "import collections\n",
52 | "import functools\n",
53 | "\n",
54 | "import mezzala.parameters"
55 | ]
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "id": "86cfc9f2",
60 | "metadata": {},
61 | "source": [
62 | "## Basic adapters"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": null,
68 | "id": "211bdaa4",
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "#export\n",
73 | "\n",
74 | "\n",
75 | "class KeyAdapter:\n",
76 | " \"\"\"\n",
77 | " Get data from subscriptable objects.\n",
78 | " \"\"\"\n",
79 | " \n",
80 | " def __init__(self, home_goals, away_goals, **kwargs):\n",
81 | " self._lookup = {\n",
82 | " 'home_goals': home_goals,\n",
83 | " 'away_goals': away_goals,\n",
84 | " **kwargs\n",
85 | " }\n",
86 | " \n",
87 | " def __repr__(self):\n",
88 | " args_repr = ', '.join(f'{k}={repr(v)}' for k, v in self._lookup.items())\n",
89 | " return f'KeyAdapter({args_repr})'\n",
90 | "\n",
91 | " def _get_in(self, row, item):\n",
92 | " if isinstance(item, list):\n",
93 | " return functools.reduce(lambda d, i: d[i], item, row)\n",
94 | " return row[item]\n",
95 | " \n",
96 | " def __getattr__(self, key): \n",
97 | " def getter(row):\n",
98 | " return self._get_in(row, self._lookup[key])\n",
99 | " return getter"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "id": "e7427e3d",
105 | "metadata": {},
106 | "source": [
107 | "Anything subscriptable can be with this type of adapter. For example,\n",
108 | "you might have input data as a list of tuples (e.g. using Python's\n",
109 | "in-built `csv` library)"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "id": "37639fc2",
116 | "metadata": {},
117 | "outputs": [],
118 | "source": [
119 | "index_adapter = KeyAdapter(0, 1)\n",
120 | "\n",
121 | "assert index_adapter.home_goals([1, 2]) == 1\n",
122 | "assert index_adapter.away_goals([1, 2]) == 2"
123 | ]
124 | },
125 | {
126 | "cell_type": "markdown",
127 | "id": "dcee4f73",
128 | "metadata": {},
129 | "source": [
130 | "Or, you might be using a list of dicts."
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": null,
136 | "id": "eb152699",
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "dict_adapter = KeyAdapter('hg', 'ag', home_team='home', away_team='away')\n",
141 | "\n",
142 | "example_dict = {\n",
143 | " 'home': 'Team 1',\n",
144 | " 'away': 'Team 2',\n",
145 | " 'hg': 4,\n",
146 | " 'ag': 3,\n",
147 | "}\n",
148 | "\n",
149 | "assert dict_adapter.home_goals(example_dict) == 4\n",
150 | "assert dict_adapter.away_goals(example_dict) == 3\n",
151 | "assert dict_adapter.home_team(example_dict) == 'Team 1'\n",
152 | "assert dict_adapter.away_team(example_dict) == 'Team 2'"
153 | ]
154 | },
155 | {
156 | "cell_type": "markdown",
157 | "id": "c2c6b4dc",
158 | "metadata": {},
159 | "source": [
160 | "Nested data can be supplied using a list"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "id": "fd4dd144",
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "nested_dict_adapter = KeyAdapter(\n",
171 | " home_goals=['scoreline', 0], \n",
172 | " away_goals=['scoreline', 1]\n",
173 | ")\n",
174 | "\n",
175 | "example_nested_dict = {\n",
176 | " 'scoreline': [1, 1]\n",
177 | "}\n",
178 | "\n",
179 | "assert nested_dict_adapter.home_goals(example_nested_dict) == 1\n",
180 | "assert nested_dict_adapter.away_goals(example_nested_dict) == 1"
181 | ]
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "id": "5b8acbec",
186 | "metadata": {},
187 | "source": [
188 | "`KeyAdapter` could be used alongside `pd.DataFrame.iterrows` as well; however, it is much faster when using `pd.DataFrame.itertuples`.\n",
189 | "\n",
190 | "Likewise, you can't use a `KeyAdapter` with custom objects (e.g. dataclasses).\n",
191 | "\n",
192 | "In this case, you need an `AttributeAdapter`."
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": null,
198 | "id": "4d42dc0a",
199 | "metadata": {},
200 | "outputs": [],
201 | "source": [
202 | "#export\n",
203 | "\n",
204 | "\n",
205 | "class AttributeAdapter:\n",
206 | " \"\"\"\n",
207 | " Get data from object attributes.\n",
208 | " \"\"\"\n",
209 | " def __init__(self, home_goals, away_goals, **kwargs):\n",
210 | " self._lookup = {\n",
211 | " 'home_goals': home_goals,\n",
212 | " 'away_goals': away_goals,\n",
213 | " **kwargs\n",
214 | " }\n",
215 | " \n",
216 | " def __repr__(self):\n",
217 | " args_repr = ', '.join(f'{k}={repr(v)}' for k, v in self._lookup.items())\n",
218 | " return f'KeyAdapter({args_repr})'\n",
219 | " \n",
220 | " def _get_in(self, row, item):\n",
221 | " if isinstance(item, list):\n",
222 | " return functools.reduce(getattr, item, row)\n",
223 | " return getattr(row, item)\n",
224 | " \n",
225 | " def __getattr__(self, key): \n",
226 | " def getter(row):\n",
227 | " return self._get_in(row, self._lookup[key])\n",
228 | " return getter"
229 | ]
230 | },
231 | {
232 | "cell_type": "code",
233 | "execution_count": null,
234 | "id": "1d3fb1dc",
235 | "metadata": {},
236 | "outputs": [],
237 | "source": [
238 | "@dataclasses.dataclass()\n",
239 | "class ExampleData:\n",
240 | " hg: int\n",
241 | " ag: int\n",
242 | " home: str\n",
243 | " away: str\n",
244 | "\n",
245 | "\n",
246 | "attr_adapter = AttributeAdapter('hg', 'ag', home_team='home', away_team='away')\n",
247 | "\n",
248 | "\n",
249 | "example_attr = ExampleData(\n",
250 | " home='Another home team',\n",
251 | " away='Another away team',\n",
252 | " hg=5,\n",
253 | " ag=1,\n",
254 | ")\n",
255 | "\n",
256 | "assert attr_adapter.home_goals(example_attr) == 5\n",
257 | "assert attr_adapter.away_goals(example_attr) == 1\n",
258 | "assert attr_adapter.home_team(example_attr) == 'Another home team'\n",
259 | "assert attr_adapter.away_team(example_attr) == 'Another away team'"
260 | ]
261 | },
262 | {
263 | "cell_type": "markdown",
264 | "id": "0de5484e",
265 | "metadata": {},
266 | "source": [
267 | "As with `KeyAdapter`, nested attributes can also be fetched using lists"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": null,
273 | "id": "ae96b0e2",
274 | "metadata": {},
275 | "outputs": [],
276 | "source": [
277 | "@dataclasses.dataclass()\n",
278 | "class Scoreline:\n",
279 | " home: int\n",
280 | " away: int\n",
281 | "\n",
282 | "\n",
283 | "@dataclasses.dataclass()\n",
284 | "class ExampleNestedData:\n",
285 | " scoreline: Scoreline\n",
286 | " home: str\n",
287 | " away: str\n",
288 | "\n",
289 | "\n",
290 | "nested_attr_adapter = AttributeAdapter(\n",
291 | " home_team='home',\n",
292 | " home_goals=['scoreline', 'home'], \n",
293 | " away_team='away',\n",
294 | " away_goals=['scoreline', 'away'],\n",
295 | ")\n",
296 | "\n",
297 | "example_nested_attr = ExampleNestedData(\n",
298 | " home='Another home team',\n",
299 | " away='Another away team',\n",
300 | " scoreline=Scoreline(2, 5),\n",
301 | ")\n",
302 | "\n",
303 | "assert nested_attr_adapter.home_goals(example_nested_attr) == 2\n",
304 | "assert nested_attr_adapter.away_goals(example_nested_attr) == 5"
305 | ]
306 | },
307 | {
308 | "cell_type": "markdown",
309 | "id": "456f7c47",
310 | "metadata": {},
311 | "source": [
312 | "## Composite adapters"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": null,
318 | "id": "a01bc537",
319 | "metadata": {},
320 | "outputs": [],
321 | "source": [
322 | "#export\n",
323 | "\n",
324 | "\n",
325 | "class LumpedAdapter:\n",
326 | " \"\"\" \n",
327 | " Lump terms which have appeared below a minimum number of times in\n",
328 | " the training data into a placeholder term\n",
329 | " \"\"\"\n",
330 | "\n",
331 | " def __init__(self, base_adapter, **kwargs):\n",
332 | " self.base_adapter = base_adapter\n",
333 | " \n",
334 | " # Match terms to placeholders\n",
335 | " # If multiple terms have the same placeholder (e.g. Home and Away\n",
336 | " # teams) they will share a counter\n",
337 | " self._term_lookup = kwargs\n",
338 | " \n",
339 | " self._counters = None\n",
340 | " \n",
341 | " def __repr__(self):\n",
342 | " args_repr = ', '.join(f'{k}={repr(v)}' for k, v in self._term_lookup.items())\n",
343 | " return f'LumpedAdapter(base_adapter={repr(self.base_adapter)}, {args_repr})'\n",
344 | " \n",
345 | " def fit(self, data):\n",
346 | " self._counters = {}\n",
347 | " for term, (placeholder, _) in self._term_lookup.items():\n",
348 | " # Initialise with an empty counter if it doesn't already exist\n",
349 | " # We need to do this so that multiple terms sharing the same counter\n",
350 | " # (home and away teams) are shared\n",
351 | " init_counter = self._counters.get(placeholder, collections.Counter())\n",
352 | " \n",
353 | " counter = collections.Counter(getattr(self.base_adapter, term)(row) for row in data)\n",
354 | " \n",
355 | " self._counters[placeholder] = init_counter + counter\n",
356 | " return self\n",
357 | " \n",
358 | " def __getattr__(self, key):\n",
359 | " if not self._counters:\n",
360 | " raise ValueError(\n",
361 | " 'No counts found! You need to call `LumpedAdapter.fit` '\n",
362 | " 'on the training data before you can use it!'\n",
363 | " )\n",
364 | " \n",
365 | " def getter(row):\n",
366 | " value = getattr(self.base_adapter, key)(row)\n",
367 | " placeholder, min_obs = self._term_lookup.get(key, (None, None))\n",
368 | " if placeholder and self._counters[placeholder][value] < min_obs:\n",
369 | " return placeholder\n",
370 | " return value\n",
371 | " return getter"
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": null,
377 | "id": "8255aa1a",
378 | "metadata": {},
379 | "outputs": [
380 | {
381 | "data": {
382 | "text/plain": [
383 | "LumpedAdapter(base_adapter=KeyAdapter(home_goals='hg', away_goals='ag', home_team='home', away_team='away'), home_team=('Other team', 5), away_team=('Other team', 5))"
384 | ]
385 | },
386 | "execution_count": null,
387 | "metadata": {},
388 | "output_type": "execute_result"
389 | }
390 | ],
391 | "source": [
392 | "example_lumped_data = [\n",
393 | " *([example_dict]*4), # i.e., 'Team 1' and 'Team 2' appear in the data 4 times\n",
394 | " {'away': 'Team 1', # 'Team 1' now appears an additional time, (5 total)\n",
395 | " # Although this time appears as an *away* team\n",
396 | " 'home': 'Team 3', # While 'Team 3' appears once\n",
397 | " 'hg': 4, \n",
398 | " 'ag': 3},\n",
399 | "]\n",
400 | "\n",
401 | "\n",
402 | "lumped_dict_adapter = LumpedAdapter(\n",
403 | " base_adapter=dict_adapter,\n",
404 | " home_team=('Other team', 5), # Because `home_team` and `away_team` share the same\n",
405 | " # placeholder value ('Other team'), they are counted\n",
406 | " # together. I.e. a team has to appear at least 5 times\n",
407 | " # as _either_ the home team, or the away team\n",
408 | " away_team=('Other team', 5)\n",
409 | ")\n",
410 | "lumped_dict_adapter.fit(example_lumped_data)\n",
411 | "\n",
412 | "lumped_dict_adapter"
413 | ]
414 | },
415 | {
416 | "cell_type": "code",
417 | "execution_count": null,
418 | "id": "a16e9648",
419 | "metadata": {},
420 | "outputs": [],
421 | "source": [
422 | "example_lumped_1 = {\n",
423 | " 'home': 'Team 1',\n",
424 | " 'away': 'Team 3',\n",
425 | " 'hg': 1, \n",
426 | " 'ag': 2\n",
427 | "}\n",
428 | "\n",
429 | "# A team with more than the minimum number of observations appears as before\n",
430 | "assert lumped_dict_adapter.home_team(example_lumped_1) == 'Team 1'\n",
431 | "\n",
432 | "# But a team with fewer observations appears as the placeholder\n",
433 | "assert lumped_dict_adapter.away_team(example_lumped_1) == 'Other team'\n",
434 | "\n",
435 | "# Meanwhile, values without a placeholder in the LumpedAdapter\n",
436 | "# also appear as before\n",
437 | "assert lumped_dict_adapter.home_goals(example_lumped_1) == 1\n",
438 | "assert lumped_dict_adapter.away_goals(example_lumped_1) == 2"
439 | ]
440 | },
441 | {
442 | "cell_type": "markdown",
443 | "id": "b8440a77",
444 | "metadata": {},
445 | "source": [
446 | "Using a lumped adapter can also allow you to handle items which didn't appear in the training set at all:"
447 | ]
448 | },
449 | {
450 | "cell_type": "code",
451 | "execution_count": null,
452 | "id": "5829767f",
453 | "metadata": {},
454 | "outputs": [],
455 | "source": [
456 | "example_lumped_2 = {\n",
457 | " 'home': 'Team 2', # Only appeared 4 times, below threshold of 5\n",
458 | " 'away': 'Team 4', # Appeared 0 times in the data\n",
459 | " 'hg': 1, \n",
460 | " 'ag': 2\n",
461 | "}\n",
462 | "\n",
463 | "assert lumped_dict_adapter.home_team(example_lumped_2) == 'Other team'\n",
464 | "assert lumped_dict_adapter.away_team(example_lumped_2) == 'Other team'"
465 | ]
466 | }
467 | ],
468 | "metadata": {
469 | "kernelspec": {
470 | "display_name": "Python 3",
471 | "language": "python",
472 | "name": "python3"
473 | }
474 | },
475 | "nbformat": 4,
476 | "nbformat_minor": 5
477 | }
478 |
--------------------------------------------------------------------------------
/nbs/core.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# default_exp __init__"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "# Core\n",
17 | "\n",
18 | "> Team-strength models in Python"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "#hide\n",
28 | "from nbdev.showdoc import *"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "import json\n",
38 | "import datetime as dt\n",
39 | "import numpy as np\n",
40 | "import pprint"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "#export\n",
50 | "\n",
51 | "# For now, just re-export everything\n",
52 | "from mezzala.adapters import *\n",
53 | "from mezzala.blocks import *\n",
54 | "from mezzala.models import *\n",
55 | "from mezzala.weights import *\n",
56 | "from mezzala.parameters import *"
57 | ]
58 | },
59 | {
60 | "cell_type": "markdown",
61 | "metadata": {},
62 | "source": [
63 | "Let's demo"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {},
70 | "outputs": [
71 | {
72 | "data": {
73 | "text/plain": [
74 | "[{'date': datetime.datetime(2015, 8, 8, 0, 0),\n",
75 | " 'team1': 'Manchester United FC',\n",
76 | " 'team2': 'Tottenham Hotspur FC',\n",
77 | " 'score': {'ft': [1, 0]}},\n",
78 | " {'date': datetime.datetime(2015, 8, 8, 0, 0),\n",
79 | " 'team1': 'AFC Bournemouth',\n",
80 | " 'team2': 'Aston Villa FC',\n",
81 | " 'score': {'ft': [0, 1]}},\n",
82 | " {'date': datetime.datetime(2015, 8, 8, 0, 0),\n",
83 | " 'team1': 'Leicester City FC',\n",
84 | " 'team2': 'Sunderland AFC',\n",
85 | " 'score': {'ft': [4, 2]}}]"
86 | ]
87 | },
88 | "execution_count": null,
89 | "metadata": {},
90 | "output_type": "execute_result"
91 | }
92 | ],
93 | "source": [
94 | "with open('../data/premier-league-1516.json', 'r') as f:\n",
95 | " pl_1516 = json.load(f)\n",
96 | "\n",
97 | "# Let's parse the dates, too\n",
98 | "for match in pl_1516:\n",
99 | " match['date'] = dt.datetime.strptime(match['date'], '%Y-%m-%d')\n",
100 | " \n",
101 | "pl_1516[0:3]"
102 | ]
103 | },
104 | {
105 | "cell_type": "markdown",
106 | "metadata": {},
107 | "source": [
108 | "A model in `mezzala` is composed of 2 parts:\n",
109 | "\n",
110 | "* Model blocks (see `mezzala.blocks`)\n",
111 | "* An adapter (see `mezzala.adapters`)"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "metadata": {},
117 | "source": [
118 | "The model blocks determine which terms your model estimates. In general, you will want\n",
119 | "to estimate offensive and defensive strength for each team (`TeamStrength`) and \n",
120 | "as well as home advantage `HomeAdvantage`.\n",
121 | "\n",
122 | "The selected model blocks can be supplied to the model as a list:"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": null,
128 | "metadata": {},
129 | "outputs": [],
130 | "source": [
131 | "blocks = [TeamStrength(), HomeAdvantage()]"
132 | ]
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "metadata": {},
137 | "source": [
138 | "An adapter connects your model to the data source. In other words, it tells the model how find the information needed to fit.\n",
139 | "\n",
140 | "The information needed is determined by which model blocks are used. In our case,\n",
141 | "\n",
142 | "* All models require `home_goals` and `away_goals`\n",
143 | "* `TeamStrength` - requires `home_team` and `away_team`\n",
144 | "\n",
145 | "`HomeAdvantage` doesn't require any information, since it assumes all matches have equal home-field advantage by default."
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": null,
151 | "metadata": {},
152 | "outputs": [],
153 | "source": [
154 | "adapter = KeyAdapter( # `KeyAdapter` = data['...']\n",
155 | " home_team='team1',\n",
156 | " away_team='team2',\n",
157 | " home_goals=['score', 'ft', 0], # Get nested fields with lists of fields\n",
158 | " away_goals=['score', 'ft', 1], # i.e. data['score']['ft'][1]\n",
159 | ")"
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "metadata": {},
165 | "source": [
166 | "Pulling this together, we can construct a model from an adapter and blocks"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": null,
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "model = DixonColes(adapter=adapter, blocks=blocks)\n",
176 | "model.fit(pl_1516)\n",
177 | "\n",
178 | "# All estimates should be valid numbers\n",
179 | "assert all(not np.isnan(x) for x in model.params.values())\n",
180 | "\n",
181 | "# Home advantage should be positive\n",
182 | "assert 1.0 < np.exp(model.params[HFA_KEY]) < 2.0"
183 | ]
184 | },
185 | {
186 | "cell_type": "markdown",
187 | "metadata": {},
188 | "source": [
189 | "Let's inspect the parameters a bit. First, let's look at the boring (non-team) ones:"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": null,
195 | "metadata": {},
196 | "outputs": [
197 | {
198 | "name": "stdout",
199 | "output_type": "stream",
200 | "text": [
201 | "ParameterKey(label='Home-field advantage') : 1.23\n",
202 | "ParameterKey(label='Rho') : 0.94\n"
203 | ]
204 | }
205 | ],
206 | "source": [
207 | "param_keys = model.params.keys()\n",
208 | "param_key_len = max(len(str(k)) for k in param_keys)\n",
209 | "\n",
210 | "for k in param_keys:\n",
211 | " if not isinstance(k, TeamParameterKey):\n",
212 | " key_str = str(k).ljust(param_key_len + 1)\n",
213 | " print(f'{key_str}: {np.exp(model.params[k]):0.2f}')"
214 | ]
215 | },
216 | {
217 | "cell_type": "markdown",
218 | "metadata": {},
219 | "source": [
220 | "And the team ones. Let's look at each team's attacking quality:"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": null,
226 | "metadata": {},
227 | "outputs": [
228 | {
229 | "name": "stdout",
230 | "output_type": "stream",
231 | "text": [
232 | "Manchester City FC: 1.38\n",
233 | "Tottenham Hotspur FC: 1.33\n",
234 | "Leicester City FC: 1.31\n",
235 | "West Ham United FC: 1.27\n",
236 | "Arsenal FC: 1.25\n",
237 | "Liverpool FC: 1.23\n",
238 | "Everton FC: 1.16\n",
239 | "Chelsea FC: 1.15\n",
240 | "Southampton FC: 1.14\n",
241 | "Manchester United FC: 0.94\n",
242 | "Sunderland AFC: 0.94\n",
243 | "AFC Bournemouth: 0.89\n",
244 | "Newcastle United FC: 0.87\n",
245 | "Swansea City FC: 0.82\n",
246 | "Stoke City FC: 0.81\n",
247 | "Watford FC: 0.78\n",
248 | "Norwich City FC: 0.77\n",
249 | "Crystal Palace FC: 0.76\n",
250 | "West Bromwich Albion FC: 0.66\n",
251 | "Aston Villa FC: 0.54\n"
252 | ]
253 | }
254 | ],
255 | "source": [
256 | "teams = {k.label for k in param_keys if isinstance(k, TeamParameterKey)}\n",
257 | "\n",
258 | "team_offence = [(t, np.exp(model.params[OffenceParameterKey(t)])) for t in teams]\n",
259 | "for team, estimate in sorted(team_offence, key=lambda x: -x[1]):\n",
260 | " print(f'{team}: {estimate:0.2f}')"
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": null,
266 | "metadata": {},
267 | "outputs": [
268 | {
269 | "name": "stdout",
270 | "output_type": "stream",
271 | "text": [
272 | "Manchester United FC: 0.82\n",
273 | "Tottenham Hotspur FC: 0.84\n",
274 | "Leicester City FC: 0.86\n",
275 | "Arsenal FC: 0.86\n",
276 | "Southampton FC: 0.97\n",
277 | "Manchester City FC: 0.99\n",
278 | "West Bromwich Albion FC: 1.10\n",
279 | "Watford FC: 1.17\n",
280 | "Liverpool FC: 1.19\n",
281 | "Crystal Palace FC: 1.19\n",
282 | "Swansea City FC: 1.21\n",
283 | "West Ham United FC: 1.22\n",
284 | "Chelsea FC: 1.26\n",
285 | "Stoke City FC: 1.28\n",
286 | "Everton FC: 1.32\n",
287 | "Sunderland AFC: 1.46\n",
288 | "Newcastle United FC: 1.52\n",
289 | "Norwich City FC: 1.55\n",
290 | "AFC Bournemouth: 1.57\n",
291 | "Aston Villa FC: 1.75\n"
292 | ]
293 | }
294 | ],
295 | "source": [
296 | "team_defence = [(t, np.exp(model.params[DefenceParameterKey(t)])) for t in teams]\n",
297 | "for team, estimate in sorted(team_defence, key=lambda x: x[1]):\n",
298 | " print(f'{team}: {estimate:0.2f}')"
299 | ]
300 | },
301 | {
302 | "cell_type": "markdown",
303 | "metadata": {},
304 | "source": [
305 | "Making predictions for a single match"
306 | ]
307 | },
308 | {
309 | "cell_type": "code",
310 | "execution_count": null,
311 | "metadata": {},
312 | "outputs": [
313 | {
314 | "data": {
315 | "text/plain": [
316 | "[ScorelinePrediction(home_goals=0, away_goals=0, probability=0.0619999820129133),\n",
317 | " ScorelinePrediction(home_goals=0, away_goals=1, probability=0.03970300056443736),\n",
318 | " ScorelinePrediction(home_goals=0, away_goals=2, probability=0.018568356365315872),\n",
319 | " ScorelinePrediction(home_goals=0, away_goals=3, probability=0.005037154039480389),\n",
320 | " ScorelinePrediction(home_goals=0, away_goals=4, probability=0.0010248451849317163)]"
321 | ]
322 | },
323 | "execution_count": null,
324 | "metadata": {},
325 | "output_type": "execute_result"
326 | }
327 | ],
328 | "source": [
329 | "scorelines = model.predict_one({\n",
330 | " 'team1': 'Manchester City FC',\n",
331 | " 'team2': 'Swansea City FC',\n",
332 | "})\n",
333 | "\n",
334 | "# Probabilities should sum to 1\n",
335 | "assert np.isclose(\n",
336 | " sum(p.probability for p in scorelines),\n",
337 | " 1.0\n",
338 | ")\n",
339 | "\n",
340 | "scorelines[0:5]"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": null,
346 | "metadata": {},
347 | "outputs": [
348 | {
349 | "data": {
350 | "text/plain": [
351 | "[OutcomePrediction(outcome=Outcomes('Home win'), probability=0.658650484098139),\n",
352 | " OutcomePrediction(outcome=Outcomes('Draw'), probability=0.21019557218753862),\n",
353 | " OutcomePrediction(outcome=Outcomes('Away win'), probability=0.13115394371432296)]"
354 | ]
355 | },
356 | "execution_count": null,
357 | "metadata": {},
358 | "output_type": "execute_result"
359 | }
360 | ],
361 | "source": [
362 | "outcomes = scorelines_to_outcomes(scorelines)\n",
363 | "\n",
364 | "# MCFC should have a better chance of beating Swansea\n",
365 | "# at home than Swansea do of winning away\n",
366 | "assert outcomes[Outcomes('Home win')].probability > outcomes[Outcomes('Away win')].probability\n",
367 | "\n",
368 | "list(outcomes.values())"
369 | ]
370 | },
371 | {
372 | "cell_type": "markdown",
373 | "metadata": {},
374 | "source": [
375 | "Or for multiple matches"
376 | ]
377 | },
378 | {
379 | "cell_type": "code",
380 | "execution_count": null,
381 | "metadata": {},
382 | "outputs": [],
383 | "source": [
384 | "many_scorelines = model.predict([\n",
385 | " {'team1': 'Manchester City FC',\n",
386 | " 'team2': 'Swansea City FC'},\n",
387 | " {'team1': 'Manchester City FC',\n",
388 | " 'team2': 'West Ham United FC'}\n",
389 | "])"
390 | ]
391 | },
392 | {
393 | "cell_type": "markdown",
394 | "metadata": {},
395 | "source": [
396 | "What about a model with a different weighting method?\n",
397 | "\n",
398 | "By default, the `DixonColes` model weights all matches equally. However, it's more realistic to give matches\n",
399 | "closer to the current date a bigger weight than those a long time ago.\n",
400 | "\n",
401 | "The original Dixon-Coles paper suggests using an exponential weight, and we can use the same:"
402 | ]
403 | },
404 | {
405 | "cell_type": "code",
406 | "execution_count": null,
407 | "metadata": {},
408 | "outputs": [],
409 | "source": [
410 | "season_end_date = max(match['date'] for match in pl_1516)\n",
411 | "\n",
412 | "weight = ExponentialWeight(\n",
413 | " # Value of `epsilon` is taken from the original paper\n",
414 | " epsilon=-0.0065, \n",
415 | " key=lambda x: (season_end_date - x['date']).days\n",
416 | ")"
417 | ]
418 | },
419 | {
420 | "cell_type": "code",
421 | "execution_count": null,
422 | "metadata": {},
423 | "outputs": [
424 | {
425 | "data": {
426 | "text/plain": [
427 | "DixonColes(adapter=KeyAdapter(home_goals=['score', 'ft', 0], away_goals=['score', 'ft', 1], home_team='team1', away_team='team2'), blocks=[TeamStrength(), HomeAdvantage()]), weight=ExponentialWeight(epsilon=-0.0065, key= at 0x11eecd158>)"
428 | ]
429 | },
430 | "execution_count": null,
431 | "metadata": {},
432 | "output_type": "execute_result"
433 | }
434 | ],
435 | "source": [
436 | "model_exp = DixonColes(\n",
437 | " adapter=adapter,\n",
438 | " blocks=blocks,\n",
439 | " weight=weight\n",
440 | ")\n",
441 | "model_exp.fit(pl_1516)"
442 | ]
443 | },
444 | {
445 | "cell_type": "markdown",
446 | "metadata": {},
447 | "source": [
448 | "How much does that change the ratings at season-end?"
449 | ]
450 | },
451 | {
452 | "cell_type": "code",
453 | "execution_count": null,
454 | "metadata": {},
455 | "outputs": [
456 | {
457 | "name": "stdout",
458 | "output_type": "stream",
459 | "text": [
460 | "OffenceParameterKey(label='AFC Bournemouth') : 0.89 -> 0.88 (0.99)\n",
461 | "DefenceParameterKey(label='AFC Bournemouth') : 1.57 -> 1.61 (1.02)\n",
462 | "OffenceParameterKey(label='Arsenal FC') : 1.25 -> 1.25 (1.00)\n",
463 | "DefenceParameterKey(label='Arsenal FC') : 0.86 -> 0.85 (0.98)\n",
464 | "OffenceParameterKey(label='Aston Villa FC') : 0.54 -> 0.49 (0.91)\n",
465 | "DefenceParameterKey(label='Aston Villa FC') : 1.75 -> 1.83 (1.04)\n",
466 | "OffenceParameterKey(label='Chelsea FC') : 1.15 -> 1.20 (1.04)\n",
467 | "DefenceParameterKey(label='Chelsea FC') : 1.26 -> 1.16 (0.92)\n",
468 | "OffenceParameterKey(label='Crystal Palace FC') : 0.76 -> 0.70 (0.92)\n",
469 | "DefenceParameterKey(label='Crystal Palace FC') : 1.19 -> 1.25 (1.05)\n",
470 | "OffenceParameterKey(label='Everton FC') : 1.16 -> 1.02 (0.88)\n",
471 | "DefenceParameterKey(label='Everton FC') : 1.32 -> 1.33 (1.01)\n",
472 | "ParameterKey(label='Home-field advantage') : 1.23 -> 1.30 (1.05)\n",
473 | "OffenceParameterKey(label='Leicester City FC') : 1.31 -> 1.25 (0.95)\n",
474 | "DefenceParameterKey(label='Leicester City FC') : 0.86 -> 0.68 (0.79)\n",
475 | "OffenceParameterKey(label='Liverpool FC') : 1.23 -> 1.33 (1.08)\n",
476 | "DefenceParameterKey(label='Liverpool FC') : 1.19 -> 1.18 (1.00)\n",
477 | "OffenceParameterKey(label='Manchester City FC') : 1.38 -> 1.36 (0.98)\n",
478 | "DefenceParameterKey(label='Manchester City FC') : 0.99 -> 1.00 (1.01)\n",
479 | "OffenceParameterKey(label='Manchester United FC') : 0.94 -> 0.92 (0.98)\n",
480 | "DefenceParameterKey(label='Manchester United FC') : 0.82 -> 0.83 (1.01)\n",
481 | "OffenceParameterKey(label='Newcastle United FC') : 0.87 -> 0.93 (1.08)\n",
482 | "DefenceParameterKey(label='Newcastle United FC') : 1.52 -> 1.37 (0.90)\n",
483 | "OffenceParameterKey(label='Norwich City FC') : 0.77 -> 0.69 (0.90)\n",
484 | "DefenceParameterKey(label='Norwich City FC') : 1.55 -> 1.51 (0.97)\n",
485 | "ParameterKey(label='Rho') : 0.94 -> 0.91 (0.97)\n",
486 | "OffenceParameterKey(label='Southampton FC') : 1.14 -> 1.26 (1.11)\n",
487 | "DefenceParameterKey(label='Southampton FC') : 0.97 -> 0.95 (0.98)\n",
488 | "OffenceParameterKey(label='Stoke City FC') : 0.81 -> 0.82 (1.01)\n",
489 | "DefenceParameterKey(label='Stoke City FC') : 1.28 -> 1.42 (1.11)\n",
490 | "OffenceParameterKey(label='Sunderland AFC') : 0.94 -> 0.99 (1.05)\n",
491 | "DefenceParameterKey(label='Sunderland AFC') : 1.46 -> 1.22 (0.84)\n",
492 | "OffenceParameterKey(label='Swansea City FC') : 0.82 -> 0.88 (1.08)\n",
493 | "DefenceParameterKey(label='Swansea City FC') : 1.21 -> 1.18 (0.97)\n",
494 | "OffenceParameterKey(label='Tottenham Hotspur FC') : 1.33 -> 1.34 (1.01)\n",
495 | "DefenceParameterKey(label='Tottenham Hotspur FC') : 0.84 -> 0.95 (1.12)\n",
496 | "OffenceParameterKey(label='Watford FC') : 0.78 -> 0.77 (0.99)\n",
497 | "DefenceParameterKey(label='Watford FC') : 1.17 -> 1.33 (1.14)\n",
498 | "OffenceParameterKey(label='West Bromwich Albion FC') : 0.66 -> 0.60 (0.91)\n",
499 | "DefenceParameterKey(label='West Bromwich Albion FC') : 1.10 -> 1.04 (0.94)\n",
500 | "OffenceParameterKey(label='West Ham United FC') : 1.27 -> 1.33 (1.04)\n",
501 | "DefenceParameterKey(label='West Ham United FC') : 1.22 -> 1.33 (1.09)\n"
502 | ]
503 | }
504 | ],
505 | "source": [
506 | "for k in sorted(param_keys, key=lambda x: x.label):\n",
507 | " key_str = str(k).ljust(param_key_len + 1)\n",
508 | " model_param = np.exp(model.params[k])\n",
509 | " model_exp_param = np.exp(model_exp.params[k])\n",
510 | " print(f'{key_str}: {model_param:0.2f} -> {model_exp_param:0.2f} ({model_exp_param/model_param:0.2f})')"
511 | ]
512 | },
513 | {
514 | "cell_type": "code",
515 | "execution_count": null,
516 | "metadata": {},
517 | "outputs": [],
518 | "source": []
519 | }
520 | ],
521 | "metadata": {
522 | "kernelspec": {
523 | "display_name": "Python 3",
524 | "language": "python",
525 | "name": "python3"
526 | }
527 | },
528 | "nbformat": 4,
529 | "nbformat_minor": 2
530 | }
531 |
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Mezzala
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 | summary: "Models for estimating football (soccer) team-strength"
10 | description: "Models for estimating football (soccer) team-strength"
11 | nb_path: "nbs/index.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 |
33 |
34 |
Install
35 |
36 |
37 |
38 |
39 |
40 |
pip install mezzala
41 |
42 |
43 |
44 |
45 |
46 |
47 |
How to use
48 |
49 |
50 |
51 | {% raw %}
52 |
53 |
54 |
55 |
56 |
57 |
58 |
importmezzala
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 | {% endraw %}
67 |
68 |
69 |
70 |
Fitting a Dixon-Coles team strength model:
71 |
72 |
73 |
74 |
75 |
76 |
77 |
First, we need to get some data
78 |
79 |
80 |
81 |
82 | {% raw %}
83 |
84 |
85 |
86 |
87 |
88 |
89 |
importitertools
90 | importjson
91 | importurllib.request
92 |
93 |
94 | # Use 2016/17 Premier League data from the openfootball repo
95 | url='https://raw.githubusercontent.com/openfootball/football.json/master/2016-17/en.1.json'
96 |
97 |
98 | response=urllib.request.urlopen(url)
99 | data_raw=json.loads(response.read())
100 |
101 | # Reshape the data to just get the matches
102 | data=list(itertools.chain(*[d['matches']fordindata_raw['rounds']]))
103 |
104 | data[0:3]
105 |
To fit a model with mezzala, you need to create an "adapter". Adapters are used to connect a model to a data source.
150 |
Because our data is a list of dicts, we are going to use a KeyAdapter.
151 |
152 |
153 |
154 |
155 | {% raw %}
156 |
157 |
158 |
159 |
160 |
161 |
162 |
adapter=mezzala.KeyAdapter(# `KeyAdapter` = datum['...']
163 | home_team='team1',
164 | away_team='team2',
165 | home_goals=['score','ft',0],# Get nested fields with lists of fields
166 | away_goals=['score','ft',1],# i.e. datum['score']['ft'][1]
167 | )
168 |
169 | # You'll never need to call the methods on an
170 | # adapter directly, but just to show that it
171 | # works as expected:
172 | adapter.home_team(data[0])
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
'Hull City AFC'
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 | {% endraw %}
197 |
198 |
199 |
200 |
Once we have an adapter for our specific data source, we can fit the model:
Each of these methods return predictions in the form of ScorelinePredictions.
304 |
305 |
predict_one returns a list of ScorelinePredictions
306 |
predict returns a list of ScorelinePredictions for each predicted match (i.e. a list of lists)
307 |
308 |
However, it can sometimes be more useful to have predictions in the form of match outcomes. Mezzala exposes the scorelines_to_outcomes function for this purpose:
It's possible to fit more sophisticated models with mezzala, using weights and model blocks
351 |
Weights
You can weight individual data points by supplying a function (or callable) to the weight argument to DixonColes:
352 |
353 |
354 |
355 |
356 | {% raw %}
357 |
358 |
359 |
360 |
361 |
362 |
363 |
mezzala.DixonColes(
364 | adapter=adapter,
365 | # By default, all data points are weighted equally,
366 | # which is equivalent to:
367 | weight=lambdax:1
368 | )
369 |
Anything subscriptable can be with this type of adapter. For example,
98 | you might have input data as a list of tuples (e.g. using Python's
99 | in-built csv library)
example_lumped_data=[
371 | *([example_dict]*4),# i.e., 'Team 1' and 'Team 2' appear in the data 4 times
372 | {'away':'Team 1',# 'Team 1' now appears an additional time, (5 total)
373 | # Although this time appears as an *away* team
374 | 'home':'Team 3',# While 'Team 3' appears once
375 | 'hg':4,
376 | 'ag':3},
377 | ]
378 |
379 |
380 | lumped_dict_adapter=LumpedAdapter(
381 | base_adapter=dict_adapter,
382 | home_team=('Other team',5),# Because `home_team` and `away_team` share the same
383 | # placeholder value ('Other team'), they are counted
384 | # together. I.e. a team has to appear at least 5 times
385 | # as _either_ the home team, or the away team
386 | away_team=('Other team',5)
387 | )
388 | lumped_dict_adapter.fit(example_lumped_data)
389 |
390 | lumped_dict_adapter
391 |
example_lumped_1={
424 | 'home':'Team 1',
425 | 'away':'Team 3',
426 | 'hg':1,
427 | 'ag':2
428 | }
429 |
430 | # A team with more than the minimum number of observations appears as before
431 | assertlumped_dict_adapter.home_team(example_lumped_1)=='Team 1'
432 |
433 | # But a team with fewer observations appears as the placeholder
434 | assertlumped_dict_adapter.away_team(example_lumped_1)=='Other team'
435 |
436 | # Meanwhile, values without a placeholder in the LumpedAdapter
437 | # also appear as before
438 | assertlumped_dict_adapter.home_goals(example_lumped_1)==1
439 | assertlumped_dict_adapter.away_goals(example_lumped_1)==2
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 | {% endraw %}
448 |
449 |
450 |
451 |
Using a lumped adapter can also allow you to handle items which didn't appear in the training set at all:
452 |
453 |
454 |
455 |
456 | {% raw %}
457 |
458 |
459 |
460 |
461 |
462 |
463 |
example_lumped_2={
464 | 'home':'Team 2',# Only appeared 4 times, below threshold of 5
465 | 'away':'Team 4',# Appeared 0 times in the data
466 | 'hg':1,
467 | 'ag':2
468 | }
469 |
470 | assertlumped_dict_adapter.home_team(example_lumped_2)=='Other team'
471 | assertlumped_dict_adapter.away_team(example_lumped_2)=='Other team'
472 |