├── .gitignore ├── LICENSE ├── README.md ├── cka.py ├── example.ipynb ├── hook_manager.py └── metrics.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Custom 4 | 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | ### JetBrains template 111 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 112 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 113 | 114 | # User-specific stuff 115 | .idea/**/workspace.xml 116 | .idea/**/tasks.xml 117 | .idea/**/usage.statistics.xml 118 | .idea/**/dictionaries 119 | .idea/**/shelf 120 | 121 | # Sensitive or high-churn files 122 | .idea/**/dataSources/ 123 | .idea/**/dataSources.ids 124 | .idea/**/dataSources.local.xml 125 | .idea/**/sqlDataSources.xml 126 | .idea/**/dynamic.xml 127 | .idea/**/uiDesigner.xml 128 | .idea/**/dbnavigator.xml 129 | 130 | # Gradle 131 | .idea/**/gradle.xml 132 | .idea/**/libraries 133 | 134 | # Gradle and Maven with auto-import 135 | # When using Gradle or Maven with auto-import, you should exclude module files, 136 | # since they will be recreated, and may cause churn. Uncomment if using 137 | # auto-import. 138 | # .idea/modules.xml 139 | # .idea/*.iml 140 | # .idea/modules 141 | 142 | # CMake 143 | cmake-build-*/ 144 | 145 | # Mongo Explorer plugin 146 | .idea/**/mongoSettings.xml 147 | 148 | # File-based project format 149 | *.iws 150 | 151 | # IntelliJ 152 | out/ 153 | 154 | # mpeltonen/sbt-idea plugin 155 | .idea_modules/ 156 | 157 | # JIRA plugin 158 | atlassian-ide-plugin.xml 159 | 160 | # Cursive Clojure plugin 161 | .idea/replstate.xml 162 | 163 | # Crashlytics plugin (for Android Studio and IntelliJ) 164 | com_crashlytics_export_strings.xml 165 | crashlytics.properties 166 | crashlytics-build.properties 167 | fabric.properties 168 | 169 | # Editor-based Rest Client 170 | .idea/httpRequests 171 | .idea/ 172 | 173 | # Custom 174 | experiments/ 175 | wandb/ 176 | **/.DS_Store 177 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 Dongwan Kim 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Centered Kernel Alignment (CKA) - PyTorch Implementation 2 | 3 | **A PyTorch implementation of Centered Kernel Alignment (CKA) with GPU support.** 4 | 5 | This code was used for the CKA analysis in our CVPR 2023 paper, "[On the Stability-Plasticity Dilemma of Class-Incremental Learning](https://openaccess.thecvf.com/content/CVPR2023/papers/Kim_On_the_Stability-Plasticity_Dilemma_of_Class-Incremental_Learning_CVPR_2023_paper.pdf)". 6 | 7 | ## Usage 8 | ```python 9 | model1 = ... # Some model, casted to GPU 10 | model2 = ... # Another model, casted to GPU 11 | dataloader = ... # Your dataloader 12 | 13 | calculator = CKACalculator(model1, model2, dataloader) 14 | cka_matrix = calculator.calculate_cka_matrix() 15 | ``` 16 | 17 | Rather than caching intermediate feature representations, this code computes CKA on-the-fly (simultaneously with the model forward pass) by using the mini-batch CKA, as described in the [paper by Nguyen et. al.](https://openreview.net/pdf?id=KJNcAkY8tY4) 18 | By leveraging GPU superiority, **this implementation runs much faster than any Numpy implementation.** 19 | 20 | ## Setup 21 | I haven't added a `requirements.txt` since the exact version of each package is not that important :man_shrugging: 22 | 23 | #### Required packages to use the class/functions: 24 | * python3.7+ 25 | * torch (any relatively recent version should be O.K.) 26 | * torchvision 27 | * tqdm 28 | * torchmetrics 29 | 30 | #### To run the `example.ipynb`: 31 | * jupyter 32 | * matplotlib 33 | * numpy 34 | 35 | ## Example notebook 36 | Try out the example notebook in `example.ipynb`. 37 | 38 | ## Other 39 | * If you found this repo helpful, please give it a :star: 40 | * If you find any bugs/improvements, feel free to create a new issue. 41 | * This code is mostly tested on ResNets 42 | 43 | ### TODO (when I feel like it) 44 | * Ditch hooks; change to `torch.fx` implementation 45 | -------------------------------------------------------------------------------- /cka.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tool to compute Centered Kernel Alignment (CKA) in PyTorch w/ GPU (single or multi). 3 | 4 | Repo: https://github.com/numpee/CKA.pytorch 5 | Author: Dongwan Kim (Github: Numpee) 6 | Year: 2022 7 | """ 8 | 9 | from __future__ import annotations 10 | 11 | from typing import Tuple, Optional, Callable, Type, Union, TYPE_CHECKING, List 12 | 13 | import torch 14 | import torch.nn as nn 15 | from tqdm.autonotebook import tqdm 16 | 17 | from hook_manager import HookManager, _HOOK_LAYER_TYPES 18 | from metrics import AccumTensor 19 | 20 | if TYPE_CHECKING: 21 | from torch.utils.data import DataLoader 22 | 23 | 24 | class CKACalculator: 25 | def __init__(self, model1: nn.Module, model2: nn.Module, dataloader: DataLoader, 26 | hook_fn: Optional[Union[str, Callable]] = None, 27 | hook_layer_types: Tuple[Type[nn.Module], ...] = _HOOK_LAYER_TYPES, num_epochs: int = 10, 28 | group_size: int = 512, epsilon: float = 1e-4, is_main_process: bool = True) -> None: 29 | """ 30 | Class to extract intermediate features and calculate CKA Matrix. 31 | :param model1: model to evaluate. __call__ function should be implemented if NOT instance of `nn.Module`. 32 | :param model2: second model to evaluate. __call__ function should be implemented if NOT instance of `nn.Module`. 33 | :param dataloader: Torch DataLoader for dataloading. Assumes first return value contains input images. 34 | :param hook_fn: Optional - Hook function or hook name string for the HookManager. Options: [flatten, avgpool]. Default: flatten 35 | :param hook_layer_types: Types of layers (modules) to add hooks to. 36 | :param num_epochs: Number of epochs for cka_batch. Default: 10 37 | :param group_size: group_size for GPU acceleration. Default: 512 38 | :param epsilon: Small multiplicative value for HSIC. Default: 1e-4 39 | :param is_main_process: is current instance main process. Default: True 40 | """ 41 | self.model1 = model1 42 | self.model2 = model2 43 | self.dataloader = dataloader 44 | self.num_epochs = num_epochs 45 | self.group_size = group_size 46 | self.epsilon = epsilon 47 | self.is_main_process = is_main_process 48 | 49 | self.model1.eval() 50 | self.model2.eval() 51 | self.hook_manager1 = HookManager(self.model1, hook_fn, hook_layer_types, calculate_gram=True) 52 | self.hook_manager2 = HookManager(self.model2, hook_fn, hook_layer_types, calculate_gram=True) 53 | self.module_names_X = None 54 | self.module_names_Y = None 55 | self.num_layers_X = None 56 | self.num_layers_Y = None 57 | self.num_elements = None 58 | 59 | # Metrics to track 60 | self.cka_matrix = None 61 | self.hsic_matrix = None 62 | self.self_hsic_x = None 63 | self.self_hsic_y = None 64 | 65 | @torch.no_grad() 66 | def calculate_cka_matrix(self) -> torch.Tensor: 67 | curr_hsic_matrix = None 68 | curr_self_hsic_x = None 69 | curr_self_hsic_y = None 70 | for epoch in range(self.num_epochs): 71 | loader = tqdm(self.dataloader, desc=f"Epoch {epoch}", disable=not self.is_main_process) 72 | for it, (imgs, *_) in enumerate(loader): 73 | imgs = imgs.cuda(non_blocking=True) 74 | self.model1(imgs) 75 | self.model2(imgs) 76 | all_layer_X, all_layer_Y = self.extract_layer_list_from_hook_manager() 77 | 78 | # Initialize values on first loop 79 | if self.num_layers_X is None: 80 | curr_hsic_matrix, curr_self_hsic_x, curr_self_hsic_y = self._init_values(all_layer_X, all_layer_Y) 81 | 82 | # Get self HSIC values --> HSIC(K, K), HSIC(L, L) 83 | self._calculate_self_hsic(all_layer_X, all_layer_Y, curr_self_hsic_x, curr_self_hsic_y) 84 | 85 | # Get cross HSIC values --> HSIC(K, L) 86 | self._calculate_cross_hsic(all_layer_X, all_layer_Y, curr_hsic_matrix) 87 | 88 | self.hook_manager1.clear_features() 89 | self.hook_manager2.clear_features() 90 | curr_hsic_matrix.fill_(0) 91 | curr_self_hsic_x.fill_(0) 92 | curr_self_hsic_y.fill_(0) 93 | 94 | # Update values across GPUs 95 | hsic_matrix = self.hsic_matrix.compute() 96 | hsic_x = self.self_hsic_x.compute() 97 | hsic_y = self.self_hsic_y.compute() 98 | self.cka_matrix = hsic_matrix.reshape(self.num_layers_Y, self.num_layers_X) / torch.sqrt(hsic_x * hsic_y) 99 | # print(self.cka_matrix.diagonal()) 100 | # self.cka_matrix = self.cka_matrix.flip(0) 101 | return self.cka_matrix 102 | 103 | def extract_layer_list_from_hook_manager(self) -> Tuple[List, List]: 104 | all_layer_X, all_layer_Y = self.hook_manager1.get_features(), self.hook_manager2.get_features() 105 | return all_layer_X, all_layer_Y 106 | 107 | def hsic1(self, K: torch.Tensor, L: torch.Tensor) -> torch.Tensor: 108 | ''' 109 | Batched version of HSIC. 110 | :param K: Size = (B, N, N) where N is the number of examples and B is the group/batch size 111 | :param L: Size = (B, N, N) where N is the number of examples and B is the group/batch size 112 | :return: HSIC tensor, Size = (B) 113 | ''' 114 | assert K.size() == L.size() 115 | assert K.dim() == 3 116 | K = K.clone() 117 | L = L.clone() 118 | n = K.size(1) 119 | 120 | # K, L --> K~, L~ by setting diagonals to zero 121 | K.diagonal(dim1=-1, dim2=-2).fill_(0) 122 | L.diagonal(dim1=-1, dim2=-2).fill_(0) 123 | 124 | KL = torch.bmm(K, L) 125 | trace_KL = KL.diagonal(dim1=-1, dim2=-2).sum(-1).unsqueeze(-1).unsqueeze(-1) 126 | middle_term = K.sum((-1, -2), keepdim=True) * L.sum((-1, -2), keepdim=True) 127 | middle_term /= (n - 1) * (n - 2) 128 | right_term = KL.sum((-1, -2), keepdim=True) 129 | right_term *= 2 / (n - 2) 130 | main_term = trace_KL + middle_term - right_term 131 | hsic = main_term / (n ** 2 - 3 * n) 132 | return hsic.squeeze(-1).squeeze(-1) 133 | 134 | def reset(self) -> None: 135 | # Set values to none, clear feature and hooks 136 | self.cka_matrix = None 137 | self.hsic_matrix = None 138 | self.self_hsic_x = None 139 | self.self_hsic_y = None 140 | self.hook_manager1.clear_all() 141 | self.hook_manager2.clear_all() 142 | 143 | def _init_values(self, all_layer_X, all_layer_Y): 144 | self.num_layers_X = len(all_layer_X) 145 | self.num_layers_Y = len(all_layer_Y) 146 | self.module_names_X = self.hook_manager1.get_module_names() 147 | self.module_names_Y = self.hook_manager2.get_module_names() 148 | self.num_elements = self.num_layers_Y * self.num_layers_X 149 | curr_hsic_matrix = torch.zeros(self.num_elements).cuda() 150 | curr_self_hsic_x = torch.zeros(1, self.num_layers_X).cuda() 151 | curr_self_hsic_y = torch.zeros(self.num_layers_Y, 1).cuda() 152 | self.hsic_matrix = AccumTensor(torch.zeros_like(curr_hsic_matrix)).cuda() 153 | self.self_hsic_x = AccumTensor(torch.zeros_like(curr_self_hsic_x)).cuda() 154 | self.self_hsic_y = AccumTensor(torch.zeros_like(curr_self_hsic_y)).cuda() 155 | return curr_hsic_matrix, curr_self_hsic_x, curr_self_hsic_y 156 | 157 | def _calculate_self_hsic(self, all_layer_X, all_layer_Y, curr_self_hsic_x, curr_self_hsic_y): 158 | for start_idx in range(0, self.num_layers_X, self.group_size): 159 | end_idx = min(start_idx + self.group_size, self.num_layers_X) 160 | K = torch.stack([all_layer_X[i] for i in range(start_idx, end_idx)], dim=0) 161 | curr_self_hsic_x[0, start_idx:end_idx] += self.hsic1(K, K) * self.epsilon 162 | for start_idx in range(0, self.num_layers_Y, self.group_size): 163 | end_idx = min(start_idx + self.group_size, self.num_layers_Y) 164 | L = torch.stack([all_layer_Y[i] for i in range(start_idx, end_idx)], dim=0) 165 | curr_self_hsic_y[start_idx:end_idx, 0] += self.hsic1(L, L) * self.epsilon 166 | 167 | self.self_hsic_x.update(curr_self_hsic_x) 168 | self.self_hsic_y.update(curr_self_hsic_y) 169 | 170 | def _calculate_cross_hsic(self, all_layer_X, all_layer_Y, curr_hsic_matrix): 171 | for start_idx in range(0, self.num_elements, self.group_size): 172 | end_idx = min(start_idx + self.group_size, self.num_elements) 173 | K = torch.stack([all_layer_X[i % self.num_layers_X] for i in range(start_idx, end_idx)], dim=0) 174 | L = torch.stack([all_layer_Y[j // self.num_layers_X] for j in range(start_idx, end_idx)], dim=0) 175 | curr_hsic_matrix[start_idx:end_idx] += self.hsic1(K, L) * self.epsilon 176 | self.hsic_matrix.update(curr_hsic_matrix) 177 | 178 | 179 | def gram(x: torch.Tensor) -> torch.Tensor: 180 | return x.matmul(x.t()) 181 | -------------------------------------------------------------------------------- /example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "a12e75d5", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "from torch.utils.data import DataLoader\n", 12 | "from torchvision.datasets import CIFAR10\n", 13 | "from torchvision.transforms import Compose, ToTensor, Normalize\n", 14 | "from torchvision.models import resnet18\n", 15 | "from tqdm.autonotebook import tqdm\n", 16 | "from copy import deepcopy" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "1bb2e232", 22 | "metadata": {}, 23 | "source": [ 24 | "## Setup DataLoader and Models \n", 25 | "\n", 26 | "An important detail is that although we are using the Validation set for `CIFAR10`, we **shuffle** and drop the last batch. This is to ensure that 1) the batches of each epoch are mixed, and 2) each iteration has the same batch size." 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "id": "47f97572", 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "name": "stdout", 37 | "output_type": "stream", 38 | "text": [ 39 | "Files already downloaded and verified\n" 40 | ] 41 | } 42 | ], 43 | "source": [ 44 | "transforms = Compose([ToTensor(), \n", 45 | " Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n", 46 | "\n", 47 | "dataset = CIFAR10(root='./', train=False, download=True, transform=transforms)\n", 48 | "dataloader = DataLoader(dataset, batch_size=256, shuffle=True, drop_last=True, num_workers=4, pin_memory=True)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "id": "48229bb6", 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "name": "stdout", 59 | "output_type": "stream", 60 | "text": [ 61 | "Dummy models created\n" 62 | ] 63 | } 64 | ], 65 | "source": [ 66 | "model1 = resnet18(pretrained=True).cuda()\n", 67 | "model1.eval()\n", 68 | "model2 = deepcopy(model1)\n", 69 | "model2.eval()\n", 70 | "print('Dummy models created')" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "id": "7f59e737", 76 | "metadata": {}, 77 | "source": [ 78 | "## Compute CKA " 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 4, 84 | "id": "5fed41fb", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "from cka import CKACalculator" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "id": "926f66b6", 94 | "metadata": {}, 95 | "source": [ 96 | "### Basic Usage \n", 97 | "\n", 98 | "Initializing the `CKACalculator` object will add forward hooks to both `model1` and `model2`. \n", 99 | "The default modules that are hooked are: `Bottleneck`, `BasicBlock`, `Conv2d`, `AdaptiveAvgPool2d`, `MaxPool2d`, and all instances of `BatchNorm`. \n", 100 | "Note that `Bottleneck` and `BasicBlock` are from the `torchvision` implementation, and will not add hooks to any custom implementations of `Bottleneck/BasicBlock`.\n", 101 | "\n", 102 | "For ResNet18, a total of 50 hooks are added.\n", 103 | "\n", 104 | "By default, the intermediate features are flattened with `flatten_hook_fn` and 10 epochs are run." 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 5, 110 | "id": "d2212ba3", 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "No hook function provided. Using flatten_hook_fn.\n", 118 | "50 Hooks registered. Total hooks: 50\n", 119 | "No hook function provided. Using flatten_hook_fn.\n", 120 | "50 Hooks registered. Total hooks: 50\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "calculator = CKACalculator(model1=model1, model2=model2, dataloader=dataloader)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "id": "6282ebb3", 131 | "metadata": {}, 132 | "source": [ 133 | "Now we can calculate the CKA matrix " 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 6, 139 | "id": "e9525f51", 140 | "metadata": { 141 | "scrolled": false 142 | }, 143 | "outputs": [ 144 | { 145 | "data": { 146 | "application/vnd.jupyter.widget-view+json": { 147 | "model_id": "f2799b606e9248fca8f378a5f05633f8", 148 | "version_major": 2, 149 | "version_minor": 0 150 | }, 151 | "text/plain": [ 152 | "Epoch 0: 0%| | 0/39 [00:00" 328 | ] 329 | }, 330 | "execution_count": 8, 331 | "metadata": {}, 332 | "output_type": "execute_result" 333 | }, 334 | { 335 | "data": { 336 | "image/png": "\n", 337 | "text/plain": [ 338 | "
" 339 | ] 340 | }, 341 | "metadata": { 342 | "needs_background": "light" 343 | }, 344 | "output_type": "display_data" 345 | } 346 | ], 347 | "source": [ 348 | "plt.imshow(cka_output.cpu().numpy(), cmap='inferno')" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "id": "e8e6cab2", 354 | "metadata": {}, 355 | "source": [ 356 | "### Advanced Usage \n", 357 | "\n", 358 | "We can customize other parameters of the `CKACalculator`. \n", 359 | "Most importantly, we can select which modules to hook. \n", 360 | "\n", 361 | "Before instantiating a new instance of `CKACalculator` on, make sure to first call the `reset` method. \n", 362 | "This clears all hooks registered in the models." 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 9, 368 | "id": "fe4a1711", 369 | "metadata": {}, 370 | "outputs": [ 371 | { 372 | "name": "stdout", 373 | "output_type": "stream", 374 | "text": [ 375 | "50 handles removed.\n", 376 | "50 handles removed.\n" 377 | ] 378 | } 379 | ], 380 | "source": [ 381 | "# Reset calculator to clear hooks\n", 382 | "calculator.reset()\n", 383 | "torch.cuda.empty_cache()" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 10, 389 | "id": "ff5c78e6", 390 | "metadata": {}, 391 | "outputs": [], 392 | "source": [ 393 | "import torch.nn as nn" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "id": "f168ff41", 399 | "metadata": {}, 400 | "source": [ 401 | "Let's consider outputs of `Conv2d` and `BatchNorm2d` only. This will create 40 hooks.\n", 402 | "\n", 403 | "For custom layers, add the custom modules in the same manner as shown below." 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": 11, 409 | "id": "9e768045", 410 | "metadata": {}, 411 | "outputs": [], 412 | "source": [ 413 | "layers = (nn.Conv2d, nn.BatchNorm2d)" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": 12, 419 | "id": "fd75a7d6", 420 | "metadata": {}, 421 | "outputs": [ 422 | { 423 | "name": "stdout", 424 | "output_type": "stream", 425 | "text": [ 426 | "No hook function provided. Using flatten_hook_fn.\n", 427 | "40 Hooks registered. Total hooks: 40\n", 428 | "No hook function provided. Using flatten_hook_fn.\n", 429 | "40 Hooks registered. Total hooks: 40\n" 430 | ] 431 | } 432 | ], 433 | "source": [ 434 | "calculator = CKACalculator(model1=model1, model2=model2, dataloader=dataloader, hook_layer_types=layers)" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": 13, 440 | "id": "d27ee2fb", 441 | "metadata": {}, 442 | "outputs": [ 443 | { 444 | "data": { 445 | "application/vnd.jupyter.widget-view+json": { 446 | "model_id": "b70e3f403590421491fc7f829567f9e3", 447 | "version_major": 2, 448 | "version_minor": 0 449 | }, 450 | "text/plain": [ 451 | "Epoch 0: 0%| | 0/39 [00:00" 606 | ] 607 | }, 608 | "execution_count": 14, 609 | "metadata": {}, 610 | "output_type": "execute_result" 611 | }, 612 | { 613 | "data": { 614 | "image/png": "\n", 615 | "text/plain": [ 616 | "
" 617 | ] 618 | }, 619 | "metadata": { 620 | "needs_background": "light" 621 | }, 622 | "output_type": "display_data" 623 | } 624 | ], 625 | "source": [ 626 | "plt.imshow(cka_output.cpu().numpy(), cmap='inferno')" 627 | ] 628 | }, 629 | { 630 | "cell_type": "markdown", 631 | "id": "afaa68d2", 632 | "metadata": {}, 633 | "source": [ 634 | "#### Extract module names " 635 | ] 636 | }, 637 | { 638 | "cell_type": "code", 639 | "execution_count": 15, 640 | "id": "b13097c6", 641 | "metadata": {}, 642 | "outputs": [ 643 | { 644 | "name": "stdout", 645 | "output_type": "stream", 646 | "text": [ 647 | "Layer 0: \tconv1\n", 648 | "Layer 1: \tbn1\n", 649 | "Layer 2: \tlayer1.0.conv1\n", 650 | "Layer 3: \tlayer1.0.bn1\n", 651 | "Layer 4: \tlayer1.0.conv2\n", 652 | "Layer 5: \tlayer1.0.bn2\n", 653 | "Layer 6: \tlayer1.1.conv1\n", 654 | "Layer 7: \tlayer1.1.bn1\n", 655 | "Layer 8: \tlayer1.1.conv2\n", 656 | "Layer 9: \tlayer1.1.bn2\n", 657 | "Layer 10: \tlayer2.0.conv1\n", 658 | "Layer 11: \tlayer2.0.bn1\n", 659 | "Layer 12: \tlayer2.0.conv2\n", 660 | "Layer 13: \tlayer2.0.bn2\n", 661 | "Layer 14: \tlayer2.0.downsample.0\n", 662 | "Layer 15: \tlayer2.0.downsample.1\n", 663 | "Layer 16: \tlayer2.1.conv1\n", 664 | "Layer 17: \tlayer2.1.bn1\n", 665 | "Layer 18: \tlayer2.1.conv2\n", 666 | "Layer 19: \tlayer2.1.bn2\n", 667 | "Layer 20: \tlayer3.0.conv1\n", 668 | "Layer 21: \tlayer3.0.bn1\n", 669 | "Layer 22: \tlayer3.0.conv2\n", 670 | "Layer 23: \tlayer3.0.bn2\n", 671 | "Layer 24: \tlayer3.0.downsample.0\n", 672 | "Layer 25: \tlayer3.0.downsample.1\n", 673 | "Layer 26: \tlayer3.1.conv1\n", 674 | "Layer 27: \tlayer3.1.bn1\n", 675 | "Layer 28: \tlayer3.1.conv2\n", 676 | "Layer 29: \tlayer3.1.bn2\n", 677 | "Layer 30: \tlayer4.0.conv1\n", 678 | "Layer 31: \tlayer4.0.bn1\n", 679 | "Layer 32: \tlayer4.0.conv2\n", 680 | "Layer 33: \tlayer4.0.bn2\n", 681 | "Layer 34: \tlayer4.0.downsample.0\n", 682 | "Layer 35: \tlayer4.0.downsample.1\n", 683 | "Layer 36: \tlayer4.1.conv1\n", 684 | "Layer 37: \tlayer4.1.bn1\n", 685 | "Layer 38: \tlayer4.1.conv2\n", 686 | "Layer 39: \tlayer4.1.bn2\n" 687 | ] 688 | } 689 | ], 690 | "source": [ 691 | "for i, name in enumerate(calculator.module_names_X):\n", 692 | " print(f\"Layer {i}: \\t{name}\")" 693 | ] 694 | } 695 | ], 696 | "metadata": { 697 | "kernelspec": { 698 | "display_name": "Python 3", 699 | "language": "python", 700 | "name": "python3" 701 | }, 702 | "language_info": { 703 | "codemirror_mode": { 704 | "name": "ipython", 705 | "version": 3 706 | }, 707 | "file_extension": ".py", 708 | "mimetype": "text/x-python", 709 | "name": "python", 710 | "nbconvert_exporter": "python", 711 | "pygments_lexer": "ipython3", 712 | "version": "3.9.9" 713 | } 714 | }, 715 | "nbformat": 4, 716 | "nbformat_minor": 5 717 | } 718 | -------------------------------------------------------------------------------- /hook_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper for CKA in PyTorch. 3 | Adds hooks to modules of a given model. 4 | 5 | Repo: https://github.com/numpee/CKA.pytorch 6 | Author: Dongwan Kim (Github: Numpee) 7 | Year: 2022 8 | """ 9 | 10 | from typing import Optional, Union, Callable, Tuple, Type, List 11 | 12 | import torch 13 | from torch import nn as nn 14 | from torchvision.models.resnet import Bottleneck, BasicBlock 15 | 16 | _HOOK_LAYER_TYPES = ( 17 | Bottleneck, BasicBlock, nn.Conv2d, nn.AdaptiveAvgPool2d, nn.MaxPool2d, nn.modules.batchnorm._BatchNorm) 18 | 19 | 20 | class HookManager: 21 | def __init__(self, model: nn.Module, hook_fn: Optional[Union[str, Callable]] = None, 22 | hook_layer_types: Tuple[Type[nn.Module], ...] = _HOOK_LAYER_TYPES, 23 | calculate_gram: bool = True) -> None: 24 | """ 25 | Add hooks to models. 26 | Mainly supports ResNets. 27 | :param model: model to attach hooks to 28 | :param hook_fn: the hook function or string. Options: ("avgpool", "flatten"). Default: flatten 29 | :param hook_layer_types: layer types to register hooks. Should be nn.Module 30 | """ 31 | self.model = model 32 | self.hook_fn = hook_fn 33 | self.hook_layer_types = hook_layer_types 34 | self.calculate_gram = calculate_gram 35 | for layer in self.hook_layer_types: 36 | if not issubclass(layer, nn.Module): 37 | raise TypeError(f"Class {layer} is not an nn.Module.") 38 | 39 | if self.hook_fn is None: 40 | self.hook_fn = self.flatten_hook_fn 41 | print("No hook function provided. Using flatten_hook_fn.") 42 | elif type(self.hook_fn) == str: 43 | hook_fn_dict = {'flatten': self.flatten_hook_fn, 'avgpool': self.avgpool_hook_fn} 44 | if self.hook_fn in hook_fn_dict: 45 | self.hook_fn = hook_fn_dict[self.hook_fn] 46 | else: 47 | raise ValueError(f"No hook function named {self.hook_fn}. Options: {list(hook_fn_dict.keys())}") 48 | 49 | # Not using dictionary because a single module may be used multiple times in forward 50 | self.features = [] 51 | self.module_names = [] 52 | self.handles = [] 53 | 54 | self.register_hooks(self.hook_fn) 55 | 56 | def get_features(self) -> List[torch.Tensor]: 57 | return self.features 58 | 59 | def get_module_names(self) -> List[str]: 60 | return self.module_names 61 | 62 | def clear_features(self) -> None: 63 | self.features = [] 64 | self.module_names = [] 65 | 66 | def clear_all(self) -> None: 67 | self.clear_hooks() 68 | self.clear_features() 69 | 70 | def clear_hooks(self) -> None: 71 | num_handles = len(self.handles) 72 | for handle in self.handles: 73 | handle.remove() 74 | self.handles = [] 75 | for m in self.model.modules(): 76 | if hasattr(m, 'module_name'): 77 | delattr(m, 'module_name') 78 | print(f"{num_handles} handles removed.") 79 | 80 | def register_hooks(self, hook_fn: Callable) -> None: 81 | prev_num_handles = len(self.handles) 82 | self._register_hook_recursive(self.model, hook_fn, prev_name="") 83 | new_num_handles = len(self.handles) 84 | print(f"{new_num_handles - prev_num_handles} Hooks registered. Total hooks: {new_num_handles}") 85 | 86 | def _register_hook_recursive(self, module: nn.Module, hook_fn: Callable, prev_name: str = "") -> None: 87 | for name, child in module.named_children(): 88 | curr_name = f"{prev_name}.{name}" if prev_name else name 89 | curr_name = curr_name.replace("_model.", "") 90 | num_grandchildren = len(list(child.children())) 91 | if num_grandchildren > 0: 92 | self._register_hook_recursive(child, hook_fn, prev_name=curr_name) 93 | if isinstance(child, self.hook_layer_types): 94 | handle = child.register_forward_hook(hook_fn) 95 | self.handles.append(handle) 96 | setattr(child, 'module_name', curr_name) 97 | 98 | def flatten_hook_fn(self, module: nn.Module, inp: torch.Tensor, out: torch.Tensor) -> None: 99 | batch_size = out.size(0) 100 | feature = out.reshape(batch_size, -1) 101 | if self.calculate_gram: 102 | feature = gram(feature) 103 | module_name = getattr(module, 'module_name') 104 | self.features.append(feature) 105 | self.module_names.append(module_name) 106 | 107 | def avgpool_hook_fn(self, module: nn.Module, inp: torch.Tensor, out: torch.Tensor) -> None: 108 | if out.dim() == 4: 109 | feature = out.mean(dim=(-1, -2)) 110 | elif out.dim() == 3: 111 | feature = out.mean(dim=-1) 112 | else: 113 | feature = out 114 | if self.calculate_gram: 115 | feature = gram(feature) 116 | module_name = getattr(module, 'module_name') 117 | self.features.append(feature) 118 | self.module_names.append(module_name) 119 | 120 | 121 | def gram(x: torch.Tensor) -> torch.Tensor: 122 | return x.matmul(x.t()) 123 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric 3 | 4 | 5 | class AccumTensor(Metric): 6 | def __init__(self, default_value: torch.Tensor): 7 | super().__init__() 8 | 9 | self.add_state("val", default=default_value, dist_reduce_fx="sum") 10 | 11 | def update(self, input_tensor: torch.Tensor): 12 | self.val += input_tensor 13 | 14 | def compute(self): 15 | return self.val 16 | --------------------------------------------------------------------------------