├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── configs ├── demo.yaml ├── pico-decoder-large-ov.yaml ├── pico-decoder-medium-ov.yaml ├── pico-decoder-small-ov.yaml └── pico-decoder-tiny-ov.yaml ├── lib ├── cka.py └── svcca │ ├── cca_core.py │ └── pwcca.py ├── pyproject.toml ├── scripts └── analyze.py ├── setup.sh └── src ├── components ├── __init__.py ├── _registry.py ├── base.py ├── ov_circuit.py └── simple.py ├── config ├── _registry.py ├── base.py ├── learning_dynamics.py ├── metrics.py └── monitoring.py ├── metrics ├── __init__.py ├── _registry.py ├── base.py ├── cka.py ├── condition_number.py ├── gini.py ├── hoyer.py ├── norm.py ├── per.py └── pwcca.py └── utils ├── __init__.py ├── data.py ├── exceptions.py ├── initialization.py └── logging.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ *.py[cod] 3 | *$py.class 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | downloads/ 14 | eggs/ 15 | .eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | share/python-wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | *.py,cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | cover/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | .pybuilder/ 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | # For a library or package, you might want to ignore these files since the code is 86 | # intended to run in multiple environments; otherwise, check them in: 87 | # .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # UV 97 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 98 | # This is especially recommended for binary packages to ensure reproducibility, and is more 99 | # commonly ignored for libraries. 100 | #uv.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 115 | .pdm.toml 116 | .pdm-python 117 | .pdm-build/ 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | 169 | # PyPI configuration file 170 | .pypirc 171 | 172 | # ignoring all pycache 173 | **/__pycache__/ 174 | **/pycache/ 175 | 176 | # ignoring wandb logs 177 | wandb/ 178 | 179 | analysis_results/ 180 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.7.1 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | args: [ --fix, --extend-select, I ] 9 | # Run the formatter. 10 | - id: ruff-format -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 📊 **Pico Analyze** 2 | 3 | **Pico Analyze** is a companion toolkit to [**pico-train**](https://github.com/pico-lm/pico-train), designed specifically for studying and visualizing the **learning dynamics** of language models. Whether you want to track activation sparsity, compare layers across checkpoints, or probe the evolution of specific attention heads, Pico Analyze has you covered. 4 | 5 | > For a **detailed run-through**, check out the **full tutorial** on our website at [picolm.io](https://picolm.io). 6 | 7 | --- 8 | 9 | ## **Key Features** 10 | 11 | 1. **Rich Checkpoint Compatibility** 12 | - Seamlessly loads model states, gradients, and activations stored by **pico-train** 13 | - Automatically handles standard PyTorch and Hugging Face–compatible checkpoints 14 | 15 | 2. **Modular Analysis System** 16 | - **Components**: Specify which parts of the model (e.g., weights, gradients, activations) to analyze 17 | - **Metrics**: Apply built-in metrics like **CKA**, **PWCCA**, **PER**, **Gini**, **Hoyer**, and more 18 | 19 | 3. **Deep Learning Dynamics Insights** 20 | - Compare **multiple checkpoints** from different training steps 21 | - Visualize how parameters evolve over time using comprehensive logs or Weights & Biases integration 22 | 23 | 4. **Config-Driven & Extensible** 24 | - Simple YAML config to define which steps, layers, metrics, and components to analyze 25 | - Easily register custom metrics or components by subclassing and decorating with `@register_metric` or `@register_component` 26 | 27 | --- 28 | 29 | ## **Installation** 30 | 31 | 1. **Clone the Repository** 32 | 33 | ```bash 34 | git clone https://github.com/pico-lm/pico-analyze 35 | cd pico-analyze 36 | ``` 37 | 38 | 2. **Configure Environment** 39 | 40 | Create a `.env` file at the root with your Hugging Face and Weights & Biases tokens: 41 | ```bash 42 | export HF_TOKEN=your_huggingface_token 43 | export WANDB_API_KEY=your_wandb_key 44 | ``` 45 | 46 | 3. **Install Dependencies** 47 | 48 | ```bash 49 | source setup.sh 50 | ``` 51 | This script checks your environment, installs necessary tools, and sets up a Poetry virtual environment. 52 | 53 | --- 54 | 55 | ## **Basic Usage** 56 | 57 | 1. **Prepare Your Checkpoints** 58 | Make sure you have checkpoints generated by **pico-train**—either locally or hosted on Hugging Face. 59 | 60 | 2. **Create an Analysis Config** 61 | Define a YAML file specifying: 62 | - **Which checkpoints** to analyze (by step or revision tag) 63 | - **Which components** (weights, activations, gradients) 64 | - **Which metrics** (CKA, Gini, etc.) 65 | 66 | ```yaml 67 | # configs/my_analysis_config.yaml 68 | 69 | analysis_name: "my_analysis" 70 | steps: 71 | - 0 72 | - 1000 73 | - 5000 74 | metrics: 75 | - metric_name: cka 76 | data_split: "val" 77 | target_checkpoint: 5000 78 | components: 79 | - component_name: simple 80 | data_type: "weights" 81 | layer_suffixes: "attention.o_proj" 82 | layers: [0, 1, 2] 83 | monitoring: 84 | output_dir: "analysis_results" 85 | save_to_wandb: true 86 | wandb: 87 | entity: "pico-lm" 88 | project: "pico-analysis" 89 | ``` 90 | 91 | 3. **Run the Analysis** 92 | ```bash 93 | poetry run analyze \ 94 | --config_path configs/my_analysis_config.yaml \ 95 | --repo_id pico-lm/pico-decoder-small \ 96 | --branch pico-decoder-small-1 97 | ``` 98 | - **`--repo_id`**: The Hugging Face repository hosting your checkpoints (e.g., `pico-lm/pico-decoder-small`) 99 | - **`--branch`**: The repo branch or “revision” (e.g., `pico-decoder-small-1`) 100 | - Or use `--run_path` to analyze local checkpoints 101 | 102 | 4. **Review Output** 103 | - Results are saved to `analysis_results/my_analysis` 104 | - Inspect JSON logs for each step, or open Weights & Biases to see dynamic charts 105 | 106 | --- 107 | 108 | ## **Configurable Metrics & Components** 109 | 110 | - **Metrics** 111 | - Single-checkpoint (e.g., **norm**, **gini**, **condition_number**) 112 | - Comparative (e.g., **cka**, **pwcca**) 113 | 114 | - **Components** 115 | - **simple**: Directly extracts a single tensor (weights, gradients, or activations) 116 | - **ov_circuit**: Combines attention value and output projections for interpretability 117 | 118 | Add custom metrics or components by registering them in the code: 119 | ```python 120 | # src/metrics/custom.py 121 | 122 | @register_metric("my_custom_metric") 123 | class MyCustomMetric(BaseMetric): 124 | ... 125 | ``` 126 | 127 | --- 128 | 129 | ## **Extensibility** 130 | 131 | 1. **Add New Metrics** 132 | Create a class inheriting from `BaseMetric` (or `BaseComparativeMetric`) and register it with `@register_metric(...)`. 133 | 134 | 2. **Add New Components** 135 | Subclass `BaseComponent` to define a new data extraction strategy and register with `@register_component(...)`. 136 | 137 | --- 138 | 139 | ## **Community & Contributions** 140 | 141 | - Report issues or request features via [GitHub Issues](https://github.com/pico-lm/pico-analyze/issues) 142 | - We welcome contributions! Feel free to open a Pull Request 143 | 144 | --- 145 | 146 | ## **License & Citation** 147 | 148 | Pico Analyze is open-source under the Apache 2.0 License. If you use it in academic or professional work, please cite: 149 | 150 | ```bibtex 151 | @software{pico2025, 152 | author = {Diehl Martinez, Richard}, 153 | title = {Pico: A Lightweight Framework for Studying Language Model Learning Dynamics}, 154 | year = {2025}, 155 | url = {https://github.com/pico-lm} 156 | } 157 | ``` 158 | 159 | --- 160 | 161 | **Happy Analyzing!** 162 | Check out our [website](https://picolm.io) or star our repos for updates, tutorials, and more on the Pico ecosystem. 163 | -------------------------------------------------------------------------------- /configs/demo.yaml: -------------------------------------------------------------------------------- 1 | # This is a demo config for the pico-analyze tool. Feel free to use this as a template 2 | # to analyze your own models. 3 | 4 | analysis_name: "pico-decoder-demo-1-analysis" 5 | 6 | metrics: 7 | 8 | # CKA metric (Comparative) 9 | - metric_name: cka 10 | target_checkpoint: 100 11 | data_split: "val" 12 | components: 13 | - component_name: ov_circuit 14 | data_type: "activations" 15 | layer_suffixes: 16 | output_layer: "attention.o_proj" 17 | value_layer: "attention.v_proj" 18 | layers: [0,11] 19 | 20 | # PWCCA metric (Comparative) 21 | - metric_name: pwcca 22 | target_checkpoint: 100 23 | data_split: "val" 24 | components: 25 | - component_name: ov_circuit 26 | data_type: "activations" 27 | layer_suffixes: 28 | output_layer: "attention.o_proj" 29 | value_layer: "attention.v_proj" 30 | layers: [0,11] 31 | 32 | # PER metric 33 | - metric_name: per 34 | data_split: "train" 35 | components: 36 | - component_name: simple 37 | data_type: "gradients" 38 | layer_suffixes: "swiglu.w_2" 39 | layers: [0,11] 40 | 41 | # Condition number metric 42 | - metric_name: condition_number 43 | data_split: "val" 44 | components: 45 | - component_name: simple 46 | data_type: "weights" 47 | layer_suffixes: "swiglu.w_2" 48 | layers: [0,11] 49 | 50 | # Gini metric 51 | - metric_name: gini 52 | data_split: "val" 53 | components: 54 | - component_name: simple 55 | data_type: "weights" 56 | layer_suffixes: "swiglu.w_2" 57 | layers: [0,11] 58 | 59 | # Hoyer metric 60 | - metric_name: hoyer 61 | data_split: "train" 62 | components: 63 | - component_name: simple 64 | data_type: "gradients" 65 | layer_suffixes: "swiglu.w_2" 66 | layers: [0,11] 67 | 68 | # Nuclear norm metric 69 | - metric_name: norm 70 | data_split: "val" 71 | norm_type: "nuclear" 72 | components: 73 | - component_name: ov_circuit 74 | data_type: "weights" 75 | layer_suffixes: 76 | output_layer: "attention.o_proj" 77 | value_layer: "attention.v_proj" 78 | layers: [0,11] 79 | 80 | 81 | steps: 82 | start: 0 83 | end: 100 84 | step: 50 85 | 86 | monitoring: 87 | output_dir: "analysis_results" 88 | 89 | save_to_wandb: true 90 | wandb: 91 | entity: "pico-lm" 92 | project: "pico-demo" 93 | -------------------------------------------------------------------------------- /configs/pico-decoder-large-ov.yaml: -------------------------------------------------------------------------------- 1 | analysis_name: "pico-decoder-large-ov-analysis" 2 | 3 | metrics: 4 | 5 | - metric_name: cka 6 | target_checkpoint: 50_000 7 | data_split: "val" 8 | components: 9 | - component_name: ov_circuit 10 | data_type: "activations" 11 | layer_suffixes: 12 | output_layer: "attention.o_proj" 13 | value_layer: "attention.v_proj" 14 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 15 | 16 | - metric_name: per 17 | data_split: "val" 18 | components: 19 | - component_name: ov_circuit 20 | data_type: "weights" 21 | layer_suffixes: 22 | output_layer: "attention.o_proj" 23 | value_layer: "attention.v_proj" 24 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 25 | - component_name: ov_circuit 26 | data_type: "gradients" 27 | layer_suffixes: 28 | output_layer: "attention.o_proj" 29 | value_layer: "attention.v_proj" 30 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 31 | 32 | - metric_name: condition_number 33 | data_split: "val" 34 | components: 35 | - component_name: ov_circuit 36 | data_type: "weights" 37 | layer_suffixes: 38 | output_layer: "attention.o_proj" 39 | value_layer: "attention.v_proj" 40 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 41 | - component_name: ov_circuit 42 | data_type: "gradients" 43 | layer_suffixes: 44 | output_layer: "attention.o_proj" 45 | value_layer: "attention.v_proj" 46 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 47 | 48 | - metric_name: gini 49 | data_split: "val" 50 | components: 51 | - component_name: ov_circuit 52 | data_type: "weights" 53 | layer_suffixes: 54 | output_layer: "attention.o_proj" 55 | value_layer: "attention.v_proj" 56 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 57 | - component_name: ov_circuit 58 | data_type: "gradients" 59 | layer_suffixes: 60 | output_layer: "attention.o_proj" 61 | value_layer: "attention.v_proj" 62 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 63 | 64 | - metric_name: norm 65 | data_split: "val" 66 | norm_type: "frobenius" 67 | components: 68 | - component_name: ov_circuit 69 | data_type: "weights" 70 | layer_suffixes: 71 | output_layer: "attention.o_proj" 72 | value_layer: "attention.v_proj" 73 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 74 | - component_name: ov_circuit 75 | data_type: "gradients" 76 | layer_suffixes: 77 | output_layer: "attention.o_proj" 78 | value_layer: "attention.v_proj" 79 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 80 | 81 | 82 | steps: 83 | start: 0 84 | end: 50_000 85 | step: 5000 86 | 87 | monitoring: 88 | output_dir: "analysis_results" 89 | 90 | save_to_wandb: true 91 | wandb: 92 | entity: "pico-lm" 93 | project: "pico-analysis" 94 | -------------------------------------------------------------------------------- /configs/pico-decoder-medium-ov.yaml: -------------------------------------------------------------------------------- 1 | analysis_name: "pico-decoder-medium-ov-analysis" 2 | 3 | metrics: 4 | 5 | - metric_name: cka 6 | target_checkpoint: 50_000 7 | data_split: "val" 8 | components: 9 | - component_name: ov_circuit 10 | data_type: "activations" 11 | layer_suffixes: 12 | output_layer: "attention.o_proj" 13 | value_layer: "attention.v_proj" 14 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 15 | 16 | - metric_name: per 17 | data_split: "val" 18 | components: 19 | - component_name: ov_circuit 20 | data_type: "weights" 21 | layer_suffixes: 22 | output_layer: "attention.o_proj" 23 | value_layer: "attention.v_proj" 24 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 25 | - component_name: ov_circuit 26 | data_type: "gradients" 27 | layer_suffixes: 28 | output_layer: "attention.o_proj" 29 | value_layer: "attention.v_proj" 30 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 31 | 32 | - metric_name: condition_number 33 | data_split: "val" 34 | components: 35 | - component_name: ov_circuit 36 | data_type: "weights" 37 | layer_suffixes: 38 | output_layer: "attention.o_proj" 39 | value_layer: "attention.v_proj" 40 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 41 | - component_name: ov_circuit 42 | data_type: "gradients" 43 | layer_suffixes: 44 | output_layer: "attention.o_proj" 45 | value_layer: "attention.v_proj" 46 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 47 | 48 | - metric_name: gini 49 | data_split: "val" 50 | components: 51 | - component_name: ov_circuit 52 | data_type: "weights" 53 | layer_suffixes: 54 | output_layer: "attention.o_proj" 55 | value_layer: "attention.v_proj" 56 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 57 | - component_name: ov_circuit 58 | data_type: "gradients" 59 | layer_suffixes: 60 | output_layer: "attention.o_proj" 61 | value_layer: "attention.v_proj" 62 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 63 | 64 | - metric_name: norm 65 | data_split: "val" 66 | norm_type: "frobenius" 67 | components: 68 | - component_name: ov_circuit 69 | data_type: "weights" 70 | layer_suffixes: 71 | output_layer: "attention.o_proj" 72 | value_layer: "attention.v_proj" 73 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 74 | - component_name: ov_circuit 75 | data_type: "gradients" 76 | layer_suffixes: 77 | output_layer: "attention.o_proj" 78 | value_layer: "attention.v_proj" 79 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 80 | 81 | 82 | steps: 83 | start: 0 84 | end: 50_000 85 | step: 5000 86 | 87 | monitoring: 88 | output_dir: "analysis_results" 89 | 90 | save_to_wandb: true 91 | wandb: 92 | entity: "pico-lm" 93 | project: "pico-analysis" 94 | -------------------------------------------------------------------------------- /configs/pico-decoder-small-ov.yaml: -------------------------------------------------------------------------------- 1 | analysis_name: "pico-decoder-small-ov-analysis" 2 | 3 | metrics: 4 | 5 | - metric_name: cka 6 | target_checkpoint: 50_000 7 | data_split: "val" 8 | components: 9 | - component_name: ov_circuit 10 | data_type: "activations" 11 | layer_suffixes: 12 | output_layer: "attention.o_proj" 13 | value_layer: "attention.v_proj" 14 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 15 | 16 | - metric_name: per 17 | data_split: "val" 18 | components: 19 | - component_name: ov_circuit 20 | data_type: "weights" 21 | layer_suffixes: 22 | output_layer: "attention.o_proj" 23 | value_layer: "attention.v_proj" 24 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 25 | - component_name: ov_circuit 26 | data_type: "gradients" 27 | layer_suffixes: 28 | output_layer: "attention.o_proj" 29 | value_layer: "attention.v_proj" 30 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 31 | 32 | - metric_name: condition_number 33 | data_split: "val" 34 | components: 35 | - component_name: ov_circuit 36 | data_type: "weights" 37 | layer_suffixes: 38 | output_layer: "attention.o_proj" 39 | value_layer: "attention.v_proj" 40 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 41 | - component_name: ov_circuit 42 | data_type: "gradients" 43 | layer_suffixes: 44 | output_layer: "attention.o_proj" 45 | value_layer: "attention.v_proj" 46 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 47 | 48 | - metric_name: gini 49 | data_split: "val" 50 | components: 51 | - component_name: ov_circuit 52 | data_type: "weights" 53 | layer_suffixes: 54 | output_layer: "attention.o_proj" 55 | value_layer: "attention.v_proj" 56 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 57 | - component_name: ov_circuit 58 | data_type: "gradients" 59 | layer_suffixes: 60 | output_layer: "attention.o_proj" 61 | value_layer: "attention.v_proj" 62 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 63 | 64 | - metric_name: norm 65 | data_split: "val" 66 | norm_type: "frobenius" 67 | components: 68 | - component_name: ov_circuit 69 | data_type: "weights" 70 | layer_suffixes: 71 | output_layer: "attention.o_proj" 72 | value_layer: "attention.v_proj" 73 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 74 | - component_name: ov_circuit 75 | data_type: "gradients" 76 | layer_suffixes: 77 | output_layer: "attention.o_proj" 78 | value_layer: "attention.v_proj" 79 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 80 | 81 | 82 | steps: 83 | start: 0 84 | end: 50_000 85 | step: 5000 86 | 87 | monitoring: 88 | output_dir: "analysis_results" 89 | 90 | save_to_wandb: true 91 | wandb: 92 | entity: "pico-lm" 93 | project: "pico-analysis" 94 | -------------------------------------------------------------------------------- /configs/pico-decoder-tiny-ov.yaml: -------------------------------------------------------------------------------- 1 | analysis_name: "pico-decoder-tiny-ov-analysis" 2 | 3 | metrics: 4 | 5 | - metric_name: cka 6 | target_checkpoint: 50_000 7 | data_split: "val" 8 | components: 9 | - component_name: ov_circuit 10 | data_type: "activations" 11 | layer_suffixes: 12 | output_layer: "attention.o_proj" 13 | value_layer: "attention.v_proj" 14 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 15 | 16 | - metric_name: per 17 | data_split: "val" 18 | components: 19 | - component_name: ov_circuit 20 | data_type: "weights" 21 | layer_suffixes: 22 | output_layer: "attention.o_proj" 23 | value_layer: "attention.v_proj" 24 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 25 | - component_name: ov_circuit 26 | data_type: "gradients" 27 | layer_suffixes: 28 | output_layer: "attention.o_proj" 29 | value_layer: "attention.v_proj" 30 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 31 | 32 | - metric_name: condition_number 33 | data_split: "val" 34 | components: 35 | - component_name: ov_circuit 36 | data_type: "weights" 37 | layer_suffixes: 38 | output_layer: "attention.o_proj" 39 | value_layer: "attention.v_proj" 40 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 41 | - component_name: ov_circuit 42 | data_type: "gradients" 43 | layer_suffixes: 44 | output_layer: "attention.o_proj" 45 | value_layer: "attention.v_proj" 46 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 47 | 48 | - metric_name: gini 49 | data_split: "val" 50 | components: 51 | - component_name: ov_circuit 52 | data_type: "weights" 53 | layer_suffixes: 54 | output_layer: "attention.o_proj" 55 | value_layer: "attention.v_proj" 56 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 57 | - component_name: ov_circuit 58 | data_type: "gradients" 59 | layer_suffixes: 60 | output_layer: "attention.o_proj" 61 | value_layer: "attention.v_proj" 62 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 63 | 64 | - metric_name: norm 65 | data_split: "val" 66 | norm_type: "frobenius" 67 | components: 68 | - component_name: ov_circuit 69 | data_type: "weights" 70 | layer_suffixes: 71 | output_layer: "attention.o_proj" 72 | value_layer: "attention.v_proj" 73 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 74 | - component_name: ov_circuit 75 | data_type: "gradients" 76 | layer_suffixes: 77 | output_layer: "attention.o_proj" 78 | value_layer: "attention.v_proj" 79 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 80 | 81 | 82 | steps: 83 | start: 0 84 | end: 50_000 85 | step: 5000 86 | 87 | monitoring: 88 | output_dir: "analysis_results" 89 | 90 | save_to_wandb: true 91 | wandb: 92 | entity: "pico-lm" 93 | project: "pico-analysis" 94 | -------------------------------------------------------------------------------- /lib/cka.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | """ 4 | Original CKA Implementation, from the authors of the CKA paper: 5 | Similarity of Neural Network Representations Revisited (https://arxiv.org/abs/1905.00414) 6 | 7 | The source code was uploaded only as a colab notebook, so we have copied it here. 8 | Copied from https://colab.research.google.com/github/google-research/google-research/blob/master/representation_similarity/Demo.ipynb#scrollTo=MkucRi3yn7UJ 9 | """ 10 | 11 | 12 | def gram_linear(x): 13 | """Compute Gram (kernel) matrix for a linear kernel. 14 | 15 | Args: 16 | x: A num_examples x num_features matrix of features. 17 | 18 | Returns: 19 | A num_examples x num_examples Gram matrix of examples. 20 | """ 21 | return x.dot(x.T) 22 | 23 | 24 | def gram_rbf(x, threshold=1.0): 25 | """Compute Gram (kernel) matrix for an RBF kernel. 26 | 27 | Args: 28 | x: A num_examples x num_features matrix of features. 29 | threshold: Fraction of median Euclidean distance to use as RBF kernel 30 | bandwidth. (This is the heuristic we use in the paper. There are other 31 | possible ways to set the bandwidth; we didn't try them.) 32 | 33 | Returns: 34 | A num_examples x num_examples Gram matrix of examples. 35 | """ 36 | dot_products = x.dot(x.T) 37 | sq_norms = np.diag(dot_products) 38 | sq_distances = -2 * dot_products + sq_norms[:, None] + sq_norms[None, :] 39 | sq_median_distance = np.median(sq_distances) 40 | return np.exp(-sq_distances / (2 * threshold**2 * sq_median_distance)) 41 | 42 | 43 | def center_gram(gram, unbiased=False): 44 | """Center a symmetric Gram matrix. 45 | 46 | This is equvialent to centering the (possibly infinite-dimensional) features 47 | induced by the kernel before computing the Gram matrix. 48 | 49 | Args: 50 | gram: A num_examples x num_examples symmetric matrix. 51 | unbiased: Whether to adjust the Gram matrix in order to compute an unbiased 52 | estimate of HSIC. Note that this estimator may be negative. 53 | 54 | Returns: 55 | A symmetric matrix with centered columns and rows. 56 | """ 57 | if not np.allclose(gram, gram.T): 58 | raise ValueError("Input must be a symmetric matrix.") 59 | gram = gram.copy() 60 | 61 | if unbiased: 62 | # This formulation of the U-statistic, from Szekely, G. J., & Rizzo, M. 63 | # L. (2014). Partial distance correlation with methods for dissimilarities. 64 | # The Annals of Statistics, 42(6), 2382-2412, seems to be more numerically 65 | # stable than the alternative from Song et al. (2007). 66 | n = gram.shape[0] 67 | np.fill_diagonal(gram, 0) 68 | means = np.sum(gram, 0, dtype=np.float64) / (n - 2) 69 | means -= np.sum(means) / (2 * (n - 1)) 70 | gram -= means[:, None] 71 | gram -= means[None, :] 72 | np.fill_diagonal(gram, 0) 73 | else: 74 | means = np.mean(gram, 0, dtype=np.float64) 75 | means -= np.mean(means) / 2 76 | gram -= means[:, None] 77 | gram -= means[None, :] 78 | 79 | return gram 80 | 81 | 82 | def cka(gram_x, gram_y, debiased=False): 83 | """Compute CKA. 84 | 85 | Args: 86 | gram_x: A num_examples x num_examples Gram matrix. 87 | gram_y: A num_examples x num_examples Gram matrix. 88 | debiased: Use unbiased estimator of HSIC. CKA may still be biased. 89 | 90 | Returns: 91 | The value of CKA between X and Y. 92 | """ 93 | gram_x = center_gram(gram_x, unbiased=debiased) 94 | gram_y = center_gram(gram_y, unbiased=debiased) 95 | 96 | # Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or 97 | # n*(n-3) (unbiased variant), but this cancels for CKA. 98 | scaled_hsic = gram_x.ravel().dot(gram_y.ravel()) 99 | 100 | normalization_x = np.linalg.norm(gram_x) 101 | normalization_y = np.linalg.norm(gram_y) 102 | return float(scaled_hsic / (normalization_x * normalization_y)) 103 | 104 | 105 | def _debiased_dot_product_similarity_helper( 106 | xty, sum_squared_rows_x, sum_squared_rows_y, squared_norm_x, squared_norm_y, n 107 | ): 108 | """Helper for computing debiased dot product similarity (i.e. linear HSIC).""" 109 | # This formula can be derived by manipulating the unbiased estimator from 110 | # Song et al. (2007). 111 | return ( 112 | xty 113 | - n / (n - 2.0) * sum_squared_rows_x.dot(sum_squared_rows_y) 114 | + squared_norm_x * squared_norm_y / ((n - 1) * (n - 2)) 115 | ) 116 | 117 | 118 | def feature_space_linear_cka(features_x, features_y, debiased=False): 119 | """Compute CKA with a linear kernel, in feature space. 120 | 121 | This is typically faster than computing the Gram matrix when there are fewer 122 | features than examples. 123 | 124 | Args: 125 | features_x: A num_examples x num_features matrix of features. 126 | features_y: A num_examples x num_features matrix of features. 127 | debiased: Use unbiased estimator of dot product similarity. CKA may still be 128 | biased. Note that this estimator may be negative. 129 | 130 | Returns: 131 | The value of CKA between X and Y. 132 | """ 133 | features_x = features_x - np.mean(features_x, 0, keepdims=True) 134 | features_y = features_y - np.mean(features_y, 0, keepdims=True) 135 | 136 | dot_product_similarity = np.linalg.norm(features_x.T.dot(features_y)) ** 2 137 | normalization_x = np.linalg.norm(features_x.T.dot(features_x)) 138 | normalization_y = np.linalg.norm(features_y.T.dot(features_y)) 139 | 140 | if debiased: 141 | n = features_x.shape[0] 142 | # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array. 143 | sum_squared_rows_x = np.einsum("ij,ij->i", features_x, features_x) 144 | sum_squared_rows_y = np.einsum("ij,ij->i", features_y, features_y) 145 | squared_norm_x = np.sum(sum_squared_rows_x) 146 | squared_norm_y = np.sum(sum_squared_rows_y) 147 | 148 | dot_product_similarity = _debiased_dot_product_similarity_helper( 149 | dot_product_similarity, 150 | sum_squared_rows_x, 151 | sum_squared_rows_y, 152 | squared_norm_x, 153 | squared_norm_y, 154 | n, 155 | ) 156 | normalization_x = np.sqrt( 157 | _debiased_dot_product_similarity_helper( 158 | normalization_x**2, 159 | sum_squared_rows_x, 160 | sum_squared_rows_x, 161 | squared_norm_x, 162 | squared_norm_x, 163 | n, 164 | ) 165 | ) 166 | normalization_y = np.sqrt( 167 | _debiased_dot_product_similarity_helper( 168 | normalization_y**2, 169 | sum_squared_rows_y, 170 | sum_squared_rows_y, 171 | squared_norm_y, 172 | squared_norm_y, 173 | n, 174 | ) 175 | ) 176 | 177 | return float(dot_product_similarity / (normalization_x * normalization_y)) 178 | -------------------------------------------------------------------------------- /lib/svcca/cca_core.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """ 17 | The core code for applying Canonical Correlation Analysis to deep networks. 18 | 19 | This module contains the core functions to apply canonical correlation analysis 20 | to deep neural networks. The main function is get_cca_similarity, which takes in 21 | two sets of activations, typically the neurons in two layers and their outputs 22 | on all of the datapoints D = [d_1,...,d_m] that have been passed through. 23 | 24 | Inputs have shape (num_neurons1, m), (num_neurons2, m). This can be directly 25 | applied used on fully connected networks. For convolutional layers, the 3d block 26 | of neurons can either be flattened entirely, along channels, or alternatively, 27 | the dft_ccas (Discrete Fourier Transform) module can be used. 28 | 29 | See: 30 | https://arxiv.org/abs/1706.05806 31 | https://arxiv.org/abs/1806.05759 32 | for full details. 33 | 34 | """ 35 | 36 | from __future__ import absolute_import, division, print_function 37 | 38 | import numpy as np 39 | 40 | num_cca_trials = 10 41 | 42 | 43 | def positivedef_matrix_sqrt(array): 44 | """Stable method for computing matrix square roots, supports complex matrices. 45 | 46 | Args: 47 | array: A numpy 2d array, can be complex valued that is a positive 48 | definite symmetric (or hermitian) matrix 49 | 50 | Returns: 51 | sqrtarray: The matrix square root of array 52 | """ 53 | w, v = np.linalg.eigh(array) 54 | # A - np.dot(v, np.dot(np.diag(w), v.T)) 55 | wsqrt = np.sqrt(w) 56 | sqrtarray = np.dot(v, np.dot(np.diag(wsqrt), np.conj(v).T)) 57 | return sqrtarray 58 | 59 | 60 | def remove_small(sigma_xx, sigma_xy, sigma_yx, sigma_yy, epsilon): 61 | """Takes covariance between X, Y, and removes values of small magnitude. 62 | 63 | Args: 64 | sigma_xx: 2d numpy array, variance matrix for x 65 | sigma_xy: 2d numpy array, crossvariance matrix for x,y 66 | sigma_yx: 2d numpy array, crossvariance matrixy for x,y, 67 | (conjugate) transpose of sigma_xy 68 | sigma_yy: 2d numpy array, variance matrix for y 69 | epsilon : cutoff value for norm below which directions are thrown 70 | away 71 | 72 | Returns: 73 | sigma_xx_crop: 2d array with low x norm directions removed 74 | sigma_xy_crop: 2d array with low x and y norm directions removed 75 | sigma_yx_crop: 2d array with low x and y norm directiosn removed 76 | sigma_yy_crop: 2d array with low y norm directions removed 77 | x_idxs: indexes of sigma_xx that were removed 78 | y_idxs: indexes of sigma_yy that were removed 79 | """ 80 | 81 | x_diag = np.abs(np.diagonal(sigma_xx)) 82 | y_diag = np.abs(np.diagonal(sigma_yy)) 83 | x_idxs = x_diag >= epsilon 84 | y_idxs = y_diag >= epsilon 85 | 86 | sigma_xx_crop = sigma_xx[x_idxs][:, x_idxs] 87 | sigma_xy_crop = sigma_xy[x_idxs][:, y_idxs] 88 | sigma_yx_crop = sigma_yx[y_idxs][:, x_idxs] 89 | sigma_yy_crop = sigma_yy[y_idxs][:, y_idxs] 90 | 91 | return (sigma_xx_crop, sigma_xy_crop, sigma_yx_crop, sigma_yy_crop, x_idxs, y_idxs) 92 | 93 | 94 | def compute_ccas(sigma_xx, sigma_xy, sigma_yx, sigma_yy, epsilon, verbose=True): 95 | """Main cca computation function, takes in variances and crossvariances. 96 | 97 | This function takes in the covariances and cross covariances of X, Y, 98 | preprocesses them (removing small magnitudes) and outputs the raw results of 99 | the cca computation, including cca directions in a rotated space, and the 100 | cca correlation coefficient values. 101 | 102 | Args: 103 | sigma_xx: 2d numpy array, (num_neurons_x, num_neurons_x) 104 | variance matrix for x 105 | sigma_xy: 2d numpy array, (num_neurons_x, num_neurons_y) 106 | crossvariance matrix for x,y 107 | sigma_yx: 2d numpy array, (num_neurons_y, num_neurons_x) 108 | crossvariance matrix for x,y (conj) transpose of sigma_xy 109 | sigma_yy: 2d numpy array, (num_neurons_y, num_neurons_y) 110 | variance matrix for y 111 | epsilon: small float to help with stabilizing computations 112 | verbose: boolean on whether to print intermediate outputs 113 | 114 | Returns: 115 | [ux, sx, vx]: [numpy 2d array, numpy 1d array, numpy 2d array] 116 | ux and vx are (conj) transposes of each other, being 117 | the canonical directions in the X subspace. 118 | sx is the set of canonical correlation coefficients- 119 | how well corresponding directions in vx, Vy correlate 120 | with each other. 121 | [uy, sy, vy]: Same as above, but for Y space 122 | invsqrt_xx: Inverse square root of sigma_xx to transform canonical 123 | directions back to original space 124 | invsqrt_yy: Same as above but for sigma_yy 125 | x_idxs: The indexes of the input sigma_xx that were pruned 126 | by remove_small 127 | y_idxs: Same as above but for sigma_yy 128 | """ 129 | 130 | (sigma_xx, sigma_xy, sigma_yx, sigma_yy, x_idxs, y_idxs) = remove_small( 131 | sigma_xx, sigma_xy, sigma_yx, sigma_yy, epsilon 132 | ) 133 | 134 | numx = sigma_xx.shape[0] 135 | numy = sigma_yy.shape[0] 136 | 137 | if numx == 0 or numy == 0: 138 | return ( 139 | [0, 0, 0], 140 | [0, 0, 0], 141 | np.zeros_like(sigma_xx), 142 | np.zeros_like(sigma_yy), 143 | x_idxs, 144 | y_idxs, 145 | ) 146 | 147 | if verbose: 148 | print("adding eps to diagonal and taking inverse") 149 | sigma_xx += epsilon * np.eye(numx) 150 | sigma_yy += epsilon * np.eye(numy) 151 | inv_xx = np.linalg.pinv(sigma_xx) 152 | inv_yy = np.linalg.pinv(sigma_yy) 153 | 154 | if verbose: 155 | print("taking square root") 156 | invsqrt_xx = positivedef_matrix_sqrt(inv_xx) 157 | invsqrt_yy = positivedef_matrix_sqrt(inv_yy) 158 | 159 | if verbose: 160 | print("dot products...") 161 | arr = np.dot(invsqrt_xx, np.dot(sigma_xy, invsqrt_yy)) 162 | 163 | if verbose: 164 | print("trying to take final svd") 165 | u, s, v = np.linalg.svd(arr) 166 | 167 | if verbose: 168 | print("computed everything!") 169 | 170 | return [u, np.abs(s), v], invsqrt_xx, invsqrt_yy, x_idxs, y_idxs 171 | 172 | 173 | def sum_threshold(array, threshold): 174 | """Computes threshold index of decreasing nonnegative array by summing. 175 | 176 | This function takes in a decreasing array nonnegative floats, and a 177 | threshold between 0 and 1. It returns the index i at which the sum of the 178 | array up to i is threshold*total mass of the array. 179 | 180 | Args: 181 | array: a 1d numpy array of decreasing, nonnegative floats 182 | threshold: a number between 0 and 1 183 | 184 | Returns: 185 | i: index at which np.sum(array[:i]) >= threshold 186 | """ 187 | assert (threshold >= 0) and (threshold <= 1), "print incorrect threshold" 188 | 189 | for i in range(len(array)): 190 | if np.sum(array[:i]) / np.sum(array) >= threshold: 191 | return i 192 | 193 | 194 | def create_zero_dict(compute_dirns, dimension): 195 | """Outputs a zero dict when neuron activation norms too small. 196 | 197 | This function creates a return_dict with appropriately shaped zero entries 198 | when all neuron activations are very small. 199 | 200 | Args: 201 | compute_dirns: boolean, whether to have zero vectors for directions 202 | dimension: int, defines shape of directions 203 | 204 | Returns: 205 | return_dict: a dict of appropriately shaped zero entries 206 | """ 207 | return_dict = {} 208 | return_dict["mean"] = (np.asarray(0), np.asarray(0)) 209 | return_dict["sum"] = (np.asarray(0), np.asarray(0)) 210 | return_dict["cca_coef1"] = np.asarray(0) 211 | return_dict["cca_coef2"] = np.asarray(0) 212 | return_dict["idx1"] = 0 213 | return_dict["idx2"] = 0 214 | 215 | if compute_dirns: 216 | return_dict["cca_dirns1"] = np.zeros((1, dimension)) 217 | return_dict["cca_dirns2"] = np.zeros((1, dimension)) 218 | 219 | return return_dict 220 | 221 | 222 | def get_cca_similarity( 223 | acts1, 224 | acts2, 225 | epsilon=0.0, 226 | threshold=0.98, 227 | compute_coefs=True, 228 | compute_dirns=False, 229 | verbose=True, 230 | ): 231 | """The main function for computing cca similarities. 232 | 233 | This function computes the cca similarity between two sets of activations, 234 | returning a dict with the cca coefficients, a few statistics of the cca 235 | coefficients, and (optionally) the actual directions. 236 | 237 | Args: 238 | acts1: (num_neurons1, data_points) a 2d numpy array of neurons by 239 | datapoints where entry (i,j) is the output of neuron i on 240 | datapoint j. 241 | acts2: (num_neurons2, data_points) same as above, but (potentially) 242 | for a different set of neurons. Note that acts1 and acts2 243 | can have different numbers of neurons, but must agree on the 244 | number of datapoints 245 | 246 | epsilon: small float to help stabilize computations 247 | 248 | threshold: float between 0, 1 used to get rid of trailing zeros in 249 | the cca correlation coefficients to output more accurate 250 | summary statistics of correlations. 251 | 252 | 253 | compute_coefs: boolean value determining whether coefficients 254 | over neurons are computed. Needed for computing 255 | directions 256 | 257 | compute_dirns: boolean value determining whether actual cca 258 | directions are computed. (For very large neurons and 259 | datasets, may be better to compute these on the fly 260 | instead of store in memory.) 261 | 262 | verbose: Boolean, whether intermediate outputs are printed 263 | 264 | Returns: 265 | return_dict: A dictionary with outputs from the cca computations. 266 | Contains neuron coefficients (combinations of neurons 267 | that correspond to cca directions), the cca correlation 268 | coefficients (how well aligned directions correlate), 269 | x and y idxs (for computing cca directions on the fly 270 | if compute_dirns=False), and summary statistics. If 271 | compute_dirns=True, the cca directions are also 272 | computed. 273 | """ 274 | 275 | # assert dimensionality equal 276 | assert acts1.shape[1] == acts2.shape[1], "dimensions don't match" 277 | # check that acts1, acts2 are transposition 278 | assert acts1.shape[0] < acts1.shape[1], ( 279 | "input must be number of neurons" "by datapoints" 280 | ) 281 | return_dict = {} 282 | 283 | # compute covariance with numpy function for extra stability 284 | numx = acts1.shape[0] 285 | numy = acts2.shape[0] 286 | 287 | covariance = np.cov(acts1, acts2) 288 | sigmaxx = covariance[:numx, :numx] 289 | sigmaxy = covariance[:numx, numx:] 290 | sigmayx = covariance[numx:, :numx] 291 | sigmayy = covariance[numx:, numx:] 292 | 293 | # rescale covariance to make cca computation more stable 294 | xmax = np.max(np.abs(sigmaxx)) 295 | ymax = np.max(np.abs(sigmayy)) 296 | sigmaxx /= xmax 297 | sigmayy /= ymax 298 | sigmaxy /= np.sqrt(xmax * ymax) 299 | sigmayx /= np.sqrt(xmax * ymax) 300 | 301 | ([u, s, v], invsqrt_xx, invsqrt_yy, x_idxs, y_idxs) = compute_ccas( 302 | sigmaxx, sigmaxy, sigmayx, sigmayy, epsilon=epsilon, verbose=verbose 303 | ) 304 | 305 | # if x_idxs or y_idxs is all false, return_dict has zero entries 306 | if (not np.any(x_idxs)) or (not np.any(y_idxs)): 307 | return create_zero_dict(compute_dirns, acts1.shape[1]) 308 | 309 | if compute_coefs: 310 | # also compute full coefficients over all neurons 311 | x_mask = np.dot(x_idxs.reshape((-1, 1)), x_idxs.reshape((1, -1))) 312 | y_mask = np.dot(y_idxs.reshape((-1, 1)), y_idxs.reshape((1, -1))) 313 | 314 | return_dict["coef_x"] = u.T 315 | return_dict["invsqrt_xx"] = invsqrt_xx 316 | return_dict["full_coef_x"] = np.zeros((numx, numx)) 317 | np.place(return_dict["full_coef_x"], x_mask, return_dict["coef_x"]) 318 | return_dict["full_invsqrt_xx"] = np.zeros((numx, numx)) 319 | np.place(return_dict["full_invsqrt_xx"], x_mask, return_dict["invsqrt_xx"]) 320 | 321 | return_dict["coef_y"] = v 322 | return_dict["invsqrt_yy"] = invsqrt_yy 323 | return_dict["full_coef_y"] = np.zeros((numy, numy)) 324 | np.place(return_dict["full_coef_y"], y_mask, return_dict["coef_y"]) 325 | return_dict["full_invsqrt_yy"] = np.zeros((numy, numy)) 326 | np.place(return_dict["full_invsqrt_yy"], y_mask, return_dict["invsqrt_yy"]) 327 | 328 | # compute means 329 | neuron_means1 = np.mean(acts1, axis=1, keepdims=True) 330 | neuron_means2 = np.mean(acts2, axis=1, keepdims=True) 331 | return_dict["neuron_means1"] = neuron_means1 332 | return_dict["neuron_means2"] = neuron_means2 333 | 334 | if compute_dirns: 335 | # orthonormal directions that are CCA directions 336 | cca_dirns1 = ( 337 | np.dot( 338 | np.dot(return_dict["full_coef_x"], return_dict["full_invsqrt_xx"]), 339 | (acts1 - neuron_means1), 340 | ) 341 | + neuron_means1 342 | ) 343 | cca_dirns2 = ( 344 | np.dot( 345 | np.dot(return_dict["full_coef_y"], return_dict["full_invsqrt_yy"]), 346 | (acts2 - neuron_means2), 347 | ) 348 | + neuron_means2 349 | ) 350 | 351 | # get rid of trailing zeros in the cca coefficients 352 | idx1 = sum_threshold(s, threshold) 353 | idx2 = sum_threshold(s, threshold) 354 | 355 | return_dict["cca_coef1"] = s 356 | return_dict["cca_coef2"] = s 357 | return_dict["x_idxs"] = x_idxs 358 | return_dict["y_idxs"] = y_idxs 359 | # summary statistics 360 | return_dict["mean"] = (np.mean(s[:idx1]), np.mean(s[:idx2])) 361 | return_dict["sum"] = (np.sum(s), np.sum(s)) 362 | 363 | if compute_dirns: 364 | return_dict["cca_dirns1"] = cca_dirns1 365 | return_dict["cca_dirns2"] = cca_dirns2 366 | 367 | return return_dict 368 | 369 | 370 | def robust_cca_similarity(acts1, acts2, epsilon=1e-6, **kwargs): 371 | """Calls get_cca_similarity multiple times while adding noise. 372 | 373 | This function is very similar to get_cca_similarity, and can be used if 374 | get_cca_similarity doesn't converge for some pair of inputs. This function 375 | adds some noise to the activations to help convergence. 376 | 377 | Args: 378 | acts1: (num_neurons1, data_points) a 2d numpy array of neurons by 379 | datapoints where entry (i,j) is the output of neuron i on 380 | datapoint j. 381 | acts2: (num_neurons2, data_points) same as above, but (potentially) 382 | for a different set of neurons. Note that acts1 and acts2 383 | can have different numbers of neurons, but must agree on the 384 | number of datapoints 385 | 386 | threshold: float between 0, 1 used to get rid of trailing zeros in 387 | the cca correlation coefficients to output more accurate 388 | summary statistics of correlations. 389 | 390 | epsilon: small float to help stabilize computations 391 | 392 | compute_dirns: boolean value determining whether actual cca 393 | directions are computed. (For very large neurons and 394 | datasets, may be better to compute these on the fly 395 | instead of store in memory.) 396 | 397 | Returns: 398 | return_dict: A dictionary with outputs from the cca computations. 399 | Contains neuron coefficients (combinations of neurons 400 | that correspond to cca directions), the cca correlation 401 | coefficients (how well aligned directions correlate), 402 | x and y idxs (for computing cca directions on the fly 403 | if compute_dirns=False), and summary statistics. If 404 | compute_dirns=True, the cca directions are also 405 | computed. 406 | """ 407 | 408 | for trial in range(num_cca_trials): 409 | try: 410 | return_dict = get_cca_similarity(acts1, acts2, **kwargs) 411 | except np.linalg.LinAlgError: 412 | acts1 = acts1 * 1e-1 + np.random.normal(size=acts1.shape) * epsilon 413 | acts2 = acts2 * 1e-1 + np.random.normal(size=acts1.shape) * epsilon 414 | if trial + 1 == num_cca_trials: 415 | raise 416 | 417 | return return_dict 418 | -------------------------------------------------------------------------------- /lib/svcca/pwcca.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """ 17 | The function for computing projection weightings. 18 | 19 | See: 20 | https://arxiv.org/abs/1806.05759 21 | for full details. 22 | 23 | """ 24 | 25 | import numpy as np 26 | 27 | from lib.svcca.cca_core import robust_cca_similarity 28 | 29 | 30 | def compute_pwcca(acts1, acts2, epsilon=1e-6): 31 | """Computes projection weighting for weighting CCA coefficients 32 | 33 | Args: 34 | acts1: 2d numpy array, shaped (neurons, num_datapoints) 35 | acts2: 2d numpy array, shaped (neurons, num_datapoints) 36 | 37 | Returns: 38 | Original cca coefficient mean and weighted mean 39 | 40 | """ 41 | sresults = robust_cca_similarity( 42 | acts1, 43 | acts2, 44 | epsilon=epsilon, 45 | compute_dirns=False, 46 | compute_coefs=True, 47 | verbose=False, 48 | ) 49 | 50 | if np.sum(sresults["x_idxs"]) <= np.sum(sresults["y_idxs"]): 51 | dirns = ( 52 | np.dot( 53 | sresults["coef_x"], 54 | ( 55 | acts1[sresults["x_idxs"]] 56 | - sresults["neuron_means1"][sresults["x_idxs"]] 57 | ), 58 | ) 59 | + sresults["neuron_means1"][sresults["x_idxs"]] 60 | ) 61 | coefs = sresults["cca_coef1"] 62 | acts = acts1 63 | idxs = sresults["x_idxs"] 64 | else: 65 | dirns = ( 66 | np.dot( 67 | sresults["coef_y"], 68 | ( 69 | acts1[sresults["y_idxs"]] 70 | - sresults["neuron_means2"][sresults["y_idxs"]] 71 | ), 72 | ) 73 | + sresults["neuron_means2"][sresults["y_idxs"]] 74 | ) 75 | coefs = sresults["cca_coef2"] 76 | acts = acts2 77 | idxs = sresults["y_idxs"] 78 | P, _ = np.linalg.qr(dirns.T) 79 | weights = np.sum(np.abs(np.dot(P.T, acts[idxs].T)), axis=1) 80 | weights = weights / np.sum(weights) 81 | 82 | return np.sum(weights * coefs), weights, coefs 83 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "pico-analyze" 3 | version = "1.0.0" 4 | description = "A companion toolkit to pico-train for quantifying, comparing, and visualizing how language models evolve during training." 5 | authors = ["Richard Diehl Martinez "] 6 | license = "Apache 2.0" 7 | readme = "README.md" 8 | packages = [{include = "src"}] 9 | 10 | [tool.poetry.dependencies] 11 | python = "^3.10" 12 | huggingface-hub = {extras = ["cli"], version = "^0.25.1"} 13 | torch = "^2.5.1" 14 | datasets = "^3.2.0" 15 | rich = "^13.9.4" 16 | pre-commit = "^4.2.0" 17 | wandb = "^0.19.5" 18 | 19 | [tool.poetry.scripts] 20 | analyze = "scripts.analyze:main" 21 | 22 | [tool.poetry.group.dev.dependencies] 23 | ipykernel = "^6.29.5" 24 | jupyter = "^1.1.1" 25 | 26 | [build-system] 27 | requires = ["poetry-core"] 28 | build-backend = "poetry.core.masonry.api" 29 | -------------------------------------------------------------------------------- /scripts/analyze.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | The main script for running learning dynamics analysis. 4 | 5 | Given a metrics config and a trained model, this script will load in the model at different 6 | checkpoints and computed the specified learning dynamics metrics. 7 | """ 8 | 9 | import json 10 | import os 11 | from dataclasses import asdict 12 | 13 | import click 14 | 15 | from src.metrics import BaseComparativeMetric, get_metric 16 | from src.utils.data import get_checkpoint_states, get_training_config 17 | from src.utils.exceptions import InvalidStepError 18 | from src.utils.initialization import ( 19 | CheckpointLocation, 20 | initialize_config, 21 | initialize_logging, 22 | initialize_output_dir, 23 | initialize_wandb, 24 | ) 25 | from src.utils.logging import pretty_print_component_metrics, pretty_print_config 26 | 27 | 28 | @click.command() 29 | @click.option( 30 | "--config_path", 31 | type=str, 32 | required=True, 33 | help="Path to the metrics configuration file.", 34 | ) 35 | @click.option("--repo_id", type=str, help="Repository ID.") 36 | @click.option("--branch", type=str, help="Branch name.") 37 | @click.option("--run_path", type=str, help="Path to the run directory.") 38 | def main(config_path: str, repo_id: str, branch: str, run_path: str): 39 | """ 40 | The main function for running learning dynamics analysis. Also note that config_path is a 41 | required argument, AND either repo_id must be provided or branch and run_path must be provided. 42 | If this is not specified here, it will raise an error when the checkpoint location is 43 | initialized. 44 | 45 | Args: 46 | config_path: str -- the path to the metrics configuration file. (required) 47 | 48 | repo_id: str -- the repository id. 49 | branch: str -- the branch name. 50 | run_path: str -- the path to the run directory. 51 | """ 52 | 53 | # Loads in the metrics config (the config that specifies the metrics to compute) 54 | metrics_config = initialize_config(config_path) 55 | 56 | # A helper class that stores the checkpoint location (either a local run or a remote run on HF) 57 | # NOTE: this will raise an error if repo_id is not provided and branch and run_path are not 58 | # provided. 59 | checkpoint_location = CheckpointLocation(repo_id, branch, run_path) 60 | 61 | # Loads in the training config (the config that specifies the model architecture, etc.) for the 62 | # given checkpoint location. NOTE: we use this to automatically determine parts of the model 63 | # architecture (e.g. the hidden dimension, number of attention heads, etc.) 64 | training_config = get_training_config(checkpoint_location) 65 | 66 | ############################################################ 67 | # 68 | # Monitoring Setup (Logging and Wandb) 69 | # 70 | ############################################################ 71 | 72 | # Set up the output directory 73 | output_dir = initialize_output_dir(metrics_config, training_config) 74 | logger = initialize_logging(output_dir) 75 | 76 | # Log the learning dynamics and training configurations to the logger 77 | logger.info("=" * 80) 78 | logger.info("Initializing Pico Analysis") 79 | logger.info("=" * 80) 80 | 81 | pretty_print_config(logger, "Learning Dynamics Config", asdict(metrics_config)) 82 | pretty_print_config(logger, "Training Config", training_config) 83 | 84 | logger.info("=" * 80 + "\n") 85 | 86 | # Set up the wandb run 87 | if metrics_config.monitoring.save_to_wandb: 88 | wandb_run = initialize_wandb(metrics_config) 89 | 90 | ############################################################ 91 | # 92 | # Setting up Metrics 93 | # 94 | ############################################################ 95 | 96 | metrics = {} 97 | 98 | # Setup all of the metrics 99 | for metric_config in metrics_config.metrics: 100 | # Sets up the metric specified in the metrics config 101 | metric = get_metric(metric_config, training_config) 102 | 103 | # NOTE: if the metric is a comparative metric, we need to set the target checkpoint 104 | # for the metric. 105 | if isinstance(metric, BaseComparativeMetric): 106 | target_checkpoint_states = get_checkpoint_states( 107 | checkpoint_location=checkpoint_location, 108 | step=metric_config.target_checkpoint, 109 | data_split=metric_config.data_split, 110 | ) 111 | metric.set_target(target_checkpoint_states) 112 | 113 | metrics[metric_config.metric_name] = metric 114 | 115 | ############################################################ 116 | # 117 | # Computing and Logging Metrics over the checkpoint steps 118 | # 119 | ############################################################ 120 | 121 | # Computing the metrics for each step 122 | for step in metrics_config.steps: 123 | step_directory = os.path.join(output_dir, f"step_{step}") 124 | os.makedirs(step_directory, exist_ok=True) 125 | 126 | step_metrics = {} 127 | 128 | for metric_name, metric in metrics.items(): 129 | try: 130 | checkpoint_states = get_checkpoint_states( 131 | checkpoint_location=checkpoint_location, 132 | step=step, 133 | data_split=metric.metric_config.data_split, 134 | ) 135 | except InvalidStepError: 136 | # NOTE: this can happen if the step is not available for the given data split; 137 | # e.g. mostly likely to happen for the last-step of the training run if a metric 138 | # was not computed on the training data. 139 | logger.warning( 140 | f"Skipping step {step} for metric {metric_name} on split {metric.metric_config.data_split} because the checkpoint does not exist" 141 | ) 142 | continue 143 | 144 | # NOTE: metric returns a list of dictionaries which corresponds to metric data 145 | # for each component specified in the metrics config. 146 | component_metrics_list = metric(checkpoint_states) 147 | 148 | component_metrics_dict = {} 149 | for component_metrics in component_metrics_list: 150 | component_metrics_dict.update(component_metrics) 151 | 152 | step_metrics[metric_name] = component_metrics_dict 153 | 154 | # store out the data to the output directory 155 | with open( 156 | os.path.join( 157 | step_directory, 158 | f"{metric_name}_{metric.metric_config.data_split}.json", 159 | ), 160 | "w", 161 | ) as f: 162 | json.dump(component_metrics_dict, f) 163 | 164 | if metrics_config.monitoring.save_to_wandb: 165 | # Create a nested dictionary with metric name as prefix 166 | wandb_formatted_data = { 167 | f"{metric_name}_{metric.metric_config.data_split}/{layer}": value 168 | for layer, value in component_metrics_dict.items() 169 | } 170 | # Add the step information 171 | wandb_run.log(wandb_formatted_data, step=step) 172 | 173 | # Log out all of the metrics at the current step 174 | pretty_print_component_metrics(logger, step, step_metrics) 175 | 176 | 177 | if __name__ == "__main__": 178 | main() 179 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script sets up the project by installing dependencies, checking for a poetry environment, 3 | # and installing pre-commit hooks. 4 | 5 | # Add color and formatting variables at the top 6 | GREEN='\033[0;32m' 7 | BLUE='\033[0;34m' 8 | YELLOW='\033[1;33m' 9 | RED='\033[0;31m' 10 | NC='\033[0m' # No Color 11 | BOLD='\033[1m' 12 | 13 | # Initialize error tracking 14 | ERRORS_FOUND=0 15 | 16 | # Function for section headers 17 | print_section() { 18 | echo -e "\n${BOLD}${BLUE}=== $1 ===${NC}\n" 19 | } 20 | 21 | # Function for success messages 22 | print_success() { 23 | echo -e "${GREEN}✓ $1${NC}" 24 | } 25 | 26 | # Function for warnings 27 | print_warning() { 28 | echo -e "${YELLOW}⚠ $1${NC}" 29 | } 30 | 31 | # --- GIT LFS SETUP --- # 32 | print_section "Git LFS Setup" 33 | if ! command -v git-lfs &> /dev/null; then 34 | print_warning "git-lfs is not installed. Some model checkpointing functionality may not work correctly." 35 | ERRORS_FOUND=$((ERRORS_FOUND + 1)) 36 | 37 | # Check the operating system 38 | if [[ "$OSTYPE" == "darwin"* ]]; then 39 | # macOS 40 | echo -e "${YELLOW} You can install it using Homebrew:${NC}" 41 | echo " brew install git-lfs" 42 | elif [[ "$OSTYPE" == "linux-gnu"* ]]; then 43 | # Linux 44 | echo -e "${YELLOW} You can install it using your package manager:${NC}" 45 | if command -v apt-get &> /dev/null; then 46 | # Ubuntu/Debian 47 | echo " curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash" 48 | echo " sudo apt-get install git-lfs" 49 | elif command -v yum &> /dev/null; then 50 | # CentOS/RHEL 51 | echo " curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | sudo bash" 52 | echo " sudo yum install git-lfs" 53 | else 54 | print_warning "Could not detect package manager. Please install git-lfs manually." 55 | fi 56 | else 57 | print_warning "Unsupported operating system. Please install git-lfs manually." 58 | fi 59 | else 60 | git-lfs install 61 | print_success "git-lfs installed and initialized" 62 | fi 63 | 64 | # --- CUDA VERSION CHECK --- # 65 | print_section "CUDA Version Check" 66 | if command -v nvidia-smi &> /dev/null; then 67 | CUDA_VERSION=$(nvidia-smi | sed -n 's/.*CUDA Version: \([0-9.]*\).*/\1/p') 68 | 69 | if [[ -z "$CUDA_VERSION" ]]; then 70 | ERRORS_FOUND=$((ERRORS_FOUND + 1)) 71 | print_warning "nvidia-smi failed to communicate with the NVIDIA driver." 72 | echo -e "${YELLOW} Ensure that the latest NVIDIA driver is installed and running.${NC}" 73 | else 74 | MAJOR_VERSION=${CUDA_VERSION%.*} 75 | MINOR_VERSION=${CUDA_VERSION#*.} 76 | 77 | if [ "$MAJOR_VERSION" -lt 12 ] || ([ "$MAJOR_VERSION" -eq 12 ] && [ "$MINOR_VERSION" -lt 1 ]); then 78 | ERRORS_FOUND=$((ERRORS_FOUND + 1)) 79 | print_warning "CUDA version ${MAJOR_VERSION}.${MINOR_VERSION} detected." 80 | echo -e "${YELLOW} Some multi-node communication GPU features may not work properly.${NC}" 81 | echo -e "${YELLOW} CUDA version 12.1 or newer is recommended.${NC}" 82 | else 83 | print_success "CUDA version ${MAJOR_VERSION}.${MINOR_VERSION} detected" 84 | fi 85 | fi 86 | else 87 | ERRORS_FOUND=$((ERRORS_FOUND + 1)) 88 | print_warning "nvidia-smi not found. Unable to check CUDA version." 89 | echo -e "${YELLOW} Ensure that NVIDIA drivers and CUDA version at 12.1 or newer are installed for GPU support.${NC}" 90 | fi 91 | 92 | 93 | # ---- ENVIRONMENT VARIABLES ---- # 94 | print_section "Environment Variables" 95 | if [ -f .env ]; then 96 | print_success "Loading environment variables from .env..." 97 | source .env 98 | if [[ -n "$HF_TOKEN" && -n "$WANDB_API_KEY" ]]; then 99 | print_success "Both HF_TOKEN and WANDB_API_KEY are set and loaded!" 100 | else 101 | print_warning "One or both of HF_TOKEN and WANDB_API_KEY are not set." 102 | fi 103 | else 104 | print_warning "No .env file found." 105 | echo -e "${YELLOW} You might need to create one with HF_TOKEN and WANDB_API_KEY${NC}" 106 | echo -e "${YELLOW} Example .env contents:${NC}" 107 | echo " export HF_TOKEN=your_huggingface_token" 108 | echo " export WANDB_API_KEY=your_wandb_key" 109 | ERRORS_FOUND=$((ERRORS_FOUND + 1)) 110 | fi 111 | 112 | # ---- POETRY SETUP ---- # 113 | print_section "Poetry Setup" 114 | 115 | # First check if Poetry is installed 116 | if ! command -v poetry &> /dev/null; then 117 | echo "Poetry not found. Installing..." 118 | curl -sSL https://install.python-poetry.org | python3 - 119 | print_success "Poetry installed successfully" 120 | else 121 | print_success "Poetry already installed" 122 | fi 123 | 124 | # Then check for virtual environment 125 | if [ ! -d ".venv" ]; then 126 | echo "No virtual environment found. Creating one..." 127 | poetry config virtualenvs.in-project true 128 | poetry install --with dev 129 | print_success "Poetry environment created successfully" 130 | else 131 | print_success "Poetry environment already exists" 132 | fi 133 | 134 | # ---- PRE-COMMIT SETUP ---- # 135 | print_section "Pre-commit Setup" 136 | 137 | # Install pre-commit hooks 138 | echo "Installing pre-commit hooks..." 139 | poetry run pre-commit install 140 | print_success "Pre-commit hooks installed" 141 | 142 | # Run pre-commit hooks on all files 143 | echo "Running pre-commit hooks on all files..." 144 | poetry run pre-commit run --all-files 145 | print_success "Pre-commit initial run complete" 146 | 147 | # --- Final Status Message --- # 148 | 149 | # Final status message 150 | print_section "Setup Status" 151 | if [ $ERRORS_FOUND -eq 0 ]; then 152 | print_success "Setup Complete! 🎉" 153 | print_success "To activate the virtual environment, run: poetry shell" 154 | else 155 | print_warning "Setup completed with warnings! Please check the messages above." 156 | echo -e "${YELLOW} Some features might not work as expected.${NC}" 157 | fi 158 | -------------------------------------------------------------------------------- /src/components/__init__.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: F401 2 | 3 | # NOTE: Importing these components automatically adds them to the COMPONENT_REGISTRY 4 | from typing import Any, Dict 5 | 6 | from src.components._registry import COMPONENT_REGISTRY 7 | from src.components.base import BaseComponent 8 | from src.components.ov_circuit import OVComponent 9 | from src.components.simple import SimpleComponent 10 | 11 | # typing imports 12 | from src.config.base import BaseComponentConfig 13 | 14 | COMPONENT_CACHE = {} 15 | 16 | 17 | def get_component( 18 | component_config: BaseComponentConfig, run_config: Dict[str, Any] 19 | ) -> BaseComponent: 20 | """ 21 | Get a component from the registry. We cache components to avoid re-initializing them, 22 | if the same component is used across multiple metrics. One benefit of doing so, is that 23 | components can maintain an internal state to avoid re-computing the same component multiple 24 | times on the same layer. 25 | 26 | Args: 27 | component_config: BaseComponentConfig -- the component configuration. 28 | run_config: Dict[str, Any] -- the run config. 29 | 30 | Returns: 31 | BaseComponent -- the component. 32 | """ 33 | 34 | component_name = component_config.component_name 35 | # track if the component name and type 36 | if component_name not in COMPONENT_CACHE: 37 | COMPONENT_CACHE[component_name] = COMPONENT_REGISTRY[component_name](run_config) 38 | 39 | return COMPONENT_CACHE[component_name] 40 | -------------------------------------------------------------------------------- /src/components/_registry.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, Type, TypeVar 2 | 3 | from src.components.base import BaseComponent 4 | 5 | T = TypeVar("T", bound=BaseComponent) 6 | 7 | COMPONENT_REGISTRY: Mapping[str, Type[T]] = {} 8 | 9 | 10 | def register_component(name: str): 11 | """ 12 | Decorator to register a component class with the ComponentRegistry. 13 | """ 14 | 15 | def _register(cls: Type[T]) -> Type[T]: 16 | COMPONENT_REGISTRY[name] = cls 17 | return cls 18 | 19 | return _register 20 | -------------------------------------------------------------------------------- /src/components/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for components. 3 | """ 4 | 5 | import os 6 | from abc import ABC, abstractmethod 7 | from typing import Any, Dict 8 | 9 | import torch 10 | 11 | from src.config.base import BaseComponentConfig 12 | 13 | 14 | class BaseComponent(ABC): 15 | """ 16 | Base class for components. There are two types of components: 17 | 1. Simple components: these are components that are a single weight, activation or gradient 18 | tensor from a given layer; e.g. the weight matrix of a layer, or the gradients of the loss 19 | wrt. the activations of a single weight matrix. 20 | 2. Compound components: these are components that are made up of multiple single components. 21 | For example, the OV-Circuit is a compound component that is made up of the value and 22 | output projection layers. 23 | 24 | 25 | Components are functional objects that are used to generate a component from a given checkpoint 26 | state and a component configuration. 27 | """ 28 | 29 | def __init__(self, training_config: Dict[str, Any]): 30 | self.training_config = training_config 31 | 32 | def get_model_prefix(self, data: Dict[str, Any]) -> str: 33 | """ 34 | Simple helper function to get the model prefix from the checkpoint activation keys. 35 | 36 | The model prefix is the part of the key that is common to all of the layers in the model. 37 | 38 | For example, if we have the following list of layer names: 39 | ``` 40 | { 41 | "model.0.weight": torch.Tensor, 42 | "model.1.weight": torch.Tensor, 43 | } 44 | ``` 45 | 46 | The model prefix is "model.". 47 | """ 48 | 49 | # NOTE: this should be the same for activations and weights 50 | _activation_layernames = list(data["activations"].keys()) 51 | 52 | return os.path.commonprefix(_activation_layernames) 53 | 54 | @abstractmethod 55 | def validate_component(self, component_config: BaseComponentConfig) -> None: 56 | """ 57 | Check the component config; components should specify the required keys in the component 58 | config by overriding this method. This function should be called by the metric to ensure 59 | that the component config is valid. 60 | 61 | Args: 62 | component_config: BaseComponentConfig -- the component configuration. 63 | 64 | Raises: 65 | InvalidComponentError -- if the component config is invalid. 66 | """ 67 | raise NotImplementedError 68 | 69 | @abstractmethod 70 | def __call__( 71 | self, 72 | checkpoint_states: Dict[str, Dict[str, torch.Tensor]], 73 | component_config: BaseComponentConfig, 74 | ) -> Dict[str, torch.Tensor]: 75 | """ 76 | Generate a component. For compound components, this will likely involve some form of matrix 77 | multiplication of different activations, weights, or gradients to produce a desired 78 | component. For simple components, this will just return the activations, weights, 79 | or gradients for the given layers. 80 | 81 | Args: 82 | checkpoint_states: Dict[str, Dict[str, torch.Tensor]] -- the checkpoint states for 83 | a given checkpoint. 84 | component_config: BaseComponentConfig -- the component configuration. 85 | 86 | Returns: 87 | Dict[str, torch.Tensor] -- a dictionary mapping layer names to the component at that layer; i.e. 88 | { 89 | "model.0.component_name": torch.Tensor, 90 | "model.1.component_name": torch.Tensor, 91 | } 92 | """ 93 | raise NotImplementedError 94 | -------------------------------------------------------------------------------- /src/components/ov_circuit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Output-Value Circuit component: a compound component that is composed of the value and output 3 | projections of the attention modules of a transformer model. 4 | """ 5 | 6 | from functools import lru_cache 7 | from typing import Any, Dict, Tuple 8 | 9 | import torch 10 | 11 | from src.components._registry import register_component 12 | from src.components.base import BaseComponent 13 | from src.config.base import BaseComponentConfig 14 | from src.utils.exceptions import InvalidComponentError 15 | 16 | 17 | @register_component("ov_circuit") 18 | class OVComponent(BaseComponent): 19 | """ 20 | Compute the 'OV-Circuit' component. The idea of an OV-Circuit stems from the observation 21 | that in an attention module, the value and the output projections always operate jointly, 22 | and write into the 'residual stream'. Thus, it makes sense to treat the Output and Value 23 | matrices as a single 'OV-Circuit' matrix, which we can compute the activations and weights for. 24 | 25 | Also note that the OV-Circuit operates 'per head' of the attention module. 26 | 27 | To read more about the OV-Circuit, see: 28 | https://transformer-circuits.pub/2021/framework/index.html 29 | """ 30 | 31 | def __init__(self, training_config: Dict[str, Any]): 32 | super().__init__(training_config) 33 | 34 | self.d_model = training_config["model"]["d_model"] 35 | 36 | self.attention_n_heads = training_config["model"]["attention_n_heads"] 37 | self.attention_n_kv_heads = training_config["model"]["attention_n_kv_heads"] 38 | self.attention_head_dim = self.d_model // self.attention_n_heads 39 | 40 | @lru_cache(maxsize=50) 41 | def compute_ov_activations( 42 | self, 43 | layer_value_activation: torch.Tensor, 44 | layer_output_projection: torch.Tensor, 45 | ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: 46 | """ 47 | Compute the OV activations for a single layer. Uses a cache to speed up the computation, 48 | if the component is used across multiple metrics. 49 | 50 | NOTE: the OV-Circuit operates 'per head' of the attention module, so we compute the OV 51 | activations for each head separately and then concatenate them together. 52 | 53 | Args: 54 | layer_value_activation: The value activations for the layer. 55 | layer_output_projection: The output projection for the layer. 56 | 57 | Returns: 58 | A tuple: 59 | - A dictionary mapping head indices to OV component activations. 60 | - A concatenated tensor of the OV component activations. 61 | """ 62 | layer_ov_activation_per_head = {} 63 | 64 | for head_idx in range(self.attention_n_heads): 65 | kv_head_idx = head_idx // ( 66 | self.attention_n_heads // self.attention_n_kv_heads 67 | ) 68 | 69 | if layer_value_activation.dtype != layer_output_projection.dtype: 70 | # NOTE: activations might be stored as memory efficient floats (e.g. bfloat16) 71 | # so we need to make sure we cast to the same type as the weights 72 | layer_value_activation = layer_value_activation.to( 73 | layer_output_projection.dtype 74 | ) 75 | 76 | start_value_activation = kv_head_idx * self.attention_head_dim 77 | end_value_activation = (kv_head_idx + 1) * self.attention_head_dim 78 | 79 | ov_activation_per_head = ( 80 | layer_value_activation[:, start_value_activation:end_value_activation] 81 | @ layer_output_projection[ 82 | :, 83 | head_idx * self.attention_head_dim : (head_idx + 1) 84 | * self.attention_head_dim, 85 | ].T 86 | ) 87 | 88 | layer_ov_activation_per_head[f"{head_idx}"] = ov_activation_per_head 89 | 90 | layer_ov_activation = torch.cat( 91 | list(layer_ov_activation_per_head.values()), dim=1 92 | ) 93 | 94 | return layer_ov_activation_per_head, layer_ov_activation 95 | 96 | @lru_cache(maxsize=50) 97 | def compute_ov_weights( 98 | self, 99 | layer_value_projection: torch.Tensor, 100 | layer_output_projection: torch.Tensor, 101 | ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: 102 | """ 103 | Compute the OV weights for a single layer. Uses a cache to speed up the computation, 104 | if the component is used across multiple metrics. 105 | 106 | NOTE: the OV-Circuit operates 'per head' of the attention module, so we compute the OV 107 | weights for each head separately and then concatenate them together. 108 | 109 | Args: 110 | layer_value_projection: The value projection for the layer. 111 | layer_output_projection: The output projection for the layer. 112 | 113 | Returns: 114 | A tuple: 115 | - A dictionary mapping head indices to OV component weights. 116 | - A concatenated tensor of the OV component weights. 117 | """ 118 | 119 | layer_ov_weights_per_head = {} 120 | 121 | for head_idx in range(self.attention_n_heads): 122 | kv_head_idx = head_idx // ( 123 | self.attention_n_heads // self.attention_n_kv_heads 124 | ) 125 | 126 | start_value_projection = kv_head_idx * self.attention_head_dim 127 | end_value_projection = (kv_head_idx + 1) * self.attention_head_dim 128 | 129 | start_output_projection = head_idx * self.attention_head_dim 130 | end_output_projection = (head_idx + 1) * self.attention_head_dim 131 | 132 | ov_weights_per_head = ( 133 | layer_value_projection[start_value_projection:end_value_projection, :] 134 | @ layer_output_projection[ 135 | :, start_output_projection:end_output_projection 136 | ] 137 | ) 138 | 139 | layer_ov_weights_per_head[f"{head_idx}"] = ov_weights_per_head 140 | 141 | layer_ov_weights = torch.cat(list(layer_ov_weights_per_head.values()), dim=1) 142 | 143 | return layer_ov_weights_per_head, layer_ov_weights 144 | 145 | @lru_cache(maxsize=50) 146 | def compute_ov_gradients( 147 | self, 148 | layer_value_gradient: torch.Tensor, 149 | layer_output_gradient: torch.Tensor, 150 | layer_value_projection: torch.Tensor, 151 | layer_output_projection: torch.Tensor, 152 | ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: 153 | """ 154 | Compute the OV gradients for a single layer. Uses a cache to speed up the computation, 155 | if the component is used across multiple metrics. 156 | 157 | NOTE: the OV-Circuit operates 'per head' of the attention module, so we compute the OV 158 | gradients for each head separately and then concatenate them together. 159 | 160 | The gradients are computed using the chain rule and is composed of two parts: 161 | 1. dL/dOutput * dOutput/dOV 162 | 2. dL/dValue * dValue/dOV 163 | 164 | Args: 165 | layer_value_gradient: The gradient of the value projection for the layer. 166 | layer_output_gradient: The gradient of the output projection for the layer. 167 | layer_value_projection: The value projection for the layer. 168 | layer_output_projection: The output projection for the layer. 169 | 170 | Returns: 171 | A tuple: 172 | - A dictionary mapping head indices to OV component gradients. 173 | - A concatenated tensor of the OV component gradients. 174 | """ 175 | 176 | layer_ov_gradients_per_head = {} 177 | 178 | for head_idx in range(self.attention_n_heads): 179 | kv_head_idx = head_idx // ( 180 | self.attention_n_heads // self.attention_n_kv_heads 181 | ) 182 | 183 | start_value_projection = kv_head_idx * self.attention_head_dim 184 | end_value_projection = (kv_head_idx + 1) * self.attention_head_dim 185 | 186 | start_output_projection = head_idx * self.attention_head_dim 187 | end_output_projection = (head_idx + 1) * self.attention_head_dim 188 | 189 | head_value_gradient = layer_value_gradient[ 190 | start_value_projection:end_value_projection, : 191 | ] 192 | 193 | head_output_gradient = layer_output_gradient[ 194 | :, start_output_projection:end_output_projection 195 | ] 196 | 197 | head_value_projection = layer_value_projection[ 198 | start_value_projection:end_value_projection, : 199 | ] 200 | 201 | head_output_projection = layer_output_projection[ 202 | :, start_output_projection:end_output_projection 203 | ] 204 | 205 | # NOTE: chain rule for gradients; dL/dOV = dL/dOutput * dOutput/dOV + dL/dValue * dValue/dOV 206 | head_ov_gradient = (head_output_projection @ head_value_gradient) + ( 207 | head_output_gradient @ head_value_projection 208 | ) 209 | 210 | layer_ov_gradients_per_head[f"{head_idx}"] = head_ov_gradient 211 | 212 | layer_ov_gradients = torch.cat( 213 | list(layer_ov_gradients_per_head.values()), dim=1 214 | ) 215 | 216 | return layer_ov_gradients_per_head, layer_ov_gradients 217 | 218 | def validate_component(self, component_config: BaseComponentConfig) -> None: 219 | """ 220 | OV circuit components require value and output layer suffixes to be specified. 221 | """ 222 | if ( 223 | "value_layer" not in component_config.layer_suffixes 224 | or "output_layer" not in component_config.layer_suffixes 225 | ): 226 | raise InvalidComponentError( 227 | "OV circuit component requires value and output layer suffixes." 228 | ) 229 | 230 | def __call__( 231 | self, 232 | checkpoint_states: Dict[str, Dict[str, torch.Tensor]], 233 | component_config: BaseComponentConfig, 234 | ) -> Dict[str, torch.Tensor]: 235 | """ 236 | Generates the OV circuit component. The OV circuit component is a compound component 237 | that is composed of the value and output projections of the model. 238 | 239 | NOTE: the OV-Circuit operates 'per head' of the attention module, so we compute and return 240 | the OV activations, weights, and gradients for each head separately and then concatenate 241 | them together as well into one aggregated tensor. 242 | 243 | Args: 244 | checkpoint_states: Dict[str, Dict[str, torch.Tensor]] -- the checkpoint states 245 | to generate the component for. 246 | component_config: BaseComponentConfig -- the component configuration. 247 | 248 | Returns: 249 | Dict[str, torch.Tensor] -- the OV circuit component; mapping layer names to OV circuit 250 | activations. 251 | """ 252 | 253 | layer_suffixes = component_config.layer_suffixes 254 | 255 | checkpoint_layer_component = {} 256 | 257 | _model_prefix = self.get_model_prefix(checkpoint_states) 258 | 259 | for layer_idx in component_config.layers: 260 | layer_output_projection = checkpoint_states["weights"][ 261 | f"{_model_prefix}{layer_idx}.{layer_suffixes['output_layer']}" 262 | ] 263 | 264 | # NOTE: depending on whether we are computing activations, weights, or gradients, we 265 | # need to fetch the appropriate tensors from the checkpoint states 266 | 267 | if component_config.data_type == "activations": 268 | layer_value_activation = checkpoint_states["activations"][ 269 | f"{_model_prefix}{layer_idx}.{layer_suffixes['value_layer']}" 270 | ] 271 | elif component_config.data_type in ["weights", "gradients"]: 272 | layer_value_projection = checkpoint_states["weights"][ 273 | f"{_model_prefix}{layer_idx}.{layer_suffixes['value_layer']}" 274 | ] 275 | 276 | if component_config.data_type == "gradients": 277 | layer_output_gradient = checkpoint_states["gradients"][ 278 | f"{_model_prefix}{layer_idx}.{layer_suffixes['output_layer']}" 279 | ] 280 | layer_value_gradient = checkpoint_states["gradients"][ 281 | f"{_model_prefix}{layer_idx}.{layer_suffixes['value_layer']}" 282 | ] 283 | 284 | # NOTE: each computation is done 'per head' of the attention module, and return a 285 | # tuple of a dictionary mapping head indices to the OV component and a concatenated 286 | # tensor of the OV component 287 | 288 | if component_config.data_type == "activations": 289 | ov_component_per_head, ov_component = self.compute_ov_activations( 290 | layer_value_activation, layer_output_projection 291 | ) 292 | elif component_config.data_type == "weights": 293 | ov_component_per_head, ov_component = self.compute_ov_weights( 294 | layer_value_projection, layer_output_projection 295 | ) 296 | elif component_config.data_type == "gradients": 297 | ov_component_per_head, ov_component = self.compute_ov_gradients( 298 | layer_value_gradient, 299 | layer_output_gradient, 300 | layer_value_projection, 301 | layer_output_projection, 302 | ) 303 | 304 | for head_idx, ov_component_head in ov_component_per_head.items(): 305 | checkpoint_layer_component[ 306 | f"{_model_prefix}{layer_idx}.ov_circuit.{component_config.data_type}.heads.{head_idx}" 307 | ] = ov_component_head 308 | 309 | checkpoint_layer_component[ 310 | f"{_model_prefix}{layer_idx}.ov_circuit.{component_config.data_type}" 311 | ] = ov_component 312 | 313 | return checkpoint_layer_component 314 | -------------------------------------------------------------------------------- /src/components/simple.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple components are those that are a single layer. For example, the weight matrix of a layer is 3 | a single component. In other words, simple components are those that can just be extracted directly 4 | from the stored out checkpoint data without much additional computation. 5 | """ 6 | 7 | from typing import Any, Dict 8 | 9 | import torch 10 | 11 | from src.components._registry import register_component 12 | from src.components.base import BaseComponent 13 | from src.config.base import BaseComponentConfig 14 | from src.utils.exceptions import InvalidComponentError 15 | 16 | 17 | @register_component("simple") 18 | class SimpleComponent(BaseComponent): 19 | """ 20 | Simple component is a component that is a single layer. For example, the weight matrix of a layer is 21 | a single component. 22 | """ 23 | 24 | def validate_component(self, component_config: BaseComponentConfig) -> None: 25 | """ 26 | Simple components can be any of the data types (activations, weights, or gradients). 27 | """ 28 | if component_config.data_type not in ["activations", "weights", "gradients"]: 29 | raise InvalidComponentError( 30 | f"Simple component only supports activations, weights, or gradients, not {component_config.data_type}." 31 | ) 32 | 33 | def __call__( 34 | self, 35 | checkpoint_states: Dict[str, Dict[str, torch.Tensor]], 36 | component_config: BaseComponentConfig, 37 | ) -> Dict[str, Any]: 38 | """ 39 | Given a dictionary of checkpoint data, extract the activations, weights, or gradients for 40 | the given layer suffix and layer. 41 | 42 | Args: 43 | checkpoint_states: Checkpoint data (activations, weights, gradients) 44 | component_config: The component configuration. 45 | 46 | Returns: 47 | A dictionary mapping layer names to MLP activations. 48 | """ 49 | 50 | checkpoint_layer_component = {} 51 | 52 | _data = checkpoint_states[component_config.data_type] 53 | _model_prefix = self.get_model_prefix(checkpoint_states) 54 | 55 | for layer_idx in component_config.layers: 56 | layer_component = _data[ 57 | f"{_model_prefix}{layer_idx}.{component_config.layer_suffixes}" 58 | ] 59 | checkpoint_layer_component[ 60 | f"{_model_prefix}{layer_idx}.{component_config.layer_suffixes}.{component_config.data_type}" 61 | ] = layer_component 62 | 63 | return checkpoint_layer_component 64 | -------------------------------------------------------------------------------- /src/config/_registry.py: -------------------------------------------------------------------------------- 1 | from typing import Type, TypeVar 2 | 3 | from src.config.base import BaseMetricConfig 4 | 5 | T = TypeVar("T", bound=BaseMetricConfig) 6 | 7 | METRIC_CONFIG_REGISTRY: dict[str, Type[T]] = {} 8 | 9 | 10 | def register_metric_config(metric_name: str): 11 | """ 12 | Decorator to register a metric config class with the MetricConfigRegistry. 13 | """ 14 | 15 | def decorator(cls): 16 | METRIC_CONFIG_REGISTRY[metric_name] = cls 17 | return cls 18 | 19 | return decorator 20 | -------------------------------------------------------------------------------- /src/config/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base configuration classes for the config system which is composed primarily of metrics that 3 | operate on top of components. 4 | """ 5 | 6 | from dataclasses import dataclass, field 7 | from typing import Dict, List 8 | 9 | # NOTE: Base class for components of a metrics. 10 | 11 | 12 | @dataclass 13 | class BaseComponentConfig: 14 | """ 15 | Base configuration for a component of a model. 16 | 17 | A component can be a single layer, group of layers, activations etc, the choice is arbitrary, 18 | as long as the given metric defines how to compute the metric for the component. 19 | 20 | Example: 21 | 22 | component_name: ov_circuit # name of the component 23 | layer_suffixes: 24 | output_layer: "attention.out_proj" # suffix of the layer to compute the metric for 25 | value_layer: "attention.v_proj" # suffix of the layer to compute the metric for 26 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] # layers to compute the metric for 27 | data_type: "weights" # type of checkpoint data to compute the component for (e.g. "weights", "activations", "gradients") 28 | 29 | """ 30 | 31 | component_name: str # name of the component 32 | layer_suffixes: ( 33 | str | Dict[str, str] 34 | ) # suffixes of the layers to compute the metric for 35 | layers: List[int] # layers to compute the metric for 36 | data_type: str = None # type of checkpoint data to compute the component for (e.g. "weights", "activations", "gradients") 37 | 38 | 39 | """ 40 | NOTE: Base class for metrics. 41 | 42 | There are two types of metrics: 43 | 1. BaseMetricConfig: these are metrics that are computed on a single checkpoint; e.g. the 44 | norm of a layer at a given checkpoint. 45 | 2. BaseComparativeMetricConfig: these are metrics that are computed on a pair of checkpoints; e.g. 46 | the CKA between two layers at two different checkpoints to measure how similar the two layers 47 | are between the two checkpoints. 48 | """ 49 | 50 | 51 | @dataclass 52 | class BaseMetricConfig: 53 | """ 54 | Base configuration for a metric. All metrics should implement this class. Requires 55 | specifying the components to compute the metric for, the metric name, and the data split. 56 | 57 | Args: 58 | metric_name: str -- the name of the metric. 59 | components: List[BaseComponentConfig] -- the components to compute the metric for. 60 | data_split: str -- the data split to compute the metric for (e.g. "train", "val", "test"). 61 | 62 | """ 63 | 64 | metric_name: str = None 65 | components: List[BaseComponentConfig] = field(default_factory=list) 66 | data_split: str = None 67 | 68 | def __post_init__(self): 69 | """ 70 | Post-initialization method to convert yaml dictionaries of components to proper 71 | BaseComponentConfig objects. 72 | """ 73 | _process_components = [] 74 | 75 | for component_config in self.components: 76 | if isinstance(component_config, dict): 77 | _process_components.append(BaseComponentConfig(**component_config)) 78 | else: 79 | _process_components.append(component_config) 80 | 81 | self.components = _process_components 82 | 83 | 84 | @dataclass 85 | class BaseComparativeMetricConfig(BaseMetricConfig): 86 | """ 87 | Base configuration for a comparative metric (which is a subclass of BaseMetricConfig). 88 | 89 | A comparative metric is a metric that is computed on a pair of checkpoints to compare how 90 | a model's activations or weights change between two different checkpoints. 91 | 92 | Args: 93 | target_checkpoint: int -- the checkpoint to compare the source checkpoint to. 94 | """ 95 | 96 | target_checkpoint: int = None 97 | -------------------------------------------------------------------------------- /src/config/learning_dynamics.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: F403 2 | 3 | """ 4 | Configuration classes for learning dynamics analysis. 5 | """ 6 | 7 | from dataclasses import dataclass, field 8 | from typing import List 9 | 10 | from src.config._registry import METRIC_CONFIG_REGISTRY 11 | from src.config.base import ( 12 | BaseComparativeMetricConfig, 13 | BaseMetricConfig, 14 | ) 15 | from src.config.metrics import * 16 | from src.config.monitoring import MonitoringConfig 17 | 18 | 19 | @dataclass 20 | class LearningDynamicsConfig: 21 | """ 22 | Root configuration for specifying what learning dynamics metrics to compute, and at which 23 | checkpoint steps to compute them for. Metrics can be single-checkpoint or comparative metrics, 24 | and are computed on components of the model. 25 | """ 26 | 27 | # Name for the analysis, used to store/bookkeep the analysis results 28 | analysis_name: str = None 29 | 30 | metrics: List[BaseMetricConfig | BaseComparativeMetricConfig] = field( 31 | default_factory=list 32 | ) 33 | steps: List[int] = field(default_factory=list) 34 | 35 | monitoring: MonitoringConfig = field(default_factory=MonitoringConfig) 36 | 37 | def __post_init__(self): 38 | """ 39 | Post-initialization method to convert metric dictionaries to proper config objects. Used 40 | for loading in metrics from a yaml file where the metrics are specified as dictionaries. 41 | 42 | Example yaml file: 43 | metrics: 44 | - metric_name: cka 45 | target_checkpoint: 1000 46 | data_split: "val" 47 | components: 48 | - component_name: simple 49 | data_type: "weights" 50 | layer_suffixes: "swiglu.w_2" 51 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 52 | - component_name: ov_circuit 53 | data_type: "weights" 54 | layer_suffixes: 55 | output_layer: "attention.o_proj" 56 | value_layer: "attention.v_proj" 57 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 58 | - metric_name: norm 59 | data_split: "train" 60 | norm_type: "nuclear" 61 | components: 62 | - component_name: simple 63 | data_type: "weights" 64 | layer_suffixes: "swiglu.w_2" 65 | layers: [0,1,2,3,4,5,6,7,8,9,10,11] 66 | 67 | This will be converted to the following config object: 68 | LearningDynamicsConfig( 69 | metrics=[ 70 | CKAConfig(metric_name="cka", target_checkpoint=1000, ...), 71 | NormConfig(metric_name="norm", data_split="train", ...) 72 | ] 73 | ) 74 | 75 | Also note that we specify for which steps we want to compute the metrics for. Either 76 | a list of steps can be specified, or a range of steps can be specified. If a range is 77 | specified, we will compute the metrics for all steps in the range (including the end-step). 78 | 79 | Example yaml file: 80 | steps: 81 | start: 0 82 | end: 100 83 | step: 50 84 | 85 | This will be converted to the following config object: 86 | LearningDynamicsConfig(steps=[0,50,100]) 87 | 88 | """ 89 | # Convert metric dictionaries to proper config objects 90 | if isinstance(self.metrics, list): 91 | processed_metrics = [] 92 | for metric in self.metrics: 93 | if isinstance(metric, dict): 94 | metric_name = metric.get("metric_name") 95 | if metric_name is None: 96 | raise ValueError( 97 | "metric_name must be specified for each metric" 98 | ) 99 | 100 | if metric_name in METRIC_CONFIG_REGISTRY: 101 | processed_metrics.append( 102 | METRIC_CONFIG_REGISTRY[metric_name](**metric) 103 | ) 104 | else: 105 | raise ValueError(f"Unknown metric_name: {metric_name}") 106 | else: 107 | processed_metrics.append(metric) 108 | self.metrics = processed_metrics 109 | 110 | if isinstance(self.steps, dict): 111 | self.steps = list( 112 | range( 113 | self.steps["start"], 114 | self.steps["end"] 115 | + self.steps["step"], # + step to include end step 116 | self.steps["step"], 117 | ) 118 | ) 119 | elif isinstance(self.steps, list): 120 | self.steps = [int(step) for step in self.steps] 121 | else: 122 | raise ValueError("steps must be a list of integers or a StepRangeConfig") 123 | 124 | self.monitoring = MonitoringConfig(**self.monitoring) 125 | -------------------------------------------------------------------------------- /src/config/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configurations for metrics. Recall there are two types of metrics: 3 | 4 | 1. BaseMetricConfig: these are metrics that are computed on a single checkpoint; e.g. the 5 | norm of a layer at a given checkpoint. 6 | 2. BaseComparativeMetricConfig: these are metrics that are computed on a pair of checkpoints; e.g. 7 | the CKA between two layers at two different checkpoints to measure how similar the two layers 8 | are between the two checkpoints. 9 | """ 10 | 11 | from dataclasses import dataclass 12 | 13 | from src.config._registry import register_metric_config 14 | from src.config.base import BaseComparativeMetricConfig, BaseMetricConfig 15 | 16 | # ----------------- 17 | # Single Checkpoint (Base) Metrics 18 | # ----------------- 19 | 20 | 21 | @dataclass 22 | @register_metric_config("condition_number") 23 | class ConditionNumberConfig(BaseMetricConfig): 24 | """ 25 | Configuration for the condition number metric. 26 | """ 27 | 28 | ... 29 | 30 | 31 | @dataclass 32 | @register_metric_config("norm") 33 | class NormConfig(BaseMetricConfig): 34 | """ 35 | Configuration for a norm metric. 36 | """ 37 | 38 | # NOTE: used to specify what type of norm to compute: 39 | # options are "Frobenius", "spectral", "max" 40 | norm_type: str = None 41 | 42 | 43 | @dataclass 44 | @register_metric_config("per") 45 | class PERConfig(BaseMetricConfig): 46 | """ 47 | Configuration for the Proportional Effective Rank (PER) metric. 48 | The PER is a metric that measures the effective rank of a matrix, and is defined in: 49 | Tending Towards Stability: Convergence Challenges in Small Language Models 50 | https://aclanthology.org/2024.findings-emnlp.187/ 51 | """ 52 | 53 | ... 54 | 55 | 56 | @dataclass 57 | @register_metric_config("gini") 58 | class GiniConfig(BaseMetricConfig): 59 | """ 60 | Configuration for the Gini metric. 61 | """ 62 | 63 | ... 64 | 65 | 66 | @dataclass 67 | @register_metric_config("hoyer") 68 | class HoyerConfig(BaseMetricConfig): 69 | """ 70 | Configuration for the Hoyer metric. 71 | """ 72 | 73 | ... 74 | 75 | 76 | # ----------------- 77 | # Multi-Checkpoint (Comparative) Metrics 78 | # ----------------- 79 | 80 | 81 | @dataclass 82 | @register_metric_config("pwcca") 83 | class PWCCAConfig(BaseComparativeMetricConfig): 84 | """ 85 | Configuration for the PWCCA metric; a comparative metric that computes the similarity between two 86 | layers' activations at two different checkpoints. 87 | """ 88 | 89 | ... 90 | 91 | 92 | @dataclass 93 | @register_metric_config("cka") 94 | class CKAConfig(BaseComparativeMetricConfig): 95 | """ 96 | Configuration for the CKA metric; a comparative metric that computes the similarity between two 97 | layers' activations at two different checkpoints. 98 | """ 99 | 100 | ... 101 | -------------------------------------------------------------------------------- /src/config/monitoring.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class WandbConfig: 6 | """ 7 | Configuration for the Wandb experiment tracker. 8 | """ 9 | 10 | entity: str = None 11 | project: str = None 12 | 13 | 14 | @dataclass 15 | class MonitoringConfig: 16 | """ 17 | Configuration for the monitoring/logging of learning dynamics metrics. 18 | """ 19 | 20 | output_dir: str = "analysis_results" 21 | 22 | save_to_wandb: bool = False 23 | wandb: WandbConfig = field(default_factory=WandbConfig) 24 | 25 | def __post_init__(self): 26 | """ 27 | Post-initialization method to convert metric dictionaries to proper config objects. Used 28 | for loading in metrics from a yaml file where the metrics are specified as dictionaries. 29 | """ 30 | if isinstance(self.wandb, dict): 31 | self.wandb = WandbConfig(**self.wandb) 32 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: F401 2 | 3 | # NOTE: Importing these metrics automatically adds them to the METRIC_REGISTRY 4 | from typing import Any, Dict 5 | 6 | # Typing 7 | from src.config.learning_dynamics import BaseMetricConfig 8 | 9 | # Registry 10 | from ._registry import METRIC_REGISTRY 11 | from .base import BaseComparativeMetric, BaseMetric 12 | 13 | # Comparative Metrics (between two checkpoints) 14 | from .cka import CKAMetric 15 | from .condition_number import ConditionNumberMetric 16 | from .gini import GiniMetric 17 | from .hoyer import HoyerMetric 18 | 19 | # Implemented Metrics 20 | # Base Metrics 21 | from .norm import NormMetric 22 | from .per import PERMetric 23 | from .pwcca import PWCCAMetric 24 | 25 | 26 | def get_metric( 27 | metric_config: BaseMetricConfig, training_config: Dict[str, Any] 28 | ) -> BaseMetric: 29 | """ 30 | Loads a metric from the metrics directory. 31 | """ 32 | return METRIC_REGISTRY[metric_config.metric_name](metric_config, training_config) 33 | -------------------------------------------------------------------------------- /src/metrics/_registry.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, Type, TypeVar 2 | 3 | from src.metrics.base import BaseMetric 4 | 5 | T = TypeVar("T", bound=BaseMetric) 6 | 7 | METRIC_REGISTRY: Mapping[str, Type[T]] = {} 8 | 9 | 10 | def register_metric(name: str): 11 | """ 12 | Decorator to register a metric class with the MetricRegistry. 13 | """ 14 | 15 | def _register(cls: Type[T]) -> Type[T]: 16 | METRIC_REGISTRY[name] = cls 17 | return cls 18 | 19 | return _register 20 | -------------------------------------------------------------------------------- /src/metrics/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for all metrics. 3 | """ 4 | 5 | from abc import ABC, abstractmethod 6 | from typing import Any, Dict, List 7 | 8 | import torch 9 | 10 | from src.components import get_component 11 | from src.config.base import BaseComponentConfig 12 | from src.config.learning_dynamics import BaseMetricConfig 13 | 14 | 15 | class BaseMetric(ABC): 16 | """ 17 | Base class for all metrics. 18 | """ 19 | 20 | def __init__(self, metric_config: BaseMetricConfig, run_config: Dict[str, Any]): 21 | """ 22 | Initialize the metric with the given metric config and a run config (the config used during 23 | training that specifies the model architecture, etc.). 24 | 25 | To see an example of a run_config, see the training_config.yaml file in the demo run of Pico: 26 | https://huggingface.co/pico-lm/demo/blob/demo-1/training_config.yaml 27 | 28 | The run_config is used to setup and compute the components that metrics are computed on. 29 | 30 | Args: 31 | metric_config: BaseMetricConfig -- the metric config to use for the metric. 32 | run_config: Dict[str, Any] -- the run config to use for the metric. 33 | """ 34 | self.metric_config = metric_config 35 | 36 | # Setup components 37 | self.components = [] 38 | for component_config in self.metric_config.components: 39 | component = get_component(component_config, run_config) 40 | 41 | # NOTE: We verify that the component config is valid for the component by calling the 42 | # component's validate_component method and that the component is compatible with the 43 | # metric by calling the metric's validate_component method. 44 | 45 | # For instance, some metrics only make sense for weights or activations, so we check 46 | # that the component config specifies a valid data_type. 47 | component.validate_component(component_config) 48 | self.validate_component(component_config) 49 | 50 | self.components.append(component) 51 | 52 | @abstractmethod 53 | def validate_component(self, component_config: BaseComponentConfig) -> None: 54 | """ 55 | Check that the components used in the metric are valid; i.e. that the metric can be 56 | computed on the components specified in the config. Should raise an InvalidComponentError 57 | if the component is not valid. 58 | 59 | Args: 60 | component_config: The component configuration. 61 | 62 | Raises: 63 | InvalidComponentError -- if the component is not valid. 64 | """ 65 | raise NotImplementedError 66 | 67 | def compute_components( 68 | self, checkpoint_states: Dict[str, Dict[str, torch.Tensor]] 69 | ) -> List[Dict[str, torch.Tensor]]: 70 | """ 71 | Preprocesses the checkpoint states to generate the components. This is a helper function that 72 | is used by the compute method. 73 | 74 | Args: 75 | checkpoint_states: Dict[str, Dict[str, torch.Tensor]] -- the initial checkpoint states 76 | to preprocess and generate the components. 77 | **kwargs: Any -- additional arguments to pass to the component setup. 78 | 79 | Returns: 80 | components_data_list: List[Dict[str, torch.Tensor]] -- a list of dictionaries mapping component 81 | names to component data; each component data is a dictionary mapping layer names to tensors 82 | that are the preprocessed components at that layer. 83 | """ 84 | component_data_list = [] 85 | for component, component_config in zip( 86 | self.components, self.metric_config.components 87 | ): 88 | component_data_list.append(component(checkpoint_states, component_config)) 89 | 90 | return component_data_list 91 | 92 | @abstractmethod 93 | def compute_metric(self, component_layer_data: torch.Tensor) -> float: 94 | """ 95 | Computes the desired metric on a single component at a given layer. 96 | 97 | Args: 98 | component_layer_data: torch.Tensor -- the data for a component at a given layer. 99 | 100 | Returns: 101 | torch.Tensor -- the computed metric for the given component at the given layer. 102 | """ 103 | pass 104 | 105 | def __call__( 106 | self, checkpoint_states: Dict[str, Dict[str, torch.Tensor]] 107 | ) -> List[Dict[str, float]]: 108 | """ 109 | Computed the desired metrics on the specified components in the metric config. Reads in the 110 | original data, preprocesses it into components, and then computes the desired metrics on each 111 | component. 112 | 113 | Args: 114 | checkpoint_states: Dict[str, Any] -- the checkpoint states to compute the metric on. 115 | 116 | Returns: 117 | component_metrics_list: List[Dict[str, float]] -- a list of dictionaries mapping component names 118 | across layers to computed metrics. 119 | 120 | NOTE: The order in which the components are returned is the same as the order in 121 | which they are specified in the metric config. 122 | """ 123 | 124 | component_data_list = self.compute_components(checkpoint_states) 125 | 126 | component_metrics_list = [] 127 | 128 | for component_data in component_data_list: 129 | # component_data will be a dictionary mapping layer names to tensors 130 | component_metric = {} 131 | 132 | for _component_layer_name, _component_layer_data in component_data.items(): 133 | component_metric[_component_layer_name] = self.compute_metric( 134 | _component_layer_data 135 | ) 136 | 137 | component_metrics_list.append(component_metric) 138 | 139 | return component_metrics_list 140 | 141 | 142 | class BaseComparativeMetric(BaseMetric): 143 | """ 144 | Base class for comparative metrics. 145 | 146 | The idea of comparative metrics is that these metrics compare the current checkpoint to a 147 | given target checkpoint. For example, we can compute the CKA between the current checkpoint and 148 | the target checkpoint. 149 | 150 | NOTE: the target data (e.g. precomputed target components) are set using the set_target method. 151 | """ 152 | 153 | def __init__(self, metric_config: BaseMetricConfig, run_config: Dict[str, Any]): 154 | """ 155 | Initialize the comparative metric as a base metric, but with a target checkpoint. 156 | """ 157 | super().__init__(metric_config, run_config) 158 | 159 | self._target_component_data_list = None 160 | 161 | def set_target(self, checkpoint_states: Dict[str, Dict[str, torch.Tensor]]) -> None: 162 | """ 163 | Set the target checkpoint data which is stored as the already preprocessed components. 164 | We set the target data before computing metrics so that we can compute the metric between 165 | the source and target checkpoints. 166 | 167 | Args: 168 | checkpoint_states: Dict[str, Dict[str, torch.Tensor]] -- the target checkpoint states. 169 | """ 170 | self._target_component_data_list = self.compute_components(checkpoint_states) 171 | 172 | @abstractmethod 173 | def compute_metric( 174 | self, 175 | source_component_layer_data: torch.Tensor, 176 | target_component_layer_data: torch.Tensor, 177 | ) -> Dict[str, Any]: 178 | """ 179 | Computes the given metric between two components. Unlike the BaseMetric class, this method 180 | takes in two tensors (source and target) and computes the metric between them. 181 | 182 | Args: 183 | source_component_layer_data: torch.Tensor -- the source component data at a given layer. 184 | target_component_layer_data: torch.Tensor -- the target component data at a given layer. 185 | 186 | Returns: 187 | float -- the computed metric. 188 | """ 189 | raise NotImplementedError 190 | 191 | def __call__( 192 | self, source_checkpoint_states: Dict[str, Dict[str, torch.Tensor]] 193 | ) -> List[Dict[str, float]]: 194 | """ 195 | Computes the metric between two sets of checkpoint states. 196 | 197 | NOTE: Can only be called if the target data is set; that is, that set_target() has been called 198 | with the checkpoint states of the target checkpoint. 199 | 200 | Args: 201 | source_checkpoint_states: Dict[str, Dict[str, torch.Tensor]] -- the source checkpoint 202 | states. 203 | 204 | Returns: 205 | component_metrics_list: List[Dict[str, float]] -- a list of dictionaries mapping components 206 | across layers to computed metrics. 207 | 208 | NOTE: The order in which the components are returned is the same as the order in 209 | which they are specified in the metric config. 210 | """ 211 | if self._target_component_data_list is None: 212 | raise ValueError("Target data is not set. Call .set_target() first.") 213 | 214 | src_component_data_list = self.compute_components(source_checkpoint_states) 215 | 216 | component_metrics_list = [] 217 | 218 | for src_component_data, target_component_data in zip( 219 | src_component_data_list, self._target_component_data_list 220 | ): 221 | # source_component_data will be a dictionary mapping layer names to tensors 222 | component_metric = {} 223 | 224 | for ( 225 | component_layer_name, 226 | source_component_layer_data, 227 | ) in src_component_data.items(): 228 | target_component_layer_data = target_component_data[ 229 | component_layer_name 230 | ] 231 | 232 | component_metric[component_layer_name] = self.compute_metric( 233 | source_component_layer_data, 234 | target_component_layer_data, 235 | ) 236 | 237 | component_metrics_list.append(component_metric) 238 | 239 | return component_metrics_list 240 | -------------------------------------------------------------------------------- /src/metrics/cka.py: -------------------------------------------------------------------------------- 1 | """ 2 | CKA (Centered Kernel Alignment) is a comparative metric for comparing how similar two 3 | sets of activations are between two different checkpoints. 4 | """ 5 | 6 | import torch 7 | 8 | from lib import cka 9 | from src.config.base import BaseComponentConfig 10 | from src.metrics._registry import register_metric 11 | from src.metrics.base import BaseComparativeMetric 12 | from src.utils.exceptions import InvalidComponentError 13 | 14 | 15 | @register_metric("cka") 16 | class CKAMetric(BaseComparativeMetric): 17 | """ 18 | Class for computing CKA (Centered Kernel Alignment) between two activations. 19 | 20 | The CKA is computed between the OV circuit activations and the MLP activations. Formally, 21 | CKA(A, B) = (K(A, A)^T K(B, B)^T) / sqrt((K(A, A)^T K(A, A)^T) * (K(B, B)^T K(B, B)^T)) 22 | 23 | where K(A, B) is the kernel matrix between the activations A and B. 24 | 25 | Reference: https://arxiv.org/pdf/1905.00414.pdf 26 | 27 | """ 28 | 29 | def validate_component(self, component_config: BaseComponentConfig) -> bool: 30 | """ 31 | The CKA metric is only valid for activations. 32 | """ 33 | if component_config.data_type not in ["activations"]: 34 | raise InvalidComponentError( 35 | f"CKA metric only supports activations, not {component_config.data_type} " 36 | f"(component: {component_config.component_name})." 37 | ) 38 | 39 | def compute_metric( 40 | self, 41 | source_component_layer_data: torch.Tensor, 42 | target_component_layer_data: torch.Tensor, 43 | ) -> float: 44 | """ 45 | Computes the CKA between two sets of source and target component layer activations. 46 | 47 | Args: 48 | source_component_layer_data: Tensor containing the source data to analyze 49 | target_component_layer_data: Tensor containing the target data to analyze 50 | 51 | Returns: 52 | float: The computed CKA 53 | """ 54 | 55 | # NOTE: The CKA implementation expects float32 numpy darrays 56 | np_src_component_layer_data = source_component_layer_data.to( 57 | dtype=torch.float32 58 | ).numpy() 59 | np_tgt_component_layer_data = target_component_layer_data.to( 60 | dtype=torch.float32 61 | ).numpy() 62 | 63 | cka_value = cka.feature_space_linear_cka( 64 | np_src_component_layer_data, np_tgt_component_layer_data 65 | ) 66 | 67 | return cka_value 68 | -------------------------------------------------------------------------------- /src/metrics/condition_number.py: -------------------------------------------------------------------------------- 1 | """ 2 | Condition number is the ratio of the largest to smallest singular value of the input. 3 | """ 4 | 5 | import torch 6 | 7 | from src.config.base import BaseComponentConfig 8 | from src.metrics._registry import register_metric 9 | from src.metrics.base import BaseMetric 10 | 11 | 12 | @register_metric("condition_number") 13 | class ConditionNumberMetric(BaseMetric): 14 | """ 15 | This metric computes the condition number of some component data. The condition number is the 16 | ratio of the largest to smallest singular value of the input. It gives a measure of how 17 | sensitive the output is to small changes in the input. 18 | """ 19 | 20 | # NOTE: Any component is valid for the condition number metric. 21 | def validate_component(self, component_config: BaseComponentConfig) -> None: ... 22 | 23 | def compute_metric(self, component_layer_data: torch.Tensor) -> float: 24 | """ 25 | Computes the condition number of the given input. 26 | 27 | Args: 28 | component_layer_data: Tensor containing the data to analyze 29 | 30 | Returns: 31 | float: The computed condition number 32 | """ 33 | 34 | # Compute the singular values of the input 35 | singular_values = torch.svd(component_layer_data).S 36 | 37 | # Compute the condition number 38 | condition_number = torch.max(singular_values) / torch.min(singular_values) 39 | 40 | return condition_number.item() 41 | -------------------------------------------------------------------------------- /src/metrics/gini.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gini coefficient is a measure of the 'inequality' of a distribution; we use it indirectly to 3 | measure the sparsity of a matrix. 4 | """ 5 | 6 | import torch 7 | 8 | from src.config.base import BaseComponentConfig 9 | from src.metrics._registry import register_metric 10 | from src.metrics.base import BaseMetric 11 | 12 | 13 | @register_metric("gini") 14 | class GiniMetric(BaseMetric): 15 | """ 16 | Compute the Gini coefficient of some component data which is a rough approximation of the 17 | sparsity of a matrix. 18 | 19 | Formally, if we let x = [x_1, ..., x_n] be the data, then the Gini coefficient is given by: 20 | 21 | G(x) = 1 - sum(x_i) / sum(x) for i = 1 to n 22 | 23 | where x_i is the i-th element of the data, and x is the sum of all the elements in the data. 24 | """ 25 | 26 | # NOTE: Any component is valid for the Gini metric. 27 | def validate_component(self, component_config: BaseComponentConfig) -> None: ... 28 | 29 | def compute_metric(self, component_layer_data: torch.Tensor) -> float: 30 | """ 31 | Compute the Gini coefficient of some component data. 32 | 33 | The Gini coefficient measures inequality in a distribution, with values ranging from 0 34 | (perfect equality) to 1 (perfect inequality). 35 | 36 | This implementation uses a more memory-efficient algorithm that avoids creating 37 | the full pairwise difference matrix. 38 | 39 | Args: 40 | component_layer_data: Tensor containing the data to analyze 41 | 42 | Returns: 43 | float: The computed Gini coefficient 44 | """ 45 | # Reshape the input tensor to a 1D array 46 | x = component_layer_data.flatten() 47 | x = torch.abs(x) 48 | 49 | # Sort the flattened vector in ascending order 50 | x_sorted, _ = torch.sort(x) 51 | n = x_sorted.shape[0] 52 | if n == 0: 53 | return 0.0 # Edge case if the matrix is empty 54 | 55 | # Compute the mean denominator 56 | total = x_sorted.sum() 57 | if total == 0: 58 | return 0.0 # If all entries are zero, Gini is 0 by convention 59 | 60 | # Apply the formula 61 | # sum_{i=1 to n} of (2i - n - 1) * x_sorted[i-1] 62 | idx = torch.arange(1, n + 1, dtype=x.dtype, device=x.device) 63 | numerator = ((2 * idx - n - 1) * x_sorted).sum() 64 | 65 | G = numerator / (n * total) 66 | return G.item() 67 | -------------------------------------------------------------------------------- /src/metrics/hoyer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hoyer's sparsity metric is a measure of the sparsity of a matrix. 3 | """ 4 | 5 | import math 6 | 7 | import torch 8 | 9 | from src.config.base import BaseComponentConfig 10 | from src.metrics._registry import register_metric 11 | from src.metrics.base import BaseMetric 12 | 13 | 14 | @register_metric("hoyer") 15 | class HoyerMetric(BaseMetric): 16 | """ 17 | Hoyer's sparsity metric is a measure of the sparsity of a matrix. Formally, it is defined as: 18 | 19 | Hoyer(P) = (sqrt(n) - ||P||_1 / ||P||_2) / (sqrt(n) - 1) 20 | 21 | where P is the parameter matrix, ||.||_1 is the L1 norm, and ||.||_2 is the L2 norm. 22 | """ 23 | 24 | # NOTE: Any component is valid for the Hoyer metric. 25 | def validate_component(self, component_config: BaseComponentConfig) -> None: ... 26 | 27 | def compute_metric(self, component_layer_data: torch.Tensor) -> float: 28 | """ 29 | Computes the Hoyer sparsity metric for a given component layer data. 30 | 31 | Args: 32 | component_layer_data: Tensor containing the data to analyze 33 | 34 | Returns: 35 | float: The computed Hoyer sparsity metric 36 | """ 37 | 38 | x = component_layer_data.flatten() 39 | n = x.numel() 40 | 41 | # Compute the L1 and L2 norms of the component layer data 42 | l1_norm = torch.norm(component_layer_data, p=1).item() 43 | l2_norm = torch.norm(component_layer_data, p=2).item() 44 | return (math.sqrt(n) - l1_norm / l2_norm) / (math.sqrt(n) - 1) 45 | -------------------------------------------------------------------------------- /src/metrics/norm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Norm metrics compute the norm of weights, activations, gradients, etc. 3 | """ 4 | 5 | from functools import partial 6 | 7 | import torch 8 | 9 | from src.config.base import BaseComponentConfig 10 | from src.config.learning_dynamics import BaseMetricConfig 11 | from src.metrics._registry import register_metric 12 | from src.metrics.base import BaseMetric 13 | 14 | 15 | @register_metric("norm") 16 | class NormMetric(BaseMetric): 17 | """ 18 | Base class for norm metrics; i.e. metrics that compute the norm of some component data. 19 | """ 20 | 21 | def __init__(self, metric_config: BaseMetricConfig, *args): 22 | super().__init__(metric_config, *args) 23 | 24 | # NOTE: We use the torch.norm function to compute the norm of the data. 25 | if self.metric_config.norm_type == "frobenius": 26 | self.norm_function = partial(torch.norm, p="fro") 27 | elif self.metric_config.norm_type == "nuclear": 28 | self.norm_function = partial(torch.norm, p="nuc") 29 | elif self.metric_config.norm_type == "inf": 30 | self.norm_function = partial(torch.norm, p=float("inf")) 31 | else: 32 | raise ValueError(f"Invalid norm_type: {self.metric_config.norm_type}") 33 | 34 | # NOTE: Any component is valid for the norm metric. 35 | def validate_component(self, component_config: BaseComponentConfig) -> None: ... 36 | 37 | def compute_metric(self, component_layer_data: torch.Tensor) -> float: 38 | """ 39 | Computes the norm of the given component data. 40 | 41 | Args: 42 | component_layer_data: The component data to compute the norm of. 43 | 44 | Returns: 45 | The norm of the component data. 46 | """ 47 | return self.norm_function(component_layer_data).item() 48 | -------------------------------------------------------------------------------- /src/metrics/per.py: -------------------------------------------------------------------------------- 1 | """ 2 | PER metric computes the Proportional Effective Rank (PER) of activations or gradients. 3 | """ 4 | 5 | import torch 6 | 7 | from src.config.base import BaseComponentConfig 8 | from src.metrics._registry import register_metric 9 | from src.metrics.base import BaseMetric 10 | from src.utils.exceptions import InvalidComponentError 11 | 12 | 13 | @register_metric("per") 14 | class PERMetric(BaseMetric): 15 | """ 16 | Compute the Proportional Effective Rank (PER) of some component data. The PER is defined as the 17 | entropy over the normalised singular values of a given matrix. 18 | 19 | Formally, if we let S = [s_1, ..., s_n] be the singular values of a parameter matrix P, then the PER is given by: 20 | 21 | PER(P) = - sum(s_i / sum(s)) * log2(s_i / sum(s)) for i = 1 to n 22 | 23 | where s = sum(s) is the sum of the singular values. 24 | """ 25 | 26 | def validate_component(self, component_config: BaseComponentConfig) -> None: 27 | """ 28 | NOTE: The PER metric is only valid for weights and gradients. 29 | """ 30 | if component_config.data_type not in ["weights", "gradients"]: 31 | raise InvalidComponentError( 32 | f"PER metric only supports weights and gradients, not {component_config.data_type} " 33 | f"(component: {component_config.component_name})." 34 | ) 35 | 36 | def compute_metric(self, component_layer_data: torch.Tensor) -> float: 37 | """ 38 | Computes the PER of a given layer. 39 | 40 | Args: 41 | component_layer_data: Tensor containing the data to analyze 42 | 43 | Returns: 44 | float: The computed PER 45 | """ 46 | 47 | layer_singular_values = torch.svd(component_layer_data).S 48 | 49 | # standardize singular values 50 | layer_singular_values = layer_singular_values / layer_singular_values.sum() 51 | 52 | # compute effective rank (ER) and proportional effective rank (PER) 53 | layer_er = torch.exp( 54 | -torch.sum(layer_singular_values * torch.log(layer_singular_values)) 55 | ).item() 56 | layer_per = layer_er / len(layer_singular_values) 57 | 58 | return layer_per 59 | -------------------------------------------------------------------------------- /src/metrics/pwcca.py: -------------------------------------------------------------------------------- 1 | """ 2 | PWCCA (Projection Weighted Canonical Correlation Analysis) is a comparative metric for comparing 3 | how similar two sets of activations are between two different checkpoints. 4 | """ 5 | 6 | import torch 7 | 8 | from lib.svcca.pwcca import compute_pwcca 9 | from src.config.base import BaseComponentConfig 10 | from src.metrics._registry import register_metric 11 | from src.metrics.base import BaseComparativeMetric 12 | from src.utils.exceptions import InvalidComponentError 13 | 14 | 15 | @register_metric("pwcca") 16 | class PWCCAMetric(BaseComparativeMetric): 17 | """ 18 | This metric computes the PWCCA of the given data. 19 | 20 | PWCCA is a variant of the Canonical Correlation Analysis (CCA) that uses projection weights to 21 | compute the similarity between two sets of activations. 22 | 23 | Reference: https://arxiv.org/abs/1806.05759 24 | """ 25 | 26 | def validate_component(self, component_config: BaseComponentConfig) -> bool: 27 | """ 28 | The PWCCA metric is only valid for activations. 29 | """ 30 | if component_config.data_type not in ["activations"]: 31 | raise InvalidComponentError( 32 | f"PWCCA metric only supports activations, not {component_config.data_type} " 33 | f"(component: {component_config.component_name})." 34 | ) 35 | 36 | def compute_metric( 37 | self, 38 | source_component_layer_data: torch.Tensor, 39 | target_component_layer_data: torch.Tensor, 40 | ) -> float: 41 | """ 42 | Computes the PWCCA between the source and target component layer activations. 43 | 44 | Args: 45 | source_component_layer_data: Tensor containing the source data to analyze 46 | target_component_layer_data: Tensor containing the target data to analyze 47 | 48 | Returns: 49 | float: The computed PWCCA 50 | """ 51 | # transforming the data to numpy 52 | # NOTE: that pwcca expects the data to be of shape '(num_neurons, num_samples)' so 53 | # we need to transpose the data 54 | np_source_component_layer_data = ( 55 | source_component_layer_data.to(dtype=torch.float32).transpose(0, 1).numpy() 56 | ) 57 | np_target_component_layer_data = ( 58 | target_component_layer_data.to(dtype=torch.float32).transpose(0, 1).numpy() 59 | ) 60 | 61 | pwcca_metric, _, _ = compute_pwcca( 62 | np_source_component_layer_data, np_target_component_layer_data, epsilon=1e-6 63 | ) 64 | 65 | return float(pwcca_metric) 66 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pico-lm/pico-analyze/c4a4e844abfec54908a10aa12a5b3a002d8945df/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for downloading checkpoint data from HuggingFace or from a local run. 3 | 4 | NOTE: Assumes that models have been uploaded to HuggingFace using pico-train. 5 | """ 6 | 7 | import os 8 | import re 9 | from functools import lru_cache 10 | 11 | import torch 12 | import yaml 13 | from datasets import load_from_disk 14 | from huggingface_hub import HfApi, hf_hub_download, snapshot_download 15 | 16 | from src.utils.exceptions import InvalidStepError 17 | from src.utils.initialization import CheckpointLocation 18 | 19 | 20 | def get_checkpoint_states( 21 | checkpoint_location: CheckpointLocation, step: int = None, data_split: str = "val" 22 | ) -> dict: 23 | """ 24 | Returns all the available checkpoint states available for a given step, data split, in a given 25 | run path or a given HuggingFace repository and branch. We assume that the checkpoint states are 26 | stored in the checkpoint folder in the checkpoint folder (generated by pico-train) and has the 27 | following structure: 28 | 29 | learning_dynamics/ 30 | checkpoint/ 31 | step_/ 32 | learning_dynamics/ 33 | train_activations.pt 34 | train_weights.pt 35 | train_gradients.pt 36 | train_data/ 37 | [...] 38 | val_activations.pt 39 | val_weights.pt 40 | val_gradients.pt 41 | 42 | Args: 43 | checkpoint_location: CheckpointLocation 44 | step: Step to get data from 45 | data_split: Data split to get data from (i.e. "train" or "val") 46 | 47 | Returns: 48 | dict: Dictionary containing the checkpoint states for a given step and data split. 49 | 50 | For instance, if the data split is "train", the dictionary will have the following 51 | structure. 52 | { 53 | "activations": { 54 | "model.0.mlp": torch.Tensor, # model activations across layer 55 | [...] 56 | }, 57 | "weights": { 58 | "model.0.mlp": torch.Tensor, # model weights across layer 59 | [...] 60 | }, 61 | "gradients": { 62 | "model.0.mlp": torch.Tensor, # model gradients across layer 63 | [...] 64 | }, 65 | "dataset": torch.utils.data.Dataset, 66 | } 67 | """ 68 | if checkpoint_location.is_remote: 69 | return _download_checkpoint_states( 70 | checkpoint_location.repo_id, checkpoint_location.branch, step, data_split 71 | ) 72 | else: 73 | return _load_checkpoint_states(checkpoint_location.run_path, step, data_split) 74 | 75 | 76 | def get_training_config(checkpoint_location: CheckpointLocation) -> dict: 77 | """ 78 | Loads in the training config from a checkpoint location. 79 | """ 80 | if checkpoint_location.is_remote: 81 | return _download_training_config( 82 | checkpoint_location.repo_id, checkpoint_location.branch 83 | ) 84 | else: 85 | return _load_training_config(checkpoint_location.run_path) 86 | 87 | 88 | #################### 89 | # 90 | # Helper Functions for loading/setting up learning dynamics data 91 | # 92 | #################### 93 | 94 | # ----------------- 95 | # Load Training Config 96 | # ----------------- 97 | 98 | 99 | def _load_training_config(run_path: str) -> dict: 100 | """ 101 | Loads in the training config from the run path. If using pico-train, the run_config will always be 102 | stored as training_config.yaml in the root of the run path. 103 | """ 104 | return yaml.safe_load(open(os.path.join(run_path, "training_config.yaml"), "r")) 105 | 106 | 107 | def _download_training_config(repo_id: str, branch: str) -> dict: 108 | """ 109 | Downloads the training config from the HuggingFace repository. If using pico-train, the 110 | run_config will always be stored as training_config.yaml in the root of the repository. 111 | """ 112 | 113 | # Get the training_config.yaml file from the HuggingFace repository 114 | training_config_path = hf_hub_download( 115 | repo_id=repo_id, revision=branch, filename="training_config.yaml" 116 | ) 117 | 118 | return yaml.safe_load(open(training_config_path, "r")) 119 | 120 | 121 | # ----------------- 122 | # HuggingFace API Helper Functions 123 | # ----------------- 124 | 125 | 126 | @lru_cache() 127 | def _get_learning_dynamics_commits(repo_id: str, branch: str, data_split: str) -> dict: 128 | """ 129 | Get the list of commits for a given repository and branch on HuggingFace that store out 130 | checkpoint states for computing learning dynamics. We cache the results to avoid making too 131 | many requests to the HuggingFace API. 132 | 133 | Args: 134 | repo_id: HuggingFace repository ID 135 | branch: Branch to get commits from 136 | data_split: Data split to get commits for 137 | 138 | Returns: 139 | dict: Dictionary containing the commits for the given data split. 140 | """ 141 | 142 | api = HfApi() 143 | 144 | # NOTE: this pattern is specific to how pico-train saves the learning dynamics data. 145 | pattern = rf"Saving Learning Dynamics Data \({data_split}\) -- Step (\d+)" 146 | 147 | # Create defaultdict to store commits by type and step 148 | learning_dynamics_commits = dict() 149 | 150 | # Get all commits 151 | commits = api.list_repo_commits(repo_id=repo_id, revision=branch) 152 | 153 | # Process each commit 154 | for commit in commits: 155 | match = re.search(pattern, commit.title) 156 | if match: 157 | step = int(match.group(1)) # step number is now in group 1 158 | 159 | learning_dynamics_commits[step] = { 160 | "commit_id": commit.commit_id, 161 | "date": commit.created_at, 162 | "message": commit.title, 163 | } 164 | 165 | return learning_dynamics_commits 166 | 167 | 168 | # ----------------- 169 | # Load Learning Dynamics Data 170 | # ----------------- 171 | 172 | 173 | def _get_checkpoint_states_dict( 174 | learning_dynamics_path: str, data_split: list[str] 175 | ) -> dict: 176 | """ 177 | Load in the checkpoint states from the directory at learning_dynamics_path. This is a helper 178 | function called on by _load_checkpoint_states and _download_checkpoint_states to load in the 179 | stored checkpoint states for a given data split. 180 | 181 | Args: 182 | learning_dynamics_path: Path to the learning dynamics directory that stores the model 183 | checkpoint states for computing learning dynamics. 184 | data_split: Data split to get data from (i.e. "train" or "val") 185 | 186 | Returns: 187 | states: Dictionary containing the loaded checkpoint states 188 | """ 189 | # load the data 190 | checkpoint_states = {} 191 | for data_type in ["activations", "weights", "gradients"]: 192 | file_path = os.path.join(learning_dynamics_path, f"{data_split}_{data_type}.pt") 193 | if os.path.exists(file_path): 194 | checkpoint_states[data_type] = torch.load(file_path) 195 | 196 | dataset_path = os.path.join(learning_dynamics_path, f"{data_split}_data") 197 | if os.path.exists(dataset_path): 198 | checkpoint_states["dataset"] = load_from_disk(dataset_path) 199 | 200 | return checkpoint_states 201 | 202 | 203 | def _load_checkpoint_states(run_path: str, step: int, data_split: str) -> dict: 204 | """ 205 | Load checkpoint states from a local run path for a given step and data split. 206 | 207 | This is a helper function called on by get_checkpoint_states to load in the checkpoint states 208 | from a local run path. 209 | 210 | Args: 211 | run_path: Path to the run 212 | step: Step to get data from 213 | data_split: Data split to get data from (e.g. "train", "val") 214 | 215 | Returns: 216 | dict: Dictionary containing the learning dynamics data for the given step and data split. 217 | """ 218 | 219 | # ensure that the run_path is a valid path 220 | if not os.path.exists(run_path): 221 | raise ValueError(f"Run path {run_path} does not exist") 222 | 223 | checkpoint_path = os.path.join(run_path, "checkpoints") 224 | # ensure that the run_path contains a checkpoint folder 225 | if not os.path.exists(checkpoint_path): 226 | raise ValueError(f"Run path {run_path} does not contain a checkpoint folder") 227 | 228 | step_path = os.path.join(checkpoint_path, f"step_{step}") 229 | 230 | # ensure that the step exists 231 | if not os.path.exists(step_path): 232 | raise InvalidStepError(step) 233 | 234 | # states to compute learning dynamics are stored in the learning_dynamics folder 235 | learning_dynamics_path = os.path.join(step_path, "learning_dynamics") 236 | 237 | checkpoint_states = _get_checkpoint_states_dict(learning_dynamics_path, data_split) 238 | if len(checkpoint_states) == 0: 239 | # NOTE: this can happen if a checkpoint folder exists for a given step but no 240 | # learning dynamics data was saved for the specified data split. 241 | raise InvalidStepError(step) 242 | 243 | return checkpoint_states 244 | 245 | 246 | def _download_checkpoint_states( 247 | repo_id: str, branch: str, step: int, data_split: str 248 | ) -> dict: 249 | """ 250 | Download checkpoint states for a specific commit and step and data split. 251 | 252 | Args: 253 | repo_id: HuggingFace repository ID 254 | branch: Branch to get commits from 255 | step: Step to get data from 256 | data_split: Data split to get data from (i.e. "train" or "val") 257 | 258 | Returns: 259 | dict: Dictionary containing the loaded learning dynamics data 260 | """ 261 | 262 | # get all of the commits in the branch 263 | learning_dynamics_commits = _get_learning_dynamics_commits( 264 | repo_id, branch, data_split 265 | ) 266 | 267 | if step not in learning_dynamics_commits: 268 | raise InvalidStepError(step) 269 | 270 | commit = learning_dynamics_commits[step] 271 | checkpoint_dir = snapshot_download( 272 | repo_id=repo_id, 273 | revision=commit["commit_id"], 274 | ) 275 | 276 | # states to compute learning dynamics are stored in the learning_dynamics folder 277 | learning_dynamics_path = os.path.join(checkpoint_dir, "learning_dynamics") 278 | return _get_checkpoint_states_dict(learning_dynamics_path, data_split) 279 | -------------------------------------------------------------------------------- /src/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom exceptions for the analysis package. 3 | """ 4 | 5 | from typing import Optional 6 | 7 | 8 | class InvalidStepError(Exception): 9 | """ 10 | Exception class that is raised when a requested step is not available in the repository. 11 | """ 12 | 13 | def __init__(self, step: Optional[int] = None, message: Optional[str] = None): 14 | if message is None: 15 | self.message = f"Step {step if step is not None else ''} is not a valid checkpoint step." 16 | else: 17 | self.message = message 18 | 19 | super().__init__(self.message) 20 | 21 | 22 | class InvalidRunLocationError(Exception): 23 | """ 24 | Exception class that is raised when a requested location for a model run is not valid. 25 | """ 26 | 27 | def __init__(self, message: Optional[str] = None): 28 | if message is None: 29 | self.message = """ 30 | Location must be either a remote repository or a local path. 31 | 32 | To specify a remote repository, provide the repo_id and branch: 33 | python run analyze [...] --repo_id --branch 34 | 35 | To specify a local path, provide the run_path: 36 | python run analyze [...] --run_path 37 | """ 38 | else: 39 | self.message = message 40 | 41 | super().__init__(self.message) 42 | 43 | 44 | class InvalidComponentError(Exception): 45 | """ 46 | Exception class that is raised when a requested component is not valid because it is either 47 | incompatible with the metric or the component config is invalid. 48 | """ 49 | 50 | def __init__(self, message: Optional[str] = None): 51 | if message is None: 52 | self.message = "Component is not a valid component." 53 | else: 54 | self.message = message 55 | 56 | super().__init__(self.message) 57 | -------------------------------------------------------------------------------- /src/utils/initialization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Initialize configuration objects from a YAML file. 3 | """ 4 | 5 | import logging 6 | import os 7 | from datetime import datetime 8 | 9 | # typing imports 10 | from typing import Any, Dict 11 | 12 | import yaml 13 | 14 | import wandb 15 | from src.config.learning_dynamics import LearningDynamicsConfig 16 | from src.utils.exceptions import InvalidRunLocationError 17 | 18 | #################### 19 | # 20 | # Monitoring Setup (Logging and Wandb) 21 | # 22 | #################### 23 | 24 | 25 | def initialize_output_dir( 26 | config: LearningDynamicsConfig, training_config: Dict[str, Any] 27 | ) -> str: 28 | """ 29 | Creates the output directory for the analysis. If no analysis name is specified, we will use 30 | the run name and the current date and time as a unique identifier. 31 | 32 | Args: 33 | config: LearningDynamicsConfig -- the learning dynamics config. 34 | training_config: Dict[str, Any] -- the training config. 35 | 36 | Returns: 37 | str -- the output directory. 38 | """ 39 | 40 | _analysis_name = config.analysis_name 41 | if _analysis_name is None or _analysis_name == "": 42 | # if no analysis name is specified, use the run name and the current date and time 43 | # as a unique identifier 44 | _analysis_name = ( 45 | training_config["checkpointing"]["run_name"] 46 | + "_analysis_" 47 | + datetime.now().strftime("%Y%m%d_%H%M%S") 48 | ) 49 | 50 | config.analysis_name = _analysis_name 51 | 52 | analysis_dir = os.path.join(config.monitoring.output_dir, _analysis_name) 53 | os.makedirs(analysis_dir, exist_ok=True) 54 | return analysis_dir 55 | 56 | 57 | def initialize_logging(analysis_dir: str) -> logging.Logger: 58 | """ 59 | Sets up the logging for the analysis. The logs are saved to the analysis directory. 60 | 61 | Args: 62 | analysis_dir: str -- the analysis directory to save the logs to 63 | 64 | Returns: 65 | logging.Logger: Configured logger instance 66 | """ 67 | # Create logger 68 | logger = logging.getLogger("pico-analyze") 69 | logger.setLevel(logging.INFO) 70 | 71 | # Remove any existing handlers 72 | if logger.hasHandlers(): 73 | logger.handlers.clear() 74 | 75 | # Create console handler 76 | console_handler = logging.StreamHandler() 77 | console_handler.setLevel(logging.INFO) 78 | 79 | # Create formatter 80 | formatter = logging.Formatter( 81 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s", 82 | datefmt="%Y-%m-%d %H:%M:%S", 83 | ) 84 | console_handler.setFormatter(formatter) 85 | 86 | file_handler = logging.FileHandler(os.path.join(analysis_dir, "analysis.log")) 87 | file_handler.setLevel(logging.INFO) 88 | file_handler.setFormatter(formatter) 89 | 90 | # Add handlers to logger 91 | logger.addHandler(console_handler) 92 | logger.addHandler(file_handler) 93 | 94 | return logger 95 | 96 | 97 | def initialize_wandb(config: LearningDynamicsConfig) -> wandb.sdk.wandb_run.Run: 98 | """ 99 | Sets up the Wandb run tracker to log out the learning dynamics metrics. Reads in the 100 | config and training config and initializes a wandb run; if the run already exists, and no 101 | entity or project is specified in the config, then wandb will print out the metrics 102 | to the existing run. 103 | 104 | Args: 105 | config: LearningDynamicsConfig -- the learning dynamics config. 106 | 107 | Returns: 108 | wandb.sdk.wandb_run.Run -- the wandb run. 109 | """ 110 | 111 | if not config.monitoring.save_to_wandb: 112 | return None 113 | 114 | # check if there is a wandb entity and project specified in the config 115 | assert ( 116 | config.monitoring.wandb.entity is not None 117 | ), "Wandb entity must be specified in the config." 118 | assert ( 119 | config.monitoring.wandb.project is not None 120 | ), "Wandb project must be specified in the config." 121 | 122 | entity = config.monitoring.wandb.entity 123 | project = config.monitoring.wandb.project 124 | 125 | run_name = config.analysis_name 126 | 127 | # initialize the wandb logger 128 | wandb_run = wandb.init( 129 | name=run_name, 130 | project=project, 131 | entity=entity, 132 | ) 133 | 134 | return wandb_run 135 | 136 | 137 | #################### 138 | # 139 | # Helper Functions and Classes 140 | # 141 | #################### 142 | 143 | 144 | class CheckpointLocation: 145 | def __init__(self, repo_id: str, branch: str, run_path: str): 146 | """ 147 | Initialize a CheckpointLocation object. Used to specify the location of a checkpoint 148 | which can be either local or remote. 149 | """ 150 | self.repo_id = repo_id 151 | self.branch = branch 152 | self.run_path = run_path 153 | 154 | self._validate_input() 155 | 156 | def _validate_input(self): 157 | """ 158 | Need to ensure that either the repo_id and branch are specified or the run_path is specified. 159 | 160 | Raises: 161 | InvalidRunLocationError: If the run_path is not specified and the repo_id and branch are not specified. 162 | """ 163 | if self.run_path is not None: 164 | if not os.path.exists(self.run_path): 165 | raise InvalidRunLocationError() 166 | self.is_remote = False 167 | else: 168 | if self.repo_id is None or self.branch is None: 169 | raise InvalidRunLocationError() 170 | self.is_remote = True 171 | 172 | 173 | #################### 174 | # 175 | # Configuration Setup 176 | # 177 | #################### 178 | 179 | 180 | def initialize_config(config_path: str) -> dict: 181 | """Initialize configuration objects with optional overrides from a YAML file. 182 | 183 | This function initializes the configuration objects with the default values, and then 184 | applies any overrides from the config_path file. 185 | 186 | Args: 187 | config_path: Path to a YAML file containing configuration overrides. 188 | 189 | Returns: 190 | A dictionary containing the initialized configuration objects. 191 | """ 192 | overrides = yaml.safe_load(open(config_path, "r")) 193 | config = LearningDynamicsConfig(**overrides) 194 | return config 195 | -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | Miscellaneous logging utilities. 3 | """ 4 | 5 | import logging 6 | from io import StringIO 7 | 8 | import yaml 9 | from rich.console import Console 10 | from rich.panel import Panel 11 | 12 | 13 | def pretty_print_config(logger: logging.Logger, title: str, config: dict) -> None: 14 | """ 15 | Pretty print config with rich formatting. We use the rich library to create a panel 16 | with a blue border and a title. We then convert the config to a YAML string and print 17 | it to the buffer. 18 | 19 | Args: 20 | logger: The logger instance 21 | title: The title of the config 22 | config: The config to print 23 | """ 24 | # Create string buffer 25 | output = StringIO() 26 | console = Console(file=output, force_terminal=False) 27 | 28 | # Convert to YAML string first 29 | yaml_str = yaml.dump( 30 | config, default_flow_style=False, sort_keys=False, Dumper=yaml.SafeDumper 31 | ) 32 | 33 | # Create formatted panel 34 | panel = Panel( 35 | yaml_str, 36 | title=f"[bold blue]{title}[/bold blue]", 37 | border_style="blue", 38 | padding=(0, 1), # Reduced padding 39 | expand=False, # Don't expand to terminal width 40 | ) 41 | 42 | # Print to buffer 43 | console.print(panel) 44 | 45 | # Log the formatted output 46 | for line in output.getvalue().splitlines(): 47 | logger.info(line) 48 | 49 | 50 | def pretty_print_component_metrics( 51 | logger: logging.Logger, step: int, step_metrics: dict 52 | ): 53 | """ 54 | Log multiple component metrics in a grouped and aligned format. We group the components by 55 | layer and sort them alphabetically. We also sort the metrics alphabetically. Finally, 56 | for each metric, we display the components in a simple bar chart. 57 | 58 | Args: 59 | logger: The logger instance 60 | step: The training step 61 | step_metrics: Dictionary of dictionaries, where each key is a metric name and 62 | each value is a dictionary of component metrics 63 | """ 64 | # Create a header for all metrics 65 | header = f"📊 Component Metrics at Step {step}" 66 | separator = "=" * len(header) 67 | 68 | logger.info(separator) 69 | logger.info(header) 70 | logger.info(separator) 71 | 72 | # Sort metrics alphabetically for consistent display 73 | sorted_metric_names = sorted(step_metrics.keys()) 74 | 75 | for metric_name in sorted_metric_names: 76 | component_metrics_dict = step_metrics[metric_name] 77 | 78 | # Create a subheader for each metric 79 | metric_header = f"🔍 {metric_name.upper()} Metric" 80 | metric_separator = "-" * len(metric_header) 81 | 82 | logger.info(f"{metric_separator}") 83 | logger.info(metric_header) 84 | logger.info(metric_separator) 85 | 86 | # Group components by layer 87 | layer_groups = {} 88 | other_components = {} 89 | 90 | for component_key, metric_value in component_metrics_dict.items(): 91 | # Extract layer number if present 92 | parts = component_key.split(".") 93 | layer_match = None 94 | layer_idx = None 95 | 96 | for i, part in enumerate(parts): 97 | if part == "layers" and i + 1 < len(parts): 98 | try: 99 | layer_idx = int(parts[i + 1]) 100 | layer_match = f"layers.{layer_idx}" 101 | break 102 | except ValueError: 103 | pass 104 | 105 | if layer_match: 106 | # Component belongs to a specific layer 107 | layer_name = f"Layer {layer_idx}" 108 | if layer_name not in layer_groups: 109 | layer_groups[layer_name] = { 110 | "_layer_idx": layer_idx, 111 | "components": {}, 112 | } 113 | 114 | layer_groups[layer_name]["components"][component_key] = metric_value 115 | else: 116 | # Component doesn't belong to a specific layer 117 | # NOTE: this really shouldn't happen, but is a fallback for any components 118 | # that don't have a layer index. 119 | other_components[component_key] = metric_value 120 | 121 | # Sort layers by index 122 | sorted_layer_names = sorted( 123 | layer_groups.keys(), key=lambda x: layer_groups[x]["_layer_idx"] 124 | ) 125 | 126 | # Display components by layer 127 | for layer_name in sorted_layer_names: 128 | logger.info(f"📌 {layer_name}") 129 | 130 | # Get components for this layer and sort them 131 | components = layer_groups[layer_name]["components"] 132 | 133 | # Sort components by name, but ensure head_0 comes before head_10 134 | def sort_key(name): 135 | parts = name.split(".") 136 | result = [] 137 | for part in parts: 138 | try: 139 | result.append((0, int(part))) 140 | except ValueError: 141 | result.append((1, part)) 142 | return result 143 | 144 | sorted_component_names = sorted(components.keys(), key=sort_key) 145 | 146 | # Find max length for alignment 147 | max_name_length = ( 148 | max(len(name) for name in sorted_component_names) 149 | if sorted_component_names 150 | else 0 151 | ) 152 | 153 | # Display each component 154 | for component_name in sorted_component_names: 155 | value = components[component_name] 156 | bar_length = min(round(value * 20), 20) # Use round() instead of int() 157 | bar = "█" * bar_length + "░" * (20 - bar_length) 158 | logger.info( 159 | f" {component_name.ljust(max_name_length)} │ {bar} │ {value:.4f}" 160 | ) 161 | 162 | # Display other components 163 | if other_components: 164 | logger.info("📌 Other Components") 165 | sorted_other_names = sorted(other_components.keys()) 166 | max_other_length = ( 167 | max(len(name) for name in sorted_other_names) 168 | if sorted_other_names 169 | else 0 170 | ) 171 | 172 | for component_name in sorted_other_names: 173 | value = other_components[component_name] 174 | bar_length = min(round(value * 20), 20) 175 | bar = "█" * bar_length + "░" * (20 - bar_length) 176 | logger.info( 177 | f" {component_name.ljust(max_other_length)} │ {bar} │ {value:.4f}" 178 | ) 179 | --------------------------------------------------------------------------------