├── .gitignore
├── LICENSE
├── README.md
├── pyproject.toml
└── src
└── spark_calibration
├── __init__.py
├── betacal.py
├── metrics.py
└── visualisation.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .pre-commit-config.yaml
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 |
190 | Copyright [2023] [Fashnear Technologies Private Limited]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Model calibration with pyspark
2 |
3 |
4 |
5 |
6 | This package provides a Betacal class which allows the user to fit/train the default beta calibration model on pyspark dataframes and predict calibrated scores
7 |
8 |
9 | ## Setup
10 |
11 | spark-calibration package is [uploaded to PyPi](https://pypi.org/project/spark-calibration/) and can be installed with this command:
12 |
13 | ```
14 | pip install spark-calibration
15 | ```
16 |
17 | ## Usage
18 |
19 | ### Training
20 |
21 | train_df should be a pyspark dataframe containing:
22 | - A column with raw model scores (default name: `score`)
23 | - A column with binary labels (default name: `label`)
24 |
25 | You can specify different column names when calling `fit()`. In some tree-based models like LightGBM, the predicted scores may fall outside the [0, 1] range and can even be negative. Please apply a sigmoid function to normalize the outputs accordingly.
26 |
27 | ```python
28 | from spark_calibration import Betacal
29 | from spark_calibration import display_classification_calib_metrics
30 | from spark_calibration import plot_calibration_curve
31 |
32 | # Initialize model
33 | bc = Betacal(parameters="abm")
34 |
35 | # Load training data
36 | train_df = spark.read.parquet("s3://train/")
37 |
38 | # Fit the model
39 | bc.fit(train_df)
40 |
41 | # Or specify custom column names
42 | # bc.fit(train_df, score_col="raw_score", label_col="actual_label")
43 |
44 | # Access model parameters
45 | print(f"Model coefficients: a={bc.a}, b={bc.b}, c={bc.c}")
46 | ```
47 |
48 | The model learns three parameters:
49 | - a: Coefficient for log(score)
50 | - b: Coefficient for log(1-score)
51 | - c: Intercept term
52 |
53 | ### Saving and Loading Models
54 |
55 | You can save the trained model to disk and load it later:
56 |
57 | ```python
58 | # Save model
59 | save_path = bc.save("/path/to/save/")
60 |
61 | # Load model
62 | loaded_model = Betacal.load("/path/to/save/")
63 | ```
64 |
65 | ### Prediction
66 |
67 | test_df should be a pyspark dataframe containing a column with raw model scores. By default, this column should be named `score`, but you can specify a different column name when calling `predict()`. The `predict` function adds a new column `prediction` which has the calibrated score.
68 |
69 | ```python
70 | test_df = spark.read.parquet("s3://test/")
71 |
72 | # Using default column name 'score'
73 | test_df = bc.predict(test_df)
74 |
75 | # Or specify a custom score column name
76 | # test_df = bc.predict(test_df, score_col="raw_score")
77 | ```
78 |
79 | ### Pre & Post Calibration Classification Metrics
80 |
81 | The test_df should have `score`, `prediction` & `label` columns.
82 | The `display_classification_calib_metrics` functions displays `brier_score_loss`, `log_loss`, `area_under_PR_curve` and `area_under_ROC_curve`
83 | ```python
84 | display_classification_calib_metrics(test_df)
85 | ```
86 | #### Output
87 | ```
88 | model brier score loss: 0.08072683729933376
89 | calibrated model brier score loss: 0.01014015353257748
90 | delta: -87.44%
91 |
92 | model log loss: 0.3038106859864252
93 | calibrated model log loss: 0.053275633947890755
94 | delta: -82.46%
95 |
96 | model aucpr: 0.03471287564672635
97 | calibrated model aucpr: 0.03471240518472563
98 | delta: -0.0%
99 |
100 | model roc_auc: 0.7490639506966398
101 | calibrated model roc_auc: 0.7490649764289607
102 | delta: 0.0%
103 | ```
104 |
105 | ### Plot the Calibration Curve
106 |
107 | Computes true, predicted probabilities (pre & post calibration) using quantile binning strategy with 50 bins and plots the calibration curve
108 |
109 | ```python
110 | plot_calibration_curve(test_df)
111 | ```
112 |
113 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "spark_calibration"
7 | version = "2.0.0"
8 | authors = [
9 | { name="Jaya", email="jaya.kommuru@meesho.com"}
10 | ]
11 |
12 | description = "Calibrating model scores/probabilities with PySpark DataFrames"
13 | license = "Apache-2.0"
14 | readme = "README.md"
15 | requires-python = ">=3.7"
16 | dependencies = [
17 | "pyspark>=3.2.1",
18 | "numpy>=1.20.3",
19 | "plotly>=5.9.0",
20 | "scikit-learn>=1.0.2",
21 | "pandas>=1.3.0"
22 | ]
23 | classifiers = [
24 | "Programming Language :: Python :: 3",
25 | "License :: OSI Approved :: Apache Software License",
26 | "Operating System :: OS Independent",
27 | ]
28 |
29 | [project.urls]
30 | "Homepage" = "https://github.com/Meesho/spark_calibration"
31 | "Bug Tracker" = "https://github.com/Meesho/spark_calibration/issues"
32 | "Documentation" = "https://github.com/Meesho/spark_calibration#readme"
--------------------------------------------------------------------------------
/src/spark_calibration/__init__.py:
--------------------------------------------------------------------------------
1 | from .betacal import Betacal
2 | from .metrics import display_classification_calib_metrics
3 | from .visualisation import plot_calibration_curve
4 |
5 | __all__ = ["Betacal", "display_classification_calib_metrics", "plot_calibration_curve"]
6 |
--------------------------------------------------------------------------------
/src/spark_calibration/betacal.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import tempfile
5 | from typing import Optional
6 |
7 | import pyspark.sql.functions as F
8 | from pyspark.ml.classification import LogisticRegression
9 | from pyspark.ml.feature import VectorAssembler
10 | from pyspark.sql import DataFrame
11 |
12 |
13 | def get_logger():
14 | """Configure logger for Spark environment."""
15 | logger = logging.getLogger(__name__)
16 |
17 | if not logger.handlers:
18 | handler = logging.StreamHandler()
19 | formatter = logging.Formatter(
20 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
21 | )
22 | handler.setFormatter(formatter)
23 | logger.addHandler(handler)
24 | logger.setLevel(logging.INFO)
25 |
26 | return logger
27 |
28 |
29 | logger = get_logger()
30 |
31 |
32 | class Betacal:
33 | """
34 | Beta calibration using a logistic transformation of raw model scores.
35 |
36 | Formula:
37 | logit = a * log(score) + b * log(1 - score) + c
38 | prediction = 1 / (1 + exp(-logit))
39 |
40 | Attributes:
41 | a (float): Coefficient for log(score)
42 | b (float): Coefficient for log(1 - score)
43 | c (float): Intercept
44 | """
45 |
46 | EPSILON = 1e-12
47 |
48 | def __init__(self, parameters: str = "abm"):
49 | assert parameters == "abm", "Only 'abm' parameterization is supported."
50 | self.parameters = parameters
51 | self.a: Optional[float] = None
52 | self.b: Optional[float] = None
53 | self.c: Optional[float] = None
54 |
55 | def get_params(self) -> dict:
56 | """
57 | Get the model parameters.
58 |
59 | Returns:
60 | dict: Dictionary containing model parameters a, b, c.
61 | """
62 | return {"a": self.a, "b": self.b, "c": self.c, "parameters": self.parameters}
63 |
64 | def _log_expr(self, col: F.Column) -> F.Column:
65 | """Numerically stable log transformation."""
66 | return F.log(F.when(col < self.EPSILON, self.EPSILON).otherwise(col))
67 |
68 | def _validate_input_df(self, df: DataFrame, score_col: str, label_col: str) -> None:
69 | """
70 | Validate input DataFrame and required columns.
71 |
72 | Args:
73 | df (DataFrame): Input dataframe.
74 | score_col (str): Column containing raw model scores.
75 | label_col (str): Column containing binary labels.
76 |
77 | Raises:
78 | ValueError: If DataFrame is empty or required columns are missing.
79 | """
80 | if df.count() == 0:
81 | raise ValueError("Cannot fit model on empty DataFrame")
82 |
83 | assert (
84 | score_col in df.columns and label_col in df.columns
85 | ), f"Columns {score_col} and {label_col} must be present."
86 |
87 | def _handle_null_values(self, df: DataFrame, score_col: str) -> DataFrame:
88 | """
89 | Handle null values in the score column.
90 |
91 | Args:
92 | df (DataFrame): Input dataframe.
93 | score_col (str): Column containing raw model scores.
94 |
95 | Returns:
96 | DataFrame: Cleaned DataFrame with null values removed.
97 |
98 | Raises:
99 | ValueError: If all rows contain null values.
100 | """
101 | total_rows = df.count()
102 | df_clean = df.dropna(subset=[score_col])
103 | rows_after_drop = df_clean.count()
104 |
105 | if rows_after_drop == 0:
106 | raise ValueError(f"All rows contained null values in {score_col} column")
107 |
108 | dropped_rows = total_rows - rows_after_drop
109 | if dropped_rows > 0:
110 | logger.info(
111 | f"Dropped {dropped_rows}/{total_rows} rows ({(dropped_rows/total_rows)*100:.2f}%) "
112 | f"with null values in column '{score_col}'"
113 | )
114 |
115 | return df_clean
116 |
117 | def _prepare_features(
118 | self, df: DataFrame, score_col: str, label_col: str
119 | ) -> DataFrame:
120 | """
121 | Prepare features for logistic regression with all possible combinations.
122 |
123 | Args:
124 | df (DataFrame): Input dataframe.
125 | score_col (str): Column containing raw model scores.
126 | label_col (str): Column containing binary labels.
127 |
128 | Returns:
129 | DataFrame: Transformed DataFrame with features ready for training.
130 | Contains three feature vectors:
131 | - features_both: Both log(score) and -log(1-score)
132 | - features_score: Only log(score)
133 | - features_complement: Only -log(1-score)
134 | """
135 | log_score = self._log_expr(F.col(score_col))
136 | log_one_minus_score = self._log_expr(1 - F.col(score_col))
137 |
138 | df_transformed = df.select(
139 | F.col(label_col).alias("label"),
140 | log_score.alias("log_score"),
141 | (-1 * log_one_minus_score).alias("log_score_complement"),
142 | )
143 |
144 | # Prepare all possible feature combinations
145 | assembler_both = VectorAssembler(
146 | inputCols=["log_score", "log_score_complement"], outputCol="features_both"
147 | )
148 | assembler_score = VectorAssembler(
149 | inputCols=["log_score"], outputCol="features_score"
150 | )
151 | assembler_complement = VectorAssembler(
152 | inputCols=["log_score_complement"], outputCol="features_complement"
153 | )
154 |
155 | df_with_both = assembler_both.transform(df_transformed)
156 | df_with_score = assembler_score.transform(df_with_both)
157 | return assembler_complement.transform(df_with_score)
158 |
159 | def _fit_logistic_regression(self, train_data: DataFrame) -> None:
160 | """
161 | Fit logistic regression model and set coefficients.
162 |
163 | Args:
164 | train_data (DataFrame): Prepared training data with features.
165 | """
166 | lr = LogisticRegression()
167 |
168 | # First try with both features
169 | model = lr.fit(
170 | train_data.select("label", F.col("features_both").alias("features"))
171 | )
172 | coef = model.coefficients
173 |
174 | if coef[0] < 0:
175 | # Use only complement feature if first coefficient is negative
176 | model = lr.fit(
177 | train_data.select(
178 | "label", F.col("features_complement").alias("features")
179 | )
180 | )
181 | self.a = 0.0
182 | self.b = float(model.coefficients[0])
183 | elif coef[1] < 0:
184 | # Use only score feature if second coefficient is negative
185 | model = lr.fit(
186 | train_data.select("label", F.col("features_score").alias("features"))
187 | )
188 | self.a = float(model.coefficients[0])
189 | self.b = 0.0
190 | else:
191 | self.a = float(coef[0])
192 | self.b = float(coef[1])
193 |
194 | self.c = float(model.intercept)
195 |
196 | def _validate_score_range(self, df: DataFrame, score_col: str) -> None:
197 | """
198 | Validate that scores are within valid range (0,1).
199 |
200 | Args:
201 | df (DataFrame): Input dataframe.
202 | score_col (str): Column containing raw model scores.
203 |
204 | Raises:
205 | ValueError: If scores are outside valid range.
206 | """
207 | stats = df.select(
208 | F.min(score_col).alias("min"), F.max(score_col).alias("max")
209 | ).collect()[0]
210 |
211 | if stats.min < 0 or stats.max > 1:
212 | raise ValueError(
213 | f"Scores must be in range [0,1], got range [{stats.min:.3f}, {stats.max:.3f}]"
214 | )
215 |
216 | def fit(
217 | self, df: DataFrame, score_col: str = "score", label_col: str = "label"
218 | ) -> "Betacal":
219 | """
220 | Fit a beta calibration model using logistic regression.
221 |
222 | Args:
223 | df (DataFrame): Input dataframe.
224 | score_col (str): Column containing raw model scores.
225 | label_col (str): Column containing binary labels.
226 |
227 | Returns:
228 | Betacal: The fitted model instance (self).
229 |
230 | Raises:
231 | ValueError: If input DataFrame is empty or contains all null values.
232 | """
233 | self._validate_input_df(df, score_col, label_col)
234 | self._validate_score_range(df, score_col)
235 | df_clean = self._handle_null_values(df, score_col)
236 | train_data = self._prepare_features(df_clean, score_col, label_col)
237 | self._fit_logistic_regression(train_data)
238 | return self
239 |
240 | def predict(
241 | self,
242 | df: DataFrame,
243 | score_col: str = "score",
244 | prediction_col: str = "prediction",
245 | ) -> DataFrame:
246 | """
247 | Apply the learned beta calibration model to predict calibrated scores.
248 |
249 | Args:
250 | df (DataFrame): Input dataframe with raw scores.
251 | score_col (str): Column name for raw score.
252 | prediction_col (str): Name for the output prediction column.
253 |
254 | Returns:
255 | DataFrame: Original dataframe with an added prediction column.
256 | Null values in score_col will result in null predictions.
257 |
258 | Raises:
259 | ValueError: If calibration coefficients are not set or scores are outside valid range.
260 | """
261 | if self.a is None or self.b is None or self.c is None:
262 | raise ValueError(
263 | "Model coefficients a, b, and c must be set. Call `.fit()` or `.load()` before prediction."
264 | )
265 |
266 | assert score_col in df.columns, f"{score_col} must be present."
267 |
268 | self._validate_score_range(df.filter(F.col(score_col).isNotNull()), score_col)
269 |
270 | log_score = self._log_expr(F.col(score_col))
271 | log_one_minus_score = self._log_expr(1 - F.col(score_col))
272 |
273 | logit = (
274 | F.lit(self.a) * log_score
275 | + F.lit(self.b) * (-1 * log_one_minus_score)
276 | + F.lit(self.c)
277 | )
278 |
279 | prediction = F.when(F.col(score_col).isNull(), None).otherwise(
280 | 1 / (1 + F.exp(-logit))
281 | )
282 |
283 | return df.withColumn(prediction_col, prediction)
284 |
285 | def save(self, path: Optional[str] = None, prefix: str = "betacal_") -> str:
286 | """
287 | Save the model coefficients to disk.
288 |
289 | Args:
290 | path (str, optional): Directory to save into. Creates temp dir if None.
291 | prefix (str): Prefix for temp folder name if path is None.
292 |
293 | Returns:
294 | str: The final save path.
295 | """
296 | if path is None:
297 | path = tempfile.mkdtemp(prefix=prefix)
298 |
299 | os.makedirs(path, exist_ok=True)
300 |
301 | with open(os.path.join(path, "coeffs.json"), "w") as f:
302 | json.dump(
303 | {"a": self.a, "b": self.b, "c": self.c, "parameters": self.parameters},
304 | f,
305 | )
306 |
307 | return path
308 |
309 | @classmethod
310 | def load(cls, path: str) -> "Betacal":
311 | """
312 | Load model coefficients from disk.
313 |
314 | Args:
315 | path (str): Directory containing 'coeffs.json'.
316 |
317 | Returns:
318 | Betacal: The loaded model.
319 | """
320 | with open(os.path.join(path, "coeffs.json"), "r") as f:
321 | coeffs = json.load(f)
322 |
323 | model = cls(parameters=coeffs["parameters"])
324 | model.a = coeffs["a"]
325 | model.b = coeffs["b"]
326 | model.c = coeffs["c"]
327 | return model
328 |
--------------------------------------------------------------------------------
/src/spark_calibration/metrics.py:
--------------------------------------------------------------------------------
1 | from pyspark.ml.evaluation import BinaryClassificationEvaluator
2 | import pyspark.sql.functions as F
3 |
4 | from pyspark.sql.dataframe import DataFrame
5 |
6 |
7 | def display_classification_calib_metrics(df: DataFrame):
8 | """Print pre and post calibration metrics for comparison
9 |
10 | Args:
11 | df: dataframe with score, label and prediction(calibratied score) columns
12 | """
13 |
14 | assert (
15 | "score" in df.columns and "label" in df.columns and "prediction" in df.columns
16 | ), "score and label columns should be present in the dataframe"
17 |
18 | model_bs = df.select(F.avg(F.pow(df["label"] - df["score"], 2))).collect()[0][0]
19 | model_ll = df.select(
20 | F.avg(
21 | -F.col("label") * F.log(F.col("score"))
22 | - (1 - F.col("label")) * F.log(1 - F.col("score"))
23 | )
24 | ).collect()[0][0]
25 |
26 | model_aucpr = BinaryClassificationEvaluator(
27 | rawPredictionCol="score", metricName="areaUnderPR"
28 | ).evaluate(df)
29 | model_roc_auc = BinaryClassificationEvaluator(
30 | rawPredictionCol="score", metricName="areaUnderROC"
31 | ).evaluate(df)
32 | iso_bs = df.select(F.avg(F.pow(df["label"] - df["prediction"], 2))).collect()[0][0]
33 |
34 | print(f"model brier score loss: {model_bs}")
35 | print(f"calibrated model brier score loss: {iso_bs}")
36 |
37 | print(f"delta: {round((iso_bs/model_bs - 1) * 100, 2)}%")
38 | iso_ll = df.select(
39 | F.avg(
40 | -F.col("label") * F.log(F.col("prediction"))
41 | - (1 - F.col("label")) * F.log(1 - F.col("prediction"))
42 | )
43 | ).collect()[0][0]
44 |
45 | print("")
46 |
47 | print(f"model log loss: {model_ll}")
48 | print(f"calibrated model log loss: {iso_ll}")
49 | print(f"delta: {round((iso_ll/model_ll - 1) * 100, 2)}%")
50 | iso_aucpr = BinaryClassificationEvaluator(
51 | rawPredictionCol="prediction", metricName="areaUnderPR"
52 | ).evaluate(df)
53 |
54 | print("")
55 |
56 | print(f"model aucpr: {model_aucpr}")
57 | print(f"calibrated model aucpr: {iso_aucpr}")
58 | print(f"delta: {round((iso_aucpr/model_aucpr - 1) * 100, 2)}%")
59 | iso_roc_auc = BinaryClassificationEvaluator(
60 | rawPredictionCol="prediction", metricName="areaUnderROC"
61 | ).evaluate(df)
62 |
63 | print("")
64 |
65 | print(f"model roc_auc: {model_roc_auc}")
66 | print(f"calibrated model roc_auc: {iso_roc_auc}")
67 | print(f"delta: {round((iso_roc_auc/model_roc_auc - 1) * 100, 2)}%")
68 |
--------------------------------------------------------------------------------
/src/spark_calibration/visualisation.py:
--------------------------------------------------------------------------------
1 | import plotly.graph_objects as go
2 | from sklearn.calibration import calibration_curve
3 | from pyspark.sql.dataframe import DataFrame
4 |
5 |
6 | def plot_calibration_curve(df: DataFrame):
7 | assert (
8 | "score" in df.columns and "label" in df.columns and "prediction" in df.columns
9 | ), "score and label columns should be present in the dataframe"
10 |
11 | df_p_v = df.select("label", "score", "prediction").toPandas().values
12 |
13 | y_test_true, y_test_pred, y_test_pred_cal = df_p_v[:, 0], df_p_v[:, 1], df_p_v[:, 2]
14 |
15 | fig = go.Figure()
16 |
17 | fig.add_trace(
18 | go.Scatter(
19 | x=[0, 1],
20 | y=[0, 1],
21 | mode="lines",
22 | name="Perfect Calibration Baseline (y=x)",
23 | line=dict(dash="dash"),
24 | )
25 | )
26 |
27 | prob_true, prob_pred = calibration_curve(
28 | y_test_true, y_test_pred, n_bins=50, strategy="quantile"
29 | )
30 | print(f"number of bins for pre-calibration scores: {prob_true.shape[0]}")
31 |
32 | fig.add_trace(
33 | go.Scatter(x=prob_pred, y=prob_true, mode="lines+markers", name="Model")
34 | )
35 |
36 | prob_true, prob_pred = calibration_curve(
37 | y_test_true, y_test_pred_cal, n_bins=50, strategy="quantile"
38 | )
39 | print(f"number of bins for post-calibration scores: {prob_true.shape[0]}")
40 |
41 | fig.add_trace(
42 | go.Scatter(
43 | x=prob_pred, y=prob_true, mode="lines+markers", name="Calibrated Model"
44 | )
45 | )
46 |
47 | fig.update_layout(
48 | title=dict(text="Calibration Curve on test data (quantile bins)", x=0.5),
49 | xaxis_title="Mean Predicted Probability",
50 | yaxis_title="Actual Fraction of Positives",
51 | )
52 |
53 | fig.show()
54 |
--------------------------------------------------------------------------------