├── .gitignore ├── LICENSE ├── README.md ├── cka.py ├── example.ipynb ├── hook_manager.py └── metrics.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Custom 4 | 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | ### JetBrains template 111 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 112 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 113 | 114 | # User-specific stuff 115 | .idea/**/workspace.xml 116 | .idea/**/tasks.xml 117 | .idea/**/usage.statistics.xml 118 | .idea/**/dictionaries 119 | .idea/**/shelf 120 | 121 | # Sensitive or high-churn files 122 | .idea/**/dataSources/ 123 | .idea/**/dataSources.ids 124 | .idea/**/dataSources.local.xml 125 | .idea/**/sqlDataSources.xml 126 | .idea/**/dynamic.xml 127 | .idea/**/uiDesigner.xml 128 | .idea/**/dbnavigator.xml 129 | 130 | # Gradle 131 | .idea/**/gradle.xml 132 | .idea/**/libraries 133 | 134 | # Gradle and Maven with auto-import 135 | # When using Gradle or Maven with auto-import, you should exclude module files, 136 | # since they will be recreated, and may cause churn. Uncomment if using 137 | # auto-import. 138 | # .idea/modules.xml 139 | # .idea/*.iml 140 | # .idea/modules 141 | 142 | # CMake 143 | cmake-build-*/ 144 | 145 | # Mongo Explorer plugin 146 | .idea/**/mongoSettings.xml 147 | 148 | # File-based project format 149 | *.iws 150 | 151 | # IntelliJ 152 | out/ 153 | 154 | # mpeltonen/sbt-idea plugin 155 | .idea_modules/ 156 | 157 | # JIRA plugin 158 | atlassian-ide-plugin.xml 159 | 160 | # Cursive Clojure plugin 161 | .idea/replstate.xml 162 | 163 | # Crashlytics plugin (for Android Studio and IntelliJ) 164 | com_crashlytics_export_strings.xml 165 | crashlytics.properties 166 | crashlytics-build.properties 167 | fabric.properties 168 | 169 | # Editor-based Rest Client 170 | .idea/httpRequests 171 | .idea/ 172 | 173 | # Custom 174 | experiments/ 175 | wandb/ 176 | **/.DS_Store 177 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 Dongwan Kim 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Centered Kernel Alignment (CKA) - PyTorch Implementation 2 | 3 | **A PyTorch implementation of Centered Kernel Alignment (CKA) with GPU support.** 4 | 5 | This code was used for the CKA analysis in our CVPR 2023 paper, "[On the Stability-Plasticity Dilemma of Class-Incremental Learning](https://openaccess.thecvf.com/content/CVPR2023/papers/Kim_On_the_Stability-Plasticity_Dilemma_of_Class-Incremental_Learning_CVPR_2023_paper.pdf)". 6 | 7 | ## Usage 8 | ```python 9 | model1 = ... # Some model, casted to GPU 10 | model2 = ... # Another model, casted to GPU 11 | dataloader = ... # Your dataloader 12 | 13 | calculator = CKACalculator(model1, model2, dataloader) 14 | cka_matrix = calculator.calculate_cka_matrix() 15 | ``` 16 | 17 | Rather than caching intermediate feature representations, this code computes CKA on-the-fly (simultaneously with the model forward pass) by using the mini-batch CKA, as described in the [paper by Nguyen et. al.](https://openreview.net/pdf?id=KJNcAkY8tY4) 18 | By leveraging GPU superiority, **this implementation runs much faster than any Numpy implementation.** 19 | 20 | ## Setup 21 | I haven't added a `requirements.txt` since the exact version of each package is not that important :man_shrugging: 22 | 23 | #### Required packages to use the class/functions: 24 | * python3.7+ 25 | * torch (any relatively recent version should be O.K.) 26 | * torchvision 27 | * tqdm 28 | * torchmetrics 29 | 30 | #### To run the `example.ipynb`: 31 | * jupyter 32 | * matplotlib 33 | * numpy 34 | 35 | ## Example notebook 36 | Try out the example notebook in `example.ipynb`. 37 | 38 | ## Other 39 | * If you found this repo helpful, please give it a :star: 40 | * If you find any bugs/improvements, feel free to create a new issue. 41 | * This code is mostly tested on ResNets 42 | 43 | ### TODO (when I feel like it) 44 | * Ditch hooks; change to `torch.fx` implementation 45 | -------------------------------------------------------------------------------- /cka.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tool to compute Centered Kernel Alignment (CKA) in PyTorch w/ GPU (single or multi). 3 | 4 | Repo: https://github.com/numpee/CKA.pytorch 5 | Author: Dongwan Kim (Github: Numpee) 6 | Year: 2022 7 | """ 8 | 9 | from __future__ import annotations 10 | 11 | from typing import Tuple, Optional, Callable, Type, Union, TYPE_CHECKING, List 12 | 13 | import torch 14 | import torch.nn as nn 15 | from tqdm.autonotebook import tqdm 16 | 17 | from hook_manager import HookManager, _HOOK_LAYER_TYPES 18 | from metrics import AccumTensor 19 | 20 | if TYPE_CHECKING: 21 | from torch.utils.data import DataLoader 22 | 23 | 24 | class CKACalculator: 25 | def __init__(self, model1: nn.Module, model2: nn.Module, dataloader: DataLoader, 26 | hook_fn: Optional[Union[str, Callable]] = None, 27 | hook_layer_types: Tuple[Type[nn.Module], ...] = _HOOK_LAYER_TYPES, num_epochs: int = 10, 28 | group_size: int = 512, epsilon: float = 1e-4, is_main_process: bool = True) -> None: 29 | """ 30 | Class to extract intermediate features and calculate CKA Matrix. 31 | :param model1: model to evaluate. __call__ function should be implemented if NOT instance of `nn.Module`. 32 | :param model2: second model to evaluate. __call__ function should be implemented if NOT instance of `nn.Module`. 33 | :param dataloader: Torch DataLoader for dataloading. Assumes first return value contains input images. 34 | :param hook_fn: Optional - Hook function or hook name string for the HookManager. Options: [flatten, avgpool]. Default: flatten 35 | :param hook_layer_types: Types of layers (modules) to add hooks to. 36 | :param num_epochs: Number of epochs for cka_batch. Default: 10 37 | :param group_size: group_size for GPU acceleration. Default: 512 38 | :param epsilon: Small multiplicative value for HSIC. Default: 1e-4 39 | :param is_main_process: is current instance main process. Default: True 40 | """ 41 | self.model1 = model1 42 | self.model2 = model2 43 | self.dataloader = dataloader 44 | self.num_epochs = num_epochs 45 | self.group_size = group_size 46 | self.epsilon = epsilon 47 | self.is_main_process = is_main_process 48 | 49 | self.model1.eval() 50 | self.model2.eval() 51 | self.hook_manager1 = HookManager(self.model1, hook_fn, hook_layer_types, calculate_gram=True) 52 | self.hook_manager2 = HookManager(self.model2, hook_fn, hook_layer_types, calculate_gram=True) 53 | self.module_names_X = None 54 | self.module_names_Y = None 55 | self.num_layers_X = None 56 | self.num_layers_Y = None 57 | self.num_elements = None 58 | 59 | # Metrics to track 60 | self.cka_matrix = None 61 | self.hsic_matrix = None 62 | self.self_hsic_x = None 63 | self.self_hsic_y = None 64 | 65 | @torch.no_grad() 66 | def calculate_cka_matrix(self) -> torch.Tensor: 67 | curr_hsic_matrix = None 68 | curr_self_hsic_x = None 69 | curr_self_hsic_y = None 70 | for epoch in range(self.num_epochs): 71 | loader = tqdm(self.dataloader, desc=f"Epoch {epoch}", disable=not self.is_main_process) 72 | for it, (imgs, *_) in enumerate(loader): 73 | imgs = imgs.cuda(non_blocking=True) 74 | self.model1(imgs) 75 | self.model2(imgs) 76 | all_layer_X, all_layer_Y = self.extract_layer_list_from_hook_manager() 77 | 78 | # Initialize values on first loop 79 | if self.num_layers_X is None: 80 | curr_hsic_matrix, curr_self_hsic_x, curr_self_hsic_y = self._init_values(all_layer_X, all_layer_Y) 81 | 82 | # Get self HSIC values --> HSIC(K, K), HSIC(L, L) 83 | self._calculate_self_hsic(all_layer_X, all_layer_Y, curr_self_hsic_x, curr_self_hsic_y) 84 | 85 | # Get cross HSIC values --> HSIC(K, L) 86 | self._calculate_cross_hsic(all_layer_X, all_layer_Y, curr_hsic_matrix) 87 | 88 | self.hook_manager1.clear_features() 89 | self.hook_manager2.clear_features() 90 | curr_hsic_matrix.fill_(0) 91 | curr_self_hsic_x.fill_(0) 92 | curr_self_hsic_y.fill_(0) 93 | 94 | # Update values across GPUs 95 | hsic_matrix = self.hsic_matrix.compute() 96 | hsic_x = self.self_hsic_x.compute() 97 | hsic_y = self.self_hsic_y.compute() 98 | self.cka_matrix = hsic_matrix.reshape(self.num_layers_Y, self.num_layers_X) / torch.sqrt(hsic_x * hsic_y) 99 | # print(self.cka_matrix.diagonal()) 100 | # self.cka_matrix = self.cka_matrix.flip(0) 101 | return self.cka_matrix 102 | 103 | def extract_layer_list_from_hook_manager(self) -> Tuple[List, List]: 104 | all_layer_X, all_layer_Y = self.hook_manager1.get_features(), self.hook_manager2.get_features() 105 | return all_layer_X, all_layer_Y 106 | 107 | def hsic1(self, K: torch.Tensor, L: torch.Tensor) -> torch.Tensor: 108 | ''' 109 | Batched version of HSIC. 110 | :param K: Size = (B, N, N) where N is the number of examples and B is the group/batch size 111 | :param L: Size = (B, N, N) where N is the number of examples and B is the group/batch size 112 | :return: HSIC tensor, Size = (B) 113 | ''' 114 | assert K.size() == L.size() 115 | assert K.dim() == 3 116 | K = K.clone() 117 | L = L.clone() 118 | n = K.size(1) 119 | 120 | # K, L --> K~, L~ by setting diagonals to zero 121 | K.diagonal(dim1=-1, dim2=-2).fill_(0) 122 | L.diagonal(dim1=-1, dim2=-2).fill_(0) 123 | 124 | KL = torch.bmm(K, L) 125 | trace_KL = KL.diagonal(dim1=-1, dim2=-2).sum(-1).unsqueeze(-1).unsqueeze(-1) 126 | middle_term = K.sum((-1, -2), keepdim=True) * L.sum((-1, -2), keepdim=True) 127 | middle_term /= (n - 1) * (n - 2) 128 | right_term = KL.sum((-1, -2), keepdim=True) 129 | right_term *= 2 / (n - 2) 130 | main_term = trace_KL + middle_term - right_term 131 | hsic = main_term / (n ** 2 - 3 * n) 132 | return hsic.squeeze(-1).squeeze(-1) 133 | 134 | def reset(self) -> None: 135 | # Set values to none, clear feature and hooks 136 | self.cka_matrix = None 137 | self.hsic_matrix = None 138 | self.self_hsic_x = None 139 | self.self_hsic_y = None 140 | self.hook_manager1.clear_all() 141 | self.hook_manager2.clear_all() 142 | 143 | def _init_values(self, all_layer_X, all_layer_Y): 144 | self.num_layers_X = len(all_layer_X) 145 | self.num_layers_Y = len(all_layer_Y) 146 | self.module_names_X = self.hook_manager1.get_module_names() 147 | self.module_names_Y = self.hook_manager2.get_module_names() 148 | self.num_elements = self.num_layers_Y * self.num_layers_X 149 | curr_hsic_matrix = torch.zeros(self.num_elements).cuda() 150 | curr_self_hsic_x = torch.zeros(1, self.num_layers_X).cuda() 151 | curr_self_hsic_y = torch.zeros(self.num_layers_Y, 1).cuda() 152 | self.hsic_matrix = AccumTensor(torch.zeros_like(curr_hsic_matrix)).cuda() 153 | self.self_hsic_x = AccumTensor(torch.zeros_like(curr_self_hsic_x)).cuda() 154 | self.self_hsic_y = AccumTensor(torch.zeros_like(curr_self_hsic_y)).cuda() 155 | return curr_hsic_matrix, curr_self_hsic_x, curr_self_hsic_y 156 | 157 | def _calculate_self_hsic(self, all_layer_X, all_layer_Y, curr_self_hsic_x, curr_self_hsic_y): 158 | for start_idx in range(0, self.num_layers_X, self.group_size): 159 | end_idx = min(start_idx + self.group_size, self.num_layers_X) 160 | K = torch.stack([all_layer_X[i] for i in range(start_idx, end_idx)], dim=0) 161 | curr_self_hsic_x[0, start_idx:end_idx] += self.hsic1(K, K) * self.epsilon 162 | for start_idx in range(0, self.num_layers_Y, self.group_size): 163 | end_idx = min(start_idx + self.group_size, self.num_layers_Y) 164 | L = torch.stack([all_layer_Y[i] for i in range(start_idx, end_idx)], dim=0) 165 | curr_self_hsic_y[start_idx:end_idx, 0] += self.hsic1(L, L) * self.epsilon 166 | 167 | self.self_hsic_x.update(curr_self_hsic_x) 168 | self.self_hsic_y.update(curr_self_hsic_y) 169 | 170 | def _calculate_cross_hsic(self, all_layer_X, all_layer_Y, curr_hsic_matrix): 171 | for start_idx in range(0, self.num_elements, self.group_size): 172 | end_idx = min(start_idx + self.group_size, self.num_elements) 173 | K = torch.stack([all_layer_X[i % self.num_layers_X] for i in range(start_idx, end_idx)], dim=0) 174 | L = torch.stack([all_layer_Y[j // self.num_layers_X] for j in range(start_idx, end_idx)], dim=0) 175 | curr_hsic_matrix[start_idx:end_idx] += self.hsic1(K, L) * self.epsilon 176 | self.hsic_matrix.update(curr_hsic_matrix) 177 | 178 | 179 | def gram(x: torch.Tensor) -> torch.Tensor: 180 | return x.matmul(x.t()) 181 | -------------------------------------------------------------------------------- /example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "a12e75d5", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "from torch.utils.data import DataLoader\n", 12 | "from torchvision.datasets import CIFAR10\n", 13 | "from torchvision.transforms import Compose, ToTensor, Normalize\n", 14 | "from torchvision.models import resnet18\n", 15 | "from tqdm.autonotebook import tqdm\n", 16 | "from copy import deepcopy" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "1bb2e232", 22 | "metadata": {}, 23 | "source": [ 24 | "## Setup DataLoader and Models \n", 25 | "\n", 26 | "An important detail is that although we are using the Validation set for `CIFAR10`, we **shuffle** and drop the last batch. This is to ensure that 1) the batches of each epoch are mixed, and 2) each iteration has the same batch size." 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "id": "47f97572", 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "name": "stdout", 37 | "output_type": "stream", 38 | "text": [ 39 | "Files already downloaded and verified\n" 40 | ] 41 | } 42 | ], 43 | "source": [ 44 | "transforms = Compose([ToTensor(), \n", 45 | " Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n", 46 | "\n", 47 | "dataset = CIFAR10(root='./', train=False, download=True, transform=transforms)\n", 48 | "dataloader = DataLoader(dataset, batch_size=256, shuffle=True, drop_last=True, num_workers=4, pin_memory=True)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "id": "48229bb6", 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "name": "stdout", 59 | "output_type": "stream", 60 | "text": [ 61 | "Dummy models created\n" 62 | ] 63 | } 64 | ], 65 | "source": [ 66 | "model1 = resnet18(pretrained=True).cuda()\n", 67 | "model1.eval()\n", 68 | "model2 = deepcopy(model1)\n", 69 | "model2.eval()\n", 70 | "print('Dummy models created')" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "id": "7f59e737", 76 | "metadata": {}, 77 | "source": [ 78 | "## Compute CKA " 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 4, 84 | "id": "5fed41fb", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "from cka import CKACalculator" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "id": "926f66b6", 94 | "metadata": {}, 95 | "source": [ 96 | "### Basic Usage \n", 97 | "\n", 98 | "Initializing the `CKACalculator` object will add forward hooks to both `model1` and `model2`. \n", 99 | "The default modules that are hooked are: `Bottleneck`, `BasicBlock`, `Conv2d`, `AdaptiveAvgPool2d`, `MaxPool2d`, and all instances of `BatchNorm`. \n", 100 | "Note that `Bottleneck` and `BasicBlock` are from the `torchvision` implementation, and will not add hooks to any custom implementations of `Bottleneck/BasicBlock`.\n", 101 | "\n", 102 | "For ResNet18, a total of 50 hooks are added.\n", 103 | "\n", 104 | "By default, the intermediate features are flattened with `flatten_hook_fn` and 10 epochs are run." 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 5, 110 | "id": "d2212ba3", 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "No hook function provided. Using flatten_hook_fn.\n", 118 | "50 Hooks registered. Total hooks: 50\n", 119 | "No hook function provided. Using flatten_hook_fn.\n", 120 | "50 Hooks registered. Total hooks: 50\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "calculator = CKACalculator(model1=model1, model2=model2, dataloader=dataloader)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "id": "6282ebb3", 131 | "metadata": {}, 132 | "source": [ 133 | "Now we can calculate the CKA matrix " 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 6, 139 | "id": "e9525f51", 140 | "metadata": { 141 | "scrolled": false 142 | }, 143 | "outputs": [ 144 | { 145 | "data": { 146 | "application/vnd.jupyter.widget-view+json": { 147 | "model_id": "f2799b606e9248fca8f378a5f05633f8", 148 | "version_major": 2, 149 | "version_minor": 0 150 | }, 151 | "text/plain": [ 152 | "Epoch 0: 0%| | 0/39 [00:00" 328 | ] 329 | }, 330 | "execution_count": 8, 331 | "metadata": {}, 332 | "output_type": "execute_result" 333 | }, 334 | { 335 | "data": { 336 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZ4AAAGcCAYAAADptMYEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA0AUlEQVR4nO3de6xc53nf+9+7Zs/sC+8USV1ISZRFSZZgy5Kj40udOqkTAU7t2G5r+CRIHKUwYDRIgzRNT+oWSNtz2h4kAU5jt2iLGLETBUmjuFFQG01T27DdOHaOL7pYtWxVEmVRIilq87Lve+5rvf2D243i6v295GzxJUV9P4BhcT/7XbPWmjXzzJDPs54QYxQAAKVUF3sHAACvLCQeAEBRJB4AQFEkHgBAUSQeAEBRJB4AQFFTm1kcQni7pI9Iakn6zRjjr7jf37NnWzx4cG8yHmOdjNVH5+2+DHrTydio9oe544pFG68HnWSsqhq7dml5u407o+g/F9Tmobd3Rnbt2qht4+MYkrGhP2Q1mRL9PdPp53mcOebGbLo73tznqLU4TD+uxnbtjmouGRu7nZaUa2gYx/QJr4N/MrZW/nl2+u4Ck9QP3WQsd0x1M8j8RnoLc2GnXTkI/cy20+omfUw5O6r0e9u56Mtdf5kXnVGb6/pchAm/nzRxqCaOX/SNZOLEE0JoSfq3ku6WdEzS10MIn4oxfju15uDBvfrq1/95cpujej0Z6/78r9n9OfytW5Kx48u77Np3/NT9Nr52eH8y1p7zL6BP/sndNu6cGqSTqSQtDtJP3w9fe8yu/fKJa2z8zKCVjB3PvDZ7tX/b+elD6US/2J+xa1fH6WN+eMGfr1zL2lcG6XO2otN27Y/O3ZGMnR74N41RJjGdrNMnfC2kXzOS9JYtV9m4e+Qn1/wb+KPVN9LbNclSkhb7h21cZv2rZ95hlx4J3/LbNha6j0y89q1z77PxXOp4PB5NxgahZ9dGs/XFUXq756JdzU60brX/RDK2mY+Ib5B0OMb4nRjjUNJ9kt69ie0BAF4BNpN49kt6YSo9tvEzAACSLnhxQQjhgyGEB0IID5w6tXKhHw4AcInbTOI5LunaF/z5wMbP/pIY40djjHfFGO/au3fyf2gHAFweNpN4vi7pphDCDSGEjqQfk/Spl2a3AACXq4mr2mKM4xDC35X0aZ0tp/54jHHychIAwCvCpvp4Yoz/RdJ/Offfr23JdLu1JRmrWpmeglG612aQ6eOpV325YG8p/VeErfaCXTto/JfK6Pplar+236TXDsa+f2NxmC6XlqSlYXrbuf6OXp3u05Gkcea4nNqcr36mjDu7bdOrU8v3RbmHboX0PkuSeRolSdNKP5ej6EvIq8y23X53gr9GWma/Yqa/qAr+Nen6ptqZt63ctr3MCTO2tSfvu5Ok2VG6FyzH9flMt7ZOvF1JaofJyqnXQvp8cOcCAEBRJB4AQFEkHgBAUSQeAEBRJB4AQFEkHgBAUZsqpz5f9dF5e5dpVzK97SPpu1pL0h1nvpKO1f4uu+1Httn4lW84lYzFOV+q+JP777Nx5zP3vcvGT/bS5ZevvfUxu/bEut/vfp0upd0y5UcELGbuqv1Xv++hZGzlzE67dmzuTr1/7nq7tsrcrH/3/MFkbGnot/3u69LXyLop9Zekcabk/mg3Xc66Mtpp175x7xkb743S5/OEub4kae/C9yVjw8wdtx+ZudLGXXnwD+z01+7BtTfauKtu/+bcQbvWuX2nbyMYmVYASdq+vi8ZWx358+nK4vu1v3Zzd82eztXkJ3xeF+bu1AAAnDcSDwCgKBIPAKAoEg8AoCgSDwCgKBIPAKAoEg8AoKiifTyD3rQOf+uWZNyNNnB9OpK05Yo3JWN1M8js2WdstN63PxlrPfOkXVvN+tvpu9RfR/+5wEXduAVJunIuPZ5CknpmrMKW9tCubVe+M6AZpXuEVtd8j8bcbC8Zu2qrH61eZ/pldnV2JGMh8xmt00r3cKyYERNSvo9ixZyvfu23PTT9WJI0MtfY6jgz0sP0RdXR950Mg7+GnLWxP+ZBkxvJkI41mXEOzhOr/u10lNn0wjB9Da3WmbEc5ipaD76PMafKXGMpw5g+Hr7xAACKIvEAAIoi8QAAiiLxAACKIvEAAIoi8QAAiipaTj2qp3R8eVcyPqjTu5MbbeBKpluVv01/PD5v4y1Xnrm45rdd+9wezO3j3S3rJT+6YGBupS9Ja0N/q35Xhtuu/O3f3X5J0qA/k35cU8YtSS1TctrLjB/IVPhq1KTLRnOlsAMzrqHOlLbnyrw3Y5TZdi5+KcpMXLDl0jlVpoXB2cTDnn3sCWO535iK/vWYEyY+svS6l99VBwB4WSPxAACKIvEAAIoi8QAAiiLxAACKIvEAAIoi8QAAiirax7PjikW946fuT8br1XTvSfuRbZmtp0cb5Pp0wk9/zMa73WeSselHft+uHf3xSf/YVbop4ea9fr8X1tPnZO/NR+zaW9fnbHxlPT2eoJfpAeqbnhZJ2nHV6WSsHvueg+mZdL/WILNfVfANIFum9iZjw8bv187ZbjK2mukvUmaMhF1qRhNI0kxrPPG2W5n2jX7txiL4tVG+F8zJ7Ve3nvx8LldnJl4707rKxnP9RSPTaDbOPM9uRMV65Ueg5EzFydJENCMm+MYDACiKxAMAKIrEAwAoisQDACiKxAMAKIrEAwAoisQDACiqaB9PPeho7fD+ZLy3tD0Zu/INp/y296W3a+fpyPfpSNLc3PXpx618f0d7z4qNh1a6/r7TGdq12+t078jU9nRMknbt9f0KM7Pp+UenTl9h144zPS/RzL0ZZnpxZubS+9Vp+56VxjyuJLVNT1XHxCQpmB6h3HyinNXR5HOCRpnnws0KWh3lZvmkH9z1pEjSSOl+LEmqQnq/e5kWoEHjf6EK6WNuNtFftDj0xzzIbHqtGSVjK8H34kSln4tuWPUPnBHMc+HU5lzyjQcAUBSJBwBQFIkHAFAUiQcAUBSJBwBQFIkHAFBU0XLqqmrUnkuXUbbaC8lYnEvfpl+SWs88mQ4urtm1udEGrmS69cb/y64d7LzXxmW2fe3hx+3S/sqWZCz8wI127bZlP3JhrpsuA3fl0JLUXvFlzdvfli7l3nrsuF1rqkbVfNV/jnIlz5J01YkDydh05cu8bzyYLsmf6/jS4XGm3HpbO12+vj5q27Wv2u3bEPpmZMO2ti/nr+PuZGyUuUbWlm/y2zaluHfu8uezU03buBtP0F15lV3rvHVfutRf8qXrknS0O5OMLQzSI2MkaWwv7X12bU5nwq8n9/cfSMb4xgMAKIrEAwAoisQDACiKxAMAKIrEAwAoisQDACiKxAMAKKpoH8/S8nZ98k/uTsYHTToP/uT+++y2q9n0LcVjnbm9+x+ftHE32iDXpzN9yz023sR0z0vrmp+zazu7zO3Ox752v+3WSmpvT9+Gffb0Lrt2lOktcULL3+c/ms9K7ZnMrfYz297STl9DuR6MVsdcf5m1OY1ZX2V6k9YH6d4QSeqaPp7cSIW2eVnl9mtK/pxU5q0p91zkNGbX2pt4Szwz2Nzb6dCMTVj3rXEam4Oandrc+boQ+MYDACiKxAMAKIrEAwAoisQDACiKxAMAKIrEAwAoisQDACiqaB9Pzqb6HUwKDa5wX1KoMvGWiZt5OpLv05GkKqSfgvHA98M0I/P0VZnZNB3TNCBJTToeWn7t1FSm6WBoel5Gk38Wyj6Ple/jqZRen5vlsxkx09MSN/HQuf0O5phz3DOVubo2tW3/LF64x81pMs9j5m0ou35Suesndz7rCS8Rt4xvPACAokg8AICiSDwAgKJIPACAokg8AICiSDwAgKKy5dQhhI9LeqekkzHG12z8bLekP5B0UNIRSe+LMS7mtjWKlU4NppPxoRlf8Jn73mW3Xcf02p4rO5Z08955G+90hsnYtYcft2tzow1cyXTnn/xru7Y/MPv90G/ZtXHoy8DVpEs7nzu63y5d7m6x8YOrDyRjoeVLSsdLc8nYwvweu9Y9j5K0NExfmwtDX9o+7KbHDxxb3m3X1uZcS9JTa1vTazOlrrly6t44/dpYzrxunjSTNepMDe98tWDjzrH1q2382dwMAePp1jOTr127ZeK1knSqny5CX6j9yI/GFEWPMwXTMVNSv1Xp14XTr9OPey7feH5b0tu/52cfkvS5GONNkj638WcAALKyiSfG+EVJ3/vx5N2SvjsB7V5J73lpdwsAcLma9N94rowxntj47+clXfkS7Q8A4DK36eKCGGOUuTtCCOGDIYQHQggPrNe9zT4cAOBlbtLEMx9CuFqSNv7/ZOoXY4wfjTHeFWO8a0trdsKHAwBcLiZNPJ+SdM/Gf98j6ZMvze4AAC532cQTQvh9Sf+/pFtCCMdCCB+Q9CuS7g4hPCnphzf+DABAVraPJ8b444nQD53vg9WNtDhIP2Tf9DOc7KX7NySfQfu171lZWN9m49vrbnrbK75npbPLNDvIjzawfTqSZqbTNR2529K39mR6HUwfRt34zyu527uPT6d7XsaZ81mbvqduz/9V7sj0rEhS11wnPdNjJknra+n9PtNPH6+UHwdyZpDer9yt9nd1OjbeN8d1xrxWJWl9nL7Kcrfa74X0aypneeTPV7eZvI9nVPl+GefMYHMDG9bq9H6vZs5XNGfc9fhIUhMy+z3hYTWmP4g7FwAAiiLxAACKIvEAAIoi8QAAiiLxAACKIvEAAIrKllO/lLZ3Rvrha48l44NxulT2tbc+ZrftSlIHXV9mu/fmIzY+tT1dyhh+4Ea7VuN9Pl6Z3J8ZbeBKpltv9jcMH+y+18ar/noyduPXDtu1veX0bfwlqXrrgWRs5pkjdq1m0uXBh6K/RkLL14UeOpMeq7DQ9+X8+29Ln5PXrPpyfTfS42z8qmRskCltf+0Vp2y8N0q/5k50/fO4Nk6XkI8zJbhn1jKvC+NVW3259LDx5evOUj99beYc2ubLvDOTIuzImFlTUi9J49zGnczS6dZk30+eDOl1fOMBABRF4gEAFEXiAQAUReIBABRF4gEAFEXiAQAUReIBABRVtI9nbdTWl09ck4wvDtO16ifWfU/BlXPpvpO1ob81/K3rvkdj194zydi2ZT+6oJ0ZixA66W6caM6H5Ecb5Pp0pm+5x8brJn17+OH6RzJr/eeZtd8bJWPHn36TXbv/xmeSsaXn99q1OUfXtidj65mRCk8/cmsy9tlnD9q1VaaR4sumFacVfO9Ib5zuAZKkgRlF8uiSfx6PDFeSsZg5pif1kI0H85n48yffbNceDc9OvO1j9aN2bWV6UxZXrrdra/n+ozqmXxfr9Wm71o5FiJsb11BlXs8p6zH93sc3HgBAUSQeAEBRJB4AQFEkHgBAUSQeAEBRJB4AQFEkHgBAUUX7eMYx6IyZK7E0TPcU9Gvf09Izs3yGmbUrmR6hmdl+MjbXTfcySFJ7e7q/SJLUmKk6psdCkh3w4ebpSL5PR5JaVXo2yNjMcJGkOtPzMh6m17u5SpJUm7XDge/XqirfW9KYx24yM0vG5phHmeexlXmaswNTjFHmfI7NvuWO2fXqNKavZLOa3OyZ3MtmE/vmemJa0b8ucvvVmAlbIfj3MHeJtEzv0bnIPnZqnTlgvvEAAIoi8QAAiiLxAACKIvEAAIoi8QAAiiLxAACKKlpOPWyk4910vF+nSxW3TPlbim9pD5OxdmVKliX1MmMTTp2+IhmLmVLZ2dO7bDy00vv23NH9dq0bP3Dj1w7btbnRBq5k+prf/wW7djBMj5GQpGf+5h8lY/NLu+3a237h8WRsZyd9W3lJUuPLaP/G7y4mY721LXbttW9/OBk7+OB37Now5a/P7//6ncnYTCd93UvSDTccsfG+GQly8tQeu/b0+rZkLDca44+P/TUbr8zL6oevXrBrn1m72sadby9fN/HaHW1f5u1GUEj+mE/2NlNSP/layZdFO58bpEfG8I0HAFAUiQcAUBSJBwBQFIkHAFAUiQcAUBSJBwBQFIkHAFBU0T6eJkb16nRNea9O9zMsDtK36ZekdpXu0ciNVOjnbuPfpNe3V3x/0SgzQmDK9Cctd33vSGPq63vLftRDrs/CjTbI9elMd9J9T5K0uLo9GTvdTfeVSJJW02MoQsf3Y2nsn6t6tDMZy41raFbT53N9MX28kjTV9v1Hi/3ZZGw2c+12V/x10O2lt73U89ffinlN1tFfXz3fumRHRawM/XvB6njyz9Pu/SmnCv4aGWaO2S3vmh5HSXbQQ73JPp5JNWZWA994AABFkXgAAEWReAAARZF4AABFkXgAAEWReAAARZF4AABFFe3j2TNd66cPpWeejOt0Hvyr3/eQ3XYzSvfaDPozdu2Oq07buJu5s/1tvqcla5ju4Ti4+oBdOj6dPq7qrQfs2rXf870j42G6/8jN05F8n44kvfFP/89k7LXdZ+xafXlnMtT9tO9XqNq+j+ebT9yUjJ3O9LT8rdufSMbGmV6bfs9fn7tmeslYb+z7xNYzfVE989rIzbF6ai3dIzTKzJ6Z7/vrz/XEtEzPniStDP3nadcv81x/YNc6P3H95L02kvTsevq5aAXfizg2G29lvl40mTafLRNmiQcH6RPNNx4AQFEkHgBAUSQeAEBRJB4AQFEkHgBAUSQeAEBRRcupx7HSYqa0OWXlzE4bXzWlncNMyWk99qWKw2H6dvtbjx23a0PLF1HGUTr3B3dveEnjlXSJ78wzR+za40+/ye+XGQMwv7Tbrs2NNnAl03Nz1/v9OpMuX188dqtd2+kMbfykKZk+nbluh4vbkrHuuj8fo9z1aUYMDDMjP5bXM2MRMiMGnHUzfmCUqR2uzS3zz65Pb2CUGekxyJRy28fNFD03Jt7OfIwPwR+ze7ln3grkplB0Nvn1YqY12ViFyuwz33gAAEWReAAARZF4AABFkXgAAEWReAAARZF4AABFkXgAAEUV7eNporRqbhFfm96R3K3l52bTt45vDf3t3adn/K3QZ+b66WCmXyFuIrePl3z/Rz0w/R8z6d4jSdp/ox8/UJuxCLf9wuN2rVZXfNyMNnB9OpIUfuw3krGrbvk3dm3T2WXjP/obf5qMdRd22LXtv3NzMnbzZ75o17qeKUl67ku327iz56ZnbXy0mr7G+ku+B8iNJ6gzvTbSlZl42p1XH7HxSn4kSGX6aUaNP2Zna3th4rWStN+9z8j3kdWm1Sb3TOTGNcxlehFTWuY8840HAFAUiQcAUBSJBwBQFIkHAFAUiQcAUBSJBwBQVLacOoRwraTf0dn6xyjpozHGj4QQdkv6A0kHJR2R9L4Y46LbVndc6eGF9G3Y+6YmcH/mdvlXbU2X8PZGvrR4YMYeSFKnPU7Gmq/63N3OlGqHKn3MC/N77NpubzYZOxQfs2uXnt9r48NB+pzs7Izs2tDx57P76fQx50YbuJLpqTt/zq6tG/9cbHvdzyRjc6dP2bUyFeTjxfS4BSk/OuP481clY7Mdf0zbl315+sCM1jg5v8+uHdfpa9+NcpCkXu1LyN2t/J9f3mnXnjHXbs7a2I8ACErv9xfn/biQYaYq2Y0RmO/5/YpmzEQz2VSD/6UKk30/6abfNs/pG89Y0i/GGG+T9CZJPxtCuE3ShyR9LsZ4k6TPbfwZAAArm3hijCdijA9t/PeqpMck7Zf0bkn3bvzavZLec4H2EQBwGTmv71AhhIOS7pT0VUlXxhhPbISe12ZakQEArxjnnHhCCFsl3S/p78UY/9LfaMcYo/Tif8kYQvhgCOGBEMID/Zi+rQ0A4JXhnBJPCKGts0nn92KMf7Tx4/kQwtUb8aslnXyxtTHGj8YY74ox3jUT0v8YDgB4ZcgmnhBCkPQxSY/FGP/VC0KfknTPxn/fI+mTL/3uAQAuN+dyd+q3SHq/pG+GEL6x8bN/LOlXJH0ihPABSc9Iet8F2UMAwGUlm3hijF+SksXrP3S+DxgnrCmvTJ265G/DnntMd5t0SWqadIF9yKytMj0awdxavtMZ2rUjMyoi1xuSU5n+IjWZbY9NAb+kyvRF5Y7ZjTbI9em0qnQPmSTFtnmep33vUj23Pb3WHO/ZuB/bMVWl41NTmXOdu/5MfKrl9yuYvpOQedG1fBuPvaV+O3NMLXftZrTcQWVMtzLvBZvato83ZrTGpO+539Wa8DYD7ni5cwEAoCgSDwCgKBIPAKAoEg8AoCgSDwCgKBIPAKAoEg8AoKhzaSB9yazFob4yOJaM10r3JOyeP2i3vauzIxkbmT4cSdoy5WfTtE1fwFUnDvhtt33/h+tPWhr6vpNunS7uP3TGz/I5upbuO5F8X8Df+F07dkn1aKeNf/OJm5Kxkz0/u+ZHf+NPkzE3T0fyfTqSFH7qY8nYaDBv13b+/N8lY/1Ff67bO9ZsfKU/l4zlem2W56+w8WE/fY3l5lS1TX9RO/ORdt+M3+8Z09/mrs1zEc363Cdx14rzfM/v19DMG5Ok2an0+qWhP1/OIG6yp8/MIHJGpoGIbzwAgKJIPACAokg8AICiSDwAgKJIPACAokg8AICiipZTNxprRaeT8Vrp0uOl4fV228Hk0FGmmnDY+HuOd0w59XTlS07rTOmnG6uwMGzbtb06fcwLpgRXktbNSAVJakzlZ2/Nlzy7clVJOm1Kpk/3Z+za7kK6bH7u9Cm7NjfawJVMz0xfadfG5XSJeX9pn12bszJIlzxPT/ljml3dauM9c75X+35icG/sr0+nX/trxJXwrprzIUnrY/96dtd2N1Py7D6pD9yGJQ0z40TGMb3fa40ff+GMTJvKuagm/H7SUE4NALhUkHgAAEWReAAARZF4AABFkXgAAEWReAAARZF4AABFFe3j2VHN6Ufn7kjGXQn9u6/zPRodc3v4QaZnZeds18Zdr82NB5+xa1sd32fhDLu+p2Xd9NPsv+2wXfv0I7fa+Nics2vf/rBd26z6zzN/6/YnkrHh4ja7tv13bk4HV+xS1XN+PIEbbeD6dCQpvOdfJ2O7pn/Jro0zvl/mr5x8JBmbnuvbtVuuf97GR0vpPp+673vUjj9xQ3pt46+BJ5Z32njbvOZef+hJu3b6SHq/cga1Hyfi3LDVPxejzDmJpt/muZ5/Llzv3KDxfU850xN+PXlqPr2QbzwAgKJIPACAokg8AICiSDwAgKJIPACAokg8AICiipZTj5uo04P0rcFbIV0SuD7y5YQrw/Ta3GiC1cy2+3X6duVznYFdmxsR4Bxb3m3jZ8wt7V+z6suSP/vsQRsfNen9Pvjgd+za9UVftuxKtbvrfpzDzZ/5Ynq7i35cQ2j728P3zX7nRhu4kunwI79m147rdRu/YubfJGOxkx4TIUnDT/htj9bTpdxb3+9HKhz68lfT+5UZTTC/5K/tSuly6uve8ZBdu/0rvvTdevD1Ey+99VrfWuGu+5wTZ3yZt3ufyY2vaDLvUTNTk41V+I8L6VYSvvEAAIoi8QAAiiLxAACKIvEAAIoi8QAAiiLxAACKIvEAAIoq2scTJY2adH2+aR3ROHNL8XR3UP4W7arcam9senzORZTpP3InRL52v47+mF2fhCS1zEOHqfQICkmaavtREP1euv9olOk5cMccWv55DG2/3+0dazbuuNEGuT6ddsv3HzVds77y11+o/PNs4+PJR3oo87hTlX8uWu41mWkr2Uzv3GYMBn78QO7arsw56Y0yrwvzPpLrgcyZ9Hy6/iC+8QAAiiLxAACKIvEAAIoi8QAAiiLxAACKIvEAAIoi8QAAiio7jyc2Oll3k/FppWvVj3bTfRKStDLaXD+NszpK16Nva19h1+ZmXUTT7vDUmp+HcmaQPuY6XmXXfvmUDUumz+f7v36nXbnY98/VrpleMpbrP3ruS7cnY8ef98ec6x1Z6adnAa1kejT+yslHkjE3T0fK9OlIqn7onydjvf5xu/bpf/ZpG++ZmU6v/uXr7Nq5o2b+jOnXk6R9O/3MHNcf17zmZrt25/AJG3duPeXnBDkH3/WA/4VMu2Dspt8rbji612/avP8NVn2fWE5nS/r16mw7nu6L4xsPAKAoEg8AoCgSDwCgKBIPAKAoEg8AoCgSDwCgqKLl1HVotBbSpaOjmC5ZXRnttNvu15PfCj03ImBkyiDXM7crr4LftlNnlrqK1UFmFEQrTH6+ZjpDG58d+8uqZ24PP9zEmInZzsDGp6b8/fSnWuly6+kpPyJgeq6fjMXODrs2N9rAlUzPzuy3a2em/Tmx1+dUutRakuKVvnzd2b5jxcbdKJPY8fsVNvGu1sqM/HBGR337QzP0Oxam0m80ayd32bXRvG566769IWem71sJUprx88kY33gAAEWReAAARZF4AABFkXgAAEWReAAARZF4AABFkXgAAEUV7ePZWrX1li3p2v/KtJa8ce8Zu23X/zHK9LTMtHx/x6hJb/tVu/18gfVBpufA9FG4mCTt6nSSsdde4ferN/Y9GCMzzuGGG47Ytd0V38+w3k2PH1he92v33PRsMrZ92V8jVcvfl355Pj3iYnbV79eW69M9C8NP+LEHofLPsxttkOvTedV/+hkbdz1C3b97r1373Hy6h2iU6eX69JO32LjzI7/8P2z8m8d/fOJtP7yw3cbdM/W6x2+1a3M9aqvmnB3vpl/rkt+vzfQ4SlInc32mzK+mry2+8QAAiiLxAACKIvEAAIoi8QAAiiLxAACKIvEAAIoi8QAAisr28YQQZiR9UdL0xu//YYzxn4YQbpB0n6QrJD0o6f0xRj+oRb7e3M2f6Y38ro5iOofm+nhyatPT0h/5+vpuJh7MGelleiH6dfq4epk5QYPG1/aPTby/nu7DkaRuz8//6PXTvU3doZ/9MVpNP/ZgZYtdGzJ9PEMzd8TtsySNltJ9PqPMPJRcH4977Ny8J9enI/l5Pktdf8yDYfraHmeuXdcnluMeV5LGmde7uwrG/hKxcq/X3PuQez0PM69XN5urP/mIoY1tT/ZcuX06l3fkgaS3xRhfJ+kOSW8PIbxJ0q9K+vUY4yFJi5I+MNHeAQBeUbKJJ561tvHH9sb/oqS3SfrDjZ/fK+k9F2IHAQCXl3P6O6gQQiuE8A1JJyV9VtJTkpZijN+918wxSS/6nT2E8MEQwgMhhAf6Te8l2GUAwMvZOSWeGGMdY7xD0gFJb5D06nN9gBjjR2OMd8UY75qpNjf7GwDw8nde/+oeY1yS9AVJb5a0M4Tw3X9NOyDJ/ysmAAA6h8QTQtgbQti58d+zku6W9JjOJqD3bvzaPZI+eYH2EQBwGTmXsQhXS7o3hNDS2UT1iRjjfw4hfFvSfSGEfyHpYUkfy22oXzd6cq2fjHdC+rbhJ3q+hHd1PHnJdCtTLbg6Sm97W9tXkLuRCjnLmRLyM4N0/ETX38b/0aVMyakphTx5ao9du9TzZc3tavL6zr4pWz45v8+unWr5x3Vluqt9/9fEdT+9duv7/XOh8ciGX/3L16WDU77kOTfawJVMX/0f/r5du/fr/18yFhpfl3zNv8/8BYkpH976a6+3S2//0v1+28bdnzo48dpdt6RHdkhSNOXSktTa2U3G1g9f47c9Tr/P9Ff96zFnemt6v5wv3b+cjGUTT4zxv0u680V+/h2d/fceAADOGXcuAAAUReIBABRF4gEAFEXiAQAUReIBABRF4gEAFHUufTwvmX7o6tHqG8l4S+lb+e9d+D677WgHLnh9N49B0sj0JNRxt13bzqR2F35y1a9dH6f7UtbGvnb/yHDFxt35PL2+za5dGfjRBk+tpfta1jP9WK0q/VyMM30SIdOv5fqLemM/ZuL4EzckY4e+/FX/wBlzR59JxuKVV9m1z82nxx5IvnfJ9elI0tT/8YvJWN0M7No93/gZG4+99HMZj2Z646rMi868brZeteDXGt/58ztsfJgZkXLdoaeTsZNHr7ZrGzNyIW5iBIUkVYs7Jlo3Hj6X3uakOwMAwCRIPACAokg8AICiSDwAgKJIPACAokg8AICiipZTR0kxpsthY0jHhu4+/ZLqOHk5daaaWiOz7ZG5fbskVSGz3y6WOSZ34/mxvyt9tvy8MVuvTemmJNXRx905G2X22z127nFD5nzmSt8dt1/ulvWSpCpzAWaufWc09i/xsYnnRhu4kulW5Uvqo68sthdwrCYfNSJJqsxr1rwHSVIwz1WubLnJvFfItApUmWvEPVUh8x6UM/F6c7h84wEAFEXiAQAUReIBABRF4gEAFEXiAQAUReIBABRF4gEAFFW0j6duBlrsH07Gq5DenUdmrrTbHobhxPsVbTeNNFK6X2Ft+Sa7dsoVs2fMV/4W7b3QTcbOrO2za5/UQxPtkyT98bG/ZuM9fzo13x8lY3V2vEX6OujV/ly3Mk/Fvpn0jvcz235ieWcyNr/kR2dMmXEMkrRv52Iytn2HH2/x6SdvsfGR6T255t8ft2vdaINcn0742x+38d76U8nYzO/8S7t2eMLfxr+p031A0cTOxtOxo6f32rX92o/W2HZsLRl77vQeu3ZUp987Z9qTvzduxtj0r/GNBwBQFIkHAFAUiQcAUBSJBwBQFIkHAFAUiQcAUBSJBwBQVNE+HilKZh5Po7GJZQa1XEBVSNej15keoCpzii9W5g+beGQ3zkTK98tUwc3jmfx57mQOqZWZKzLj5qFk+rHaZttVpjepZR5XksamtyQ3G2lTMvNjYs88dmYglOvTkaS5LTemd2ucmXsz8q8516szWNli1zqL/Vkb75teG0laXduajK1ktj0y18HKwM9GysnNFEsZN/TxAAAuESQeAEBRJB4AQFEkHgBAUSQeAEBRJB4AQFEhxslK5SaxpboivnrmHcl425Qe/8DudKmhJK2ZEssmc4i58l93m/87d6VHJkhSbW47L8kWiR9b97dRXx6lt/2qrenSdEn6/Elf7tqY6+Jnb163a1eGvnzTlQ+7slBJuvPqY8nY82Y0gSS1W7ljTp/P1UxJ6usPPZmMXfeOzAgK/1Spec3NyVjszNi1h395m40Phun5BTf97tV2befoI+lg5ccLVF/77zYu83qufu5jdung8Xv9to2pP/mzideOl/x7VBz7a3v69enxBaNH/dpoRhAMl/1+5Uxt7U207q2/dVoPnXjxNym+8QAAiiLxAACKIvEAAIoi8QAAiiLxAACKIvEAAIoi8QAAiio6FmEQ+joSvpWMVyG9OwfX3ui3bW6nn7uNf7f2/R2DJt3I06k2d8tx59l13+DRbdLxYeP7O46GZ/2Dm3P2zJrv71jN9CusDNPxQeZW/JUOJGNnBumeFElqVZP3rK2bPglJmj5yQzK2/SuLdm3M9HrtHD6RjJmXjCTpm8d/3MbHpm/q9i/d7zdeTf65dXhih4270QZVpk9n+pZ7JtonSRqvnPa/YPqTlj7sX6/DzPV54I1HkrHW9hW71o2wyI2JyAnhpR9JwzceAEBRJB4AQFEkHgBAUSQeAEBRJB4AQFEkHgBAUSQeAEBRRft4NiNkenFyvTqbUZkHzz1ubhbQxRIynzminRS0ycfexHNVhQt3Ql0/zaX6PG7Wpp7lsRlUlXlhNLXvi4qZ+IUSxiMbjxfrHTPT3+bETayVtKl+reQmX/ItAgBgkHgAAEWReAAARZF4AABFkXgAAEWReAAARRUtDqybrha6j5jfSJf9fXPuoN12Y27dXUWfX5erM37bSpeNdldeZde2M6fY7dnTrWfs2lE1SMaW+unxAZJ0rH7Uxp1vL19n473a1x4/10/v9yhT4DtqtiZja2P/uK1MHbd7LrqZYxrUe9LBB19v1+bcemp3MtaaMiXNkh5e2G7jY3O67/7UQbt261UL6WDmVvq5cunBypZkbMef/Jldmxtt4EqmW2/+kF3rPHXkP9p4f5QZi3D6G8lY05uza2OdvnpzYxFCZlxIM2zbeHKfzMgNvvEAAIoi8QAAiiLxAACKIvEAAIoi8QAAiiLxAACKIvEAAIo65z6eEEJL0gOSjscY3xlCuEHSfZKukPSgpPfHGIcXZjfxUquC/8zRxAs3FsE+7iZu1B9MH9g5rTfLX66f0C7WNIdcb0j07UebU2VGLlyg7kU3VkOSmkx8U/NCzLiQ3HOxmW1P6nxeTz8v6bEX/PlXJf16jPGQpEVJH3gpdwwAcHk6p8QTQjgg6R2SfnPjz0HS2yT94cav3CvpPRdg/wAAl5lz/cbzYUm/pL8YWHiFpKUY43jjz8ck7X9pdw0AcDnKJp4QwjslnYwxPjjJA4QQPhhCeCCE8MAk6wEAl5dz+We2t0h6Vwjhr0uakbRd0kck7QwhTG186zkg6fiLLY4xflTSRyUphAvwr1QAgJeV7DeeGOM/ijEeiDEelPRjkj4fY/wJSV+Q9N6NX7tH0icv2F4CAC4bmyks/IeS7gsh/AtJD0v6WG7Bjmqv3jr3vmR8WzudB2/f6esvn1hNH0quSHGmdZWNLw7TX9Teuq9v154Z+FPcmL17eu2WzLbTpceHtvmjXly53sZbMX0r9B1t/8W1ypSF/sT16f02l4AkaWs7fSv+L86nxwdI0nTL7/fzvfR+Dxq/9oat6evg1mv9eIvBYNrGD74r/bfUo6PpMRGS9LrHb7Xx3jh9fe665Vm79jt/fkcylistPnp6r40v9meTsXdf48ceLH14bONObrSBO663fOm9yZgkNTGzX/f/12Sotbvr92uQfuHU3Rm/tvHP1fS+FRtPqdrp4z2vxBNj/G+S/tvGf39H0hsm2iMAwCvWy7UvDgDwMkXiAQAUReIBABRF4gEAFEXiAQAUReIBABR1gW4QnuZuel+b4CjTFzDaxF38q0yjz8C0ENW5W51nZNpDJhYz262V6SkwhzXI1P0PM7e8d0/VZm5uMcxcA7n+omGdfuxh4zc+atKf4camV0aSRuN0z5Qke8Kaod/2sPYjAtx+x9p/Lh2OOun9ylwj/dofc79OH1ccZ/ZrkN6vnL45JsmPNsj16VTBP1f2ys+9v2XOt5MbmxA6E/ZFmdcy33gAAEWReAAARZF4AABFkXgAAEWReAAARZF4AABFkXgAAEUV7ePpa6jH49FkfHY0l4xtX99nt71gmkdy2XWUaXpZa0bJ2NHMrIt8T0u6/v5U3y9eq9P19acyM17qmD6ms/tlzmemZSDTLqNn19PnrJVZu38uPfcmt185s1PpDYyj74eJub4oo6r88xy76f0KU77BYzXTQ9Q3vTqtnX4GzHWHnk4HK79f246t2fjqWnrO0PTrh3btgTcesXG79vQ3/C+4i9vM05EyfTqSwnv/XTLWPfMVv3aUPp+do4/4tZketeH+H7TxlPhbn07G+MYDACiKxAMAKIrEAwAoisQDACiKxAMAKIrEAwAoqmg5daNGg9CbaO3qyBcjrtbp8uBcdh1nCh1XwnoytjCYtWvXJ6+y1UI9sPHVkC53nR348t/1+rSNh5Bef7Lnz1fXzbeQ1DLbzpVTS+lS7PnMfk37U6IlU/u+1vgn8rle+nb6J87ssWt7Iz8i4Iaje9P7dXKXXXu862/zPzS3018/fI1de/Lo1clYlbnV/nOn/TlZ6adfV7c+6l/Rre0rNu5GCDS9dEtHTmu3Lz/PjTZwJdNbrniT37QZyRCf+ppdmyunbi0/Z+NJdbrsnW88AICiSDwAgKJIPACAokg8AICiSDwAgKJIPACAokg8AICiivbxSFI0xeyNidWZe4rXtkje59eY6eNx+zzO7Ne4yd0MPc2dDym3X5Mf08YvTCyzZY3NL8TMRyF3HeSexyZucm6CEc22XUySohmNIUnNKN2AFOvcuAbPXZ5x7LfdNOknK9MaolHt33pGZtu5/XJ9OjnRjImQJIX0CYuDzNrMfrnRBq5PR5KqkD6fsetHUOSerGrG9yqmhGbyUTUAALykSDwAgKJIPACAokg8AICiSDwAgKJIPACAokg8AICiivbx1HGoxdHRZHy6tTUZ69fX222vh34yNhUzs2mq9LwdSeqGVRPdZ9fOTmV6OEwfxTjTEZPr87FrY25mTvozySjTI1Rn4i3zcaeT+SjkwrmWqcxuaWDOyUi+j2LQTCdjvbGft7M+8jNzBqtb0tte9z0W/dpff/10q4X65nEl358UTL+LJM2007NaJGllkD6fw+X0+4QkNSP/thbdPJ7M2mDmDNXd9Kyoc9E5+kgylpup43p1wt3/78T7JEmDtccnWtfMPp+M8Y0HAFAUiQcAUBSJBwBQFIkHAFAUiQcAUBSJBwBQVPGxCJOavHD40uaOK3ub/+DmC0y2PyVsYlLEy/I6uJDjGPDScuXSOa5M+1y2Hcx4AheTlJ9DcYnhGw8AoCgSDwCgKBIPAKAoEg8AoCgSDwCgKBIPAKAoEg8AoKiifTxBldpV+jbu7ZCOTVe+Rr4yt38P8munYuZW6CE9ViF3G/+c2pT2b1X61vCSbFPLtJs9IKlqfNwdc+585mwxp3um5Xsd5lrpg67MKAfJj2OQpMocV5X5jDZtwjNTfqSCGy8gSZ0tvfS2+/4a6WR6R1yP0fTWrl1bLe5IxnJjEXIqs35qa/p8nH3sTE9LlX6ymqEfYSGzX9P7VvzSjr8Ohvt/MBlrLT9n11Yz6ffOSccafNeWrbdMtK5VpcdE8I0HAFAUiQcAUBSJBwBQFIkHAFAUiQcAUBSJBwBQFIkHAFAUiQcAUBSJBwBQFIkHAFAUiQcAUBSJBwBQFIkHAFAUiQcAUFSIcXO3Lz+vBwvhlKRnXvCjPZJOF9uBlz/O1/nhfJ0fztf54Xx518cY975YoGji+d8ePIQHYox3XbQdeJnhfJ0fztf54XydH87X5PirNgBAUSQeAEBRFzvxfPQiP/7LDefr/HC+zg/n6/xwviZ0Uf+NBwDwynOxv/EAAF5hLkriCSG8PYTweAjhcAjhQxdjHy51IYSPhxBOhhAefcHPdocQPhtCeHLj/3ddzH28VIQQrg0hfCGE8O0QwrdCCD+/8XPO14sIIcyEEL4WQnhk43z93xs/vyGE8NWN1+UfhBA6F3tfLyUhhFYI4eEQwn/e+DPna0LFE08IoSXp30r6EUm3SfrxEMJtpffjZeC3Jb39e372IUmfizHeJOlzG3+GNJb0izHG2yS9SdLPblxTnK8XN5D0thjj6yTdIentIYQ3SfpVSb8eYzwkaVHSBy7eLl6Sfl7SYy/4M+drQhfjG88bJB2OMX4nxjiUdJ+kd1+E/bikxRi/KGnhe378bkn3bvz3vZLeU3KfLlUxxhMxxoc2/ntVZ98c9ovz9aLiWWsbf2xv/C9KepukP9z4OefrBUIIByS9Q9Jvbvw5iPM1sYuRePZLOvqCPx/b+Bnyrowxntj47+clXXkxd+ZSFEI4KOlOSV8V5ytp46+NviHppKTPSnpK0lKMcbzxK7wu/7IPS/olSc3Gn68Q52tiFBe8TMWz5YiUJL5ACGGrpPsl/b0Y48oLY5yvvyzGWMcY75B0QGf/FuLVF3ePLl0hhHdKOhljfPBi78vlYuoiPOZxSde+4M8HNn6GvPkQwtUxxhMhhKt19tMqJIUQ2jqbdH4vxvhHGz/mfGXEGJdCCF+Q9GZJO0MIUxuf4nld/oW3SHpXCOGvS5qRtF3SR8T5mtjF+MbzdUk3bVSEdCT9mKRPXYT9eDn6lKR7Nv77HkmfvIj7csnY+Pv2j0l6LMb4r14Q4ny9iBDC3hDCzo3/npV0t87+u9gXJL1349c4XxtijP8oxnggxnhQZ9+vPh9j/AlxviZ2URpINz45fFhSS9LHY4z/svhOXOJCCL8v6Qd19g6485L+qaT/JOkTkq7T2bt8vy/G+L0FCK84IYTvl/Rnkr6pv/g7+H+ss//Ow/n6HiGE23X2H8NbOvvh8xMxxv8nhPAqnS322S3pYUk/GWMcXLw9vfSEEH5Q0j+IMb6T8zU57lwAACiK4gIAQFEkHgBAUSQeAEBRJB4AQFEkHgBAUSQeAEBRJB4AQFEkHgBAUf8TG59AhzI63YsAAAAASUVORK5CYII=\n", 337 | "text/plain": [ 338 | "
" 339 | ] 340 | }, 341 | "metadata": { 342 | "needs_background": "light" 343 | }, 344 | "output_type": "display_data" 345 | } 346 | ], 347 | "source": [ 348 | "plt.imshow(cka_output.cpu().numpy(), cmap='inferno')" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "id": "e8e6cab2", 354 | "metadata": {}, 355 | "source": [ 356 | "### Advanced Usage \n", 357 | "\n", 358 | "We can customize other parameters of the `CKACalculator`. \n", 359 | "Most importantly, we can select which modules to hook. \n", 360 | "\n", 361 | "Before instantiating a new instance of `CKACalculator` on, make sure to first call the `reset` method. \n", 362 | "This clears all hooks registered in the models." 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 9, 368 | "id": "fe4a1711", 369 | "metadata": {}, 370 | "outputs": [ 371 | { 372 | "name": "stdout", 373 | "output_type": "stream", 374 | "text": [ 375 | "50 handles removed.\n", 376 | "50 handles removed.\n" 377 | ] 378 | } 379 | ], 380 | "source": [ 381 | "# Reset calculator to clear hooks\n", 382 | "calculator.reset()\n", 383 | "torch.cuda.empty_cache()" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 10, 389 | "id": "ff5c78e6", 390 | "metadata": {}, 391 | "outputs": [], 392 | "source": [ 393 | "import torch.nn as nn" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "id": "f168ff41", 399 | "metadata": {}, 400 | "source": [ 401 | "Let's consider outputs of `Conv2d` and `BatchNorm2d` only. This will create 40 hooks.\n", 402 | "\n", 403 | "For custom layers, add the custom modules in the same manner as shown below." 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": 11, 409 | "id": "9e768045", 410 | "metadata": {}, 411 | "outputs": [], 412 | "source": [ 413 | "layers = (nn.Conv2d, nn.BatchNorm2d)" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": 12, 419 | "id": "fd75a7d6", 420 | "metadata": {}, 421 | "outputs": [ 422 | { 423 | "name": "stdout", 424 | "output_type": "stream", 425 | "text": [ 426 | "No hook function provided. Using flatten_hook_fn.\n", 427 | "40 Hooks registered. Total hooks: 40\n", 428 | "No hook function provided. Using flatten_hook_fn.\n", 429 | "40 Hooks registered. Total hooks: 40\n" 430 | ] 431 | } 432 | ], 433 | "source": [ 434 | "calculator = CKACalculator(model1=model1, model2=model2, dataloader=dataloader, hook_layer_types=layers)" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": 13, 440 | "id": "d27ee2fb", 441 | "metadata": {}, 442 | "outputs": [ 443 | { 444 | "data": { 445 | "application/vnd.jupyter.widget-view+json": { 446 | "model_id": "b70e3f403590421491fc7f829567f9e3", 447 | "version_major": 2, 448 | "version_minor": 0 449 | }, 450 | "text/plain": [ 451 | "Epoch 0: 0%| | 0/39 [00:00" 606 | ] 607 | }, 608 | "execution_count": 14, 609 | "metadata": {}, 610 | "output_type": "execute_result" 611 | }, 612 | { 613 | "data": { 614 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZ4AAAGbCAYAAAD0sfa8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAArSElEQVR4nO3deZDc9Xnn8c/TPaduCQkBAiwsLh8B4VJICC4fOE6R02bL8YbNxngrVSSpeMuupFJxUpvE2aw3x8Z2stktp/BCzOYibHyu441D2VQclxOMwAIDsjkMlC50II3m6vPX3/1jGkchmvl+kFrf0QzvVxWF1PPo+X3717/pp3umn+8TKSUBAFBKbbEXAAB4eaHwAACKovAAAIqi8AAAiqLwAACKGip5sI0bV6etWzdl41KqsjGzj09Yx+xU+bu49pxjVq6qNZKNmZxaZeVydJL3uqDq5WNq4R2zm/KBbeN4ktQzPjG5cTT/WEtS1zgX053BvY6aTm0rrqduNmalxq1c7udLuyn/AFThPUirasPmUfMmetPZGPc+Vr2WEeVlq9fGzKPmVb3ZgeVaW8s/F7pm1RxYrsq89nN6qaOUqpM+oRQtPFu3btJ99/9WNq5TzWRjHv7+z1jH3Hd8fTbmh9/1CSvX9JNbsjGf//s3WLkch1ujVtyxVv5hXDnkPREdbtWzMfvM771GlX9iePelXtE/1sw/eXz50Eorl9NB8E+tvVauSR3JxlxXv8rKZZwuSdJz3fz3x3TkYyTp+pXnZWPcYvGp2X/K5zKKpiQdaz6ZDzJzrR271IpzHJ19aGC53rDindkY8zWeHtbj2ZhkZjvW2WMedWGN9rPzfu20XiJGxI0R8a2IeDIi3n86uQAALw+nXHgioi7pf0r6QUmvlnRzRLx6UAsDACxPp/OO51pJT6aUvp1Saku6S9LbBrMsAMBydTqFZ4ukE38YuLd/278QEbdGxM6I2Hn48ORpHA4AsByc8Y9Tp5RuSyntSCnt2LRpzZk+HADgLHc6hWefpItO+PuF/dsAAJjX6RSe+yVdFhGXRMSIpJ+Q9NnBLAsAsFydch9PSqkbEe+R9AVJdUl3pJQeHdjKAADL0mk1kKaUPi/p8358ZTWHDtfzjYDNTn4XAUlqGTsXVFNeZ3ljIv87qlbPexOZnB0CKi9Xs5fP1WznG0MlaaJt5HK2SpDUqPK7EnTN++ho97xWRyesMnYkmIvrZGPq4W0b4ez0IEmjyu820Ele87Gzo4Xb2Fo31pXMHRVqkf++dXaNcHPV5H1/SOYWIIbVw/lr3/xW03g3/5zZNc/XaH0wu680Y/5zyl5tAICiKDwAgKIoPACAoig8AICiKDwAgKIoPACAoig8AICiKDwAgKKKTiCdfXzCmhzqNIde/5V3WMd0RujWvnivlWvjj+TX9a6v/qmVyyn5n/vYj1upJlr56ZwXrJqycjW6+SbAVSPOWGLpaGNFNub6q7yJjgf2n5+N2bbeawSujCbfbfsutnJNtLdmY67b5O3KXhlNxZK0+3i+wa9ZeeOer9s0YeTymitn9r/OyOV1oz44np8c7Noer8zGuKPhd6/MTyF2jdXzB+2Y6zqvfW42pjInkK6Jdd5BM76lZ+b9Gu94AABFUXgAAEVReAAARVF4AABFUXgAAEVReAAARVF4AABFUXgAAEVReAAARRXduaBTDWnf8XxHsjWu2tiRQJLqtfwI4LTvoJerl+/87ZmjnMOYv9zseg+P01k+3fZGhbeNXMO1/EhrN1er6XXYt40dFRrmOHRnwnTHGCcueWOhnTHnkrejwqBVxv3sLMK6BsndlcDKlQZ3Lga4LPMdhBc1lNwx4AuLBa77pX1FAQCWHAoPAKAoCg8AoCgKDwCgKAoPAKAoCg8AoCgKDwCgKAoPAKCoog2ka885ph9+1yeycdVUfoSxO67aaQ6Nd99u5ZqdfTYbU/uHP7RyRS3feXj5Jq+x9ejM6mzMlnMOW7kmZ/JjlRtmM6rTALv2vCNWrqqbb2obmvAaWx0rhzZZce1efl0R3rjnes0bTeyoyTvmSD1/znpmq6Mz1tqcfK2kwT2Ws9Xgzuvx2vMDyzVaPy8b4za/do3HO5nXxExtxjtoRi/mP++84wEAFEXhAQAUReEBABRF4QEAFEXhAQAUReEBABRF4QEAFEXhAQAUReEBABRVdOeCqjWi6Se3ZOMaE2uyMRt/xOued8ZVOzsSSNKKFa/IxlQbJ61cUc93EY+MtK1ca6rZbMyGc70dAsammtmYw0fOsXJ1ja7+ZI6Ybhu7JQwNed3uPeOYo8bOEpI0YsQ5o8lfiqmOM67ay9UxHqPKHN3dMca5d5y545I6yo+2r4V3Xlu9/HVRC+8+9ga4o8JEO38uWubhJiO/20CSd1HMxpR30IyFzhXveAAARZ3WO56IeEbSlKRKUjeltGMQiwIALF+D+FHbm1NK3s9xAAAve/yoDQBQ1OkWniTp7yLigYi49WQBEXFrROyMiJ3PNwb3izkAwNJ0uj9qe31KaV9EnCvpnoj4ZkrpyycGpJRuk3SbJG0/d8ycxgEAWK5O6x1PSmlf//+HJH1K0rWDWBQAYPk65cITESsjYvULf5b0A5IeGdTCAADL0+n8qG2zpE/FXOPVkKS/SCn97UL/YHJqlT7/92/IJm718vXwXV/9U2uRvSqfyx1X7TSH1t/3v6xcvdTNxlw++x+tXFVjNBtTX9Wwcq03ztfYI9usXNOT+ZHca27wRgmvPrAnG/P8zsutXLV6vpFu25FzrVzntPPn/urN+61clXHdS9LKoY3ZmJbZtLp5Zb5ZcNq4j5L0mnX5senu6Ovm0fxj6Q60vmK1t35HZ8q7xhyvWTu433mPHc9fE12jwVeSxoe8az/n/za+Pu/XTrnwpJS+LenqU/33AICXJz5ODQAoisIDACiKwgMAKIrCAwAoisIDACiKwgMAKIrCAwAoisIDACiq6OhrV3JG7ZolM4xu3TDHHDvjqp0dCSSpFvlT320NW7l6nXyuoRGzS9oYExx1L1etZsS1O1au1Mk/4PbjWHN73o1cMbh9b5O88cvm9GiLs/56eOfL+ZYc5P70i/GqeZDH7BmPt7nZwEA515dzRSyUhnc8AICiKDwAgKIoPACAoig8AICiKDwAgKIoPACAoig8AICiKDwAgKKKNpB2Uk2HW/kxtG1j/PLnPvbj1jGb3fxdvHzTQSvXyEg7n8scV+00h478+n+3cjVb+fWnP/p1K5d6+aa2/Xu2WKmOz67Mxmyd2mnlinp+XUcP5sf/St7jOGGOez7azj+Oe49vsHJVxrmXpKemBzdi2tE0x2g/kZ+ircrsfj1YO2rFOVbMDGaUsyQ9XX92cLmmrxhYrmeas9mYnjksvDOgNt/mAs3ovOMBABRF4QEAFEXhAQAUReEBABRF4QEAFEXhAQAUReEBABRF4QEAFEXhAQAUVXTngqonHWvlD9k0OrgnWmPWMZ2u66Mzq61ca6p8d3DV8DrenXHVzo4EkjQ2ujkbkzZ6I7mdubdVz3u94oz27R5aYeXqTo9nY2Yb+RhJ6hi7Wcya3foNY5eN55vetWqNfJf0fCu/Nndk8rrhkWxMy9xRYaab73jvmjsXNGr57zXXbGVe+4ZOrTWwXM+3BjeCfSry5yuZOxc4Oxz0jHHoC8XwjgcAUBSFBwBQFIUHAFAUhQcAUBSFBwBQFIUHAFAUhQcAUBSFBwBQVNEG0lpIK4fyjUfNdr5B7oJVxpxdSdPtfIPclnMOW7k2nHskG1Nf1bByDY3km+3ccdVOc2i8+3YrV9XLN8hddO8fWrnaxrmf3bfJyrXv6YuyMevXT1i5HKM1r9lutXE9H2h4DaQ1ec2Vj0/mH+96eE2fF4znX3u6DaSHezPZmGTex2PVnmxMmK+b99fy16HreHf/wHLtGs43m1fyml8neweyMXYDaRpMY2s7zd/UyjseAEBR2cITEXdExKGIeOSE2zZExD0R8UT//+vP7DIBAMuF847n45JufNFt75f0xZTSZZK+2P87AABZ2cKTUvqypKMvuvltku7s//lOSW8f7LIAAMvVqf6OZ3NK6YXfZj0nad7tkSPi1ojYGRE7ZyrvF+8AgOXrtD9ckFJK0vwfVUkp3ZZS2pFS2rGy7m1bDwBYvk618ByMiPMlqf//Q4NbEgBgOTvVwvNZSbf0/3yLpM8MZjkAgOXO+Tj1X0r6R0lXRMTeiPhpSb8j6a0R8YSk7+//HQCArOzOBSmlm+f50lte6sG6KXTYGNs70c53Sje6w9Yx28YI48mZVVausalmNma9MQpZktTL71wgs2PcG1ftjeyt1/Ld1M6OBJJUGSOmu23vcXTGQrdb3rpqtfz56pljqJ0R0x3zcaybD/cCv1J9yTrG/eya63d2JXDGKr9c1JNx7ZvXRIQxqt28bOoxqH0F5l88OxcAAIqi8AAAiqLwAACKovAAAIqi8AAAiqLwAACKovAAAIqi8AAAiio6+rrdk/bNPw31O5pVvsls1YjXEDlcyzdqNsyGyMNHzsnGjD2yzcoV9fy69u/ZYuWqevnXD4McV33hXe+zcrXaz2djnv03n7RyHZzYkI15028/aOVSL3993fRnx6xUjemV2Zi1m148VeTkYshoKpb0+vuvycaMjbStXJdc8kw2pjmzwsr1Pc9cko1xrlVJ+pu9b87G1MzmystWd7xAw2PHLx5YrrXD+Y5Od+z4keblp7uc7+gYDemOL7UPz/s13vEAAIqi8AAAiqLwAACKovAAAIqi8AAAiqLwAACKovAAAIqi8AAAiqLwAACKKrpzQS8lNap8V2yjyndwH2143dTO6OumMaJZkrq9fK7pydVWrpqxo8Lx2XxXvCT1jPm4gxxX7exIIEmjI/mdHo5NrbFyHZnNP94xddzKpW43G1J11lmpnJHcM8e8+zg07HXYH2uOZ2PGzWt6djI/9n22kT+eJE228mPTq+S91m0Ymzi4o8KnuoN7fe08f7lqYXzfeptZaNbY7cUdOl4NaOeChY7HOx4AQFEUHgBAURQeAEBRFB4AQFEUHgBAURQeAEBRFB4AQFEUHgBAUUUbSDeOVnr3pfmRwt0qXw+vv+oh65it5lg2Zu15R6xcyRhDu+YGr7lS7Xyz4NapnVaq7qF8c+Xsvk1ervZwNsYdV+00h37P3/9bK9d3zT6bjZn99S9ZuWrD+QbSbzx+mZXrSCPf5Hv9lY9ZuZqN/LUqSevHGtmYRjf/OErSjNGY2zC+hyTpqel8M2rHHOV8sDm4cdUbRvJPc0YvpyRpf7N1mqv5Zz/5isE1fX7tSP5x7JrJ6sbbkZ7RYzrcnP+k8o4HAFAUhQcAUBSFBwBQFIUHAFAUhQcAUBSFBwBQFIUHAFAUhQcAUBSFBwBQVNGdC7qppmNmF3TOgf3nW3Fto4O76uZHWkve+OjVB/ZYuVInX/PDnO3bnc6PJt739EVWLmeU88GJDVYuZ1y1syOBJK1Y8YpszN69m61cIyPtbMwhY0cCSTpiXM+zM96Y9o6524AzPtoZ+S5Jx2eM0dft/EhrSZoxRkx3zO75SvnW+J45orll7pbg6Bh7CfTM/QaGjZf9Ed59dJ4qzKnjGhnQ25HaAmviHQ8AoKhs4YmIOyLiUEQ8csJtH4iIfRGxq//fD53ZZQIAlgvnHc/HJd14kts/klLa3v/v84NdFgBgucoWnpTSlyUdLbAWAMDLwOn8juc9EfFw/0dx6+cLiohbI2JnROyc7ua3cwcALG+nWng+KmmbpO2SDkj60HyBKaXbUko7Uko7Vg3lP30FAFjeTqnwpJQOppSqlFJP0sckXTvYZQEAlqtTKjwRcWITzU2SHpkvFgCAE2UbSCPiLyW9SdLGiNgr6TckvSkitktKkp6R9DPOwaY7NX35UL4xr23MVd223vuxXaOTb/ocmqisXEND+bjnd15u5Ypa/j4ePbjRyjXbyJ+L9esnrFztVv58vem3H7RyxdTxbIw7rtppDr3wrvdZuapefnzxv/uTn/NyHck3V/ZaXmNo1L3Gw6m/fWM25pxVk1aurVc+mY1pTXrNtE4/p9P8KkmN6txsjNvouGV8cGO0j7Xz32shr2H1AeMjW27D7fOt/MlPRlOu5I21djIt9DyeLTwppZtPcvPtxnEBAPhX2LkAAFAUhQcAUBSFBwBQFIUHAFAUhQcAUBSFBwBQFIUHAFAUhQcAUFTR0deS193sdM5WPa9mmtNxLT1jhG7N7D6PWj7OGdEsSZ3u4B7GmrGjgnpmO3W3mz/ecD5G8s6FsyOBJNVr+d0G0rDXfR6j+a74MEerO7kkaaSeP2dDQ955dXbQcHdUCOOUhfkN6YxyrrtjoY376Ix8nzvm4MZoj9Xz63KPN2pcYj3zPlrP0Uae2gJr5x0PAKAoCg8AoCgKDwCgKAoPAKAoCg8AoCgKDwCgKAoPAKAoCg8AoKiiDaTTqa1/au3NxlXKN79t23exdcyO0fS5cmiTlWvUaETbdiQ/stc10c43OkrSbJXvHhs1GlYlr8nspj87ZuWqOuuyMd94/DIr16FGfvyyO67aaQ6Nd3lDdjutg9mY3q99xMo1vHbaiptsrsjGDNW9ce6ThzZkY9pN7zocruWPOWy+1D13LJ9rzLymB8lZvttj+lwjH9iuvCbZqa73eDtaaTDntbtAJyrveAAARVF4AABFUXgAAEVReAAARVF4AABFUXgAAEVReAAARVF4AABFUXgAAEUV3bmgp64mdSQbVyk/AniivdU6ptP42+55o4lHjJ0LzjF3GwhjbO/R9rCVq1HlXz+sHnJ3LjCON53fRUDyxgkfMXYkkKQjzbFsTHXEPPfGiGlnRwJJGhvdnI05OrHKyuWabOXv5+iQN0Z7fCq/toZx7iWp0fWuV0ezMsbMy9sioNfLf384170kzRpPKO6r+ZZx0LY5Zn665406d3SMnWMcvQV2QOAdDwCgKAoPAKAoCg8AoCgKDwCgKAoPAKAoCg8AoCgKDwCgKAoPAKCoog2kKzWu6+pXZePqxuzY6zZNWsd0mhidZk5Jahojpq/evN/K5dh7PD+WWJKeNxr8DjS8JkBnVPjaTUetXDPH1mRjrr/yMSvX7Ex+3HOv5TUwRjf/OLrjqp3m0A23/ScrV6easeJ+9KY/ysakEe/xbt+dv147U16T73ddmh9rn4xzL0mjn7oxG1OT93171VWPWHGOf3zgdQPL9aqLns3GdLveU/RzR8/JxjiNtJLXCNwznld/4ZuMvgYAnCWyhSciLoqIeyPisYh4NCLe2799Q0TcExFP9P+//swvFwCw1DnveLqSfjGl9GpJ3yvp5yPi1ZLeL+mLKaXLJH2x/3cAABaULTwppQMppQf7f56StFvSFklvk3RnP+xOSW8/Q2sEACwjL+nDBRGxVdI1ku6TtDmldKD/pecknXSb3oi4VdKtkjQeg92lFwCw9NgfLoiIVZI+Iel9KaV/8ZGylFKSTv4Rk5TSbSmlHSmlHaMxflqLBQAsfVbhiYhhzRWdP08pfbJ/88GIOL//9fMlHTozSwQALCfOp9pC0u2SdqeUPnzClz4r6Zb+n2+R9JnBLw8AsNw4v+O5XtJPSfpGROzq3/arkn5H0t0R8dOSnpX0zjOyQgDAspItPCmlr0jzzph9y0s5WJI3irqX8kGV0TkrSZXRrVuveeNlB3U8SUrG2N7K2EVA8nZncLu868YhY6iycg0N58cvN90dFYxu6qh7j6Mz+np47bSVy+HuSDBc93YI6M0a+WreDgFhjHN3H2+LcTxJGqrlj+l+3zrfH4NUM3dCaRkjzJ3rXpJm2yPZGOc5R5JmOvlcjoV2N2DnAgBAURQeAEBRFB4AQFEUHgBAURQeAEBRFB4AQFEUHgBAURQeAEBRFB4AQFEvaSzC6eqmnp7r5ruuR5Xv1t19vPyIhalOvvN35dBGK5exOYOemvbu4/OtfJf645NdK9c8m4z/C6+//xor07Fmfjfy9WMNK1eV8q+Rpv72jVaukXr+XEw2V1i5Jo3u8x+96Y+sXNaOBJJqb/mtbEyjuc/K9fQHvmDk8naXeN33HcwH9byu/s3rj2ZjOl3v6Wvd1d+24hyvOrxhYLm2/tjOfJC5qcq2b52bT9XxdrNoTXk7aOSs3jf/9cw7HgBAURQeAEBRFB4AQFEUHgBAURQeAEBRFB4AQFEUHgBAURQeAEBRRRtIq+hpOvJNcp2Ub8prVl5Tm8MdC90xmrlaldek5XDGhEteT149Bjf+d2ykbcWNGw1+DXO0b9s4r+esmrRyDQ3lG0iH6t6459Gh/BjtNGJeq+a4aqc5dHxsi5VrbLSVjXFHOafN51lxjtVrprIx7pj5QT7L1Qc4BryzJ98g3mt7i58+tD4bk8znpsZMvvHbUS1wPN7xAACKovAAAIqi8AAAiqLwAACKovAAAIqi8AAAiqLwAACKovAAAIqi8AAAiiq6c8Gq2rCuX5nvbq4ZTfbXbZqwjln18slGzC71Ti/f+bt5Zb7jWpLC7AZ3rBseycZcMO69xuik/Pm65JJnrFyzk/nO7JlZb8T08Zl8rq1XPmnlilr+3E8e8kYcj0/l19W+e7+Vy1mX5I2rdnYkkKRXfvrnsjHuGO2nfvxvsjHuuOovPHGFFee4cPdrBpbr60fXZGPc7+yrv/WqbIyzY4ckfXMy/33krqtZDWaXk0PTT837Nd7xAACKovAAAIqi8AAAiqLwAACKovAAAIqi8AAAiqLwAACKovAAAIoq2kAqeU1MzsjnptlY1THG4/bkNUxVRnPldDs/tluS6pGfo+3ex5bRJOvESFLXiGvOeE2fs438CN1G0xsLPWuc19bkSitX1PPnvt30Hkdn/Z0pc13mWGXnmO646kGO0W61843MXbOB1GlkdnWt5wA31+mt5UQN41w4z1+S1Da+b3tmB2lzQNO9FzpVvOMBABSVLTwRcVFE3BsRj0XEoxHx3v7tH4iIfRGxq//fD5355QIAljrnfW9X0i+mlB6MiNWSHoiIe/pf+0hK6ffP3PIAAMtNtvCklA5IOtD/81RE7Jbk/dAXAIAXeUm/44mIrZKukXRf/6b3RMTDEXFHRKyf59/cGhE7I2Jns9c4vdUCAJY8u/BExCpJn5D0vpTSpKSPStomabvm3hF96GT/LqV0W0ppR0ppx1gt/yknAMDyZhWeiBjWXNH585TSJyUppXQwpVSllHqSPibp2jO3TADAcuF8qi0k3S5pd0rpwyfcfv4JYTdJemTwywMALDfOp9qul/RTkr4REbv6t/2qpJsjYrvmekKfkfQzZ2B9AIBlJlIa3AjmnOH6yrRu/LXZuLqGszFvGXndIJYkSWo6WyVI6hitv69Z5/3azIl6wpuirZluvtX4cG/WypWMvSU+8GpvrPJkK9/9/9R0fnS0JM1082fsjecdtHKF0RQ/XDN3Eejmr9W33vwZK5erdkH+XKTN+RHzkvTUh/MfUHV2JJCk137h3dmYquddO/Gx/Eju1DC/1y42dscwvockqbFzrRXn2PPopdmYdsc792OjzWxMz9wFIQ1o14h3fnWnHj0+edJk7FwAACiKwgMAKIrCAwAoisIDACiKwgMAKIrCAwAoisIDACiKwgMAKKro6OskaW5rt0ycNRZ6cI2vbqqO0Wzr5nLa1SqzubdrxDmNoZLUM4YAV2YjWpXycR1zJHfHGDnsHE+SwjhfwwN8SZa63ghz1cyLx51hbOgY45fdcdVOc2i95o0UT07fpDuHumY8mDUzl/HcFObj6DRq9szvj5pxzJ57F82x6dk8Czzn8I4HAFAUhQcAUBSFBwBQFIUHAFAUhQcAUBSFBwBQFIUHAFAUhQcAUBSFBwBQVNGdC6peS8eaT2bjapFf1oPj6wexJElSsvYRkDrKd2Y3j15+usv5joO1o1Zco5Yfa32s2nO6y/mOv9n7ZiuuYZzWg82Olasydl5oVOdauepGM/i5Y9410azyyUY/daOVa8gct715ff66WL3Gm5v+hSeuyMZ0zFHIVznjqr1Jzor/cEc2pjHzlJWr/rv/IxvTq7yFpSq/C0XyHkbtObIpG9Os8qPVJWldI/8c0Km8p/ux4bYVl9Nd4FzxjgcAUBSFBwBQFIUHAFAUhQcAUBSFBwBQFIUHAFAUhQcAUBSFBwBQVNEGUilJxujrnroF1vLS1SLfPGZOly1e8WOAR6x5/YRWo6arZ4yrHjHvYt0Y7TtmjkKuKX8na+bY8bp5TGdctTuefJBSwzimOa7aaQ5dsXKbl6uTP19OY6gktSZXWnGOY83xbEzTbPp0rrGOeU1Mtrzx5Pnj0UAKADhLUHgAAEVReAAARVF4AABFUXgAAEVReAAARVF4AABFUXgAAEVReAAARRXduaBeG9PasUuzcc7o6+3xSuuYTpf9bOV1U7d6+Zm2V6weTNevJK2Y8UY5z1b5nR7218yZw4bLVnvjqqe6+dc1G0a8S7DVyz+QW8a9ddVr3k4Cjp7RDX7VVY9YuZI5Ynrd1d/OB5nf2Rfufk02pmt2vNcuHjOCvFzOuGpnRwJJGv/t/2bFOUbu/1A+qObtgvDGxkPZmHbL+7694PsezgcZ30OS1D68zorL+eDB+cevZ6+CiBiLiK9FxEMR8WhE/Gb/9ksi4r6IeDIi/ioiBvfMBgBYtpyXHy1JN6SUrpa0XdKNEfG9kn5X0kdSSpdKOibpp8/YKgEAy0a28KQ50/2/Dvf/S5JukPTX/dvvlPT2M7FAAMDyYv3ANSLqEbFL0iFJ90h6StJESumFXy7slbRlnn97a0TsjIid/xwOAHi5sgpPSqlKKW2XdKGkayVd6R4gpXRbSmlHSmlHGB8aAAAsby/p49QppQlJ90q6TtK6+OdKcqGkfYNdGgBgOXI+1bYpItb1/zwu6a2SdmuuAL2jH3aLpM+coTUCAJYR52df50u6MyLqmitUd6eUPhcRj0m6KyL+i6SvS7r9DK4TALBMZAtPSulhSdec5PZva+73PQNXU74Byx2/PEi1WISDLnOLcUrdRk1Hb3C9qGctd5y7uvkGa5njvXtVvi3QHVc9SNHLr39RLgmzOdSRBpVrgRPBljkAgKIoPACAoig8AICiKDwAgKIoPACAoig8AICiKDwAgKIoPACAoig8AICiim4XXfVmdXQ2P+5VynfO7l550ikM/0ot5Wvr8drzVq6e8p3ZnanLrVxOxX+6/qyVq1NrZWOOd/dbuRyPHb/YimtU+R7u/c382iWpY/TPH2uPW7nqxnYJ7iuyWeM+/uMDrzOzeV51eEM2pj5k7CIg6etH12RjuubWBTftXJsPCi+ZsytBa3KllcsZV+3sSCBJ9e/5JSvO8dQz/ycb0+x4g53Pu/rxbEyqvKu6Z4wUD2N8/EIRvOMBABRF4QEAFEXhAQAUReEBABRF4QEAFEXhAQAUReEBABRF4QEAFFW0gRQ4VT2jgTSMxmOXO5J7kK/calF+aHLpIzqNh5KUvP5XTy3fjLoY46qdEey9AY5pl3l9uY/R6eAdDwCgKAoPAKAoCg8AoCgKDwCgKAoPAKAoCg8AoCgKDwCgKAoPAKAoCg8AoKiiOxesrW3SG1a8Mxu3ejhfD8fqXkevEzVaP8/KNdHOd/S+Zq3Xct0zVvb09BVWrudb+a7+XcOjVq56Gs7GrB32OptrRvv/T77CGzlsXBJ64KiVSmP1/Pqfa3jXV6uXz/Wqi7wR5q2W9xht/bGd2ZjOnlVWrqu/9apsTKPrPU3sefTSbIzTrS9Je45sysYca3qjzt/YeMiKczjjqt37eP1X3pGN6aWulUt//OlsSGp57zOq2bF8rl7+Pi4UwTseAEBRFB4AQFEUHgBAURQeAEBRFB4AQFEUHgBAURQeAEBRFB4AQFHFR1877YKVEdQZ4ETYmpmrNcBxvEbf4UBVMhvRjHPRMprHJKltnC+vfVQKY2xvx0xWNxpb25X3ALV7+YN2zQbMTjffvCvJOmm9tnfMdpUfC93pea9P252RbEzPvHaaVf5cNCvzPrby63I1nftoNpA6zaG18O6jcRlK5rl3WOOxFzgc73gAAEVlC09EjEXE1yLioYh4NCJ+s3/7xyPi6YjY1f9v+xlfLQBgyXPex7Uk3ZBSmo6IYUlfiYj/1//aL6WU/vrMLQ8AsNxkC09KKUma7v91uP9f4d9QAACWC+t3PBFRj4hdkg5JuieldF//Sx+MiIcj4iMRcdKtdSPi1ojYGRE726kxmFUDAJYsq/CklKqU0nZJF0q6NiJeK+lXJF0p6bslbZD0y/P829tSSjtSSjtGwtvGHACwfL2kT7WllCYk3SvpxpTSgTSnJelPJF17BtYHAFhmnE+1bYqIdf0/j0t6q6RvRsT5/dtC0tslPXLmlgkAWC6cT7WdL+nOiKhrrlDdnVL6XER8KSI2aa5NaJeknz1zywQALBfOp9oelnTNSW6/4aUebFZNPazHs3Hj3ZXZmPPa51rHdH6W2DU/pDcZM9mYseMbrVyOZ5qzVtxU5OMmewesXHOvLxZ2pHm5lWvW2ILia0dWWLmcSefPt7zHcTR/FzXV9bapmO7lu8+fO3qOlWu27XXYb/tW/tqfPrTeyvXNyfz5b5sd728ebWZjak7Hu6R1jfw1XTO/by/4vofzQeZ9PO/q/POXzRhXbe1IIKn283fkc5ljtEe++nvZmDAWFp8+Nu/X2LkAAFAUhQcAUBSFBwBQFIUHAFAUhQcAUBSFBwBQFIUHAFAUhQcAUFTx0dfJmNvbNcY0V/bQ5HxtTWYjmrX2Ac607pn30VmXE9MPHBjniF13WcZLJPdxdEcTD0rPHB2dnLnjknqdfAdsMkZazx3TOJ55TTj3022I7Bhjrd2R3IMc+Zwq45jGmHZJSi0jl7n2QY7RTrPT+SDngezN34TNOx4AQFEUHgBAURQeAEBRFB4AQFEUHgBAURQeAEBRFB4AQFEUHgBAURQeAEBRRXcuqFJbxzp7snGj9VXZmDWxzjrmUMp3cM/U8iOtJWk2prIx40PeSO5kNDd35I1fdnY46CWvZbwe+dciHWfxkiojrm6+9BlxmrzNDntn+S3zfHWMXTYa3WEr10zHG33dmsqPhm/MjFu5mlW+M77pXYZKxo4QYXb1jw23szGTrVErV/vwumxMcncI6OSfMsMc713NjllxDmdctbUjgaR463893eXM5Vnza/N+jXc8AICiKDwAgKIoPACAoig8AICiKDwAgKIoPACAoig8AICiKDwAgKKKj77GHHdwt5UrBpmtrAFOCrendpc+W6VHbePMcZtDHU7Tqnu8sEZRnz3PE7zjAQAUReEBABRF4QEAFEXhAQAUReEBABRF4QEAFEXhAQAUReEBABRF4QEAFBXJHGM8kINFHJb07Itu3ijpSLFFDNZSXru0tNe/lNcusf7FtJTXLi2d9b8ipbTpZF8oWnhOuoCInSmlHYu6iFO0lNcuLe31L+W1S6x/MS3ltUtLf/0SP2oDABRG4QEAFHU2FJ7bFnsBp2Epr11a2utfymuXWP9iWsprl5b++hf/dzwAgJeXs+EdDwDgZYTCAwAoatEKT0TcGBHfiognI+L9i7WOUxURz0TENyJiV0TsXOz15ETEHRFxKCIeOeG2DRFxT0Q80f//+sVc43zmWfsHImJf//zviogfWsw1ziciLoqIeyPisYh4NCLe2799qZz7+da/VM7/WER8LSIe6q//N/u3XxIR9/Wff/4qIkYWe60vtsDaPx4RT59w7rcv8lJfskX5HU9E1CU9LumtkvZKul/SzSmlx4ov5hRFxDOSdqSUlkIjlyLiDZKmJf3vlNJr+7f9nqSjKaXf6Rf/9SmlX17MdZ7MPGv/gKTplNLvL+baciLifEnnp5QejIjVkh6Q9HZJ79bSOPfzrf+dWhrnPyStTClNR8SwpK9Ieq+kX5D0yZTSXRHxx5IeSil9dDHX+mILrP1nJX0upfTXi7rA07BY73iulfRkSunbKaW2pLskvW2R1vKykFL6sqSjL7r5bZLu7P/5Ts09oZx15ln7kpBSOpBSerD/5ylJuyVt0dI59/Otf0lIc6b7fx3u/5ck3SDphSfus/L8L7D2JW+xCs8WSXtO+PteLaGLuS9J+ruIeCAibl3sxZyizSmlA/0/Pydp82Iu5hS8JyIe7v8o7qz8UdWJImKrpGsk3acleO5ftH5piZz/iKhHxC5JhyTdI+kpSRMppW4/5Kx9/nnx2lNKL5z7D/bP/UciYnTxVnhq+HDBqXt9Sul1kn5Q0s/3fxy0ZKW5n7kupVdTH5W0TdJ2SQckfWhRV5MREaskfULS+1JKkyd+bSmc+5Osf8mc/5RSlVLaLulCzf205crFXZHvxWuPiNdK+hXN3YfvlrRB0ln3I9qcxSo8+yRddMLfL+zftmSklPb1/39I0qc0d0EvNQf7P8N/4Wf5hxZ5PbaU0sH+N2VP0sd0Fp///s/nPyHpz1NKn+zfvGTO/cnWv5TO/wtSShOS7pV0naR1ETHU/9JZ//xzwtpv7P/4M6WUWpL+REvg3L/YYhWe+yVd1v9kyYikn5D02UVay0sWESv7v2hVRKyU9AOSHln4X52VPivplv6fb5H0mUVcy0vywpN23006S89//xfEt0vanVL68AlfWhLnfr71L6Hzvyki1vX/PK65DzTt1tyT+Dv6YWfl+Z9n7d884QVLaO53U2fluV/Iou1c0P/45R9Iqku6I6X0wUVZyCmIiFdq7l2OJA1J+ouzff0R8ZeS3qS5LdUPSvoNSZ+WdLekizU3ruKdKaWz7pf486z9TZr7MU+S9IyknznhdyZnjYh4vaR/kPQNSb3+zb+qud+TLIVzP9/6b9bSOP9Xae7DA3XNvdC+O6X0n/vfw3dp7kdVX5f07/vvIM4aC6z9S5I2SQpJuyT97AkfQlgS2DIHAFAUHy4AABRF4QEAFEXhAQAUReEBABRF4QEAFEXhAQAUReEBABT1/wGOw1OuSeir5QAAAABJRU5ErkJggg==\n", 615 | "text/plain": [ 616 | "
" 617 | ] 618 | }, 619 | "metadata": { 620 | "needs_background": "light" 621 | }, 622 | "output_type": "display_data" 623 | } 624 | ], 625 | "source": [ 626 | "plt.imshow(cka_output.cpu().numpy(), cmap='inferno')" 627 | ] 628 | }, 629 | { 630 | "cell_type": "markdown", 631 | "id": "afaa68d2", 632 | "metadata": {}, 633 | "source": [ 634 | "#### Extract module names " 635 | ] 636 | }, 637 | { 638 | "cell_type": "code", 639 | "execution_count": 15, 640 | "id": "b13097c6", 641 | "metadata": {}, 642 | "outputs": [ 643 | { 644 | "name": "stdout", 645 | "output_type": "stream", 646 | "text": [ 647 | "Layer 0: \tconv1\n", 648 | "Layer 1: \tbn1\n", 649 | "Layer 2: \tlayer1.0.conv1\n", 650 | "Layer 3: \tlayer1.0.bn1\n", 651 | "Layer 4: \tlayer1.0.conv2\n", 652 | "Layer 5: \tlayer1.0.bn2\n", 653 | "Layer 6: \tlayer1.1.conv1\n", 654 | "Layer 7: \tlayer1.1.bn1\n", 655 | "Layer 8: \tlayer1.1.conv2\n", 656 | "Layer 9: \tlayer1.1.bn2\n", 657 | "Layer 10: \tlayer2.0.conv1\n", 658 | "Layer 11: \tlayer2.0.bn1\n", 659 | "Layer 12: \tlayer2.0.conv2\n", 660 | "Layer 13: \tlayer2.0.bn2\n", 661 | "Layer 14: \tlayer2.0.downsample.0\n", 662 | "Layer 15: \tlayer2.0.downsample.1\n", 663 | "Layer 16: \tlayer2.1.conv1\n", 664 | "Layer 17: \tlayer2.1.bn1\n", 665 | "Layer 18: \tlayer2.1.conv2\n", 666 | "Layer 19: \tlayer2.1.bn2\n", 667 | "Layer 20: \tlayer3.0.conv1\n", 668 | "Layer 21: \tlayer3.0.bn1\n", 669 | "Layer 22: \tlayer3.0.conv2\n", 670 | "Layer 23: \tlayer3.0.bn2\n", 671 | "Layer 24: \tlayer3.0.downsample.0\n", 672 | "Layer 25: \tlayer3.0.downsample.1\n", 673 | "Layer 26: \tlayer3.1.conv1\n", 674 | "Layer 27: \tlayer3.1.bn1\n", 675 | "Layer 28: \tlayer3.1.conv2\n", 676 | "Layer 29: \tlayer3.1.bn2\n", 677 | "Layer 30: \tlayer4.0.conv1\n", 678 | "Layer 31: \tlayer4.0.bn1\n", 679 | "Layer 32: \tlayer4.0.conv2\n", 680 | "Layer 33: \tlayer4.0.bn2\n", 681 | "Layer 34: \tlayer4.0.downsample.0\n", 682 | "Layer 35: \tlayer4.0.downsample.1\n", 683 | "Layer 36: \tlayer4.1.conv1\n", 684 | "Layer 37: \tlayer4.1.bn1\n", 685 | "Layer 38: \tlayer4.1.conv2\n", 686 | "Layer 39: \tlayer4.1.bn2\n" 687 | ] 688 | } 689 | ], 690 | "source": [ 691 | "for i, name in enumerate(calculator.module_names_X):\n", 692 | " print(f\"Layer {i}: \\t{name}\")" 693 | ] 694 | } 695 | ], 696 | "metadata": { 697 | "kernelspec": { 698 | "display_name": "Python 3", 699 | "language": "python", 700 | "name": "python3" 701 | }, 702 | "language_info": { 703 | "codemirror_mode": { 704 | "name": "ipython", 705 | "version": 3 706 | }, 707 | "file_extension": ".py", 708 | "mimetype": "text/x-python", 709 | "name": "python", 710 | "nbconvert_exporter": "python", 711 | "pygments_lexer": "ipython3", 712 | "version": "3.9.9" 713 | } 714 | }, 715 | "nbformat": 4, 716 | "nbformat_minor": 5 717 | } 718 | -------------------------------------------------------------------------------- /hook_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper for CKA in PyTorch. 3 | Adds hooks to modules of a given model. 4 | 5 | Repo: https://github.com/numpee/CKA.pytorch 6 | Author: Dongwan Kim (Github: Numpee) 7 | Year: 2022 8 | """ 9 | 10 | from typing import Optional, Union, Callable, Tuple, Type, List 11 | 12 | import torch 13 | from torch import nn as nn 14 | from torchvision.models.resnet import Bottleneck, BasicBlock 15 | 16 | _HOOK_LAYER_TYPES = ( 17 | Bottleneck, BasicBlock, nn.Conv2d, nn.AdaptiveAvgPool2d, nn.MaxPool2d, nn.modules.batchnorm._BatchNorm) 18 | 19 | 20 | class HookManager: 21 | def __init__(self, model: nn.Module, hook_fn: Optional[Union[str, Callable]] = None, 22 | hook_layer_types: Tuple[Type[nn.Module], ...] = _HOOK_LAYER_TYPES, 23 | calculate_gram: bool = True) -> None: 24 | """ 25 | Add hooks to models. 26 | Mainly supports ResNets. 27 | :param model: model to attach hooks to 28 | :param hook_fn: the hook function or string. Options: ("avgpool", "flatten"). Default: flatten 29 | :param hook_layer_types: layer types to register hooks. Should be nn.Module 30 | """ 31 | self.model = model 32 | self.hook_fn = hook_fn 33 | self.hook_layer_types = hook_layer_types 34 | self.calculate_gram = calculate_gram 35 | for layer in self.hook_layer_types: 36 | if not issubclass(layer, nn.Module): 37 | raise TypeError(f"Class {layer} is not an nn.Module.") 38 | 39 | if self.hook_fn is None: 40 | self.hook_fn = self.flatten_hook_fn 41 | print("No hook function provided. Using flatten_hook_fn.") 42 | elif type(self.hook_fn) == str: 43 | hook_fn_dict = {'flatten': self.flatten_hook_fn, 'avgpool': self.avgpool_hook_fn} 44 | if self.hook_fn in hook_fn_dict: 45 | self.hook_fn = hook_fn_dict[self.hook_fn] 46 | else: 47 | raise ValueError(f"No hook function named {self.hook_fn}. Options: {list(hook_fn_dict.keys())}") 48 | 49 | # Not using dictionary because a single module may be used multiple times in forward 50 | self.features = [] 51 | self.module_names = [] 52 | self.handles = [] 53 | 54 | self.register_hooks(self.hook_fn) 55 | 56 | def get_features(self) -> List[torch.Tensor]: 57 | return self.features 58 | 59 | def get_module_names(self) -> List[str]: 60 | return self.module_names 61 | 62 | def clear_features(self) -> None: 63 | self.features = [] 64 | self.module_names = [] 65 | 66 | def clear_all(self) -> None: 67 | self.clear_hooks() 68 | self.clear_features() 69 | 70 | def clear_hooks(self) -> None: 71 | num_handles = len(self.handles) 72 | for handle in self.handles: 73 | handle.remove() 74 | self.handles = [] 75 | for m in self.model.modules(): 76 | if hasattr(m, 'module_name'): 77 | delattr(m, 'module_name') 78 | print(f"{num_handles} handles removed.") 79 | 80 | def register_hooks(self, hook_fn: Callable) -> None: 81 | prev_num_handles = len(self.handles) 82 | self._register_hook_recursive(self.model, hook_fn, prev_name="") 83 | new_num_handles = len(self.handles) 84 | print(f"{new_num_handles - prev_num_handles} Hooks registered. Total hooks: {new_num_handles}") 85 | 86 | def _register_hook_recursive(self, module: nn.Module, hook_fn: Callable, prev_name: str = "") -> None: 87 | for name, child in module.named_children(): 88 | curr_name = f"{prev_name}.{name}" if prev_name else name 89 | curr_name = curr_name.replace("_model.", "") 90 | num_grandchildren = len(list(child.children())) 91 | if num_grandchildren > 0: 92 | self._register_hook_recursive(child, hook_fn, prev_name=curr_name) 93 | if isinstance(child, self.hook_layer_types): 94 | handle = child.register_forward_hook(hook_fn) 95 | self.handles.append(handle) 96 | setattr(child, 'module_name', curr_name) 97 | 98 | def flatten_hook_fn(self, module: nn.Module, inp: torch.Tensor, out: torch.Tensor) -> None: 99 | batch_size = out.size(0) 100 | feature = out.reshape(batch_size, -1) 101 | if self.calculate_gram: 102 | feature = gram(feature) 103 | module_name = getattr(module, 'module_name') 104 | self.features.append(feature) 105 | self.module_names.append(module_name) 106 | 107 | def avgpool_hook_fn(self, module: nn.Module, inp: torch.Tensor, out: torch.Tensor) -> None: 108 | if out.dim() == 4: 109 | feature = out.mean(dim=(-1, -2)) 110 | elif out.dim() == 3: 111 | feature = out.mean(dim=-1) 112 | else: 113 | feature = out 114 | if self.calculate_gram: 115 | feature = gram(feature) 116 | module_name = getattr(module, 'module_name') 117 | self.features.append(feature) 118 | self.module_names.append(module_name) 119 | 120 | 121 | def gram(x: torch.Tensor) -> torch.Tensor: 122 | return x.matmul(x.t()) 123 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric 3 | 4 | 5 | class AccumTensor(Metric): 6 | def __init__(self, default_value: torch.Tensor): 7 | super().__init__() 8 | 9 | self.add_state("val", default=default_value, dist_reduce_fx="sum") 10 | 11 | def update(self, input_tensor: torch.Tensor): 12 | self.val += input_tensor 13 | 14 | def compute(self): 15 | return self.val 16 | --------------------------------------------------------------------------------