├── .gitignore ├── Notebook.ipynb ├── README.md ├── Sigmoid regression.ipynb ├── calibration.py ├── model.py ├── plotting.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A guide to model calibration 2 | 3 | A simple implementation of two methods of calibrating machine learning models - Platt scaling and isotonic regression. 4 | For the main notebook with a demo of the calibration on MNIST dataset, see [Notebook.ipynb](Notebook.ipynb). The modules used are contained in [model.py](model.py), [calibration.py](calibration.py), and [plotting.py](plotting.py). 5 | 6 | For a short description of fitting a logistic regression line using linear regression, see [Sigmoid regression.ipynb](Sigmoid%20regression.ipynb). 7 | -------------------------------------------------------------------------------- /Sigmoid regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "8268718f", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "from sklearn.linear_model import LinearRegression\n", 12 | "import matplotlib.pyplot as plt" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "20d2d4a1", 18 | "metadata": {}, 19 | "source": [ 20 | "Creating some dummy, noised, S-shaped points:" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 19, 26 | "id": "06109975", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "x = np.linspace(-5, 5, 10)\n", 31 | "x += (np.random.random(10)-0.5)/1.5\n", 32 | "y = 1 / (1 + np.exp(-x))\n", 33 | "y += (np.random.random(10)-0.5)/6\n", 34 | "x = np.clip((x+5)/10, 0.01, 0.99)\n", 35 | "y = np.clip(y, 0.01, 0.99)" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "id": "5c683818", 41 | "metadata": {}, 42 | "source": [ 43 | "Fitting the regresssion model:" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 20, 49 | "id": "2aee6c6a", 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "data": { 54 | "image/png": "\n", 55 | "text/plain": [ 56 | "
" 57 | ] 58 | }, 59 | "metadata": { 60 | "needs_background": "light" 61 | }, 62 | "output_type": "display_data" 63 | } 64 | ], 65 | "source": [ 66 | "plt.figure(figsize=(12, 10))\n", 67 | "plt.subplots_adjust(hspace=.5)\n", 68 | "plt.subplot(2,2,1)\n", 69 | "plt.title('1. Start with some points forming an \"S\" shape\\n\\n', fontsize=12)\n", 70 | "plt.scatter(x, y)\n", 71 | "\n", 72 | "plt.subplot(2,2,2)\n", 73 | "y_logit = np.log(y / (1 - y))\n", 74 | "regressor = LinearRegression().fit(x.reshape(-1, 1), y_logit.reshape(-1, 1))\n", 75 | "y_predicted = regressor.predict(x.reshape(-1, 1))\n", 76 | "plt.title('2. Transforming the points into logits with $\\\\ln{\\\\left(\\\\frac{x}{1-x}\\\\right)}$\\nyields a more or less straight line.\\nA linear regression model can be fit to this data.', fontsize=12)\n", 77 | "plt.scatter(x, y_logit)\n", 78 | "plt.plot(x, y_predicted, color='red')\n", 79 | "\n", 80 | "plt.subplot(2,2,3)\n", 81 | "plt.title('3. The regression line can be transformed back to\\nthe original domain using sigmoid function $\\\\frac{1}{1+e^{-x}}$', fontsize=12)\n", 82 | "x_new = np.linspace(0,1,100)\n", 83 | "y_new = regressor.predict(x_new.reshape(-1, 1))\n", 84 | "y_new_sigmoid = 1 / (1 + np.exp(-y_new))\n", 85 | "plt.scatter(x, y)\n", 86 | "plt.plot(x_new, y_new_sigmoid, color='red')\n", 87 | "plt.savefig('sigmoid-regression.svg', bbox_inches='tight')\n", 88 | "plt.show()" 89 | ] 90 | } 91 | ], 92 | "metadata": { 93 | "kernelspec": { 94 | "display_name": "Python [conda env:calibration]", 95 | "language": "python", 96 | "name": "conda-env-calibration-py" 97 | }, 98 | "language_info": { 99 | "codemirror_mode": { 100 | "name": "ipython", 101 | "version": 3 102 | }, 103 | "file_extension": ".py", 104 | "mimetype": "text/x-python", 105 | "name": "python", 106 | "nbconvert_exporter": "python", 107 | "pygments_lexer": "ipython3", 108 | "version": "3.8.11" 109 | } 110 | }, 111 | "nbformat": 4, 112 | "nbformat_minor": 5 113 | } 114 | -------------------------------------------------------------------------------- /calibration.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.isotonic import IsotonicRegression 3 | from sklearn.linear_model import LinearRegression 4 | 5 | 6 | class SigmoidCalibrator: 7 | def __init__(self, prob_pred, prob_true): 8 | prob_pred, prob_true = self._filter_out_of_domain(prob_pred, prob_true) 9 | prob_true = np.log(prob_true / (1 - prob_true)) 10 | self.regressor = LinearRegression().fit( 11 | prob_pred.reshape(-1, 1), prob_true.reshape(-1, 1) 12 | ) 13 | 14 | def calibrate(self, probabilities): 15 | return 1 / (1 + np.exp(-self.regressor.predict(probabilities.reshape(-1, 1)).flatten())) 16 | 17 | def _filter_out_of_domain(self, prob_pred, prob_true): 18 | filtered = list(zip(*[p for p in zip(prob_pred, prob_true) if 0 < p[1] < 1])) 19 | return np.array(filtered) 20 | 21 | 22 | class IsotonicCalibrator: 23 | def __init__(self, prob_pred, prob_true): 24 | self.regressor = IsotonicRegression(out_of_bounds="clip") 25 | self.regressor.fit(prob_pred, prob_true) 26 | 27 | def calibrate(self, probabilities): 28 | return self.regressor.predict(probabilities) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import h2o 2 | import numpy as np 3 | import pandas as pd 4 | from h2o.estimators import H2OEstimator as H2OClassifier 5 | from sklearn.base import ClassifierMixin as ScikitClassifier 6 | from sklearn.calibration import calibration_curve 7 | from tensorflow.keras import Model as KerasBaseModel 8 | 9 | from calibration import IsotonicCalibrator, SigmoidCalibrator 10 | 11 | 12 | class CalibratableModelFactory: 13 | def get_model(self, base_model): 14 | if isinstance(base_model, H2OClassifier): 15 | return H2OModel(base_model) 16 | elif isinstance(base_model, ScikitClassifier): 17 | return ScikitModel(base_model) 18 | elif isinstance(base_model, KerasBaseModel): 19 | return KerasModel(base_model) 20 | raise ValueError("Unsupported model passed as an argument") 21 | 22 | 23 | class CalibratableModelMixin: 24 | def __init__(self, model): 25 | self.model = model 26 | self.name = model.__class__.__name__ 27 | self.sigmoid_calibrator = None 28 | self.isotonic_calibrator = None 29 | self.calibrators = { 30 | "sigmoid": None, 31 | "isotonic": None, 32 | } 33 | 34 | def calibrate(self, X, y): 35 | predictions = self.predict(X) 36 | prob_true, prob_pred = calibration_curve(y, predictions, n_bins=10) 37 | self.calibrators["sigmoid"] = SigmoidCalibrator(prob_pred, prob_true) 38 | self.calibrators["isotonic"] = IsotonicCalibrator(prob_pred, prob_true) 39 | 40 | def calibrate_probabilities(self, probabilities, method="isotonic"): 41 | if method not in self.calibrators: 42 | raise ValueError("Method has to be either 'sigmoid' or 'isotonic'") 43 | if self.calibrators[method] is None: 44 | raise ValueError("Fit the calibrators first") 45 | return self.calibrators[method].calibrate(probabilities) 46 | 47 | def predict_calibrated(self, X, method="isotonic"): 48 | return self.calibrate_probabilities(self.predict(X), method) 49 | 50 | def score(self, X, y): 51 | return self._get_accuracy(y, self.predict(X)) 52 | 53 | def score_calibrated(self, X, y, method="isotonic"): 54 | return self._get_accuracy(y, self.predict_calibrated(X, method)) 55 | 56 | def _get_accuracy(self, y, preds): 57 | return np.mean(np.equal(y.astype(np.bool), preds >= 0.5)) 58 | 59 | 60 | class H2OModel(CalibratableModelMixin): 61 | def train(self, X, y): 62 | self.features = list(range(len(X[0]))) 63 | self.target = "target" 64 | train_frame = self._to_h2o_frame(X, y) 65 | self.model.train(x=self.features, y=self.target, training_frame=train_frame) 66 | 67 | def predict(self, X): 68 | predict_frame = self._to_h2o_frame(X) 69 | return self.model.predict(predict_frame).as_data_frame()["p1"].to_numpy() 70 | 71 | def _to_h2o_frame(self, X, y=None): 72 | df = pd.DataFrame(data=X, columns=self.features) 73 | if y is not None: 74 | df[self.target] = y 75 | h2o_frame = h2o.H2OFrame(df) 76 | if y is not None: 77 | h2o_frame[self.target] = h2o_frame[self.target].asfactor() 78 | return h2o_frame 79 | 80 | 81 | class ScikitModel(CalibratableModelMixin): 82 | def train(self, X, y): 83 | self.model.fit(X, y) 84 | 85 | def predict(self, X): 86 | return self.model.predict_proba(X)[:, 1] 87 | 88 | 89 | class KerasModel(CalibratableModelMixin): 90 | def train(self, X, y): 91 | self.model.fit(X, y, batch_size=128, epochs=10, verbose=0) 92 | 93 | def predict(self, X): 94 | return self.model.predict(X).flatten() 95 | -------------------------------------------------------------------------------- /plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from sklearn.calibration import calibration_curve 4 | from sklearn.metrics import brier_score_loss 5 | 6 | 7 | def plot_sample(X, title=None, size=8): 8 | fig = plt.figure() 9 | st = fig.suptitle(title) 10 | for i, im in enumerate(np.random.permutation(X)[:size]): 11 | fig.add_subplot(1, size, i + 1) 12 | plt.axis("off") 13 | plt.imshow(im.reshape(28, 28), cmap="gray") 14 | fig.tight_layout() 15 | st.set_y(0.59) 16 | 17 | 18 | def plot_sample_predictions(models, X, X_unscaled, size=8): 19 | indexes = np.random.permutation(len(X))[:size] 20 | fig = plt.figure(figsize=(6, len(models) * 1.6), constrained_layout=True) 21 | subfigs = fig.subfigures(len(models)) 22 | for j, model in enumerate(models): 23 | subfig = subfigs.flat[j] 24 | subfig.suptitle(model.name) 25 | for i, idx in enumerate(indexes): 26 | prediction = model.predict(np.array([X[idx]]))[0] 27 | subfig.add_subplot(1, size, i + 1) 28 | plt.axis("off") 29 | plt.title(f"{round(100*prediction,2)}%", fontsize=10) 30 | plt.imshow(X_unscaled[idx].reshape(28, 28), cmap="gray") 31 | 32 | 33 | def plot_calibration_curve(y, probs, title): 34 | brier_score = brier_score_loss(y, probs) 35 | prob_true, prob_pred = calibration_curve(y, probs, n_bins=10) 36 | plt.plot([0, 1], [0, 1], linestyle="--") 37 | plt.plot( 38 | prob_pred, 39 | prob_true, 40 | marker=".", 41 | color="orange", 42 | ) 43 | plt.title(f"{title}\nBrier score: {round(brier_score, 3)}") 44 | plt.ylabel("Fraction of positives") 45 | plt.xlabel("Mean predicted value") 46 | return prob_true, prob_pred 47 | 48 | 49 | def plot_fitted_calibrator(prob_true, prob_pred, prob_calibrated, title=None): 50 | plt.plot([0, 1], [0, 1], linestyle="--") 51 | plt.plot(prob_pred, prob_true, marker=".", color="orange") 52 | plt.plot(prob_pred, prob_calibrated, color="red") 53 | plt.title(title) 54 | plt.ylabel("Fraction of positives") 55 | plt.xlabel("Mean predicted value") 56 | 57 | 58 | def plot_calibration_details_for_models( 59 | models, X, y, calibrated=False, method="isotonic" 60 | ): 61 | plt.figure(figsize=(10, 10)) 62 | ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2) 63 | ax2 = plt.subplot2grid((3, 1), (2, 0)) 64 | ax1.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated") 65 | for model in models: 66 | name = model.name 67 | probabilities = ( 68 | model.predict(X) 69 | if not calibrated 70 | else model.predict_calibrated(X, method=method) 71 | ) 72 | prob_true, prob_pred = calibration_curve(y, probabilities, n_bins=10) 73 | brier_score = brier_score_loss(y, probabilities) 74 | 75 | ax1.plot( 76 | prob_pred, 77 | prob_true, 78 | marker=".", 79 | label=f"{name} (BS={round(brier_score, 3)})", 80 | ) 81 | 82 | ax2.hist( 83 | probabilities, range=(0, 1), bins=10, label=name, histtype="step", lw=2 84 | ) 85 | 86 | ax1.set_ylabel("Fraction of positives") 87 | ax1.set_ylim([-0.05, 1.05]) 88 | ax1.legend(loc="upper left") 89 | ax1.set_title("Calibration plots") 90 | 91 | ax2.set_xlabel("Mean predicted value") 92 | ax2.set_ylabel("Count") 93 | ax2.legend(loc="upper center", ncol=2) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.14.0 2 | argon2-cffi==20.1.0 3 | astunparse==1.6.3 4 | async-generator==1.10 5 | attrs==21.2.0 6 | backcall==0.2.0 7 | bleach==4.0.0 8 | cachetools==4.2.2 9 | certifi==2021.5.30 10 | cffi==1.14.6 11 | charset-normalizer==2.0.6 12 | clang==5.0 13 | colorama==0.4.4 14 | cycler==0.10.0 15 | debugpy==1.4.1 16 | decorator==5.0.9 17 | defusedxml==0.7.1 18 | entrypoints==0.3 19 | flatbuffers==1.12 20 | future==0.18.2 21 | gast==0.4.0 22 | google-auth==1.35.0 23 | google-auth-oauthlib==0.4.6 24 | google-pasta==0.2.0 25 | grpcio==1.40.0 26 | h2o==3.34.0.1 27 | h5py==3.1.0 28 | idna==3.2 29 | importlib-metadata==4.8.1 30 | ipykernel==6.2.0 31 | ipython==7.27.0 32 | ipython-genutils==0.2.0 33 | jedi==0.18.0 34 | Jinja2==3.0.1 35 | joblib==1.0.1 36 | jsonschema==3.2.0 37 | jupyter-client==7.0.1 38 | jupyter-core==4.7.1 39 | jupyterlab-pygments==0.1.2 40 | keras==2.6.0 41 | Keras-Preprocessing==1.1.2 42 | kiwisolver==1.3.2 43 | Markdown==3.3.4 44 | MarkupSafe==2.0.1 45 | matplotlib==3.4.3 46 | matplotlib-inline==0.1.2 47 | mistune==0.8.4 48 | nb-conda==2.2.1 49 | nb-conda-kernels==2.3.1 50 | nbclient==0.5.3 51 | nbconvert==6.1.0 52 | nbformat==5.1.3 53 | nest-asyncio==1.5.1 54 | notebook==6.4.3 55 | numpy==1.19.5 56 | oauthlib==3.1.1 57 | opt-einsum==3.3.0 58 | packaging==21.0 59 | pandas==1.3.3 60 | pandocfilters==1.4.3 61 | parso==0.8.2 62 | pickleshare==0.7.5 63 | Pillow==8.3.2 64 | pip==21.0.1 65 | prometheus-client==0.11.0 66 | prompt-toolkit==3.0.17 67 | protobuf==3.18.0 68 | pyasn1==0.4.8 69 | pyasn1-modules==0.2.8 70 | pycparser==2.20 71 | Pygments==2.10.0 72 | pyparsing==2.4.7 73 | pyrsistent==0.17.3 74 | python-dateutil==2.8.2 75 | pytz==2021.1 76 | pywin32==228 77 | pywinpty==0.5.7 78 | pyzmq==22.2.1 79 | requests==2.26.0 80 | requests-oauthlib==1.3.0 81 | rsa==4.7.2 82 | scikit-learn==0.24.2 83 | scipy==1.7.1 84 | Send2Trash==1.5.0 85 | setuptools==58.0.4 86 | six==1.15.0 87 | sklearn==0.0 88 | tabulate==0.8.9 89 | tensorboard==2.6.0 90 | tensorboard-data-server==0.6.1 91 | tensorboard-plugin-wit==1.8.0 92 | tensorflow==2.6.0 93 | tensorflow-estimator==2.6.0 94 | termcolor==1.1.0 95 | terminado==0.9.4 96 | testpath==0.5.0 97 | threadpoolctl==2.2.0 98 | tornado==6.1 99 | traitlets==5.0.5 100 | typing-extensions==3.7.4.3 101 | urllib3==1.26.6 102 | wcwidth==0.2.5 103 | webencodings==0.5.1 104 | Werkzeug==2.0.1 105 | wheel==0.37.0 106 | wincertstore==0.2 107 | wrapt==1.12.1 108 | zipp==3.5.0 109 | --------------------------------------------------------------------------------