├── .gitignore ├── LICENSE ├── README.md ├── pyproject.toml └── src └── spark_calibration ├── __init__.py ├── betacal.py ├── metrics.py └── visualisation.py /.gitignore: -------------------------------------------------------------------------------- 1 | .pre-commit-config.yaml 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | 190 | Copyright [2023] [Fashnear Technologies Private Limited] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Model calibration with pyspark 2 | 3 | Screenshot 2023-10-10 at 3 19 39 PM 4 | 5 | 6 | This package provides a Betacal class which allows the user to fit/train the default beta calibration model on pyspark dataframes and predict calibrated scores 7 | 8 | 9 | ## Setup 10 | 11 | spark-calibration package is [uploaded to PyPi](https://pypi.org/project/spark-calibration/) and can be installed with this command: 12 | 13 | ``` 14 | pip install spark-calibration 15 | ``` 16 | 17 | ## Usage 18 | 19 | ### Training 20 | 21 | train_df should be a pyspark dataframe containing: 22 | - A column with raw model scores (default name: `score`) 23 | - A column with binary labels (default name: `label`) 24 | 25 | You can specify different column names when calling `fit()`. In some tree-based models like LightGBM, the predicted scores may fall outside the [0, 1] range and can even be negative. Please apply a sigmoid function to normalize the outputs accordingly. 26 | 27 | ```python 28 | from spark_calibration import Betacal 29 | from spark_calibration import display_classification_calib_metrics 30 | from spark_calibration import plot_calibration_curve 31 | 32 | # Initialize model 33 | bc = Betacal(parameters="abm") 34 | 35 | # Load training data 36 | train_df = spark.read.parquet("s3://train/") 37 | 38 | # Fit the model 39 | bc.fit(train_df) 40 | 41 | # Or specify custom column names 42 | # bc.fit(train_df, score_col="raw_score", label_col="actual_label") 43 | 44 | # Access model parameters 45 | print(f"Model coefficients: a={bc.a}, b={bc.b}, c={bc.c}") 46 | ``` 47 | 48 | The model learns three parameters: 49 | - a: Coefficient for log(score) 50 | - b: Coefficient for log(1-score) 51 | - c: Intercept term 52 | 53 | ### Saving and Loading Models 54 | 55 | You can save the trained model to disk and load it later: 56 | 57 | ```python 58 | # Save model 59 | save_path = bc.save("/path/to/save/") 60 | 61 | # Load model 62 | loaded_model = Betacal.load("/path/to/save/") 63 | ``` 64 | 65 | ### Prediction 66 | 67 | test_df should be a pyspark dataframe containing a column with raw model scores. By default, this column should be named `score`, but you can specify a different column name when calling `predict()`. The `predict` function adds a new column `prediction` which has the calibrated score. 68 | 69 | ```python 70 | test_df = spark.read.parquet("s3://test/") 71 | 72 | # Using default column name 'score' 73 | test_df = bc.predict(test_df) 74 | 75 | # Or specify a custom score column name 76 | # test_df = bc.predict(test_df, score_col="raw_score") 77 | ``` 78 | 79 | ### Pre & Post Calibration Classification Metrics 80 | 81 | The test_df should have `score`, `prediction` & `label` columns. 82 | The `display_classification_calib_metrics` functions displays `brier_score_loss`, `log_loss`, `area_under_PR_curve` and `area_under_ROC_curve` 83 | ```python 84 | display_classification_calib_metrics(test_df) 85 | ``` 86 | #### Output 87 | ``` 88 | model brier score loss: 0.08072683729933376 89 | calibrated model brier score loss: 0.01014015353257748 90 | delta: -87.44% 91 | 92 | model log loss: 0.3038106859864252 93 | calibrated model log loss: 0.053275633947890755 94 | delta: -82.46% 95 | 96 | model aucpr: 0.03471287564672635 97 | calibrated model aucpr: 0.03471240518472563 98 | delta: -0.0% 99 | 100 | model roc_auc: 0.7490639506966398 101 | calibrated model roc_auc: 0.7490649764289607 102 | delta: 0.0% 103 | ``` 104 | 105 | ### Plot the Calibration Curve 106 | 107 | Computes true, predicted probabilities (pre & post calibration) using quantile binning strategy with 50 bins and plots the calibration curve 108 | 109 | ```python 110 | plot_calibration_curve(test_df) 111 | ``` 112 | Screenshot 2023-10-10 at 3 19 39 PM 113 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "spark_calibration" 7 | version = "2.0.0" 8 | authors = [ 9 | { name="Jaya", email="jaya.kommuru@meesho.com"} 10 | ] 11 | 12 | description = "Calibrating model scores/probabilities with PySpark DataFrames" 13 | license = "Apache-2.0" 14 | readme = "README.md" 15 | requires-python = ">=3.7" 16 | dependencies = [ 17 | "pyspark>=3.2.1", 18 | "numpy>=1.20.3", 19 | "plotly>=5.9.0", 20 | "scikit-learn>=1.0.2", 21 | "pandas>=1.3.0" 22 | ] 23 | classifiers = [ 24 | "Programming Language :: Python :: 3", 25 | "License :: OSI Approved :: Apache Software License", 26 | "Operating System :: OS Independent", 27 | ] 28 | 29 | [project.urls] 30 | "Homepage" = "https://github.com/Meesho/spark_calibration" 31 | "Bug Tracker" = "https://github.com/Meesho/spark_calibration/issues" 32 | "Documentation" = "https://github.com/Meesho/spark_calibration#readme" -------------------------------------------------------------------------------- /src/spark_calibration/__init__.py: -------------------------------------------------------------------------------- 1 | from .betacal import Betacal 2 | from .metrics import display_classification_calib_metrics 3 | from .visualisation import plot_calibration_curve 4 | 5 | __all__ = ["Betacal", "display_classification_calib_metrics", "plot_calibration_curve"] 6 | -------------------------------------------------------------------------------- /src/spark_calibration/betacal.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import tempfile 5 | from typing import Optional 6 | 7 | import pyspark.sql.functions as F 8 | from pyspark.ml.classification import LogisticRegression 9 | from pyspark.ml.feature import VectorAssembler 10 | from pyspark.sql import DataFrame 11 | 12 | 13 | def get_logger(): 14 | """Configure logger for Spark environment.""" 15 | logger = logging.getLogger(__name__) 16 | 17 | if not logger.handlers: 18 | handler = logging.StreamHandler() 19 | formatter = logging.Formatter( 20 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 21 | ) 22 | handler.setFormatter(formatter) 23 | logger.addHandler(handler) 24 | logger.setLevel(logging.INFO) 25 | 26 | return logger 27 | 28 | 29 | logger = get_logger() 30 | 31 | 32 | class Betacal: 33 | """ 34 | Beta calibration using a logistic transformation of raw model scores. 35 | 36 | Formula: 37 | logit = a * log(score) + b * log(1 - score) + c 38 | prediction = 1 / (1 + exp(-logit)) 39 | 40 | Attributes: 41 | a (float): Coefficient for log(score) 42 | b (float): Coefficient for log(1 - score) 43 | c (float): Intercept 44 | """ 45 | 46 | EPSILON = 1e-12 47 | 48 | def __init__(self, parameters: str = "abm"): 49 | assert parameters == "abm", "Only 'abm' parameterization is supported." 50 | self.parameters = parameters 51 | self.a: Optional[float] = None 52 | self.b: Optional[float] = None 53 | self.c: Optional[float] = None 54 | 55 | def get_params(self) -> dict: 56 | """ 57 | Get the model parameters. 58 | 59 | Returns: 60 | dict: Dictionary containing model parameters a, b, c. 61 | """ 62 | return {"a": self.a, "b": self.b, "c": self.c, "parameters": self.parameters} 63 | 64 | def _log_expr(self, col: F.Column) -> F.Column: 65 | """Numerically stable log transformation.""" 66 | return F.log(F.when(col < self.EPSILON, self.EPSILON).otherwise(col)) 67 | 68 | def _validate_input_df(self, df: DataFrame, score_col: str, label_col: str) -> None: 69 | """ 70 | Validate input DataFrame and required columns. 71 | 72 | Args: 73 | df (DataFrame): Input dataframe. 74 | score_col (str): Column containing raw model scores. 75 | label_col (str): Column containing binary labels. 76 | 77 | Raises: 78 | ValueError: If DataFrame is empty or required columns are missing. 79 | """ 80 | if df.count() == 0: 81 | raise ValueError("Cannot fit model on empty DataFrame") 82 | 83 | assert ( 84 | score_col in df.columns and label_col in df.columns 85 | ), f"Columns {score_col} and {label_col} must be present." 86 | 87 | def _handle_null_values(self, df: DataFrame, score_col: str) -> DataFrame: 88 | """ 89 | Handle null values in the score column. 90 | 91 | Args: 92 | df (DataFrame): Input dataframe. 93 | score_col (str): Column containing raw model scores. 94 | 95 | Returns: 96 | DataFrame: Cleaned DataFrame with null values removed. 97 | 98 | Raises: 99 | ValueError: If all rows contain null values. 100 | """ 101 | total_rows = df.count() 102 | df_clean = df.dropna(subset=[score_col]) 103 | rows_after_drop = df_clean.count() 104 | 105 | if rows_after_drop == 0: 106 | raise ValueError(f"All rows contained null values in {score_col} column") 107 | 108 | dropped_rows = total_rows - rows_after_drop 109 | if dropped_rows > 0: 110 | logger.info( 111 | f"Dropped {dropped_rows}/{total_rows} rows ({(dropped_rows/total_rows)*100:.2f}%) " 112 | f"with null values in column '{score_col}'" 113 | ) 114 | 115 | return df_clean 116 | 117 | def _prepare_features( 118 | self, df: DataFrame, score_col: str, label_col: str 119 | ) -> DataFrame: 120 | """ 121 | Prepare features for logistic regression with all possible combinations. 122 | 123 | Args: 124 | df (DataFrame): Input dataframe. 125 | score_col (str): Column containing raw model scores. 126 | label_col (str): Column containing binary labels. 127 | 128 | Returns: 129 | DataFrame: Transformed DataFrame with features ready for training. 130 | Contains three feature vectors: 131 | - features_both: Both log(score) and -log(1-score) 132 | - features_score: Only log(score) 133 | - features_complement: Only -log(1-score) 134 | """ 135 | log_score = self._log_expr(F.col(score_col)) 136 | log_one_minus_score = self._log_expr(1 - F.col(score_col)) 137 | 138 | df_transformed = df.select( 139 | F.col(label_col).alias("label"), 140 | log_score.alias("log_score"), 141 | (-1 * log_one_minus_score).alias("log_score_complement"), 142 | ) 143 | 144 | # Prepare all possible feature combinations 145 | assembler_both = VectorAssembler( 146 | inputCols=["log_score", "log_score_complement"], outputCol="features_both" 147 | ) 148 | assembler_score = VectorAssembler( 149 | inputCols=["log_score"], outputCol="features_score" 150 | ) 151 | assembler_complement = VectorAssembler( 152 | inputCols=["log_score_complement"], outputCol="features_complement" 153 | ) 154 | 155 | df_with_both = assembler_both.transform(df_transformed) 156 | df_with_score = assembler_score.transform(df_with_both) 157 | return assembler_complement.transform(df_with_score) 158 | 159 | def _fit_logistic_regression(self, train_data: DataFrame) -> None: 160 | """ 161 | Fit logistic regression model and set coefficients. 162 | 163 | Args: 164 | train_data (DataFrame): Prepared training data with features. 165 | """ 166 | lr = LogisticRegression() 167 | 168 | # First try with both features 169 | model = lr.fit( 170 | train_data.select("label", F.col("features_both").alias("features")) 171 | ) 172 | coef = model.coefficients 173 | 174 | if coef[0] < 0: 175 | # Use only complement feature if first coefficient is negative 176 | model = lr.fit( 177 | train_data.select( 178 | "label", F.col("features_complement").alias("features") 179 | ) 180 | ) 181 | self.a = 0.0 182 | self.b = float(model.coefficients[0]) 183 | elif coef[1] < 0: 184 | # Use only score feature if second coefficient is negative 185 | model = lr.fit( 186 | train_data.select("label", F.col("features_score").alias("features")) 187 | ) 188 | self.a = float(model.coefficients[0]) 189 | self.b = 0.0 190 | else: 191 | self.a = float(coef[0]) 192 | self.b = float(coef[1]) 193 | 194 | self.c = float(model.intercept) 195 | 196 | def _validate_score_range(self, df: DataFrame, score_col: str) -> None: 197 | """ 198 | Validate that scores are within valid range (0,1). 199 | 200 | Args: 201 | df (DataFrame): Input dataframe. 202 | score_col (str): Column containing raw model scores. 203 | 204 | Raises: 205 | ValueError: If scores are outside valid range. 206 | """ 207 | stats = df.select( 208 | F.min(score_col).alias("min"), F.max(score_col).alias("max") 209 | ).collect()[0] 210 | 211 | if stats.min < 0 or stats.max > 1: 212 | raise ValueError( 213 | f"Scores must be in range [0,1], got range [{stats.min:.3f}, {stats.max:.3f}]" 214 | ) 215 | 216 | def fit( 217 | self, df: DataFrame, score_col: str = "score", label_col: str = "label" 218 | ) -> "Betacal": 219 | """ 220 | Fit a beta calibration model using logistic regression. 221 | 222 | Args: 223 | df (DataFrame): Input dataframe. 224 | score_col (str): Column containing raw model scores. 225 | label_col (str): Column containing binary labels. 226 | 227 | Returns: 228 | Betacal: The fitted model instance (self). 229 | 230 | Raises: 231 | ValueError: If input DataFrame is empty or contains all null values. 232 | """ 233 | self._validate_input_df(df, score_col, label_col) 234 | self._validate_score_range(df, score_col) 235 | df_clean = self._handle_null_values(df, score_col) 236 | train_data = self._prepare_features(df_clean, score_col, label_col) 237 | self._fit_logistic_regression(train_data) 238 | return self 239 | 240 | def predict( 241 | self, 242 | df: DataFrame, 243 | score_col: str = "score", 244 | prediction_col: str = "prediction", 245 | ) -> DataFrame: 246 | """ 247 | Apply the learned beta calibration model to predict calibrated scores. 248 | 249 | Args: 250 | df (DataFrame): Input dataframe with raw scores. 251 | score_col (str): Column name for raw score. 252 | prediction_col (str): Name for the output prediction column. 253 | 254 | Returns: 255 | DataFrame: Original dataframe with an added prediction column. 256 | Null values in score_col will result in null predictions. 257 | 258 | Raises: 259 | ValueError: If calibration coefficients are not set or scores are outside valid range. 260 | """ 261 | if self.a is None or self.b is None or self.c is None: 262 | raise ValueError( 263 | "Model coefficients a, b, and c must be set. Call `.fit()` or `.load()` before prediction." 264 | ) 265 | 266 | assert score_col in df.columns, f"{score_col} must be present." 267 | 268 | self._validate_score_range(df.filter(F.col(score_col).isNotNull()), score_col) 269 | 270 | log_score = self._log_expr(F.col(score_col)) 271 | log_one_minus_score = self._log_expr(1 - F.col(score_col)) 272 | 273 | logit = ( 274 | F.lit(self.a) * log_score 275 | + F.lit(self.b) * (-1 * log_one_minus_score) 276 | + F.lit(self.c) 277 | ) 278 | 279 | prediction = F.when(F.col(score_col).isNull(), None).otherwise( 280 | 1 / (1 + F.exp(-logit)) 281 | ) 282 | 283 | return df.withColumn(prediction_col, prediction) 284 | 285 | def save(self, path: Optional[str] = None, prefix: str = "betacal_") -> str: 286 | """ 287 | Save the model coefficients to disk. 288 | 289 | Args: 290 | path (str, optional): Directory to save into. Creates temp dir if None. 291 | prefix (str): Prefix for temp folder name if path is None. 292 | 293 | Returns: 294 | str: The final save path. 295 | """ 296 | if path is None: 297 | path = tempfile.mkdtemp(prefix=prefix) 298 | 299 | os.makedirs(path, exist_ok=True) 300 | 301 | with open(os.path.join(path, "coeffs.json"), "w") as f: 302 | json.dump( 303 | {"a": self.a, "b": self.b, "c": self.c, "parameters": self.parameters}, 304 | f, 305 | ) 306 | 307 | return path 308 | 309 | @classmethod 310 | def load(cls, path: str) -> "Betacal": 311 | """ 312 | Load model coefficients from disk. 313 | 314 | Args: 315 | path (str): Directory containing 'coeffs.json'. 316 | 317 | Returns: 318 | Betacal: The loaded model. 319 | """ 320 | with open(os.path.join(path, "coeffs.json"), "r") as f: 321 | coeffs = json.load(f) 322 | 323 | model = cls(parameters=coeffs["parameters"]) 324 | model.a = coeffs["a"] 325 | model.b = coeffs["b"] 326 | model.c = coeffs["c"] 327 | return model 328 | -------------------------------------------------------------------------------- /src/spark_calibration/metrics.py: -------------------------------------------------------------------------------- 1 | from pyspark.ml.evaluation import BinaryClassificationEvaluator 2 | import pyspark.sql.functions as F 3 | 4 | from pyspark.sql.dataframe import DataFrame 5 | 6 | 7 | def display_classification_calib_metrics(df: DataFrame): 8 | """Print pre and post calibration metrics for comparison 9 | 10 | Args: 11 | df: dataframe with score, label and prediction(calibratied score) columns 12 | """ 13 | 14 | assert ( 15 | "score" in df.columns and "label" in df.columns and "prediction" in df.columns 16 | ), "score and label columns should be present in the dataframe" 17 | 18 | model_bs = df.select(F.avg(F.pow(df["label"] - df["score"], 2))).collect()[0][0] 19 | model_ll = df.select( 20 | F.avg( 21 | -F.col("label") * F.log(F.col("score")) 22 | - (1 - F.col("label")) * F.log(1 - F.col("score")) 23 | ) 24 | ).collect()[0][0] 25 | 26 | model_aucpr = BinaryClassificationEvaluator( 27 | rawPredictionCol="score", metricName="areaUnderPR" 28 | ).evaluate(df) 29 | model_roc_auc = BinaryClassificationEvaluator( 30 | rawPredictionCol="score", metricName="areaUnderROC" 31 | ).evaluate(df) 32 | iso_bs = df.select(F.avg(F.pow(df["label"] - df["prediction"], 2))).collect()[0][0] 33 | 34 | print(f"model brier score loss: {model_bs}") 35 | print(f"calibrated model brier score loss: {iso_bs}") 36 | 37 | print(f"delta: {round((iso_bs/model_bs - 1) * 100, 2)}%") 38 | iso_ll = df.select( 39 | F.avg( 40 | -F.col("label") * F.log(F.col("prediction")) 41 | - (1 - F.col("label")) * F.log(1 - F.col("prediction")) 42 | ) 43 | ).collect()[0][0] 44 | 45 | print("") 46 | 47 | print(f"model log loss: {model_ll}") 48 | print(f"calibrated model log loss: {iso_ll}") 49 | print(f"delta: {round((iso_ll/model_ll - 1) * 100, 2)}%") 50 | iso_aucpr = BinaryClassificationEvaluator( 51 | rawPredictionCol="prediction", metricName="areaUnderPR" 52 | ).evaluate(df) 53 | 54 | print("") 55 | 56 | print(f"model aucpr: {model_aucpr}") 57 | print(f"calibrated model aucpr: {iso_aucpr}") 58 | print(f"delta: {round((iso_aucpr/model_aucpr - 1) * 100, 2)}%") 59 | iso_roc_auc = BinaryClassificationEvaluator( 60 | rawPredictionCol="prediction", metricName="areaUnderROC" 61 | ).evaluate(df) 62 | 63 | print("") 64 | 65 | print(f"model roc_auc: {model_roc_auc}") 66 | print(f"calibrated model roc_auc: {iso_roc_auc}") 67 | print(f"delta: {round((iso_roc_auc/model_roc_auc - 1) * 100, 2)}%") 68 | -------------------------------------------------------------------------------- /src/spark_calibration/visualisation.py: -------------------------------------------------------------------------------- 1 | import plotly.graph_objects as go 2 | from sklearn.calibration import calibration_curve 3 | from pyspark.sql.dataframe import DataFrame 4 | 5 | 6 | def plot_calibration_curve(df: DataFrame): 7 | assert ( 8 | "score" in df.columns and "label" in df.columns and "prediction" in df.columns 9 | ), "score and label columns should be present in the dataframe" 10 | 11 | df_p_v = df.select("label", "score", "prediction").toPandas().values 12 | 13 | y_test_true, y_test_pred, y_test_pred_cal = df_p_v[:, 0], df_p_v[:, 1], df_p_v[:, 2] 14 | 15 | fig = go.Figure() 16 | 17 | fig.add_trace( 18 | go.Scatter( 19 | x=[0, 1], 20 | y=[0, 1], 21 | mode="lines", 22 | name="Perfect Calibration Baseline (y=x)", 23 | line=dict(dash="dash"), 24 | ) 25 | ) 26 | 27 | prob_true, prob_pred = calibration_curve( 28 | y_test_true, y_test_pred, n_bins=50, strategy="quantile" 29 | ) 30 | print(f"number of bins for pre-calibration scores: {prob_true.shape[0]}") 31 | 32 | fig.add_trace( 33 | go.Scatter(x=prob_pred, y=prob_true, mode="lines+markers", name="Model") 34 | ) 35 | 36 | prob_true, prob_pred = calibration_curve( 37 | y_test_true, y_test_pred_cal, n_bins=50, strategy="quantile" 38 | ) 39 | print(f"number of bins for post-calibration scores: {prob_true.shape[0]}") 40 | 41 | fig.add_trace( 42 | go.Scatter( 43 | x=prob_pred, y=prob_true, mode="lines+markers", name="Calibrated Model" 44 | ) 45 | ) 46 | 47 | fig.update_layout( 48 | title=dict(text="Calibration Curve on test data (quantile bins)", x=0.5), 49 | xaxis_title="Mean Predicted Probability", 50 | yaxis_title="Actual Fraction of Positives", 51 | ) 52 | 53 | fig.show() 54 | --------------------------------------------------------------------------------