├── .gitignore ├── LICENSE ├── README.md ├── calibration ├── __init__.py ├── calibrators.py ├── util_test.py └── utils.py ├── codalab_commands.bash ├── data ├── cifar10vgg.h5 └── cifar_probs.dat ├── examples ├── advanced_example.py ├── bootstrap_example.py ├── multiclass_example.py └── simple_example.py ├── experiments ├── debiased_estimator │ ├── estimation_error_vs_bins.py │ └── mse_vs_ce_tradeoff.py ├── platt_not_calibrated │ ├── cifar10vgg.py │ ├── lower_bounds.py │ ├── save_cifar_logits.py │ └── save_imagenet_logits.py ├── scaling_binning_calibrator │ └── compare_calibrators.py └── synthetic │ └── synthetic.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | saved_files/* 3 | *pycache* 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ananya Kumar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Uncertainty Calibration Library 2 | 3 | This repository contains library code to measure the calibration error of models, including confidence intervals computed by Bootstrap resampling, and code to recalibrate models. Our functions estimate the calibration error and ECE more accurately than prior "plugin" estimators and we provide Bootstrap confidence intervals. See [Verified Uncertainty Calibration](https://arxiv.org/abs/1909.10155) for more details. 4 | 5 | Motivating example for uncertainty calibration: Calster and Vickers 2015 train a random forest that takes in features such as tumor size and presence of teratoma, and tries to predict the probability a patient has testicular cancer. They note that for a large number of patients, the model predicts around a 20% chance of cancer. In reality, around 40% of these patients had cancer. This underestimation can lead to doctors prescribing the wrong treatment---in a situation where many lives are at stake. 6 | 7 | *The high level point here is that the uncertainties that models output matter, not just the model's accuracy*. Calibration is a popular way to measure the quality of a model's uncertainties, and recalibration is a way to take an existing model and correct its uncertainties to make them better. 8 | 9 | This library focuses on calibration for classification tasks. For regression tasks you could check out: https://github.com/uncertainty-toolbox/uncertainty-toolbox 10 | 11 | ## Installation 12 | 13 | ```python 14 | pip3 install uncertainty-calibration 15 | ``` 16 | 17 | The calibration library requires python 3.6 or higher at the moment because we make use of the Python 3 optional typing mechanism. 18 | If your project requires an earlier of version of python, and you wish to use our library, please contact us. 19 | 20 | 21 | ## Overview 22 | 23 | Measuring the calibration error of a model is as simple as: 24 | 25 | ```python 26 | import calibration as cal 27 | calibration_error = cal.get_calibration_error(model_probs, labels) 28 | ``` 29 | 30 | Recalibrating a model is very simple as well. Recalibration requires a small labeled dataset, on which we train a recalibrator: 31 | 32 | ```python 33 | calibrator = cal.PlattBinnerMarginalCalibrator(num_points, num_bins=10) 34 | calibrator.train_calibration(model_probs, labels) 35 | ``` 36 | 37 | Now whenever the model outputs a prediction, we pass it through the calibrator to produce better probabilities. 38 | 39 | ```python 40 | calibrated_probs = cal.calibrate(test_probs) 41 | ``` 42 | 43 | Our library makes it very easy to measure confidence intervals on the calibration error as well, using bootstrap resamples. 44 | 45 | ```python 46 | [lower, _, upper] = cal.get_calibration_error_uncertainties(model_probs, labels) 47 | ``` 48 | 49 | 50 | ## Examples 51 | 52 | You can find complete code examples in the examples folder. Refer to: 53 | - examples/simple_example.py for a simple example in the binary classification setting. 54 | - examples/multiclass_example.py for the multiclass (more than 2 classes) setting. 55 | - examples/advanced_example.py --- our calibration library also exposes a more customizable interface for advanced users. 56 | 57 | 58 | ## Citation 59 | 60 | If you find this library useful please consider citing our paper: 61 | 62 | @inproceedings{kumar2019calibration, 63 | author = {Ananya Kumar and Percy Liang and Tengyu Ma}, 64 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, 65 | title = {Verified Uncertainty Calibration}, 66 | year = {2019}, 67 | } 68 | 69 | 70 | ## Advanced: ECE, Debiasing, and Top-Label Calibration Error 71 | 72 | By default, our library measure per-class, root-mean-squared calibration error, and uses the techniques in [Verified Uncertainty Calibration](https://arxiv.org/abs/1909.10155) to accurately estimate the calibration error. However, we also support measuring the ECE [(Guo et al)](https://arxiv.org/abs/1706.04599) and using older, less accurate, ways of estimating the calibration error. 73 | 74 | To measure the ECE as in [Guo et al](https://arxiv.org/abs/1706.04599): 75 | 76 | ```python 77 | calibration_error = cal.get_ece(model_probs, labels) 78 | ``` 79 | 80 | To estimate it more accurately, and use a more stable way of binning, use: 81 | 82 | ```python 83 | calibration_error = cal.get_top_calibration_error(model_probs, labels, p=1) 84 | ``` 85 | 86 | Multiclass calibration vs ECE / Top-label: When measuring the calibration error of a multiclass model, we can measure the calibration error of all classes (per-class calibration error), or of the top prediction. As an example, imagine that a medical diagnosis system says there is a 50% chance a patient has a benign tumor, a 10% chance she has an aggressive form of cancer, and a 40% chance she has one of a long list of other conditions. We would like the system to be calibrated on each of these predictions (especially cancer!), and not just the top prediction of a benign tumor. [Nixon et al](https://arxiv.org/abs/1909.10155), [Kumar et al](https://arxiv.org/abs/1909.10155), and [Kull et al](https://arxiv.org/abs/1910.12656) measure per-class calibration to account for this. 87 | 88 | 89 | ## Questions, bugs, and contributions 90 | 91 | Please feel free to ask us questions, submit bug reports, or contribute push requests. 92 | Feel free to submit a brief description of a push request before you implement it to get feedback on it, or see how it can fit into our project. 93 | 94 | 95 | ## Verified Uncertainty Calibration paper 96 | 97 | This repository also contains code for the NeurIPS 2019 (Spotlight) paper [Verified Uncertainty Calibration](https://arxiv.org/abs/1909.10155) 98 | 99 | In our paper, we show that: 100 | - The calibration error of methods like Platt scaling and temperature scaling are typically underestimated, and cannot be easily measured. 101 | - We propose an efficient recalibration method where the calibration error can be measured. 102 | - We show that we can estimate the calibration error with fewer samples (than the standard method) using an estimator from the meteorological literature. 103 | 104 | 105 | ## Experiments 106 | 107 | See our CodaLab worksheet https://worksheets.codalab.org/worksheets/0xb6d027ee127e422989ab9115726c5411 which contains all the experiment runs and the exact code used to produce them. You can download ImageNet logits from CodaLab as well: from https://worksheets.codalab.org/bundles/0x81c9c8a9bf6c47f59f45f6fc80790c3c download imagenet_logits.dat and put them into the data folder. 108 | 109 | The experiments folder contains experiments for the paper. 110 | 111 | We have 4 sets of experiments: 112 | - Showing the Platt scaling is less calibrated than reported (Section 3) 113 | - Comparing the scaling binning calibrator with histogram binning on CIFAR-10 and ImageNet (Section 4) 114 | - Synthetic experiments to validate our theoretical bounds (Section 4) 115 | - Experiments showing the debiased estimator can estimate calibration error with fewer samples than standard estimator (Section 5) 116 | Running each experiment saves plots in the corresponding folder in saved_files 117 | 118 | -------------------------------------------------------------------------------- /calibration/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .utils import * 3 | from .calibrators import * 4 | -------------------------------------------------------------------------------- /calibration/calibrators.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | from . import utils 5 | 6 | 7 | class HistogramCalibrator: 8 | def __init__(self, num_calibration, num_bins): 9 | self._num_calibration = num_calibration 10 | self._num_bins = num_bins 11 | 12 | def train_calibration(self, zs, ys): 13 | bins = utils.get_equal_bins(zs, num_bins=self._num_bins) 14 | self._calibrator = utils.get_histogram_calibrator(zs, ys, bins) 15 | 16 | def calibrate(self, zs): 17 | return self._calibrator(zs) 18 | 19 | 20 | class PlattBinnerCalibrator: 21 | def __init__(self, num_calibration, num_bins): 22 | self._num_calibration = num_calibration 23 | self._num_bins = num_bins 24 | 25 | def train_calibration(self, zs, ys): 26 | self._platt = utils.get_platt_scaler(zs, ys) 27 | platt_probs = self._platt(zs) 28 | bins = utils.get_equal_bins(platt_probs, num_bins=self._num_bins) 29 | self._discrete_calibrator = utils.get_discrete_calibrator(platt_probs, bins) 30 | 31 | def calibrate(self, zs): 32 | platt_probs = self._platt(zs) 33 | return self._discrete_calibrator(platt_probs) 34 | 35 | 36 | class PlattCalibrator: 37 | def __init__(self, num_calibration, num_bins): 38 | self._num_calibration = num_calibration 39 | self._num_bins = num_bins 40 | 41 | def train_calibration(self, zs, ys): 42 | self._platt = utils.get_platt_scaler(zs, ys) 43 | 44 | def calibrate(self, zs): 45 | return self._platt(zs) 46 | 47 | 48 | class HistogramTopCalibrator: 49 | 50 | def __init__(self, num_calibration, num_bins): 51 | self._num_calibration = num_calibration 52 | self._num_bins = num_bins 53 | 54 | def train_calibration(self, probs, labels): 55 | assert(len(probs) >= self._num_calibration) 56 | top_probs = utils.get_top_probs(probs) 57 | predictions = utils.get_top_predictions(probs) 58 | correct = (predictions == labels) 59 | bins = utils.get_equal_bins(top_probs, num_bins=self._num_bins) 60 | self._calibrator = utils.get_histogram_calibrator( 61 | top_probs, correct, bins) 62 | 63 | def calibrate(self, probs): 64 | top_probs = utils.get_top_probs(probs) 65 | return self._calibrator(top_probs) 66 | 67 | 68 | class PlattBinnerTopCalibrator: 69 | 70 | def __init__(self, num_calibration, num_bins): 71 | self._num_calibration = num_calibration 72 | self._num_bins = num_bins 73 | 74 | def train_calibration(self, probs, labels): 75 | assert(len(probs) >= self._num_calibration) 76 | predictions = utils.get_top_predictions(probs) 77 | top_probs = utils.get_top_probs(probs) 78 | correct = (predictions == labels) 79 | self._platt = utils.get_platt_scaler( 80 | top_probs, correct) 81 | platt_probs = self._platt(top_probs) 82 | bins = utils.get_equal_bins(platt_probs, num_bins=self._num_bins) 83 | self._discrete_calibrator = utils.get_discrete_calibrator( 84 | platt_probs, bins) 85 | 86 | def calibrate(self, probs): 87 | top_probs = self._platt(utils.get_top_probs(probs)) 88 | return self._discrete_calibrator(top_probs) 89 | 90 | 91 | class PlattTopCalibrator: 92 | 93 | def __init__(self, num_calibration, num_bins): 94 | self._num_calibration = num_calibration 95 | self._num_bins = num_bins 96 | 97 | def train_calibration(self, probs, labels): 98 | assert(len(probs) >= self._num_calibration) 99 | predictions = utils.get_top_predictions(probs) 100 | top_probs = utils.get_top_probs(probs) 101 | correct = (predictions == labels) 102 | self._platt = utils.get_platt_scaler( 103 | top_probs, correct) 104 | 105 | def calibrate(self, probs): 106 | return self._platt(utils.get_top_probs(probs)) 107 | 108 | 109 | class IdentityTopCalibrator: 110 | 111 | def __init__(self, num_calibration, num_bins): 112 | pass 113 | 114 | def train_calibration(self, probs, labels): 115 | pass 116 | 117 | def calibrate(self, probs): 118 | return utils.get_top_probs(probs) 119 | 120 | 121 | class HistogramMarginalCalibrator: 122 | 123 | def __init__(self, num_calibration, num_bins): 124 | self._num_calibration = num_calibration 125 | self._num_bins = num_bins 126 | 127 | def train_calibration(self, probs, labels): 128 | """Train a calibrator given probs and labels. 129 | 130 | Args: 131 | probs: A sequence of dimension (n, k) where n is the number of 132 | data points, and k is the number of classes, representing 133 | the output probabilities/confidences of the uncalibrated 134 | model. 135 | labels: A sequence of length n, where n is the number of data points, 136 | representing the ground truth label for each data point. 137 | """ 138 | assert(len(probs) >= self._num_calibration) 139 | probs = np.array(probs) 140 | self._k = probs.shape[1] # Number of classes. 141 | assert self._k == np.max(labels) - np.min(labels) + 1 142 | labels_one_hot = utils.get_labels_one_hot(np.array(labels), self._k) 143 | self._calibrators = [] 144 | for c in range(self._k): 145 | # For each class c, get the probabilities the model output for that class, and whether 146 | # the data point was actually class c, or not. 147 | probs_c = probs[:, c] 148 | labels_c = labels_one_hot[:, c] 149 | bins = utils.get_equal_bins(probs_c, num_bins=self._num_bins) 150 | calibrator_c = utils.get_histogram_calibrator(probs_c, labels_c, bins) 151 | self._calibrators.append(calibrator_c) 152 | 153 | def calibrate(self, probs): 154 | probs = np.array(probs) 155 | assert self._k == probs.shape[1] 156 | calibrated_probs = np.zeros(probs.shape) 157 | for c in range(self._k): 158 | probs_c = probs[:, c] 159 | calibrated_probs[:, c] = self._calibrators[c](probs_c) 160 | return calibrated_probs 161 | 162 | 163 | class PlattBinnerMarginalCalibrator: 164 | 165 | def __init__(self, num_calibration, num_bins): 166 | self._num_calibration = num_calibration 167 | self._num_bins = num_bins 168 | 169 | def train_calibration(self, probs, labels): 170 | """Train a calibrator given probs and labels. 171 | 172 | Args: 173 | probs: A sequence of dimension (n, k) where n is the number of 174 | data points, and k is the number of classes, representing 175 | the output probabilities/confidences of the uncalibrated 176 | model. 177 | labels: A sequence of length n, where n is the number of data points, 178 | representing the ground truth label for each data point. 179 | """ 180 | assert(len(probs) >= self._num_calibration) 181 | probs = np.array(probs) 182 | self._k = probs.shape[1] # Number of classes. 183 | assert self._k == np.max(labels) - np.min(labels) + 1 184 | labels_one_hot = utils.get_labels_one_hot(np.array(labels), self._k) 185 | assert labels_one_hot.shape == probs.shape 186 | self._platts = [] 187 | self._calibrators = [] 188 | for c in range(self._k): 189 | # For each class c, get the probabilities the model output for that class, and whether 190 | # the data point was actually class c, or not. 191 | probs_c = probs[:, c] 192 | labels_c = labels_one_hot[:, c] 193 | platt_c = utils.get_platt_scaler(probs_c, labels_c) 194 | self._platts.append(platt_c) 195 | platt_probs_c = platt_c(probs_c) 196 | bins = utils.get_equal_bins(platt_probs_c, num_bins=self._num_bins) 197 | calibrator_c = utils.get_discrete_calibrator(platt_probs_c, bins) 198 | self._calibrators.append(calibrator_c) 199 | 200 | 201 | def calibrate(self, probs): 202 | probs = np.array(probs) 203 | assert self._k == probs.shape[1] 204 | calibrated_probs = np.zeros(probs.shape) 205 | for c in range(self._k): 206 | probs_c = probs[:, c] 207 | platt_probs_c = self._platts[c](probs_c) 208 | calibrated_probs[:, c] = self._calibrators[c](platt_probs_c) 209 | return calibrated_probs 210 | -------------------------------------------------------------------------------- /calibration/util_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from parameterized import parameterized 3 | from calibration.utils import * 4 | from calibration.utils import _get_ce 5 | import numpy as np 6 | 7 | 8 | import collections 9 | 10 | 11 | def multiset_equal(list1, list2): 12 | return collections.Counter(list1) == collections.Counter(list2) 13 | 14 | 15 | def list_to_tuple(l): 16 | if not isinstance(l, list): 17 | return l 18 | return tuple(list_to_tuple(x) for x in l) 19 | 20 | 21 | class TestUtilMethods(unittest.TestCase): 22 | 23 | def test_split(self): 24 | self.assertEqual(split([1, 3, 2, 4], parts=2), [[1, 3], [2, 4]]) 25 | self.assertEqual(split([1], parts=1), [[1]]) 26 | self.assertEqual(split([2, 3], parts=1), [[2, 3]]) 27 | self.assertEqual(split([2, 3], parts=2), [[2], [3]]) 28 | self.assertEqual(split([1, 2, 3], parts=1), [[1, 2, 3]]) 29 | self.assertEqual(split([1, 2, 3], parts=2), [[1, 2], [3]]) 30 | self.assertEqual(split([1, 2, 3], parts=3), [[1], [2], [3]]) 31 | 32 | def test_get_1_equal_bin(self): 33 | probs = [0.3, 0.5, 0.2, 0.3, 0.5, 0.7] 34 | bins = get_equal_bins(probs, num_bins=1) 35 | self.assertEqual(bins, [1.0]) 36 | 37 | def test_get_2_equal_bins(self): 38 | probs = [0.3, 0.5, 0.2, 0.3, 0.5, 0.7] 39 | bins = get_equal_bins(probs, num_bins=2) 40 | self.assertEqual(bins, [0.4, 1.0]) 41 | 42 | def test_get_3_equal_bins(self): 43 | probs = [0.3, 0.5, 0.2, 0.3, 0.5, 0.7] 44 | bins = get_equal_bins(probs, num_bins=3) 45 | self.assertEqual(bins, [0.3, 0.5, 1.0]) 46 | 47 | def test_get_3_equal_bins_lots_of_1s(self): 48 | probs = [0.3, 0.5, 1.0, 1.0, 1.0, 1.0] 49 | bins = get_equal_bins(probs, num_bins=3) 50 | self.assertEqual(bins, [0.75, 1.0]) 51 | 52 | def test_get_3_equal_bins_uneven_sizes(self): 53 | probs = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] 54 | bins = np.array(get_equal_bins(probs, num_bins=3)) 55 | self.assertTrue(np.allclose(bins, np.array([0.55, 0.75, 1.0]))) 56 | 57 | def test_equal_bins_more_bins_points(self): 58 | probs = [0.3] 59 | bins = get_equal_bins(probs, num_bins=2) 60 | self.assertEqual(bins, [1.0]) 61 | bins = get_equal_bins(probs, num_bins=5) 62 | self.assertEqual(bins, [1.0]) 63 | probs = [0.3, 0.5] 64 | bins = get_equal_bins(probs, num_bins=5) 65 | self.assertEqual(bins, [0.4, 1.0]) 66 | 67 | def test_equal_bin_num_bins(self): 68 | for n in [1,2,3,5,10,20]: 69 | for num_bins in range(1,n): 70 | bins = split(np.arange(n) / float(n), num_bins) 71 | self.assertEqual(len(bins), num_bins) 72 | 73 | def test_get_1_equal_prob_bins(self): 74 | probs = [0.3, 0.5, 0.2, 0.3, 0.5, 0.7] 75 | bins = get_equal_prob_bins(probs, num_bins=1) 76 | self.assertEqual(bins, [1.0]) 77 | 78 | def test_get_2_equal_prob_bins(self): 79 | probs = [0.3, 0.5, 0.2, 0.3, 0.5, 0.7] 80 | bins = get_equal_prob_bins(probs, num_bins=2) 81 | self.assertEqual(bins, [0.5, 1.0]) 82 | 83 | def test_get_discrete_bins(self): 84 | probs = [0.3, 0.5, 0.2, 0.3, 0.5, 0.7] 85 | bins = get_discrete_bins(probs) 86 | self.assertEqual(bins, [0.25, 0.4, 0.6, 1.0]) 87 | 88 | def test_enough_duplicates(self): 89 | probs = np.array([0.1, 0.3, 0.5]) 90 | self.assertFalse(enough_duplicates(probs)) 91 | probs = np.array([0.1, 0.1, 0.5]) 92 | self.assertFalse(enough_duplicates(probs)) 93 | probs = np.array([0.1, 0.1, 0.1, 0.1, 0.6, 0.6, 0.6, 0.6, 0.6]) 94 | self.assertTrue(enough_duplicates(probs)) 95 | 96 | def test_get_bin(self): 97 | bins = [0.2, 0.5, 1.0] 98 | self.assertEqual(get_bin(0.0, bins), 0) 99 | self.assertEqual(get_bin(0.19, bins), 0) 100 | self.assertEqual(get_bin(0.21, bins), 1) 101 | self.assertEqual(get_bin(0.49, bins), 1) 102 | self.assertEqual(get_bin(0.51, bins), 2) 103 | self.assertEqual(get_bin(1.0, bins), 2) 104 | 105 | def test_get_bin_size_1(self): 106 | bins = [1.0] 107 | self.assertEqual(get_bin(0.0, bins), 0) 108 | self.assertEqual(get_bin(0.5, bins), 0) 109 | self.assertEqual(get_bin(1.0, bins), 0) 110 | 111 | def test_bin_all_same(self): 112 | for n in range(1,10): 113 | for num_bins in range(1,min(3,n)): 114 | data = [(0.5, 1.0)] * n 115 | probs = [p for p, y in data] 116 | bins = get_equal_bins(probs, num_bins=num_bins) 117 | binned_data = bin(data, bins) 118 | self.assertTrue( 119 | np.all(np.array(binned_data[0]) == np.array(data))) 120 | for j in range(1, num_bins): 121 | self.assertEqual(len(binned_data[j]), 0) 122 | 123 | def test_bin(self): 124 | data = [(0.3, 1.0), (0.5, 0.0), (0.2, 1.0), (0.3, 0.0), (0.5, 1.0), (0.7, 0.0)] 125 | bins = [0.4, 1.0] 126 | binned_data = tuple(np.array(bin(data, bins)).tolist()) 127 | self.assertTrue(multiset_equal( 128 | list_to_tuple(binned_data[0]), ((0.3, 1.0), (0.2, 1.0), (0.3, 0.0)))) 129 | self.assertTrue(multiset_equal( 130 | list_to_tuple(binned_data[1]), ((0.5, 1.0), (0.5, 0.0), (0.7, 0.0)))) 131 | 132 | @parameterized.expand([ 133 | [[(0.3, 0.5)], -0.2], 134 | [[(0.5, 0.3)], 0.2], 135 | [[(0.3, 0.5), (0.8, 0.4)], 0.1], 136 | [[(0.3, 0.5), (0.8, 0.4), (0.4, 0.0)], 0.2] 137 | ]) 138 | def test_difference_mean(self, data, true_value): 139 | self.assertAlmostEqual(difference_mean(data), true_value) 140 | 141 | 142 | @parameterized.expand([ 143 | [[[(0.3, 0.5)]], [1.0]], 144 | [[[(0.3, 0.5)], [(0.4, 0.7)]], [0.5, 0.5]], 145 | [[[(0.3, 0.5)], [(0.4, 0.7)], [(0.0, 1.0), (0.6, 0.0)]], [0.25, 0.25, 0.5]], 146 | ]) 147 | def test_get_bin_probs(self, binned_data, probs): 148 | self.assertAlmostEqual(get_bin_probs(binned_data), probs) 149 | 150 | @parameterized.expand([ 151 | [[[(0.3, 1.0)]], 1, 0.7], 152 | [[[(0.3, 0.0)]], 1, 0.3], 153 | [[[(0.3, 1.0)]], 2, 0.7], 154 | [[[(0.3, 1.0)], [(0.6, 0.0)]], 1, 0.65], 155 | [[[(0.3, 1.0)], [(0.6, 0.0), (0.6, 1.0)]], 1, 0.3], 156 | [[[(0.3, 1.0)], [(0.6, 0.0)]], 2, 0.6519202405], 157 | ]) 158 | def test_plugin_ce(self, binned_data, power, true_ce): 159 | self.assertAlmostEqual(plugin_ce(binned_data, power), true_ce) 160 | 161 | @parameterized.expand([ 162 | [1, 4/9*0.25+5/9*0.2], 163 | [2, (4/9*(0.25**2)+5/9*(0.2**2))**(1/2.0)], 164 | [3, (4/9*(0.25**3)+5/9*(0.2**3))**(1/3.0)], 165 | ]) 166 | def test_get_binary_ce(self, p, true_ce): 167 | probs = [0.5, 0.5, 0.5, 0.6, 0.5, 0.6, 0.6, 0.6, 0.6] 168 | labels = [0, 1, 0, 1, 0, 1, 1, 1, 0] 169 | pred_ce = _get_ce(probs, labels, p, debias=False, num_bins=None, 170 | binning_scheme=get_discrete_bins) 171 | self.assertAlmostEqual(pred_ce, true_ce) 172 | wrapper_ce = get_calibration_error(probs, labels, p=p, debias=False) 173 | self.assertAlmostEqual(pred_ce, wrapper_ce) 174 | pred_ce = _get_ce(probs, labels, p, debias=True, num_bins=None, 175 | binning_scheme=get_discrete_bins) 176 | self.assertLess(pred_ce, true_ce) 177 | 178 | @parameterized.expand([ 179 | [1, 4/9*0.25+5/9*0.2], 180 | [2, (4/9*(0.25**2)+5/9*(0.2**2))**(1/2.0)], 181 | [3, (4/9*(0.25**3)+5/9*(0.2**3))**(1/3.0)], 182 | ]) 183 | def test_get_two_label_ce(self, p, true_ce): 184 | # Same as the previous test, except probs is now multi-dimensional. 185 | pt6 = [0.4, 0.6] 186 | pt5 = [0.5, 0.5] 187 | probs = [pt5, pt5, pt5, pt6, pt5, pt6, pt6, pt6, pt6] 188 | labels = [0, 1, 0, 1, 0, 1, 1, 1, 0] 189 | pred_ce = _get_ce(probs, labels, p, debias=False, num_bins=None, 190 | binning_scheme=get_discrete_bins) 191 | self.assertAlmostEqual(pred_ce, true_ce) 192 | # Check that the wrapper calls _get_ce with the right options. 193 | wrapper_ce = get_calibration_error(probs, labels, p=p, debias=False) 194 | self.assertAlmostEqual(pred_ce, wrapper_ce) 195 | # For the 2 label case, marginal calibration and top-label calibration should be the same. 196 | top_label_ce = get_calibration_error(probs, labels, p=p, debias=False, mode='top-label') 197 | self.assertAlmostEqual(top_label_ce, pred_ce) 198 | debiased_top_label_ce = get_calibration_error(probs, labels, p=p, debias=True, mode='top-label') 199 | self.assertLess(debiased_top_label_ce, pred_ce) 200 | debiased_pred_ce = _get_ce(probs, labels, p, debias=True, num_bins=None, 201 | binning_scheme=get_discrete_bins) 202 | self.assertLess(debiased_pred_ce, true_ce) 203 | 204 | @parameterized.expand([ 205 | [1, 4/42.0*0.1 + 6/42.0*(5/6.0-0.6) + 8/42.0*0.1 + 4/42.0*0.05 + 6/42.0*(0.3-1/6.0), 206 | 6/14.0*(5/6.0-0.6) + 4/14.0*0.05 + 4/14.0*0.1], 207 | [2, (4/42.0*0.1**2 + 6/42.0*(5/6.0-0.6)**2 + 8/42.0*0.1**2 + 208 | 4/42.0*0.05**2 + 6/42.0*(0.3-1/6.0)**2)**(1/2.0), 209 | (6/14.0*(5/6.0-0.6)**2 + 4/14.0*0.05**2 + 4/14.0*0.1**2)**(1/2.0)], 210 | [3, (4/42.0*0.1**3 + 6/42.0*(5/6.0-0.6)**3 + 8/42.0*0.1**3 + 211 | 4/42.0*0.05**3 + 6/42.0*(0.3-1/6.0)**3)**(1/3.0), 212 | (6/14.0*(5/6.0-0.6)**3 + 4/14.0*0.05**3 + 4/14.0*0.1**3)**(1/3.0)] 213 | ]) 214 | def test_get_three_label_ce(self, p, true_marginal_ce, true_top_ce): 215 | # Same as the previous test, except probs is now multi-dimensional. 216 | l0 = [0.6, 0.3, 0.1] 217 | l1 = [0.1, 0.8, 0.1] 218 | l2 = [0.1, 0.0, 0.9] 219 | probs = np.array([l0, l0, l0, l0, l0, l0, l1, l1, l1, l1, l2, l2, l2, l2]) 220 | labels = np.array([ 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]) 221 | perm = np.random.permutation(len(labels)) 222 | probs, labels = probs[perm], labels[perm] 223 | pred_ce = _get_ce(probs, labels, p, debias=False, num_bins=None, 224 | binning_scheme=get_discrete_bins) 225 | self.assertAlmostEqual(pred_ce, true_marginal_ce) 226 | # Check that the wrapper calls _get_ce with the right options. 227 | wrapper_ce = get_calibration_error(probs, labels, p=p, debias=False) 228 | self.assertAlmostEqual(pred_ce, wrapper_ce) 229 | top_label_ce = get_calibration_error(probs, labels, p=p, debias=False, mode='top-label') 230 | self.assertAlmostEqual(top_label_ce, true_top_ce) 231 | debiased_top_label_ce = get_calibration_error(probs, labels, p=p, debias=True, mode='top-label') 232 | self.assertLess(debiased_top_label_ce, true_top_ce) 233 | debiased_pred_ce = _get_ce(probs, labels, p, debias=True, num_bins=None, 234 | binning_scheme=get_discrete_bins) 235 | self.assertLess(debiased_pred_ce, true_marginal_ce) 236 | 237 | 238 | @parameterized.expand([ 239 | [1, 0.5*(2/3.0-0.6) + 0.5*(1-2.75/3)], 240 | [2, (0.5*(2/3.0-0.6)**2 + 0.5*(1-2.75/3)**2)**(1/2.0)], 241 | [3, (0.5*(2/3.0-0.6)**3 + 0.5*(1-2.75/3)**3)**(1/3.0)], 242 | ]) 243 | def test_three_label_top_ce_lower_bound(self, p, true_top_ce): 244 | # Same as the previous test, except probs is now multi-dimensional. 245 | probs = np.array([[0.8, 0.1, 0.1], 246 | [0.6, 0.2, 0.2], 247 | [0.0, 0.9, 0.1], 248 | [0.0, 1.0, 0.0], 249 | [0.3, 0.3, 0.4], 250 | [0.2, 0.0, 0.85]]) 251 | labels = np.array([0, 0, 1, 1, 0, 2]) 252 | perm = np.random.permutation(len(labels)) 253 | probs, labels = probs[perm], labels[perm] 254 | top_label_ce = lower_bound_scaling_ce(probs, labels, p=p, debias=False, num_bins=2, 255 | binning_scheme=get_equal_bins, mode='top-label') 256 | self.assertAlmostEqual(top_label_ce, true_top_ce) 257 | debiased_top_label_ce = lower_bound_scaling_ce(probs, labels, p=p, debias=True, num_bins=2, 258 | binning_scheme=get_equal_bins, mode='top-label') 259 | self.assertLess(debiased_top_label_ce, true_top_ce) 260 | 261 | def test_ece(self): 262 | probs = np.array([[0.8, 0.1, 0.1], 263 | [0.6, 0.2, 0.2], 264 | [0.0, 0.9, 0.1], 265 | [0.0, 1.0, 0.0], 266 | [0.3, 0.3, 0.4], 267 | [0.2, 0.0, 0.85]]) 268 | labels = np.array([0, 0, 1, 1, 0, 2]) 269 | perm = np.random.permutation(len(labels)) 270 | probs, labels = probs[perm], labels[perm] 271 | true_ece = 4/6.0 * (1 - (0.8+0.85+0.9+1.0)/4) 272 | pred_ece = get_ece(probs, labels, num_bins=3) 273 | self.assertAlmostEqual(pred_ece, true_ece) 274 | probs = [0.6, 0.7, 0.8, 0.9] 275 | labels = [0, 0, 1, 1] 276 | pred_ece = get_ece(probs, labels, num_bins=2) 277 | true_ece = 0.25 278 | self.assertAlmostEqual(pred_ece, true_ece) 279 | 280 | @parameterized.expand([ 281 | [[0.1], [1], 1, 0.9], 282 | [[0.1], [0], 1, 0.1], 283 | [[0.1, 0.7], [0, 1], 1, 0.1], 284 | [[0.7, 0.1], [1, 0], 1, 0.1], 285 | [[0.1, 0.7, 0.4], [0, 0, 0], 1, 0.4], 286 | [[0.1, 0.9], [0, 1], 1, 0.0], 287 | [[0.1, 0.7], [0, 1], 2, 0.2], 288 | [[0.1, 0.1, 0.7], [0, 1, 1], 2, 0.4*2/3+0.3*1/3], 289 | [[0.1, 0.1, 0.1, 0.1, 0.7], [0, 1, 0, 0, 1], 2, 0.15*4/5+0.3*1/5], 290 | [[0.1, 0.7, 0.5, 0.9], [0, 1, 0, 1], 2, 0.25], 291 | [[0.1, 0.7, 0.5, 0.9], [0, 1, 0, 1], 4, 0.25], 292 | [[0.6, 0.7, 0.8, 0.9], [0, 0, 1, 1], 2, 0.4], 293 | [[0.1, 0.7, 0.5, 0.9], [0, 1, 1, 1], 2, 0.2], 294 | [[0.1, 0.7, 0.5, 0.9], [0, 1, 1, 1], 4, 0.25], 295 | ]) 296 | def test_1d_ece_em(self, probs, correct, num_bins, true_ece): 297 | pred_ece = get_ece_em(probs, correct, num_bins=num_bins) 298 | self.assertAlmostEqual(pred_ece, true_ece) 299 | # If number of bins is 1, then test that the regular ece is the same too. 300 | if num_bins == 1: 301 | pred_ece_ew = get_ece(probs, correct, num_bins=num_bins) 302 | self.assertAlmostEqual(pred_ece_ew, true_ece) 303 | 304 | def test_missing_classes_ece(self): 305 | pred_ece = get_ece([[0.9,0.1], [0.8,0.2]], [0,0]) 306 | true_ece = 0.15 307 | self.assertAlmostEqual(pred_ece, true_ece) 308 | 309 | def test_missing_class_binary_ece(self): 310 | pred_ece = get_ece([0.9, 0.1, 0.3], [0, 0, 0], num_bins=1) 311 | true_ece = 1.3 / 3 312 | self.assertAlmostEqual(pred_ece, true_ece) 313 | pred_ece = get_ece([0.9, 0.1, 0.3], [1, 1, 1], num_bins=1) 314 | true_ece = 1.7 / 3 315 | self.assertAlmostEqual(pred_ece, true_ece) 316 | 317 | @parameterized.expand([ 318 | [[0.1], [1], 1.0, 1.0], 319 | [[0.6], [0], 0.0, 0.0], 320 | [[0.3, 0.9], [0, 1], 0.75, 1.0], 321 | [[0.9, 0.3], [1, 0], 0.75, 1.0], 322 | [[0.3, 0.9], [1, 0], 0.25, 0.0], 323 | [[0.6, 0.0, 0.8], [1, 0, 1], 8/9.0, 1.0], 324 | [[0.6, 0.0, 0.8], [1, 1, 1], 1.0, 1.0], 325 | [[0.6, 0.0, 0.8], [0, 0, 0], 0.0, 0.0], 326 | [[0.6, 0.0, 0.8], [0, 0, 1], 11/18.0, 1.0], 327 | [[0.6, 0.0, 0.8], [0, 1, 0], 1/9.0, 0.0], 328 | [[0.1]*10+[0.6,0.7], [0]*10+[0,0], 0.0, 0.0], 329 | [[0.1]*10+[0.6,0.7], [0]*10+[0,1], np.mean(1.0/np.arange(1,13)), 0.5], 330 | [[0.1]*10+[0.6,0.7], [0]*10+[1,0], np.mean(1.0/np.arange(1,13))-1.0/12, 0.5], 331 | [[0.1]*9+[0.5,0.6,0.7], [0]*9+[1,0,0], np.mean(1.0/np.arange(1,13))-1.5/12, 0.0] 332 | ]) 333 | def test_selective_stats(self, probs, correct, sel_acc, sel_90): 334 | pred_sel_acc, pred_sel_90 = get_selective_stats(probs, correct) 335 | self.assertAlmostEqual(sel_acc, pred_sel_acc) 336 | self.assertAlmostEqual(sel_90, pred_sel_90) 337 | 338 | if __name__ == '__main__': 339 | unittest.main() 340 | -------------------------------------------------------------------------------- /calibration/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import bisect 3 | from typing import List, Tuple, NewType, TypeVar 4 | import numpy as np 5 | import pickle 6 | from sklearn.linear_model import LogisticRegression 7 | 8 | # Define data types. 9 | 10 | Data = List[Tuple[float, float]] # List of (predicted_probability, true_label). 11 | Bins = List[float] # List of bin boundaries, excluding 0.0, but including 1.0. 12 | BinnedData = List[Data] # binned_data[i] contains the data in bin i. 13 | T = TypeVar('T') 14 | 15 | eps = 1e-6 16 | 17 | 18 | # Functions the produce bins from data. 19 | 20 | def split(sequence: List[T], parts: int) -> List[List[T]]: 21 | assert parts <= len(sequence) 22 | array_splits = np.array_split(sequence, parts) 23 | splits = [list(l) for l in array_splits] 24 | assert len(splits) == parts 25 | return splits 26 | 27 | def get_equal_bins(probs: List[float], num_bins: int=10) -> Bins: 28 | """Get bins that contain approximately an equal number of data points.""" 29 | sorted_probs = sorted(probs) 30 | if num_bins > len(sorted_probs): 31 | num_bins = len(sorted_probs) 32 | binned_data = split(sorted_probs, num_bins) 33 | bins: Bins = [] 34 | for i in range(len(binned_data) - 1): 35 | last_prob = binned_data[i][-1] 36 | next_first_prob = binned_data[i + 1][0] 37 | bins.append((last_prob + next_first_prob) / 2.0) 38 | bins.append(1.0) 39 | bins = sorted(list(set(bins))) 40 | return bins 41 | 42 | 43 | def get_equal_prob_bins(probs: List[float], num_bins: int=10) -> Bins: 44 | return [i * 1.0 / num_bins for i in range(1, num_bins + 1)] 45 | 46 | 47 | def get_discrete_bins(data: List[float]) -> Bins: 48 | sorted_values = sorted(np.unique(data)) 49 | bins = [] 50 | for i in range(len(sorted_values) - 1): 51 | mid = (sorted_values[i] + sorted_values[i+1]) / 2.0 52 | bins.append(mid) 53 | bins.append(1.0) 54 | return bins 55 | 56 | 57 | # User facing functions to measure calibration error. 58 | 59 | def get_top_calibration_error_uncertainties(probs, labels, p=2, alpha=0.1): 60 | return get_calibration_error_uncertainties(probs, labels, p, alpha, mode='top-label') 61 | 62 | 63 | def get_calibration_error_uncertainties(probs, labels, p=2, alpha=0.1, mode='marginal'): 64 | """Get confidence intervals for the calibration error. 65 | 66 | Args: 67 | probs: A numpy array of shape (n,) or (n, k). If the shape is (n,) then 68 | we assume binary classification and probs[i] is the model's confidence 69 | the i-th example is 1. Otherwise, probs[i][j] is the model's confidence 70 | the i-th example is j, with 0 <= probs[i][j] <= 1. 71 | labels: A numpy array of shape (n,). labels[i] denotes the label of the i-th 72 | example. In the binary classification setting, labels[i] must be 0 or 1, 73 | in the k class setting labels[i] is an integer with 0 <= labels[i] <= k-1. 74 | p: We measure the lp calibration error, where p >= 1 is an integer. 75 | mode: 'marginal' or 'top-label'. 'marginal' calibration means we compute the 76 | calibraton error for each class and then average them. Top-label means 77 | we compute the calibration error of the prediction that the model is most 78 | confident about. 79 | 80 | Returns: 81 | [lower, mid, upper]: 1-alpha confidence intervals produced by bootstrap resampling. 82 | [lower, upper] represents the confidence interval. mid represents the median of 83 | the bootstrap estimates. When p is not 2 (e.g. for the ECE where p = 1), this 84 | can be used as a debiased estimate as well. 85 | """ 86 | data = list(zip(probs, labels)) 87 | def ce_functional(data): 88 | probs, labels = zip(*data) 89 | return get_calibration_error(probs, labels, p, debias=False, mode=mode) 90 | [lower, mid, upper] = bootstrap_uncertainty(data, ce_functional, num_samples=100, alpha=alpha) 91 | return [lower, mid, upper] 92 | 93 | 94 | def get_top_calibration_error(probs, labels, p=2, debias=True): 95 | return get_calibration_error(probs, labels, p, debias, mode='top-label') 96 | 97 | 98 | def get_calibration_error(probs, labels, p=2, debias=True, mode='marginal'): 99 | """Get the calibration error. 100 | 101 | Args: 102 | probs: A numpy array of shape (n,) or (n, k). If the shape is (n,) then 103 | we assume binary classification and probs[i] is the model's confidence 104 | the i-th example is 1. Otherwise, probs[i][j] is the model's confidence 105 | the i-th example is j, with 0 <= probs[i][j] <= 1. 106 | labels: A numpy array of shape (n,). labels[i] denotes the label of the i-th 107 | example. In the binary classification setting, labels[i] must be 0 or 1, 108 | in the k class setting labels[i] is an integer with 0 <= labels[i] <= k-1. 109 | p: We measure the lp calibration error, where p >= 1 is an integer. 110 | debias: Should we try to debias the estimates? For p = 2, the debiasing 111 | has provably better sample complexity. 112 | mode: 'marginal' or 'top-label'. 'marginal' calibration means we compute the 113 | calibraton error for each class and then average them. Top-label means 114 | we compute the calibration error of the prediction that the model is most 115 | confident about. 116 | 117 | Returns: 118 | Estimated calibration error, a floating point value. 119 | The method first uses heuristics to check if the values came from a scaling 120 | method or binning method, and then calls the corresponding function. For 121 | more explicit control, use lower_bound_scaling_ce or get_binning_ce. 122 | """ 123 | if is_discrete(probs): 124 | return get_binning_ce(probs, labels, p, debias, mode=mode) 125 | else: 126 | return lower_bound_scaling_ce(probs, labels, p, debias, mode=mode) 127 | 128 | 129 | def lower_bound_scaling_top_ce(probs, labels, p=2, debias=True, num_bins=15, 130 | binning_scheme=get_equal_bins): 131 | return lower_bound_scaling_ce(probs, labels, p, debias, num_bins, binning_scheme, 132 | mode='top-label') 133 | 134 | 135 | def lower_bound_scaling_ce(probs, labels, p=2, debias=True, num_bins=15, 136 | binning_scheme=get_equal_bins, mode='marginal'): 137 | """Lower bound the calibration error of a model with continuous outputs. 138 | 139 | Args: 140 | probs: A numpy array of shape (n,) or (n, k). If the shape is (n,) then 141 | we assume binary classification and probs[i] is the model's confidence 142 | the i-th example is 1. Otherwise, probs[i][j] is the model's confidence 143 | the i-th example is j, with 0 <= probs[i][j] <= 1. 144 | labels: A numpy array of shape (n,). labels[i] denotes the label of the i-th 145 | example. In the binary classification setting, labels[i] must be 0 or 1, 146 | in the k class setting labels[i] is an integer with 0 <= labels[i] <= k-1. 147 | p: We measure the lp calibration error, where p >= 1 is an integer. 148 | debias: Should we try to debias the estimates? For p = 2, the debiasing 149 | has provably better sample complexity. 150 | num_bins: Integer number of bins used to estimate the calibration error. 151 | binning_scheme: A function that takes in a list of probabilities and number of bins, 152 | and outputs a list of bins. See get_equal_bins, get_equal_prob_bins for examples. 153 | mode: 'marginal' or 'top-label'. 'marginal' calibration means we compute the 154 | calibraton error for each class and then average them. Top-label means 155 | we compute the calibration error of the prediction that the model is most 156 | confident about. 157 | 158 | Returns: 159 | Estimated lower bound for calibration error, a floating point value. 160 | For scaling methods we cannot estimate the calibration error, but only a 161 | lower bound. 162 | """ 163 | return _get_ce(probs, labels, p, debias, num_bins, binning_scheme, mode=mode) 164 | 165 | 166 | def get_binning_top_ce(probs, labels, p=2, debias=True, mode='marginal'): 167 | return get_binning_ce(probs, labels, p, debias, mode='top-label') 168 | 169 | 170 | def get_binning_ce(probs, labels, p=2, debias=True, mode='marginal'): 171 | """Estimate the calibration error of a binned model. 172 | 173 | Args: 174 | probs: A numpy array of shape (n,) or (n, k). If the shape is (n,) then 175 | we assume binary classification and probs[i] is the model's confidence 176 | the i-th example is 1. Otherwise, probs[i][j] is the model's confidence 177 | the i-th example is j, with 0 <= probs[i][j] <= 1. 178 | labels: A numpy array of shape (n,). labels[i] denotes the label of the i-th 179 | example. In the binary classification setting, labels[i] must be 0 or 1, 180 | in the k class setting labels[i] is an integer with 0 <= labels[i] <= k-1. 181 | p: We measure the lp calibration error, where p >= 1 is an integer. 182 | debias: Should we try to debias the estimates? For p = 2, the debiasing 183 | has provably better sample complexity. 184 | mode: 'marginal' or 'top-label'. 'marginal' calibration means we compute the 185 | calibraton error for each class and then average them. Top-label means 186 | we compute the calibration error of the prediction that the model is most 187 | confident about. 188 | 189 | Returns: 190 | Estimated calibration error, a floating point value. 191 | """ 192 | return _get_ce(probs, labels, p, debias, None, binning_scheme=get_discrete_bins, mode=mode) 193 | 194 | 195 | def get_ece(probs, labels, debias=False, num_bins=15, mode='top-label'): 196 | """Get ECE as computed by Guo et al.""" 197 | return lower_bound_scaling_ce(probs, labels, p=1, debias=debias, num_bins=num_bins, 198 | binning_scheme=get_equal_prob_bins, mode=mode) 199 | 200 | 201 | def get_ece_em(probs, labels, debias=False, num_bins=15, mode='top-label'): 202 | """Get ECE, but use equal mass binning.""" 203 | return lower_bound_scaling_ce(probs, labels, p=1, debias=debias, num_bins=num_bins, 204 | binning_scheme=get_equal_bins, mode=mode) 205 | 206 | 207 | def get_selective_stats(probs, correct): 208 | """Return area under coverage-accuracy curve, and acc for 10% most confident predictions.""" 209 | # Sort in descending order. 210 | probs = np.array(probs) 211 | correct = np.array(correct) 212 | sort_indices = np.argsort(-probs) 213 | sorted_correct = correct[sort_indices] 214 | accs = np.cumsum(sorted_correct) / np.arange(1, len(sorted_correct) + 1) 215 | coverage_acc_area = np.mean(accs) 216 | acc_percentile_90 = accs[int(0.1 * len(sorted_correct))] 217 | return coverage_acc_area, acc_percentile_90 218 | 219 | 220 | def _get_ce(probs, labels, p, debias, num_bins, binning_scheme, mode='marginal'): 221 | def ce_1d(probs, labels): 222 | assert probs.shape == labels.shape 223 | assert len(probs.shape) == 1 224 | data = list(zip(probs, labels)) 225 | if binning_scheme == get_discrete_bins: 226 | assert(num_bins is None) 227 | bins = binning_scheme(probs) 228 | else: 229 | bins = binning_scheme(probs, num_bins=num_bins) 230 | if p == 2 and debias: 231 | return unbiased_l2_ce(bin(data, bins)) 232 | elif debias: 233 | return normal_debiased_ce(bin(data, bins), power=p) 234 | else: 235 | return plugin_ce(bin(data, bins), power=p) 236 | if mode != 'marginal' and mode != 'top-label': 237 | raise ValueError("mode must be 'marginal' or 'top-label'.") 238 | probs = np.array(probs) 239 | labels = np.array(labels) 240 | if not(np.issubdtype(labels.dtype, np.integer)): 241 | raise ValueError('labels should an integer numpy array.') 242 | if len(labels.shape) != 1: 243 | raise ValueError('labels should be a 1D numpy array.') 244 | if probs.shape[0] != labels.shape[0]: 245 | raise ValueError('labels and probs should have the same number of entries.') 246 | if len(probs.shape) == 1: 247 | # If 1D (2-class setting), compute the regular calibration error. 248 | if np.min(labels) < 0 or np.max(labels) > 1: 249 | raise ValueError('If probs is 1D, each label should be 0 or 1.') 250 | return ce_1d(probs, labels) 251 | elif len(probs.shape) == 2: 252 | if np.min(labels) < 0 or np.max(labels) > probs.shape[1] - 1: 253 | raise ValueError('labels should be between 0 and num_classes - 1.') 254 | if mode == 'marginal': 255 | labels_one_hot = get_labels_one_hot(labels, k=probs.shape[1]) 256 | assert probs.shape == labels_one_hot.shape 257 | marginal_ces = [] 258 | for k in range(probs.shape[1]): 259 | cur_probs = probs[:, k] 260 | cur_labels = labels_one_hot[:, k] 261 | marginal_ces.append(ce_1d(cur_probs, cur_labels) ** p) 262 | return np.mean(marginal_ces) ** (1.0 / p) 263 | elif mode == 'top-label': 264 | preds = get_top_predictions(probs) 265 | correct = (preds == labels).astype(probs.dtype) 266 | confidences = get_top_probs(probs) 267 | return ce_1d(confidences, correct) 268 | else: 269 | raise ValueError('probs should be a 1D or 2D numpy array.') 270 | 271 | 272 | def is_discrete(probs): 273 | probs = np.array(probs) 274 | if len(probs.shape) == 1: 275 | return enough_duplicates(probs) 276 | elif len(probs.shape) == 2: 277 | for k in range(probs.shape[1]): 278 | if not enough_duplicates(probs[:, k]): 279 | return False 280 | return True 281 | else: 282 | raise ValueError('probs must be a 1D or 2D numpy array.') 283 | 284 | 285 | def enough_duplicates(array): 286 | # TODO: instead check that we have at least 2 values in each bin. 287 | num_bins = get_discrete_bins(array) 288 | if len(num_bins) < array.shape[0] / 4.0: 289 | return True 290 | return False 291 | 292 | 293 | # Functions that bin data. 294 | 295 | def get_bin(pred_prob: float, bins: List[float]) -> int: 296 | """Get the index of the bin that pred_prob belongs in.""" 297 | assert 0.0 <= pred_prob <= 1.0 298 | assert bins[-1] == 1.0 299 | return bisect.bisect_left(bins, pred_prob) 300 | 301 | 302 | def bin(data: Data, bins: Bins): 303 | return fast_bin(data, bins) 304 | 305 | 306 | def fast_bin(data, bins): 307 | prob_label = np.array(data) 308 | bin_indices = np.searchsorted(bins, prob_label[:, 0]) 309 | bin_sort_indices = np.argsort(bin_indices) 310 | sorted_bins = bin_indices[bin_sort_indices] 311 | splits = np.searchsorted(sorted_bins, list(range(1, len(bins)))) 312 | binned_data = np.split(prob_label[bin_sort_indices], splits) 313 | return binned_data 314 | 315 | 316 | def equal_bin(data: Data, num_bins : int) -> BinnedData: 317 | sorted_probs = sorted(data) 318 | return split(sorted_probs, num_bins) 319 | 320 | 321 | # Calibration error estimators. 322 | 323 | def difference_mean(data : Data) -> float: 324 | """Returns average pred_prob - average label.""" 325 | data = np.array(data) 326 | ave_pred_prob = np.mean(data[:, 0]) 327 | ave_label = np.mean(data[:, 1]) 328 | return ave_pred_prob - ave_label 329 | 330 | 331 | def get_bin_probs(binned_data: BinnedData) -> List[float]: 332 | bin_sizes = list(map(len, binned_data)) 333 | num_data = sum(bin_sizes) 334 | bin_probs = list(map(lambda b: b * 1.0 / num_data, bin_sizes)) 335 | assert(abs(sum(bin_probs) - 1.0) < eps) 336 | return list(bin_probs) 337 | 338 | 339 | def plugin_ce(binned_data: BinnedData, power=2) -> float: 340 | def bin_error(data: Data): 341 | if len(data) == 0: 342 | return 0.0 343 | return abs(difference_mean(data)) ** power 344 | bin_probs = get_bin_probs(binned_data) 345 | bin_errors = list(map(bin_error, binned_data)) 346 | return np.dot(bin_probs, bin_errors) ** (1.0 / power) 347 | 348 | 349 | def unbiased_square_ce(binned_data: BinnedData) -> float: 350 | # Note, this is not the l2 CE. It does not take the square root. 351 | def bin_error(data: Data): 352 | if len(data) < 2: 353 | return 0.0 354 | # raise ValueError('Too few values in bin, use fewer bins or get more data.') 355 | biased_estimate = abs(difference_mean(data)) ** 2 356 | label_values = list(map(lambda x: x[1], data)) 357 | mean_label = np.mean(label_values) 358 | variance = mean_label * (1.0 - mean_label) / (len(data) - 1.0) 359 | return biased_estimate - variance 360 | bin_probs = get_bin_probs(binned_data) 361 | bin_errors = list(map(bin_error, binned_data)) 362 | return np.dot(bin_probs, bin_errors) 363 | 364 | 365 | def unbiased_l2_ce(binned_data: BinnedData) -> float: 366 | return max(unbiased_square_ce(binned_data), 0.0) ** 0.5 367 | 368 | 369 | def normal_debiased_ce(binned_data : BinnedData, power=1, resamples=1000) -> float: 370 | bin_sizes = np.array(list(map(len, binned_data))) 371 | if np.min(bin_sizes) <= 1: 372 | raise ValueError('Every bin must have at least 2 points for debiased estimator. ' 373 | 'Try adding the argument debias=False to your function call.') 374 | label_means = np.array(list(map(lambda l: np.mean([b for a, b in l]), binned_data))) 375 | label_stddev = np.sqrt(label_means * (1 - label_means) / bin_sizes) 376 | model_vals = np.array(list(map(lambda l: np.mean([a for a, b in l]), binned_data))) 377 | assert(label_means.shape == (len(binned_data),)) 378 | assert(model_vals.shape == (len(binned_data),)) 379 | ce = plugin_ce(binned_data, power=power) 380 | bin_probs = get_bin_probs(binned_data) 381 | resampled_ces = [] 382 | for i in range(resamples): 383 | label_samples = np.random.normal(loc=label_means, scale=label_stddev) 384 | # TODO: we can also correct the bias for the model_vals, although this is 385 | # smaller. 386 | diffs = np.power(np.abs(label_samples - model_vals), power) 387 | cur_ce = np.power(np.dot(bin_probs, diffs), 1.0 / power) 388 | resampled_ces.append(cur_ce) 389 | mean_resampled = np.mean(resampled_ces) 390 | bias_corrected_ce = 2 * ce - mean_resampled 391 | return bias_corrected_ce 392 | 393 | 394 | # MSE Estimators. 395 | 396 | def eval_top_mse(calibrated_probs, probs, labels): 397 | correct = (get_top_predictions(probs) == labels) 398 | return np.mean(np.square(calibrated_probs - correct)) 399 | 400 | 401 | def eval_marginal_mse(calibrated_probs, probs, labels): 402 | assert calibrated_probs.shape == probs.shape 403 | k = probs.shape[1] 404 | labels_one_hot = get_labels_one_hot(np.array(labels), k) 405 | return np.mean(np.square(calibrated_probs - labels_one_hot)) * calibrated_probs.shape[1] / 2.0 406 | 407 | 408 | # Bootstrap utilities. 409 | 410 | def resample(data: List[T]) -> List[T]: 411 | indices = np.random.choice(list(range(len(data))), size=len(data), replace=True) 412 | return [data[i] for i in indices] 413 | 414 | 415 | def bootstrap_uncertainty(data: List[T], functional, estimator=None, alpha=10.0, 416 | num_samples=1000) -> Tuple[float, float]: 417 | """Return boostrap uncertained for 1 - alpha percent confidence interval.""" 418 | if estimator is None: 419 | estimator = functional 420 | estimate = estimator(data) 421 | plugin = functional(data) 422 | bootstrap_estimates = [] 423 | for _ in range(num_samples): 424 | bootstrap_estimates.append(estimator(resample(data))) 425 | return (plugin + estimate - np.percentile(bootstrap_estimates, 100 - alpha / 2.0), 426 | plugin + estimate - np.percentile(bootstrap_estimates, 50), 427 | plugin + estimate - np.percentile(bootstrap_estimates, alpha / 2.0)) 428 | 429 | 430 | def precentile_bootstrap_uncertainty(data: List[T], functional, estimator=None, alpha=10.0, 431 | num_samples=1000) -> Tuple[float, float]: 432 | """Return boostrap uncertained for 1 - alpha percent confidence interval.""" 433 | if estimator is None: 434 | estimator = functional 435 | plugin = functional(data) 436 | estimate = estimator(data) 437 | bootstrap_estimates = [] 438 | for _ in range(num_samples): 439 | bootstrap_estimates.append(estimator(resample(data))) 440 | bias = 2 * np.percentile(bootstrap_estimates, 50) - plugin - estimate 441 | return (np.percentile(bootstrap_estimates, alpha / 2.0) - bias, 442 | np.percentile(bootstrap_estimates, 50) - bias, 443 | np.percentile(bootstrap_estimates, 100 - alpha / 2.0) - bias) 444 | 445 | 446 | def bootstrap_std(data: List[T], estimator=None, num_samples=100) -> Tuple[float, float]: 447 | """Return boostrap uncertained for 1 - alpha percent confidence interval.""" 448 | bootstrap_estimates = [] 449 | for _ in range(num_samples): 450 | bootstrap_estimates.append(estimator(resample(data))) 451 | return np.std(bootstrap_estimates) 452 | 453 | 454 | # Re-Calibration utilities. 455 | 456 | def get_platt_scaler(model_probs, labels, get_clf=False): 457 | clf = LogisticRegression(C=1e10, solver='lbfgs') 458 | eps = 1e-12 459 | model_probs = model_probs.astype(dtype=np.float64) 460 | model_probs = np.expand_dims(model_probs, axis=-1) 461 | model_probs = np.clip(model_probs, eps, 1 - eps) 462 | model_probs = np.log(model_probs / (1 - model_probs)) 463 | clf.fit(model_probs, labels) 464 | def calibrator(probs): 465 | x = np.array(probs, dtype=np.float64) 466 | x = np.clip(x, eps, 1 - eps) 467 | x = np.log(x / (1 - x)) 468 | x = x * clf.coef_[0] + clf.intercept_ 469 | output = 1 / (1 + np.exp(-x)) 470 | return output 471 | if get_clf: 472 | return calibrator, clf 473 | return calibrator 474 | 475 | 476 | def get_histogram_calibrator(model_probs, values, bins): 477 | binned_values = [[] for _ in range(len(bins))] 478 | for prob, value in zip(model_probs, values): 479 | bin_idx = get_bin(prob, bins) 480 | binned_values[bin_idx].append(float(value)) 481 | def safe_mean(values, bin_idx): 482 | if len(values) == 0: 483 | if bin_idx == 0: 484 | return float(bins[0]) / 2.0 485 | return float(bins[bin_idx] + bins[bin_idx - 1]) / 2.0 486 | return np.mean(values) 487 | bin_means = [safe_mean(values, bidx) for values, bidx in zip(binned_values, range(len(bins)))] 488 | bin_means = np.array(bin_means) 489 | def calibrator(probs): 490 | indices = np.searchsorted(bins, probs) 491 | return bin_means[indices] 492 | return calibrator 493 | 494 | 495 | def get_discrete_calibrator(model_probs, bins): 496 | return get_histogram_calibrator(model_probs, model_probs, bins) 497 | 498 | 499 | # Utils to load and save files. 500 | 501 | def save_test_probs_labels(dataset, model, filename): 502 | (x_train, y_train), (x_test, y_test) = dataset.load_data() 503 | probs = model.predict(x_test) 504 | pickle.dump((probs, y_test), open(filename, "wb")) 505 | 506 | 507 | def load_test_probs_labels(filename): 508 | probs, labels = pickle.load(open(filename, "rb")) 509 | if len(labels.shape) > 1: 510 | labels = labels[:, 0] 511 | indices = np.random.choice(list(range(len(probs))), size=len(probs), replace=False) 512 | probs = np.array([probs[i] for i in indices]) 513 | labels = np.array([labels[i] for i in indices]) 514 | return probs, labels 515 | 516 | 517 | def get_top_predictions(probs): 518 | return np.argmax(probs, 1) 519 | 520 | 521 | def get_top_probs(probs): 522 | return np.max(probs, 1) 523 | 524 | 525 | def get_accuracy(probs, labels): 526 | return sum(labels == predictions) * 1.0 / len(labels) 527 | 528 | 529 | def get_labels_one_hot(labels, k): 530 | assert np.min(labels) >= 0 531 | assert np.max(labels) <= k - 1 532 | num_labels = labels.shape[0] 533 | labels_one_hot = np.zeros((num_labels, k)) 534 | labels_one_hot[np.arange(num_labels), labels] = 1 535 | return labels_one_hot 536 | -------------------------------------------------------------------------------- /codalab_commands.bash: -------------------------------------------------------------------------------- 1 | 2 | # Lower bound experiments (Section 3). 3 | 4 | # CIFAR: 5 | 6 | cl run :experiments :calibration :data \ 7 | 'export PYTHONPATH="."; \ 8 | python3 experiments/platt_not_calibrated/lower_bounds.py \ 9 | --probs_path=data/cifar_probs.dat --bins_list="2, 4, 8, 16, 32, 64, 128" --lp=2 \ 10 | --calibration_data_size=1000 --bin_data_size=1000 --plot_save_file=l2_lower_bound_cifar_plot.png \ 11 | --binning=equal_bins --num_samples=1000' \ 12 | --request-queue tag=nlp -n cifar_l2_lower_bound 13 | 14 | cl run :experiments :calibration :data \ 15 | 'export PYTHONPATH="."; \ 16 | python3 experiments/platt_not_calibrated/lower_bounds.py \ 17 | --probs_path=data/cifar_probs.dat --bins_list="2, 4, 8, 16, 32, 64, 128" --lp=1 \ 18 | --calibration_data_size=1000 --bin_data_size=1000 --plot_save_file=l1_lower_bound_cifar_plot.png \ 19 | --binning=equal_bins --num_samples=1000' \ 20 | --request-queue tag=nlp -n cifar_l1_lower_bound 21 | 22 | cl run :experiments :calibration :data \ 23 | 'export PYTHONPATH="."; \ 24 | python3 experiments/platt_not_calibrated/lower_bounds.py \ 25 | --probs_path=data/cifar_probs.dat --bins_list="2, 4, 8, 16, 32, 64, 128" --lp=1 \ 26 | --calibration_data_size=1000 --bin_data_size=1000 --plot_save_file=prob_lower_bound_cifar_plot.png \ 27 | --binning=equal_prob_bins --num_samples=1000' \ 28 | --request-queue tag=nlp -n cifar_prob_lower_bound 29 | 30 | # ImageNet: 31 | 32 | cl run :experiments :calibration :data \ 33 | 'export PYTHONPATH="."; \ 34 | python3 experiments/platt_not_calibrated/lower_bounds.py \ 35 | --probs_path=data/imagenet_probs.dat --bins_list="2, 4, 8, 16, 32, 64, 128, 256, 512" --lp=2 \ 36 | --calibration_data_size=20000 --bin_data_size=5000 --plot_save_file=l2_lower_bound_imnet_plot.png \ 37 | --binning=equal_bins --num_samples=1000' \ 38 | --request-queue tag=nlp -n imagenet_l2_lower_bound 39 | 40 | cl run :experiments :calibration :data \ 41 | 'export PYTHONPATH="."; \ 42 | python3 experiments/platt_not_calibrated/lower_bounds.py \ 43 | --probs_path=data/imagenet_probs.dat --bins_list="2, 4, 8, 16, 32, 64, 128, 256, 512" --lp=1 \ 44 | --calibration_data_size=20000 --bin_data_size=5000 --plot_save_file=l1_lower_bound_imnet_plot.png \ 45 | --binning=equal_bins --num_samples=1000' \ 46 | --request-queue tag=nlp -n imagenet_l1_lower_bound 47 | 48 | cl run :experiments :calibration :data \ 49 | 'export PYTHONPATH="."; \ 50 | python3 experiments/platt_not_calibrated/lower_bounds.py \ 51 | --probs_path=data/imagenet_probs.dat --bins_list="2, 4, 8, 16, 32, 64, 128, 256, 512" --lp=1 \ 52 | --calibration_data_size=20000 --bin_data_size=5000 --plot_save_file=prob_lower_bound_imnet_plot.png \ 53 | --binning=equal_prob_bins --num_samples=1000' \ 54 | --request-queue tag=nlp -n imagenet_prob_lower_bound 55 | 56 | 57 | # Comparing calibrators experiment (Section 4). 58 | 59 | cl run :experiments :calibration :data \ 60 | 'export PYTHONPATH="."; \ 61 | python3 experiments/scaling_binning_calibrator/compare_calibrators.py' \ 62 | --request-queue tag=nlp -n compare_calibrators 63 | 64 | 65 | # Synthetic experiments (Section 4). 66 | 67 | cl run :experiments :calibration :data \ 68 | 'export PYTHONPATH="."; \ 69 | python3 experiments/synthetic/synthetic.py \ 70 | --experiment_name=vary_n_a1_b0' \ 71 | --request-queue tag=nlp -n vary_n_a1_b0 72 | 73 | cl run :experiments :calibration :data \ 74 | 'export PYTHONPATH="."; \ 75 | python3 experiments/synthetic/synthetic.py \ 76 | --experiment_name=vary_b_a1_b0' \ 77 | --request-queue tag=nlp -n vary_b_a1_b0 78 | 79 | cl run :experiments :calibration :data \ 80 | 'export PYTHONPATH="."; \ 81 | python3 experiments/synthetic/synthetic.py \ 82 | --experiment_name=vary_n_a2_b1' \ 83 | --request-queue tag=nlp -n vary_n_a2_b1 84 | 85 | cl run :experiments :calibration :data \ 86 | 'export PYTHONPATH="."; \ 87 | python3 experiments/synthetic/synthetic.py \ 88 | --experiment_name=vary_b_a2_b1' \ 89 | --request-queue tag=nlp -n vary_b_a2_b1 90 | 91 | cl run :experiments :calibration :data \ 92 | 'export PYTHONPATH="."; \ 93 | python3 experiments/synthetic/synthetic.py \ 94 | --experiment_name=noisy_vary_n_a2_b1' \ 95 | --request-queue tag=nlp -n noisy_vary_n_a2_b1 96 | 97 | 98 | # Comparing plugin with debiased calibration error estimators (Section 5). 99 | 100 | cl run :experiments :calibration :data \ 101 | 'export PYTHONPATH="."; \ 102 | python3 experiments/debiased_estimator/estimation_error_vs_bins.py' \ 103 | --request-queue tag=nlp -n estimator_comparison -------------------------------------------------------------------------------- /data/cifar10vgg.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p-lambda/verified_calibration/ee81c346895e3377653bd347c429a95bd631058d/data/cifar10vgg.h5 -------------------------------------------------------------------------------- /data/cifar_probs.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p-lambda/verified_calibration/ee81c346895e3377653bd347c429a95bd631058d/data/cifar_probs.dat -------------------------------------------------------------------------------- /examples/advanced_example.py: -------------------------------------------------------------------------------- 1 | """More complete example showing features of the calibration library. 2 | 3 | For more advanced users our calibration library is fairly customizable. 4 | 5 | We offer a variety of different ways to estimate the calibration 6 | error (e.g. plugin vs debiased), and can measure ECE, or the standard l2 error. 7 | 8 | The user can also choose a different binning scheme, or write their own binning 9 | scheme. For example, in the literature some people use equal probability binning, 10 | splitting the interval [0, 1] into B equal parts. Others split the interval [0, 1] 11 | into B bin so that each bin has an equal number of data points. Alternative 12 | binning schemes are also possible. 13 | """ 14 | 15 | 16 | import calibration 17 | import numpy as np 18 | 19 | 20 | def main(): 21 | # Make synthetic dataset. 22 | np.random.seed(0) # Keep results consistent. 23 | num_points = 1000 24 | (zs, ys) = synthetic_data_1d(num_points=num_points) 25 | 26 | # Estimate a lower bound on the calibration error. 27 | # Here z_i is the confidence of the uncalibrated model, y_i is the true label. 28 | # In simple_example.py we used get_calibration_error, but for advanced users 29 | # we recommend using the more explicit lower_bound_scaling_ce to have 30 | # more control over functionality, and be explicit about the semantics - 31 | # that we are only estimating a lower bound. 32 | l2_calibration_error = calibration.lower_bound_scaling_ce(zs, ys) 33 | print("Uncalibrated model l2 calibration error is > %.2f%%" % (100 * l2_calibration_error)) 34 | 35 | # We can break this down into multiple steps. 1. We choose a binning scheme, 36 | # 2. we bin the data, and 3. we measure the calibration error. 37 | # Each of these steps can be customized, and users can substitute the component 38 | # with their own code. 39 | data = list(zip(zs, ys)) 40 | bins = calibration.get_equal_bins(zs, num_bins=10) 41 | l2_calibration_error = calibration.unbiased_l2_ce(calibration.bin(data, bins)) 42 | print("Uncalibrated model l2 calibration error is > %.2f%%" % (100 * l2_calibration_error)) 43 | 44 | # Use Platt binning to train a recalibrator. 45 | calibrator = calibration.PlattBinnerCalibrator(num_points, num_bins=10) 46 | calibrator.train_calibration(np.array(zs), ys) 47 | 48 | # Measure the calibration error of recalibrated model. 49 | # In this case we have a binning model, so we can estimate the true calibration error. 50 | # Again, for advanced users we recommend being explicit and using get_binning_ce instead 51 | # of get_calibration_error. 52 | (test_zs, test_ys) = synthetic_data_1d(num_points=num_points) 53 | calibrated_zs = list(calibrator.calibrate(test_zs)) 54 | l2_calibration_error = calibration.get_binning_ce(calibrated_zs, test_ys) 55 | print("Scaling-binning l2 calibration error is %.2f%%" % (100 * l2_calibration_error)) 56 | 57 | # As above we can break this down into 3 steps. Notice here we have a binning model, 58 | # so we use get_discrete_bins to get all the bins (all possible values the model 59 | # outputs). 60 | data = list(zip(calibrated_zs, test_ys)) 61 | bins = calibration.get_discrete_bins(calibrated_zs) 62 | binned = calibration.bin(data, bins) 63 | l2_calibration_error = calibration.unbiased_l2_ce(calibration.bin(data, bins)) 64 | print("Scaling-binning l2 calibration error is %.2f%%" % (100 * l2_calibration_error)) 65 | 66 | # Compute calibration error and confidence interval. 67 | # In the simple_example.py we just called get_calibration_error_uncertainties. 68 | # This function uses the bootstrap to estimate confidence intervals. 69 | # The bootstrap first requires us to define the functional we are trying to 70 | # estimate, and then resamples the data multiple times to estimate confidence intervals. 71 | def estimate_ce(data, estimator): 72 | zs = [z for z, y in data] 73 | binned_data = calibration.bin(data, calibration.get_discrete_bins(zs)) 74 | return estimator(binned_data) 75 | functional = lambda data: estimate_ce(data, lambda x: calibration.plugin_ce(x)) 76 | [lower, _, upper] = calibration.bootstrap_uncertainty(data, functional, num_samples=100) 77 | print(" Confidence interval is [%.2f%%, %.2f%%]" % (100 * lower, 100 * upper)) 78 | 79 | # Advanced: boostrap can be used to debias the l1-calibration error (ECE) as well. 80 | # This is a heuristic, which does not (yet) come with a formal guarantee. 81 | functional = lambda data: estimate_ce(data, lambda x: calibration.plugin_ce(x, power=1)) 82 | [lower, mid, upper] = calibration.bootstrap_uncertainty(data, functional, num_samples=100) 83 | print("Debiased estimate of L1 calibration error is %.2f%%" % (100 * mid)) 84 | print(" Confidence interval is [%.2f%%, %.2f%%]" % (100 * lower, 100 * upper)) 85 | 86 | 87 | # Helper functions used to generate synthetic data. 88 | 89 | def synthetic_data_1d(num_points): 90 | f_true = platt_function(1, 1) 91 | return sample(f_true, np.random.uniform, num_points) 92 | 93 | def platt_function(a, b): 94 | """Return a (vectorized) platt function f: [0, 1] -> [0, 1] parameterized by a, b.""" 95 | def eval(x): 96 | x = np.log(x / (1 - x)) 97 | x = a * x + b 98 | return 1 / (1 + np.exp(-x)) 99 | return np.vectorize(eval) 100 | 101 | def sample(f, z_dist, n): 102 | """Returns ([z_1, ..., z_n], [y_1, ..., y_n]) where z_i ~ z_dist, y_i ~ Bernoulli(f(z_i)).""" 103 | zs = list(z_dist(size=n)) 104 | ps = f(zs) 105 | ys = list(np.random.binomial(1, p=ps)) 106 | return (zs, ys) 107 | 108 | if __name__ == "__main__": 109 | main() 110 | -------------------------------------------------------------------------------- /examples/bootstrap_example.py: -------------------------------------------------------------------------------- 1 | 2 | import calibration as cal 3 | import numpy as np 4 | 5 | # Keep the results consistent. 6 | np.random.seed(0) 7 | 8 | def mean_test(): 9 | p = 0.9 10 | num_trials = 100 11 | num_samples = 100 12 | bootstrap_valid = 0 13 | means = [] 14 | for i in range(num_trials): 15 | samples = np.random.binomial(n=1, p=p, size=num_samples) 16 | means.append(np.mean(samples)) 17 | # lower, _, upper = cal.bootstrap_uncertainty( 18 | # list(samples), np.mean, num_samples=1000) 19 | # if lower <= p <= upper: 20 | # bootstrap_valid += 1 21 | # print("Valid percent is {}".format(float(bootstrap_valid) / num_trials)) 22 | print(np.mean(means), np.std(means)) 23 | 24 | def ce_test(): 25 | x_low, x_high = 0.5, 0.9 26 | error = 0.04 27 | ce = error ** 2 28 | # The x values will be uniformly distributed. 29 | num_trials = 1000 30 | num_samples = 30000 31 | bootstrap_valid = 0 32 | num_bins = 1 33 | total_len = 0.0 34 | sum_est = 0.0 35 | stds = [] 36 | for num_bins in [1, 2, 4, 8, 16, 32, 64]: 37 | # Define the estimator. 38 | bin_xs = np.random.uniform(size=num_samples * 2, low=x_low, high=x_high) 39 | bins = cal.get_equal_bins(bin_xs, num_bins=num_bins) 40 | def estimate_unbiased_ce(data): 41 | binned_data = cal.bin(data, bins) 42 | return np.sqrt(max(0.0, cal.unbiased_square_ce(binned_data))) 43 | ces = [] 44 | for i in range(num_trials): 45 | xs = np.random.uniform(size=num_samples, low=x_low, high=x_high) 46 | ys = np.random.binomial(n=1, p=xs-error) 47 | data = list(zip(xs, ys)) 48 | ces.append(estimate_unbiased_ce(data)) 49 | print(np.mean(ces), np.std(ces)) 50 | stds.append(np.std(ces)) 51 | print(stds) 52 | # def estimate_plugin_ce(data): 53 | # binned_data = cal.bin(data, bins) 54 | # return cal.plugin_ce(binned_data) ** 2 55 | # for i in range(num_trials): 56 | # xs = np.random.uniform(size=num_samples, low=x_low, high=x_high) 57 | # ys = np.random.binomial(n=1, p=xs+error) 58 | # data = list(zip(xs, ys)) 59 | # lower, mid, upper = cal.precentile_bootstrap_uncertainty( 60 | # data, estimate_plugin_ce, estimate_plugin_ce, num_samples=1000) 61 | # # est = estimate_unbiased_ce(data) 62 | # # mid = est 63 | # # print('mean', np.mean(ys), np.mean(xs)) 64 | # # print(abs(mid - ce) / ce, mid, ce) 65 | # est = estimate_plugin_ce(data) 66 | # sum_est += est 67 | # total_len += upper - lower 68 | # if lower <= ce <= upper: 69 | # bootstrap_valid += 1 70 | # else: 71 | # print('est', est) 72 | # print('interval', lower, upper) 73 | # print("Valid percent is {}".format(float(bootstrap_valid) / num_trials)) 74 | # print("Average length is {}".format(total_len / num_trials)) 75 | # print("Average est is {}".format(sum_est / num_trials)) 76 | 77 | if __name__ == "__main__": 78 | ce_test() 79 | -------------------------------------------------------------------------------- /examples/multiclass_example.py: -------------------------------------------------------------------------------- 1 | """Mneasuring calibration and calibrating a model for multiclass classification.""" 2 | 3 | import numpy as np 4 | import calibration 5 | import scipy 6 | 7 | 8 | def main(): 9 | # Make synthetic dataset. 10 | np.random.seed(0) # Keep results consistent. 11 | num_points = 10000 12 | d = 10 13 | (zs, ys) = synthetic_data(num_points=num_points, d=d) 14 | 15 | # Estimate a lower bound on the calibration error. 16 | # Here z_i are the per-class confidences of the uncalibrated model, y_i is the true label. 17 | calibration_error = calibration.get_calibration_error(zs, ys) 18 | print("Uncalibrated model calibration error is > %.2f%%" % (100 * calibration_error)) 19 | 20 | # Use Platt binning to train a recalibrator. 21 | calibrator = calibration.PlattBinnerMarginalCalibrator(num_points, num_bins=10) 22 | calibrator.train_calibration(zs, ys) 23 | 24 | # Measure the calibration error of recalibrated model. 25 | (test_zs, test_ys) = synthetic_data(num_points=num_points, d=d) 26 | calibrated_zs = calibrator.calibrate(test_zs) 27 | calibration_error = calibration.get_calibration_error(calibrated_zs, test_ys) 28 | print("Scaling-binning L2 calibration error is %.2f%%" % (100 * calibration_error)) 29 | 30 | # Get confidence intervals for the calibration error. 31 | [lower, _, upper] = calibration.get_calibration_error_uncertainties(calibrated_zs, test_ys) 32 | print(" Confidence interval is [%.2f%%, %.2f%%]" % (100 * lower, 100 * upper)) 33 | 34 | 35 | def synthetic_data(num_points, d): 36 | true_probs = np.random.dirichlet([1] * d, size=num_points) 37 | samples = vectorized_sample(true_probs, np.array(list(range(d)))) 38 | model_probs = sharpen(true_probs, T=5) 39 | return (model_probs, samples) 40 | 41 | def vectorized_sample(probs, items): 42 | s = probs.cumsum(axis=1) 43 | r = np.random.rand(probs.shape[0]) 44 | r = np.expand_dims(r, axis=-1) 45 | r = np.tile(r, (1, probs.shape[1])) 46 | k = (s < r).sum(axis=1) 47 | return items[k] 48 | 49 | def sharpen(probs, T): 50 | probs = np.log(np.clip(probs, 1e-6, 1-1e-6)) 51 | probs = probs * T 52 | return scipy.special.softmax(probs, axis=1) 53 | 54 | if __name__ == "__main__": 55 | main() -------------------------------------------------------------------------------- /examples/simple_example.py: -------------------------------------------------------------------------------- 1 | """Mneasuring calibration and calibrating a model for binary classification.""" 2 | 3 | import numpy as np 4 | import calibration 5 | 6 | 7 | def main(): 8 | # Make synthetic dataset. 9 | np.random.seed(0) # Keep results consistent. 10 | num_points = 1000 11 | (zs, ys) = synthetic_data_1d(num_points=num_points) 12 | 13 | # Estimate a lower bound on the calibration error. 14 | # Here z_i is the confidence of the uncalibrated model, y_i is the true label. 15 | calibration_error = calibration.get_calibration_error(zs, ys) 16 | print("Uncalibrated model calibration error is > %.2f%%" % (100 * calibration_error)) 17 | 18 | # Estimate the ECE. 19 | ece = calibration.get_ece(zs, ys) 20 | print("Uncalibrated model ECE is > %.2f%%" % (100 * ece)) 21 | 22 | # Use Platt binning to train a recalibrator. 23 | calibrator = calibration.PlattBinnerCalibrator(num_points, num_bins=10) 24 | calibrator.train_calibration(np.array(zs), ys) 25 | 26 | # Measure the calibration error of recalibrated model. 27 | (test_zs, test_ys) = synthetic_data_1d(num_points=num_points) 28 | calibrated_zs = calibrator.calibrate(test_zs) 29 | calibration_error = calibration.get_calibration_error(calibrated_zs, test_ys) 30 | print("Scaling-binning L2 calibration error is %.2f%%" % (100 * calibration_error)) 31 | 32 | # Get confidence intervals for the calibration error. 33 | [lower, _, upper] = calibration.get_calibration_error_uncertainties(calibrated_zs, test_ys) 34 | print(" Confidence interval is [%.2f%%, %.2f%%]" % (100 * lower, 100 * upper)) 35 | 36 | 37 | # Helper functions used to generate synthetic data. 38 | 39 | def synthetic_data_1d(num_points): 40 | f_true = platt_function(1, 1) 41 | return sample(f_true, np.random.uniform, num_points) 42 | 43 | def platt_function(a, b): 44 | """Return a (vectorized) platt function f: [0, 1] -> [0, 1] parameterized by a, b.""" 45 | def eval(x): 46 | x = np.log(x / (1 - x)) 47 | x = a * x + b 48 | return 1 / (1 + np.exp(-x)) 49 | return np.vectorize(eval) 50 | 51 | def sample(f, z_dist, n): 52 | """Returns ([z_1, ..., z_n], [y_1, ..., y_n]) where z_i ~ z_dist, y_i ~ Bernoulli(f(z_i)).""" 53 | zs = list(z_dist(size=n)) 54 | ps = f(zs) 55 | ys = list(np.random.binomial(1, p=ps)) 56 | return (zs, ys) 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /experiments/debiased_estimator/estimation_error_vs_bins.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from matplotlib.ticker import PercentFormatter 6 | import os 7 | import calibration as cal 8 | 9 | 10 | np.random.seed(0) # Keep results consistent. 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--probs_file', default='data/cifar_probs.dat', type=str, 14 | help='Name of file to load probs, labels pair.') 15 | parser.add_argument('--platt_data_size', default=2000, type=int, 16 | help='Number of examples to use for Platt Scaling.') 17 | parser.add_argument('--bin_data_size', default=2000, type=int, 18 | help='Number of examples to use for binning.') 19 | parser.add_argument('--num_bins', default=100, type=int, 20 | help='Bins to test estimators with.') 21 | 22 | 23 | def compare_scaling_binning_squared_ce( 24 | probs, labels, platt_data_size, bin_data_size, num_bins, ver_base_size=2000, 25 | ver_size_increment=1000, max_ver_size=7000, num_resamples=1000, 26 | save_prefix='./saved_files/debiased_estimator/', lp=2, 27 | Calibrator=cal.PlattBinnerTopCalibrator): 28 | calibrator = Calibrator(num_calibration=platt_data_size, num_bins=num_bins) 29 | calibrator.train_calibration(probs[:platt_data_size], labels[:platt_data_size]) 30 | predictions = cal.get_top_predictions(probs) 31 | correct = (predictions == labels).astype(np.int32) 32 | verification_correct = correct[bin_data_size:] 33 | verification_probs = calibrator.calibrate(probs[bin_data_size:]) 34 | verification_sizes = list(range(ver_base_size, 1 + min(max_ver_size, len(verification_probs)), 35 | ver_size_increment)) 36 | estimators = [lambda p, l: cal.get_calibration_error(p, l, p=lp, debias=False) ** lp, 37 | lambda p, l: cal.get_calibration_error(p, l, p=lp, debias=True) ** lp] 38 | estimates = get_estimates( 39 | estimators, verification_probs, verification_correct, verification_sizes, 40 | num_resamples) 41 | true_calibration = cal.get_calibration_error(verification_probs, verification_correct, p=lp, debias=False) ** lp 42 | print(true_calibration) 43 | print(np.sqrt(np.mean(estimates[1, -1, :]))) 44 | errors = np.abs(estimates - true_calibration) 45 | plot_mse_curve(errors, verification_sizes, num_resamples, save_prefix, num_bins) 46 | plot_histograms(errors, num_resamples, save_prefix, num_bins) 47 | 48 | 49 | def compare_scaling_ce( 50 | probs, labels, platt_data_size, bin_data_size, num_bins, ver_base_size=2000, 51 | ver_size_increment=1000, max_ver_size=7000, num_resamples=1000, 52 | save_prefix='./saved_files/debiased_estimator/', lp=1, Calibrator=cal.PlattTopCalibrator): 53 | calibrator = Calibrator(num_calibration=platt_data_size, num_bins=num_bins) 54 | calibrator.train_calibration(probs[:platt_data_size], labels[:platt_data_size]) 55 | predictions = cal.get_top_predictions(probs) 56 | correct = (predictions == labels).astype(np.int32) 57 | verification_correct = correct[bin_data_size:] 58 | verification_probs = calibrator.calibrate(probs[bin_data_size:]) 59 | verification_sizes = list(range(ver_base_size, 1 + min(max_ver_size, len(verification_probs)), 60 | ver_size_increment)) 61 | binning_probs = calibrator.calibrate(probs[:bin_data_size]) 62 | bins = cal.get_equal_bins(binning_probs, num_bins=num_bins) 63 | def plugin_estimator(p, l): 64 | data = list(zip(p, l)) 65 | binned_data = cal.bin(data, bins) 66 | return cal.plugin_ce(binned_data, power=lp) 67 | def debiased_estimator(p, l): 68 | data = list(zip(p, l)) 69 | binned_data = cal.bin(data, bins) 70 | if lp == 2: 71 | return cal.unbiased_l2_ce(binned_data) 72 | else: 73 | return cal.normal_debiased_ce(binned_data, power=lp) 74 | estimators = [plugin_estimator, debiased_estimator] 75 | estimates = get_estimates( 76 | estimators, verification_probs, verification_correct, verification_sizes, 77 | num_resamples) 78 | true_calibration = plugin_estimator(verification_probs, verification_correct) 79 | print(true_calibration) 80 | print(np.sqrt(np.mean(estimates[1, -1, :]))) 81 | errors = np.abs(estimates - true_calibration) 82 | plot_mse_curve(errors, verification_sizes, num_resamples, save_prefix, num_bins) 83 | plot_histograms(errors, num_resamples, save_prefix, num_bins) 84 | 85 | 86 | def get_estimates(estimators, verification_probs, verification_labels, verification_sizes, 87 | num_resamples=1000): 88 | # We want to compare the two estimators when varying the number of samples. 89 | # However, a single point of comparison does not tell us much about the estimators. 90 | # So we use resampling - we resample from the test set many times, and run the estimators 91 | # on the resamples. We stores these values. This gives us a sense of the range of values 92 | # the estimator might output. 93 | # So estimates[i][j][k] stores the estimate when using estimator i, with verification_sizes[j] 94 | # samples, in the k-th resampling. 95 | estimates = np.zeros((len(estimators), len(verification_sizes), num_resamples)) 96 | for ver_idx, verification_size in zip(range(len(verification_sizes)), verification_sizes): 97 | for k in range(num_resamples): 98 | # Resample 99 | indices = np.random.choice(list(range(len(verification_probs))), 100 | size=verification_size, replace=True) 101 | cur_verification_probs = [verification_probs[i] for i in indices] 102 | cur_verification_correct = [verification_labels[i] for i in indices] 103 | for i in range(len(estimators)): 104 | estimates[i][ver_idx][k] = estimators[i](cur_verification_probs, 105 | cur_verification_correct) 106 | 107 | estimates = np.sort(estimates, axis=-1) 108 | return estimates 109 | 110 | 111 | def plot_mse_curve(errors, verification_sizes, num_resamples, save_prefix, num_bins): 112 | plt.clf() 113 | errors = np.square(errors) 114 | accumulated_errors = np.mean(errors, axis=-1) 115 | error_bars_90 = 1.645 * np.std(errors, axis=-1) / np.sqrt(num_resamples) 116 | print(accumulated_errors) 117 | plt.errorbar( 118 | verification_sizes, accumulated_errors[0], yerr=[error_bars_90[0], error_bars_90[0]], 119 | barsabove=True, color='red', capsize=4, label='plugin') 120 | plt.errorbar( 121 | verification_sizes, accumulated_errors[1], yerr=[error_bars_90[1], error_bars_90[1]], 122 | barsabove=True, color='blue', capsize=4, label='debiased') 123 | plt.ylabel("MSE of Calibration Error") 124 | plt.xlabel("Number of Samples") 125 | plt.legend(loc='upper right') 126 | plt.tight_layout() 127 | save_name = save_prefix + "curve_" + str(num_bins) 128 | plt.ylim(bottom=0.0) 129 | plt.savefig(save_name) 130 | 131 | 132 | def plot_histograms(errors, num_resamples, save_prefix, num_bins): 133 | plt.clf() 134 | plt.ylabel("Number of estimates") 135 | plt.xlabel("Absolute deviation from ground truth") 136 | bins = np.linspace(np.min(errors[:, 0, :]), np.max(errors[:, 0, :]), 40) 137 | plt.hist(errors[0][0], bins, alpha=0.5, label='plugin') 138 | plt.hist(errors[1][0], bins, alpha=0.5, label='debiased') 139 | plt.legend(loc='upper right') 140 | plt.gca().yaxis.set_major_formatter(PercentFormatter(xmax=num_resamples)) 141 | plt.tight_layout() 142 | save_name = save_prefix + "histogram_" + str(num_bins) 143 | plt.savefig(save_name) 144 | 145 | 146 | def cifar_experiments(): 147 | probs, labels = cal.load_test_probs_labels('data/cifar_probs.dat') 148 | if not os.path.exists('./saved_files'): 149 | os.mkdir('./saved_files') 150 | if not os.path.exists('./saved_files/debiased_estimator/'): 151 | os.mkdir('./saved_files/debiased_estimator/') 152 | save_prefix = './saved_files/debiased_estimator/' 153 | compare_scaling_binning_squared_ce( 154 | probs, labels, platt_data_size=3000, bin_data_size=3000, num_bins=100, 155 | save_prefix=save_prefix+"cifar_scaling_binning_") 156 | compare_scaling_binning_squared_ce( 157 | probs, labels, platt_data_size=3000, bin_data_size=3000, num_bins=15, 158 | save_prefix=save_prefix+"cifar_scaling_binning_") 159 | compare_scaling_ce( 160 | probs, labels, platt_data_size=3000, bin_data_size=3000, num_bins=100, 161 | save_prefix=save_prefix+"cifar_scaling_ece_") 162 | compare_scaling_ce( 163 | probs, labels, platt_data_size=3000, bin_data_size=3000, num_bins=15, 164 | save_prefix=save_prefix+"cifar_scaling_ece_") 165 | 166 | 167 | def imagenet_experiments(): 168 | probs, labels = cal.load_test_probs_labels('data/imagenet_probs.dat') 169 | if not os.path.exists('./saved_files'): 170 | os.mkdir('./saved_files') 171 | if not os.path.exists('./saved_files/debiased_estimator/'): 172 | os.mkdir('./saved_files/debiased_estimator/') 173 | save_prefix = './saved_files/debiased_estimator/' 174 | compare_scaling_binning_squared_ce( 175 | probs, labels, platt_data_size=3000, bin_data_size=3000, num_bins=100, 176 | save_prefix=save_prefix+"imnet_scaling_binning_") 177 | compare_scaling_binning_squared_ce( 178 | probs, labels, platt_data_size=3000, bin_data_size=3000, num_bins=15, 179 | save_prefix=save_prefix+"imnet_scaling_binning_") 180 | compare_scaling_ce( 181 | probs, labels, platt_data_size=3000, bin_data_size=3000, num_bins=100, 182 | save_prefix=save_prefix+"imnet_scaling_ece_") 183 | compare_scaling_ce( 184 | probs, labels, platt_data_size=3000, bin_data_size=3000, num_bins=15, 185 | save_prefix=save_prefix+"imnet_scaling_ece_") 186 | 187 | 188 | def parse_input(): 189 | args = parser.parse_args() 190 | probs, labels = cal.load_test_probs_labels(args.probs_file) 191 | if not os.path.exists('./saved_files'): 192 | os.mkdir('./saved_files') 193 | if not os.path.exists('./saved_files/debiased_estimator/'): 194 | os.mkdir('./saved_files/debiased_estimator/') 195 | save_prefix = './saved_files/debiased_estimator/' 196 | compare_scaling_binning_squared_ce( 197 | probs, labels, args.platt_data_size, args.bin_data_size, args.num_bins, 198 | save_prefix=save_prefix) 199 | 200 | 201 | if __name__ == "__main__": 202 | cifar_experiments() 203 | imagenet_experiments() 204 | -------------------------------------------------------------------------------- /experiments/debiased_estimator/mse_vs_ce_tradeoff.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import bisect 4 | import matplotlib.pyplot as plt 5 | from sklearn.linear_model import LogisticRegression 6 | import os 7 | import lib.utils as utils 8 | 9 | 10 | np.random.seed(0) # Keep results consistent. 11 | 12 | 13 | def calibrate_marginals_experiment(probs, labels, k): 14 | num_calib = 3000 15 | num_bin = 3000 16 | num_cert = 4000 17 | assert(probs.shape[0] == num_calib + num_bin + num_cert) 18 | num_bins = 100 19 | bootstrap_samples = 100 20 | # First split by label? To ensure equal class numbers? Do this later. 21 | labels = utils.get_labels_one_hot(labels[:], k) 22 | mse = np.mean(np.square(labels - probs)) 23 | print('original mse is ', mse) 24 | calib_probs = probs[:num_calib, :] 25 | calib_labels = labels[:num_calib, :] 26 | bin_probs = probs[num_calib:num_calib + num_bin, :] 27 | bin_labels = labels[num_calib:num_calib + num_bin, :] 28 | cert_probs = probs[num_calib + num_bin:, :] 29 | cert_labels = labels[num_calib + num_bin:, :] 30 | mses = [] 31 | unbiased_ces = [] 32 | biased_ces = [] 33 | std_unbiased_ces = [] 34 | std_biased_ces = [] 35 | for num_bins in range(10, 101, 10): 36 | # Train a platt scaler and binner. 37 | platts = [] 38 | platt_binners_equal_points = [] 39 | for l in range(k): 40 | platt_l = utils.get_platt_scaler(calib_probs[:, l], calib_labels[:, l]) 41 | platts.append(platt_l) 42 | cal_probs_l = platt_l(calib_probs[:, l]) 43 | bins_l = utils.get_equal_bins(cal_probs_l, num_bins=num_bins) 44 | cal_bin_probs_l = platt_l(bin_probs[:, l]) 45 | platt_binner_l = utils.get_discrete_calibrator(cal_bin_probs_l, bins_l) 46 | platt_binners_equal_points.append(platt_binner_l) 47 | 48 | # Write a function that takes data and outputs the mse, ce 49 | def get_mse_ce(probs, labels, ce_est): 50 | mses = [] 51 | ces = [] 52 | probs = np.array(probs) 53 | labels = np.array(labels) 54 | for l in range(k): 55 | cal_probs_l = platt_binners_equal_points[l](platts[l](probs[:, l])) 56 | data = list(zip(cal_probs_l, labels[:, l])) 57 | bins_l = utils.get_discrete_bins(cal_probs_l) 58 | binned_data = utils.bin(data, bins_l) 59 | # probs = platts[l](probs[:, l]) 60 | # for p in [1, 5, 10, 20, 50, 85, 88.5, 92, 94, 96, 98, 100]: 61 | # print(np.percentile(probs, p)) 62 | # import time 63 | # time.sleep(100) 64 | # print('lengths') 65 | # print([len(d) for d in binned_data]) 66 | ces.append(ce_est(binned_data)) 67 | mses.append(np.mean([(prob - label) ** 2 for prob, label in data])) 68 | return np.mean(mses), np.mean(ces) 69 | 70 | def plugin_ce_squared(data): 71 | probs, labels = zip(*data) 72 | return get_mse_ce(probs, labels, lambda x: utils.plugin_ce(x) ** 2)[1] 73 | def mse(data): 74 | probs, labels = zip(*data) 75 | return get_mse_ce(probs, labels, lambda x: utils.plugin_ce(x) ** 2)[0] 76 | def unbiased_ce_squared(data): 77 | probs, labels = zip(*data) 78 | return get_mse_ce(probs, labels, utils.unbiased_square_ce)[1] 79 | 80 | mse, unbiased_ce = get_mse_ce( 81 | cert_probs, cert_labels, utils.unbiased_square_ce) 82 | mse, biased_ce = get_mse_ce( 83 | cert_probs, cert_labels, lambda x: utils.plugin_ce(x) ** 2) 84 | mses.append(mse) 85 | unbiased_ces.append(unbiased_ce) 86 | biased_ces.append(biased_ce) 87 | print('biased ce: ', np.sqrt(biased_ce)) 88 | print('mse: ', mse) 89 | print('improved ce: ', np.sqrt(unbiased_ce)) 90 | data = list(zip(list(cert_probs), list(cert_labels))) 91 | std_biased_ces.append( 92 | utils.bootstrap_std(data, plugin_ce_squared, num_samples=bootstrap_samples)) 93 | std_unbiased_ces.append( 94 | utils.bootstrap_std(data, unbiased_ce_squared, num_samples=bootstrap_samples)) 95 | 96 | std_multiplier = 1.3 # For one sided 90% confidence interval. 97 | upper_unbiased_ces = list(map(lambda p: np.sqrt(p[0] + std_multiplier * p[1]), 98 | zip(unbiased_ces, std_unbiased_ces))) 99 | upper_biased_ces = list(map(lambda p: np.sqrt(p[0] + std_multiplier * p[1]), 100 | zip(biased_ces, std_biased_ces))) 101 | # Get points on the Pareto curve, and plot them. 102 | def get_pareto_points(data): 103 | pareto_points = [] 104 | def dominated(p1, p2): 105 | return p1[0] >= p2[0] and p1[1] >= p2[1] 106 | for datum in data: 107 | num_dominated = sum(map(lambda x: dominated(datum, x), data)) 108 | if num_dominated == 1: 109 | pareto_points.append(datum) 110 | return pareto_points 111 | print(get_pareto_points(list(zip(upper_unbiased_ces, mses, list(range(5, 101, 5)))))) 112 | print(get_pareto_points(list(zip(upper_biased_ces, mses, list(range(5, 101, 5)))))) 113 | plot_unbiased_ces, plot_unbiased_mses = zip(*get_pareto_points(list(zip(upper_unbiased_ces, mses)))) 114 | plot_biased_ces, plot_biased_mses = zip(*get_pareto_points(list(zip(upper_biased_ces, mses)))) 115 | plt.title("MSE vs Calibration Error") 116 | plt.scatter(plot_unbiased_ces, plot_unbiased_mses, c='red', marker='o', label='Ours') 117 | plt.scatter(plot_biased_ces, plot_biased_mses, c='blue', marker='s', label='Plugin') 118 | plt.legend(loc='upper left') 119 | plt.ylim(0.0, 0.013) 120 | plt.xlabel("Squared Calibration Error") 121 | plt.ylabel("Mean-Squared Error") 122 | plt.tight_layout() 123 | save_name = "./saved_files/debiased_estimator/mse_vs_ce" 124 | plt.savefig(save_name) 125 | 126 | 127 | if __name__ == "__main__": 128 | if not os.path.exists('./saved_files'): 129 | os.mkdir('./saved_files') 130 | if not os.path.exists('./saved_files/debiased_estimator/'): 131 | os.mkdir('./saved_files/debiased_estimator/') 132 | 133 | probs, labels = utils.load_test_probs_labels('data/cifar_probs.dat') 134 | predictions = np.argmax(probs, 1) 135 | probabilities = np.max(probs, 1) 136 | accuracy = np.mean(labels[:] == predictions) 137 | print('accuracy is ' + str(accuracy)) 138 | calibrate_marginals_experiment(probs, labels, k=10) 139 | -------------------------------------------------------------------------------- /experiments/platt_not_calibrated/cifar10vgg.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | import tensorflow.keras 4 | from tensorflow.keras.datasets import cifar10 5 | from tensorflow.keras.preprocessing.image import ImageDataGenerator 6 | from tensorflow.keras.models import Sequential 7 | from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten 8 | from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization 9 | from tensorflow.keras import optimizers 10 | import numpy as np 11 | from tensorflow.keras import regularizers 12 | 13 | class cifar10vgg: 14 | def __init__(self, load_path="./data/cifar10vgg.h5", train=False): 15 | self.num_classes = 10 16 | self.weight_decay = 0.0005 17 | self.x_shape = [32,32,3] 18 | 19 | self.model = self.build_model() 20 | if train: 21 | self.model = self.train(self.model) 22 | else: 23 | print('loading weights') 24 | self.model.load_weights(load_path) 25 | print('loaded weights') 26 | 27 | 28 | def build_model(self): 29 | # Build the network of vgg for 10 classes with massive dropout and weight decay as described in the paper. 30 | 31 | model = Sequential() 32 | weight_decay = self.weight_decay 33 | 34 | model.add(Conv2D(64, (3, 3), padding='same', 35 | input_shape=self.x_shape,kernel_regularizer=regularizers.l2(weight_decay))) 36 | model.add(Activation('relu')) 37 | model.add(BatchNormalization()) 38 | model.add(Dropout(0.3)) 39 | 40 | model.add(Conv2D(64, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 41 | model.add(Activation('relu')) 42 | model.add(BatchNormalization()) 43 | 44 | model.add(MaxPooling2D(pool_size=(2, 2))) 45 | 46 | model.add(Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 47 | model.add(Activation('relu')) 48 | model.add(BatchNormalization()) 49 | model.add(Dropout(0.4)) 50 | 51 | model.add(Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 52 | model.add(Activation('relu')) 53 | model.add(BatchNormalization()) 54 | 55 | model.add(MaxPooling2D(pool_size=(2, 2))) 56 | 57 | model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 58 | model.add(Activation('relu')) 59 | model.add(BatchNormalization()) 60 | model.add(Dropout(0.4)) 61 | 62 | model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 63 | model.add(Activation('relu')) 64 | model.add(BatchNormalization()) 65 | model.add(Dropout(0.4)) 66 | 67 | model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 68 | model.add(Activation('relu')) 69 | model.add(BatchNormalization()) 70 | 71 | model.add(MaxPooling2D(pool_size=(2, 2))) 72 | 73 | 74 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 75 | model.add(Activation('relu')) 76 | model.add(BatchNormalization()) 77 | model.add(Dropout(0.4)) 78 | 79 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 80 | model.add(Activation('relu')) 81 | model.add(BatchNormalization()) 82 | model.add(Dropout(0.4)) 83 | 84 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 85 | model.add(Activation('relu')) 86 | model.add(BatchNormalization()) 87 | 88 | model.add(MaxPooling2D(pool_size=(2, 2))) 89 | 90 | 91 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 92 | model.add(Activation('relu')) 93 | model.add(BatchNormalization()) 94 | model.add(Dropout(0.4)) 95 | 96 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 97 | model.add(Activation('relu')) 98 | model.add(BatchNormalization()) 99 | model.add(Dropout(0.4)) 100 | 101 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 102 | model.add(Activation('relu')) 103 | model.add(BatchNormalization()) 104 | 105 | model.add(MaxPooling2D(pool_size=(2, 2))) 106 | model.add(Dropout(0.5)) 107 | 108 | model.add(Flatten()) 109 | model.add(Dense(512,kernel_regularizer=regularizers.l2(weight_decay))) 110 | model.add(Activation('relu')) 111 | model.add(BatchNormalization()) 112 | 113 | model.add(Dropout(0.5)) 114 | model.add(Dense(self.num_classes)) 115 | model.add(Activation('softmax')) 116 | return model 117 | 118 | 119 | def normalize(self,X_train,X_test): 120 | #this function normalize inputs for zero mean and unit variance 121 | # it is used when training a model. 122 | # Input: training set and test set 123 | # Output: normalized training set and test set according to the trianing set statistics. 124 | mean = np.mean(X_train,axis=(0,1,2,3)) 125 | std = np.std(X_train, axis=(0, 1, 2, 3)) 126 | X_train = (X_train-mean)/(std+1e-7) 127 | X_test = (X_test-mean)/(std+1e-7) 128 | return X_train, X_test 129 | 130 | def normalize_production(self,x): 131 | #this function is used to normalize instances in production according to saved training set statistics 132 | # Input: X - a training set 133 | # Output X - a normalized training set according to normalization constants. 134 | 135 | #these values produced during first training and are general for the standard cifar10 training set normalization 136 | mean = 120.707 137 | std = 64.15 138 | return (x-mean)/(std+1e-7) 139 | 140 | def predict(self,x,normalize=True,batch_size=50): 141 | if normalize: 142 | x = self.normalize_production(x) 143 | return self.model.predict(x,batch_size) 144 | 145 | def train(self,model): 146 | 147 | #training parameters 148 | batch_size = 128 149 | maxepoches = 250 150 | learning_rate = 0.1 151 | lr_decay = 1e-6 152 | lr_drop = 20 153 | # The data, shuffled and split between train and test sets: 154 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 155 | x_train = x_train.astype('float32') 156 | x_test = x_test.astype('float32') 157 | x_train, x_test = self.normalize(x_train, x_test) 158 | 159 | y_train = keras.utils.to_categorical(y_train, self.num_classes) 160 | y_test = keras.utils.to_categorical(y_test, self.num_classes) 161 | 162 | def lr_scheduler(epoch): 163 | return learning_rate * (0.5 ** (epoch // lr_drop)) 164 | reduce_lr = keras.callbacks.LearningRateScheduler(lr_scheduler) 165 | 166 | #data augmentation 167 | datagen = ImageDataGenerator( 168 | featurewise_center=False, # set input mean to 0 over the dataset 169 | samplewise_center=False, # set each sample mean to 0 170 | featurewise_std_normalization=False, # divide inputs by std of the dataset 171 | samplewise_std_normalization=False, # divide each input by its std 172 | zca_whitening=False, # apply ZCA whitening 173 | rotation_range=15, # randomly rotate images in the range (degrees, 0 to 180) 174 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 175 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 176 | horizontal_flip=True, # randomly flip images 177 | vertical_flip=False) # randomly flip images 178 | # (std, mean, and principal components if ZCA whitening is applied). 179 | datagen.fit(x_train) 180 | 181 | 182 | 183 | #optimization details 184 | sgd = optimizers.SGD(lr=learning_rate, decay=lr_decay, momentum=0.9, nesterov=True) 185 | model.compile(loss='categorical_crossentropy', optimizer=sgd,metrics=['accuracy']) 186 | 187 | 188 | # training process in a for loop with learning rate drop every 25 epoches. 189 | 190 | historytemp = model.fit_generator(datagen.flow(x_train, y_train, 191 | batch_size=batch_size), 192 | steps_per_epoch=x_train.shape[0] // batch_size, 193 | epochs=maxepoches, 194 | validation_data=(x_test, y_test),callbacks=[reduce_lr],verbose=2) 195 | model.save_weights('cifar10vgg.h5') 196 | return model 197 | 198 | def evaluate(self, x_valid, y_valid): 199 | predicted_x = self.model.predict(x_valid) 200 | residuals = np.argmax(predicted_x,1)!=np.argmax(y_valid,1) 201 | return [-1.0, sum(residuals) * 1.0 /len(residuals)] 202 | 203 | 204 | if __name__ == '__main__': 205 | 206 | 207 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 208 | x_train = x_train.astype('float32') 209 | x_test = x_test.astype('float32') 210 | 211 | y_train = keras.utils.to_categorical(y_train, 10) 212 | y_test = keras.utils.to_categorical(y_test, 10) 213 | 214 | model = cifar10vgg(train=False) 215 | 216 | predicted_x = model.predict(x_test) 217 | residuals = np.argmax(predicted_x,1)!=np.argmax(y_test,1) 218 | 219 | loss = sum(residuals) * 1.0 /len(residuals) 220 | print("the validation 0/1 loss is: ",loss) 221 | -------------------------------------------------------------------------------- /experiments/platt_not_calibrated/lower_bounds.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from matplotlib import rc 6 | import os 7 | import calibration as cal 8 | 9 | # Keep the results consistent. 10 | np.random.seed(0) 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--probs_path', default='data/cifar_probs.dat', type=str, 14 | help='Name of file to load probs, labels pair.') 15 | parser.add_argument('--calibration_data_size', default=1000, type=int, 16 | help='Number of examples to use for Platt Scaling.') 17 | parser.add_argument('--bin_data_size', default=1000, type=int, 18 | help='Number of examples to use for binning.') 19 | parser.add_argument('--plot_save_file', default='lower_bound_plot.png', type=str, 20 | help='File to save lower bound plot.') 21 | parser.add_argument('--binning', default='equal_bins', type=str, 22 | help='The binning strategy to use.') 23 | parser.add_argument('--lp', default=2, type=int, 24 | help='Use the lp-calibration error.') 25 | parser.add_argument('--bins_list', default="2, 4, 8, 16, 32, 64, 128", 26 | type=lambda s: [int(t) for t in s.split(',')], 27 | help='Bin sizes to evaluate calibration error at.') 28 | parser.add_argument('--num_samples', default=1000, type=int, 29 | help='Number of resamples for bootstrap confidence intervals.') 30 | 31 | 32 | def lower_bound_experiment(probs, labels, calibration_data_size, bin_data_size, bins_list, 33 | save_name='cmp_est', binning_func=cal.get_equal_bins, lp=2, 34 | num_samples=1000): 35 | # Shuffle the probs and labels. 36 | np.random.seed(0) # Keep results consistent. 37 | indices = np.random.choice(list(range(len(probs))), size=len(probs), replace=False) 38 | probs = [probs[i] for i in indices] 39 | labels = [labels[i] for i in indices] 40 | predictions = cal.get_top_predictions(probs) 41 | probs = cal.get_top_probs(probs) 42 | correct = (predictions == labels) 43 | print('num_correct: ', sum(correct)) 44 | # Platt scale on first chunk of data 45 | platt = cal.get_platt_scaler(probs[:calibration_data_size], correct[:calibration_data_size]) 46 | platt_probs = platt(probs) 47 | lower, middle, upper = [], [], [] 48 | for num_bins in bins_list: 49 | bins = binning_func( 50 | platt_probs[:calibration_data_size+bin_data_size], num_bins=num_bins) 51 | verification_probs = platt_probs[calibration_data_size+bin_data_size:] 52 | verification_correct = correct[calibration_data_size+bin_data_size:] 53 | verification_data = list(zip(verification_probs, verification_correct)) 54 | def estimator(data): 55 | binned_data = cal.bin(data, bins) 56 | return cal.plugin_ce(binned_data, power=lp) 57 | print('estimate: ', estimator(verification_data)) 58 | estimate_interval = cal.bootstrap_uncertainty( 59 | verification_data, estimator, num_samples=1000) 60 | lower.append(estimate_interval[0]) 61 | middle.append(estimate_interval[1]) 62 | upper.append(estimate_interval[2]) 63 | print('interval: ', estimate_interval) 64 | # Plot the results. 65 | lower_errors = np.array(middle) - np.array(lower) 66 | upper_errors = np.array(upper) - np.array(middle) 67 | plt.clf() 68 | font = {'family' : 'normal', 'size': 18} 69 | rc('font', **font) 70 | plt.errorbar( 71 | bins_list, middle, yerr=[lower_errors, upper_errors], 72 | barsabove=True, fmt = 'none', color='black', capsize=4) 73 | plt.scatter(bins_list, middle, color='black') 74 | plt.xlabel(r"No. of bins") 75 | if lp == 2: 76 | plt.ylabel("Calibration error") 77 | else: 78 | plt.ylabel("l%d Calibration error" % lp) 79 | plt.xscale('log', basex=2) 80 | ax = plt.gca() 81 | ax.spines['right'].set_visible(False) 82 | ax.spines['top'].set_visible(False) 83 | ax.yaxis.set_ticks_position('left') 84 | ax.xaxis.set_ticks_position('bottom') 85 | plt.tight_layout() 86 | plt.savefig(save_name) 87 | 88 | 89 | def cifar_experiment(savefile, binning_func=cal.get_equal_bins, lp=2): 90 | np.random.seed(0) 91 | calibration_data_size = 1000 92 | bin_data_size = 1000 93 | probs, labels = cal.load_test_probs_labels('cifar_probs.dat') 94 | lower_bound_experiment(probs, labels, calibration_data_size, bin_data_size, 95 | bins_list=[2, 4, 8, 16, 32, 64, 128], save_name=savefile, 96 | binning_func=binning_func, lp=lp) 97 | 98 | 99 | def imagenet_experiment(savefile, binning_func=cal.get_equal_bins, lp=2): 100 | np.random.seed(0) 101 | calibration_data_size = 20000 102 | bin_data_size = 5000 103 | probs, labels = cal.load_test_probs_labels('imagenet_probs.dat') 104 | lower_bound_experiment(probs, labels, calibration_data_size, bin_data_size, 105 | bins_list=[2, 4, 8, 16, 32, 64, 128, 256, 512], save_name=savefile, 106 | binning_func=binning_func, lp=lp) 107 | 108 | 109 | if __name__ == "__main__": 110 | if not os.path.exists('./saved_files'): 111 | os.mkdir('./saved_files') 112 | if not os.path.exists('./saved_files/platt_not_calibrated/'): 113 | os.mkdir('./saved_files/platt_not_calibrated/') 114 | prefix = './saved_files/platt_not_calibrated/' 115 | args = parser.parse_args() 116 | 117 | if args.binning == 'equal_prob_bins': 118 | binning = cal.get_equal_prob_bins 119 | else: 120 | binning = cal.get_equal_bins 121 | 122 | probs, labels = cal.load_test_probs_labels(args.probs_path) 123 | print(args.bins_list) 124 | lower_bound_experiment( 125 | probs, labels, args.calibration_data_size, args.bin_data_size, args.bins_list, 126 | save_name=prefix+args.plot_save_file, binning_func=binning, lp=args.lp, 127 | num_samples=args.num_samples) 128 | -------------------------------------------------------------------------------- /experiments/platt_not_calibrated/save_cifar_logits.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | from tensorflow.keras.datasets import cifar10 4 | 5 | import cifar10vgg 6 | import lib.utils as utils 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--save_file_path', default='cifar_probs.dat', type=str, 10 | help='Name of file to save probs, labels pair.') 11 | 12 | if __name__ == "__main__": 13 | args = parser.parse_args() 14 | utils.save_test_probs_labels(cifar10, cifar10vgg.cifar10vgg(), args.save_file_path) 15 | -------------------------------------------------------------------------------- /experiments/platt_not_calibrated/save_imagenet_logits.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | from tensorflow.keras.applications import vgg16 4 | from tensorflow.python.keras.applications.imagenet_utils import decode_predictions 5 | from tensorflow.keras.preprocessing.image import ImageDataGenerator 6 | import numpy as np 7 | import pickle 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--save_file_path', default='data/imagenet_probs.dat', type=str, 11 | help='Name of file to save probs, labels pair.') 12 | parser.add_argument('--load_folder', default='.', type=str, 13 | help='Path to folder containing ImageNet images.') 14 | parser.add_argument('--batch_size', default=32, type=int, 15 | help='Batch size of input data to model.') 16 | 17 | if __name__ == "__main__": 18 | args = parser.parse_args() 19 | vgg_model = vgg16.VGG16(weights='imagenet') 20 | eval_datagen = ImageDataGenerator(preprocessing_function=vgg16.preprocess_input) 21 | generator = eval_datagen.flow_from_directory( 22 | args.load_folder, shuffle=False, batch_size=args.batch_size, target_size=(224, 224)) 23 | labels = generator.classes 24 | num_steps = int(np.ceil(len(labels) * 1.0 / args.batch_size)) 25 | probs = vgg_model.predict_generator(generator, steps=num_steps) 26 | pickle.dump((probs, labels), open(args.save_file_path, "wb")) 27 | -------------------------------------------------------------------------------- /experiments/scaling_binning_calibrator/compare_calibrators.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from matplotlib import rc 5 | import time 6 | import os 7 | 8 | import calibration as cal 9 | 10 | 11 | def eval_top_calibration(probs, eval_probs, labels): 12 | correct = (cal.get_top_predictions(eval_probs) == labels) 13 | data = list(zip(probs, correct)) 14 | bins = cal.get_discrete_bins(probs) 15 | binned_data = cal.bin(data, bins) 16 | return cal.plugin_ce(binned_data) ** 2 17 | 18 | 19 | def eval_marginal_calibration(probs, eval_probs, labels, plugin=True): 20 | ces = [] # Compute the calibration error per class, then take the average. 21 | k = eval_probs.shape[1] 22 | labels_one_hot = cal.get_labels_one_hot(np.array(labels), k) 23 | for c in range(k): 24 | probs_c = probs[:, c] 25 | labels_c = labels_one_hot[:, c] 26 | data_c = list(zip(probs_c, labels_c)) 27 | bins_c = cal.get_discrete_bins(probs_c) 28 | binned_data_c = cal.bin(data_c, bins_c) 29 | if plugin: 30 | ce_c = cal.plugin_ce(binned_data_c) ** 2 31 | else: 32 | ce_c = cal.unbiased_square_ce(binned_data_c) 33 | ces.append(ce_c) 34 | return np.mean(ces) 35 | 36 | 37 | def upper_bound_marginal_calibration_unbiased(probs, eval_probs, labels, samples=30): 38 | data = list(zip(probs, eval_probs, labels)) 39 | def evaluator(data): 40 | probs, eval_probs, labels = list(zip(*data)) 41 | probs, eval_probs, labels = np.array(probs), np.array(eval_probs), np.array(labels) 42 | return eval_marginal_calibration(probs, eval_probs, labels, plugin=False) 43 | estimate = evaluator(data) 44 | conf_interval = cal.bootstrap_std(data, evaluator, num_samples=samples) 45 | return estimate + 1.3 * conf_interval 46 | 47 | 48 | def upper_bound_marginal_calibration_biased(probs, eval_probs, labels, samples=30): 49 | data = list(zip(probs, eval_probs, labels)) 50 | def evaluator(data): 51 | probs, eval_probs, labels = list(zip(*data)) 52 | probs, eval_probs, labels = np.array(probs), np.array(eval_probs), np.array(labels) 53 | return eval_marginal_calibration(probs, eval_probs, labels, plugin=True) 54 | estimate = evaluator(data) 55 | conf_interval = cal.bootstrap_std(data, evaluator, num_samples=samples) 56 | return estimate + 1.3 * conf_interval 57 | 58 | 59 | def compare_calibrators(data_sampler, num_bins, Calibrators, calibration_evaluators, 60 | eval_mse): 61 | """Get one sample of the calibration error and MSE for a set of calibrators. 62 | 63 | Args: 64 | data_sampler: A function that takes in 0 arguments 65 | and returns calib_probs, calib_labels, eval_probs, eval_labels, mse_probs, 66 | mse_labels, where calib_probs and calib_labels should be used by the calibrator 67 | to calibrate, eval_probs and eval_labels should be used to measure the calibration 68 | error, and mse_probs, mse_labels should be used to measure the mean-squared error. 69 | num_bins: integer number of bins. 70 | Calibrators: calibrator classes from e.g. calibrators.py. 71 | calibration_evaluators: a list of functions. calibration_evaluators[i] takes 72 | the output from the calibration method of calibrator i, eval_probs, 73 | eval_labels, and returns a float representing the calibration error 74 | (or an upper bound of it) of calibrator i. We suppose multiple calibration 75 | evaluators because different calibrators may require different ways 76 | of estimating/upper bounding calibration error. 77 | eval_mse: a function that takes in the output of the calibration method, 78 | mse_probs, mse_labels, and returns a float representing the MSE. 79 | """ 80 | calib_probs, calib_labels, eval_probs, eval_labels, mse_probs, mse_labels = data_sampler() 81 | l2_ces = [] 82 | mses = [] 83 | train_time = 0.0 84 | eval_time = 0.0 85 | start_total = time.time() 86 | for Calibrator, i in zip(Calibrators, range(len(Calibrators))): 87 | calibrator = Calibrator(1, num_bins) 88 | start_time = time.time() 89 | calibrator.train_calibration(calib_probs, calib_labels) 90 | train_time += (time.time() - start_time) 91 | calibrated_probs = calibrator.calibrate(eval_probs) 92 | start_time = time.time() 93 | mid = calibration_evaluators[i](calibrated_probs, eval_probs, eval_labels) 94 | eval_time += time.time() - start_time 95 | cal_mse_probs = calibrator.calibrate(mse_probs) 96 | mse = eval_mse(cal_mse_probs, mse_probs, mse_labels) 97 | l2_ces.append(mid) 98 | mses.append(mse) 99 | # print('train_time: ', train_time) 100 | # print('eval_time: ', eval_time) 101 | # print('total_time: ', time.time() - start_total) 102 | return l2_ces, mses 103 | 104 | 105 | def average_calibration(data_sampler, num_bins, Calibrators, calibration_evaluators, 106 | eval_mse, num_trials=100): 107 | l2_ces, mses = [], [] 108 | for i in range(num_trials): 109 | cur_l2_ces, cur_mses = compare_calibrators( 110 | data_sampler, num_bins, Calibrators, 111 | calibration_evaluators, eval_mse) 112 | l2_ces.append(cur_l2_ces) 113 | mses.append(cur_mses) 114 | l2_ce_means = np.mean(l2_ces, axis=0) 115 | l2_ce_stddevs = np.std(l2_ces, axis=0) / np.sqrt(num_trials) 116 | mses = np.mean(mses, axis=0) 117 | mse_stddevs = np.std(mses, axis=0) / np.sqrt(num_trials) 118 | return l2_ce_means, l2_ce_stddevs, mses, mse_stddevs 119 | 120 | 121 | def vary_bin_calibration(data_sampler, num_bins_list, Calibrators, calibration_evaluators, 122 | eval_mse, num_trials=100): 123 | ce_list = [] 124 | stddev_list = [] 125 | mse_list = [] 126 | for num_bins in num_bins_list: 127 | l2_ce_means, l2_ce_stddevs, mses, mse_stddevs = average_calibration( 128 | data_sampler, num_bins, Calibrators, 129 | calibration_evaluators, eval_mse, num_trials) 130 | ce_list.append(l2_ce_means) 131 | stddev_list.append(l2_ce_stddevs) 132 | mse_list.append(mses) 133 | return np.transpose(ce_list), np.transpose(stddev_list), np.transpose(mse_list) 134 | 135 | 136 | def plot_ces(bins_list, l2_ces, l2_ce_stddevs, save_path='marginal_ces.png'): 137 | plt.clf() 138 | font = {'family' : 'normal', 139 | 'size' : 16} 140 | rc('font', **font) 141 | plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0)) 142 | # 90% confidence intervals. 143 | error_bars_90 = 1.645 * l2_ce_stddevs 144 | plt.errorbar( 145 | bins_list, l2_ces[0], yerr=[error_bars_90[0], error_bars_90[0]], 146 | barsabove=True, color='red', capsize=4, label='histogram', linestyle='--') 147 | plt.errorbar( 148 | bins_list, l2_ces[1], yerr=[error_bars_90[1], error_bars_90[1]], 149 | barsabove=True, color='blue', capsize=4, label='scaling-binning') 150 | plt.ylabel("Squared Calibration Error") 151 | plt.xlabel("Number of Bins") 152 | plt.ylim(bottom=0.0) 153 | plt.legend(loc='lower right') 154 | plt.tight_layout() 155 | plt.savefig(save_path) 156 | 157 | 158 | def plot_mse_ce_curve(bins_list, l2_ces, mses, xlim=None, ylim=None, 159 | save_path='marginal_mse_vs_ces.png'): 160 | plt.clf() 161 | font = {'family' : 'normal', 162 | 'size' : 16} 163 | rc('font', **font) 164 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0)) 165 | def get_pareto_points(data): 166 | pareto_points = [] 167 | def dominated(p1, p2): 168 | return p1[0] >= p2[0] and p1[1] >= p2[1] 169 | for datum in data: 170 | num_dominated = sum(map(lambda x: dominated(datum, x), data)) 171 | if num_dominated == 1: 172 | pareto_points.append(datum) 173 | return pareto_points 174 | print(get_pareto_points(list(zip(l2_ces[0], mses[0], bins_list)))) 175 | print(get_pareto_points(list(zip(l2_ces[1], mses[1], bins_list)))) 176 | l2ces0, mses0 = zip(*get_pareto_points(list(zip(l2_ces[0], mses[0])))) 177 | l2ces1, mses1 = zip(*get_pareto_points(list(zip(l2_ces[1], mses[1])))) 178 | plt.scatter(l2ces0, mses0, c='red', marker='o', label='histogram') 179 | plt.scatter(l2ces1, mses1, c='blue', marker='x', label='scaling-binning') 180 | plt.legend(loc='upper right') 181 | if xlim is not None: 182 | plt.xlim(xlim) 183 | if ylim is not None: 184 | plt.ylim(ylim) 185 | plt.xlabel("Squared Calibration Error") 186 | plt.ylabel("Mean-Squared Error") 187 | plt.tight_layout() 188 | plt.savefig(save_path) 189 | 190 | 191 | def make_calibration_data_sampler(probs, labels, num_calibration): 192 | def data_sampler(): 193 | assert len(probs) == len(labels) 194 | indices = np.random.choice(list(range(len(probs))), 195 | size=num_calibration, replace=True) 196 | calib_probs = np.array([probs[i] for i in indices]) 197 | calib_labels = np.array([labels[i] for i in indices]) 198 | eval_probs = probs 199 | eval_labels = labels 200 | return calib_probs, calib_labels, eval_probs, eval_labels, probs, labels 201 | return data_sampler 202 | 203 | 204 | def make_calibration_eval_data_sampler(probs, labels, num_calib, num_eval): 205 | def data_sampler(): 206 | assert len(probs) == len(labels) 207 | calib_indices = np.random.choice( 208 | list(range(len(probs))), size=num_calib, replace=True) 209 | eval_indices = np.random.choice( 210 | list(range(len(probs))), size=num_eval, replace=True) 211 | calib_probs = np.array([probs[i] for i in calib_indices]) 212 | calib_labels = np.array([labels[i] for i in calib_indices]) 213 | eval_probs = np.array([probs[i] for i in eval_indices]) 214 | eval_labels = np.array([labels[i] for i in eval_indices]) 215 | return calib_probs, calib_labels, eval_probs, eval_labels, probs, labels 216 | return data_sampler 217 | 218 | 219 | def cifar10_experiment_top(probs_path, ce_save_path, mse_ce_save_path, num_trials=100): 220 | probs, labels = cal.load_test_probs_labels(probs_path) 221 | bins_list = list(range(10, 101, 10)) 222 | num_calibration = 1000 223 | l2_ces, l2_stddevs, mses = vary_bin_calibration( 224 | data_sampler=make_calibration_data_sampler(probs, labels, num_calibration), 225 | num_bins_list=bins_list, 226 | Calibrators=[cal.HistogramTopCalibrator, cal.PlattBinnerTopCalibrator], 227 | calibration_evaluators=[eval_top_calibration, eval_top_calibration], 228 | eval_mse=cal.eval_top_mse, 229 | num_trials=num_trials) 230 | plot_mse_ce_curve(bins_list, l2_ces, mses, xlim=(0.0, 0.002), ylim=(0.0425, 0.045), 231 | save_path=mse_ce_save_path) 232 | plot_ces(bins_list, l2_ces, l2_stddevs, save_path=ce_save_path) 233 | 234 | 235 | def cifar10_experiment_marginal(probs_path, ce_save_path, mse_ce_save_path, num_trials=100): 236 | probs, labels = cal.load_test_probs_labels(probs_path) 237 | bins_list = list(range(10, 101, 10)) 238 | num_calibration = 1000 239 | l2_ces, l2_stddevs, mses = vary_bin_calibration( 240 | data_sampler=make_calibration_data_sampler(probs, labels, num_calibration), 241 | num_bins_list=bins_list, 242 | Calibrators=[cal.HistogramMarginalCalibrator, 243 | cal.PlattBinnerMarginalCalibrator], 244 | calibration_evaluators=[eval_marginal_calibration, eval_marginal_calibration], 245 | eval_mse=cal.eval_marginal_mse, 246 | num_trials=num_trials) 247 | plot_mse_ce_curve(bins_list, l2_ces, mses, xlim=(0.0, 0.0006), ylim=(0.04, 0.08), 248 | save_path=mse_ce_save_path) 249 | plot_ces(bins_list, l2_ces, l2_stddevs, save_path=ce_save_path) 250 | 251 | 252 | def imagenet_experiment_top(probs_path, ce_save_path, mse_ce_save_path, num_trials=100): 253 | probs, labels = cal.load_test_probs_labels(probs_path) 254 | bins_list = list(range(10, 101, 10)) 255 | num_calibration = 1000 256 | l2_ces, l2_stddevs, mses = vary_bin_calibration( 257 | data_sampler=make_calibration_data_sampler(probs, labels, num_calibration), 258 | num_bins_list=bins_list, 259 | Calibrators=[cal.HistogramTopCalibrator, cal.PlattBinnerTopCalibrator], 260 | calibration_evaluators=[eval_top_calibration, eval_top_calibration], 261 | eval_mse=cal.eval_top_mse, 262 | num_trials=num_trials) 263 | plot_mse_ce_curve(bins_list, l2_ces, mses, save_path=mse_ce_save_path) 264 | plot_ces(bins_list, l2_ces, l2_stddevs, save_path=ce_save_path) 265 | 266 | 267 | def imagenet_experiment_marginal(probs_path, ce_save_path, mse_ce_save_path, num_trials=20): 268 | probs, labels = cal.load_test_probs_labels(probs_path) 269 | bins_list = list(range(10, 101, 10)) 270 | num_calibration = 25000 271 | l2_ces, l2_stddevs, mses = vary_bin_calibration( 272 | data_sampler=make_calibration_data_sampler(probs, labels, num_calibration), 273 | num_bins_list=bins_list, 274 | Calibrators=[cal.HistogramMarginalCalibrator, 275 | cal.PlattBinnerMarginalCalibrator], 276 | calibration_evaluators=[eval_marginal_calibration, eval_marginal_calibration], 277 | eval_mse=cal.eval_marginal_mse, 278 | num_trials=num_trials) 279 | plot_mse_ce_curve(bins_list, l2_ces, mses, save_path=mse_ce_save_path) 280 | plot_ces(bins_list, l2_ces, l2_stddevs, save_path=ce_save_path) 281 | 282 | 283 | if __name__ == "__main__": 284 | if not os.path.exists('./saved_files'): 285 | os.mkdir('./saved_files') 286 | if not os.path.exists('./saved_files/scaling_binning_calibrator/'): 287 | os.mkdir('./saved_files/scaling_binning_calibrator/') 288 | prefix = './saved_files/scaling_binning_calibrator/' 289 | # Main marginal calibration CIFAR-10 experiment in the paper. 290 | np.random.seed(0) # Keep results consistent. 291 | cifar10_experiment_marginal( 292 | probs_path='data/cifar_probs.dat', 293 | ce_save_path=prefix+'cifar_marginal_ce_plot', 294 | mse_ce_save_path=prefix+'cifar_marginal_mse_ce_plot') 295 | # Top-label calibration CIFAR experiment in the Appendix, 1000 points. 296 | np.random.seed(0) # Keep results consistent. 297 | cifar10_experiment_top( 298 | probs_path='data/cifar_probs.dat', 299 | ce_save_path=prefix+'cifar_top_ce_plot', 300 | mse_ce_save_path=prefix+'cifar_top_mse_ce_plot') 301 | # Top-label calibration ImageNet experiment in the Appendix, 1000 points. 302 | np.random.seed(0) # Keep results consistent. 303 | imagenet_experiment_top( 304 | probs_path='data/imagenet_probs.dat', 305 | ce_save_path=prefix+'imagenet_top_ce_plot', 306 | mse_ce_save_path=prefix+'imagenet_top_mse_ce_plot') 307 | -------------------------------------------------------------------------------- /experiments/synthetic/synthetic.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import numpy as np 4 | import calibration as cal 5 | import matplotlib.pyplot as plt 6 | import pickle 7 | import os 8 | from pathlib import Path 9 | 10 | # Keep the results consistent. 11 | np.random.seed(0) 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--experiment_name', default='vary_n_a1_b0', type=str, 15 | help='Name of the experiment to run.') 16 | 17 | 18 | def platt_function(a, b): 19 | """Return a (vectorized) platt function f: [0, 1] -> [0, 1] parameterized by a, b.""" 20 | def eval(x): 21 | x = np.log(x / (1 - x)) 22 | x = a * x + b 23 | return 1 / (1 + np.exp(-x)) 24 | return np.vectorize(eval) 25 | 26 | 27 | def noisy_platt_function(a, b, eps, l, u, num_bins=100000): 28 | """Return a (vectorized) noisy platt function f: [l, u] -> [0, 1] parameterized by a, b. 29 | 30 | Let g denote the Platt function parameterized by a, b. f accepts input x, where l <= x <= u. 31 | The range [l, u] is split into num_bins equally spaced intervals. If x is in interval j, 32 | f(x) = g(x) + noise[j], where noise[j] is sampled to be {-eps, +eps} with equal probability. 33 | f is vectorized, that is, it can also accept a numpy array as argument, and will apply f 34 | to each element in the array.""" 35 | def platt(x): 36 | x = np.log(x / (1 - x)) 37 | x = a * x + b 38 | return 1 / (1 + np.exp(-x)) 39 | assert(1 - eps >= platt(l) >= eps) 40 | assert(1 - eps >= platt(u) >= eps) 41 | noise = (np.random.binomial(1, np.ones(num_bins + 1) * 0.5) * 2 - 1) * eps 42 | def eval(x): 43 | assert l <= x <= u 44 | b = np.floor((x - l) / (u - l) * num_bins).astype(np.int32) 45 | assert(np.all(b <= num_bins)) 46 | b -= (b == num_bins) 47 | return platt(x) + noise[b] 48 | return np.vectorize(eval) 49 | 50 | 51 | def sample(f, z_dist, n): 52 | """Returns ([z_1, ..., z_n], [y_1, ..., y_n]) where z_i ~ z_dist, y_i ~ Bernoulli(f(z_i)).""" 53 | zs = z_dist(size=n) 54 | ps = f(zs) 55 | ys = np.random.binomial(1, p=ps) 56 | return (zs, ys) 57 | 58 | 59 | def evaluate_l2ce(f, calibrator, z_dist, n): 60 | """Returns the calibration error of the calibrator on z_dist, f using n samples.""" 61 | zs = z_dist(size=n) 62 | ps = f(zs) 63 | phats = calibrator.calibrate(zs) 64 | bins = cal.get_discrete_bins(phats) 65 | data = list(zip(phats, ps)) 66 | binned_data = cal.bin(data, bins) 67 | return cal.plugin_ce(binned_data) ** 2 68 | 69 | 70 | def evaluate_mse(f, calibrator, z_dist, n): 71 | """Returns the MSE of the calibrator on z_dist, f using n samples.""" 72 | zs = z_dist(size=n) 73 | ps = f(zs) 74 | phats = calibrator.calibrate(zs) 75 | return np.mean(np.square(ps - phats)) 76 | 77 | 78 | def get_errors(f, Calibrators, z_dist, nb_args, num_trials, num_evaluation, 79 | evaluate_error=evaluate_l2ce): 80 | """Get the errors (+std-devs) of calibrators for each (n, b) in nb_args.""" 81 | means = np.zeros((len(Calibrators), len(nb_args))) 82 | std_devs = np.zeros((len(Calibrators), len(nb_args))) 83 | for i, Calibrator in zip(range(len(Calibrators)), Calibrators): 84 | for j, (num_calibration, num_bins) in zip(range(len(nb_args)), nb_args): 85 | current_errors = [] 86 | for k in range(num_trials): 87 | zs, ys = sample(f, z_dist=z_dist, n=num_calibration) 88 | calibrator = Calibrator(num_calibration=num_calibration, num_bins=num_bins) 89 | calibrator.train_calibration(zs, ys) 90 | error = evaluate_error(f, calibrator, z_dist, n=num_evaluation) 91 | assert(error >= 0.0) 92 | current_errors.append(error) 93 | means[i][j] = np.mean(current_errors) 94 | std_devs[i][j] = np.std(current_errors) / np.sqrt(num_trials) 95 | return means, std_devs 96 | 97 | 98 | def sweep_n_platt(a, b, save_file, base_n=500, max_n_multiplier=9, bins=10, 99 | num_trials=1000, num_evaluation=1000): 100 | f = platt_function(a, b) 101 | Calibrators = [cal.PlattCalibrator, 102 | cal.HistogramCalibrator, 103 | cal.PlattBinnerCalibrator] 104 | names = ['scaling', 'binning', 'scaling-binning'] 105 | dist = np.random.uniform 106 | nb_args = [(base_n * i, bins) for i in range(1, max_n_multiplier)] 107 | means, stddevs = get_errors(f, Calibrators, dist, nb_args, num_trials, num_evaluation) 108 | pickle.dump((names, nb_args, means, stddevs), open(save_file, "wb")) 109 | 110 | 111 | def sweep_b_platt(a, b, save_file, n=2000, base_bins=10, max_bin_multiplier=9, 112 | num_trials=1000, num_evaluation=1000): 113 | f = platt_function(a, b) 114 | Calibrators = [cal.PlattCalibrator, 115 | cal.HistogramCalibrator, 116 | cal.PlattBinnerCalibrator] 117 | names = ['scaling', 'binning', 'scaling-binning'] 118 | dist = np.random.uniform 119 | nb_args = [(n, base_bins * i) for i in range(1, max_bin_multiplier)] 120 | means, stddevs = get_errors(f, Calibrators, dist, nb_args, num_trials, num_evaluation) 121 | pickle.dump((names, nb_args, means, stddevs), open(save_file, "wb")) 122 | 123 | 124 | def sweep_n_noisy_platt(a, b, save_file, base_n=500, max_n_multiplier=9, bins=10, 125 | num_trials=1000, num_evaluation=100): 126 | l, u = 0.25, 0.75 # Probably needs to change. 127 | f = noisy_platt_function(a, b, 0.02, 0.25, 0.75) # Probably needs to change. 128 | Calibrators = [cal.PlattCalibrator, 129 | cal.HistogramCalibrator, 130 | cal.PlattBinnerCalibrator] 131 | names = ['scaling', 'binning', 'scaling-binning'] 132 | def dist(size): 133 | return np.random.uniform(low=l, high=u, size=size) 134 | nb_args = [(base_n * i, bins) for i in range(1, max_n_multiplier)] 135 | means, stddevs = get_errors(f, Calibrators, dist, nb_args, num_trials, num_evaluation) 136 | pickle.dump((names, nb_args, means, stddevs), open(save_file, "wb")) 137 | 138 | 139 | # Plots 1/eps^2 against n. 140 | def plot_sweep_n(load_file, save_file): 141 | # TODO: add legends. 142 | (names, nb_args, means, stddevs) = pickle.load(open(load_file, "rb")) 143 | error_bars_90 = (1.645 / (np.square(means))) * stddevs 144 | plt.clf() 145 | def plot_inv_err_n(method_means, method_error_bars_90, color, calibrator_name): 146 | plt.errorbar([n for (n, b) in nb_args], 1 / method_means, color=color, 147 | yerr=[method_error_bars_90, method_error_bars_90], 148 | barsabove=True, capsize=4, label=calibrator_name) 149 | plt.ylabel("1 / epsilon^2") 150 | plt.xlabel("n (number of calibration points)") 151 | plt.tight_layout() 152 | plt.savefig(save_file+'_'+calibrator_name) 153 | colors = ['red', 'green', 'blue'] 154 | for i in range(len(names)): 155 | plt.clf() 156 | plot_inv_err_n(means[i], error_bars_90[i], color=colors[i], 157 | calibrator_name=names[i]) 158 | print('calibrators', names) 159 | print('nb_args', nb_args) 160 | print('means', means) 161 | print('stddevs', stddevs) 162 | # # TODO: include method names to avoid confusion. 163 | # for i in range(means.shape[0]): 164 | # factor_1_3 = divide(means[i][1], means[i][3], stddevs[i][1], stddevs[i][3]) 165 | # print('Method', i, 'going from', nb_args[1], nb_args[3], 'error goes', factor_1_3) 166 | 167 | 168 | def plot_noisy_eps_n(load_file, save_file, skip_binning=True): 169 | (names, nb_args, means, stddevs) = pickle.load(open(load_file, "rb")) 170 | error_bars_90 = 1.645 * stddevs 171 | plt.clf() 172 | def plot_err(method_means, method_error_bars_90, color, calibrator_name): 173 | plt.errorbar([n for (n, b) in nb_args], method_means, color=color, 174 | yerr=[method_error_bars_90, method_error_bars_90], 175 | barsabove=True, capsize=4, label=calibrator_name) 176 | plt.ylabel("epsilon^2") 177 | plt.xlabel("n (number of calibration points)") 178 | colors = ['red', 'green', 'blue'] 179 | for i in range(len(names)): 180 | if names[i] == 'binning' and skip_binning: 181 | pass 182 | else: 183 | plot_err(means[i], error_bars_90[i], color=colors[i], 184 | calibrator_name=names[i]) 185 | plt.tight_layout() 186 | plt.legend(loc='upper right') 187 | plt.savefig(save_file) 188 | print('calibrators', names) 189 | print('nb_args', nb_args) 190 | print('means', means) 191 | print('stddevs', stddevs) 192 | 193 | 194 | def plot_sweep_b(load_file, save_file): 195 | # TODO: Add legends. 196 | (names, nb_args, means, stddevs) = pickle.load(open(load_file, "rb")) 197 | error_bars_90 = (1.645 / (np.square(means))) * stddevs 198 | def plot_inv_err_b(method_means, method_error_bars_90, color, calibrator_name): 199 | plt.errorbar([b for (n, b) in nb_args], 1 / method_means, color=color, 200 | yerr=[method_error_bars_90, method_error_bars_90], 201 | barsabove=True, capsize=4, label=calibrator_name) 202 | plt.ylabel("1 / epsilon^2") 203 | plt.xlabel("b (number of bins)") 204 | plt.tight_layout() 205 | plt.savefig(save_file+'_'+calibrator_name) 206 | colors = ['red', 'green', 'blue'] 207 | for i in range(len(names)): 208 | plt.clf() 209 | plot_inv_err_b(means[i], error_bars_90[i], color=colors[i], 210 | calibrator_name=names[i]) 211 | print('calibrators', names) 212 | print('nb_args', nb_args) 213 | print('means', means) 214 | print('stddevs', stddevs) 215 | 216 | 217 | def divide(mu1, mu2, sigma1, sigma2): 218 | # Use delta method to compute confidence intervals for division. 219 | mu = mu1 / mu2 220 | sigma = np.sqrt(1.0 / (mu2 ** 2) * (sigma1 ** 2) + (mu1 ** 2) / (mu2 ** 4) * (sigma2 ** 2)) 221 | return mu, sigma 222 | 223 | 224 | def plot_curve(f, save_file, l=1e-8, u=1.0-1e-8): 225 | xs = np.arange(l, u, 1 / 1000.0) 226 | ys = f(xs) 227 | plt.clf() 228 | plt.plot(xs, ys) 229 | plt.ylabel("P(Y = 1 | z)") 230 | plt.xlabel("z") 231 | plt.tight_layout() 232 | parent = Path(save_file).parent 233 | plt.savefig(save_file) 234 | 235 | 236 | if __name__ == "__main__": 237 | if not os.path.exists('./saved_files'): 238 | os.mkdir('./saved_files') 239 | if not os.path.exists('./saved_files/synthetic/'): 240 | os.mkdir('./saved_files/synthetic/') 241 | prefix = './saved_files/synthetic/' 242 | args = parser.parse_args() 243 | 244 | if args.experiment_name == 'vary_n_a1_b0': 245 | f = platt_function(1, 0) 246 | plot_curve(f, prefix+'curve_vary_n_a1_b0') 247 | sweep_n_platt(1, 0, prefix+'vary_n_a1_b0') 248 | plot_sweep_n(load_file=prefix+'vary_n_a1_b0', 249 | save_file=prefix+'vary_n_a1_b0') 250 | elif args.experiment_name == 'vary_b_a1_b0': 251 | f = platt_function(1, 0) 252 | plot_curve(f, prefix+'curve_vary_b_a1_b0') 253 | sweep_b_platt(1, 0, prefix+'vary_b_a1_b0') 254 | plot_sweep_b(load_file=prefix+'vary_b_a1_b0', 255 | save_file=prefix+'vary_b_a1_b0') 256 | elif args.experiment_name == 'vary_n_a2_b1': 257 | f = platt_function(2, 1) 258 | plot_curve(f, prefix+'curve_vary_n_a2_b1') 259 | sweep_n_platt(2, 1, prefix+'vary_n_a2_b1') 260 | plot_sweep_n(load_file=prefix+'vary_n_a2_b1', 261 | save_file=prefix+'vary_n_a2_b1') 262 | elif args.experiment_name == 'vary_b_a2_b1': 263 | f = platt_function(2, 1) 264 | plot_curve(f, prefix+'curve_vary_b_a2_b1') 265 | sweep_b_platt(2, 1, prefix+'vary_b_a2_b1') 266 | plot_sweep_b(load_file=prefix+'vary_b_a2_b1', 267 | save_file=prefix+'vary_b_a2_b1') 268 | elif args.experiment_name == 'noisy_vary_n_a2_b1': 269 | f = noisy_platt_function(2, 1, eps=0.02, l=0.25, u=0.75) 270 | plot_curve(f, prefix+'noisy_curve_vary_n_a2_b1', l=0.25, u=0.75) 271 | sweep_n_noisy_platt(2, 1, prefix+'noisy_vary_n_a2_b1') 272 | plot_noisy_eps_n(load_file=prefix+'noisy_vary_n_a2_b1', 273 | save_file=prefix+'noisy_error_plot_vary_n_a2_b1') 274 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="uncertainty-calibration", 8 | version="0.1.4", 9 | author="Ananya Kumar", 10 | author_email="skywalker94@gmail.com", 11 | description="Utilities to calibrate model uncertainties and measure calibration.", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/AnanyaKumar/verified_calibration", 15 | packages=setuptools.find_packages(), 16 | install_requires=['numpy', 'scikit-learn', 'parameterized'], 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | ], 22 | python_requires='>=3.6', 23 | ) 24 | --------------------------------------------------------------------------------