├── figures ├── formulae_1.png ├── formulae_2.png ├── formulae_3.png ├── output_15_1.png ├── output_15_11.png ├── output_15_13.png ├── output_15_15.png ├── output_15_17.png ├── output_15_3.png ├── output_15_5.png ├── output_15_7.png ├── output_15_9.png └── output_7_0.png ├── kmeans.py ├── utils.py ├── neon.py └── README.md /figures/formulae_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobkauffmann/neon_demo/HEAD/figures/formulae_1.png -------------------------------------------------------------------------------- /figures/formulae_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobkauffmann/neon_demo/HEAD/figures/formulae_2.png -------------------------------------------------------------------------------- /figures/formulae_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobkauffmann/neon_demo/HEAD/figures/formulae_3.png -------------------------------------------------------------------------------- /figures/output_15_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobkauffmann/neon_demo/HEAD/figures/output_15_1.png -------------------------------------------------------------------------------- /figures/output_15_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobkauffmann/neon_demo/HEAD/figures/output_15_11.png -------------------------------------------------------------------------------- /figures/output_15_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobkauffmann/neon_demo/HEAD/figures/output_15_13.png -------------------------------------------------------------------------------- /figures/output_15_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobkauffmann/neon_demo/HEAD/figures/output_15_15.png -------------------------------------------------------------------------------- /figures/output_15_17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobkauffmann/neon_demo/HEAD/figures/output_15_17.png -------------------------------------------------------------------------------- /figures/output_15_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobkauffmann/neon_demo/HEAD/figures/output_15_3.png -------------------------------------------------------------------------------- /figures/output_15_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobkauffmann/neon_demo/HEAD/figures/output_15_5.png -------------------------------------------------------------------------------- /figures/output_15_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobkauffmann/neon_demo/HEAD/figures/output_15_7.png -------------------------------------------------------------------------------- /figures/output_15_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobkauffmann/neon_demo/HEAD/figures/output_15_9.png -------------------------------------------------------------------------------- /figures/output_7_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobkauffmann/neon_demo/HEAD/figures/output_7_0.png -------------------------------------------------------------------------------- /kmeans.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.cluster import KMeans as KMeans_sk 3 | import numpy 4 | 5 | class KMeans(torch.nn.Module): 6 | def __init__(self, n_clusters, random_state=None): 7 | super().__init__() 8 | self.n_clusters = n_clusters 9 | self.kmeans = KMeans_sk(n_clusters=self.n_clusters, random_state=random_state) 10 | 11 | def fit(self, X): 12 | if torch.is_tensor(X): 13 | X = X.numpy() 14 | self.kmeans.fit(X) 15 | self.centroids = torch.from_numpy(self.kmeans.cluster_centers_).double() 16 | 17 | def forward(self, x): 18 | N = x.shape[0] 19 | distances = torch.cdist(x, self.centroids)**2 20 | top_val = torch.topk(distances, k=2, dim=1, largest=False).values 21 | fx = torch.diff(top_val).squeeze() 22 | return fx 23 | 24 | def decision(self, X): 25 | distances = torch.cdist(X, self.centroids)**2 26 | return distances.argmin(-1) 27 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as tr 3 | from sklearn.decomposition import PCA 4 | from sklearn.datasets import load_wine 5 | from sklearn.preprocessing import MinMaxScaler 6 | import matplotlib.pyplot as plt 7 | from matplotlib import cm 8 | from scipy.optimize import linear_sum_assignment 9 | from sklearn.metrics.cluster import contingency_matrix 10 | 11 | cmap = cm.get_cmap('tab20', 20) 12 | 13 | def plot_explanation(x, R, feature_names=None, vlim=None): 14 | if feature_names is None: 15 | feature_names = list(range(len(x))) 16 | else: 17 | feature_names = [fn.replace('_',' ') for fn in feature_names] 18 | plt.figure(figsize=(4,2)) 19 | plt.subplot(121) 20 | plt.title('data point') 21 | plt.gca().set_axisbelow(True), plt.grid(linestyle='dashed') 22 | negative = x.clamp(max=0) 23 | positive = x.clamp(min=0) 24 | plt.barh(range(len(feature_names)), negative, color='c') 25 | plt.barh(range(len(feature_names)), positive, color='m') 26 | plt.yticks(range(len(feature_names)), feature_names) 27 | plt.gca().invert_yaxis() 28 | plt.xlim(0,1.1) 29 | 30 | plt.subplot(122) 31 | plt.title('feature relevance') 32 | plt.gca().set_axisbelow(True), plt.grid(linestyle='dashed') 33 | negative = R.clamp(max=0) 34 | positive = R.clamp(min=0) 35 | if vlim is None: 36 | vlim = max(abs(negative).max(), positive.max()) + .3 37 | plt.barh(range(len(feature_names)), negative, color='b') 38 | plt.barh(range(len(feature_names)), positive, color='r') 39 | plt.xlim(-vlim, vlim) 40 | plt.yticks(range(len(feature_names)),[]*len(feature_names)) 41 | plt.gca().invert_yaxis() 42 | -------------------------------------------------------------------------------- /neon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy 3 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 4 | 5 | # soft minpooling layer 6 | def smin(X, s, dim=-1): 7 | return -(1/s)*torch.logsumexp(-s*X, dim=dim) + (1/s)*numpy.log(X.shape[dim]) 8 | 9 | # soft maxpooling layer 10 | def smax(X, s, dim=-1): 11 | return (1/s)*torch.logsumexp(s*X, dim=dim) - (1/s)*numpy.log(X.shape[dim]) 12 | 13 | class NeuralizedKMeans(torch.nn.Module): 14 | def __init__(self, kmeans): 15 | super().__init__() 16 | self.n_clusters = kmeans.n_clusters 17 | self.kmeans = kmeans 18 | K, D = kmeans.centroids.shape 19 | self.W = torch.empty(K, K-1, D, dtype=torch.double) 20 | self.b = torch.empty(K, K-1, dtype=torch.double) 21 | for c in range(K): 22 | for kk in range(K-1): 23 | k = kk if kk < c else kk + 1 24 | self.W[c, kk] = 2*(kmeans.centroids[c] - kmeans.centroids[k]) 25 | self.b[c, kk] = (torch.norm(kmeans.centroids[k])**2 - 26 | torch.norm(kmeans.centroids[c])**2) 27 | 28 | def h(self, X): 29 | z = torch.einsum('ckd,nd->nck', self.W, X) + self.b 30 | return z 31 | 32 | def forward(self, X, c=None): 33 | h = self.h(X) 34 | out = h.min(-1).values 35 | if c is None: 36 | return out.max(-1).values 37 | else: 38 | return out[:,c] 39 | 40 | def inc(z, eps=1e-9): 41 | return z + eps*(2*(z >= 0) - 1) 42 | 43 | def beta_heuristic(model, X): 44 | fc = model(X) 45 | return 1/fc.mean() 46 | 47 | def neon(model, X, beta): 48 | R = torch.zeros_like(X) 49 | if not torch.is_tensor(beta): 50 | beta = torch.tensor(beta) 51 | for i in range(X.shape[0]): 52 | x = X[[i]] 53 | ### forward 54 | h = model.h(x) 55 | out = h.min(-1).values 56 | c = out.argmax() 57 | ### backward 58 | pk = torch.nn.functional.softmin(beta*h[:,c], dim=-1) 59 | Rk = out[:,c] * pk 60 | knc = [k for k in range(model.n_clusters) if k!=c] 61 | Z = model.W[c]*(x - .5*(model.kmeans.centroids[[c]] + model.kmeans.centroids[knc])) 62 | Z = Z / inc(Z.sum(-1, keepdims=True)) 63 | R[i] = (Z * Rk.view(-1,1)).sum(0) 64 | return R 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Explaining K-Means with NEON 2 | This demo demonstrates the NEON approach for explaining K-Means clustering predictions. The method is fully described in 3 | 4 |
5 | Kauffmann, J., Esders, M., Ruff, L., Montavon, G., Samek, W., & Müller, K.-R.7 | 8 | 9 | ```python 10 | from kmeans import KMeans 11 | from neon import NeuralizedKMeans, neon 12 | from utils import * 13 | ``` 14 | 15 | ## Loading the data 16 | First, we load the dataset and normalize the data to a reasonable range. 17 | 18 | 19 | ```python 20 | wine = load_wine() 21 | X, ytrue = wine['data'], tr.tensor(wine['target']) 22 | feature_names = wine['feature_names'] 23 | 24 | X = MinMaxScaler().fit_transform(X) 25 | X = tr.from_numpy(X) 26 | ``` 27 | 28 | ## Training the K-Means model 29 | Then, we train a K-Means model. 30 | 31 | 32 | ```python 33 | # random state for reproducibility 34 | m = KMeans(n_clusters=3, random_state=77) 35 | m.fit(X) 36 | ``` 37 | 38 | The cluster assignments and true labels can be visualized in a 2D PCA embedding. 39 | 40 | 41 | ```python 42 | # find best match between clusters and classes 43 | y = m.decision(X) 44 | C = contingency_matrix(ytrue, y) 45 | _, best_match = linear_sum_assignment(-C.T) 46 | y = tr.tensor([best_match[i] for i in y]) 47 | 48 | # compute a PCA embedding for visualization 49 | pca = PCA(n_components=2).fit(X) 50 | Z = pca.transform(X) 51 | 52 | plt.title('wine clustering --- facecolor: clusters, edgecolor: classes') 53 | plt.scatter(Z[:,0], Z[:,1], facecolor=cmap(2*y.numpy() + 1), edgecolor=cmap(2*ytrue.numpy()), alpha=.5) 54 | plt.gca().set_aspect('equal') 55 | plt.xticks([]), plt.yticks([]) 56 | plt.show() 57 | ``` 58 | 59 | 60 | 61 |  62 | 63 | 64 | 65 | ## Neuralizing the model 66 | 67 | The decision function for a cluster can be recovered as 68 |  69 | This function contrasts distance to cluster *c* against distance to the nearest competitor. 70 | 71 | 72 | ```python 73 | logits = m(X) 74 | ``` 75 | 76 | As shown in the original paper, the logit can be transformed to a neural network with identical outputs. The layers can be described as 77 |  78 | 79 | 80 | ```python 81 | m = NeuralizedKMeans(m) 82 | 83 | # check if all outputs are exactly the same with the neuralized model 84 | assert tr.isclose(logits, m(X)).all(), "Predictions are not equal!" 85 | ``` 86 | 87 | ## Explaining the cluster assignment 88 | 89 | The neuralized model can be explained with Layer-wise Relevance Propagation (LRP). Here, we use the midpoint-rule in the first layer as described in the paper. 90 | 91 |  92 | with *Ri* the relevance of input variable *xi*. The hyperparameter 𝛽 controls the contribution of other competitors to the explanation. The main purpose is to disambiguate the explanation when more than one competitor is close to the min in layer 2. 93 | 94 | 95 | ```python 96 | R = neon(m, X, beta=1) 97 | ``` 98 | 99 | The explanations can be visualized similarly to the inputs, e.g. in a barplot. 100 | Here, we show an explanation for all misclassified points. 101 | 102 | Note that points near the decision boundary (with probability close to 0.5) have low ambiguity regarding the nearest competitors, hence 𝛽 has little effect. 103 | 104 | 105 | ```python 106 | I = tr.nonzero(y != ytrue)[:,0] 107 | for i in I: 108 | logit = logits[i] 109 | prob = 1 / (1 + tr.exp(-logit)) 110 | 111 | print('data point %d'%i) 112 | print(' cluster assignment: %d (probability %.2f)'%(y[i],prob)) 113 | print(' true class : %d'%ytrue[i]) 114 | print(' sum(R) / logit : %.4f / %.4f'%(sum(R[i]), logit)) 115 | plot_explanation(X[i], R[i], feature_names, vlim=abs(R[I]).max()*1.1) 116 | plt.show() 117 | print('-'*80) 118 | ``` 119 | 120 | data point 60 121 | cluster assignment: 2 (probability 0.52) 122 | true class : 1 123 | sum(R) / logit : 0.0772 / 0.0772 124 | 125 | 126 | 127 | 128 |  129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- 133 | data point 61 134 | cluster assignment: 2 (probability 0.56) 135 | true class : 1 136 | sum(R) / logit : 0.2447 / 0.2447 137 | 138 | 139 | 140 | 141 |  142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- 146 | data point 68 147 | cluster assignment: 2 (probability 0.52) 148 | true class : 1 149 | sum(R) / logit : 0.0800 / 0.0800 150 | 151 | 152 | 153 | 154 |  155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- 159 | data point 70 160 | cluster assignment: 2 (probability 0.52) 161 | true class : 1 162 | sum(R) / logit : 0.0866 / 0.0866 163 | 164 | 165 | 166 | 167 |  168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- 172 | data point 73 173 | cluster assignment: 0 (probability 0.57) 174 | true class : 1 175 | sum(R) / logit : 0.2848 / 0.2848 176 | 177 | 178 | 179 | 180 |  181 | 182 | 183 | 184 | -------------------------------------------------------------------------------- 185 | data point 83 186 | cluster assignment: 2 (probability 0.61) 187 | true class : 1 188 | sum(R) / logit : 0.4378 / 0.4378 189 | 190 | 191 | 192 | 193 |  194 | 195 | 196 | 197 | -------------------------------------------------------------------------------- 198 | data point 92 199 | cluster assignment: 2 (probability 0.50) 200 | true class : 1 201 | sum(R) / logit : 0.0089 / 0.0089 202 | 203 | 204 | 205 | 206 |  207 | 208 | 209 | 210 | -------------------------------------------------------------------------------- 211 | data point 95 212 | cluster assignment: 0 (probability 0.53) 213 | true class : 1 214 | sum(R) / logit : 0.1348 / 0.1348 215 | 216 | 217 | 218 | 219 |  220 | 221 | 222 | 223 | -------------------------------------------------------------------------------- 224 | data point 118 225 | cluster assignment: 2 (probability 0.55) 226 | true class : 1 227 | sum(R) / logit : 0.2151 / 0.2151 228 | 229 | 230 | 231 | 232 |  233 | 234 | 235 | 236 | -------------------------------------------------------------------------------- 237 | 238 | --------------------------------------------------------------------------------
From clustering to cluster explanations via neural networks
arXiv:1906.07633v2, 2021 6 |