├── .gitignore
├── README.md
├── cka.py
├── example.py
├── img
├── cka_minibatch.png
├── hsic.png
└── r18_cka_new.png
└── test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | .idea/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CKA_minibatch_pytorch
2 |
3 | Pytorch implementation of Centered Kernel Alignment (CKA) and its minibatch version.
4 |
5 |
6 |
7 |
8 |
9 | ## Example
10 |
11 | A simple example comparing layer outputs of resnet18 can be found in `example.py`.
12 |
13 |
14 |
15 |
16 |
17 | ## Get Started
18 |
19 | **Linear CKA**
20 |
21 | ```python
22 | from cka import cka_score
23 |
24 | # compute cka with features
25 | cka_score = cka_score(x1, x2)
26 |
27 | # compute cka with gram matrix
28 | g1 = x1 @ x1.transpose(0,1)
29 | g2 = x2 @ x2.transpose(0,1)
30 | cka_score = cka_score(g1, g2, gram=True)
31 | ```
32 |
33 | **Minibatch CKA**
34 |
35 | ```python
36 | from cka import CKA_Minibatch
37 |
38 | cka_logger = CKA_Minibatch()
39 | for data_batch in data_loader:
40 | x1 = model_1(data_batch)
41 | x2 = model_2(data_batch)
42 | cka_logger.update(x1, x2)
43 | # aggregate multiple splits to get more accurate estimation
44 | cka_score = cka_logger.compute()
45 | ```
46 |
47 | **Grid of Minibatch CKA**
48 |
49 | ```python
50 | from cka import CKA_Minibatch_Grid
51 |
52 | cka_logger = CKA_Minibatch_Grid(d1, d2)
53 | for data_batch in data_loader:
54 | feature_list_1 = model_1(data_batch) # len(feature_list_1) = d1
55 | feature_list_2 = model_2(data_batch) # len(feature_list_2) = d2
56 | cka_logger.update(feature_list_1, feature_list_2)
57 | cka_score_grid = cka_logger.compute() # [d1, d2]
58 | ```
59 |
60 |
61 |
62 | ## Reference
63 |
64 | [1] Kornblith, Simon, et al. "Similarity of neural network representations revisited." *International Conference on Machine Learning*. PMLR, 2019.
65 |
66 | [2] Nguyen, Thao, Maithra Raghu, and Simon Kornblith. "Do wide and deep networks learn the same things? uncovering how neural network representations vary with width and depth." *arXiv preprint arXiv:2010.15327* (2020).
67 |
--------------------------------------------------------------------------------
/cka.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 | from torch import Tensor
3 | import torch
4 | from torch.nn import Module
5 |
6 |
7 | def centering(k: Tensor, inplace: bool = True) -> Tensor:
8 | if not inplace:
9 | k = torch.clone(k)
10 | means = k.mean(dim=0)
11 | means -= means.mean() / 2
12 | k -= means.view(-1, 1)
13 | k -= means.view(1, -1)
14 | return k
15 |
16 |
17 | # def centering(k: Tensor) -> Tensor:
18 | # m = k.shape[0]
19 | # h = torch.eye(m) - torch.ones(m, m) / m
20 | # return torch.matmul(h, torch.matmul(k, h))
21 |
22 |
23 | def linear_hsic(k: Tensor, l: Tensor, unbiased: bool = True) -> Tensor:
24 | assert k.shape[0] == l.shape[0], 'Input must have the same size'
25 | m = k.shape[0]
26 | if unbiased:
27 | k.fill_diagonal_(0)
28 | l.fill_diagonal_(0)
29 | kl = torch.matmul(k, l)
30 | score = torch.trace(kl) + k.sum() * l.sum() / ((m - 1) * (m - 2)) - 2 * kl.sum() / (m - 2)
31 | return score / (m * (m - 3))
32 | else:
33 | k, l = centering(k), centering(l)
34 | return (k * l).sum() / ((m - 1) ** 2)
35 |
36 |
37 | def cka_score(x1: Tensor, x2: Tensor, gram: bool = False) -> Tensor:
38 | assert x1.shape[0] == x2.shape[0], 'Input must have the same batch size'
39 | if not gram:
40 | x1 = torch.matmul(x1, x1.transpose(0, 1))
41 | x2 = torch.matmul(x2, x2.transpose(0, 1))
42 | cross_score = linear_hsic(x1, x2)
43 | self_score1 = linear_hsic(x1, x1)
44 | self_score2 = linear_hsic(x2, x2)
45 | return cross_score / torch.sqrt(self_score1 * self_score2)
46 |
47 |
48 | class CKA_Minibatch(Module):
49 | """
50 | Minibatch Centered Kernel Alignment
51 | Reference: https://arxiv.org/pdf/2010.15327
52 | """
53 |
54 | def __init__(self):
55 | super().__init__()
56 | self.total = 0
57 | self.cross_hsic, self.self_hsic1, self.self_hsic2 = [], [], []
58 |
59 | def reset(self):
60 | self.total = 0
61 | self.cross_hsic, self.self_hsic1, self.self_hsic2 = [], [], []
62 |
63 | def update(self, x1: Tensor, x2: Tensor, gram: bool = False) -> None:
64 | """
65 | gram: if true, the method takes gram matrix as input
66 | """
67 | assert x1.shape[0] == x2.shape[0], 'Input must have the same batch size'
68 | self.total += 1
69 | if not gram:
70 | x1 = torch.matmul(x1, x1.transpose(0, 1))
71 | x2 = torch.matmul(x2, x2.transpose(0, 1))
72 | self.cross_hsic.append(linear_hsic(x1, x2))
73 | self.self_hsic1.append(linear_hsic(x1, x1))
74 | self.self_hsic2.append(linear_hsic(x2, x2))
75 |
76 | def compute(self) -> Tensor:
77 | assert self.total > 0, 'Please call method update(x1, x2) first!'
78 | cross_score = sum(self.cross_hsic) / self.total
79 | self_score1 = sum(self.self_hsic1) / self.total
80 | self_score2 = sum(self.self_hsic2) / self.total
81 | return cross_score / torch.sqrt(self_score1 * self_score2)
82 |
83 |
84 | class CKA_Minibatch_Grid(Module):
85 | '''
86 | Compute CKA for a 2D grid of features
87 | '''
88 |
89 | def __init__(self, dim1: int, dim2: int):
90 | super().__init__()
91 | self.cka_loggers = [[CKA_Minibatch() for _ in range(dim2)] for _ in range(dim1)]
92 | self.dim1 = dim1
93 | self.dim2 = dim2
94 |
95 | def reset(self):
96 | for i in range(self.dim1):
97 | for j in range(self.dim2):
98 | self.cka_loggers[i][j].reset()
99 |
100 | def update(self, x1: Sequence[Tensor], x2: Sequence[Tensor], gram: bool = False) -> None:
101 | assert len(x1) == self.dim1, 'Grid dim0 mismatch'
102 | assert len(x2) == self.dim2, 'Grid dim1 mismatch'
103 | if not gram:
104 | x1 = [torch.matmul(x, x.transpose(0, 1)) for x in x1]
105 | x2 = [torch.matmul(x, x.transpose(0, 1)) for x in x2]
106 | for i in range(self.dim1):
107 | for j in range(self.dim2):
108 | self.cka_loggers[i][j].update(x1[i], x2[j], gram=True)
109 |
110 | def compute(self) -> Tensor:
111 | result = torch.zeros(self.dim1, self.dim2)
112 | for i in range(self.dim1):
113 | for j in range(self.dim2):
114 | result[i, j] = self.cka_loggers[i][j].compute()
115 | return result
116 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import torch.utils.data
5 | import torchvision.datasets as datasets
6 | from torchvision.models import resnet18
7 | import torchvision.transforms as transforms
8 |
9 | from cka import CKA_Minibatch_Grid
10 |
11 |
12 | def forward_features(model, x):
13 | _b = x.shape[0]
14 | x = model.conv1(x)
15 | x = model.bn1(x)
16 | x = model.relu(x)
17 | x = model.maxpool(x)
18 |
19 | x = model.layer1(x)
20 | x1 = x
21 | x = model.layer2(x)
22 | x2 = x
23 | x = model.layer3(x)
24 | x3 = x
25 | x = model.layer4(x)
26 | x4 = x
27 | return x1.view(_b, -1), x2.view(_b, -1), x3.view(_b, -1), x4.view(_b, -1)
28 |
29 |
30 |
31 | def main():
32 | DATA_ROOT = '/home/data/ImageNet/val'
33 | batch_size = 128
34 | dataset_size = 1280
35 | num_sweep = 10
36 | num_features = 4
37 |
38 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
39 | std=[0.229, 0.224, 0.225])
40 | torch.random.manual_seed(0)
41 | perms = [torch.randperm(dataset_size) for _ in range(num_sweep)]
42 | dataset = datasets.ImageFolder(DATA_ROOT, transforms.Compose([
43 | transforms.Resize(256),
44 | transforms.CenterCrop(224),
45 | transforms.ToTensor(),
46 | normalize,
47 | ]))
48 |
49 | model = resnet18(pretrained=True)
50 | model.cuda()
51 | model.eval()
52 | cka_logger = CKA_Minibatch_Grid(num_features, num_features)
53 | with torch.no_grad():
54 | for sweep in range(num_sweep):
55 | dataset_sweep = torch.utils.data.Subset(dataset, perms[sweep])
56 | data_loader = torch.utils.data.DataLoader(
57 | dataset_sweep,
58 | batch_size=batch_size, shuffle=False,
59 | num_workers=4, pin_memory=True)
60 | for images, targets in tqdm(data_loader):
61 | images = images.cuda()
62 | features = forward_features(model, images)
63 | cka_logger.update(features, features)
64 | torch.cuda.empty_cache()
65 |
66 | cka_matrix = cka_logger.compute()
67 |
68 | plt.title('Pretrained Resnet18 Layer CKA')
69 | plt.xticks([0, 1, 2, 3], ['Layer 1', 'Layer 2', 'Layer 3', 'Layer 4'])
70 | plt.yticks([0, 1, 2, 3], ['Layer 1', 'Layer 2', 'Layer 3', 'Layer 4'])
71 | plt.imshow(cka_matrix.numpy(), origin='lower', cmap='magma')
72 | plt.clim(0, 1)
73 | plt.colorbar()
74 | plt.savefig('r18_cka_new.png')
75 |
76 | if __name__ == '__main__':
77 | main()
--------------------------------------------------------------------------------
/img/cka_minibatch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yueyang2000/CKA_minibatch_pytorch/8e27dd4b38f6081aa6a1409e8a4dd89ea6f06bc0/img/cka_minibatch.png
--------------------------------------------------------------------------------
/img/hsic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yueyang2000/CKA_minibatch_pytorch/8e27dd4b38f6081aa6a1409e8a4dd89ea6f06bc0/img/hsic.png
--------------------------------------------------------------------------------
/img/r18_cka_new.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yueyang2000/CKA_minibatch_pytorch/8e27dd4b38f6081aa6a1409e8a4dd89ea6f06bc0/img/r18_cka_new.png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from cka import cka_score, CKA_Minibatch
3 |
4 | bs = 64
5 | split = 20
6 | n_epoch = 20
7 |
8 | x1 = torch.rand(bs * split, 128)
9 | x2 = 2 * x1 + torch.randn(x1.shape) * 1
10 |
11 | cka_minibatch = CKA_Minibatch()
12 | for epoch in range(n_epoch):
13 | perm = torch.randperm(bs*split)
14 | x1 = x1[perm]
15 | x2 = x2[perm]
16 | for i in range(split):
17 | b1 = x1[i*bs:(i+1)*bs]
18 | b2 = x2[i*bs:(i+1)*bs]
19 | cka_minibatch.update(b1, b2)
20 | score = cka_minibatch.compute()
21 | print(f'Minibatch CKA at epoch {epoch}: {score.item()}')
22 |
23 | print('Full CKA:', cka_score(x1, x2).item())
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------