├── Cargo.toml ├── LICENSE.md ├── README.md └── src ├── classification.rs ├── display.rs ├── error.rs ├── lib.rs ├── numeric.rs ├── regression.rs └── util.rs /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "eval-metrics" 3 | version = "1.0.2" 4 | authors = ["Benjamin Harrison "] 5 | description = "Evaluation metrics for machine learning" 6 | keywords = ["evaluation-metrics", "machine-learning"] 7 | homepage = "https://github.com/benjarison/eval-metrics" 8 | repository = "https://github.com/benjarison/eval-metrics" 9 | documentation = "https://docs.rs/eval-metrics" 10 | readme = "README.md" 11 | license = "MIT OR Apache-2.0" 12 | edition = "2018" 13 | 14 | [dependencies] 15 | assert_approx_eq = "1.1.0" 16 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | eval-metrics is dual-licensed under The MIT License [1] and 2 | Apache 2.0 License [2]. Copyright (c) 2020, Benjamin Harrison 3 | 4 | [1] 5 | 6 | The MIT License 7 | 8 | Copyright 2020 Benjamin Harrison 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 11 | documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 12 | rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit 13 | persons to whom the Software is furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the 16 | Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE 19 | WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 20 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 21 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | 23 | [2] 24 | 25 | Apache License 26 | Version 2.0, January 2004 27 | http://www.apache.org/licenses/ 28 | 29 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 30 | 31 | 1. Definitions. 32 | 33 | "License" shall mean the terms and conditions for use, reproduction, 34 | and distribution as defined by Sections 1 through 9 of this document. 35 | 36 | "Licensor" shall mean the copyright owner or entity authorized by 37 | the copyright owner that is granting the License. 38 | 39 | "Legal Entity" shall mean the union of the acting entity and all 40 | other entities that control, are controlled by, or are under common 41 | control with that entity. For the purposes of this definition, 42 | "control" means (i) the power, direct or indirect, to cause the 43 | direction or management of such entity, whether by contract or 44 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 45 | outstanding shares, or (iii) beneficial ownership of such entity. 46 | 47 | "You" (or "Your") shall mean an individual or Legal Entity 48 | exercising permissions granted by this License. 49 | 50 | "Source" form shall mean the preferred form for making modifications, 51 | including but not limited to software source code, documentation 52 | source, and configuration files. 53 | 54 | "Object" form shall mean any form resulting from mechanical 55 | transformation or translation of a Source form, including but 56 | not limited to compiled object code, generated documentation, 57 | and conversions to other media types. 58 | 59 | "Work" shall mean the work of authorship, whether in Source or 60 | Object form, made available under the License, as indicated by a 61 | copyright notice that is included in or attached to the work 62 | (an example is provided in the Appendix below). 63 | 64 | "Derivative Works" shall mean any work, whether in Source or Object 65 | form, that is based on (or derived from) the Work and for which the 66 | editorial revisions, annotations, elaborations, or other modifications 67 | represent, as a whole, an original work of authorship. For the purposes 68 | of this License, Derivative Works shall not include works that remain 69 | separable from, or merely link (or bind by name) to the interfaces of, 70 | the Work and Derivative Works thereof. 71 | 72 | "Contribution" shall mean any work of authorship, including 73 | the original version of the Work and any modifications or additions 74 | to that Work or Derivative Works thereof, that is intentionally 75 | submitted to Licensor for inclusion in the Work by the copyright owner 76 | or by an individual or Legal Entity authorized to submit on behalf of 77 | the copyright owner. For the purposes of this definition, "submitted" 78 | means any form of electronic, verbal, or written communication sent 79 | to the Licensor or its representatives, including but not limited to 80 | communication on electronic mailing lists, source code control systems, 81 | and issue tracking systems that are managed by, or on behalf of, the 82 | Licensor for the purpose of discussing and improving the Work, but 83 | excluding communication that is conspicuously marked or otherwise 84 | designated in writing by the copyright owner as "Not a Contribution." 85 | 86 | "Contributor" shall mean Licensor and any individual or Legal Entity 87 | on behalf of whom a Contribution has been received by Licensor and 88 | subsequently incorporated within the Work. 89 | 90 | 2. Grant of Copyright License. Subject to the terms and conditions of 91 | this License, each Contributor hereby grants to You a perpetual, 92 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 93 | copyright license to reproduce, prepare Derivative Works of, 94 | publicly display, publicly perform, sublicense, and distribute the 95 | Work and such Derivative Works in Source or Object form. 96 | 97 | 3. Grant of Patent License. Subject to the terms and conditions of 98 | this License, each Contributor hereby grants to You a perpetual, 99 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 100 | (except as stated in this section) patent license to make, have made, 101 | use, offer to sell, sell, import, and otherwise transfer the Work, 102 | where such license applies only to those patent claims licensable 103 | by such Contributor that are necessarily infringed by their 104 | Contribution(s) alone or by combination of their Contribution(s) 105 | with the Work to which such Contribution(s) was submitted. If You 106 | institute patent litigation against any entity (including a 107 | cross-claim or counterclaim in a lawsuit) alleging that the Work 108 | or a Contribution incorporated within the Work constitutes direct 109 | or contributory patent infringement, then any patent licenses 110 | granted to You under this License for that Work shall terminate 111 | as of the date such litigation is filed. 112 | 113 | 4. Redistribution. You may reproduce and distribute copies of the 114 | Work or Derivative Works thereof in any medium, with or without 115 | modifications, and in Source or Object form, provided that You 116 | meet the following conditions: 117 | 118 | (a) You must give any other recipients of the Work or 119 | Derivative Works a copy of this License; and 120 | 121 | (b) You must cause any modified files to carry prominent notices 122 | stating that You changed the files; and 123 | 124 | (c) You must retain, in the Source form of any Derivative Works 125 | that You distribute, all copyright, patent, trademark, and 126 | attribution notices from the Source form of the Work, 127 | excluding those notices that do not pertain to any part of 128 | the Derivative Works; and 129 | 130 | (d) If the Work includes a "NOTICE" text file as part of its 131 | distribution, then any Derivative Works that You distribute must 132 | include a readable copy of the attribution notices contained 133 | within such NOTICE file, excluding those notices that do not 134 | pertain to any part of the Derivative Works, in at least one 135 | of the following places: within a NOTICE text file distributed 136 | as part of the Derivative Works; within the Source form or 137 | documentation, if provided along with the Derivative Works; or, 138 | within a display generated by the Derivative Works, if and 139 | wherever such third-party notices normally appear. The contents 140 | of the NOTICE file are for informational purposes only and 141 | do not modify the License. You may add Your own attribution 142 | notices within Derivative Works that You distribute, alongside 143 | or as an addendum to the NOTICE text from the Work, provided 144 | that such additional attribution notices cannot be construed 145 | as modifying the License. 146 | 147 | You may add Your own copyright statement to Your modifications and 148 | may provide additional or different license terms and conditions 149 | for use, reproduction, or distribution of Your modifications, or 150 | for any such Derivative Works as a whole, provided Your use, 151 | reproduction, and distribution of the Work otherwise complies with 152 | the conditions stated in this License. 153 | 154 | 5. Submission of Contributions. Unless You explicitly state otherwise, 155 | any Contribution intentionally submitted for inclusion in the Work 156 | by You to the Licensor shall be under the terms and conditions of 157 | this License, without any additional terms or conditions. 158 | Notwithstanding the above, nothing herein shall supersede or modify 159 | the terms of any separate license agreement you may have executed 160 | with Licensor regarding such Contributions. 161 | 162 | 6. Trademarks. This License does not grant permission to use the trade 163 | names, trademarks, service marks, or product names of the Licensor, 164 | except as required for reasonable and customary use in describing the 165 | origin of the Work and reproducing the content of the NOTICE file. 166 | 167 | 7. Disclaimer of Warranty. Unless required by applicable law or 168 | agreed to in writing, Licensor provides the Work (and each 169 | Contributor provides its Contributions) on an "AS IS" BASIS, 170 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 171 | implied, including, without limitation, any warranties or conditions 172 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 173 | PARTICULAR PURPOSE. You are solely responsible for determining the 174 | appropriateness of using or redistributing the Work and assume any 175 | risks associated with Your exercise of permissions under this License. 176 | 177 | 8. Limitation of Liability. In no event and under no legal theory, 178 | whether in tort (including negligence), contract, or otherwise, 179 | unless required by applicable law (such as deliberate and grossly 180 | negligent acts) or agreed to in writing, shall any Contributor be 181 | liable to You for damages, including any direct, indirect, special, 182 | incidental, or consequential damages of any character arising as a 183 | result of this License or out of the use or inability to use the 184 | Work (including but not limited to damages for loss of goodwill, 185 | work stoppage, computer failure or malfunction, or any and all 186 | other commercial damages or losses), even if such Contributor 187 | has been advised of the possibility of such damages. 188 | 189 | 9. Accepting Warranty or Additional Liability. While redistributing 190 | the Work or Derivative Works thereof, You may choose to offer, 191 | and charge a fee for, acceptance of support, warranty, indemnity, 192 | or other liability obligations and/or rights consistent with this 193 | License. However, in accepting such obligations, You may act only 194 | on Your own behalf and on Your sole responsibility, not on behalf 195 | of any other Contributor, and only if You agree to indemnify, 196 | defend, and hold each Contributor harmless for any liability 197 | incurred by, or claims asserted against, such Contributor by reason 198 | of your accepting any such warranty or additional liability. 199 | 200 | END OF TERMS AND CONDITIONS 201 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # eval-metrics 2 | Evaluation metrics for machine learning 3 | 4 | [![crates.io](https://img.shields.io/crates/v/eval-metrics.svg)](https://crates.io/crates/eval-metrics) 5 | ![License](https://img.shields.io/crates/l/eval-metrics.svg) 6 | 7 | ---------------------------------------- 8 | 9 | ## Design 10 | 11 | The goal of this library is to provide an intuitive collection of functions for computing evaluation metrics commonly 12 | encountered in machine learning. Metrics are separated into modules for either `classification` or `regression`, with 13 | the classification module supporting both binary and multi-class tasks. This distinction between binary and multi-class 14 | classification is made explicit to underscore the fact that there are subtle differences in certain metrics between the 15 | two cases (i.e. multi-class metrics often require averaging methods). Metrics can often fail to be defined for a 16 | variety of numerical reasons, and in these cases `Result` types are used to make this fact apparent. 17 | 18 | ## Supported Metrics 19 | 20 | | Metric | Task | Description | 21 | |-------------|----------------------------|--------------------------------------------------------------------| 22 | | Accuracy | Binary Classification | Binary Class Accuracy | 23 | | Precision | Binary Classification | Binary Class Precision | 24 | | Recall | Binary Classification | Binary Class Recall | 25 | | F-1 | Binary Classification | Harmonic Mean of Precision and Recall | 26 | | MCC | Binary Classification | Matthews Correlation Coefficient | 27 | | ROC Curve | Binary Classification | Receiver Operating Characteristic Curve | 28 | | AUC | Binary Classification | Area Under ROC Curve | 29 | | PR Curve | Binary Classification | Precision-Recall Curve | 30 | | AP | Binary Classification | Average Precision | 31 | | Accuracy | Multi-Class Classification | Multi-Class Accuracy | 32 | | Precision | Multi-Class Classification | Multi-Class Precision | 33 | | Recall | Multi-Class Classification | Multi-Class Recall | 34 | | F-1 | Multi-Class Classification | Multi-Class F1 | 35 | | Rk | Multi-Class Classification | K-Category Correlation Coefficient as described by Gorodkin (2004) | 36 | | M-AUC | Multi-Class Classification | Multi-Class AUC as described by Hand and Till (2001) | 37 | | RMSE | Regression | Root Mean Squared Error | 38 | | MSE | Regression | Mean Squared Error | 39 | | MAE | Regression | Mean Absolute Error | 40 | | R-Square | Regression | Coefficient of Determination | 41 | | Correlation | Regression | Linear Correlation Coefficient | 42 | 43 | ## Usage 44 | 45 | ### Binary Classification 46 | 47 | The `BinaryConfusionMatrix` struct provides functionality for computing common binary classification metrics. 48 | 49 | ```rust 50 | use eval_metrics::error::EvalError; 51 | use eval_metrics::classification::BinaryConfusionMatrix; 52 | 53 | fn main() -> Result<(), EvalError> { 54 | // note: these scores could also be f32 values 55 | let scores = vec![0.5, 0.2, 0.7, 0.4, 0.1, 0.3, 0.8, 0.9]; 56 | let labels = vec![false, false, true, false, true, false, false, true]; 57 | let threshold = 0.5; 58 | 59 | // compute confusion matrix from scores and labels 60 | let matrix = BinaryConfusionMatrix::compute(&scores, &labels, threshold)?; 61 | 62 | // counts 63 | let tpc = matrix.tp_count; 64 | let fpc = matrix.fp_count; 65 | let tnc = matrix.tn_count; 66 | let fnc = matrix.fn_count; 67 | 68 | // metrics 69 | let acc = matrix.accuracy()?; 70 | let pre = matrix.precision()?; 71 | let rec = matrix.recall()?; 72 | let f1 = matrix.f1()?; 73 | let mcc = matrix.mcc()?; 74 | 75 | // print matrix to console 76 | println!("{}", matrix); 77 | Ok(()) 78 | } 79 | ``` 80 | ``` 81 | o=========================o 82 | | Label | 83 | o=========================o 84 | | Positive | Negative | 85 | o==============o============o============|============o 86 | | | Positive | 2 | 2 | 87 | | Prediction |============|------------|------------| 88 | | | Negative | 1 | 3 | 89 | o==============o============o=========================o 90 | ``` 91 | 92 | In addition to the metrics derived from the confusion matrix, ROC curves and PR curves can be computed, providing metrics 93 | such as AUC and AP. 94 | 95 | ```rust 96 | use eval_metrics::error::EvalError; 97 | use eval_metrics::classification::{RocCurve, RocPoint, PrCurve, PrPoint}; 98 | 99 | fn main() -> Result<(), EvalError> { 100 | // note: these scores could also be f32 values 101 | let scores = vec![0.5, 0.2, 0.7, 0.4, 0.1, 0.3, 0.8, 0.9]; 102 | let labels = vec![false, false, true, false, true, false, false, true]; 103 | 104 | // construct roc curve 105 | let roc = RocCurve::compute(&scores, &labels)?; 106 | // compute auc 107 | let auc = roc.auc(); 108 | // inspect roc curve points 109 | roc.points.iter().for_each(|point| { 110 | let tpr = point.tp_rate; 111 | let fpr = point.fp_rate; 112 | let thresh = point.threshold; 113 | }); 114 | 115 | // construct pr curve 116 | let pr = PrCurve::compute(&scores, &labels)?; 117 | // compute average precision 118 | let ap = pr.ap(); 119 | // inspect pr curve points 120 | pr.points.iter().for_each(|point| { 121 | let pre = point.precision; 122 | let rec = point.recall; 123 | let thresh = point.threshold; 124 | }); 125 | Ok(()) 126 | } 127 | ``` 128 | 129 | ### Multi-Class Classification 130 | 131 | The `MultiConfusionMatrix` struct provides functionality for computing common multi-class classification metrics. 132 | Additionally, averaging methods must be explicitly provided for several of these metrics. 133 | 134 | ```rust 135 | use eval_metrics::error::EvalError; 136 | use eval_metrics::classification::{MultiConfusionMatrix, Averaging}; 137 | 138 | fn main() -> Result<(), EvalError> { 139 | // note: these scores could also be f32 values 140 | let scores = vec![ 141 | vec![0.3, 0.1, 0.6], 142 | vec![0.5, 0.2, 0.3], 143 | vec![0.2, 0.7, 0.1], 144 | vec![0.3, 0.3, 0.4], 145 | vec![0.5, 0.1, 0.4], 146 | vec![0.8, 0.1, 0.1], 147 | vec![0.3, 0.5, 0.2] 148 | ]; 149 | let labels = vec![2, 1, 1, 2, 0, 2, 0]; 150 | 151 | // compute confusion matrix from scores and labels 152 | let matrix = MultiConfusionMatrix::compute(&scores, &labels)?; 153 | 154 | // get counts 155 | let counts = &matrix.counts; 156 | 157 | // metrics 158 | let acc = matrix.accuracy()?; 159 | let mac_pre = matrix.precision(&Averaging::Macro)?; 160 | let wgt_pre = matrix.precision(&Averaging::Weighted)?; 161 | let mac_rec = matrix.recall(&Averaging::Macro)?; 162 | let wgt_rec = matrix.recall(&Averaging::Weighted)?; 163 | let mac_f1 = matrix.f1(&Averaging::Macro)?; 164 | let wgt_f1 = matrix.f1(&Averaging::Weighted)?; 165 | let rk = matrix.rk()?; 166 | 167 | // print matrix to console 168 | println!("{}", matrix); 169 | Ok(()) 170 | } 171 | ``` 172 | ``` 173 | o===================================o 174 | | Label | 175 | o===================================o 176 | | Class-1 | Class-2 | Class-3 | 177 | o==============o===========o===========|===========|===========o 178 | | | Class-1 | 1 | 1 | 1 | 179 | | |===========|-----------|-----------|-----------| 180 | | Prediction | Class-2 | 1 | 1 | 0 | 181 | | |===========|-----------|-----------|-----------| 182 | | | Class-3 | 0 | 0 | 2 | 183 | o==============o===========o===================================o 184 | ``` 185 | 186 | In addition to these global metrics, per-class metrics can be obtained as well. 187 | 188 | ```rust 189 | use eval_metrics::error::EvalError; 190 | use eval_metrics::classification::{MultiConfusionMatrix}; 191 | 192 | fn main() -> Result<(), EvalError> { 193 | // note: these scores could also be f32 values 194 | let scores = vec![ 195 | vec![0.3, 0.1, 0.6], 196 | vec![0.5, 0.2, 0.3], 197 | vec![0.2, 0.7, 0.1], 198 | vec![0.3, 0.3, 0.4], 199 | vec![0.5, 0.1, 0.4], 200 | vec![0.8, 0.1, 0.1], 201 | vec![0.3, 0.5, 0.2] 202 | ]; 203 | let labels = vec![2, 1, 1, 2, 0, 2, 0]; 204 | 205 | // compute confusion matrix from scores and labels 206 | let matrix = MultiConfusionMatrix::compute(&scores, &labels)?; 207 | 208 | // per-class metrics 209 | let pca = matrix.per_class_accuracy(); 210 | let pcp = matrix.per_class_precision(); 211 | let pcr = matrix.per_class_recall(); 212 | let pcf = matrix.per_class_f1(); 213 | let pcm = matrix.per_class_mcc(); 214 | 215 | // print per-class metrics to console 216 | println!("{:?}", pca); 217 | println!("{:?}", pcp); 218 | println!("{:?}", pcr); 219 | println!("{:?}", pcf); 220 | println!("{:?}", pcm); 221 | Ok(()) 222 | } 223 | ``` 224 | ``` 225 | [Ok(0.5714285714285714), Ok(0.7142857142857143), Ok(0.8571428571428571)] 226 | [Ok(0.3333333333333333), Ok(0.5), Ok(1.0)] 227 | [Ok(0.5), Ok(0.5), Ok(0.6666666666666666)] 228 | [Ok(0.4), Ok(0.5), Ok(0.8)] 229 | [Ok(0.09128709291752773), Ok(0.3), Ok(0.7302967433402215)] 230 | ``` 231 | 232 | 233 | In addition to the metrics derived from the confusion matrix, the M-AUC (multi-class AUC) metric as described by 234 | Hand and Till (2001) is provided as a standalone function: 235 | 236 | ```rust 237 | let mauc = m_auc(&scores, &labels)?; 238 | ``` 239 | 240 | ### Regression 241 | 242 | All regression metrics operate on a pair of scores and labels. 243 | 244 | ```rust 245 | use eval_metrics::error::EvalError; 246 | use eval_metrics::regression::*; 247 | 248 | fn main() -> Result<(), EvalError> { 249 | 250 | // note: these could also be f32 values 251 | let scores = vec![0.4, 0.7, -1.2, 2.5, 0.3]; 252 | let labels = vec![0.2, 1.1, -0.9, 1.3, -0.2]; 253 | 254 | // root mean squared error 255 | let rmse = rmse(&scores, &labels)?; 256 | // mean squared error 257 | let mse = mse(&scores, &labels)?; 258 | // mean absolute error 259 | let mae = mae(&scores, &labels)?; 260 | // coefficient of determination 261 | let rsq = rsq(&scores, &labels)?; 262 | // pearson correlation coefficient 263 | let corr = corr(&scores, &labels)?; 264 | Ok(()) 265 | } 266 | ``` 267 | -------------------------------------------------------------------------------- /src/classification.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! Provides support for both binary and multi-class classification metrics 3 | //! 4 | 5 | use std::cmp::Ordering; 6 | use crate::util; 7 | use crate::numeric::Scalar; 8 | use crate::error::EvalError; 9 | use crate::display; 10 | 11 | /// 12 | /// Confusion matrix for binary classification 13 | /// 14 | #[derive(Copy, Clone, Debug, Eq, PartialEq)] 15 | pub struct BinaryConfusionMatrix { 16 | /// true positive count 17 | pub tp_count: usize, 18 | /// false positive count 19 | pub fp_count: usize, 20 | /// true negative count 21 | pub tn_count: usize, 22 | /// false negative count 23 | pub fn_count: usize, 24 | /// count sum 25 | sum: usize 26 | } 27 | 28 | impl BinaryConfusionMatrix { 29 | 30 | /// 31 | /// Computes a new binary confusion matrix from the provided scores and labels 32 | /// 33 | /// # Arguments 34 | /// 35 | /// * `scores` - vector of scores 36 | /// * `labels` - vector of boolean labels 37 | /// * `threshold` - decision threshold value for classifying scores 38 | /// 39 | /// # Errors 40 | /// 41 | /// An invalid input error will be returned if either scores or labels are empty, or if their 42 | /// lengths do not match. An undefined metric error will be returned if scores contain any value 43 | /// that is not finite. 44 | /// 45 | /// # Examples 46 | /// 47 | /// ``` 48 | /// # use eval_metrics::error::EvalError; 49 | /// # fn main() -> Result<(), EvalError> { 50 | /// use eval_metrics::classification::BinaryConfusionMatrix; 51 | /// let scores = vec![0.4, 0.7, 0.1, 0.3, 0.9]; 52 | /// let labels = vec![false, true, false, true, true]; 53 | /// let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5)?; 54 | /// # Ok(())} 55 | /// ``` 56 | /// 57 | pub fn compute(scores: &Vec, 58 | labels: &Vec, 59 | threshold: T) -> Result { 60 | util::validate_input_dims(scores, labels).and_then(|()| { 61 | let mut counts = [0, 0, 0, 0]; 62 | for (&score, &label) in scores.iter().zip(labels) { 63 | if !score.is_finite() { 64 | return Err(EvalError::infinite_value()) 65 | } else if score >= threshold && label { 66 | counts[3] += 1; 67 | } else if score >= threshold { 68 | counts[2] += 1; 69 | } else if score < threshold && !label { 70 | counts[0] += 1; 71 | } else { 72 | counts[1] += 1; 73 | } 74 | }; 75 | let sum = counts.iter().sum(); 76 | Ok(BinaryConfusionMatrix { 77 | tp_count: counts[3], 78 | fp_count: counts[2], 79 | tn_count: counts[0], 80 | fn_count: counts[1], 81 | sum 82 | }) 83 | }) 84 | } 85 | 86 | /// 87 | /// Constructs a binary confusion matrix with the provided counts 88 | /// 89 | /// # Arguments 90 | /// 91 | /// * `tp_count` - true positive count 92 | /// * `fp_count` - false positive count 93 | /// * `tn_count` - true negative count 94 | /// * `fn_count` - false negative count 95 | /// 96 | /// # Errors 97 | /// 98 | /// An invalid input error will be returned if all provided counts are zero 99 | /// 100 | pub fn from_counts(tp_count: usize, 101 | fp_count: usize, 102 | tn_count: usize, 103 | fn_count: usize) -> Result { 104 | match tp_count + fp_count + tn_count + fn_count { 105 | 0 => Err(EvalError::invalid_input("Confusion matrix has all zero counts")), 106 | sum => Ok(BinaryConfusionMatrix {tp_count, fp_count, tn_count, fn_count, sum}) 107 | } 108 | } 109 | 110 | /// 111 | /// Computes accuracy 112 | /// 113 | pub fn accuracy(&self) -> Result { 114 | let num = self.tp_count + self.tn_count; 115 | match self.sum { 116 | // This should never happen as long as we prevent empty confusion matrices 117 | 0 => Err(EvalError::undefined_metric("Accuracy")), 118 | sum => Ok(num as f64 / sum as f64) 119 | } 120 | } 121 | 122 | /// 123 | /// Computes precision 124 | /// 125 | pub fn precision(&self) -> Result { 126 | match self.tp_count + self.fp_count { 127 | 0 => Err(EvalError::undefined_metric("Precision")), 128 | den => Ok((self.tp_count as f64) / den as f64) 129 | } 130 | } 131 | 132 | /// 133 | /// Computes recall 134 | /// 135 | pub fn recall(&self) -> Result { 136 | match self.tp_count + self.fn_count { 137 | 0 => Err(EvalError::undefined_metric("Recall")), 138 | den => Ok((self.tp_count as f64) / den as f64) 139 | } 140 | } 141 | 142 | /// 143 | /// Computes F1 144 | /// 145 | pub fn f1(&self) -> Result { 146 | match (self.precision(), self.recall()) { 147 | (Ok(p), Ok(r)) if p == 0.0 && r == 0.0 => Ok(0.0), 148 | (Ok(p), Ok(r)) => Ok(2.0 * (p * r) / (p + r)), 149 | (Err(e), _) => Err(e), 150 | (_, Err(e)) => Err(e) 151 | } 152 | } 153 | 154 | /// 155 | /// Computes Matthews correlation coefficient (phi) 156 | /// 157 | pub fn mcc(&self) -> Result { 158 | let n = self.sum as f64; 159 | let s = (self.tp_count + self.fn_count) as f64 / n; 160 | let p = (self.tp_count + self.fp_count) as f64 / n; 161 | match (p * s * (1.0 - s) * (1.0 - p)).sqrt() { 162 | den if den == 0.0 => Err(EvalError::undefined_metric("MCC")), 163 | den => Ok(((self.tp_count as f64 / n) - s * p) / den) 164 | } 165 | } 166 | } 167 | 168 | impl std::fmt::Display for BinaryConfusionMatrix { 169 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 170 | let counts = vec![ 171 | vec![self.tp_count, self.fp_count], 172 | vec![self.fn_count, self.tn_count] 173 | ]; 174 | let outcomes = vec![String::from("Positive"), String::from("Negative")]; 175 | write!(f, "{}", display::stringify_confusion_matrix(&counts, &outcomes)) 176 | } 177 | } 178 | 179 | /// 180 | /// Represents a single point along a roc curve 181 | /// 182 | #[derive(Copy, Clone, Debug, PartialEq)] 183 | pub struct RocPoint { 184 | /// True positive rate 185 | pub tp_rate: T, 186 | /// False positive rate 187 | pub fp_rate: T, 188 | /// Score threshold 189 | pub threshold: T 190 | } 191 | 192 | /// 193 | /// Represents a full roc curve 194 | /// 195 | #[derive(Clone, Debug)] 196 | pub struct RocCurve { 197 | /// Roc curve points 198 | pub points: Vec>, 199 | /// Length 200 | dim: usize 201 | } 202 | 203 | impl RocCurve { 204 | 205 | /// 206 | /// Computes the roc curve from the provided data 207 | /// 208 | /// # Arguments 209 | /// 210 | /// * `scores` - vector of scores 211 | /// * `labels` - vector of labels 212 | /// 213 | /// # Errors 214 | /// 215 | /// An invalid input error will be returned if either scores or labels are empty or contain a 216 | /// single data point, or if their lengths do not match. An undefined metric error will be 217 | /// returned if scores contain any value that is not finite or if labels are all constant. 218 | /// 219 | /// # Examples 220 | /// 221 | /// ``` 222 | /// # use eval_metrics::error::EvalError; 223 | /// # fn main() -> Result<(), EvalError> { 224 | /// use eval_metrics::classification::RocCurve; 225 | /// let scores = vec![0.4, 0.7, 0.1, 0.3, 0.9]; 226 | /// let labels = vec![false, true, false, true, true]; 227 | /// let roc = RocCurve::compute(&scores, &labels)?; 228 | /// # Ok(())} 229 | /// ``` 230 | /// 231 | pub fn compute(scores: &Vec, labels: &Vec) -> Result, EvalError> { 232 | util::validate_input_dims(scores, labels).and_then(|()| { 233 | // roc not defined for a single data point 234 | let n = match scores.len() { 235 | 1 => return Err(EvalError::invalid_input( 236 | "Unable to compute roc curve on single data point" 237 | )), 238 | len => len 239 | }; 240 | let (mut pairs, np) = create_pairs(scores, labels)?; 241 | let nn = n - np; 242 | sort_pairs_descending(&mut pairs); 243 | let mut tpc = if pairs[0].1 {1} else {0}; 244 | let mut fpc = 1 - tpc; 245 | let mut points = Vec::>::new(); 246 | let mut last_tpr = T::zero(); 247 | let mut last_fpr = T::zero(); 248 | let mut trend: Option = None; 249 | 250 | for i in 1..n { 251 | if pairs[i].0 != pairs[i-1].0 { 252 | let tp_rate = T::from_usize(tpc) / T::from_usize(np); 253 | let fp_rate = T::from_usize(fpc) / T::from_usize(nn); 254 | if !tp_rate.is_finite() || !fp_rate.is_finite() { 255 | return Err(EvalError::undefined_metric("ROC")) 256 | } 257 | let threshold = pairs[i-1].0; 258 | match trend { 259 | Some(RocTrend::Horizontal) => if tp_rate > last_tpr { 260 | points.push(RocPoint {tp_rate, fp_rate, threshold}); 261 | } else if let Some(mut point) = points.last_mut() { 262 | point.fp_rate = fp_rate; 263 | point.threshold = threshold; 264 | }, 265 | Some(RocTrend::Vertical) => if fp_rate > last_fpr { 266 | points.push(RocPoint {tp_rate, fp_rate, threshold}) 267 | } else if let Some(mut point) = points.last_mut() { 268 | point.tp_rate = tp_rate; 269 | point.threshold = threshold; 270 | }, 271 | _ => points.push(RocPoint {tp_rate, fp_rate, threshold}), 272 | } 273 | 274 | trend = if fp_rate > last_fpr && tp_rate == last_tpr { 275 | Some(RocTrend::Horizontal) 276 | } else if tp_rate > last_tpr && fp_rate == last_fpr { 277 | Some(RocTrend::Vertical) 278 | } else { 279 | Some(RocTrend::Diagonal) 280 | }; 281 | last_tpr = tp_rate; 282 | last_fpr = fp_rate; 283 | } 284 | if pairs[i].1 { 285 | tpc += 1; 286 | } else { 287 | fpc += 1; 288 | } 289 | } 290 | 291 | if let Some(mut point) = points.last_mut() { 292 | if point.tp_rate != T::one() || point.fp_rate != T::one() { 293 | let threshold = pairs.last().unwrap().0; 294 | match trend { 295 | Some(RocTrend::Horizontal) if point.tp_rate == T::one() => { 296 | point.fp_rate = T::one(); 297 | point.threshold = threshold; 298 | }, 299 | Some(RocTrend::Vertical) if point.fp_rate == T::one() => { 300 | point.tp_rate = T::one(); 301 | point.threshold = threshold; 302 | } 303 | _ => points.push(RocPoint { 304 | tp_rate: T::one(), fp_rate: T::one(), threshold 305 | }) 306 | } 307 | } 308 | } 309 | 310 | match points.len() { 311 | 0 => Err(EvalError::constant_input_data()), 312 | dim => Ok(RocCurve {points, dim}) 313 | } 314 | }) 315 | } 316 | 317 | /// 318 | /// Computes AUC from the roc curve 319 | /// 320 | pub fn auc(&self) -> T { 321 | let mut val = self.points[0].tp_rate * self.points[0].fp_rate / T::from_f64(2.0); 322 | for i in 1..self.dim { 323 | let fpr_diff = self.points[i].fp_rate - self.points[i-1].fp_rate; 324 | let a = self.points[i-1].tp_rate * fpr_diff; 325 | let tpr_diff = self.points[i].tp_rate - self.points[i-1].tp_rate; 326 | let b = tpr_diff * fpr_diff / T::from_f64(2.0); 327 | val += a + b; 328 | } 329 | return val 330 | } 331 | } 332 | 333 | /// 334 | /// Represents a single point along a precision-recall curve 335 | /// 336 | #[derive(Copy, Clone, Debug, PartialEq)] 337 | pub struct PrPoint { 338 | /// Precision value 339 | pub precision: T, 340 | /// Recall value 341 | pub recall: T, 342 | /// Score threshold 343 | pub threshold: T 344 | } 345 | 346 | /// 347 | /// Represents a full precision-recall curve 348 | /// 349 | #[derive(Clone, Debug)] 350 | pub struct PrCurve { 351 | /// PR curve points 352 | pub points: Vec>, 353 | /// Length 354 | dim: usize 355 | } 356 | 357 | impl PrCurve { 358 | 359 | /// 360 | /// Computes the precision-recall curve from the provided data 361 | /// 362 | /// # Arguments 363 | /// 364 | /// * `scores` - vector of scores 365 | /// * `labels` - vector of labels 366 | /// 367 | /// # Errors 368 | /// 369 | /// An invalid input error will be returned if either scores or labels are empty or contain a 370 | /// single data point, or if their lengths do not match. An undefined metric error will be 371 | /// returned if scores contain any value that is not finite, or if labels are all false. 372 | /// 373 | /// # Examples 374 | /// 375 | /// ``` 376 | /// # use eval_metrics::error::EvalError; 377 | /// # fn main() -> Result<(), EvalError> { 378 | /// use eval_metrics::classification::PrCurve; 379 | /// let scores = vec![0.4, 0.7, 0.1, 0.3, 0.9]; 380 | /// let labels = vec![false, true, false, true, true]; 381 | /// let pr = PrCurve::compute(&scores, &labels)?; 382 | /// # Ok(())} 383 | /// ``` 384 | /// 385 | pub fn compute(scores: &Vec, labels: &Vec) -> Result, EvalError> { 386 | util::validate_input_dims(scores, labels).and_then(|()| { 387 | let n = match scores.len() { 388 | 1 => return Err(EvalError::invalid_input( 389 | "Unable to compute pr curve on single data point" 390 | )), 391 | len => len 392 | }; 393 | let (mut pairs, mut fnc) = create_pairs(scores, labels)?; 394 | sort_pairs_descending(&mut pairs); 395 | let mut tpc = 0; 396 | let mut fpc = 0; 397 | let mut points = Vec::>::new(); 398 | let mut last_rec = T::zero(); 399 | 400 | for i in 0..n { 401 | if pairs[i].1 { 402 | tpc += 1; 403 | fnc -= 1; 404 | } else { 405 | fpc += 1; 406 | } 407 | if (i < n-1 && pairs[i].0 != pairs[i+1].0) || i == n-1 { 408 | let precision = T::from_usize(tpc) / T::from_usize(tpc + fpc); 409 | let recall = T::from_usize(tpc) / T::from_usize(tpc + fnc); 410 | if !precision.is_finite() || !recall.is_finite() { 411 | return Err(EvalError::undefined_metric("PR")) 412 | } 413 | let threshold = pairs[i].0; 414 | if recall != last_rec { 415 | points.push(PrPoint {precision, recall, threshold}); 416 | } 417 | last_rec = recall; 418 | } 419 | } 420 | 421 | let dim = points.len(); 422 | Ok(PrCurve {points, dim}) 423 | }) 424 | } 425 | 426 | /// 427 | /// Computes average precision from the PR curve 428 | /// 429 | pub fn ap(&self) -> T { 430 | let mut val = self.points[0].precision * self.points[0].recall; 431 | for i in 1..self.dim { 432 | let rec_diff = self.points[i].recall - self.points[i-1].recall; 433 | val += rec_diff * self.points[i].precision; 434 | } 435 | return val; 436 | } 437 | } 438 | 439 | 440 | /// 441 | /// Confusion matrix for multi-class classification, in which rows represent predicted counts and 442 | /// columns represent labeled counts 443 | /// 444 | #[derive(Clone, Debug, Eq, PartialEq)] 445 | pub struct MultiConfusionMatrix { 446 | /// output dimension 447 | pub dim: usize, 448 | /// count data 449 | pub counts: Vec>, 450 | /// count sum 451 | sum: usize 452 | } 453 | 454 | impl MultiConfusionMatrix { 455 | 456 | /// 457 | /// Computes a new confusion matrix from the provided scores and labels 458 | /// 459 | /// # Arguments 460 | /// 461 | /// * `scores` - vector of class scores 462 | /// * `labels` - vector of class labels (indexed at zero) 463 | /// 464 | /// # Errors 465 | /// 466 | /// An invalid input error will be returned if either scores or labels are empty, or if their 467 | /// lengths do not match. An undefined metric error will be returned if scores contain any value 468 | /// that is not finite. 469 | /// 470 | /// # Examples 471 | /// 472 | /// ``` 473 | /// # use eval_metrics::error::EvalError; 474 | /// # fn main() -> Result<(), EvalError> { 475 | /// use eval_metrics::classification::MultiConfusionMatrix; 476 | /// let scores = vec![ 477 | /// vec![0.3, 0.1, 0.6], 478 | /// vec![0.5, 0.2, 0.3], 479 | /// vec![0.2, 0.7, 0.1], 480 | /// vec![0.3, 0.3, 0.4], 481 | /// vec![0.5, 0.1, 0.4], 482 | /// vec![0.8, 0.1, 0.1], 483 | /// vec![0.3, 0.5, 0.2] 484 | /// ]; 485 | /// let labels = vec![2, 1, 1, 2, 0, 2, 0]; 486 | /// let matrix = MultiConfusionMatrix::compute(&scores, &labels)?; 487 | /// # Ok(())} 488 | /// ``` 489 | /// 490 | pub fn compute(scores: &Vec>, 491 | labels: &Vec) -> Result { 492 | util::validate_input_dims(scores, labels).and_then(|()| { 493 | let dim = scores[0].len(); 494 | let mut counts = vec![vec![0; dim]; dim]; 495 | let mut sum = 0; 496 | for (i, s) in scores.iter().enumerate() { 497 | if s.iter().any(|v| !v.is_finite()) { 498 | return Err(EvalError::infinite_value()) 499 | } else if s.len() != dim { 500 | return Err(EvalError::invalid_input("Inconsistent score dimension")) 501 | } else if labels[i] >= dim { 502 | return Err(EvalError::invalid_input("Labels have more classes than scores")) 503 | } 504 | let ind = s.iter().enumerate().max_by(|(_, a), (_, b)| { 505 | a.partial_cmp(b).unwrap_or(Ordering::Equal) 506 | }).map(|(mi, _)| mi).ok_or(EvalError::constant_input_data())?; 507 | counts[ind][labels[i]] += 1; 508 | sum += 1; 509 | } 510 | Ok(MultiConfusionMatrix {dim, counts, sum}) 511 | }) 512 | } 513 | 514 | /// 515 | /// Constructs a multi confusion matrix with the provided counts 516 | /// 517 | /// # Arguments 518 | /// 519 | /// * `counts` - vector of vector of counts, where each inner vector represents a row in the 520 | /// confusion matrix 521 | /// 522 | /// # Errors 523 | /// 524 | /// An invalid input error will be returned if the counts are not a square matrix, or if the 525 | /// counts are all zero 526 | /// 527 | /// # Examples 528 | /// 529 | /// ``` 530 | /// # use eval_metrics::error::EvalError; 531 | /// # fn main() -> Result<(), EvalError> { 532 | /// use eval_metrics::classification::MultiConfusionMatrix; 533 | /// let counts = vec![ 534 | /// vec![8, 3, 2], 535 | /// vec![1, 5, 3], 536 | /// vec![2, 1, 9] 537 | /// ]; 538 | /// let matrix = MultiConfusionMatrix::from_counts(counts)?; 539 | /// # Ok(())} 540 | /// ``` 541 | /// 542 | pub fn from_counts(counts: Vec>) -> Result { 543 | let dim = counts.len(); 544 | let mut sum = 0; 545 | for row in &counts { 546 | sum += row.iter().sum::(); 547 | if row.len() != dim { 548 | let msg = format!("Inconsistent column length ({})", row.len()); 549 | return Err(EvalError::invalid_input(msg.as_str())); 550 | } 551 | } 552 | if sum == 0 { 553 | Err(EvalError::invalid_input("Confusion matrix has all zero counts")) 554 | } else { 555 | Ok(MultiConfusionMatrix {dim, counts, sum}) 556 | } 557 | } 558 | 559 | /// 560 | /// Computes accuracy 561 | /// 562 | pub fn accuracy(&self) -> Result { 563 | match self.sum { 564 | // This should never happen as long as we prevent empty confusion matrices 565 | 0 => Err(EvalError::undefined_metric("Accuracy")), 566 | sum => { 567 | let mut correct = 0; 568 | for i in 0..self.dim { 569 | correct += self.counts[i][i]; 570 | } 571 | Ok(correct as f64 / sum as f64) 572 | } 573 | } 574 | } 575 | 576 | /// 577 | /// Computes precision, which necessarily requires a specified averaging method 578 | /// 579 | /// # Arguments 580 | /// 581 | /// * `avg` - averaging method, which can be either 'Macro' or 'Weighted' 582 | /// 583 | pub fn precision(&self, avg: &Averaging) -> Result { 584 | self.agg_metric(&self.per_class_precision(), avg) 585 | } 586 | 587 | /// 588 | /// Computes recall, which necessarily requires a specified averaging method 589 | /// 590 | /// # Arguments 591 | /// 592 | /// * `avg` - averaging method, which can be either 'Macro' or 'Weighted' 593 | /// 594 | pub fn recall(&self, avg: &Averaging) -> Result { 595 | self.agg_metric(&self.per_class_recall(), avg) 596 | } 597 | 598 | /// 599 | /// Computes F1, which necessarily requires a specified averaging method 600 | /// 601 | /// # Arguments 602 | /// 603 | /// * `avg` - averaging method, which can be either 'Macro' or 'Weighted' 604 | /// 605 | pub fn f1(&self, avg: &Averaging) -> Result { 606 | self.agg_metric(&self.per_class_f1(), avg) 607 | } 608 | 609 | /// 610 | /// Computes Rk, also known as the multi-class Matthews correlation coefficient following the 611 | /// approach of Gorodkin in "Comparing two K-category assignments by a K-category correlation 612 | /// coefficient" (2004) 613 | /// 614 | pub fn rk(&self) -> Result { 615 | let mut t = vec![0.0; self.dim]; 616 | let mut p = vec![0.0; self.dim]; 617 | let mut c = 0.0; 618 | let s = self.sum as f64; 619 | 620 | for i in 0..self.dim { 621 | c += self.counts[i][i] as f64; 622 | for j in 0..self.dim { 623 | t[j] += self.counts[i][j] as f64; 624 | p[i] += self.counts[i][j] as f64; 625 | } 626 | } 627 | 628 | let tt = t.iter().fold(0.0, |acc, val| acc + (val * val)); 629 | let pp = p.iter().fold(0.0, |acc, val| acc + (val * val)); 630 | let tp = t.iter().zip(p).fold(0.0, |acc, (t_val, p_val)| acc + t_val * p_val); 631 | let num = c * s - tp; 632 | let den = (s * s - pp).sqrt() * (s * s - tt).sqrt(); 633 | 634 | if den == 0.0 { 635 | Err(EvalError::undefined_metric("Rk")) 636 | } else { 637 | Ok(num / den) 638 | } 639 | } 640 | 641 | /// 642 | /// Computes per-class accuracy, resulting in a vector of values for each class 643 | /// 644 | pub fn per_class_accuracy(&self) -> Vec> { 645 | self.per_class_binary_metric("accuracy") 646 | } 647 | 648 | /// 649 | /// Computes per-class precision, resulting in a vector of values for each class 650 | /// 651 | pub fn per_class_precision(&self) -> Vec> { 652 | self.per_class_binary_metric("precision") 653 | } 654 | 655 | /// 656 | /// Computes per-class recall, resulting in a vector of values for each class 657 | /// 658 | pub fn per_class_recall(&self) -> Vec> { 659 | self.per_class_binary_metric("recall") 660 | } 661 | 662 | /// 663 | /// Computes per-class F1, resulting in a vector of values for each class 664 | /// 665 | pub fn per_class_f1(&self) -> Vec> { 666 | self.per_class_binary_metric("f1") 667 | } 668 | 669 | /// 670 | /// Computes per-class MCC, resulting in a vector of values for each class 671 | /// 672 | pub fn per_class_mcc(&self) -> Vec> { 673 | self.per_class_binary_metric("mcc") 674 | } 675 | 676 | fn per_class_binary_metric(&self, metric: &str) -> Vec> { 677 | (0..self.dim).map(|k| { 678 | let (mut tpc, mut fpc, mut tnc, mut fnc) = (0, 0, 0, 0); 679 | for i in 0..self.dim { 680 | for j in 0..self.dim { 681 | let count = self.counts[i][j]; 682 | if i == k && j == k { 683 | tpc = count; 684 | } else if i == k { 685 | fpc += count; 686 | } else if j == k { 687 | fnc += count; 688 | } else { 689 | tnc += count; 690 | } 691 | } 692 | } 693 | let matrix = BinaryConfusionMatrix::from_counts(tpc, fpc, tnc, fnc)?; 694 | match metric { 695 | "accuracy" => matrix.accuracy(), 696 | "precision" => matrix.precision(), 697 | "recall" => matrix.recall(), 698 | "f1" => matrix.f1(), 699 | "mcc" => matrix.mcc(), 700 | other => Err(EvalError::invalid_metric(other)) 701 | } 702 | }).collect() 703 | } 704 | 705 | fn agg_metric(&self, pcm: &Vec>, 706 | avg: &Averaging) -> Result { 707 | match avg { 708 | Averaging::Macro => self.macro_metric(pcm), 709 | Averaging::Weighted => self.weighted_metric(pcm) 710 | } 711 | } 712 | 713 | fn macro_metric(&self, pcm: &Vec>) -> Result { 714 | pcm.iter().try_fold(0.0, |sum, metric| { 715 | match metric { 716 | Ok(m) => Ok(sum + m), 717 | Err(e) => Err(e.clone()) 718 | } 719 | }).map(|sum| {sum / pcm.len() as f64}) 720 | } 721 | 722 | fn weighted_metric(&self, pcm: &Vec>) -> Result { 723 | pcm.iter() 724 | .zip(self.class_counts().iter()) 725 | .try_fold(0.0, |val, (metric, &class)| { 726 | match metric { 727 | Ok(m) => Ok(val + (m * (class as f64) / (self.sum as f64))), 728 | Err(e) => Err(e.clone()) 729 | } 730 | }) 731 | } 732 | 733 | fn class_counts(&self) -> Vec { 734 | let mut counts = vec![0; self.dim]; 735 | for i in 0..self.dim { 736 | for j in 0..self.dim { 737 | counts[j] += self.counts[i][j]; 738 | } 739 | } 740 | counts 741 | } 742 | } 743 | 744 | impl std::fmt::Display for MultiConfusionMatrix { 745 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 746 | if self.dim <= 25 { 747 | let outcomes = (0..self.dim).map(|i| format!("Class-{}", i + 1)).collect(); 748 | write!(f, "{}", display::stringify_confusion_matrix(&self.counts, &outcomes)) 749 | } else { 750 | write!(f, "[Confusion matrix is too large to display]") 751 | } 752 | } 753 | } 754 | 755 | /// 756 | /// Computes multi-class AUC as described by Hand and Till in "A Simple Generalisation of the Area 757 | /// Under the ROC Curve for Multiple Class Classification Problems" (2001) 758 | /// 759 | /// # Arguments 760 | /// 761 | /// * `scores` - vector of class scores 762 | /// * `labels` - vector of class labels 763 | /// 764 | /// # Errors 765 | /// 766 | /// An invalid input error will be returned if either scores or labels are empty or contain a 767 | /// single data point, or if their lengths do not match. An undefined metric error will be 768 | /// returned if scores contain any value that is not finite, or if any pairwise roc curve is not 769 | /// defined for all distinct class label pairs. 770 | /// 771 | /// # Examples 772 | /// 773 | /// ``` 774 | /// # use eval_metrics::error::EvalError; 775 | /// # fn main() -> Result<(), EvalError> { 776 | /// use eval_metrics::classification::m_auc; 777 | /// let scores = vec![ 778 | /// vec![0.3, 0.1, 0.6], 779 | /// vec![0.5, 0.2, 0.3], 780 | /// vec![0.2, 0.7, 0.1], 781 | /// vec![0.3, 0.3, 0.4], 782 | /// vec![0.5, 0.1, 0.4], 783 | /// vec![0.8, 0.1, 0.1], 784 | /// vec![0.3, 0.5, 0.2] 785 | /// ]; 786 | /// let labels = vec![2, 1, 1, 2, 0, 2, 0]; 787 | /// let metric = m_auc(&scores, &labels)?; 788 | /// # Ok(())} 789 | /// ``` 790 | 791 | pub fn m_auc(scores: &Vec>, labels: &Vec) -> Result { 792 | util::validate_input_dims(scores, labels).and_then(|()| { 793 | let dim = scores[0].len(); 794 | let mut m_sum = T::zero(); 795 | 796 | fn subset(scr: &Vec>, 797 | lab: &Vec, 798 | j: usize, 799 | k: usize) -> (Vec, Vec) { 800 | 801 | scr.iter().zip(lab.iter()).filter(|(_, &l)| { 802 | l == j || l == k 803 | }).map(|(s, &l)| { 804 | (s[k], l == k) 805 | }).unzip() 806 | } 807 | 808 | for j in 0..dim { 809 | for k in 0..j { 810 | let (k_scores, k_labels) = subset(scores, labels, j, k); 811 | let ajk = RocCurve::compute(&k_scores, &k_labels)?.auc(); 812 | let (j_scores, j_labels) = subset(scores, labels, k, j); 813 | let akj = RocCurve::compute(&j_scores, &j_labels)?.auc(); 814 | m_sum += (ajk + akj) / T::from_f64(2.0); 815 | } 816 | } 817 | Ok(m_sum * T::from_f64(2.0) / (T::from_usize(dim) * (T::from_usize(dim) - T::one()))) 818 | }) 819 | } 820 | 821 | /// 822 | /// Specifies the averaging method to use for computing multi-class metrics 823 | /// 824 | #[derive(Copy, Clone, Debug, Eq, PartialEq)] 825 | pub enum Averaging { 826 | /// Macro average, in which the individual metrics for each class are weighted uniformly 827 | Macro, 828 | /// Weighted average, in which the individual metrics for each class are weighted by the number 829 | /// of occurrences of that class 830 | Weighted 831 | } 832 | 833 | enum RocTrend { 834 | Horizontal, 835 | Vertical, 836 | Diagonal 837 | } 838 | 839 | fn create_pairs(scores: &Vec, 840 | labels: &Vec) -> Result<(Vec<(T, bool)>, usize), EvalError> { 841 | let n = scores.len(); 842 | let mut pairs = Vec::with_capacity(n); 843 | let mut num_pos = 0; 844 | 845 | for i in 0..n { 846 | if !scores[i].is_finite() { 847 | return Err(EvalError::infinite_value()) 848 | } else if labels[i] { 849 | num_pos += 1; 850 | } 851 | pairs.push((scores[i], labels[i])) 852 | } 853 | Ok((pairs, num_pos)) 854 | } 855 | 856 | fn sort_pairs_descending(pairs: &mut Vec<(T, bool)>) { 857 | pairs.sort_unstable_by(|(s1, _), (s2, _)| { 858 | if s1 > s2 { 859 | Ordering::Less 860 | } else if s1 < s2 { 861 | Ordering::Greater 862 | } else { 863 | Ordering::Equal 864 | } 865 | }); 866 | } 867 | 868 | #[cfg(test)] 869 | mod tests { 870 | use assert_approx_eq::assert_approx_eq; 871 | use super::*; 872 | 873 | fn binary_data() -> (Vec, Vec) { 874 | let scores = vec![0.5, 0.2, 0.7, 0.4, 0.1, 0.3, 0.8, 0.9]; 875 | let labels = vec![false, false, true, false, true, false, false, true]; 876 | (scores, labels) 877 | } 878 | 879 | fn multi_class_data() -> (Vec>, Vec) { 880 | 881 | let scores = vec![ 882 | vec![0.3, 0.1, 0.6], 883 | vec![0.5, 0.2, 0.3], 884 | vec![0.2, 0.7, 0.1], 885 | vec![0.3, 0.3, 0.4], 886 | vec![0.5, 0.1, 0.4], 887 | vec![0.8, 0.1, 0.1], 888 | vec![0.3, 0.5, 0.2] 889 | ]; 890 | let labels = vec![2, 1, 1, 2, 0, 2, 0]; 891 | (scores, labels) 892 | } 893 | 894 | #[test] 895 | fn test_binary_confusion_matrix() { 896 | let (scores, labels) = binary_data(); 897 | let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap(); 898 | assert_eq!(matrix.tp_count, 2); 899 | assert_eq!(matrix.fp_count, 2); 900 | assert_eq!(matrix.tn_count, 3); 901 | assert_eq!(matrix.fn_count, 1); 902 | } 903 | 904 | #[test] 905 | fn test_binary_confusion_matrix_empty() { 906 | assert!(BinaryConfusionMatrix::compute( 907 | &Vec::::new(), 908 | &Vec::::new(), 909 | 0.5 910 | ).is_err()); 911 | } 912 | 913 | #[test] 914 | fn test_binary_confusion_matrix_unequal_length() { 915 | assert!(BinaryConfusionMatrix::compute( 916 | &vec![0.1, 0.2], 917 | &vec![true, false, true], 918 | 0.5 919 | ).is_err()); 920 | } 921 | 922 | #[test] 923 | fn test_binary_confusion_matrix_nan() { 924 | assert!(BinaryConfusionMatrix::compute( 925 | &vec![f64::NAN, 0.2, 0.4], 926 | &vec![true, false, true], 927 | 0.5 928 | ).is_err()); 929 | } 930 | 931 | #[test] 932 | fn test_binary_confusion_matrix_with_counts() { 933 | let matrix = BinaryConfusionMatrix::from_counts(2, 4, 5, 3).unwrap(); 934 | assert_eq!(matrix.tp_count, 2); 935 | assert_eq!(matrix.fp_count, 4); 936 | assert_eq!(matrix.tn_count, 5); 937 | assert_eq!(matrix.fn_count, 3); 938 | assert_eq!(matrix.sum, 14); 939 | assert!(BinaryConfusionMatrix::from_counts(0, 0, 0, 0).is_err()) 940 | } 941 | 942 | #[test] 943 | fn test_binary_accuracy() { 944 | let (scores, labels) = binary_data(); 945 | let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap(); 946 | assert_approx_eq!(matrix.accuracy().unwrap(), 0.625); 947 | } 948 | 949 | #[test] 950 | fn test_binary_precision() { 951 | let (scores, labels) = binary_data(); 952 | let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap(); 953 | assert_approx_eq!(matrix.precision().unwrap(), 0.5); 954 | 955 | // test edge case where we never predict a positive 956 | assert!(BinaryConfusionMatrix::compute( 957 | &vec![0.4, 0.3, 0.1, 0.2, 0.1], 958 | &vec![true, false, true, false, true], 959 | 0.5 960 | ).unwrap().precision().is_err()); 961 | } 962 | 963 | #[test] 964 | fn test_binary_precision_empty() { 965 | assert!(BinaryConfusionMatrix::compute( 966 | &Vec::::new(), 967 | &Vec::::new(), 968 | 0.5 969 | ).is_err()); 970 | } 971 | 972 | #[test] 973 | fn test_binary_precision_unequal_length() { 974 | assert!(BinaryConfusionMatrix::compute( 975 | &vec![0.1, 0.2], 976 | &vec![true, false, true], 977 | 0.5 978 | ).is_err()); 979 | } 980 | 981 | #[test] 982 | fn test_binary_recall() { 983 | let (scores, labels) = binary_data(); 984 | let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap(); 985 | assert_approx_eq!(matrix.recall().unwrap(), 2.0 / 3.0); 986 | 987 | // test edge case where we have no positive class 988 | assert!(BinaryConfusionMatrix::compute( 989 | &vec![0.4, 0.3, 0.1, 0.8, 0.7], 990 | &vec![false, false, false, false, false], 991 | 0.5 992 | ).unwrap().recall().is_err()); 993 | } 994 | 995 | #[test] 996 | fn test_binary_recall_empty() { 997 | assert!(BinaryConfusionMatrix::compute( 998 | &Vec::::new(), 999 | &Vec::::new(), 1000 | 0.5 1001 | ).is_err()); 1002 | } 1003 | 1004 | #[test] 1005 | fn test_binary_recall_unequal_length() { 1006 | assert!(BinaryConfusionMatrix::compute( 1007 | &vec![0.1, 0.2], 1008 | &vec![true, false, true], 1009 | 0.5 1010 | ).is_err()); 1011 | } 1012 | 1013 | #[test] 1014 | fn test_binary_f1() { 1015 | let (scores, labels) = binary_data(); 1016 | let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap(); 1017 | assert_approx_eq!(matrix.f1().unwrap(), 0.5714285714285715); 1018 | 1019 | // test edge case where we never predict a positive 1020 | assert!(BinaryConfusionMatrix::compute( 1021 | &vec![0.4, 0.3, 0.1, 0.2, 0.1], 1022 | &vec![true, false, true, false, true], 1023 | 0.5 1024 | ).unwrap().f1().is_err()); 1025 | 1026 | // test edge case where we have no positive class 1027 | assert!(BinaryConfusionMatrix::compute( 1028 | &vec![0.4, 0.3, 0.1, 0.8, 0.7], 1029 | &vec![false, false, false, false, false], 1030 | 0.5 1031 | ).unwrap().f1().is_err()); 1032 | } 1033 | 1034 | #[test] 1035 | fn test_binary_f1_empty() { 1036 | assert!(BinaryConfusionMatrix::compute( 1037 | &Vec::::new(), 1038 | &Vec::::new(), 1039 | 0.5 1040 | ).is_err()); 1041 | } 1042 | 1043 | #[test] 1044 | fn test_binary_f1_unequal_length() { 1045 | assert!(BinaryConfusionMatrix::compute( 1046 | &vec![0.1, 0.2], 1047 | &vec![true, false, true], 1048 | 0.5 1049 | ).is_err()); 1050 | } 1051 | 1052 | #[test] 1053 | fn test_binary_f1_0p_0r() { 1054 | let scores = vec![0.1, 0.2, 0.7, 0.8]; 1055 | let labels = vec![false, true, false, false]; 1056 | 1057 | assert_eq!(BinaryConfusionMatrix::compute(&scores, &labels, 0.5) 1058 | .unwrap() 1059 | .f1() 1060 | .unwrap(), 0.0 1061 | ) 1062 | } 1063 | 1064 | #[test] 1065 | fn test_mcc() { 1066 | let (scores, labels) = binary_data(); 1067 | let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap(); 1068 | assert_approx_eq!(matrix.mcc().unwrap(), 0.2581988897471611) 1069 | } 1070 | 1071 | #[test] 1072 | fn test_roc() { 1073 | let (scores, labels) = binary_data(); 1074 | let roc = RocCurve::compute(&scores, &labels).unwrap(); 1075 | 1076 | assert_eq!(roc.dim, 5); 1077 | assert_approx_eq!(roc.points[0].tp_rate, 1.0 / 3.0); 1078 | assert_approx_eq!(roc.points[0].fp_rate, 0.0); 1079 | assert_approx_eq!(roc.points[0].threshold, 0.9); 1080 | assert_approx_eq!(roc.points[1].tp_rate, 1.0 / 3.0); 1081 | assert_approx_eq!(roc.points[1].fp_rate, 0.2); 1082 | assert_approx_eq!(roc.points[1].threshold, 0.8); 1083 | assert_approx_eq!(roc.points[2].tp_rate, 2.0 / 3.0); 1084 | assert_approx_eq!(roc.points[2].fp_rate, 0.2); 1085 | assert_approx_eq!(roc.points[2].threshold, 0.7); 1086 | assert_approx_eq!(roc.points[3].tp_rate, 2.0 / 3.0); 1087 | assert_approx_eq!(roc.points[3].fp_rate, 1.0); 1088 | assert_approx_eq!(roc.points[3].threshold, 0.2); 1089 | assert_approx_eq!(roc.points[4].tp_rate, 1.0); 1090 | assert_approx_eq!(roc.points[4].fp_rate, 1.0); 1091 | assert_approx_eq!(roc.points[4].threshold, 0.1); 1092 | } 1093 | 1094 | #[test] 1095 | fn test_roc_tied_scores() { 1096 | let scores = vec![1.0, 0.1, 1.0, 0.9, 0.5, 0.1, 0.8, 0.9, 1.0, 0.4]; 1097 | let labels = vec![true, false, false, false, false, false, true, true, false, false]; 1098 | let roc = RocCurve::compute(&scores, &labels).unwrap(); 1099 | assert_approx_eq!(roc.points[0].tp_rate, 1.0 / 3.0); 1100 | assert_approx_eq!(roc.points[0].fp_rate, 0.2857142857142857); 1101 | assert_approx_eq!(roc.points[0].threshold, 1.0); 1102 | assert_approx_eq!(roc.points[1].tp_rate, 2.0 / 3.0); 1103 | assert_approx_eq!(roc.points[1].fp_rate, 0.42857142857142855); 1104 | assert_approx_eq!(roc.points[1].threshold, 0.9); 1105 | assert_approx_eq!(roc.points[2].tp_rate, 1.0); 1106 | assert_approx_eq!(roc.points[2].fp_rate, 0.42857142857142855); 1107 | assert_approx_eq!(roc.points[2].threshold, 0.8); 1108 | assert_approx_eq!(roc.points[3].tp_rate, 1.0); 1109 | assert_approx_eq!(roc.points[3].fp_rate, 1.0); 1110 | assert_approx_eq!(roc.points[3].threshold, 0.1); 1111 | } 1112 | 1113 | #[test] 1114 | fn test_roc_empty() { 1115 | assert!(RocCurve::compute(&Vec::::new(), &Vec::::new()).is_err()); 1116 | } 1117 | 1118 | #[test] 1119 | fn test_roc_unequal_length() { 1120 | assert!(RocCurve::compute( 1121 | &vec![0.4, 0.5, 0.2], 1122 | &vec![true, false, true, false] 1123 | ).is_err()); 1124 | } 1125 | 1126 | #[test] 1127 | fn test_roc_nan() { 1128 | assert!(RocCurve::compute( 1129 | &vec![0.4, 0.5, 0.2, f64::NAN], 1130 | &vec![true, false, true, false] 1131 | ).is_err()); 1132 | } 1133 | 1134 | #[test] 1135 | fn test_roc_constant_label() { 1136 | let scores = vec![0.1, 0.4, 0.5, 0.7]; 1137 | let labels_true = vec![true; 4]; 1138 | let labels_false = vec![false; 4]; 1139 | assert!(match RocCurve::compute(&scores, &labels_true) { 1140 | Err(err) if err.msg.contains("Undefined") => true, 1141 | _ => false 1142 | }); 1143 | assert!(match RocCurve::compute(&scores, &labels_false) { 1144 | Err(err) if err.msg.contains("Undefined") => true, 1145 | _ => false 1146 | }); 1147 | } 1148 | 1149 | #[test] 1150 | fn test_roc_constant_score() { 1151 | let scores = vec![0.4, 0.4, 0.4, 0.4]; 1152 | let labels = vec![true, false, true, false]; 1153 | assert!(match RocCurve::compute(&scores, &labels) { 1154 | Err(err) if err.msg.contains("Constant") => true, 1155 | _ => false 1156 | }); 1157 | } 1158 | 1159 | #[test] 1160 | fn test_auc() { 1161 | let (scores, labels) = binary_data(); 1162 | assert_approx_eq!(RocCurve::compute(&scores, &labels).unwrap().auc(), 0.6); 1163 | 1164 | let scores2 = vec![0.2, 0.5, 0.5, 0.3]; 1165 | let labels2 = vec![false, true, false, true]; 1166 | assert_approx_eq!(RocCurve::compute(&scores2, &labels2).unwrap().auc(), 0.625); 1167 | } 1168 | 1169 | #[test] 1170 | fn test_auc_tied_scores() { 1171 | let scores = vec![0.1, 0.2, 0.3, 0.3, 0.3, 0.7, 0.8]; 1172 | let labels1 = vec![false, false, true, false, true, false, true]; 1173 | let labels2 = vec![false, false, true, true, false, false, true]; 1174 | let labels3 = vec![false, false, false, true, true, false, true]; 1175 | assert_approx_eq!(RocCurve::compute(&scores, &labels1).unwrap().auc(), 0.75); 1176 | assert_approx_eq!(RocCurve::compute(&scores, &labels2).unwrap().auc(), 0.75); 1177 | assert_approx_eq!(RocCurve::compute(&scores, &labels3).unwrap().auc(), 0.75); 1178 | 1179 | let scores2 = vec![1.0, 0.1, 1.0, 0.9, 0.5, 0.1, 0.8, 0.9, 1.0, 0.4]; 1180 | let labels4 = vec![true, false, false, false, false, false, true, true, false, false]; 1181 | assert_approx_eq!(RocCurve::compute(&scores2, &labels4).unwrap().auc(), 0.6904761904761905); 1182 | } 1183 | 1184 | #[test] 1185 | fn test_pr() { 1186 | let (scores, labels) = binary_data(); 1187 | let pr = PrCurve::compute(&scores, &labels).unwrap(); 1188 | assert_approx_eq!(pr.points[0].precision, 1.0); 1189 | assert_approx_eq!(pr.points[0].recall, 1.0 / 3.0); 1190 | assert_approx_eq!(pr.points[0].threshold, 0.9); 1191 | assert_approx_eq!(pr.points[1].precision, 2.0 / 3.0); 1192 | assert_approx_eq!(pr.points[1].recall, 2.0 / 3.0); 1193 | assert_approx_eq!(pr.points[1].threshold, 0.7); 1194 | assert_approx_eq!(pr.points[2].precision, 0.375); 1195 | assert_approx_eq!(pr.points[2].recall, 1.0); 1196 | assert_approx_eq!(pr.points[2].threshold, 0.1); 1197 | } 1198 | 1199 | #[test] 1200 | fn test_pr_empty() { 1201 | assert!(PrCurve::compute(&Vec::::new(), &Vec::::new()).is_err()); 1202 | } 1203 | 1204 | #[test] 1205 | fn test_pr_unequal_length() { 1206 | assert!(PrCurve::compute(&vec![0.4, 0.5, 0.2], &vec![true, false, true, false]).is_err()); 1207 | } 1208 | 1209 | #[test] 1210 | fn test_pr_nan() { 1211 | assert!(PrCurve::compute( 1212 | &vec![0.4, 0.5, 0.2, f64::NAN], 1213 | &vec![true, false, true, false] 1214 | ).is_err()); 1215 | } 1216 | 1217 | #[test] 1218 | fn test_pr_constant_label() { 1219 | let scores = vec![0.1, 0.4, 0.5, 0.7]; 1220 | let labels_true = vec![true; 4]; 1221 | let labels_false = vec![false; 4]; 1222 | assert!(PrCurve::compute(&scores, &labels_true).is_ok()); 1223 | assert!(match PrCurve::compute(&scores, &labels_false) { 1224 | Err(err) if err.msg.contains("Undefined") => true, 1225 | _ => false 1226 | }); 1227 | } 1228 | 1229 | #[test] 1230 | fn test_pr_constant_score() { 1231 | let scores = vec![0.4, 0.4, 0.4, 0.4]; 1232 | let labels = vec![true, false, true, false]; 1233 | assert!(PrCurve::compute(&scores, &labels).is_ok()); 1234 | } 1235 | 1236 | #[test] 1237 | fn test_ap() { 1238 | let (scores, labels) = binary_data(); 1239 | assert_approx_eq!(PrCurve::compute(&scores, &labels).unwrap().ap(), 0.6805555555555556); 1240 | 1241 | let scores2 = vec![0.2, 0.5, 0.5, 0.3]; 1242 | let labels2 = vec![false, true, false, true]; 1243 | assert_approx_eq!(PrCurve::compute(&scores2, &labels2).unwrap().ap(), 0.58333333333333); 1244 | } 1245 | 1246 | #[test] 1247 | fn test_ap_tied_scores() { 1248 | let scores = vec![0.1, 0.2, 0.3, 0.3, 0.3, 0.7, 0.8]; 1249 | let labels1 = vec![false, false, true, false, true, false, true]; 1250 | let labels2 = vec![false, false, true, true, false, false, true]; 1251 | let labels3 = vec![false, false, false, true, true, false, true]; 1252 | assert_approx_eq!(PrCurve::compute(&scores, &labels1).unwrap().ap(), 0.7333333333333); 1253 | assert_approx_eq!(PrCurve::compute(&scores, &labels2).unwrap().ap(), 0.7333333333333); 1254 | assert_approx_eq!(PrCurve::compute(&scores, &labels3).unwrap().ap(), 0.7333333333333); 1255 | } 1256 | 1257 | #[test] 1258 | fn test_multi_confusion_matrix() { 1259 | let (scores, labels) = multi_class_data(); 1260 | let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap(); 1261 | assert_eq!(matrix.counts, vec![vec![1, 1, 1], vec![1, 1, 0], vec![0, 0, 2]]); 1262 | assert_eq!(matrix.dim, 3); 1263 | assert_eq!(matrix.sum, 7); 1264 | } 1265 | 1266 | #[test] 1267 | fn test_multi_confusion_matrix_empty() { 1268 | let scores: Vec> = vec![]; 1269 | let labels = Vec::::new(); 1270 | assert!(MultiConfusionMatrix::compute(&scores, &labels).is_err()); 1271 | } 1272 | 1273 | #[test] 1274 | fn test_multi_confusion_matrix_unequal_length() { 1275 | assert!(MultiConfusionMatrix::compute(&vec![vec![0.2, 0.4, 0.4], vec![0.5, 0.1, 0.4]], 1276 | &vec![2, 1, 0]).is_err()); 1277 | } 1278 | 1279 | #[test] 1280 | fn test_multi_confusion_matrix_nan() { 1281 | assert!(MultiConfusionMatrix::compute( 1282 | &vec![vec![0.2, 0.4, 0.4], vec![0.5, 0.1, 0.4], vec![0.3, 0.7, f64::NAN]], 1283 | &vec![2, 1, 0] 1284 | ).is_err()); 1285 | } 1286 | 1287 | #[test] 1288 | fn test_multi_confusion_matrix_inconsistent_score_dims() { 1289 | let scores = vec![vec![0.2, 0.4, 0.4], vec![0.5, 0.1, 0.4], vec![0.3, 0.7]]; 1290 | let labels = vec![2, 1, 0]; 1291 | assert!(MultiConfusionMatrix::compute(&scores, &labels).is_err()); 1292 | } 1293 | 1294 | #[test] 1295 | fn test_multi_confusion_matrix_score_label_dim_mismatch() { 1296 | let scores = vec![vec![0.2, 0.4, 0.4], vec![0.5, 0.1, 0.4], vec![0.3, 0.2, 0.5]]; 1297 | let labels = vec![2, 3, 0]; 1298 | assert!(MultiConfusionMatrix::compute(&scores, &labels).is_err()); 1299 | } 1300 | 1301 | #[test] 1302 | fn test_multi_confusion_matrix_counts() { 1303 | let counts = vec![vec![6, 3, 1], vec![4, 2, 7], vec![5, 2, 8]]; 1304 | let matrix = MultiConfusionMatrix::from_counts(counts).unwrap(); 1305 | assert_eq!(matrix.dim, 3); 1306 | assert_eq!(matrix.sum, 38); 1307 | assert_eq!(matrix.counts, vec![vec![6, 3, 1], vec![4, 2, 7], vec![5, 2, 8]]); 1308 | } 1309 | 1310 | #[test] 1311 | fn test_multi_confusion_matrix_bad_counts() { 1312 | let counts = vec![vec![6, 3, 1], vec![4, 2], vec![5, 2, 8]]; 1313 | assert!(MultiConfusionMatrix::from_counts(counts).is_err()) 1314 | } 1315 | 1316 | #[test] 1317 | fn test_multi_accuracy() { 1318 | let (scores, labels) = multi_class_data(); 1319 | let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap(); 1320 | assert_approx_eq!(matrix.accuracy().unwrap(), 0.5714285714285714) 1321 | } 1322 | 1323 | #[test] 1324 | fn test_multi_precision() { 1325 | let (scores, labels) = multi_class_data(); 1326 | let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap(); 1327 | assert_approx_eq!(matrix.precision(&Averaging::Macro).unwrap(), 0.611111111111111); 1328 | assert_approx_eq!(matrix.precision(&Averaging::Weighted).unwrap(), 2.0 / 3.0); 1329 | 1330 | assert!(MultiConfusionMatrix::compute( 1331 | &vec![vec![0.6, 0.4, 0.0], 1332 | vec![0.2, 0.8, 0.0], 1333 | vec![0.9, 0.1, 0.0], 1334 | vec![0.3, 0.7, 0.0]], 1335 | &vec![0, 1, 2, 1] 1336 | ).unwrap().precision(&Averaging::Macro).is_err()) 1337 | } 1338 | 1339 | #[test] 1340 | fn test_multi_recall() { 1341 | let (scores, labels) = multi_class_data(); 1342 | let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap(); 1343 | assert_approx_eq!(matrix.recall(&Averaging::Macro).unwrap(), 0.5555555555555555); 1344 | assert_approx_eq!(matrix.recall(&Averaging::Weighted).unwrap(), 0.5714285714285714); 1345 | 1346 | assert!(MultiConfusionMatrix::compute( 1347 | &vec![vec![0.6, 0.3, 0.1], 1348 | vec![0.2, 0.5, 0.3], 1349 | vec![0.8, 0.1, 0.1], 1350 | vec![0.3, 0.5, 0.2]], 1351 | &vec![0, 1, 0, 1] 1352 | ).unwrap().recall(&Averaging::Macro).is_err()) 1353 | } 1354 | 1355 | #[test] 1356 | fn test_multi_f1() { 1357 | let (scores, labels) = multi_class_data(); 1358 | let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap(); 1359 | assert_approx_eq!(matrix.f1(&Averaging::Macro).unwrap(), 0.5666666666666668); 1360 | assert_approx_eq!(matrix.f1(&Averaging::Weighted).unwrap(), 0.6); 1361 | 1362 | assert!(MultiConfusionMatrix::compute( 1363 | &vec![vec![0.6, 0.4, 0.0], 1364 | vec![0.2, 0.8, 0.0], 1365 | vec![0.3, 0.7, 0.0]], 1366 | &vec![0, 2, 1] 1367 | ).unwrap().f1(&Averaging::Macro).is_err()); 1368 | 1369 | assert!(MultiConfusionMatrix::compute( 1370 | &vec![vec![0.6, 0.3, 0.1], 1371 | vec![0.2, 0.5, 0.3], 1372 | vec![0.3, 0.5, 0.2]], 1373 | &vec![1, 0, 1] 1374 | ).unwrap().f1(&Averaging::Macro).is_err()); 1375 | } 1376 | 1377 | #[test] 1378 | fn test_multi_f1_0p_0r() { 1379 | let scores = multi_class_data().0; 1380 | // every prediction is wrong 1381 | let labels = vec![1, 2, 0, 0, 1, 1, 0]; 1382 | 1383 | assert_eq!(MultiConfusionMatrix::compute(&scores, &labels) 1384 | .unwrap() 1385 | .f1(&Averaging::Macro) 1386 | .unwrap(), 0.0 1387 | ) 1388 | } 1389 | 1390 | #[test] 1391 | fn test_rk() { 1392 | let (scores, labels) = multi_class_data(); 1393 | let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap(); 1394 | assert_approx_eq!(matrix.rk().unwrap(), 0.375) 1395 | } 1396 | 1397 | #[test] 1398 | fn test_per_class_accuracy() { 1399 | let (scores, labels) = multi_class_data(); 1400 | let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap(); 1401 | let pca = matrix.per_class_accuracy(); 1402 | assert_eq!(pca.len(), 3); 1403 | assert_approx_eq!(pca[0].as_ref().unwrap(), 0.5714285714285714); 1404 | assert_approx_eq!(pca[1].as_ref().unwrap(), 0.7142857142857143); 1405 | assert_approx_eq!(pca[2].as_ref().unwrap(), 0.8571428571428571); 1406 | } 1407 | 1408 | #[test] 1409 | fn test_per_class_precision() { 1410 | let (scores, labels) = multi_class_data(); 1411 | let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap(); 1412 | let pcp = matrix.per_class_precision(); 1413 | assert_eq!(pcp.len(), 3); 1414 | assert_approx_eq!(pcp[0].as_ref().unwrap(), 0.3333333333333333); 1415 | assert_approx_eq!(pcp[1].as_ref().unwrap(), 0.5); 1416 | assert_approx_eq!(pcp[2].as_ref().unwrap(), 1.0); 1417 | println!("{}", matrix); 1418 | } 1419 | 1420 | #[test] 1421 | fn test_per_class_recall() { 1422 | let (scores, labels) = multi_class_data(); 1423 | let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap(); 1424 | let pcr = matrix.per_class_recall(); 1425 | assert_eq!(pcr.len(), 3); 1426 | assert_approx_eq!(pcr[0].as_ref().unwrap(), 0.5); 1427 | assert_approx_eq!(pcr[1].as_ref().unwrap(), 0.5); 1428 | assert_approx_eq!(pcr[2].as_ref().unwrap(), 0.6666666666666666); 1429 | } 1430 | 1431 | #[test] 1432 | fn test_per_class_f1() { 1433 | let (scores, labels) = multi_class_data(); 1434 | let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap(); 1435 | let pcf = matrix.per_class_f1(); 1436 | assert_eq!(pcf.len(), 3); 1437 | assert_approx_eq!(pcf[0].as_ref().unwrap(), 0.4); 1438 | assert_approx_eq!(pcf[1].as_ref().unwrap(), 0.5); 1439 | assert_approx_eq!(pcf[2].as_ref().unwrap(), 0.8); 1440 | } 1441 | 1442 | #[test] 1443 | fn test_per_class_mcc() { 1444 | let (scores, labels) = multi_class_data(); 1445 | let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap(); 1446 | let pcm = matrix.per_class_mcc(); 1447 | assert_eq!(pcm.len(), 3); 1448 | assert_approx_eq!(pcm[0].as_ref().unwrap(), 0.09128709291752773); 1449 | assert_approx_eq!(pcm[1].as_ref().unwrap(), 0.3); 1450 | assert_approx_eq!(pcm[2].as_ref().unwrap(), 0.7302967433402215); 1451 | } 1452 | 1453 | #[test] 1454 | fn test_m_auc() { 1455 | let (scores, labels) = multi_class_data(); 1456 | assert_approx_eq!(m_auc(&scores, &labels).unwrap(), 0.673611111111111) 1457 | } 1458 | 1459 | #[test] 1460 | fn test_m_auc_empty() { 1461 | assert!(m_auc(&Vec::>::new(), &Vec::::new()).is_err()); 1462 | } 1463 | 1464 | #[test] 1465 | fn test_m_auc_unequal_length() { 1466 | assert!(m_auc(&Vec::>::new(), &vec![3, 0, 1, 2]).is_err()); 1467 | } 1468 | 1469 | #[test] 1470 | fn test_m_auc_nan() { 1471 | let scores = vec![ 1472 | vec![0.3, 0.1, 0.6], 1473 | vec![0.5, f64::NAN, 0.3], 1474 | vec![0.2, 0.7, 0.1], 1475 | ]; 1476 | // every prediction is wrong 1477 | let labels = vec![1, 2, 0]; 1478 | assert!(m_auc(&scores, &labels).is_err()); 1479 | } 1480 | 1481 | #[test] 1482 | fn test_m_auc_constant_label() { 1483 | let scores = vec![ 1484 | vec![0.3, 0.1, 0.6], 1485 | vec![0.5, 0.2, 0.3], 1486 | vec![0.2, 0.7, 0.1], 1487 | vec![0.8, 0.1, 0.1], 1488 | ]; 1489 | 1490 | let labels = vec![1, 1, 1, 1]; 1491 | assert!(m_auc(&scores, &labels).is_err()) 1492 | } 1493 | } 1494 | -------------------------------------------------------------------------------- /src/display.rs: -------------------------------------------------------------------------------- 1 | 2 | /// 3 | /// Creates a string representation of a confusion matrix with the provided counts and class names 4 | /// 5 | /// # Arguments 6 | /// 7 | /// * `counts` the confusion matrix counts 8 | /// * `classes` the class outcome names 9 | /// 10 | pub fn stringify_confusion_matrix(counts: &Vec>, classes: &Vec) -> String { 11 | // Compute max length of outcome names (in chars) 12 | let max_class_length = classes.iter().fold(0, |max, outcome| { 13 | match outcome.chars().count() { 14 | length if length > max => length, 15 | _ => max 16 | } 17 | }); 18 | // Two spaces on either side, plus a leading pipe character 19 | let padded_class_length = max_class_length + 5; 20 | // Two spaces on either side of "Prediction", plus a leading pipe character 21 | let prediction_wing_length = padded_class_length + 15; 22 | // Build the output string 23 | let mut output = String::new(); 24 | write_cm_top_rows(classes, prediction_wing_length, padded_class_length, &mut output); 25 | write_cm_data_rows(counts, classes, prediction_wing_length, padded_class_length, &mut output); 26 | output 27 | } 28 | 29 | fn write_cm_top_rows(outcomes: &Vec, 30 | prediction_wing_length: usize, 31 | padded_outcome_length: usize, 32 | buffer: &mut String) { 33 | // 1st row 34 | fill_char(' ', prediction_wing_length, buffer); 35 | buffer.push('o'); 36 | fill_char('=', outcomes.len() * padded_outcome_length - 1, buffer); 37 | buffer.push_str("o\n"); 38 | 39 | // 2nd row 40 | fill_char(' ', prediction_wing_length, buffer); 41 | buffer.push('|'); 42 | buffer.push_str(center("Label", outcomes.len() * padded_outcome_length - 1).as_str()); 43 | buffer.push_str("|\n"); 44 | 45 | // 3rd row 46 | fill_char(' ', prediction_wing_length, buffer); 47 | buffer.push('o'); 48 | fill_char('=', outcomes.len() * padded_outcome_length - 1, buffer); 49 | buffer.push_str("o\n"); 50 | 51 | // 4th row 52 | fill_char(' ', prediction_wing_length, buffer); 53 | buffer.push('|'); 54 | for i in 0..outcomes.len() { 55 | let content = center(outcomes[i].as_str(), padded_outcome_length - 1); 56 | buffer.push_str(format!("{}|", content).as_str()); 57 | } 58 | buffer.push('\n'); 59 | 60 | // 5th row 61 | buffer.push('o'); 62 | buffer.push_str("==============o"); 63 | fill_char('=', prediction_wing_length - 16, buffer); 64 | buffer.push('o'); 65 | for _ in 1..outcomes.len() { 66 | fill_char('=', padded_outcome_length - 1, buffer); 67 | buffer.push('|'); 68 | } 69 | fill_char('=', padded_outcome_length - 1, buffer); 70 | buffer.push_str("o\n"); 71 | } 72 | 73 | fn write_cm_data_rows(counts: &Vec>, 74 | outcomes: &Vec, 75 | prediction_wing_length: usize, 76 | padded_outcome_length: usize, 77 | buffer: &mut String) { 78 | 79 | for (i, outcome) in outcomes.iter().enumerate() { 80 | buffer.push('|'); 81 | if i == outcomes.len() / 2 && outcomes.len() % 2 != 0 { 82 | buffer.push_str(" Prediction |"); 83 | } else { 84 | buffer.push_str(" |"); 85 | } 86 | buffer.push_str(center(outcome.as_str(), padded_outcome_length - 1).as_str()); 87 | buffer.push('|'); 88 | for j in 0..outcomes.len() { 89 | let fmt_count = center(counts[i][j].to_string().as_str(), padded_outcome_length - 1); 90 | buffer.push_str(format!("{}|", fmt_count).as_str()); 91 | } 92 | buffer.push('\n'); 93 | if i < outcomes.len() - 1 { 94 | if i == outcomes.len() / 2 - 1 && outcomes.len() % 2 == 0 { 95 | buffer.push_str("| Prediction |"); 96 | } else { 97 | buffer.push_str("| |"); 98 | } 99 | fill_char('=', padded_outcome_length - 1, buffer); 100 | buffer.push('|'); 101 | for _ in 0..outcomes.len() { 102 | fill_char('-', padded_outcome_length - 1, buffer); 103 | buffer.push('|'); 104 | } 105 | buffer.push('\n'); 106 | } else { 107 | buffer.push('o'); 108 | buffer.push_str("==============o"); 109 | fill_char('=', prediction_wing_length - 16, buffer); 110 | buffer.push('o'); 111 | fill_char('=', padded_outcome_length * outcomes.len() - 1, buffer); 112 | buffer.push('o'); 113 | } 114 | } 115 | } 116 | 117 | fn fill_char(c: char, length: usize, buffer: &mut String) { 118 | for _ in 0..length { 119 | buffer.push(c); 120 | } 121 | } 122 | 123 | fn center(s: &str, length: usize) -> String { 124 | let len: usize = s.chars().count(); 125 | let diff = length - len; 126 | let left_pad_length = diff / 2; 127 | let right_pad_length = diff - left_pad_length; 128 | let left_pad: String = vec![' '; left_pad_length].iter().collect(); 129 | let right_pad: String = vec![' '; right_pad_length].iter().collect(); 130 | format!("{}{}{}", left_pad, s, right_pad) 131 | } 132 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! Details evaluation error types 3 | //! 4 | 5 | /// 6 | /// Represents a generic evaluation error 7 | /// 8 | #[derive(Clone, Debug)] 9 | pub struct EvalError { 10 | /// The error message 11 | pub msg: String 12 | } 13 | 14 | impl EvalError { 15 | 16 | /// 17 | /// Alerts than an invalid input was provided 18 | /// 19 | /// # Arguments 20 | /// 21 | /// * `msg` - detailed error message 22 | /// 23 | pub fn invalid_input(msg: &str) -> EvalError { 24 | EvalError {msg: format!("Invalid input: {}", msg)} 25 | } 26 | 27 | /// 28 | /// Alerts that an undefined metric was encountered 29 | /// 30 | /// # Arguments 31 | /// 32 | /// * `name` - metric name 33 | /// 34 | pub fn undefined_metric(name: &str) -> EvalError { 35 | EvalError {msg: format!("Undefined metric: {}", name)} 36 | } 37 | 38 | /// 39 | /// Alerts than an infinite/NaN value was encountered 40 | /// 41 | pub fn infinite_value() -> EvalError { 42 | EvalError {msg: String::from("Infinite or NaN value")} 43 | } 44 | 45 | /// 46 | /// Alerts that constant input data was encountered 47 | /// 48 | pub fn constant_input_data() -> EvalError { 49 | EvalError {msg: String::from("Constant input data")} 50 | } 51 | 52 | /// 53 | /// Alerts than an invalid metric was provided 54 | /// 55 | /// # Arguments 56 | /// 57 | /// * `name` - metric name 58 | /// 59 | pub fn invalid_metric(name: &str) -> EvalError { 60 | EvalError {msg: format!("Invalid metric: {}", name)} 61 | } 62 | } 63 | 64 | impl std::fmt::Display for EvalError { 65 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 66 | write!(f, "{}", self.msg) 67 | } 68 | } 69 | 70 | impl std::error::Error for EvalError {} 71 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! This crate provides functionality for evaluation metrics commonly found in machine learning 3 | //! 4 | 5 | pub mod classification; 6 | pub mod regression; 7 | pub mod error; 8 | mod display; 9 | mod numeric; 10 | mod util; 11 | -------------------------------------------------------------------------------- /src/numeric.rs: -------------------------------------------------------------------------------- 1 | use std::ops::{Add, Sub, Mul, Div, AddAssign}; 2 | 3 | /// 4 | /// Represents a scalar value which can either be single (f32) or double (f64) precision 5 | /// 6 | pub trait Scalar: 7 | private::Sealed + 8 | Copy + 9 | Add + 10 | Sub + 11 | Mul + 12 | Div + 13 | AddAssign + 14 | PartialOrd { 15 | 16 | /// 17 | /// Computes the absolute value 18 | /// 19 | fn abs(self) -> Self; 20 | 21 | /// 22 | /// Compute the square root 23 | /// 24 | fn sqrt(self) -> Self; 25 | 26 | /// 27 | /// Indicates whether or not the value is finite 28 | /// 29 | fn is_finite(self) -> bool; 30 | 31 | /// 32 | /// Provides a representation of the number zero 33 | /// 34 | fn zero() -> Self; 35 | 36 | /// 37 | /// Provides a representation of the number one 38 | /// 39 | fn one() -> Self; 40 | 41 | /// 42 | /// Constructs a float from an f32 value 43 | /// 44 | fn from_f32(x: f32) -> Self; 45 | 46 | /// 47 | /// Constructs a float from an f64 value 48 | /// 49 | fn from_f64(x: f64) -> Self; 50 | 51 | /// 52 | /// Constructs a float from a usize value 53 | /// 54 | fn from_usize(x: usize) -> Self; 55 | } 56 | 57 | /// 58 | /// Implementation for f32 single-precision values 59 | /// 60 | impl Scalar for f32 { 61 | fn abs(self) -> Self {self.abs()} 62 | fn sqrt(self) -> Self {self.sqrt()} 63 | fn is_finite(self) -> bool {self.is_finite()} 64 | fn zero() -> Self {0.0_f32} 65 | fn one() -> Self {1.0_f32} 66 | fn from_f32(x: f32) -> Self {x} 67 | fn from_f64(x: f64) -> Self {x as f32} 68 | fn from_usize(x: usize) -> Self {x as f32} 69 | } 70 | 71 | /// 72 | /// Implementation for f64 double-precision values 73 | /// 74 | impl Scalar for f64 { 75 | fn abs(self) -> Self {self.abs()} 76 | fn sqrt(self) -> Self {self.sqrt()} 77 | fn is_finite(self) -> bool {self.is_finite()} 78 | fn zero() -> Self {0.0} 79 | fn one() -> Self {1.0} 80 | fn from_f32(x: f32) -> Self {x as f64} 81 | fn from_f64(x: f64) -> Self {x} 82 | fn from_usize(x: usize) -> Self {x as f64} 83 | } 84 | 85 | mod private { 86 | 87 | pub trait Sealed {} 88 | 89 | impl Sealed for f32 {} 90 | impl Sealed for f64 {} 91 | } 92 | -------------------------------------------------------------------------------- /src/regression.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! Provides support for regression metrics 3 | //! 4 | 5 | use crate::util; 6 | use crate::numeric::Scalar; 7 | use crate::error::EvalError; 8 | 9 | /// 10 | /// Computes the mean squared error between scores and labels 11 | /// 12 | /// # Arguments 13 | /// 14 | /// * `scores` - score vector 15 | /// * `labels` - label vector 16 | /// 17 | /// # Examples 18 | /// 19 | /// ``` 20 | /// # use eval_metrics::error::EvalError; 21 | /// # fn main() -> Result<(), EvalError> { 22 | /// use eval_metrics::regression::mse; 23 | /// let scores = vec![2.3, 5.1, -3.2, 7.1, -4.4]; 24 | /// let labels = vec![1.7, 4.3, -4.1, 6.5, -3.2]; 25 | /// let metric = mse(&scores, &labels)?; 26 | /// # Ok(())} 27 | /// ``` 28 | /// 29 | pub fn mse(scores: &Vec, labels: &Vec) -> Result { 30 | util::validate_input_dims(scores, labels).and_then(|()| { 31 | Ok(scores.iter().zip(labels.iter()).fold(T::zero(), |sum, (&a, &b)| { 32 | let diff = a - b; 33 | sum + (diff * diff) 34 | }) / T::from_usize(scores.len())) 35 | }).and_then(util::check_finite) 36 | } 37 | 38 | /// 39 | /// Computes the root mean squared error between scores and labels 40 | /// 41 | /// # Arguments 42 | /// 43 | /// * `scores` - score vector 44 | /// * `labels` - label vector 45 | /// 46 | /// # Examples 47 | /// 48 | /// ``` 49 | /// # use eval_metrics::error::EvalError; 50 | /// # fn main() -> Result<(), EvalError> { 51 | /// use eval_metrics::regression::rmse; 52 | /// let scores = vec![2.3, 5.1, -3.2, 7.1, -4.4]; 53 | /// let labels = vec![1.7, 4.3, -4.1, 6.5, -3.2]; 54 | /// let metric = rmse(&scores, &labels)?; 55 | /// # Ok(())} 56 | /// ``` 57 | /// 58 | pub fn rmse(scores: &Vec, labels: &Vec) -> Result { 59 | mse(scores, labels).map(|m| m.sqrt()) 60 | } 61 | 62 | /// 63 | /// Computes the mean absolute error between scores and labels 64 | /// 65 | /// # Arguments 66 | /// 67 | /// * `scores` - score vector 68 | /// * `labels` - label vector 69 | /// 70 | /// # Examples 71 | /// 72 | /// ``` 73 | /// # use eval_metrics::error::EvalError; 74 | /// # fn main() -> Result<(), EvalError> { 75 | /// use eval_metrics::regression::mae; 76 | /// let scores = vec![2.3, 5.1, -3.2, 7.1, -4.4]; 77 | /// let labels = vec![1.7, 4.3, -4.1, 6.5, -3.2]; 78 | /// let metric = mae(&scores, &labels)?; 79 | /// # Ok(())} 80 | /// ``` 81 | /// 82 | pub fn mae(scores: &Vec, labels: &Vec) -> Result { 83 | util::validate_input_dims(scores, labels).and_then(|()| { 84 | Ok(scores.iter().zip(labels.iter()).fold(T::zero(), |sum, (&a, &b)| { 85 | sum + (a - b).abs() 86 | }) / T::from_usize(scores.len())) 87 | }).and_then(util::check_finite) 88 | } 89 | 90 | /// 91 | /// Computes the coefficient of determination between scores and labels 92 | /// 93 | /// # Arguments 94 | /// 95 | /// * `scores` - score vector 96 | /// * `labels` - label vector 97 | /// 98 | /// # Examples 99 | /// 100 | /// ``` 101 | /// # use eval_metrics::error::EvalError; 102 | /// # fn main() -> Result<(), EvalError> { 103 | /// use eval_metrics::regression::rsq; 104 | /// let scores = vec![2.3, 5.1, -3.2, 7.1, -4.4]; 105 | /// let labels = vec![1.7, 4.3, -4.1, 6.5, -3.2]; 106 | /// let metric = rsq(&scores, &labels)?; 107 | /// # Ok(())} 108 | /// ``` 109 | /// 110 | pub fn rsq(scores: &Vec, labels: &Vec) -> Result { 111 | util::validate_input_dims(scores, labels).and_then(|()| { 112 | let length = scores.len(); 113 | let label_sum = labels.iter().fold(T::zero(), |s, &v| {s + v}); 114 | let label_mean = label_sum / T::from_usize(length); 115 | let den = labels.iter().fold(T::zero(), |sse, &label| { 116 | sse + (label - label_mean) * (label - label_mean) 117 | }) / T::from_usize(length); 118 | if den == T::zero() { 119 | Err(EvalError::constant_input_data()) 120 | } else { 121 | mse(scores, labels).map(|m| T::one() - (m / den)) 122 | } 123 | }) 124 | } 125 | 126 | /// 127 | /// Computes the linear correlation between scores and labels 128 | /// 129 | /// # Arguments 130 | /// 131 | /// * `scores` - score vector 132 | /// * `labels` - label vector 133 | /// 134 | /// # Examples 135 | /// 136 | /// ``` 137 | /// # use eval_metrics::error::EvalError; 138 | /// # fn main() -> Result<(), EvalError> { 139 | /// use eval_metrics::regression::corr; 140 | /// let scores = vec![2.3, 5.1, -3.2, 7.1, -4.4]; 141 | /// let labels = vec![1.7, 4.3, -4.1, 6.5, -3.2]; 142 | /// let metric = corr(&scores, &labels)?; 143 | /// # Ok(())} 144 | /// ``` 145 | /// 146 | pub fn corr(scores: &Vec, labels: &Vec) -> Result { 147 | util::validate_input_dims(scores, labels).and_then(|()| { 148 | let length = scores.len(); 149 | let x_mean = scores.iter().fold(T::zero(), |sum, &v| {sum + v}) / T::from_usize(length); 150 | let y_mean = labels.iter().fold(T::zero(), |sum, &v| {sum + v}) / T::from_usize(length); 151 | let mut sxx = T::zero(); 152 | let mut syy = T::zero(); 153 | let mut sxy = T::zero(); 154 | 155 | scores.iter().zip(labels.iter()).for_each(|(&x, &y)| { 156 | let x_diff = x - x_mean; 157 | let y_diff = y - y_mean; 158 | sxx += x_diff * x_diff; 159 | syy += y_diff * y_diff; 160 | sxy += x_diff * y_diff; 161 | }); 162 | 163 | match (sxx * syy).sqrt() { 164 | den if den == T::zero() => Err(EvalError::constant_input_data()), 165 | den => util::check_finite(sxy / den) 166 | } 167 | }) 168 | } 169 | 170 | #[cfg(test)] 171 | mod tests { 172 | 173 | use assert_approx_eq::assert_approx_eq; 174 | use super::*; 175 | 176 | fn data() -> (Vec, Vec) { 177 | let scores= vec![0.5, 0.2, 0.7, 0.4, 0.1, 0.3, 0.8, 0.9]; 178 | let labels= vec![0.3, 0.1, 0.5, 0.6, 0.2, 0.5, 0.7, 0.6]; 179 | (scores, labels) 180 | } 181 | 182 | #[test] 183 | fn test_mse() { 184 | let (scores, labels) = data(); 185 | assert_approx_eq!(mse(&scores, &labels).unwrap(), 0.035) 186 | } 187 | 188 | #[test] 189 | fn test_mse_empty() { 190 | assert!(mse(&Vec::::new(), &Vec::::new()).is_err()) 191 | } 192 | 193 | #[test] 194 | fn test_mse_unequal_length() { 195 | assert!(mse(&vec![0.1, 0.2], &vec![0.3, 0.5, 0.8]).is_err()) 196 | } 197 | 198 | #[test] 199 | fn test_mse_constant() { 200 | assert_approx_eq!(mse(&vec![1.0; 10], &vec![1.0; 10]).unwrap(), 0.0) 201 | } 202 | 203 | #[test] 204 | fn test_mse_nan() { 205 | assert!(mse(&vec![0.2, 0.5, 0.4], &vec![0.1, 0.4, f64::NAN]).is_err()) 206 | } 207 | 208 | #[test] 209 | fn test_rmse() { 210 | let (scores, labels) = data(); 211 | assert_approx_eq!(rmse(&scores, &labels).unwrap(), 0.035.sqrt()) 212 | } 213 | 214 | #[test] 215 | fn test_rmse_empty() { 216 | assert!(rmse(&Vec::::new(), &Vec::::new()).is_err()) 217 | } 218 | 219 | #[test] 220 | fn test_rmse_unequal_length() { 221 | assert!(rmse(&vec![0.1, 0.2], &vec![0.3, 0.5, 0.8]).is_err()) 222 | } 223 | 224 | #[test] 225 | fn test_rmse_constant() { 226 | assert_approx_eq!(rmse(&vec![1.0; 10], &vec![1.0; 10]).unwrap(), 0.0) 227 | } 228 | 229 | #[test] 230 | fn test_rmse_nan() { 231 | assert!(rmse(&vec![0.2, 0.5, 0.4], &vec![0.1, 0.4, f64::NAN]).is_err()) 232 | } 233 | 234 | #[test] 235 | fn test_mae() { 236 | let (scores, labels) = data(); 237 | assert_approx_eq!(mae(&scores, &labels).unwrap(), 0.175) 238 | } 239 | 240 | #[test] 241 | fn test_mae_empty() { 242 | assert!(mae(&Vec::::new(), &Vec::::new()).is_err()) 243 | } 244 | 245 | #[test] 246 | fn test_mae_unequal_length() { 247 | assert!(mae(&vec![0.1, 0.2], &vec![0.3, 0.5, 0.8]).is_err()) 248 | } 249 | 250 | #[test] 251 | fn test_mae_constant() { 252 | assert_approx_eq!(mae(&vec![1.0; 10], &vec![1.0; 10]).unwrap(), 0.0) 253 | } 254 | 255 | #[test] 256 | fn test_mae_nan() { 257 | assert!(mae(&vec![0.2, 0.5, 0.4], &vec![0.1, 0.4, f64::NAN]).is_err()) 258 | } 259 | 260 | #[test] 261 | fn test_rsq() { 262 | let (scores, labels) = data(); 263 | assert_approx_eq!(rsq(&scores, &labels).unwrap(), 0.12156862745098007) 264 | } 265 | 266 | #[test] 267 | fn test_rsq_empty() { 268 | assert!(rsq(&Vec::::new(), &Vec::::new()).is_err()) 269 | } 270 | 271 | #[test] 272 | fn test_rsq_unequal_length() { 273 | assert!(rsq(&vec![0.1, 0.2], &vec![0.3, 0.5, 0.8]).is_err()) 274 | } 275 | 276 | #[test] 277 | fn test_rsq_constant() { 278 | assert!(rsq(&vec![1.0; 10], &vec![1.0; 10]).is_err()) 279 | } 280 | 281 | #[test] 282 | fn test_rsq_nan() { 283 | assert!(rsq(&vec![0.2, 0.5, 0.4], &vec![0.1, 0.4, f64::NAN]).is_err()) 284 | } 285 | 286 | #[test] 287 | fn test_corr() { 288 | let (scores, labels) = data(); 289 | assert_approx_eq!(corr(&scores, &labels).unwrap(), 0.7473417080949364) 290 | } 291 | 292 | #[test] 293 | fn test_corr_empty() { 294 | assert!(corr(&Vec::::new(), &Vec::::new()).is_err()) 295 | } 296 | 297 | #[test] 298 | fn test_corr_unequal_length() { 299 | assert!(corr(&vec![0.1, 0.2], &vec![0.3, 0.5, 0.8]).is_err()) 300 | } 301 | 302 | #[test] 303 | fn test_corr_constant() { 304 | assert!(corr(&vec![1.0; 10], &vec![1.0; 10]).is_err()) 305 | } 306 | 307 | #[test] 308 | fn test_corr_nan() { 309 | assert!(corr(&vec![0.2, 0.5, 0.4], &vec![0.1, 0.4, f64::NAN]).is_err()) 310 | } 311 | } 312 | -------------------------------------------------------------------------------- /src/util.rs: -------------------------------------------------------------------------------- 1 | use crate::error::EvalError; 2 | use crate::numeric::Scalar; 3 | 4 | /// 5 | /// Validates a pair of scores and labels, returning an error if either scores or labels are 6 | /// empty, or if they have lengths that differ 7 | /// 8 | /// # Arguments 9 | /// 10 | /// * `scores` - vector of scores 11 | /// * `labels` - vector of labels 12 | /// 13 | pub fn validate_input_dims(scores: &Vec, labels: &Vec) -> Result<(), EvalError> { 14 | if scores.is_empty() { 15 | Err(EvalError::invalid_input("Scores are empty")) 16 | } else if labels.is_empty() { 17 | Err(EvalError::invalid_input("Labels are empty")) 18 | } else if scores.len() != labels.len() { 19 | Err(EvalError::invalid_input("Scores and labels have different lengths")) 20 | } else { 21 | Ok(()) 22 | } 23 | } 24 | 25 | /// 26 | /// Check if the provided value is finite 27 | /// 28 | /// # Arguments 29 | /// 30 | /// * `value` - float value to check 31 | /// 32 | pub fn check_finite(value: T) -> Result { 33 | if value.is_finite() { 34 | Ok(value) 35 | } else { 36 | Err(EvalError::infinite_value()) 37 | } 38 | } 39 | --------------------------------------------------------------------------------