├── LICENSE ├── README.md └── ece_kde.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Teodora Popordanoska 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Calibration Error Estimator 2 | This is the official code repository for ["A Consistent and Differentiable Lp Canonical Calibration Error Estimator"]( 3 | https://arxiv.org/abs/2210.07810), published in NeurIPS 2022. 4 | 5 | The paper proposes $ECE^{KDE}$, a consistent and differentiable estimator of the Lp calibration error. To model a density 6 | over a simplex, we use Kernel Density Estimation (KDE) with Dirichlet kernels. This estimator can tractably capture the 7 | highest form of calibration, called canonical (or distribution) calibration, which requires the entire probability 8 | vector to be calibrated. 9 | 10 | ## Usage 11 | $ECE^{KDE}$ can be directly optimized alongside any loss function in a calibration regularized training objective: 12 | 13 | $$f = \arg\min_{f\in \mathcal{F}}\, \Bigl(\operatorname{Risk}(f) + \lambda \cdot \operatorname{CE}(f)\Bigr). $$ 14 | 15 | The weight $\lambda$ is chosen via cross-validation. 16 | 17 | Additionally, the estimator can be used as a metric to evaluate canonical (distribution), marginal (classwise) and top-label (confidence) calibration. 18 | 19 | 20 | ## To use it in your project 21 | Copy the file `ece_kde.py` to your repo. You can obtain the estimate of CE with the method `get_ece_kde`. The bandwidth 22 | of the kernel can either be manually set, or chosen by maximizing the leave-one-out likelihood with the method 23 | `get_bandwidth`. 24 | For example, an estimate of canonical CE as defined in Equation 9 in the paper can be obtained with: 25 | ``` 26 | # Generate dummy probability scores and labels 27 | f = torch.rand((50, 3)) 28 | f = f / torch.sum(f, dim=1).unsqueeze(-1) 29 | y = torch.randint(0, 3, (50,)) 30 | 31 | get_ece_kde(f, y, bandwidth=0.001, p=1, mc_type='canonical', device='cpu') 32 | ``` 33 | The code is still in its preliminary version. A demo will be available soon. 34 | 35 | ## Reference 36 | If you found this work or code useful, please cite: 37 | 38 | ``` 39 | @inproceedings{Popordanoska2022b, 40 | title={A Consistent and Differentiable $L_p$ Canonical Calibration Error Estimator}, 41 | AUTHOR = {Popordanoska, Teodora and Sayer, Raphael and Blaschko, Matthew B.}, 42 | YEAR = {2022}, 43 | booktitle = {Advances in Neural Information Processing Systems}, 44 | } 45 | ``` 46 | 47 | ## License 48 | 49 | Everything is licensed under the [MIT License](https://opensource.org/licenses/MIT). 50 | -------------------------------------------------------------------------------- /ece_kde.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def get_bandwidth(f, device): 6 | """ 7 | Select a bandwidth for the kernel based on maximizing the leave-one-out likelihood (LOO MLE). 8 | 9 | :param f: The vector containing the probability scores, shape [num_samples, num_classes] 10 | :param device: The device type: 'cpu' or 'cuda' 11 | 12 | :return: The bandwidth of the kernel 13 | """ 14 | bandwidths = torch.cat((torch.logspace(start=-5, end=-1, steps=15), torch.linspace(0.2, 1, steps=5))) 15 | max_b = -1 16 | max_l = 0 17 | n = len(f) 18 | for b in bandwidths: 19 | log_kern = get_kernel(f, b, device) 20 | log_fhat = torch.logsumexp(log_kern, 1) - torch.log(n-1) 21 | l = torch.sum(log_fhat) 22 | if l > max_l: 23 | max_l = l 24 | max_b = b 25 | 26 | return max_b 27 | 28 | 29 | def get_ece_kde(f, y, bandwidth, p, mc_type, device): 30 | """ 31 | Calculate an estimate of Lp calibration error. 32 | 33 | :param f: The vector containing the probability scores, shape [num_samples, num_classes] 34 | :param y: The vector containing the labels, shape [num_samples] 35 | :param bandwidth: The bandwidth of the kernel 36 | :param p: The p-norm. Typically, p=1 or p=2 37 | :param mc_type: The type of multiclass calibration: canonical, marginal or top_label 38 | :param device: The device type: 'cpu' or 'cuda' 39 | 40 | :return: An estimate of Lp calibration error 41 | """ 42 | check_input(f, bandwidth, mc_type) 43 | if f.shape[1] == 1: 44 | return 2 * get_ratio_binary(f, y, bandwidth, p, device) 45 | else: 46 | if mc_type == 'canonical': 47 | return get_ratio_canonical(f, y, bandwidth, p, device) 48 | elif mc_type == 'marginal': 49 | return get_ratio_marginal_vect(f, y, bandwidth, p, device) 50 | elif mc_type == 'top_label': 51 | return get_ratio_toplabel(f, y, bandwidth, p, device) 52 | 53 | 54 | def get_ratio_binary(f, y, bandwidth, p, device): 55 | assert f.shape[1] == 1 56 | 57 | log_kern = get_kernel(f, bandwidth, device) 58 | 59 | return get_kde_for_ece(f, y, log_kern, p) 60 | 61 | 62 | def get_ratio_canonical(f, y, bandwidth, p, device): 63 | if f.shape[1] > 60: 64 | # Slower but more numerically stable implementation for larger number of classes 65 | return get_ratio_canonical_log(f, y, bandwidth, p, device) 66 | 67 | log_kern = get_kernel(f, bandwidth, device) 68 | kern = torch.exp(log_kern) 69 | 70 | y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float32) 71 | kern_y = torch.matmul(kern, y_onehot) 72 | den = torch.sum(kern, dim=1) 73 | # to avoid division by 0 74 | den = torch.clamp(den, min=1e-10) 75 | 76 | ratio = kern_y / den.unsqueeze(-1) 77 | ratio = torch.sum(torch.abs(ratio - f)**p, dim=1) 78 | 79 | return torch.mean(ratio) 80 | 81 | 82 | # Note for training: Make sure there are at least two examples for every class present in the batch, otherwise 83 | # LogsumexpBackward returns nans. 84 | def get_ratio_canonical_log(f, y, bandwidth, p, device='cpu'): 85 | log_kern = get_kernel(f, bandwidth, device) 86 | y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float32) 87 | log_y = torch.log(y_onehot) 88 | log_den = torch.logsumexp(log_kern, dim=1) 89 | final_ratio = 0 90 | for k in range(f.shape[1]): 91 | log_kern_y = log_kern + (torch.ones([f.shape[0], 1]) * log_y[:, k].unsqueeze(0)) 92 | log_inner_ratio = torch.logsumexp(log_kern_y, dim=1) - log_den 93 | inner_ratio = torch.exp(log_inner_ratio) 94 | inner_diff = torch.abs(inner_ratio - f[:, k])**p 95 | final_ratio += inner_diff 96 | 97 | return torch.mean(final_ratio) 98 | 99 | 100 | def get_ratio_marginal_vect(f, y, bandwidth, p, device): 101 | y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float32) 102 | log_kern_vect = beta_kernel(f, f, bandwidth).squeeze() 103 | log_kern_diag = torch.diag(torch.finfo(torch.float).min * torch.ones(len(f))).to(device) 104 | # Multiclass case 105 | log_kern_diag_repeated = f.shape[1] * [log_kern_diag] 106 | log_kern_diag_repeated = torch.stack(log_kern_diag_repeated, dim=2) 107 | log_kern_vect = log_kern_vect + log_kern_diag_repeated 108 | 109 | return get_kde_for_ece_vect(f, y_onehot, log_kern_vect, p) 110 | 111 | 112 | def get_ratio_toplabel(f, y, bandwidth, p, device): 113 | f_max, indices = torch.max(f, 1) 114 | f_max = f_max.unsqueeze(-1) 115 | y_max = (y == indices).to(torch.int) 116 | 117 | return get_ratio_binary(f_max, y_max, bandwidth, p, device) 118 | 119 | 120 | def get_kde_for_ece_vect(f, y, log_kern, p): 121 | log_kern_y = log_kern * y 122 | # Trick: -inf instead of 0 in log space 123 | log_kern_y[log_kern_y == 0] = torch.finfo(torch.float).min 124 | 125 | log_num = torch.logsumexp(log_kern_y, dim=1) 126 | log_den = torch.logsumexp(log_kern, dim=1) 127 | 128 | log_ratio = log_num - log_den 129 | ratio = torch.exp(log_ratio) 130 | ratio = torch.abs(ratio - f)**p 131 | 132 | return torch.sum(torch.mean(ratio, dim=0)) 133 | 134 | 135 | def get_kde_for_ece(f, y, log_kern, p): 136 | f = f.squeeze() 137 | N = len(f) 138 | # Select the entries where y = 1 139 | idx = torch.where(y == 1)[0] 140 | if not idx.numel(): 141 | return torch.sum((torch.abs(-f))**p) / N 142 | 143 | if idx.numel() == 1: 144 | # because of -inf in the vector 145 | log_kern = torch.cat((log_kern[:idx], log_kern[idx+1:])) 146 | f_one = f[idx] 147 | f = torch.cat((f[:idx], f[idx+1:])) 148 | 149 | log_kern_y = torch.index_select(log_kern, 1, idx) 150 | 151 | log_num = torch.logsumexp(log_kern_y, dim=1) 152 | log_den = torch.logsumexp(log_kern, dim=1) 153 | 154 | log_ratio = log_num - log_den 155 | ratio = torch.exp(log_ratio) 156 | ratio = torch.abs(ratio - f)**p 157 | 158 | if idx.numel() == 1: 159 | return (ratio.sum() + f_one ** p)/N 160 | 161 | return torch.mean(ratio) 162 | 163 | 164 | def get_kernel(f, bandwidth, device): 165 | # if num_classes == 1 166 | if f.shape[1] == 1: 167 | log_kern = beta_kernel(f, f, bandwidth).squeeze() 168 | else: 169 | log_kern = dirichlet_kernel(f, bandwidth).squeeze() 170 | # Trick: -inf on the diagonal 171 | return log_kern + torch.diag(torch.finfo(torch.float).min * torch.ones(len(f))).to(device) 172 | 173 | 174 | def beta_kernel(z, zi, bandwidth=0.1): 175 | p = zi / bandwidth + 1 176 | q = (1-zi) / bandwidth + 1 177 | z = z.unsqueeze(-2) 178 | 179 | log_beta = torch.lgamma(p) + torch.lgamma(q) - torch.lgamma(p + q) 180 | log_num = (p-1) * torch.log(z) + (q-1) * torch.log(1-z) 181 | log_beta_pdf = log_num - log_beta 182 | 183 | return log_beta_pdf 184 | 185 | 186 | def dirichlet_kernel(z, bandwidth=0.1): 187 | alphas = z / bandwidth + 1 188 | 189 | log_beta = (torch.sum((torch.lgamma(alphas)), dim=1) - torch.lgamma(torch.sum(alphas, dim=1))) 190 | log_num = torch.matmul(torch.log(z), (alphas-1).T) 191 | log_dir_pdf = log_num - log_beta 192 | 193 | return log_dir_pdf 194 | 195 | 196 | def check_input(f, bandwidth, mc_type): 197 | assert not isnan(f) 198 | assert len(f.shape) == 2 199 | assert bandwidth > 0 200 | assert torch.min(f) >= 0 201 | assert torch.max(f) <= 1 202 | 203 | 204 | def isnan(a): 205 | return torch.any(torch.isnan(a)) 206 | --------------------------------------------------------------------------------