├── .gitignore ├── README.md ├── cka.py ├── example.py ├── img ├── cka_minibatch.png ├── hsic.png └── r18_cka_new.png └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CKA_minibatch_pytorch 2 | 3 | Pytorch implementation of Centered Kernel Alignment (CKA) and its minibatch version. 4 | 5 | cka_minibatch 6 | 7 | hsic 8 | 9 | ## Example 10 | 11 | A simple example comparing layer outputs of resnet18 can be found in `example.py`. 12 | 13 | r18_cka 14 | 15 | 16 | 17 | ## Get Started 18 | 19 | **Linear CKA** 20 | 21 | ```python 22 | from cka import cka_score 23 | 24 | # compute cka with features 25 | cka_score = cka_score(x1, x2) 26 | 27 | # compute cka with gram matrix 28 | g1 = x1 @ x1.transpose(0,1) 29 | g2 = x2 @ x2.transpose(0,1) 30 | cka_score = cka_score(g1, g2, gram=True) 31 | ``` 32 | 33 | **Minibatch CKA** 34 | 35 | ```python 36 | from cka import CKA_Minibatch 37 | 38 | cka_logger = CKA_Minibatch() 39 | for data_batch in data_loader: 40 | x1 = model_1(data_batch) 41 | x2 = model_2(data_batch) 42 | cka_logger.update(x1, x2) 43 | # aggregate multiple splits to get more accurate estimation 44 | cka_score = cka_logger.compute() 45 | ``` 46 | 47 | **Grid of Minibatch CKA** 48 | 49 | ```python 50 | from cka import CKA_Minibatch_Grid 51 | 52 | cka_logger = CKA_Minibatch_Grid(d1, d2) 53 | for data_batch in data_loader: 54 | feature_list_1 = model_1(data_batch) # len(feature_list_1) = d1 55 | feature_list_2 = model_2(data_batch) # len(feature_list_2) = d2 56 | cka_logger.update(feature_list_1, feature_list_2) 57 | cka_score_grid = cka_logger.compute() # [d1, d2] 58 | ``` 59 | 60 | 61 | 62 | ## Reference 63 | 64 | [1] Kornblith, Simon, et al. "Similarity of neural network representations revisited." *International Conference on Machine Learning*. PMLR, 2019. 65 | 66 | [2] Nguyen, Thao, Maithra Raghu, and Simon Kornblith. "Do wide and deep networks learn the same things? uncovering how neural network representations vary with width and depth." *arXiv preprint arXiv:2010.15327* (2020). 67 | -------------------------------------------------------------------------------- /cka.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | from torch import Tensor 3 | import torch 4 | from torch.nn import Module 5 | 6 | 7 | def centering(k: Tensor, inplace: bool = True) -> Tensor: 8 | if not inplace: 9 | k = torch.clone(k) 10 | means = k.mean(dim=0) 11 | means -= means.mean() / 2 12 | k -= means.view(-1, 1) 13 | k -= means.view(1, -1) 14 | return k 15 | 16 | 17 | # def centering(k: Tensor) -> Tensor: 18 | # m = k.shape[0] 19 | # h = torch.eye(m) - torch.ones(m, m) / m 20 | # return torch.matmul(h, torch.matmul(k, h)) 21 | 22 | 23 | def linear_hsic(k: Tensor, l: Tensor, unbiased: bool = True) -> Tensor: 24 | assert k.shape[0] == l.shape[0], 'Input must have the same size' 25 | m = k.shape[0] 26 | if unbiased: 27 | k.fill_diagonal_(0) 28 | l.fill_diagonal_(0) 29 | kl = torch.matmul(k, l) 30 | score = torch.trace(kl) + k.sum() * l.sum() / ((m - 1) * (m - 2)) - 2 * kl.sum() / (m - 2) 31 | return score / (m * (m - 3)) 32 | else: 33 | k, l = centering(k), centering(l) 34 | return (k * l).sum() / ((m - 1) ** 2) 35 | 36 | 37 | def cka_score(x1: Tensor, x2: Tensor, gram: bool = False) -> Tensor: 38 | assert x1.shape[0] == x2.shape[0], 'Input must have the same batch size' 39 | if not gram: 40 | x1 = torch.matmul(x1, x1.transpose(0, 1)) 41 | x2 = torch.matmul(x2, x2.transpose(0, 1)) 42 | cross_score = linear_hsic(x1, x2) 43 | self_score1 = linear_hsic(x1, x1) 44 | self_score2 = linear_hsic(x2, x2) 45 | return cross_score / torch.sqrt(self_score1 * self_score2) 46 | 47 | 48 | class CKA_Minibatch(Module): 49 | """ 50 | Minibatch Centered Kernel Alignment 51 | Reference: https://arxiv.org/pdf/2010.15327 52 | """ 53 | 54 | def __init__(self): 55 | super().__init__() 56 | self.total = 0 57 | self.cross_hsic, self.self_hsic1, self.self_hsic2 = [], [], [] 58 | 59 | def reset(self): 60 | self.total = 0 61 | self.cross_hsic, self.self_hsic1, self.self_hsic2 = [], [], [] 62 | 63 | def update(self, x1: Tensor, x2: Tensor, gram: bool = False) -> None: 64 | """ 65 | gram: if true, the method takes gram matrix as input 66 | """ 67 | assert x1.shape[0] == x2.shape[0], 'Input must have the same batch size' 68 | self.total += 1 69 | if not gram: 70 | x1 = torch.matmul(x1, x1.transpose(0, 1)) 71 | x2 = torch.matmul(x2, x2.transpose(0, 1)) 72 | self.cross_hsic.append(linear_hsic(x1, x2)) 73 | self.self_hsic1.append(linear_hsic(x1, x1)) 74 | self.self_hsic2.append(linear_hsic(x2, x2)) 75 | 76 | def compute(self) -> Tensor: 77 | assert self.total > 0, 'Please call method update(x1, x2) first!' 78 | cross_score = sum(self.cross_hsic) / self.total 79 | self_score1 = sum(self.self_hsic1) / self.total 80 | self_score2 = sum(self.self_hsic2) / self.total 81 | return cross_score / torch.sqrt(self_score1 * self_score2) 82 | 83 | 84 | class CKA_Minibatch_Grid(Module): 85 | ''' 86 | Compute CKA for a 2D grid of features 87 | ''' 88 | 89 | def __init__(self, dim1: int, dim2: int): 90 | super().__init__() 91 | self.cka_loggers = [[CKA_Minibatch() for _ in range(dim2)] for _ in range(dim1)] 92 | self.dim1 = dim1 93 | self.dim2 = dim2 94 | 95 | def reset(self): 96 | for i in range(self.dim1): 97 | for j in range(self.dim2): 98 | self.cka_loggers[i][j].reset() 99 | 100 | def update(self, x1: Sequence[Tensor], x2: Sequence[Tensor], gram: bool = False) -> None: 101 | assert len(x1) == self.dim1, 'Grid dim0 mismatch' 102 | assert len(x2) == self.dim2, 'Grid dim1 mismatch' 103 | if not gram: 104 | x1 = [torch.matmul(x, x.transpose(0, 1)) for x in x1] 105 | x2 = [torch.matmul(x, x.transpose(0, 1)) for x in x2] 106 | for i in range(self.dim1): 107 | for j in range(self.dim2): 108 | self.cka_loggers[i][j].update(x1[i], x2[j], gram=True) 109 | 110 | def compute(self) -> Tensor: 111 | result = torch.zeros(self.dim1, self.dim2) 112 | for i in range(self.dim1): 113 | for j in range(self.dim2): 114 | result[i, j] = self.cka_loggers[i][j].compute() 115 | return result 116 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.utils.data 5 | import torchvision.datasets as datasets 6 | from torchvision.models import resnet18 7 | import torchvision.transforms as transforms 8 | 9 | from cka import CKA_Minibatch_Grid 10 | 11 | 12 | def forward_features(model, x): 13 | _b = x.shape[0] 14 | x = model.conv1(x) 15 | x = model.bn1(x) 16 | x = model.relu(x) 17 | x = model.maxpool(x) 18 | 19 | x = model.layer1(x) 20 | x1 = x 21 | x = model.layer2(x) 22 | x2 = x 23 | x = model.layer3(x) 24 | x3 = x 25 | x = model.layer4(x) 26 | x4 = x 27 | return x1.view(_b, -1), x2.view(_b, -1), x3.view(_b, -1), x4.view(_b, -1) 28 | 29 | 30 | 31 | def main(): 32 | DATA_ROOT = '/home/data/ImageNet/val' 33 | batch_size = 128 34 | dataset_size = 1280 35 | num_sweep = 10 36 | num_features = 4 37 | 38 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 39 | std=[0.229, 0.224, 0.225]) 40 | torch.random.manual_seed(0) 41 | perms = [torch.randperm(dataset_size) for _ in range(num_sweep)] 42 | dataset = datasets.ImageFolder(DATA_ROOT, transforms.Compose([ 43 | transforms.Resize(256), 44 | transforms.CenterCrop(224), 45 | transforms.ToTensor(), 46 | normalize, 47 | ])) 48 | 49 | model = resnet18(pretrained=True) 50 | model.cuda() 51 | model.eval() 52 | cka_logger = CKA_Minibatch_Grid(num_features, num_features) 53 | with torch.no_grad(): 54 | for sweep in range(num_sweep): 55 | dataset_sweep = torch.utils.data.Subset(dataset, perms[sweep]) 56 | data_loader = torch.utils.data.DataLoader( 57 | dataset_sweep, 58 | batch_size=batch_size, shuffle=False, 59 | num_workers=4, pin_memory=True) 60 | for images, targets in tqdm(data_loader): 61 | images = images.cuda() 62 | features = forward_features(model, images) 63 | cka_logger.update(features, features) 64 | torch.cuda.empty_cache() 65 | 66 | cka_matrix = cka_logger.compute() 67 | 68 | plt.title('Pretrained Resnet18 Layer CKA') 69 | plt.xticks([0, 1, 2, 3], ['Layer 1', 'Layer 2', 'Layer 3', 'Layer 4']) 70 | plt.yticks([0, 1, 2, 3], ['Layer 1', 'Layer 2', 'Layer 3', 'Layer 4']) 71 | plt.imshow(cka_matrix.numpy(), origin='lower', cmap='magma') 72 | plt.clim(0, 1) 73 | plt.colorbar() 74 | plt.savefig('r18_cka_new.png') 75 | 76 | if __name__ == '__main__': 77 | main() -------------------------------------------------------------------------------- /img/cka_minibatch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yueyang2000/CKA_minibatch_pytorch/8e27dd4b38f6081aa6a1409e8a4dd89ea6f06bc0/img/cka_minibatch.png -------------------------------------------------------------------------------- /img/hsic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yueyang2000/CKA_minibatch_pytorch/8e27dd4b38f6081aa6a1409e8a4dd89ea6f06bc0/img/hsic.png -------------------------------------------------------------------------------- /img/r18_cka_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yueyang2000/CKA_minibatch_pytorch/8e27dd4b38f6081aa6a1409e8a4dd89ea6f06bc0/img/r18_cka_new.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from cka import cka_score, CKA_Minibatch 3 | 4 | bs = 64 5 | split = 20 6 | n_epoch = 20 7 | 8 | x1 = torch.rand(bs * split, 128) 9 | x2 = 2 * x1 + torch.randn(x1.shape) * 1 10 | 11 | cka_minibatch = CKA_Minibatch() 12 | for epoch in range(n_epoch): 13 | perm = torch.randperm(bs*split) 14 | x1 = x1[perm] 15 | x2 = x2[perm] 16 | for i in range(split): 17 | b1 = x1[i*bs:(i+1)*bs] 18 | b2 = x2[i*bs:(i+1)*bs] 19 | cka_minibatch.update(b1, b2) 20 | score = cka_minibatch.compute() 21 | print(f'Minibatch CKA at epoch {epoch}: {score.item()}') 22 | 23 | print('Full CKA:', cka_score(x1, x2).item()) 24 | 25 | 26 | 27 | --------------------------------------------------------------------------------