├── CKA.ipynb ├── CKA.py └── README.md /CKA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# CKA - Toy Example" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import torch\n", 17 | "import numpy as np\n", 18 | "from CKA import CKA, CudaCKA" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "# Numpy" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 6, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "Linear CKA, between X and Y: 0.010065926085323442\n", 38 | "Linear CKA, between X and X: 1.0\n", 39 | "RBF Kernel CKA, between X and Y: 0.01682517317497278\n", 40 | "RBF Kernel CKA, between X and X: 1.0\n", 41 | "CPU times: user 1h 15min 15s, sys: 26min 15s, total: 1h 41min 30s\n", 42 | "Wall time: 3min 8s\n" 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "%%time\n", 48 | "\n", 49 | "np_cka = CKA()\n", 50 | "\n", 51 | "X = np.random.randn(10000, 100)\n", 52 | "Y = np.random.randn(10000, 100)\n", 53 | "\n", 54 | "print('Linear CKA, between X and Y: {}'.format(np_cka.linear_CKA(X, Y)))\n", 55 | "print('Linear CKA, between X and X: {}'.format(np_cka.linear_CKA(X, X)))\n", 56 | "\n", 57 | "print('RBF Kernel CKA, between X and Y: {}'.format(np_cka.kernel_CKA(X, Y)))\n", 58 | "print('RBF Kernel CKA, between X and X: {}'.format(np_cka.kernel_CKA(X, X)))" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "# PyTorch with CUDA" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 4, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "Linear CKA, between X and Y: 0.009900251403450966\n", 78 | "Linear CKA, between X and X: 0.9999998807907104\n", 79 | "RBF Kernel CKA, between X and Y: 0.016650838777422905\n", 80 | "RBF Kernel CKA, between X and X: 0.9999999403953552\n", 81 | "CPU times: user 10.9 s, sys: 4.39 s, total: 15.3 s\n", 82 | "Wall time: 15.3 s\n" 83 | ] 84 | } 85 | ], 86 | "source": [ 87 | "%%time\n", 88 | "\n", 89 | "device = torch.device('cuda:5')\n", 90 | "cuda_cka = CudaCKA(device)\n", 91 | "\n", 92 | "X = torch.randn(10000, 100, device=device)\n", 93 | "Y = torch.randn(10000, 100, device=device)\n", 94 | "\n", 95 | "print('Linear CKA, between X and Y: {}'.format(cuda_cka.linear_CKA(X, Y)))\n", 96 | "print('Linear CKA, between X and X: {}'.format(cuda_cka.linear_CKA(X, X)))\n", 97 | "\n", 98 | "print('RBF Kernel CKA, between X and Y: {}'.format(cuda_cka.kernel_CKA(X, Y)))\n", 99 | "print('RBF Kernel CKA, between X and X: {}'.format(cuda_cka.kernel_CKA(X, X)))" 100 | ] 101 | } 102 | ], 103 | "metadata": { 104 | "kernelspec": { 105 | "display_name": "Python 3", 106 | "language": "python", 107 | "name": "python3" 108 | }, 109 | "language_info": { 110 | "codemirror_mode": { 111 | "name": "ipython", 112 | "version": 3 113 | }, 114 | "file_extension": ".py", 115 | "mimetype": "text/x-python", 116 | "name": "python", 117 | "nbconvert_exporter": "python", 118 | "pygments_lexer": "ipython3", 119 | "version": "3.6.9" 120 | } 121 | }, 122 | "nbformat": 4, 123 | "nbformat_minor": 4 124 | } 125 | -------------------------------------------------------------------------------- /CKA.py: -------------------------------------------------------------------------------- 1 | # inspired by 2 | # https://github.com/yuanli2333/CKA-Centered-Kernel-Alignment/blob/master/CKA.py 3 | 4 | import math 5 | import torch 6 | import numpy as np 7 | 8 | class CKA(object): 9 | def __init__(self): 10 | pass 11 | 12 | def centering(self, K): 13 | n = K.shape[0] 14 | unit = np.ones([n, n]) 15 | I = np.eye(n) 16 | H = I - unit / n 17 | return np.dot(np.dot(H, K), H) 18 | 19 | def rbf(self, X, sigma=None): 20 | GX = np.dot(X, X.T) 21 | KX = np.diag(GX) - GX + (np.diag(GX) - GX).T 22 | if sigma is None: 23 | mdist = np.median(KX[KX != 0]) 24 | sigma = math.sqrt(mdist) 25 | KX *= - 0.5 / (sigma * sigma) 26 | KX = np.exp(KX) 27 | return KX 28 | 29 | def kernel_HSIC(self, X, Y, sigma): 30 | return np.sum(self.centering(self.rbf(X, sigma)) * self.centering(self.rbf(Y, sigma))) 31 | 32 | def linear_HSIC(self, X, Y): 33 | L_X = X @ X.T 34 | L_Y = Y @ Y.T 35 | return np.sum(self.centering(L_X) * self.centering(L_Y)) 36 | 37 | def linear_CKA(self, X, Y): 38 | hsic = self.linear_HSIC(X, Y) 39 | var1 = np.sqrt(self.linear_HSIC(X, X)) 40 | var2 = np.sqrt(self.linear_HSIC(Y, Y)) 41 | 42 | return hsic / (var1 * var2) 43 | 44 | def kernel_CKA(self, X, Y, sigma=None): 45 | hsic = self.kernel_HSIC(X, Y, sigma) 46 | var1 = np.sqrt(self.kernel_HSIC(X, X, sigma)) 47 | var2 = np.sqrt(self.kernel_HSIC(Y, Y, sigma)) 48 | 49 | return hsic / (var1 * var2) 50 | 51 | 52 | class CudaCKA(object): 53 | def __init__(self, device): 54 | self.device = device 55 | 56 | def centering(self, K): 57 | n = K.shape[0] 58 | unit = torch.ones([n, n], device=self.device) 59 | I = torch.eye(n, device=self.device) 60 | H = I - unit / n 61 | return torch.matmul(torch.matmul(H, K), H) 62 | 63 | def rbf(self, X, sigma=None): 64 | GX = torch.matmul(X, X.T) 65 | KX = torch.diag(GX) - GX + (torch.diag(GX) - GX).T 66 | if sigma is None: 67 | mdist = torch.median(KX[KX != 0]) 68 | sigma = math.sqrt(mdist) 69 | KX *= - 0.5 / (sigma * sigma) 70 | KX = torch.exp(KX) 71 | return KX 72 | 73 | def kernel_HSIC(self, X, Y, sigma): 74 | return torch.sum(self.centering(self.rbf(X, sigma)) * self.centering(self.rbf(Y, sigma))) 75 | 76 | def linear_HSIC(self, X, Y): 77 | L_X = torch.matmul(X, X.T) 78 | L_Y = torch.matmul(Y, Y.T) 79 | return torch.sum(self.centering(L_X) * self.centering(L_Y)) 80 | 81 | def linear_CKA(self, X, Y): 82 | hsic = self.linear_HSIC(X, Y) 83 | var1 = torch.sqrt(self.linear_HSIC(X, X)) 84 | var2 = torch.sqrt(self.linear_HSIC(Y, Y)) 85 | 86 | return hsic / (var1 * var2) 87 | 88 | def kernel_CKA(self, X, Y, sigma=None): 89 | hsic = self.kernel_HSIC(X, Y, sigma) 90 | var1 = torch.sqrt(self.kernel_HSIC(X, X, sigma)) 91 | var2 = torch.sqrt(self.kernel_HSIC(Y, Y, sigma)) 92 | return hsic / (var1 * var2) 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CKA-similarity 2 | An PyTorch Implementation of CKA-similarity with CUDA support. 3 | Inspired by: 4 | https://github.com/yuanli2333/CKA-Centered-Kernel-Alignment 5 | 6 | The Centered Kernel Alignment (CKA) method is from paper: 7 | 8 | *Similarity of Neural Network Representations Revisited*. [https://arxiv.org/abs/1905.00414] 9 | 10 | ## Running time for the example notebook 11 | Numpy: 3min 8s 12 | 13 | PyTorch: **15.3s** 14 | 15 | For large matrices, use PyTorch with CUDA to accelerate. 16 | 17 | # Dependencies 18 | ``` 19 | python3 20 | numpy 21 | gzip 22 | torch 23 | ``` 24 | 25 | Tested environment: *PyTorch v1.7.0*, *Python 3.6.9* 26 | 27 | --------------------------------------------------------------------------------