├── .github ├── CODEOWNERS └── CODE_OF_CONDUCT.md ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.rst ├── examples └── README.md ├── logo └── metriks-logo.svg ├── metriks ├── __init__.py └── ranking.py ├── setup.py ├── tests ├── __init__.py ├── test_confusion_matrix_at_k.py ├── test_label_mean_reciprocal_rank.py ├── test_mean_reciprocal_rank.py ├── test_ndcg.py ├── test_precision_at_k.py └── test_recall_at_k.py └── tox.ini /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # List of source code paths and code owners 2 | # common services & repos 3 | * @wontonswaggie, @karenclo, @jonathanlunt, @crystal-chung, @seedambiereed, @rdedhia -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | Open source projects are “living.” Contributions in the form of issues and pull requests are welcomed and encouraged. When you contribute, you explicitly say you are part of the community and abide by its Code of Conduct. 2 | 3 | # The Code 4 | 5 | At Intuit, we foster a kind, respectful, harassment-free cooperative community. Our open source community works to: 6 | 7 | - Be kind and respectful; 8 | - Act as a global community; 9 | - Conduct ourselves professionally. 10 | 11 | As members of this community, we will not tolerate behaviors including, but not limited to: 12 | 13 | - Violent threats or language; 14 | - Discriminatory or derogatory jokes or language; 15 | - Public or private harassment of any kind; 16 | - Other conduct considered inappropriate in a professional setting. 17 | 18 | ## Reporting Concerns 19 | 20 | If you see someone violating the Code of Conduct please email TechOpenSource@intuit.com 21 | 22 | ## Scope 23 | 24 | This code of conduct applies to: 25 | 26 | All repos and communities for Intuit-managed projects, whether or not the text is included in a Intuit-managed project’s repository; 27 | 28 | Individuals or teams representing projects in official capacity, such as via official social media channels or at in-person meetups. 29 | 30 | ## Attribution 31 | 32 | This Code of Conduct is partly inspired by and based on those of Amazon, CocoaPods, GitHub, Microsoft, thoughtbot, and on the Contributor Covenant version 1.4.1. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache 2 | # Mac 3 | .DS_Store 4 | 5 | ### Python ### 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | wheelhouse/ 32 | cover/ 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *,cover 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | 61 | # Sphinx documentation 62 | docs/_build/ 63 | 64 | # PyBuilder 65 | target/ 66 | 67 | 68 | ### PyCharm ### 69 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio 70 | 71 | *.iml 72 | 73 | ## Directory-based project format: 74 | .idea/ 75 | # if you remove the above rule, at least ignore the following: 76 | 77 | # User-specific stuff: 78 | # .idea/workspace.xml 79 | # .idea/tasks.xml 80 | # .idea/dictionaries 81 | 82 | # Sensitive or high-churn files: 83 | # .idea/dataSources.ids 84 | # .idea/dataSources.xml 85 | # .idea/sqlDataSources.xml 86 | # .idea/dynamic.xml 87 | # .idea/uiDesigner.xml 88 | 89 | # Gradle: 90 | # .idea/gradle.xml 91 | # .idea/libraries 92 | 93 | # Mongo Explorer plugin: 94 | # .idea/mongoSettings.xml 95 | 96 | ## File-based project format: 97 | *.ipr 98 | *.iws 99 | 100 | ## Plugin-specific files: 101 | 102 | # IntelliJ 103 | /out/ 104 | 105 | # mpeltonen/sbt-idea plugin 106 | .idea_modules/ 107 | 108 | # JIRA plugin 109 | atlassian-ide-plugin.xml 110 | 111 | # Crashlytics plugin (for Android Studio and IntelliJ) 112 | com_crashlytics_export_strings.xml 113 | crashlytics.properties 114 | crashlytics-build.properties 115 | 116 | # VS Code and virtual env 117 | .vscode/* 118 | venv/* 119 | __version.py 120 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.6" 4 | install: 5 | - pip install pytest-cov 6 | - pip install coveralls 7 | - pip install . 8 | script: pytest --cov=metriks 9 | after_success: 10 | - coveralls -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2019 Intuit Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to use, 8 | copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 9 | Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 16 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 17 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 18 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 20 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. image:: logo/metriks-logo.svg 2 | 3 | |python| |build| |coverage| 4 | 5 | .. |python| image:: https://img.shields.io/badge/python-3.6%20-blue.svg 6 | :target: https://www.python.org/downloads/release/python-360/ 7 | :alt: Python Version 8 | 9 | .. |build| image:: https://travis-ci.com/intuit/metriks.svg?branch=master 10 | :target: https://travis-ci.com/intuit/metriks 11 | 12 | .. |coverage| image:: https://coveralls.io/repos/github/intuit/metriks/badge.svg?branch=master 13 | :target: https://coveralls.io/github/intuit/metriks?branch=master 14 | 15 | ----- 16 | 17 | metriks is a Python package of commonly used metrics for evaluating information retrieval models. 18 | 19 | Available Metrics 20 | --------------------------- 21 | +------------------------------------------------------------+-------------------------------------------------------------------------------+ 22 | | Python API | Description | 23 | +============================================================+===============================================================================+ 24 | | `metriks.recall_at_k(y_true, y_prob, k)` | Calculates recall at k for binary classification ranking problems. | 25 | +------------------------------------------------------------+-------------------------------------------------------------------------------+ 26 | | `metriks.precision_at_k(y_true, y_prob, k)` | Calculates precision at k for binary classification ranking problems. | 27 | +------------------------------------------------------------+-------------------------------------------------------------------------------+ 28 | | `metriks.mean_reciprocal_rank(y_true, y_prob)` | Gets a positional score on how well you did at rank 1, rank 2, etc. | 29 | +------------------------------------------------------------+-------------------------------------------------------------------------------+ 30 | | `metriks.ndcg(y_true, y_prob, k)` | A score for measuring the quality of a set of ranked results. | 31 | +------------------------------------------------------------+-------------------------------------------------------------------------------+ 32 | | `metriks.label_mean_reciprocal_rank(y_true, y_prob)` | Determines the average rank each label was placed across samples. Only labels | 33 | | | that are relevant in the true data set are considered in the calculation. | 34 | +------------------------------------------------------------+-------------------------------------------------------------------------------+ 35 | | `metriks.confusion_matrix_at_k(y_true, y_prob, k)` | Generates binary predictions from probabilities by evaluating the top k | 36 | | | items (in ranked order by y_prob) as true. | 37 | +------------------------------------------------------------+-------------------------------------------------------------------------------+ 38 | 39 | Installation 40 | ------------ 41 | Install using `pip `_ 42 | :: 43 | 44 | pip install metriks 45 | 46 | Alternatively, specific distributions can be downloaded from the 47 | github `release `_ 48 | page. Once downloaded, install the ``.tar.gz`` file directly: 49 | :: 50 | 51 | pip install metriks-\*.tar.gz 52 | 53 | Development 54 | ----------- 55 | 1. (*Optional*) If you have `virtualenv` and `virtualenvwrapper` create a new virtual environment: 56 | :: 57 | 58 | mkvirtualenv metriks 59 | 60 | This isolates your specific project dependencies to avoid conflicts 61 | with other projects. 62 | 63 | 2. Clone and install the repository: 64 | :: 65 | 66 | git clone git@github.com:intuit/metriks.git 67 | cd metriks 68 | pip install -e . 69 | 70 | 71 | This will install a version to an isolated environment in editable 72 | mode. As you update the code in the repository, the new code will 73 | immediately be available to run within the environment (without the 74 | need to `pip install` it again) 75 | 76 | 3. Run the tests using `tox`: 77 | :: 78 | 79 | pip install tox 80 | tox 81 | 82 | Tox will run all of the tests in isolated environments 83 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Example Jupyter Notebooks 3 | 4 | This folder will have examples on how metriks can be used for calculating precision and recall on ranking models. 5 | Please create subfolders for each example depending on your dataset. 6 | 7 | Steps included in all examples are, 8 | 9 | 1. A default dataset has been choosen. 10 | 2. Trained a ranking model with the dataset choosen. 11 | 3. Finally demonstrated how these metrics can be used on our dataset, 12 | 13 | | Python API | Description | 14 | +============================================================+===============================================================================+ 15 | | `metriks.recall_at_k(y_true, y_prob, k)` | Calculates recall at k for binary classification ranking problems. | 16 | +------------------------------------------------------------+-------------------------------------------------------------------------------+ 17 | | `metriks.precision_at_k(y_true, y_prob, k)` | Calculates precision at k for binary classification ranking problems. | 18 | +------------------------------------------------------------+-------------------------------------------------------------------------------+ -------------------------------------------------------------------------------- /logo/metriks-logo.svg: -------------------------------------------------------------------------------- 1 | metriks-logom e t r i k s -------------------------------------------------------------------------------- /metriks/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from metriks.__version import __version__ 3 | except ImportError: # pragma: no cover 4 | __version__ = "dev" 5 | 6 | from metriks.ranking import ( 7 | recall_at_k, 8 | precision_at_k, 9 | mean_reciprocal_rank, 10 | label_mean_reciprocal_rank, 11 | ndcg, 12 | confusion_matrix_at_k, 13 | ) 14 | -------------------------------------------------------------------------------- /metriks/ranking.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ranking 3 | ======= 4 | Metrics to use for ranking models. 5 | """ 6 | 7 | import numpy as np 8 | 9 | 10 | def check_arrays(y_true, y_prob): 11 | # Make sure that inputs this conforms to our expectations 12 | assert isinstance(y_true, np.ndarray), AssertionError( 13 | 'Expect y_true to be a {expected}. Got {actual}' 14 | .format(expected=np.ndarray, actual=type(y_true)) 15 | ) 16 | 17 | assert isinstance(y_prob, np.ndarray), AssertionError( 18 | 'Expect y_prob to be a {expected}. Got {actual}' 19 | .format(expected=np.ndarray, actual=type(y_prob)) 20 | ) 21 | 22 | assert y_true.shape == y_prob.shape, AssertionError( 23 | 'Shapes must match. Got y_true={true_shape}, y_prob={prob_shape}' 24 | .format(true_shape=y_true.shape, prob_shape=y_prob.shape) 25 | ) 26 | 27 | assert len(y_true.shape) == 2, AssertionError( 28 | 'Shapes should be of rank 2. Got {rank}' 29 | .format(rank=len(y_true.shape)) 30 | ) 31 | 32 | uniques = np.unique(y_true) 33 | assert len(uniques) <= 2, AssertionError( 34 | 'Expected labels: [0, 1]. Got: {uniques}' 35 | .format(uniques=uniques) 36 | ) 37 | 38 | 39 | def check_k(n_items, k): 40 | # Make sure that inputs conform to our expectations 41 | assert isinstance(k, int), AssertionError( 42 | 'Expect k to be a {expected}. Got {actual}' 43 | .format(expected=int, actual=type(k)) 44 | ) 45 | 46 | assert 0 <= k <= n_items, AssertionError( 47 | 'Expect 0 <= k <= {n_items}. Got {k}' 48 | .format(n_items=n_items, k=k) 49 | ) 50 | 51 | 52 | def recall_at_k(y_true, y_prob, k): 53 | """ 54 | Calculates recall at k for binary classification ranking problems. Recall 55 | at k measures the proportion of total relevant items that are found in the 56 | top k (in ranked order by y_prob). If k=5, there are 6 total relevant documents, 57 | and 3 of the top 5 items are relevant, the recall at k will be 0.5. 58 | 59 | Samples where y_true is 0 for all labels are filtered out because there will be 60 | 0 true positives and false negatives. 61 | 62 | Args: 63 | y_true (~np.ndarray): Flags (0, 1) which indicate whether a column is 64 | relevant or not. size=(n_samples, n_items) 65 | y_prob (~np.ndarray): The predicted probability that the given flag 66 | is relevant. size=(n_samples, n_items) 67 | k (int): Number of items to evaluate for relevancy, in descending 68 | sorted order by y_prob 69 | 70 | Returns: 71 | recall (~np.ndarray): The recall at k 72 | 73 | Example: 74 | >>> y_true = np.array([ 75 | [0, 0, 1], 76 | [0, 1, 0], 77 | [1, 0, 0], 78 | ]) 79 | >>> y_prob = np.array([ 80 | [0.4, 0.6, 0.3], 81 | [0.1, 0.2, 0.9], 82 | [0.9, 0.6, 0.3], 83 | ]) 84 | >>> recall_at_k(y_true, y_prob, 2) 85 | 0.6666666666666666 86 | 87 | In the example above, each of the samples has 1 total relevant document. 88 | For the first sample, there are 0 relevant documents in the top k for k=2, 89 | because 0.3 is the 3rd value for y_prob in descending order. For the second 90 | sample, there is 1 relevant document in the top k, because 0.2 is the 2nd 91 | value for y_prob in descending order. For the third sample, there is 1 92 | relevant document in the top k, because 0.9 is the 1st value for y_prob in 93 | descending order. Averaging the values for all of these samples (0, 1, 1) 94 | gives a value for recall at k of 2/3. 95 | """ 96 | check_arrays(y_true, y_prob) 97 | check_k(y_true.shape[1], k) 98 | 99 | # Filter out rows of all zeros 100 | mask = y_true.sum(axis=1).astype(bool) 101 | y_prob = y_prob[mask] 102 | y_true = y_true[mask] 103 | 104 | # Extract shape components 105 | n_samples, n_items = y_true.shape 106 | 107 | # List of locations indexing 108 | y_prob_index_order = np.argsort(-y_prob) 109 | rows = np.reshape(np.arange(n_samples), (-1, 1)) 110 | ranking = y_true[rows, y_prob_index_order] 111 | 112 | # Calculate number true positives for numerator and number of relevant documents for denominator 113 | num_tp = np.sum(ranking[:, :k], axis=1) 114 | num_relevant = np.sum(ranking, axis=1) 115 | # Calculate recall at k 116 | recall = np.mean(num_tp / num_relevant) 117 | 118 | return recall 119 | 120 | 121 | def precision_at_k(y_true, y_prob, k): 122 | """ 123 | Calculates precision at k for binary classification ranking problems. 124 | Precision at k measures the proportion of items in the top k (in ranked 125 | order by y_prob) that are relevant (as defined by y_true). If k=5, and 126 | 3 of the top 5 items are relevant, the precision at k will be 0.6. 127 | 128 | Args: 129 | y_true (~np.ndarray): Flags (0, 1) which indicate whether a column is 130 | relevant or not. size=(n_samples, n_items) 131 | y_prob (~np.ndarray): The predicted probability that the given flag 132 | is relevant. size=(n_samples, n_items) 133 | k (int): Number of items to evaluate for relevancy, in descending 134 | sorted order by y_prob 135 | 136 | Returns: 137 | precision_k (~np.ndarray): The precision at k 138 | 139 | Example: 140 | >>> y_true = np.array([ 141 | [0, 0, 1], 142 | [0, 1, 0], 143 | [1, 0, 0], 144 | ]) 145 | >>> y_prob = np.array([ 146 | [0.4, 0.6, 0.3], 147 | [0.1, 0.2, 0.9], 148 | [0.9, 0.6, 0.3], 149 | ]) 150 | >>> precision_at_k(y_true, y_prob, 2) 151 | 0.3333333333333333 152 | 153 | For the first sample, there are 0 relevant documents in the top k for k=2, 154 | because 0.3 is the 3rd value for y_prob in descending order. For the second 155 | sample, there is 1 relevant document in the top k, because 0.2 is the 2nd 156 | value for y_prob in descending order. For the third sample, there is 1 157 | relevant document in the top k, because 0.9 is the 1st value for y_prob in 158 | descending order. Because k=2, the values for precision of k for each sample 159 | are 0, 1/2, and 1/2 respectively. Averaging these gives a value for precision 160 | at k of 1/3. 161 | """ 162 | check_arrays(y_true, y_prob) 163 | check_k(y_true.shape[1], k) 164 | 165 | # Extract shape components 166 | n_samples, n_items = y_true.shape 167 | 168 | # List of locations indexing 169 | y_prob_index_order = np.argsort(-y_prob) 170 | rows = np.reshape(np.arange(n_samples), (-1, 1)) 171 | ranking = y_true[rows, y_prob_index_order] 172 | 173 | # Calculate number of true positives for numerator 174 | num_tp = np.sum(ranking[:, :k], axis=1) 175 | # Calculate precision at k 176 | precision = np.mean(num_tp / k) 177 | 178 | return precision 179 | 180 | 181 | def mean_reciprocal_rank(y_true, y_prob): 182 | """ 183 | Gets a positional score about how well you did at rank 1, rank 2, 184 | etc. The resulting vector is of size (n_items,) but element 0 corresponds to 185 | rank 1 not label 0. 186 | 187 | Args: 188 | y_true (~np.ndarray): Flags (0, 1) which indicate whether a column is 189 | relevant or not. size=(n_samples, n_items) 190 | y_prob (~np.ndarray): The predicted probability that the given flag 191 | is relevant. size=(n_samples, n_items) 192 | 193 | Returns: 194 | mrr (~np.ma.array): The positional ranking score. This will be masked 195 | for ranks where there were no relevant values. size=(n_items,) 196 | """ 197 | 198 | check_arrays(y_true, y_prob) 199 | 200 | # Extract shape components 201 | n_samples, n_items = y_true.shape 202 | 203 | # Determine the ranking order 204 | rank_true = np.flip(np.argsort(y_true, axis=1), axis=1) 205 | rank_prob = np.flip(np.argsort(y_prob, axis=1), axis=1) 206 | 207 | # Compute reciprocal ranks 208 | reciprocal = 1.0 / (np.argsort(rank_prob, axis=1) + 1) 209 | 210 | # Now order the reciprocal ranks by the true order 211 | rows = np.reshape(np.arange(n_samples), (-1, 1)) 212 | cols = rank_true 213 | ordered = reciprocal[rows, cols] 214 | 215 | # Create a masked array of true labels only 216 | ma = np.ma.array(ordered, mask=np.isclose(y_true[rows, cols], 0)) 217 | return ma.mean(axis=0) 218 | 219 | 220 | def label_mean_reciprocal_rank(y_true, y_prob): 221 | """ 222 | Determines the average rank each label was placed across samples. Only labels that are 223 | relevant in the true data set are considered in the calculation. 224 | 225 | Args: 226 | y_true (~np.ndarray): Flags (0, 1) which indicate whether a column is 227 | relevant or not. size=(n_samples, n_items)that 228 | y_prob (~np.ndarray): The predicted probability the given flag 229 | is relevant. size=(n_samples, n_items) 230 | Returns: 231 | mrr (~np.ma.array): The positional ranking score. This will be masked 232 | for ranks where there were no relevant values. size=(n_items,) 233 | """ 234 | 235 | check_arrays(y_true, y_prob) 236 | 237 | rank_prob = np.flip(np.argsort(y_prob, axis=1), axis=1) 238 | reciprocal = 1 / (np.argsort(rank_prob, axis=1) + 1) 239 | ma = np.ma.array(reciprocal, mask=~y_true.astype(bool)) 240 | 241 | return ma.mean(axis=0) 242 | 243 | 244 | def ndcg(y_true, y_prob, k=0): 245 | """ 246 | A score for measuring the quality of a set of ranked results. The resulting score is between 0 and 1.0 - 247 | results that are relevant and appear earlier in the result set are given a heavier weight, so the 248 | higher the score, the more relevant your results are 249 | 250 | The optional k param is recommended for data sets where the first few labels are almost always ranked first, 251 | and hence skew the overall score. To compute this "NDCG after k" metric, we remove the top k (by predicted 252 | probability) labels and compute NDCG as usual for the remaining labels. 253 | 254 | Args: 255 | y_true (~np.ndarray): Flags (0, 1) which indicate whether a column is 256 | relevant or not. size=(n_samples, n_items)that 257 | y_prob (~np.ndarray): The predicted probability the given flag 258 | is relevant. size=(n_samples, n_items) 259 | k (int): Optional, the top k classes to exclude 260 | Returns: 261 | ndcg (~np.float64): The normalized dcg score across all queries, excluding the top k 262 | """ 263 | # Get the sorted prob indices in descending order 264 | rank_prob = np.flip(np.argsort(y_prob, axis=1), axis=1) 265 | 266 | # Get the sorted true indices in descending order 267 | rank_true = np.flip(np.argsort(y_true, axis=1), axis=1) 268 | 269 | prob_samples, prob_items = y_prob.shape 270 | true_samples, true_items = y_true.shape 271 | 272 | # Compute DCG 273 | 274 | # Order y_true and y_prob by y_prob order indices 275 | prob_vals = y_prob[np.arange(prob_samples).reshape(prob_samples, 1), rank_prob] 276 | true_vals = y_true[np.arange(true_samples).reshape(true_samples, 1), rank_prob] 277 | 278 | # Remove the first k columns 279 | prob_vals = prob_vals[:, k:] 280 | true_vals = true_vals[:, k:] 281 | 282 | rank_prob_k = np.flip(np.argsort(prob_vals, axis=1), axis=1) 283 | 284 | n_samples, n_items = true_vals.shape 285 | 286 | values = np.arange(n_samples).reshape(n_samples, 1) 287 | 288 | # Construct the dcg numerator, which are the relevant items for each rank 289 | dcg_numerator = true_vals[values, rank_prob_k] 290 | 291 | # Construct the denominator, which is the log2 of the current rank + 1 292 | position = np.arange(1, n_items + 1) 293 | denominator = np.log2(np.tile(position, (n_samples, 1)) + 1.0) 294 | 295 | dcg = np.sum(dcg_numerator / denominator, axis=1) 296 | 297 | # Compute IDCG 298 | rank_true_idcg = np.flip(np.argsort(true_vals, axis=1), axis=1) 299 | 300 | idcg_true_samples, idcg_true_items = rank_true_idcg.shape 301 | 302 | # Order y_true indices 303 | idcg_true_vals = true_vals[np.arange(idcg_true_samples).reshape(idcg_true_samples, 1), rank_true_idcg] 304 | 305 | rank_true_k = np.flip(np.argsort(idcg_true_vals, axis=1), axis=1) 306 | 307 | idcg_numerator = idcg_true_vals[values, rank_true_k] 308 | 309 | idcg = np.sum(idcg_numerator / denominator, axis=1) 310 | 311 | with np.warnings.catch_warnings(): 312 | np.warnings.filterwarnings('ignore') 313 | sample_ndcg = np.divide(dcg, idcg) 314 | 315 | # ndcg may be NaN if idcg is 0; this happens when there are no relevant documents 316 | # in this case, showing anything in any order should be considered correct 317 | where_are_nans = np.isnan(sample_ndcg) 318 | where_are_infs = np.isinf(sample_ndcg) 319 | sample_ndcg[where_are_nans] = 1.0 320 | sample_ndcg[where_are_infs] = 0.0 321 | 322 | return np.mean(sample_ndcg, dtype=np.float64) 323 | 324 | 325 | def generate_y_pred_at_k(y_prob, k): 326 | """ 327 | Generates a matrix of binary predictions from a matrix of probabilities 328 | by evaluating the top k items (in ranked order by y_prob) as true. 329 | 330 | In the case where multiple probabilities for a sample are identical, the 331 | behavior is undefined in terms of how the probabilities are ranked by argsort. 332 | 333 | Args: 334 | y_prob (~np.ndarray): The predicted probability that the given flag 335 | is relevant. size=(n_samples, n_items) 336 | k (int): Number of items to evaluate as true, in descending 337 | sorted order by y_prob 338 | 339 | Returns: 340 | y_pred (~np.ndarray): A binary prediction that the given flag is 341 | relevant. size=(n_samples, n_items) 342 | 343 | Example: 344 | >>> y_prob = np.array([ 345 | [0.4, 0.6, 0.3], 346 | [0.1, 0.2, 0.9], 347 | [0.9, 0.6, 0.3], 348 | ]) 349 | >>> generate_y_pred_at_k(y_prob, 2) 350 | array([ 351 | [1, 1, 0], 352 | [0, 1, 1], 353 | [1, 1, 0] 354 | ]) 355 | 356 | For the first sample, the top 2 values for y_prob are 0.6 and 0.4, so y_pred 357 | at those positions is 1. For the second sample, the top 2 values for y_prob 358 | are 0.9 and 0.2, so y_pred at these positions is 1. For the third sample, the 359 | top 2 values for y_prob are 0.9 and 0.6, so y_pred at these positions in 1. 360 | """ 361 | n_items = y_prob.shape[1] 362 | index_array = np.argsort(y_prob, axis=1) 363 | col_idx = np.arange(y_prob.shape[0]).reshape(-1, 1) 364 | y_pred = np.zeros(np.shape(y_prob)) 365 | y_pred[col_idx, index_array[:, n_items-k:n_items]] = 1 366 | return y_pred 367 | 368 | 369 | def confusion_matrix_at_k(y_true, y_prob, k): 370 | """ 371 | Generates binary predictions from probabilities by evaluating the top k items 372 | (in ranked order by y_prob) as true. Uses these binary predictions along with 373 | true flags to calculate the confusion matrix per label for binary 374 | classification problems. 375 | 376 | Args: 377 | y_true (~np.ndarray): Flags (0, 1) which indicate whether a column is 378 | relevant or not. size=(n_samples, n_items) 379 | y_prob (~np.ndarray): The predicted probability that the given flag 380 | is relevant. size=(n_samples, n_items) 381 | k (int): Number of items to evaluate as true, in descending 382 | sorted order by y_prob 383 | 384 | Returns: 385 | tn, fp, fn, tp (tuple of ~np.ndarrays): A tuple of ndarrays containing 386 | the number of true negatives (tn), false positives (fp), 387 | false negatives (fn), and true positives (tp) for each item. The 388 | length of each ndarray is equal to n_items 389 | 390 | Example: 391 | >>> y_true = np.array([ 392 | [0, 0, 1], 393 | [0, 1, 0], 394 | [1, 0, 0], 395 | ]) 396 | >>> y_prob = np.array([ 397 | [0.4, 0.6, 0.3], 398 | [0.1, 0.2, 0.9], 399 | [0.9, 0.6, 0.3], 400 | ]) 401 | >>> y_pred = np.array([ 402 | [1, 1, 0], 403 | [0, 1, 1], 404 | [1, 1, 0] 405 | ]) 406 | >>> label_names = ['moved', 'hadAJob', 'farmIncome'] 407 | >>> confusion_matrix_at_k(y_true, y_prob, 2) 408 | ( 409 | np.array([1, 0, 1]), 410 | np.array([1, 2, 1]), 411 | np.array([0, 0, 1]), 412 | np.array([1, 1, 0]) 413 | ) 414 | 415 | In the example above, y_pred is not passed into the function, but is 416 | generated by calling generate_y_pred_at_k with y_prob and k. 417 | 418 | For the first item (moved), the first sample is a false positive, the 419 | second is a true negative, and the third is a true positive. 420 | For the second item (hadAJob), the first and third samples are false 421 | positives, and the second is a true positive. 422 | For the third item (farmIncome), the first item is a false negative, the 423 | second is a false positive, and the third is a true positive. 424 | """ 425 | check_arrays(y_true, y_prob) 426 | check_k(y_true.shape[1], k) 427 | 428 | y_pred = generate_y_pred_at_k(y_prob, k) 429 | 430 | tp = np.count_nonzero(y_pred * y_true, axis=0) 431 | tn = np.count_nonzero((y_pred - 1) * (y_true - 1), axis=0) 432 | fp = np.count_nonzero(y_pred * (y_true - 1), axis=0) 433 | fn = np.count_nonzero((y_pred - 1) * y_true, axis=0) 434 | 435 | return tn, fp, fn, tp 436 | 437 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('README.rst', encoding='utf-8') as fh: 4 | long_description = fh.read() 5 | 6 | setup( 7 | # Package information 8 | name='metriks', 9 | description='metriks is a Python package of commonly used metrics for evaluating information retrieval models.', 10 | long_description=long_description, 11 | long_description_content_type='text/x-rst', 12 | url='https://github.com/intuit/metriks', 13 | classifiers=[ 14 | 'Programming Language :: Python :: 3', 15 | 'License :: OSI Approved :: MIT License' 16 | ], 17 | pyton_requires='>=3.6', 18 | 19 | # Package data 20 | use_scm_version={ 21 | "write_to": "metriks/__version.py", 22 | "write_to_template": '__version__ = "{version}"\n', 23 | }, 24 | packages=find_packages(exclude=('tests*',)), 25 | 26 | # Insert dependencies list here 27 | install_requires=[ 28 | 'numpy', 29 | ], 30 | setup_requires=["setuptools-scm"], 31 | extras_require={ 32 | 'dev': [ 33 | 'pytest', 34 | 'tox', 35 | ] 36 | } 37 | ) 38 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intuit/metriks/a0e62df847b8997bfc32b337d7e41db16d0a8ce4/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_confusion_matrix_at_k.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import metriks 3 | 4 | 5 | def test_confusion_matrix_at_k_wikipedia(): 6 | """ 7 | Tests the wikipedia example. 8 | 9 | https://en.wikipedia.org/wiki/Mean_reciprocal_rank 10 | """ 11 | 12 | # Flag indicating which is actually relevant 13 | y_true = np.array([ 14 | [0, 0, 1], 15 | [0, 1, 0], 16 | [1, 0, 0], 17 | ]) 18 | 19 | # The predicted probability 20 | y_prob = np.array([ 21 | [0.4, 0.6, 0.3], 22 | [0.1, 0.2, 0.9], 23 | [0.9, 0.6, 0.3], 24 | ]) 25 | 26 | expected = ( 27 | np.array([1, 0, 1]), 28 | np.array([1, 2, 1]), 29 | np.array([0, 0, 1]), 30 | np.array([1, 1, 0]) 31 | ) 32 | result = metriks.confusion_matrix_at_k(y_true, y_prob, 2) 33 | for i in range(4): 34 | assert expected[i].all() == result[i].all(), AssertionError( 35 | 'Expected:\n{expected}. \nGot:\n{actual}' 36 | .format(expected=expected, actual=result) 37 | ) 38 | 39 | print(result) 40 | 41 | 42 | def test_confusion_matrix_at_k_with_errors(): 43 | """Test confusion_matrix_at_k where there are errors in the probabilities""" 44 | y_true = np.array([ 45 | [1, 1, 0], 46 | [1, 0, 1], 47 | [1, 1, 0], 48 | [1, 0, 1], 49 | [0, 1, 0], 50 | [0, 1, 1], 51 | [0, 1, 1], 52 | [1, 0, 1], 53 | [1, 1, 1], 54 | ]) 55 | 56 | y_prob = np.array([ 57 | [0.7, 0.6, 0.3], 58 | [0.9, 0.2, 0.1], # 3rd probability is an error 59 | [0.7, 0.8, 0.9], # 3rd probability is an error 60 | [0.9, 0.8, 0.3], # 2nd and 3rd probability are swapped 61 | [0.4, 0.6, 0.3], 62 | [0.1, 0.6, 0.9], 63 | [0.1, 0.6, 0.9], 64 | [0.1, 0.6, 0.5], 65 | [0.9, 0.8, 0.7], 66 | ]) 67 | 68 | # Check results 69 | expected = ( 70 | np.array([2, 0, 2]), 71 | np.array([1, 3, 1]), 72 | np.array([2, 0, 3]), 73 | np.array([4, 6, 3]) 74 | ) 75 | result = metriks.confusion_matrix_at_k(y_true, y_prob, 2) 76 | for i in range(4): 77 | assert expected[i].all() == result[i].all(), AssertionError( 78 | 'Expected:\n{expected}. \nGot:\n{actual}' 79 | .format(expected=expected, actual=result) 80 | ) 81 | 82 | print(result) 83 | 84 | 85 | def test_confusion_matrix_at_k_perfect(): 86 | """Test confusion_matrix_at_k where the probabilities are perfect""" 87 | y_true = np.array([ 88 | [0, 1, 0], 89 | [1, 0, 0], 90 | [0, 1, 0], 91 | [1, 0, 0], 92 | [0, 0, 1], 93 | ]) 94 | 95 | y_prob = np.array([ 96 | [0.3, 0.7, 0.0], 97 | [0.1, 0.0, 0.0], 98 | [0.1, 0.5, 0.3], 99 | [0.6, 0.2, 0.4], 100 | [0.1, 0.2, 0.3], 101 | ]) 102 | 103 | # Check results 104 | expected = ( 105 | np.array([3, 3, 4]), 106 | np.array([0, 0, 0]), 107 | np.array([0, 0, 0]), 108 | np.array([2, 2, 1]) 109 | ) 110 | result = metriks.confusion_matrix_at_k(y_true, y_prob, 1) 111 | for i in range(4): 112 | assert expected[i].all() == result[i].all(), AssertionError( 113 | 'Expected:\n{expected}. \nGot:\n{actual}' 114 | .format(expected=expected, actual=result) 115 | ) 116 | 117 | print(result) 118 | 119 | 120 | def test_confusion_matrix_at_k_perfect_multiple_true(): 121 | """Test confusion_matrix_at_k where the probabilities are perfect and there 122 | are multiple true labels for some samples""" 123 | y_true = np.array([ 124 | [1, 1, 0], 125 | [1, 0, 0], 126 | [0, 1, 1], 127 | [1, 0, 1], 128 | [0, 1, 1], 129 | ]) 130 | 131 | y_prob = np.array([ 132 | [0.3, 0.7, 0.0], 133 | [0.1, 0.05, 0.0], 134 | [0.1, 0.5, 0.3], 135 | [0.6, 0.2, 0.4], 136 | [0.1, 0.2, 0.3], 137 | ]) 138 | 139 | # Check results for k=0 140 | expected = ( 141 | np.array([2, 2, 2]), 142 | np.array([0, 0, 0]), 143 | np.array([3, 3, 3]), 144 | np.array([0, 0, 0]) 145 | ) 146 | result = metriks.confusion_matrix_at_k(y_true, y_prob, 0) 147 | for i in range(4): 148 | assert expected[i].all() == result[i].all(), AssertionError( 149 | 'Expected:\n{expected}. \nGot:\n{actual}' 150 | .format(expected=expected, actual=result) 151 | ) 152 | 153 | print(result) 154 | 155 | # Check results for k=2 156 | expected = ( 157 | np.array([2, 1, 2]), 158 | np.array([0, 1, 0]), 159 | np.array([0, 0, 0]), 160 | np.array([3, 3, 3]) 161 | ) 162 | result = metriks.confusion_matrix_at_k(y_true, y_prob, 2) 163 | for i in range(4): 164 | assert expected[i].all() == result[i].all(), AssertionError( 165 | 'Expected:\n{expected}. \nGot:\n{actual}' 166 | .format(expected=expected, actual=result) 167 | ) 168 | 169 | print(result) 170 | 171 | 172 | def test_confusion_matrix_at_k_few_zeros(): 173 | """Test confusion_matrix_at_k where there are very few zeros""" 174 | y_true = np.array([ 175 | [0, 1, 1, 1, 1], 176 | [1, 1, 1, 1, 1], 177 | [1, 1, 1, 1, 0] 178 | ]) 179 | 180 | y_prob = np.array([ 181 | [0.1, 0.4, 0.35, 0.8, 0.9], 182 | [0.3, 0.2, 0.7, 0.8, 0.6], 183 | [0.1, 0.2, 0.3, 0.4, 0.5] 184 | ]) 185 | 186 | # Check results for k=2 187 | expected = ( 188 | np.array([1, 0, 0, 0, 0]), 189 | np.array([0, 0, 0, 0, 1]), 190 | np.array([2, 3, 2, 0, 1]), 191 | np.array([0, 0, 1, 3, 1]) 192 | ) 193 | result = metriks.confusion_matrix_at_k(y_true, y_prob, 2) 194 | for i in range(4): 195 | assert expected[i].all() == result[i].all(), AssertionError( 196 | 'Expected:\n{expected}. \nGot:\n{actual}' 197 | .format(expected=expected, actual=result) 198 | ) 199 | 200 | print(result) 201 | 202 | # Check results for k=5 203 | expected = { 204 | 'moved': {'fp': 1, 'tn': 0, 'fn': 0, 'tp': 2}, 205 | 'hadAJob': {'fp': 0, 'tn': 0, 'fn': 0, 'tp': 3}, 206 | 'farmIncome': {'fp': 0, 'tn': 0, 'fn': 0, 'tp': 3}, 207 | 'married': {'fp': 0, 'tn': 0, 'fn': 0, 'tp': 3}, 208 | 'alimony': {'fp': 1, 'tn': 0, 'fn': 0, 'tp': 2} 209 | } 210 | expected = ( 211 | np.array([0, 0, 0, 0, 0]), 212 | np.array([1, 0, 0, 0, 1]), 213 | np.array([0, 0, 0, 0, 0]), 214 | np.array([2, 3, 3, 3, 2]) 215 | ) 216 | result = metriks.confusion_matrix_at_k(y_true, y_prob, 5) 217 | for i in range(4): 218 | assert expected[i].all() == result[i].all(), AssertionError( 219 | 'Expected:\n{expected}. \nGot:\n{actual}' 220 | .format(expected=expected, actual=result) 221 | ) 222 | 223 | print(result) 224 | 225 | 226 | def test_confusion_matrix_at_k_zeros(): 227 | """Test confusion_matrix_at_k where there is a sample of all zeros""" 228 | y_true = np.array([ 229 | [0, 0, 1, 1], 230 | [1, 1, 1, 0], 231 | [0, 0, 0, 0] 232 | ]) 233 | 234 | y_prob = np.array([ 235 | [0.1, 0.4, 0.35, 0.8], 236 | [0.3, 0.2, 0.7, 0.8], 237 | [0.1, 0.2, 0.3, 0.4] 238 | ]) 239 | 240 | # Check results 241 | expected = { 242 | 'moved': {'fp': 0, 'tn': 2, 'fn': 1, 'tp': 0}, 243 | 'hadAJob': {'fp': 1, 'tn': 1, 'fn': 1, 'tp': 0}, 244 | 'farmIncome': {'fp': 1, 'tn': 0, 'fn': 1, 'tp': 1}, 245 | 'married': {'fp': 2, 'tn': 0, 'fn': 0, 'tp': 1} 246 | } 247 | expected = ( 248 | np.array([2, 1, 0, 0]), 249 | np.array([0, 1, 1, 2]), 250 | np.array([1, 1, 1, 0]), 251 | np.array([0, 0, 1, 1]) 252 | ) 253 | result = metriks.confusion_matrix_at_k(y_true, y_prob, 2) 254 | for i in range(4): 255 | assert expected[i].all() == result[i].all(), AssertionError( 256 | 'Expected:\n{expected}. \nGot:\n{actual}' 257 | .format(expected=expected, actual=result) 258 | ) 259 | 260 | print(result) 261 | 262 | 263 | if __name__ == '__main__': 264 | test_confusion_matrix_at_k_wikipedia() 265 | test_confusion_matrix_at_k_with_errors() 266 | test_confusion_matrix_at_k_perfect() 267 | test_confusion_matrix_at_k_perfect_multiple_true() 268 | test_confusion_matrix_at_k_few_zeros() 269 | test_confusion_matrix_at_k_zeros() 270 | -------------------------------------------------------------------------------- /tests/test_label_mean_reciprocal_rank.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import metriks 4 | 5 | 6 | def test_label_mrr_with_errors(): 7 | y_true = np.array([[0, 1, 1, 0], 8 | [1, 1, 1, 0], 9 | [0, 1, 1, 1]]) 10 | y_prob = np.array([[0.2, 0.9, 0.8, 0.4], 11 | [0.9, 0.2, 0.8, 0.4], 12 | [0.2, 0.8, 0.9, 0.4]]) 13 | 14 | result = metriks.label_mean_reciprocal_rank(y_true, y_prob) 15 | 16 | expected = np.ma.array([1.0, (7/4) / 3, 2/3, 1/3]) 17 | 18 | np.testing.assert_allclose(result, expected) 19 | print(result) 20 | 21 | 22 | def test_label_mrr_perfect(): 23 | """Test MRR where the probabilities are perfect""" 24 | y_true = np.array([ 25 | [0, 1, 0], 26 | [1, 0, 0], 27 | [0, 1, 0], 28 | [1, 0, 0], 29 | [0, 0, 1], 30 | ]) 31 | 32 | y_prob = np.array([ 33 | [0.3, 0.7, 0.0], 34 | [0.1, 0.0, 0.0], 35 | [0.1, 0.5, 0.3], 36 | [0.6, 0.2, 0.4], 37 | [0.1, 0.2, 0.3], 38 | ]) 39 | 40 | # Check results 41 | expected = np.ma.array([1.0, 0.0, 0.0], mask=[False, True, True]) 42 | result = metriks.label_mean_reciprocal_rank(y_true, y_prob) 43 | np.testing.assert_allclose(result, expected) 44 | 45 | print(result) 46 | 47 | 48 | def test_label_mrr_zeros(): 49 | """Test MRR where no relevant labels""" 50 | y_true = np.array([ 51 | [0, 0, 0], 52 | [0, 0, 0], 53 | [0, 0, 0] 54 | ]) 55 | 56 | y_prob = np.array([ 57 | [0.3, 0.7, 0.0], 58 | [0.1, 0.0, 0.0], 59 | [0.1, 0.5, 0.3] 60 | ]) 61 | 62 | # Check results 63 | result = metriks.label_mean_reciprocal_rank(y_true, y_prob) 64 | np.testing.assert_equal(result.mask.all(), True) 65 | 66 | print(result) 67 | 68 | 69 | def test_label_mrr_some_zeros(): 70 | """Test MRR where some relevant labels""" 71 | y_true = np.array([ 72 | [0, 1, 0], 73 | [0, 0, 0], 74 | [0, 1, 0] 75 | ]) 76 | 77 | y_prob = np.array([ 78 | [0.3, 0.7, 0.0], 79 | [0.1, 0.0, 0.0], 80 | [0.1, 0.5, 0.3] 81 | ]) 82 | 83 | # Check results 84 | expected = np.ma.array([0.0, 1.0, 0.0], mask=[True, False, True]) 85 | result = metriks.label_mean_reciprocal_rank(y_true, y_prob) 86 | np.testing.assert_allclose(result, expected) 87 | 88 | print(result) 89 | 90 | 91 | def test_label_mrr_ones(): 92 | """Test MRR where all labels are relevant and predictions are perfect""" 93 | y_true = np.array([ 94 | [1, 1, 1], 95 | [1, 1, 1], 96 | [1, 1, 1] 97 | ]) 98 | 99 | y_prob = np.array([ 100 | [0.3, 0.7, 0.0], 101 | [0.1, 0.9, 0.0], 102 | [0.2, 0.5, 0.1] 103 | ]) 104 | 105 | # Check results 106 | expected = np.ma.array([0.5, 1.0, 1/3]) 107 | result = metriks.label_mean_reciprocal_rank(y_true, y_prob) 108 | np.testing.assert_allclose(result, expected) 109 | 110 | print(result) 111 | 112 | 113 | if __name__ == '__main__': 114 | test_label_mrr_with_errors() 115 | test_label_mrr_perfect() 116 | test_label_mrr_zeros() 117 | test_label_mrr_some_zeros() 118 | test_label_mrr_ones() 119 | 120 | 121 | -------------------------------------------------------------------------------- /tests/test_mean_reciprocal_rank.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import metriks 4 | 5 | 6 | def test_mrr_wikipedia(): 7 | """ 8 | Tests the wikipedia example. 9 | 10 | https://en.wikipedia.org/wiki/Mean_reciprocal_rank 11 | """ 12 | values = [ 13 | ['catten', 'cati', 'cats'], 14 | ['torii', 'tori', 'toruses'], 15 | ['viruses', 'virii', 'viri'], 16 | ] 17 | 18 | # Flag indicating which is actually relevant 19 | y_true = np.array([ 20 | [0, 0, 1], 21 | [0, 1, 0], 22 | [1, 0, 0], 23 | ]) 24 | 25 | # The predicted probability 26 | y_prob = np.array([ 27 | [0.4, 0.6, 0.3], 28 | [0.1, 0.2, 0.9], 29 | [0.9, 0.6, 0.3], 30 | ]) 31 | 32 | # Check results 33 | expected = np.ma.array([11.0/18.0, 0.0, 0.0], mask=[False, True, True]) 34 | result = metriks.mean_reciprocal_rank(y_true, y_prob) 35 | np.testing.assert_allclose(result, expected) 36 | 37 | print(result) 38 | 39 | 40 | def test_mrr_with_errors(): 41 | """Test MRR where there are errors in the probabilities""" 42 | y_true = np.array([ 43 | [1, 1, 0], 44 | [1, 0, 1], 45 | [1, 1, 0], 46 | [1, 0, 1], 47 | [0, 1, 0], 48 | [0, 1, 1], 49 | [0, 1, 1], 50 | [1, 0, 1], 51 | [1, 1, 1], 52 | ]) 53 | 54 | y_prob = np.array([ 55 | [0.7, 0.6, 0.3], 56 | [0.9, 0.2, 0.1], # 3rd probability is an error 57 | [0.7, 0.8, 0.9], # 3rd probability is an error 58 | [0.9, 0.8, 0.3], # 2nd and 3rd probability are swapped 59 | [0.4, 0.6, 0.3], 60 | [0.1, 0.6, 0.9], 61 | [0.1, 0.6, 0.9], 62 | [0.1, 0.6, 0.5], 63 | [0.9, 0.8, 0.7], 64 | ]) 65 | 66 | # Check results 67 | expected = np.ma.array([11.0 / 18.0, 31.0 / 48.0, 1.0]) 68 | result = metriks.mean_reciprocal_rank(y_true, y_prob) 69 | np.testing.assert_allclose(result, expected) 70 | 71 | print(result) 72 | 73 | 74 | def test_mrr_perfect(): 75 | """Test MRR where the probabilities are perfect""" 76 | y_true = np.array([ 77 | [0, 1, 0], 78 | [1, 0, 0], 79 | [0, 1, 0], 80 | [1, 0, 0], 81 | [0, 0, 1], 82 | ]) 83 | 84 | y_prob = np.array([ 85 | [0.3, 0.7, 0.0], 86 | [0.1, 0.0, 0.0], 87 | [0.1, 0.5, 0.3], 88 | [0.6, 0.2, 0.4], 89 | [0.1, 0.2, 0.3], 90 | ]) 91 | 92 | # Check results 93 | expected = np.ma.array([1.0, 0.0, 0.0], mask=[False, True, True]) 94 | result = metriks.mean_reciprocal_rank(y_true, y_prob) 95 | np.testing.assert_allclose(result, expected) 96 | 97 | print(result) 98 | 99 | 100 | def test_mrr_zeros(): 101 | """Test MRR where there is a sample of all zeros""" 102 | y_true = np.array([ 103 | [0, 0, 1, 1], 104 | [1, 1, 1, 0], 105 | [0, 0, 0, 0] 106 | ]) 107 | 108 | y_prob = np.array([ 109 | [0.1, 0.4, 0.35, 0.8], 110 | [0.3, 0.2, 0.7, 0.8], 111 | [0.1, 0.2, 0.3, 0.4] 112 | ]) 113 | 114 | # Check results 115 | expected = np.ma.array([0.75, 7.0/24.0, 1.0/3.0, 0.0], mask=[False, False, False, True]) 116 | result = metriks.mean_reciprocal_rank(y_true, y_prob) 117 | np.testing.assert_allclose(result, expected) 118 | 119 | print(result) 120 | 121 | 122 | if __name__ == '__main__': 123 | test_mrr_wikipedia() 124 | test_mrr_with_errors() 125 | test_mrr_perfect() 126 | test_mrr_zeros() -------------------------------------------------------------------------------- /tests/test_ndcg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import metriks 3 | 4 | 5 | def test_ndcg_perfect(): 6 | """ 7 | All predicted rankings match the expected ranking 8 | """ 9 | y_true = np.array([[0, 1, 1, 0], 10 | [0, 1, 1, 1], 11 | [1, 1, 1, 0]]) 12 | 13 | y_prob = np.array([[0.2, 0.9, 0.8, 0.4], 14 | [0.4, 0.8, 0.7, 0.5], 15 | [0.9, 0.6, 0.7, 0.2]]) 16 | 17 | expected = 1.0 18 | actual = metriks.ndcg(y_true, y_prob) 19 | 20 | np.testing.assert_equal([actual], [expected]) 21 | 22 | 23 | def test_ndcg_errors(): 24 | """ 25 | Some samples predicted the order incorrectly 26 | """ 27 | y_true = np.array([[1, 1, 1, 0], 28 | [1, 0, 1, 1], 29 | [0, 1, 1, 0]]) 30 | 31 | y_prob = np.array([[0.2, 0.9, 0.8, 0.4], 32 | [0.5, 0.8, 0.7, 0.4], 33 | [0.9, 0.6, 0.7, 0.2]]) 34 | 35 | expected = 0.7979077 36 | actual = metriks.ndcg(y_true, y_prob) 37 | 38 | np.testing.assert_allclose([actual], [expected]) 39 | 40 | 41 | def test_ndcg_all_ones(): 42 | """ 43 | Every item in each sample is relevant 44 | """ 45 | y_true = np.array([[1, 1, 1, 1], 46 | [1, 1, 1, 1], 47 | [1, 1, 1, 1]]) 48 | 49 | y_prob = np.array([[0.2, 0.9, 0.8, 0.4], 50 | [0.5, 0.8, 0.7, 0.4], 51 | [0.9, 0.6, 0.7, 0.2]]) 52 | 53 | expected = 1.0 54 | actual = metriks.ndcg(y_true, y_prob) 55 | 56 | np.testing.assert_allclose([actual], [expected]) 57 | 58 | 59 | def test_all_zeros(): 60 | """ 61 | There are no relevant items in any sample 62 | """ 63 | y_true = np.array([[0, 0, 0, 0], 64 | [0, 0, 0, 0], 65 | [0, 0, 0, 0]]) 66 | 67 | y_prob = np.array([[0.2, 0.9, 0.8, 0.4], 68 | [0.5, 0.8, 0.7, 0.4], 69 | [0.9, 0.6, 0.7, 0.2]]) 70 | 71 | expected = 1.0 72 | actual = metriks.ndcg(y_true, y_prob) 73 | 74 | np.testing.assert_allclose([actual], [expected]) 75 | 76 | 77 | def test_combination(): 78 | """ 79 | Mixture of all relevant, no relevant, and some relevant samples 80 | """ 81 | y_true = np.array([[0, 0, 0, 0], 82 | [0, 1, 1, 1], 83 | [1, 1, 1, 1]]) 84 | 85 | y_prob = np.array([[0.2, 0.9, 0.8, 0.4], 86 | [0.5, 0.8, 0.7, 0.4], 87 | [0.9, 0.6, 0.7, 0.2]]) 88 | 89 | expected = 0.989156 90 | actual = metriks.ndcg(y_true, y_prob) 91 | 92 | np.testing.assert_allclose([actual], [expected]) 93 | 94 | 95 | def test_ndcg_after_0(): 96 | """ 97 | ndcg with no classes removed 98 | """ 99 | y_true = np.array([[0, 1, 1, 0, 0, 1], 100 | [0, 1, 1, 0, 1, 1], 101 | [1, 1, 1, 1, 1, 0], 102 | [0, 1, 1, 0, 1, 0], 103 | [1, 1, 0, 0, 1, 0]]) 104 | 105 | y_prob = np.array([[0.5, 0.9, 0.8, 0.4, 0.2, 0.2], 106 | [0.9, 0.8, 0.2, 0.5, 0.2, 0.4], 107 | [0.2, 0.9, 0.8, 0.4, 0.6, 0.5], 108 | [0.2, 0.7, 0.8, 0.4, 0.6, 0.9], 109 | [0.2, 0.9, 0.8, 0.5, 0.3, 0.2]]) 110 | 111 | expected = 0.8395053 112 | actual = metriks.ndcg(y_true, y_prob) 113 | 114 | np.testing.assert_allclose([actual], [expected]) 115 | 116 | 117 | def test_ndcg_after_1(): 118 | """ 119 | ndcg with top 1 class removed 120 | """ 121 | y_true = np.array([[0, 1, 1, 0, 0, 1], 122 | [1, 1, 1, 0, 1, 0], 123 | [1, 1, 1, 1, 1, 0], 124 | [0, 1, 1, 0, 0, 1], 125 | [1, 1, 0, 0, 1, 1]]) 126 | 127 | y_prob = np.array([[0.5, 0.9, 0.8, 0.4, 0.2, 0.2], 128 | [0.9, 0.8, 0.2, 0.5, 0.2, 0.4], 129 | [0.2, 0.9, 0.8, 0.4, 0.6, 0.5], 130 | [0.2, 0.7, 0.8, 0.4, 0.6, 0.9], 131 | [0.2, 0.2, 0.8, 0.5, 0.3, 0.9]]) 132 | 133 | expected = 0.8554782 134 | actual = metriks.ndcg(y_true, y_prob, 1) 135 | 136 | np.testing.assert_allclose([actual], [expected]) 137 | 138 | 139 | def test_ndcg_after_2_perfect(): 140 | """ 141 | ndcg with top 2 classes removed 142 | """ 143 | y_true = np.array([[1, 1, 1, 0, 0, 0], 144 | [1, 1, 0, 1, 0, 1], 145 | [0, 1, 1, 1, 1, 1], 146 | [0, 1, 1, 0, 1, 1], 147 | [0, 0, 1, 1, 1, 1]]) 148 | 149 | y_prob = np.array([[0.5, 0.9, 0.8, 0.4, 0.2, 0.2], 150 | [0.9, 0.8, 0.2, 0.5, 0.2, 0.4], 151 | [0.2, 0.9, 0.8, 0.4, 0.6, 0.5], 152 | [0.2, 0.7, 0.8, 0.4, 0.6, 0.9], 153 | [0.2, 0.2, 0.8, 0.5, 0.3, 0.9]]) 154 | 155 | expected = 1.0 156 | actual = metriks.ndcg(y_true, y_prob, 2) 157 | 158 | np.testing.assert_allclose([actual], [expected]) 159 | 160 | 161 | def test_ndcg_after_2_errors(): 162 | """ 163 | ndcg with top 2 classes removed 164 | """ 165 | y_true = np.array([[0, 1, 1, 0, 0, 1], 166 | [1, 1, 1, 0, 1, 1], 167 | [1, 1, 1, 1, 1, 0], 168 | [0, 1, 1, 0, 1, 1], 169 | [0, 0, 1, 1, 1, 1]]) 170 | 171 | y_prob = np.array([[0.5, 0.9, 0.8, 0.4, 0.2, 0.2], 172 | [0.9, 0.8, 0.2, 0.5, 0.2, 0.4], 173 | [0.2, 0.9, 0.8, 0.4, 0.6, 0.5], 174 | [0.2, 0.7, 0.8, 0.4, 0.6, 0.9], 175 | [0.2, 0.2, 0.8, 0.5, 0.3, 0.9]]) 176 | 177 | expected = 0.8139061 178 | actual = metriks.ndcg(y_true, y_prob, 2) 179 | 180 | np.testing.assert_allclose([actual], [expected]) 181 | 182 | 183 | def test_ndcg_1_over_0_error(): 184 | """ 185 | ndcg with division by 0 186 | """ 187 | y_true = np.array([[0, 1, 1, 1]]) 188 | 189 | y_prob = np.array([[0.9, 0.8, 0.7, 0.6]]) 190 | 191 | expected = 1.0 192 | actual = metriks.ndcg(y_true, y_prob, 1) 193 | 194 | np.testing.assert_allclose([actual], [expected]) 195 | 196 | 197 | def test_ndcg_0_over_0_error(): 198 | """ 199 | ndcg with division by 0 200 | """ 201 | y_true = np.array([[0, 0, 0, 0]]) 202 | 203 | y_prob = np.array([[0.9, 0.8, 0.7, 0.6]]) 204 | 205 | expected = 1.0 206 | actual = metriks.ndcg(y_true, y_prob, 1) 207 | 208 | np.testing.assert_allclose([actual], [expected]) 209 | 210 | if __name__ == '__main__': 211 | test_ndcg_perfect() 212 | test_ndcg_errors() 213 | test_ndcg_all_ones() 214 | test_all_zeros() 215 | test_combination() 216 | test_ndcg_after_0() 217 | test_ndcg_after_1() 218 | test_ndcg_after_2_perfect() 219 | test_ndcg_after_2_errors() 220 | test_ndcg_1_over_0_error() 221 | test_ndcg_0_over_0_error() 222 | -------------------------------------------------------------------------------- /tests/test_precision_at_k.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import metriks 4 | 5 | 6 | def test_precision_at_k_wikipedia(): 7 | """ 8 | Tests the wikipedia example. 9 | 10 | https://en.wikipedia.org/wiki/Mean_reciprocal_rank 11 | """ 12 | 13 | # Flag indicating which is actually relevant 14 | y_true = np.array([ 15 | [0, 0, 1], 16 | [0, 1, 0], 17 | [1, 0, 0], 18 | ]) 19 | 20 | # The predicted probability 21 | y_prob = np.array([ 22 | [0.4, 0.6, 0.3], 23 | [0.1, 0.2, 0.9], 24 | [0.9, 0.6, 0.3], 25 | ]) 26 | 27 | # Check results 28 | expected = 2.0/6.0 29 | result = metriks.precision_at_k(y_true, y_prob, 2) 30 | np.testing.assert_allclose(result, expected) 31 | 32 | print(result) 33 | 34 | 35 | def test_precision_at_k_with_errors(): 36 | """Test precision_at_k where there are errors in the probabilities""" 37 | y_true = np.array([ 38 | [1, 1, 0], 39 | [1, 0, 1], 40 | [1, 1, 0], 41 | [1, 0, 1], 42 | [0, 1, 0], 43 | [0, 1, 1], 44 | [0, 1, 1], 45 | [1, 0, 1], 46 | [1, 1, 1], 47 | ]) 48 | 49 | y_prob = np.array([ 50 | [0.7, 0.6, 0.3], 51 | [0.9, 0.2, 0.1], # 3rd probability is an error 52 | [0.7, 0.8, 0.9], # 3rd probability is an error 53 | [0.9, 0.8, 0.3], # 2nd and 3rd probability are swapped 54 | [0.4, 0.6, 0.3], 55 | [0.1, 0.6, 0.9], 56 | [0.1, 0.6, 0.9], 57 | [0.1, 0.6, 0.5], 58 | [0.9, 0.8, 0.7], 59 | ]) 60 | 61 | # Check results 62 | expected = 18.0/27.0 63 | result = metriks.precision_at_k(y_true, y_prob, 3) 64 | np.testing.assert_allclose(result, expected) 65 | 66 | print(result) 67 | 68 | 69 | def test_precision_at_k_perfect(): 70 | """Test precision_at_k where the probabilities are perfect""" 71 | y_true = np.array([ 72 | [0, 1, 0], 73 | [1, 0, 0], 74 | [0, 1, 0], 75 | [1, 0, 0], 76 | [0, 0, 1], 77 | ]) 78 | 79 | y_prob = np.array([ 80 | [0.3, 0.7, 0.0], 81 | [0.1, 0.0, 0.0], 82 | [0.1, 0.5, 0.3], 83 | [0.6, 0.2, 0.4], 84 | [0.1, 0.2, 0.3], 85 | ]) 86 | 87 | # Check results 88 | expected = 1.0 89 | result = metriks.precision_at_k(y_true, y_prob, 1) 90 | np.testing.assert_allclose(result, expected) 91 | 92 | print(result) 93 | 94 | 95 | def test_precision_at_k_perfect_multiple_true(): 96 | """Test precision_at_k where the probabilities are perfect and there 97 | are multiple true labels for some samples""" 98 | y_true = np.array([ 99 | [1, 1, 0], 100 | [1, 0, 0], 101 | [0, 1, 1], 102 | [1, 0, 1], 103 | [0, 1, 1], 104 | ]) 105 | 106 | y_prob = np.array([ 107 | [0.3, 0.7, 0.0], 108 | [0.1, 0.0, 0.0], 109 | [0.1, 0.5, 0.3], 110 | [0.6, 0.2, 0.4], 111 | [0.1, 0.2, 0.3], 112 | ]) 113 | 114 | # Check results for k=1 115 | expected = 1.0 116 | result = metriks.precision_at_k(y_true, y_prob, 1) 117 | np.testing.assert_allclose(result, expected) 118 | 119 | print(result) 120 | 121 | # Check results for k=2 122 | expected = 9.0/10.0 123 | result = metriks.precision_at_k(y_true, y_prob, 2) 124 | np.testing.assert_allclose(result, expected) 125 | 126 | print(result) 127 | 128 | 129 | def test_precision_at_k_few_zeros(): 130 | """Test precision_at_k where there are very few zeros""" 131 | y_true = np.array([ 132 | [0, 1, 1, 1, 1], 133 | [1, 1, 1, 1, 1], 134 | [1, 1, 1, 1, 0] 135 | ]) 136 | 137 | y_prob = np.array([ 138 | [0.1, 0.4, 0.35, 0.8, 0.9], 139 | [0.3, 0.2, 0.7, 0.8, 0.6], 140 | [0.1, 0.2, 0.3, 0.4, 0.5] 141 | ]) 142 | 143 | # Check results 144 | expected = 5.0/6.0 145 | result = metriks.precision_at_k(y_true, y_prob, 2) 146 | np.testing.assert_allclose(result, expected) 147 | 148 | print(result) 149 | 150 | 151 | def test_precision_at_k_zeros(): 152 | """Test precision_at_k where there is a sample of all zeros""" 153 | y_true = np.array([ 154 | [0, 0, 1, 1], 155 | [1, 1, 1, 0], 156 | [0, 0, 0, 0] 157 | ]) 158 | 159 | y_prob = np.array([ 160 | [0.1, 0.4, 0.35, 0.8], 161 | [0.3, 0.2, 0.7, 0.8], 162 | [0.1, 0.2, 0.3, 0.4] 163 | ]) 164 | 165 | # Check results 166 | expected = 4.0/9.0 167 | result = metriks.precision_at_k(y_true, y_prob, 3) 168 | np.testing.assert_allclose(result, expected) 169 | 170 | print(result) 171 | 172 | 173 | if __name__ == '__main__': 174 | test_precision_at_k_wikipedia() 175 | test_precision_at_k_with_errors() 176 | test_precision_at_k_perfect() 177 | test_precision_at_k_perfect_multiple_true() 178 | test_precision_at_k_few_zeros() 179 | test_precision_at_k_zeros() 180 | -------------------------------------------------------------------------------- /tests/test_recall_at_k.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import metriks 4 | 5 | 6 | def test_recall_at_k_wikipedia(): 7 | """ 8 | Tests the wikipedia example. 9 | 10 | https://en.wikipedia.org/wiki/Mean_reciprocal_rank 11 | """ 12 | 13 | # Flag indicating which is actually relevant 14 | y_true = np.array([ 15 | [0, 0, 1], 16 | [0, 1, 0], 17 | [1, 0, 0], 18 | ]) 19 | 20 | # The predicted probability 21 | y_prob = np.array([ 22 | [0.4, 0.6, 0.3], 23 | [0.1, 0.2, 0.9], 24 | [0.9, 0.6, 0.3], 25 | ]) 26 | 27 | # Check results 28 | expected = 2.0/3.0 29 | result = metriks.recall_at_k(y_true, y_prob, 2) 30 | np.testing.assert_allclose(result, expected) 31 | 32 | print(result) 33 | 34 | 35 | def test_recall_at_k_with_errors(): 36 | """Test recall_at_k where there are errors in the probabilities""" 37 | y_true = np.array([ 38 | [1, 1, 0], 39 | [1, 0, 1], 40 | [1, 1, 0], 41 | [1, 0, 1], 42 | [0, 1, 0], 43 | [0, 1, 1], 44 | [0, 1, 1], 45 | [1, 0, 1], 46 | [1, 1, 1], 47 | ]) 48 | 49 | y_prob = np.array([ 50 | [0.7, 0.6, 0.3], 51 | [0.9, 0.2, 0.1], # 3rd probability is an error 52 | [0.7, 0.8, 0.9], # 3rd probability is an error 53 | [0.9, 0.8, 0.3], # 2nd and 3rd probability are swapped 54 | [0.4, 0.6, 0.3], 55 | [0.1, 0.6, 0.9], 56 | [0.1, 0.6, 0.9], 57 | [0.1, 0.6, 0.5], 58 | [0.9, 0.8, 0.7], 59 | ]) 60 | 61 | # Check results 62 | expected = 40.0/54.0 63 | result = metriks.recall_at_k(y_true, y_prob, 2) 64 | np.testing.assert_allclose(result, expected) 65 | 66 | print(result) 67 | 68 | 69 | def test_recall_at_k_perfect(): 70 | """Test recall_at_k where the probabilities are perfect""" 71 | y_true = np.array([ 72 | [0, 1, 0], 73 | [1, 0, 0], 74 | [0, 1, 0], 75 | [1, 0, 0], 76 | [0, 0, 1], 77 | ]) 78 | 79 | y_prob = np.array([ 80 | [0.3, 0.7, 0.0], 81 | [0.1, 0.0, 0.0], 82 | [0.1, 0.5, 0.3], 83 | [0.6, 0.2, 0.4], 84 | [0.1, 0.2, 0.3], 85 | ]) 86 | 87 | # Check results 88 | expected = 1.0 89 | result = metriks.recall_at_k(y_true, y_prob, 1) 90 | np.testing.assert_allclose(result, expected) 91 | 92 | print(result) 93 | 94 | 95 | def test_recall_at_k_perfect_multiple_true(): 96 | """Test recall_at_k where the probabilities are perfect and there 97 | are multiple true labels for some samples""" 98 | y_true = np.array([ 99 | [1, 1, 0], 100 | [1, 0, 0], 101 | [0, 1, 1], 102 | [1, 0, 1], 103 | [0, 1, 1], 104 | ]) 105 | 106 | y_prob = np.array([ 107 | [0.3, 0.7, 0.0], 108 | [0.1, 0.0, 0.0], 109 | [0.1, 0.5, 0.3], 110 | [0.6, 0.2, 0.4], 111 | [0.1, 0.2, 0.3], 112 | ]) 113 | 114 | # Check results for k=1 115 | expected = 6.0/10.0 116 | result = metriks.recall_at_k(y_true, y_prob, 1) 117 | np.testing.assert_allclose(result, expected) 118 | 119 | print(result) 120 | 121 | # Check results for k=2 122 | expected = 1.0 123 | result = metriks.recall_at_k(y_true, y_prob, 2) 124 | np.testing.assert_allclose(result, expected) 125 | 126 | print(result) 127 | 128 | 129 | def test_recall_at_k_few_zeros(): 130 | """Test recall_at_k where there are very few zeros""" 131 | y_true = np.array([ 132 | [0, 1, 1, 1, 1], 133 | [1, 1, 1, 1, 1], 134 | [1, 1, 1, 1, 0] 135 | ]) 136 | 137 | y_prob = np.array([ 138 | [0.1, 0.4, 0.35, 0.8, 0.9], 139 | [0.3, 0.2, 0.7, 0.8, 0.6], 140 | [0.1, 0.2, 0.3, 0.4, 0.5] 141 | ]) 142 | 143 | # Check results 144 | expected = 23.0/60.0 145 | result = metriks.recall_at_k(y_true, y_prob, 2) 146 | np.testing.assert_allclose(result, expected) 147 | 148 | print(result) 149 | 150 | 151 | def test_recall_at_k_remove_zeros(): 152 | """Test recall_at_k where there is a sample of all zeros, which is filtered out""" 153 | y_true = np.array([ 154 | [0, 0, 1, 1], 155 | [1, 1, 1, 0], 156 | [0, 0, 0, 0] 157 | ]) 158 | 159 | y_prob = np.array([ 160 | [0.1, 0.4, 0.35, 0.8], 161 | [0.3, 0.2, 0.7, 0.8], 162 | [0.1, 0.2, 0.3, 0.4] 163 | ]) 164 | 165 | # Check results 166 | expected = 5.0/12.0 167 | result = metriks.recall_at_k(y_true, y_prob, 2) 168 | np.testing.assert_allclose(result, expected) 169 | 170 | print(result) 171 | 172 | 173 | if __name__ == '__main__': 174 | test_recall_at_k_wikipedia() 175 | test_recall_at_k_with_errors() 176 | test_recall_at_k_perfect() 177 | test_recall_at_k_perfect_multiple_true() 178 | test_recall_at_k_few_zeros() 179 | test_recall_at_k_remove_zeros() 180 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | 2 | # tox (https://tox.readthedocs.io/) is a tool for running tests 3 | # in multiple virtualenvs. This configuration file will run the 4 | # test suite on all supported python versions. To use it, "pip install tox" 5 | # and then run "tox" from this directory. 6 | 7 | [tox] 8 | envlist = py36 9 | 10 | [testenv] 11 | deps = coverage 12 | pytest 13 | commands = 14 | coverage run -m pytest 15 | coverage html --include='metriks/*' 16 | --------------------------------------------------------------------------------