├── .github
├── CODEOWNERS
└── CODE_OF_CONDUCT.md
├── .gitignore
├── .travis.yml
├── LICENSE
├── README.rst
├── examples
└── README.md
├── logo
└── metriks-logo.svg
├── metriks
├── __init__.py
└── ranking.py
├── setup.py
├── tests
├── __init__.py
├── test_confusion_matrix_at_k.py
├── test_label_mean_reciprocal_rank.py
├── test_mean_reciprocal_rank.py
├── test_ndcg.py
├── test_precision_at_k.py
└── test_recall_at_k.py
└── tox.ini
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # List of source code paths and code owners
2 | # common services & repos
3 | * @wontonswaggie, @karenclo, @jonathanlunt, @crystal-chung, @seedambiereed, @rdedhia
--------------------------------------------------------------------------------
/.github/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | Open source projects are “living.” Contributions in the form of issues and pull requests are welcomed and encouraged. When you contribute, you explicitly say you are part of the community and abide by its Code of Conduct.
2 |
3 | # The Code
4 |
5 | At Intuit, we foster a kind, respectful, harassment-free cooperative community. Our open source community works to:
6 |
7 | - Be kind and respectful;
8 | - Act as a global community;
9 | - Conduct ourselves professionally.
10 |
11 | As members of this community, we will not tolerate behaviors including, but not limited to:
12 |
13 | - Violent threats or language;
14 | - Discriminatory or derogatory jokes or language;
15 | - Public or private harassment of any kind;
16 | - Other conduct considered inappropriate in a professional setting.
17 |
18 | ## Reporting Concerns
19 |
20 | If you see someone violating the Code of Conduct please email TechOpenSource@intuit.com
21 |
22 | ## Scope
23 |
24 | This code of conduct applies to:
25 |
26 | All repos and communities for Intuit-managed projects, whether or not the text is included in a Intuit-managed project’s repository;
27 |
28 | Individuals or teams representing projects in official capacity, such as via official social media channels or at in-person meetups.
29 |
30 | ## Attribution
31 |
32 | This Code of Conduct is partly inspired by and based on those of Amazon, CocoaPods, GitHub, Microsoft, thoughtbot, and on the Contributor Covenant version 1.4.1.
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .pytest_cache
2 | # Mac
3 | .DS_Store
4 |
5 | ### Python ###
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | env/
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | wheelhouse/
32 | cover/
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *,cover
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 |
61 | # Sphinx documentation
62 | docs/_build/
63 |
64 | # PyBuilder
65 | target/
66 |
67 |
68 | ### PyCharm ###
69 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio
70 |
71 | *.iml
72 |
73 | ## Directory-based project format:
74 | .idea/
75 | # if you remove the above rule, at least ignore the following:
76 |
77 | # User-specific stuff:
78 | # .idea/workspace.xml
79 | # .idea/tasks.xml
80 | # .idea/dictionaries
81 |
82 | # Sensitive or high-churn files:
83 | # .idea/dataSources.ids
84 | # .idea/dataSources.xml
85 | # .idea/sqlDataSources.xml
86 | # .idea/dynamic.xml
87 | # .idea/uiDesigner.xml
88 |
89 | # Gradle:
90 | # .idea/gradle.xml
91 | # .idea/libraries
92 |
93 | # Mongo Explorer plugin:
94 | # .idea/mongoSettings.xml
95 |
96 | ## File-based project format:
97 | *.ipr
98 | *.iws
99 |
100 | ## Plugin-specific files:
101 |
102 | # IntelliJ
103 | /out/
104 |
105 | # mpeltonen/sbt-idea plugin
106 | .idea_modules/
107 |
108 | # JIRA plugin
109 | atlassian-ide-plugin.xml
110 |
111 | # Crashlytics plugin (for Android Studio and IntelliJ)
112 | com_crashlytics_export_strings.xml
113 | crashlytics.properties
114 | crashlytics-build.properties
115 |
116 | # VS Code and virtual env
117 | .vscode/*
118 | venv/*
119 | __version.py
120 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - "3.6"
4 | install:
5 | - pip install pytest-cov
6 | - pip install coveralls
7 | - pip install .
8 | script: pytest --cov=metriks
9 | after_success:
10 | - coveralls
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2019 Intuit Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | this software and associated documentation files (the "Software"), to deal in
7 | the Software without restriction, including without limitation the rights to use,
8 | copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
9 | Software, and to permit persons to whom the Software is furnished to do so,
10 | subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
16 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
17 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
18 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
20 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | .. image:: logo/metriks-logo.svg
2 |
3 | |python| |build| |coverage|
4 |
5 | .. |python| image:: https://img.shields.io/badge/python-3.6%20-blue.svg
6 | :target: https://www.python.org/downloads/release/python-360/
7 | :alt: Python Version
8 |
9 | .. |build| image:: https://travis-ci.com/intuit/metriks.svg?branch=master
10 | :target: https://travis-ci.com/intuit/metriks
11 |
12 | .. |coverage| image:: https://coveralls.io/repos/github/intuit/metriks/badge.svg?branch=master
13 | :target: https://coveralls.io/github/intuit/metriks?branch=master
14 |
15 | -----
16 |
17 | metriks is a Python package of commonly used metrics for evaluating information retrieval models.
18 |
19 | Available Metrics
20 | ---------------------------
21 | +------------------------------------------------------------+-------------------------------------------------------------------------------+
22 | | Python API | Description |
23 | +============================================================+===============================================================================+
24 | | `metriks.recall_at_k(y_true, y_prob, k)` | Calculates recall at k for binary classification ranking problems. |
25 | +------------------------------------------------------------+-------------------------------------------------------------------------------+
26 | | `metriks.precision_at_k(y_true, y_prob, k)` | Calculates precision at k for binary classification ranking problems. |
27 | +------------------------------------------------------------+-------------------------------------------------------------------------------+
28 | | `metriks.mean_reciprocal_rank(y_true, y_prob)` | Gets a positional score on how well you did at rank 1, rank 2, etc. |
29 | +------------------------------------------------------------+-------------------------------------------------------------------------------+
30 | | `metriks.ndcg(y_true, y_prob, k)` | A score for measuring the quality of a set of ranked results. |
31 | +------------------------------------------------------------+-------------------------------------------------------------------------------+
32 | | `metriks.label_mean_reciprocal_rank(y_true, y_prob)` | Determines the average rank each label was placed across samples. Only labels |
33 | | | that are relevant in the true data set are considered in the calculation. |
34 | +------------------------------------------------------------+-------------------------------------------------------------------------------+
35 | | `metriks.confusion_matrix_at_k(y_true, y_prob, k)` | Generates binary predictions from probabilities by evaluating the top k |
36 | | | items (in ranked order by y_prob) as true. |
37 | +------------------------------------------------------------+-------------------------------------------------------------------------------+
38 |
39 | Installation
40 | ------------
41 | Install using `pip `_
42 | ::
43 |
44 | pip install metriks
45 |
46 | Alternatively, specific distributions can be downloaded from the
47 | github `release `_
48 | page. Once downloaded, install the ``.tar.gz`` file directly:
49 | ::
50 |
51 | pip install metriks-\*.tar.gz
52 |
53 | Development
54 | -----------
55 | 1. (*Optional*) If you have `virtualenv` and `virtualenvwrapper` create a new virtual environment:
56 | ::
57 |
58 | mkvirtualenv metriks
59 |
60 | This isolates your specific project dependencies to avoid conflicts
61 | with other projects.
62 |
63 | 2. Clone and install the repository:
64 | ::
65 |
66 | git clone git@github.com:intuit/metriks.git
67 | cd metriks
68 | pip install -e .
69 |
70 |
71 | This will install a version to an isolated environment in editable
72 | mode. As you update the code in the repository, the new code will
73 | immediately be available to run within the environment (without the
74 | need to `pip install` it again)
75 |
76 | 3. Run the tests using `tox`:
77 | ::
78 |
79 | pip install tox
80 | tox
81 |
82 | Tox will run all of the tests in isolated environments
83 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Example Jupyter Notebooks
3 |
4 | This folder will have examples on how metriks can be used for calculating precision and recall on ranking models.
5 | Please create subfolders for each example depending on your dataset.
6 |
7 | Steps included in all examples are,
8 |
9 | 1. A default dataset has been choosen.
10 | 2. Trained a ranking model with the dataset choosen.
11 | 3. Finally demonstrated how these metrics can be used on our dataset,
12 |
13 | | Python API | Description |
14 | +============================================================+===============================================================================+
15 | | `metriks.recall_at_k(y_true, y_prob, k)` | Calculates recall at k for binary classification ranking problems. |
16 | +------------------------------------------------------------+-------------------------------------------------------------------------------+
17 | | `metriks.precision_at_k(y_true, y_prob, k)` | Calculates precision at k for binary classification ranking problems. |
18 | +------------------------------------------------------------+-------------------------------------------------------------------------------+
--------------------------------------------------------------------------------
/logo/metriks-logo.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/metriks/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | from metriks.__version import __version__
3 | except ImportError: # pragma: no cover
4 | __version__ = "dev"
5 |
6 | from metriks.ranking import (
7 | recall_at_k,
8 | precision_at_k,
9 | mean_reciprocal_rank,
10 | label_mean_reciprocal_rank,
11 | ndcg,
12 | confusion_matrix_at_k,
13 | )
14 |
--------------------------------------------------------------------------------
/metriks/ranking.py:
--------------------------------------------------------------------------------
1 | """
2 | Ranking
3 | =======
4 | Metrics to use for ranking models.
5 | """
6 |
7 | import numpy as np
8 |
9 |
10 | def check_arrays(y_true, y_prob):
11 | # Make sure that inputs this conforms to our expectations
12 | assert isinstance(y_true, np.ndarray), AssertionError(
13 | 'Expect y_true to be a {expected}. Got {actual}'
14 | .format(expected=np.ndarray, actual=type(y_true))
15 | )
16 |
17 | assert isinstance(y_prob, np.ndarray), AssertionError(
18 | 'Expect y_prob to be a {expected}. Got {actual}'
19 | .format(expected=np.ndarray, actual=type(y_prob))
20 | )
21 |
22 | assert y_true.shape == y_prob.shape, AssertionError(
23 | 'Shapes must match. Got y_true={true_shape}, y_prob={prob_shape}'
24 | .format(true_shape=y_true.shape, prob_shape=y_prob.shape)
25 | )
26 |
27 | assert len(y_true.shape) == 2, AssertionError(
28 | 'Shapes should be of rank 2. Got {rank}'
29 | .format(rank=len(y_true.shape))
30 | )
31 |
32 | uniques = np.unique(y_true)
33 | assert len(uniques) <= 2, AssertionError(
34 | 'Expected labels: [0, 1]. Got: {uniques}'
35 | .format(uniques=uniques)
36 | )
37 |
38 |
39 | def check_k(n_items, k):
40 | # Make sure that inputs conform to our expectations
41 | assert isinstance(k, int), AssertionError(
42 | 'Expect k to be a {expected}. Got {actual}'
43 | .format(expected=int, actual=type(k))
44 | )
45 |
46 | assert 0 <= k <= n_items, AssertionError(
47 | 'Expect 0 <= k <= {n_items}. Got {k}'
48 | .format(n_items=n_items, k=k)
49 | )
50 |
51 |
52 | def recall_at_k(y_true, y_prob, k):
53 | """
54 | Calculates recall at k for binary classification ranking problems. Recall
55 | at k measures the proportion of total relevant items that are found in the
56 | top k (in ranked order by y_prob). If k=5, there are 6 total relevant documents,
57 | and 3 of the top 5 items are relevant, the recall at k will be 0.5.
58 |
59 | Samples where y_true is 0 for all labels are filtered out because there will be
60 | 0 true positives and false negatives.
61 |
62 | Args:
63 | y_true (~np.ndarray): Flags (0, 1) which indicate whether a column is
64 | relevant or not. size=(n_samples, n_items)
65 | y_prob (~np.ndarray): The predicted probability that the given flag
66 | is relevant. size=(n_samples, n_items)
67 | k (int): Number of items to evaluate for relevancy, in descending
68 | sorted order by y_prob
69 |
70 | Returns:
71 | recall (~np.ndarray): The recall at k
72 |
73 | Example:
74 | >>> y_true = np.array([
75 | [0, 0, 1],
76 | [0, 1, 0],
77 | [1, 0, 0],
78 | ])
79 | >>> y_prob = np.array([
80 | [0.4, 0.6, 0.3],
81 | [0.1, 0.2, 0.9],
82 | [0.9, 0.6, 0.3],
83 | ])
84 | >>> recall_at_k(y_true, y_prob, 2)
85 | 0.6666666666666666
86 |
87 | In the example above, each of the samples has 1 total relevant document.
88 | For the first sample, there are 0 relevant documents in the top k for k=2,
89 | because 0.3 is the 3rd value for y_prob in descending order. For the second
90 | sample, there is 1 relevant document in the top k, because 0.2 is the 2nd
91 | value for y_prob in descending order. For the third sample, there is 1
92 | relevant document in the top k, because 0.9 is the 1st value for y_prob in
93 | descending order. Averaging the values for all of these samples (0, 1, 1)
94 | gives a value for recall at k of 2/3.
95 | """
96 | check_arrays(y_true, y_prob)
97 | check_k(y_true.shape[1], k)
98 |
99 | # Filter out rows of all zeros
100 | mask = y_true.sum(axis=1).astype(bool)
101 | y_prob = y_prob[mask]
102 | y_true = y_true[mask]
103 |
104 | # Extract shape components
105 | n_samples, n_items = y_true.shape
106 |
107 | # List of locations indexing
108 | y_prob_index_order = np.argsort(-y_prob)
109 | rows = np.reshape(np.arange(n_samples), (-1, 1))
110 | ranking = y_true[rows, y_prob_index_order]
111 |
112 | # Calculate number true positives for numerator and number of relevant documents for denominator
113 | num_tp = np.sum(ranking[:, :k], axis=1)
114 | num_relevant = np.sum(ranking, axis=1)
115 | # Calculate recall at k
116 | recall = np.mean(num_tp / num_relevant)
117 |
118 | return recall
119 |
120 |
121 | def precision_at_k(y_true, y_prob, k):
122 | """
123 | Calculates precision at k for binary classification ranking problems.
124 | Precision at k measures the proportion of items in the top k (in ranked
125 | order by y_prob) that are relevant (as defined by y_true). If k=5, and
126 | 3 of the top 5 items are relevant, the precision at k will be 0.6.
127 |
128 | Args:
129 | y_true (~np.ndarray): Flags (0, 1) which indicate whether a column is
130 | relevant or not. size=(n_samples, n_items)
131 | y_prob (~np.ndarray): The predicted probability that the given flag
132 | is relevant. size=(n_samples, n_items)
133 | k (int): Number of items to evaluate for relevancy, in descending
134 | sorted order by y_prob
135 |
136 | Returns:
137 | precision_k (~np.ndarray): The precision at k
138 |
139 | Example:
140 | >>> y_true = np.array([
141 | [0, 0, 1],
142 | [0, 1, 0],
143 | [1, 0, 0],
144 | ])
145 | >>> y_prob = np.array([
146 | [0.4, 0.6, 0.3],
147 | [0.1, 0.2, 0.9],
148 | [0.9, 0.6, 0.3],
149 | ])
150 | >>> precision_at_k(y_true, y_prob, 2)
151 | 0.3333333333333333
152 |
153 | For the first sample, there are 0 relevant documents in the top k for k=2,
154 | because 0.3 is the 3rd value for y_prob in descending order. For the second
155 | sample, there is 1 relevant document in the top k, because 0.2 is the 2nd
156 | value for y_prob in descending order. For the third sample, there is 1
157 | relevant document in the top k, because 0.9 is the 1st value for y_prob in
158 | descending order. Because k=2, the values for precision of k for each sample
159 | are 0, 1/2, and 1/2 respectively. Averaging these gives a value for precision
160 | at k of 1/3.
161 | """
162 | check_arrays(y_true, y_prob)
163 | check_k(y_true.shape[1], k)
164 |
165 | # Extract shape components
166 | n_samples, n_items = y_true.shape
167 |
168 | # List of locations indexing
169 | y_prob_index_order = np.argsort(-y_prob)
170 | rows = np.reshape(np.arange(n_samples), (-1, 1))
171 | ranking = y_true[rows, y_prob_index_order]
172 |
173 | # Calculate number of true positives for numerator
174 | num_tp = np.sum(ranking[:, :k], axis=1)
175 | # Calculate precision at k
176 | precision = np.mean(num_tp / k)
177 |
178 | return precision
179 |
180 |
181 | def mean_reciprocal_rank(y_true, y_prob):
182 | """
183 | Gets a positional score about how well you did at rank 1, rank 2,
184 | etc. The resulting vector is of size (n_items,) but element 0 corresponds to
185 | rank 1 not label 0.
186 |
187 | Args:
188 | y_true (~np.ndarray): Flags (0, 1) which indicate whether a column is
189 | relevant or not. size=(n_samples, n_items)
190 | y_prob (~np.ndarray): The predicted probability that the given flag
191 | is relevant. size=(n_samples, n_items)
192 |
193 | Returns:
194 | mrr (~np.ma.array): The positional ranking score. This will be masked
195 | for ranks where there were no relevant values. size=(n_items,)
196 | """
197 |
198 | check_arrays(y_true, y_prob)
199 |
200 | # Extract shape components
201 | n_samples, n_items = y_true.shape
202 |
203 | # Determine the ranking order
204 | rank_true = np.flip(np.argsort(y_true, axis=1), axis=1)
205 | rank_prob = np.flip(np.argsort(y_prob, axis=1), axis=1)
206 |
207 | # Compute reciprocal ranks
208 | reciprocal = 1.0 / (np.argsort(rank_prob, axis=1) + 1)
209 |
210 | # Now order the reciprocal ranks by the true order
211 | rows = np.reshape(np.arange(n_samples), (-1, 1))
212 | cols = rank_true
213 | ordered = reciprocal[rows, cols]
214 |
215 | # Create a masked array of true labels only
216 | ma = np.ma.array(ordered, mask=np.isclose(y_true[rows, cols], 0))
217 | return ma.mean(axis=0)
218 |
219 |
220 | def label_mean_reciprocal_rank(y_true, y_prob):
221 | """
222 | Determines the average rank each label was placed across samples. Only labels that are
223 | relevant in the true data set are considered in the calculation.
224 |
225 | Args:
226 | y_true (~np.ndarray): Flags (0, 1) which indicate whether a column is
227 | relevant or not. size=(n_samples, n_items)that
228 | y_prob (~np.ndarray): The predicted probability the given flag
229 | is relevant. size=(n_samples, n_items)
230 | Returns:
231 | mrr (~np.ma.array): The positional ranking score. This will be masked
232 | for ranks where there were no relevant values. size=(n_items,)
233 | """
234 |
235 | check_arrays(y_true, y_prob)
236 |
237 | rank_prob = np.flip(np.argsort(y_prob, axis=1), axis=1)
238 | reciprocal = 1 / (np.argsort(rank_prob, axis=1) + 1)
239 | ma = np.ma.array(reciprocal, mask=~y_true.astype(bool))
240 |
241 | return ma.mean(axis=0)
242 |
243 |
244 | def ndcg(y_true, y_prob, k=0):
245 | """
246 | A score for measuring the quality of a set of ranked results. The resulting score is between 0 and 1.0 -
247 | results that are relevant and appear earlier in the result set are given a heavier weight, so the
248 | higher the score, the more relevant your results are
249 |
250 | The optional k param is recommended for data sets where the first few labels are almost always ranked first,
251 | and hence skew the overall score. To compute this "NDCG after k" metric, we remove the top k (by predicted
252 | probability) labels and compute NDCG as usual for the remaining labels.
253 |
254 | Args:
255 | y_true (~np.ndarray): Flags (0, 1) which indicate whether a column is
256 | relevant or not. size=(n_samples, n_items)that
257 | y_prob (~np.ndarray): The predicted probability the given flag
258 | is relevant. size=(n_samples, n_items)
259 | k (int): Optional, the top k classes to exclude
260 | Returns:
261 | ndcg (~np.float64): The normalized dcg score across all queries, excluding the top k
262 | """
263 | # Get the sorted prob indices in descending order
264 | rank_prob = np.flip(np.argsort(y_prob, axis=1), axis=1)
265 |
266 | # Get the sorted true indices in descending order
267 | rank_true = np.flip(np.argsort(y_true, axis=1), axis=1)
268 |
269 | prob_samples, prob_items = y_prob.shape
270 | true_samples, true_items = y_true.shape
271 |
272 | # Compute DCG
273 |
274 | # Order y_true and y_prob by y_prob order indices
275 | prob_vals = y_prob[np.arange(prob_samples).reshape(prob_samples, 1), rank_prob]
276 | true_vals = y_true[np.arange(true_samples).reshape(true_samples, 1), rank_prob]
277 |
278 | # Remove the first k columns
279 | prob_vals = prob_vals[:, k:]
280 | true_vals = true_vals[:, k:]
281 |
282 | rank_prob_k = np.flip(np.argsort(prob_vals, axis=1), axis=1)
283 |
284 | n_samples, n_items = true_vals.shape
285 |
286 | values = np.arange(n_samples).reshape(n_samples, 1)
287 |
288 | # Construct the dcg numerator, which are the relevant items for each rank
289 | dcg_numerator = true_vals[values, rank_prob_k]
290 |
291 | # Construct the denominator, which is the log2 of the current rank + 1
292 | position = np.arange(1, n_items + 1)
293 | denominator = np.log2(np.tile(position, (n_samples, 1)) + 1.0)
294 |
295 | dcg = np.sum(dcg_numerator / denominator, axis=1)
296 |
297 | # Compute IDCG
298 | rank_true_idcg = np.flip(np.argsort(true_vals, axis=1), axis=1)
299 |
300 | idcg_true_samples, idcg_true_items = rank_true_idcg.shape
301 |
302 | # Order y_true indices
303 | idcg_true_vals = true_vals[np.arange(idcg_true_samples).reshape(idcg_true_samples, 1), rank_true_idcg]
304 |
305 | rank_true_k = np.flip(np.argsort(idcg_true_vals, axis=1), axis=1)
306 |
307 | idcg_numerator = idcg_true_vals[values, rank_true_k]
308 |
309 | idcg = np.sum(idcg_numerator / denominator, axis=1)
310 |
311 | with np.warnings.catch_warnings():
312 | np.warnings.filterwarnings('ignore')
313 | sample_ndcg = np.divide(dcg, idcg)
314 |
315 | # ndcg may be NaN if idcg is 0; this happens when there are no relevant documents
316 | # in this case, showing anything in any order should be considered correct
317 | where_are_nans = np.isnan(sample_ndcg)
318 | where_are_infs = np.isinf(sample_ndcg)
319 | sample_ndcg[where_are_nans] = 1.0
320 | sample_ndcg[where_are_infs] = 0.0
321 |
322 | return np.mean(sample_ndcg, dtype=np.float64)
323 |
324 |
325 | def generate_y_pred_at_k(y_prob, k):
326 | """
327 | Generates a matrix of binary predictions from a matrix of probabilities
328 | by evaluating the top k items (in ranked order by y_prob) as true.
329 |
330 | In the case where multiple probabilities for a sample are identical, the
331 | behavior is undefined in terms of how the probabilities are ranked by argsort.
332 |
333 | Args:
334 | y_prob (~np.ndarray): The predicted probability that the given flag
335 | is relevant. size=(n_samples, n_items)
336 | k (int): Number of items to evaluate as true, in descending
337 | sorted order by y_prob
338 |
339 | Returns:
340 | y_pred (~np.ndarray): A binary prediction that the given flag is
341 | relevant. size=(n_samples, n_items)
342 |
343 | Example:
344 | >>> y_prob = np.array([
345 | [0.4, 0.6, 0.3],
346 | [0.1, 0.2, 0.9],
347 | [0.9, 0.6, 0.3],
348 | ])
349 | >>> generate_y_pred_at_k(y_prob, 2)
350 | array([
351 | [1, 1, 0],
352 | [0, 1, 1],
353 | [1, 1, 0]
354 | ])
355 |
356 | For the first sample, the top 2 values for y_prob are 0.6 and 0.4, so y_pred
357 | at those positions is 1. For the second sample, the top 2 values for y_prob
358 | are 0.9 and 0.2, so y_pred at these positions is 1. For the third sample, the
359 | top 2 values for y_prob are 0.9 and 0.6, so y_pred at these positions in 1.
360 | """
361 | n_items = y_prob.shape[1]
362 | index_array = np.argsort(y_prob, axis=1)
363 | col_idx = np.arange(y_prob.shape[0]).reshape(-1, 1)
364 | y_pred = np.zeros(np.shape(y_prob))
365 | y_pred[col_idx, index_array[:, n_items-k:n_items]] = 1
366 | return y_pred
367 |
368 |
369 | def confusion_matrix_at_k(y_true, y_prob, k):
370 | """
371 | Generates binary predictions from probabilities by evaluating the top k items
372 | (in ranked order by y_prob) as true. Uses these binary predictions along with
373 | true flags to calculate the confusion matrix per label for binary
374 | classification problems.
375 |
376 | Args:
377 | y_true (~np.ndarray): Flags (0, 1) which indicate whether a column is
378 | relevant or not. size=(n_samples, n_items)
379 | y_prob (~np.ndarray): The predicted probability that the given flag
380 | is relevant. size=(n_samples, n_items)
381 | k (int): Number of items to evaluate as true, in descending
382 | sorted order by y_prob
383 |
384 | Returns:
385 | tn, fp, fn, tp (tuple of ~np.ndarrays): A tuple of ndarrays containing
386 | the number of true negatives (tn), false positives (fp),
387 | false negatives (fn), and true positives (tp) for each item. The
388 | length of each ndarray is equal to n_items
389 |
390 | Example:
391 | >>> y_true = np.array([
392 | [0, 0, 1],
393 | [0, 1, 0],
394 | [1, 0, 0],
395 | ])
396 | >>> y_prob = np.array([
397 | [0.4, 0.6, 0.3],
398 | [0.1, 0.2, 0.9],
399 | [0.9, 0.6, 0.3],
400 | ])
401 | >>> y_pred = np.array([
402 | [1, 1, 0],
403 | [0, 1, 1],
404 | [1, 1, 0]
405 | ])
406 | >>> label_names = ['moved', 'hadAJob', 'farmIncome']
407 | >>> confusion_matrix_at_k(y_true, y_prob, 2)
408 | (
409 | np.array([1, 0, 1]),
410 | np.array([1, 2, 1]),
411 | np.array([0, 0, 1]),
412 | np.array([1, 1, 0])
413 | )
414 |
415 | In the example above, y_pred is not passed into the function, but is
416 | generated by calling generate_y_pred_at_k with y_prob and k.
417 |
418 | For the first item (moved), the first sample is a false positive, the
419 | second is a true negative, and the third is a true positive.
420 | For the second item (hadAJob), the first and third samples are false
421 | positives, and the second is a true positive.
422 | For the third item (farmIncome), the first item is a false negative, the
423 | second is a false positive, and the third is a true positive.
424 | """
425 | check_arrays(y_true, y_prob)
426 | check_k(y_true.shape[1], k)
427 |
428 | y_pred = generate_y_pred_at_k(y_prob, k)
429 |
430 | tp = np.count_nonzero(y_pred * y_true, axis=0)
431 | tn = np.count_nonzero((y_pred - 1) * (y_true - 1), axis=0)
432 | fp = np.count_nonzero(y_pred * (y_true - 1), axis=0)
433 | fn = np.count_nonzero((y_pred - 1) * y_true, axis=0)
434 |
435 | return tn, fp, fn, tp
436 |
437 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open('README.rst', encoding='utf-8') as fh:
4 | long_description = fh.read()
5 |
6 | setup(
7 | # Package information
8 | name='metriks',
9 | description='metriks is a Python package of commonly used metrics for evaluating information retrieval models.',
10 | long_description=long_description,
11 | long_description_content_type='text/x-rst',
12 | url='https://github.com/intuit/metriks',
13 | classifiers=[
14 | 'Programming Language :: Python :: 3',
15 | 'License :: OSI Approved :: MIT License'
16 | ],
17 | pyton_requires='>=3.6',
18 |
19 | # Package data
20 | use_scm_version={
21 | "write_to": "metriks/__version.py",
22 | "write_to_template": '__version__ = "{version}"\n',
23 | },
24 | packages=find_packages(exclude=('tests*',)),
25 |
26 | # Insert dependencies list here
27 | install_requires=[
28 | 'numpy',
29 | ],
30 | setup_requires=["setuptools-scm"],
31 | extras_require={
32 | 'dev': [
33 | 'pytest',
34 | 'tox',
35 | ]
36 | }
37 | )
38 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/intuit/metriks/a0e62df847b8997bfc32b337d7e41db16d0a8ce4/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_confusion_matrix_at_k.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import metriks
3 |
4 |
5 | def test_confusion_matrix_at_k_wikipedia():
6 | """
7 | Tests the wikipedia example.
8 |
9 | https://en.wikipedia.org/wiki/Mean_reciprocal_rank
10 | """
11 |
12 | # Flag indicating which is actually relevant
13 | y_true = np.array([
14 | [0, 0, 1],
15 | [0, 1, 0],
16 | [1, 0, 0],
17 | ])
18 |
19 | # The predicted probability
20 | y_prob = np.array([
21 | [0.4, 0.6, 0.3],
22 | [0.1, 0.2, 0.9],
23 | [0.9, 0.6, 0.3],
24 | ])
25 |
26 | expected = (
27 | np.array([1, 0, 1]),
28 | np.array([1, 2, 1]),
29 | np.array([0, 0, 1]),
30 | np.array([1, 1, 0])
31 | )
32 | result = metriks.confusion_matrix_at_k(y_true, y_prob, 2)
33 | for i in range(4):
34 | assert expected[i].all() == result[i].all(), AssertionError(
35 | 'Expected:\n{expected}. \nGot:\n{actual}'
36 | .format(expected=expected, actual=result)
37 | )
38 |
39 | print(result)
40 |
41 |
42 | def test_confusion_matrix_at_k_with_errors():
43 | """Test confusion_matrix_at_k where there are errors in the probabilities"""
44 | y_true = np.array([
45 | [1, 1, 0],
46 | [1, 0, 1],
47 | [1, 1, 0],
48 | [1, 0, 1],
49 | [0, 1, 0],
50 | [0, 1, 1],
51 | [0, 1, 1],
52 | [1, 0, 1],
53 | [1, 1, 1],
54 | ])
55 |
56 | y_prob = np.array([
57 | [0.7, 0.6, 0.3],
58 | [0.9, 0.2, 0.1], # 3rd probability is an error
59 | [0.7, 0.8, 0.9], # 3rd probability is an error
60 | [0.9, 0.8, 0.3], # 2nd and 3rd probability are swapped
61 | [0.4, 0.6, 0.3],
62 | [0.1, 0.6, 0.9],
63 | [0.1, 0.6, 0.9],
64 | [0.1, 0.6, 0.5],
65 | [0.9, 0.8, 0.7],
66 | ])
67 |
68 | # Check results
69 | expected = (
70 | np.array([2, 0, 2]),
71 | np.array([1, 3, 1]),
72 | np.array([2, 0, 3]),
73 | np.array([4, 6, 3])
74 | )
75 | result = metriks.confusion_matrix_at_k(y_true, y_prob, 2)
76 | for i in range(4):
77 | assert expected[i].all() == result[i].all(), AssertionError(
78 | 'Expected:\n{expected}. \nGot:\n{actual}'
79 | .format(expected=expected, actual=result)
80 | )
81 |
82 | print(result)
83 |
84 |
85 | def test_confusion_matrix_at_k_perfect():
86 | """Test confusion_matrix_at_k where the probabilities are perfect"""
87 | y_true = np.array([
88 | [0, 1, 0],
89 | [1, 0, 0],
90 | [0, 1, 0],
91 | [1, 0, 0],
92 | [0, 0, 1],
93 | ])
94 |
95 | y_prob = np.array([
96 | [0.3, 0.7, 0.0],
97 | [0.1, 0.0, 0.0],
98 | [0.1, 0.5, 0.3],
99 | [0.6, 0.2, 0.4],
100 | [0.1, 0.2, 0.3],
101 | ])
102 |
103 | # Check results
104 | expected = (
105 | np.array([3, 3, 4]),
106 | np.array([0, 0, 0]),
107 | np.array([0, 0, 0]),
108 | np.array([2, 2, 1])
109 | )
110 | result = metriks.confusion_matrix_at_k(y_true, y_prob, 1)
111 | for i in range(4):
112 | assert expected[i].all() == result[i].all(), AssertionError(
113 | 'Expected:\n{expected}. \nGot:\n{actual}'
114 | .format(expected=expected, actual=result)
115 | )
116 |
117 | print(result)
118 |
119 |
120 | def test_confusion_matrix_at_k_perfect_multiple_true():
121 | """Test confusion_matrix_at_k where the probabilities are perfect and there
122 | are multiple true labels for some samples"""
123 | y_true = np.array([
124 | [1, 1, 0],
125 | [1, 0, 0],
126 | [0, 1, 1],
127 | [1, 0, 1],
128 | [0, 1, 1],
129 | ])
130 |
131 | y_prob = np.array([
132 | [0.3, 0.7, 0.0],
133 | [0.1, 0.05, 0.0],
134 | [0.1, 0.5, 0.3],
135 | [0.6, 0.2, 0.4],
136 | [0.1, 0.2, 0.3],
137 | ])
138 |
139 | # Check results for k=0
140 | expected = (
141 | np.array([2, 2, 2]),
142 | np.array([0, 0, 0]),
143 | np.array([3, 3, 3]),
144 | np.array([0, 0, 0])
145 | )
146 | result = metriks.confusion_matrix_at_k(y_true, y_prob, 0)
147 | for i in range(4):
148 | assert expected[i].all() == result[i].all(), AssertionError(
149 | 'Expected:\n{expected}. \nGot:\n{actual}'
150 | .format(expected=expected, actual=result)
151 | )
152 |
153 | print(result)
154 |
155 | # Check results for k=2
156 | expected = (
157 | np.array([2, 1, 2]),
158 | np.array([0, 1, 0]),
159 | np.array([0, 0, 0]),
160 | np.array([3, 3, 3])
161 | )
162 | result = metriks.confusion_matrix_at_k(y_true, y_prob, 2)
163 | for i in range(4):
164 | assert expected[i].all() == result[i].all(), AssertionError(
165 | 'Expected:\n{expected}. \nGot:\n{actual}'
166 | .format(expected=expected, actual=result)
167 | )
168 |
169 | print(result)
170 |
171 |
172 | def test_confusion_matrix_at_k_few_zeros():
173 | """Test confusion_matrix_at_k where there are very few zeros"""
174 | y_true = np.array([
175 | [0, 1, 1, 1, 1],
176 | [1, 1, 1, 1, 1],
177 | [1, 1, 1, 1, 0]
178 | ])
179 |
180 | y_prob = np.array([
181 | [0.1, 0.4, 0.35, 0.8, 0.9],
182 | [0.3, 0.2, 0.7, 0.8, 0.6],
183 | [0.1, 0.2, 0.3, 0.4, 0.5]
184 | ])
185 |
186 | # Check results for k=2
187 | expected = (
188 | np.array([1, 0, 0, 0, 0]),
189 | np.array([0, 0, 0, 0, 1]),
190 | np.array([2, 3, 2, 0, 1]),
191 | np.array([0, 0, 1, 3, 1])
192 | )
193 | result = metriks.confusion_matrix_at_k(y_true, y_prob, 2)
194 | for i in range(4):
195 | assert expected[i].all() == result[i].all(), AssertionError(
196 | 'Expected:\n{expected}. \nGot:\n{actual}'
197 | .format(expected=expected, actual=result)
198 | )
199 |
200 | print(result)
201 |
202 | # Check results for k=5
203 | expected = {
204 | 'moved': {'fp': 1, 'tn': 0, 'fn': 0, 'tp': 2},
205 | 'hadAJob': {'fp': 0, 'tn': 0, 'fn': 0, 'tp': 3},
206 | 'farmIncome': {'fp': 0, 'tn': 0, 'fn': 0, 'tp': 3},
207 | 'married': {'fp': 0, 'tn': 0, 'fn': 0, 'tp': 3},
208 | 'alimony': {'fp': 1, 'tn': 0, 'fn': 0, 'tp': 2}
209 | }
210 | expected = (
211 | np.array([0, 0, 0, 0, 0]),
212 | np.array([1, 0, 0, 0, 1]),
213 | np.array([0, 0, 0, 0, 0]),
214 | np.array([2, 3, 3, 3, 2])
215 | )
216 | result = metriks.confusion_matrix_at_k(y_true, y_prob, 5)
217 | for i in range(4):
218 | assert expected[i].all() == result[i].all(), AssertionError(
219 | 'Expected:\n{expected}. \nGot:\n{actual}'
220 | .format(expected=expected, actual=result)
221 | )
222 |
223 | print(result)
224 |
225 |
226 | def test_confusion_matrix_at_k_zeros():
227 | """Test confusion_matrix_at_k where there is a sample of all zeros"""
228 | y_true = np.array([
229 | [0, 0, 1, 1],
230 | [1, 1, 1, 0],
231 | [0, 0, 0, 0]
232 | ])
233 |
234 | y_prob = np.array([
235 | [0.1, 0.4, 0.35, 0.8],
236 | [0.3, 0.2, 0.7, 0.8],
237 | [0.1, 0.2, 0.3, 0.4]
238 | ])
239 |
240 | # Check results
241 | expected = {
242 | 'moved': {'fp': 0, 'tn': 2, 'fn': 1, 'tp': 0},
243 | 'hadAJob': {'fp': 1, 'tn': 1, 'fn': 1, 'tp': 0},
244 | 'farmIncome': {'fp': 1, 'tn': 0, 'fn': 1, 'tp': 1},
245 | 'married': {'fp': 2, 'tn': 0, 'fn': 0, 'tp': 1}
246 | }
247 | expected = (
248 | np.array([2, 1, 0, 0]),
249 | np.array([0, 1, 1, 2]),
250 | np.array([1, 1, 1, 0]),
251 | np.array([0, 0, 1, 1])
252 | )
253 | result = metriks.confusion_matrix_at_k(y_true, y_prob, 2)
254 | for i in range(4):
255 | assert expected[i].all() == result[i].all(), AssertionError(
256 | 'Expected:\n{expected}. \nGot:\n{actual}'
257 | .format(expected=expected, actual=result)
258 | )
259 |
260 | print(result)
261 |
262 |
263 | if __name__ == '__main__':
264 | test_confusion_matrix_at_k_wikipedia()
265 | test_confusion_matrix_at_k_with_errors()
266 | test_confusion_matrix_at_k_perfect()
267 | test_confusion_matrix_at_k_perfect_multiple_true()
268 | test_confusion_matrix_at_k_few_zeros()
269 | test_confusion_matrix_at_k_zeros()
270 |
--------------------------------------------------------------------------------
/tests/test_label_mean_reciprocal_rank.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import metriks
4 |
5 |
6 | def test_label_mrr_with_errors():
7 | y_true = np.array([[0, 1, 1, 0],
8 | [1, 1, 1, 0],
9 | [0, 1, 1, 1]])
10 | y_prob = np.array([[0.2, 0.9, 0.8, 0.4],
11 | [0.9, 0.2, 0.8, 0.4],
12 | [0.2, 0.8, 0.9, 0.4]])
13 |
14 | result = metriks.label_mean_reciprocal_rank(y_true, y_prob)
15 |
16 | expected = np.ma.array([1.0, (7/4) / 3, 2/3, 1/3])
17 |
18 | np.testing.assert_allclose(result, expected)
19 | print(result)
20 |
21 |
22 | def test_label_mrr_perfect():
23 | """Test MRR where the probabilities are perfect"""
24 | y_true = np.array([
25 | [0, 1, 0],
26 | [1, 0, 0],
27 | [0, 1, 0],
28 | [1, 0, 0],
29 | [0, 0, 1],
30 | ])
31 |
32 | y_prob = np.array([
33 | [0.3, 0.7, 0.0],
34 | [0.1, 0.0, 0.0],
35 | [0.1, 0.5, 0.3],
36 | [0.6, 0.2, 0.4],
37 | [0.1, 0.2, 0.3],
38 | ])
39 |
40 | # Check results
41 | expected = np.ma.array([1.0, 0.0, 0.0], mask=[False, True, True])
42 | result = metriks.label_mean_reciprocal_rank(y_true, y_prob)
43 | np.testing.assert_allclose(result, expected)
44 |
45 | print(result)
46 |
47 |
48 | def test_label_mrr_zeros():
49 | """Test MRR where no relevant labels"""
50 | y_true = np.array([
51 | [0, 0, 0],
52 | [0, 0, 0],
53 | [0, 0, 0]
54 | ])
55 |
56 | y_prob = np.array([
57 | [0.3, 0.7, 0.0],
58 | [0.1, 0.0, 0.0],
59 | [0.1, 0.5, 0.3]
60 | ])
61 |
62 | # Check results
63 | result = metriks.label_mean_reciprocal_rank(y_true, y_prob)
64 | np.testing.assert_equal(result.mask.all(), True)
65 |
66 | print(result)
67 |
68 |
69 | def test_label_mrr_some_zeros():
70 | """Test MRR where some relevant labels"""
71 | y_true = np.array([
72 | [0, 1, 0],
73 | [0, 0, 0],
74 | [0, 1, 0]
75 | ])
76 |
77 | y_prob = np.array([
78 | [0.3, 0.7, 0.0],
79 | [0.1, 0.0, 0.0],
80 | [0.1, 0.5, 0.3]
81 | ])
82 |
83 | # Check results
84 | expected = np.ma.array([0.0, 1.0, 0.0], mask=[True, False, True])
85 | result = metriks.label_mean_reciprocal_rank(y_true, y_prob)
86 | np.testing.assert_allclose(result, expected)
87 |
88 | print(result)
89 |
90 |
91 | def test_label_mrr_ones():
92 | """Test MRR where all labels are relevant and predictions are perfect"""
93 | y_true = np.array([
94 | [1, 1, 1],
95 | [1, 1, 1],
96 | [1, 1, 1]
97 | ])
98 |
99 | y_prob = np.array([
100 | [0.3, 0.7, 0.0],
101 | [0.1, 0.9, 0.0],
102 | [0.2, 0.5, 0.1]
103 | ])
104 |
105 | # Check results
106 | expected = np.ma.array([0.5, 1.0, 1/3])
107 | result = metriks.label_mean_reciprocal_rank(y_true, y_prob)
108 | np.testing.assert_allclose(result, expected)
109 |
110 | print(result)
111 |
112 |
113 | if __name__ == '__main__':
114 | test_label_mrr_with_errors()
115 | test_label_mrr_perfect()
116 | test_label_mrr_zeros()
117 | test_label_mrr_some_zeros()
118 | test_label_mrr_ones()
119 |
120 |
121 |
--------------------------------------------------------------------------------
/tests/test_mean_reciprocal_rank.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import metriks
4 |
5 |
6 | def test_mrr_wikipedia():
7 | """
8 | Tests the wikipedia example.
9 |
10 | https://en.wikipedia.org/wiki/Mean_reciprocal_rank
11 | """
12 | values = [
13 | ['catten', 'cati', 'cats'],
14 | ['torii', 'tori', 'toruses'],
15 | ['viruses', 'virii', 'viri'],
16 | ]
17 |
18 | # Flag indicating which is actually relevant
19 | y_true = np.array([
20 | [0, 0, 1],
21 | [0, 1, 0],
22 | [1, 0, 0],
23 | ])
24 |
25 | # The predicted probability
26 | y_prob = np.array([
27 | [0.4, 0.6, 0.3],
28 | [0.1, 0.2, 0.9],
29 | [0.9, 0.6, 0.3],
30 | ])
31 |
32 | # Check results
33 | expected = np.ma.array([11.0/18.0, 0.0, 0.0], mask=[False, True, True])
34 | result = metriks.mean_reciprocal_rank(y_true, y_prob)
35 | np.testing.assert_allclose(result, expected)
36 |
37 | print(result)
38 |
39 |
40 | def test_mrr_with_errors():
41 | """Test MRR where there are errors in the probabilities"""
42 | y_true = np.array([
43 | [1, 1, 0],
44 | [1, 0, 1],
45 | [1, 1, 0],
46 | [1, 0, 1],
47 | [0, 1, 0],
48 | [0, 1, 1],
49 | [0, 1, 1],
50 | [1, 0, 1],
51 | [1, 1, 1],
52 | ])
53 |
54 | y_prob = np.array([
55 | [0.7, 0.6, 0.3],
56 | [0.9, 0.2, 0.1], # 3rd probability is an error
57 | [0.7, 0.8, 0.9], # 3rd probability is an error
58 | [0.9, 0.8, 0.3], # 2nd and 3rd probability are swapped
59 | [0.4, 0.6, 0.3],
60 | [0.1, 0.6, 0.9],
61 | [0.1, 0.6, 0.9],
62 | [0.1, 0.6, 0.5],
63 | [0.9, 0.8, 0.7],
64 | ])
65 |
66 | # Check results
67 | expected = np.ma.array([11.0 / 18.0, 31.0 / 48.0, 1.0])
68 | result = metriks.mean_reciprocal_rank(y_true, y_prob)
69 | np.testing.assert_allclose(result, expected)
70 |
71 | print(result)
72 |
73 |
74 | def test_mrr_perfect():
75 | """Test MRR where the probabilities are perfect"""
76 | y_true = np.array([
77 | [0, 1, 0],
78 | [1, 0, 0],
79 | [0, 1, 0],
80 | [1, 0, 0],
81 | [0, 0, 1],
82 | ])
83 |
84 | y_prob = np.array([
85 | [0.3, 0.7, 0.0],
86 | [0.1, 0.0, 0.0],
87 | [0.1, 0.5, 0.3],
88 | [0.6, 0.2, 0.4],
89 | [0.1, 0.2, 0.3],
90 | ])
91 |
92 | # Check results
93 | expected = np.ma.array([1.0, 0.0, 0.0], mask=[False, True, True])
94 | result = metriks.mean_reciprocal_rank(y_true, y_prob)
95 | np.testing.assert_allclose(result, expected)
96 |
97 | print(result)
98 |
99 |
100 | def test_mrr_zeros():
101 | """Test MRR where there is a sample of all zeros"""
102 | y_true = np.array([
103 | [0, 0, 1, 1],
104 | [1, 1, 1, 0],
105 | [0, 0, 0, 0]
106 | ])
107 |
108 | y_prob = np.array([
109 | [0.1, 0.4, 0.35, 0.8],
110 | [0.3, 0.2, 0.7, 0.8],
111 | [0.1, 0.2, 0.3, 0.4]
112 | ])
113 |
114 | # Check results
115 | expected = np.ma.array([0.75, 7.0/24.0, 1.0/3.0, 0.0], mask=[False, False, False, True])
116 | result = metriks.mean_reciprocal_rank(y_true, y_prob)
117 | np.testing.assert_allclose(result, expected)
118 |
119 | print(result)
120 |
121 |
122 | if __name__ == '__main__':
123 | test_mrr_wikipedia()
124 | test_mrr_with_errors()
125 | test_mrr_perfect()
126 | test_mrr_zeros()
--------------------------------------------------------------------------------
/tests/test_ndcg.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import metriks
3 |
4 |
5 | def test_ndcg_perfect():
6 | """
7 | All predicted rankings match the expected ranking
8 | """
9 | y_true = np.array([[0, 1, 1, 0],
10 | [0, 1, 1, 1],
11 | [1, 1, 1, 0]])
12 |
13 | y_prob = np.array([[0.2, 0.9, 0.8, 0.4],
14 | [0.4, 0.8, 0.7, 0.5],
15 | [0.9, 0.6, 0.7, 0.2]])
16 |
17 | expected = 1.0
18 | actual = metriks.ndcg(y_true, y_prob)
19 |
20 | np.testing.assert_equal([actual], [expected])
21 |
22 |
23 | def test_ndcg_errors():
24 | """
25 | Some samples predicted the order incorrectly
26 | """
27 | y_true = np.array([[1, 1, 1, 0],
28 | [1, 0, 1, 1],
29 | [0, 1, 1, 0]])
30 |
31 | y_prob = np.array([[0.2, 0.9, 0.8, 0.4],
32 | [0.5, 0.8, 0.7, 0.4],
33 | [0.9, 0.6, 0.7, 0.2]])
34 |
35 | expected = 0.7979077
36 | actual = metriks.ndcg(y_true, y_prob)
37 |
38 | np.testing.assert_allclose([actual], [expected])
39 |
40 |
41 | def test_ndcg_all_ones():
42 | """
43 | Every item in each sample is relevant
44 | """
45 | y_true = np.array([[1, 1, 1, 1],
46 | [1, 1, 1, 1],
47 | [1, 1, 1, 1]])
48 |
49 | y_prob = np.array([[0.2, 0.9, 0.8, 0.4],
50 | [0.5, 0.8, 0.7, 0.4],
51 | [0.9, 0.6, 0.7, 0.2]])
52 |
53 | expected = 1.0
54 | actual = metriks.ndcg(y_true, y_prob)
55 |
56 | np.testing.assert_allclose([actual], [expected])
57 |
58 |
59 | def test_all_zeros():
60 | """
61 | There are no relevant items in any sample
62 | """
63 | y_true = np.array([[0, 0, 0, 0],
64 | [0, 0, 0, 0],
65 | [0, 0, 0, 0]])
66 |
67 | y_prob = np.array([[0.2, 0.9, 0.8, 0.4],
68 | [0.5, 0.8, 0.7, 0.4],
69 | [0.9, 0.6, 0.7, 0.2]])
70 |
71 | expected = 1.0
72 | actual = metriks.ndcg(y_true, y_prob)
73 |
74 | np.testing.assert_allclose([actual], [expected])
75 |
76 |
77 | def test_combination():
78 | """
79 | Mixture of all relevant, no relevant, and some relevant samples
80 | """
81 | y_true = np.array([[0, 0, 0, 0],
82 | [0, 1, 1, 1],
83 | [1, 1, 1, 1]])
84 |
85 | y_prob = np.array([[0.2, 0.9, 0.8, 0.4],
86 | [0.5, 0.8, 0.7, 0.4],
87 | [0.9, 0.6, 0.7, 0.2]])
88 |
89 | expected = 0.989156
90 | actual = metriks.ndcg(y_true, y_prob)
91 |
92 | np.testing.assert_allclose([actual], [expected])
93 |
94 |
95 | def test_ndcg_after_0():
96 | """
97 | ndcg with no classes removed
98 | """
99 | y_true = np.array([[0, 1, 1, 0, 0, 1],
100 | [0, 1, 1, 0, 1, 1],
101 | [1, 1, 1, 1, 1, 0],
102 | [0, 1, 1, 0, 1, 0],
103 | [1, 1, 0, 0, 1, 0]])
104 |
105 | y_prob = np.array([[0.5, 0.9, 0.8, 0.4, 0.2, 0.2],
106 | [0.9, 0.8, 0.2, 0.5, 0.2, 0.4],
107 | [0.2, 0.9, 0.8, 0.4, 0.6, 0.5],
108 | [0.2, 0.7, 0.8, 0.4, 0.6, 0.9],
109 | [0.2, 0.9, 0.8, 0.5, 0.3, 0.2]])
110 |
111 | expected = 0.8395053
112 | actual = metriks.ndcg(y_true, y_prob)
113 |
114 | np.testing.assert_allclose([actual], [expected])
115 |
116 |
117 | def test_ndcg_after_1():
118 | """
119 | ndcg with top 1 class removed
120 | """
121 | y_true = np.array([[0, 1, 1, 0, 0, 1],
122 | [1, 1, 1, 0, 1, 0],
123 | [1, 1, 1, 1, 1, 0],
124 | [0, 1, 1, 0, 0, 1],
125 | [1, 1, 0, 0, 1, 1]])
126 |
127 | y_prob = np.array([[0.5, 0.9, 0.8, 0.4, 0.2, 0.2],
128 | [0.9, 0.8, 0.2, 0.5, 0.2, 0.4],
129 | [0.2, 0.9, 0.8, 0.4, 0.6, 0.5],
130 | [0.2, 0.7, 0.8, 0.4, 0.6, 0.9],
131 | [0.2, 0.2, 0.8, 0.5, 0.3, 0.9]])
132 |
133 | expected = 0.8554782
134 | actual = metriks.ndcg(y_true, y_prob, 1)
135 |
136 | np.testing.assert_allclose([actual], [expected])
137 |
138 |
139 | def test_ndcg_after_2_perfect():
140 | """
141 | ndcg with top 2 classes removed
142 | """
143 | y_true = np.array([[1, 1, 1, 0, 0, 0],
144 | [1, 1, 0, 1, 0, 1],
145 | [0, 1, 1, 1, 1, 1],
146 | [0, 1, 1, 0, 1, 1],
147 | [0, 0, 1, 1, 1, 1]])
148 |
149 | y_prob = np.array([[0.5, 0.9, 0.8, 0.4, 0.2, 0.2],
150 | [0.9, 0.8, 0.2, 0.5, 0.2, 0.4],
151 | [0.2, 0.9, 0.8, 0.4, 0.6, 0.5],
152 | [0.2, 0.7, 0.8, 0.4, 0.6, 0.9],
153 | [0.2, 0.2, 0.8, 0.5, 0.3, 0.9]])
154 |
155 | expected = 1.0
156 | actual = metriks.ndcg(y_true, y_prob, 2)
157 |
158 | np.testing.assert_allclose([actual], [expected])
159 |
160 |
161 | def test_ndcg_after_2_errors():
162 | """
163 | ndcg with top 2 classes removed
164 | """
165 | y_true = np.array([[0, 1, 1, 0, 0, 1],
166 | [1, 1, 1, 0, 1, 1],
167 | [1, 1, 1, 1, 1, 0],
168 | [0, 1, 1, 0, 1, 1],
169 | [0, 0, 1, 1, 1, 1]])
170 |
171 | y_prob = np.array([[0.5, 0.9, 0.8, 0.4, 0.2, 0.2],
172 | [0.9, 0.8, 0.2, 0.5, 0.2, 0.4],
173 | [0.2, 0.9, 0.8, 0.4, 0.6, 0.5],
174 | [0.2, 0.7, 0.8, 0.4, 0.6, 0.9],
175 | [0.2, 0.2, 0.8, 0.5, 0.3, 0.9]])
176 |
177 | expected = 0.8139061
178 | actual = metriks.ndcg(y_true, y_prob, 2)
179 |
180 | np.testing.assert_allclose([actual], [expected])
181 |
182 |
183 | def test_ndcg_1_over_0_error():
184 | """
185 | ndcg with division by 0
186 | """
187 | y_true = np.array([[0, 1, 1, 1]])
188 |
189 | y_prob = np.array([[0.9, 0.8, 0.7, 0.6]])
190 |
191 | expected = 1.0
192 | actual = metriks.ndcg(y_true, y_prob, 1)
193 |
194 | np.testing.assert_allclose([actual], [expected])
195 |
196 |
197 | def test_ndcg_0_over_0_error():
198 | """
199 | ndcg with division by 0
200 | """
201 | y_true = np.array([[0, 0, 0, 0]])
202 |
203 | y_prob = np.array([[0.9, 0.8, 0.7, 0.6]])
204 |
205 | expected = 1.0
206 | actual = metriks.ndcg(y_true, y_prob, 1)
207 |
208 | np.testing.assert_allclose([actual], [expected])
209 |
210 | if __name__ == '__main__':
211 | test_ndcg_perfect()
212 | test_ndcg_errors()
213 | test_ndcg_all_ones()
214 | test_all_zeros()
215 | test_combination()
216 | test_ndcg_after_0()
217 | test_ndcg_after_1()
218 | test_ndcg_after_2_perfect()
219 | test_ndcg_after_2_errors()
220 | test_ndcg_1_over_0_error()
221 | test_ndcg_0_over_0_error()
222 |
--------------------------------------------------------------------------------
/tests/test_precision_at_k.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import metriks
4 |
5 |
6 | def test_precision_at_k_wikipedia():
7 | """
8 | Tests the wikipedia example.
9 |
10 | https://en.wikipedia.org/wiki/Mean_reciprocal_rank
11 | """
12 |
13 | # Flag indicating which is actually relevant
14 | y_true = np.array([
15 | [0, 0, 1],
16 | [0, 1, 0],
17 | [1, 0, 0],
18 | ])
19 |
20 | # The predicted probability
21 | y_prob = np.array([
22 | [0.4, 0.6, 0.3],
23 | [0.1, 0.2, 0.9],
24 | [0.9, 0.6, 0.3],
25 | ])
26 |
27 | # Check results
28 | expected = 2.0/6.0
29 | result = metriks.precision_at_k(y_true, y_prob, 2)
30 | np.testing.assert_allclose(result, expected)
31 |
32 | print(result)
33 |
34 |
35 | def test_precision_at_k_with_errors():
36 | """Test precision_at_k where there are errors in the probabilities"""
37 | y_true = np.array([
38 | [1, 1, 0],
39 | [1, 0, 1],
40 | [1, 1, 0],
41 | [1, 0, 1],
42 | [0, 1, 0],
43 | [0, 1, 1],
44 | [0, 1, 1],
45 | [1, 0, 1],
46 | [1, 1, 1],
47 | ])
48 |
49 | y_prob = np.array([
50 | [0.7, 0.6, 0.3],
51 | [0.9, 0.2, 0.1], # 3rd probability is an error
52 | [0.7, 0.8, 0.9], # 3rd probability is an error
53 | [0.9, 0.8, 0.3], # 2nd and 3rd probability are swapped
54 | [0.4, 0.6, 0.3],
55 | [0.1, 0.6, 0.9],
56 | [0.1, 0.6, 0.9],
57 | [0.1, 0.6, 0.5],
58 | [0.9, 0.8, 0.7],
59 | ])
60 |
61 | # Check results
62 | expected = 18.0/27.0
63 | result = metriks.precision_at_k(y_true, y_prob, 3)
64 | np.testing.assert_allclose(result, expected)
65 |
66 | print(result)
67 |
68 |
69 | def test_precision_at_k_perfect():
70 | """Test precision_at_k where the probabilities are perfect"""
71 | y_true = np.array([
72 | [0, 1, 0],
73 | [1, 0, 0],
74 | [0, 1, 0],
75 | [1, 0, 0],
76 | [0, 0, 1],
77 | ])
78 |
79 | y_prob = np.array([
80 | [0.3, 0.7, 0.0],
81 | [0.1, 0.0, 0.0],
82 | [0.1, 0.5, 0.3],
83 | [0.6, 0.2, 0.4],
84 | [0.1, 0.2, 0.3],
85 | ])
86 |
87 | # Check results
88 | expected = 1.0
89 | result = metriks.precision_at_k(y_true, y_prob, 1)
90 | np.testing.assert_allclose(result, expected)
91 |
92 | print(result)
93 |
94 |
95 | def test_precision_at_k_perfect_multiple_true():
96 | """Test precision_at_k where the probabilities are perfect and there
97 | are multiple true labels for some samples"""
98 | y_true = np.array([
99 | [1, 1, 0],
100 | [1, 0, 0],
101 | [0, 1, 1],
102 | [1, 0, 1],
103 | [0, 1, 1],
104 | ])
105 |
106 | y_prob = np.array([
107 | [0.3, 0.7, 0.0],
108 | [0.1, 0.0, 0.0],
109 | [0.1, 0.5, 0.3],
110 | [0.6, 0.2, 0.4],
111 | [0.1, 0.2, 0.3],
112 | ])
113 |
114 | # Check results for k=1
115 | expected = 1.0
116 | result = metriks.precision_at_k(y_true, y_prob, 1)
117 | np.testing.assert_allclose(result, expected)
118 |
119 | print(result)
120 |
121 | # Check results for k=2
122 | expected = 9.0/10.0
123 | result = metriks.precision_at_k(y_true, y_prob, 2)
124 | np.testing.assert_allclose(result, expected)
125 |
126 | print(result)
127 |
128 |
129 | def test_precision_at_k_few_zeros():
130 | """Test precision_at_k where there are very few zeros"""
131 | y_true = np.array([
132 | [0, 1, 1, 1, 1],
133 | [1, 1, 1, 1, 1],
134 | [1, 1, 1, 1, 0]
135 | ])
136 |
137 | y_prob = np.array([
138 | [0.1, 0.4, 0.35, 0.8, 0.9],
139 | [0.3, 0.2, 0.7, 0.8, 0.6],
140 | [0.1, 0.2, 0.3, 0.4, 0.5]
141 | ])
142 |
143 | # Check results
144 | expected = 5.0/6.0
145 | result = metriks.precision_at_k(y_true, y_prob, 2)
146 | np.testing.assert_allclose(result, expected)
147 |
148 | print(result)
149 |
150 |
151 | def test_precision_at_k_zeros():
152 | """Test precision_at_k where there is a sample of all zeros"""
153 | y_true = np.array([
154 | [0, 0, 1, 1],
155 | [1, 1, 1, 0],
156 | [0, 0, 0, 0]
157 | ])
158 |
159 | y_prob = np.array([
160 | [0.1, 0.4, 0.35, 0.8],
161 | [0.3, 0.2, 0.7, 0.8],
162 | [0.1, 0.2, 0.3, 0.4]
163 | ])
164 |
165 | # Check results
166 | expected = 4.0/9.0
167 | result = metriks.precision_at_k(y_true, y_prob, 3)
168 | np.testing.assert_allclose(result, expected)
169 |
170 | print(result)
171 |
172 |
173 | if __name__ == '__main__':
174 | test_precision_at_k_wikipedia()
175 | test_precision_at_k_with_errors()
176 | test_precision_at_k_perfect()
177 | test_precision_at_k_perfect_multiple_true()
178 | test_precision_at_k_few_zeros()
179 | test_precision_at_k_zeros()
180 |
--------------------------------------------------------------------------------
/tests/test_recall_at_k.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import metriks
4 |
5 |
6 | def test_recall_at_k_wikipedia():
7 | """
8 | Tests the wikipedia example.
9 |
10 | https://en.wikipedia.org/wiki/Mean_reciprocal_rank
11 | """
12 |
13 | # Flag indicating which is actually relevant
14 | y_true = np.array([
15 | [0, 0, 1],
16 | [0, 1, 0],
17 | [1, 0, 0],
18 | ])
19 |
20 | # The predicted probability
21 | y_prob = np.array([
22 | [0.4, 0.6, 0.3],
23 | [0.1, 0.2, 0.9],
24 | [0.9, 0.6, 0.3],
25 | ])
26 |
27 | # Check results
28 | expected = 2.0/3.0
29 | result = metriks.recall_at_k(y_true, y_prob, 2)
30 | np.testing.assert_allclose(result, expected)
31 |
32 | print(result)
33 |
34 |
35 | def test_recall_at_k_with_errors():
36 | """Test recall_at_k where there are errors in the probabilities"""
37 | y_true = np.array([
38 | [1, 1, 0],
39 | [1, 0, 1],
40 | [1, 1, 0],
41 | [1, 0, 1],
42 | [0, 1, 0],
43 | [0, 1, 1],
44 | [0, 1, 1],
45 | [1, 0, 1],
46 | [1, 1, 1],
47 | ])
48 |
49 | y_prob = np.array([
50 | [0.7, 0.6, 0.3],
51 | [0.9, 0.2, 0.1], # 3rd probability is an error
52 | [0.7, 0.8, 0.9], # 3rd probability is an error
53 | [0.9, 0.8, 0.3], # 2nd and 3rd probability are swapped
54 | [0.4, 0.6, 0.3],
55 | [0.1, 0.6, 0.9],
56 | [0.1, 0.6, 0.9],
57 | [0.1, 0.6, 0.5],
58 | [0.9, 0.8, 0.7],
59 | ])
60 |
61 | # Check results
62 | expected = 40.0/54.0
63 | result = metriks.recall_at_k(y_true, y_prob, 2)
64 | np.testing.assert_allclose(result, expected)
65 |
66 | print(result)
67 |
68 |
69 | def test_recall_at_k_perfect():
70 | """Test recall_at_k where the probabilities are perfect"""
71 | y_true = np.array([
72 | [0, 1, 0],
73 | [1, 0, 0],
74 | [0, 1, 0],
75 | [1, 0, 0],
76 | [0, 0, 1],
77 | ])
78 |
79 | y_prob = np.array([
80 | [0.3, 0.7, 0.0],
81 | [0.1, 0.0, 0.0],
82 | [0.1, 0.5, 0.3],
83 | [0.6, 0.2, 0.4],
84 | [0.1, 0.2, 0.3],
85 | ])
86 |
87 | # Check results
88 | expected = 1.0
89 | result = metriks.recall_at_k(y_true, y_prob, 1)
90 | np.testing.assert_allclose(result, expected)
91 |
92 | print(result)
93 |
94 |
95 | def test_recall_at_k_perfect_multiple_true():
96 | """Test recall_at_k where the probabilities are perfect and there
97 | are multiple true labels for some samples"""
98 | y_true = np.array([
99 | [1, 1, 0],
100 | [1, 0, 0],
101 | [0, 1, 1],
102 | [1, 0, 1],
103 | [0, 1, 1],
104 | ])
105 |
106 | y_prob = np.array([
107 | [0.3, 0.7, 0.0],
108 | [0.1, 0.0, 0.0],
109 | [0.1, 0.5, 0.3],
110 | [0.6, 0.2, 0.4],
111 | [0.1, 0.2, 0.3],
112 | ])
113 |
114 | # Check results for k=1
115 | expected = 6.0/10.0
116 | result = metriks.recall_at_k(y_true, y_prob, 1)
117 | np.testing.assert_allclose(result, expected)
118 |
119 | print(result)
120 |
121 | # Check results for k=2
122 | expected = 1.0
123 | result = metriks.recall_at_k(y_true, y_prob, 2)
124 | np.testing.assert_allclose(result, expected)
125 |
126 | print(result)
127 |
128 |
129 | def test_recall_at_k_few_zeros():
130 | """Test recall_at_k where there are very few zeros"""
131 | y_true = np.array([
132 | [0, 1, 1, 1, 1],
133 | [1, 1, 1, 1, 1],
134 | [1, 1, 1, 1, 0]
135 | ])
136 |
137 | y_prob = np.array([
138 | [0.1, 0.4, 0.35, 0.8, 0.9],
139 | [0.3, 0.2, 0.7, 0.8, 0.6],
140 | [0.1, 0.2, 0.3, 0.4, 0.5]
141 | ])
142 |
143 | # Check results
144 | expected = 23.0/60.0
145 | result = metriks.recall_at_k(y_true, y_prob, 2)
146 | np.testing.assert_allclose(result, expected)
147 |
148 | print(result)
149 |
150 |
151 | def test_recall_at_k_remove_zeros():
152 | """Test recall_at_k where there is a sample of all zeros, which is filtered out"""
153 | y_true = np.array([
154 | [0, 0, 1, 1],
155 | [1, 1, 1, 0],
156 | [0, 0, 0, 0]
157 | ])
158 |
159 | y_prob = np.array([
160 | [0.1, 0.4, 0.35, 0.8],
161 | [0.3, 0.2, 0.7, 0.8],
162 | [0.1, 0.2, 0.3, 0.4]
163 | ])
164 |
165 | # Check results
166 | expected = 5.0/12.0
167 | result = metriks.recall_at_k(y_true, y_prob, 2)
168 | np.testing.assert_allclose(result, expected)
169 |
170 | print(result)
171 |
172 |
173 | if __name__ == '__main__':
174 | test_recall_at_k_wikipedia()
175 | test_recall_at_k_with_errors()
176 | test_recall_at_k_perfect()
177 | test_recall_at_k_perfect_multiple_true()
178 | test_recall_at_k_few_zeros()
179 | test_recall_at_k_remove_zeros()
180 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 |
2 | # tox (https://tox.readthedocs.io/) is a tool for running tests
3 | # in multiple virtualenvs. This configuration file will run the
4 | # test suite on all supported python versions. To use it, "pip install tox"
5 | # and then run "tox" from this directory.
6 |
7 | [tox]
8 | envlist = py36
9 |
10 | [testenv]
11 | deps = coverage
12 | pytest
13 | commands =
14 | coverage run -m pytest
15 | coverage html --include='metriks/*'
16 |
--------------------------------------------------------------------------------