├── .env.example ├── .github ├── pull_request_template.md └── workflows │ └── checks.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode ├── extensions.json ├── launch.json └── settings-example.json ├── ACCESS.md ├── LICENSE ├── Makefile ├── README.md ├── conftest.py ├── demos └── train_saes.ipynb ├── e2e_sae ├── __init__.py ├── data.py ├── hooks.py ├── loader.py ├── log.py ├── losses.py ├── metrics.py ├── models │ ├── __init__.py │ ├── mlp.py │ ├── sparsifiers.py │ └── transformers.py ├── plotting.py ├── scripts │ ├── __init__.py │ ├── analysis │ │ ├── activation_analysis.py │ │ ├── autointerp.py │ │ ├── geometric_analysis.py │ │ ├── pca_dir0_explore.py │ │ ├── plot_performance.py │ │ ├── plot_settings.py │ │ ├── resample_direction.py │ │ └── utils.py │ ├── train_mlp_saes │ │ ├── max_act_data_mnist.py │ │ ├── max_act_data_mnist.yaml │ │ ├── mnist_saes.yaml │ │ └── train_mnist_saes.py │ ├── train_mnist │ │ ├── mnist.yaml │ │ └── run_train_mnist.py │ ├── train_tlens │ │ ├── run_train_tlens.py │ │ ├── sample_models │ │ │ └── tiny-gpt2_lr-0.001_bs-16_2024-04-21_14-01-14 │ │ │ │ ├── epoch_1.pt │ │ │ │ └── final_config.yaml │ │ └── tiny_gpt2.yaml │ ├── train_tlens_saes │ │ ├── gpt2_e2e.yaml │ │ ├── gpt2_e2e_recon.yaml │ │ ├── gpt2_e2e_recon_sweep.yaml │ │ ├── gpt2_e2e_sweep.yaml │ │ ├── gpt2_local.yaml │ │ ├── gpt2_local_sweep.yaml │ │ ├── pythia_14m_e2e.yaml │ │ ├── run_sweep.py │ │ ├── run_sweep_mp.py │ │ ├── run_train_tlens_saes.py │ │ ├── run_wandb_sweep.py │ │ ├── tinystories_1M_e2e.yaml │ │ ├── tinystories_1M_e2e_sweep.yaml │ │ ├── tinystories_1M_local.yaml │ │ └── tinystories_1M_local_sweep.yaml │ └── upload_hf_dataset.py ├── settings.py ├── types.py └── utils.py ├── pyproject.toml └── tests ├── test_mlp.py ├── test_sae.py ├── test_train_tlens_saes.py ├── test_transformer.py ├── test_utils.py └── utils.py /.env.example: -------------------------------------------------------------------------------- 1 | WANDB_API_KEY=your_api_key 2 | WANDB_ENTITY=your_entity_name 3 | OPENAI_API_KEY=your_api_key # For autointerp 4 | NEURONPEDIA_API_KEY=your_api_key # For neuronpedia 5 | HF_HOME=\my_drive\huggingface\misc 6 | HF_DATASETS_CACHE=\my_drive\huggingface\datasets -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | 4 | ## Related Issue 5 | 6 | 7 | ## Motivation and Context 8 | 9 | 10 | ## How Has This Been Tested? 11 | 12 | 13 | ## Does this PR introduce a breaking change? 14 | 15 | -------------------------------------------------------------------------------- /.github/workflows/checks.yaml: -------------------------------------------------------------------------------- 1 | name: Checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - '**' 10 | # Allows you to run this workflow manually from the Actions tab 11 | workflow_dispatch: 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | timeout-minutes: 15 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Set up Python 3.11 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: '3.11' 23 | # pyright needs node 24 | - name: Setup node 25 | uses: actions/setup-node@v4 26 | with: 27 | node-version: 21 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install ".[dev]" 32 | - name: Run pyright type check 33 | run: | 34 | pyright 35 | - name: Run ruff lint 36 | run: | 37 | ruff check --fix-only . 38 | - name: Run ruff format 39 | run: | 40 | ruff format . 41 | - name: Run tests 42 | run: | 43 | python -m pytest tests/ --runslow --durations=10 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/out/ 2 | neuronpedia_outputs/ 3 | .env 4 | .vscode/settings.json 5 | wandb/ 6 | .data/ 7 | .checkpoints/ 8 | .DS_Store 9 | **/.DS_Store 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # poetry 108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 112 | #poetry.lock 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | #pdm.lock 117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 118 | # in version control. 119 | # https://pdm.fming.dev/#use-with-ide 120 | .pdm.toml 121 | 122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 123 | __pypackages__/ 124 | 125 | # Celery stuff 126 | celerybeat-schedule 127 | celerybeat.pid 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # pytype static type analyzer 160 | .pytype/ 161 | 162 | # Cython debug symbols 163 | cython_debug/ 164 | 165 | # PyCharm 166 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 167 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 168 | # and can be added to the global gitignore or merged into this file. For a more nuclear 169 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 170 | #.idea/ -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # If you want branch protection for the main branch, use this hook 3 | # - repo: https://github.com/pre-commit/pre-commit-hooks 4 | # rev: v4.5.0 5 | # hooks: 6 | # - id: no-commit-to-branch 7 | # args: ["--branch=main"] 8 | # stages: 9 | # - commit 10 | - repo: local 11 | hooks: 12 | - id: pyright 13 | name: Pyright 14 | entry: pyright 15 | language: system 16 | types: [python] 17 | stages: 18 | - commit 19 | 20 | - id: ruff-lint 21 | name: Ruff lint 22 | entry: ruff check 23 | args: ["--fix-only"] 24 | language: system 25 | types: [python] 26 | stages: 27 | - commit 28 | 29 | - id: ruff-format 30 | name: Ruff format 31 | entry: ruff format 32 | language: system 33 | types: [python] 34 | stages: 35 | - commit -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-toolsai.jupyter", 4 | "github.copilot", 5 | "ms-python.python", 6 | "charliermarsh.ruff", 7 | "stkb.rewrap" 8 | ] 9 | } -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "train gpt2 sae", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "program": "${workspaceFolder}/e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py", 12 | "args": "${workspaceFolder}/e2e_sae/scripts/train_tlens_saes/gpt2_e2e.yaml", 13 | "console": "integratedTerminal", 14 | "justMyCode": true, 15 | "env": { 16 | "PYDEVD_DISABLE_FILE_VALIDATION": "1" 17 | } 18 | }, 19 | { 20 | "name": "train pythia-14m sae", 21 | "type": "debugpy", 22 | "request": "launch", 23 | "program": "${workspaceFolder}/e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py", 24 | "args": "${workspaceFolder}/e2e_sae/scripts/train_tlens_saes/pythia_14m_e2e.yaml", 25 | "console": "integratedTerminal", 26 | "justMyCode": true, 27 | "env": { 28 | "PYDEVD_DISABLE_FILE_VALIDATION": "1" 29 | } 30 | }, 31 | { 32 | "name": "train tinystories-1M sae", 33 | "type": "debugpy", 34 | "request": "launch", 35 | "program": "${workspaceFolder}/e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py", 36 | "args": "${workspaceFolder}/e2e_sae/scripts/train_tlens_saes/tinystories_1M_e2e.yaml", 37 | "console": "integratedTerminal", 38 | "justMyCode": true, 39 | "env": { 40 | "PYDEVD_DISABLE_FILE_VALIDATION": "1" 41 | } 42 | }, 43 | { 44 | "name": "train tinystories-1M local sae", 45 | "type": "debugpy", 46 | "request": "launch", 47 | "program": "${workspaceFolder}/e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes_local.py", 48 | "args": "${workspaceFolder}/e2e_sae/scripts/train_tlens_saes/tinystories_1M_local.yaml", 49 | "console": "integratedTerminal", 50 | "justMyCode": true, 51 | "env": { 52 | "PYDEVD_DISABLE_FILE_VALIDATION": "1" 53 | } 54 | }, 55 | { 56 | "name": "train mnist sae", 57 | "type": "debugpy", 58 | "request": "launch", 59 | "program": "${workspaceFolder}/e2e_sae/scripts/train_mlp_saes/train_mnist_saes.py", 60 | "args": "${workspaceFolder}/e2e_sae/scripts/train_mlp_saes/mnist_saes.yaml", 61 | "console": "integratedTerminal", 62 | "justMyCode": true, 63 | "env": { 64 | "PYDEVD_DISABLE_FILE_VALIDATION": "1" 65 | } 66 | }, 67 | { 68 | "name": "Python: Attach", 69 | "type": "debugpy", 70 | "request": "attach", 71 | "connect": { 72 | "host": "localhost", 73 | "port": 5678 74 | }, 75 | } 76 | ] 77 | } -------------------------------------------------------------------------------- /.vscode/settings-example.json: -------------------------------------------------------------------------------- 1 | { 2 | "[python]": { 3 | "editor.formatOnSave": true, 4 | "editor.codeActionsOnSave": { 5 | "source.fixAll": "explicit", 6 | "source.organizeImports": "explicit" 7 | }, 8 | "editor.defaultFormatter": "charliermarsh.ruff" 9 | }, 10 | "github.copilot.enable": { 11 | "*": true 12 | }, 13 | "rewrap.autoWrap.enabled": true, 14 | "rewrap.wrappingColumn": 100, 15 | "notebook.formatOnSave.enabled": true 16 | } -------------------------------------------------------------------------------- /ACCESS.md: -------------------------------------------------------------------------------- 1 | # Disclosure Level - Public 2 | 3 | See Privacy Levels [here](https://www.apolloresearch.ai/blog/security) 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ApolloResearch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: install 2 | install: 3 | pip install -e . 4 | 5 | .PHONY: install-dev 6 | install-dev: 7 | pip install -e .[dev] 8 | pre-commit install 9 | 10 | .PHONY: type 11 | type: 12 | SKIP=no-commit-to-branch pre-commit run -a pyright 13 | 14 | .PHONY: format 15 | format: 16 | # Fix all autofixable problems (which sorts imports) then format errors 17 | SKIP=no-commit-to-branch pre-commit run -a ruff-lint 18 | SKIP=no-commit-to-branch pre-commit run -a ruff-format 19 | 20 | .PHONY: check 21 | check: 22 | SKIP=no-commit-to-branch pre-commit run -a --hook-stage commit 23 | 24 | .PHONY: test 25 | test: 26 | python -m pytest tests/ 27 | 28 | .PHONY: test-all 29 | test-all: 30 | python -m pytest tests/ --runslow -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # e2e_sae 2 | 3 | This library is used to train and evaluate Sparse Autoencoders (SAEs). It handles the following 4 | training types: 5 | - e2e (end-to-end): Loss function includes sparsity and final model kl_divergence. 6 | - e2e + downstream reconstruction: Loss function includes sparsity, final model kl_divergence, and MSE 7 | at downstream layers. 8 | - local (i.e. vanilla SAEs): Loss function includes sparsity and MSE at the SAE layer 9 | - Any combination of the above. 10 | 11 | See our [paper](https://publications.apolloresearch.ai/end_to_end_sparse_dictionary_learning) which argues for training SAEs e2e rather than locally. All SAEs presented in the paper can be found at https://wandb.ai/sparsify/gpt2 and can be loaded using this library. 12 | 13 | ## Usage 14 | ### Installation 15 | ```bash 16 | pip install e2e_sae 17 | ``` 18 | 19 | ### Train SAEs on any [TransformerLens](https://github.com/neelnanda-io/TransformerLens) model 20 | If you would like to track your run with Weights and Biases, place your api key and entity name in 21 | a new file called `.env`. An example is provided in [.env.example](.env.example). 22 | 23 | Create a config file (see gpt2 configs [here](e2e_sae/scripts/train_tlens_saes/) for examples). 24 | Then run 25 | ```bash 26 | python e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py 27 | ``` 28 | 29 | If using a Colab notebook, see [this example](demos/train_saes.ipynb). 30 | 31 | Sample wandb sweep configs are provided in [e2e_sae/scripts/train_tlens_saes/](e2e_sae/scripts/train_tlens_saes/). 32 | 33 | The library also contains scripts for training mlps and SAEs on mlps, as well as training 34 | custom transformerlens models and SAEs on these models (see [here](e2e_sae/scripts/)). 35 | ### Load a Pre-trained SAE 36 | You can load any pre-trained SAE (and accompanying TransformerLens model) trained using this library 37 | from Weights and Biases or locally by running 38 | ```python 39 | from e2e_sae import SAETransformer 40 | model = SAETransformer.from_wandb("") 41 | # or, if stored locally 42 | model = SAETransformer.from_local_path("/path/to/checkpoint/dir") 43 | ``` 44 | All runs in our 45 | [paper](https://publications.apolloresearch.ai/end_to_end_sparse_dictionary_learning) 46 | can be loaded this way (e.g.[sparsify/gpt2/tvj2owza](https://wandb.ai/sparsify/gpt2/runs/tvj2owza)). 47 | 48 | 49 | This will instantiate a `SAETransformer` class, which contains a TransformerLens model with SAEs 50 | attached. To do a forward pass without SAEs, use the `forward_raw` method, to do a forward pass with 51 | SAEs, use the `forward` method (or simply call the SAETansformer instance). 52 | 53 | The dictionary elements of an SAE can be accessed via `SAE.dict_elements`. This is will normalize 54 | the decoder elements to have norm 1. 55 | 56 | ### Analysis 57 | To reproduce all of the analysis in our 58 | [paper](https://publications.apolloresearch.ai/end_to_end_sparse_dictionary_learning) use the 59 | scripts in `e2e_sae/scripts/analysis/`. 60 | 61 | ## Contributing 62 | Developer dependencies are installed with `make install-dev`, which will also install pre-commit 63 | hooks. 64 | 65 | Suggested extensions and settings for VSCode are provided in `.vscode/`. To use the suggested 66 | settings, copy `.vscode/settings-example.json` to `.vscode/settings.json`. 67 | 68 | There are various `make` commands that may be helpful 69 | 70 | ```bash 71 | make check # Run pre-commit checks on all files (i.e. pyright, ruff linter, and ruff formatter) 72 | make type # Run pyright on all files 73 | make format # Run ruff linter and formatter on all files 74 | make test # Run tests that aren't marked `slow` 75 | make test-all # Run all tests 76 | ``` 77 | 78 | This library is maintained by [Dan Braun](https://danbraunai.github.io/). 79 | 80 | Join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-2hk7rcm8g-IIuaxpte_1GHp5joc~1kww) 81 | to chat about this library and other projects in the space! 82 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | """Add runslow option and skip slow tests if not specified. 2 | 3 | Taken from https://docs.pytest.org/en/latest/example/simple.html. 4 | """ 5 | from collections.abc import Iterable 6 | 7 | import pytest 8 | import torch 9 | from pytest import Config, Item, Parser 10 | 11 | 12 | def pytest_addoption(parser: Parser) -> None: 13 | parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") 14 | 15 | 16 | def pytest_configure(config: Config) -> None: 17 | config.addinivalue_line("markers", "slow: mark test as slow to run") 18 | config.addinivalue_line("markers", "cpuslow: mark test as slow to run if no gpu available") 19 | 20 | 21 | def pytest_collection_modifyitems(config: Config, items: Iterable[Item]) -> None: 22 | if config.getoption("--runslow"): 23 | # --runslow given in cli: do not skip slow tests 24 | return 25 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 26 | skip_cpuslow = pytest.mark.skip(reason="need --runslow option to run since no gpu available") 27 | for item in items: 28 | if "slow" in item.keywords: 29 | item.add_marker(skip_slow) 30 | if "cpuslow" in item.keywords and not torch.cuda.is_available(): 31 | item.add_marker(skip_cpuslow) 32 | -------------------------------------------------------------------------------- /demos/train_saes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from e2e_sae.scripts.train_tlens_saes.run_train_tlens_saes import Config\n", 10 | "from e2e_sae.scripts.train_tlens_saes.run_train_tlens_saes import main as run_training" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "## Define a Config\n", 18 | "The sample config below will train a single SAE on layer 6 of gpt2 using an e2e loss.\n", 19 | "\n", 20 | "Note that this will take 10-11 hours on an A100 to run. See\n", 21 | "[e2e_sae/scripts/train_tlens_saes/tinystories_1M_e2e.yaml](../e2e_sae/scripts/train_tlens_saes/tinystories_1M_e2e.yaml)\n", 22 | "for a tinystories-1m config, or simply choose a smaller model to train on and adjust the\n", 23 | "n_ctx and dataset accordingly (some other pre-tokenized datasets can be found [here](https://huggingface.co/apollo-research)).\n" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "config = Config(\n", 33 | " wandb_project=\"gpt2-e2e_play\",\n", 34 | " wandb_run_name=None, # If not set, will use a name based on important config values\n", 35 | " wandb_run_name_prefix=\"\",\n", 36 | " seed=0,\n", 37 | " tlens_model_name=\"gpt2-small\",\n", 38 | " tlens_model_path=None,\n", 39 | " n_samples=400_000,\n", 40 | " save_every_n_samples=None,\n", 41 | " eval_every_n_samples=40_000,\n", 42 | " eval_n_samples=500,\n", 43 | " log_every_n_grad_steps=20,\n", 44 | " collect_act_frequency_every_n_samples=40_000,\n", 45 | " act_frequency_n_tokens=500_000,\n", 46 | " batch_size=8,\n", 47 | " effective_batch_size=16, # Number of samples before each optimizer step\n", 48 | " lr=5e-4,\n", 49 | " lr_schedule=\"cosine\",\n", 50 | " min_lr_factor=0.1, # Minimum learning rate as a fraction of the initial learning rate\n", 51 | " warmup_samples=20_000, # Linear warmup over this many samples\n", 52 | " max_grad_norm=10.0, # Gradient norms get clipped to this value before optimizer steps\n", 53 | " loss={\n", 54 | " # Note that \"original acts\" below refers to the activations in a model without SAEs\n", 55 | " \"sparsity\": {\n", 56 | " \"p_norm\": 1.0, # p value in Lp norm\n", 57 | " \"coeff\": 1.5, # Multiplies the Lp norm in the loss (sparsity coefficient)\n", 58 | " },\n", 59 | " \"in_to_orig\": None, # Used for e2e+future recon. MSE between the input to the SAE and original acts\n", 60 | " \"out_to_orig\": None, # Not commonly used. MSE between the output of the SAE and original acts\n", 61 | " \"out_to_in\": {\n", 62 | " # Multiplies the MSE between the output and input of the SAE. Setting to 0 lets us track this\n", 63 | " # loss during training without optimizing it\n", 64 | " \"coeff\": 0.0,\n", 65 | " },\n", 66 | " \"logits_kl\": {\n", 67 | " \"coeff\": 1.0, # Multiplies the KL divergence between the logits of the SAE model and original model\n", 68 | " },\n", 69 | " },\n", 70 | " train_data={\n", 71 | " # See https://huggingface.co/apollo-research for other pre-tokenized datasets\n", 72 | " \"dataset_name\": \"apollo-research/Skylion007-openwebtext-tokenizer-gpt2\",\n", 73 | " \"is_tokenized\": True,\n", 74 | " \"tokenizer_name\": \"gpt2\",\n", 75 | " \"streaming\": True,\n", 76 | " \"split\": \"train\",\n", 77 | " \"n_ctx\": 1024,\n", 78 | " },\n", 79 | " eval_data={\n", 80 | " # By default this will use a different seed to the training data, but can be set with `seed`\n", 81 | " \"dataset_name\": \"apollo-research/Skylion007-openwebtext-tokenizer-gpt2\",\n", 82 | " \"is_tokenized\": True,\n", 83 | " \"tokenizer_name\": \"gpt2\",\n", 84 | " \"streaming\": True,\n", 85 | " \"split\": \"train\",\n", 86 | " \"n_ctx\": 1024,\n", 87 | " },\n", 88 | " saes={\n", 89 | " \"retrain_saes\": False, # Determines whether to continue training the SAEs in pretrained_sae_paths\n", 90 | " \"pretrained_sae_paths\": None, # Path or paths to pretrained SAEs\n", 91 | " \"sae_positions\": [ # Position or positions to place SAEs in the model\n", 92 | " \"blocks.6.hook_resid_pre\",\n", 93 | " ],\n", 94 | " \"dict_size_to_input_ratio\": 60.0, # Size of the dictionary relative to the activations at the SAE positions\n", 95 | " },\n", 96 | ")" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "## Train" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "run_training(config)" 113 | ] 114 | } 115 | ], 116 | "metadata": { 117 | "kernelspec": { 118 | "display_name": "sp-env", 119 | "language": "python", 120 | "name": "python3" 121 | }, 122 | "language_info": { 123 | "codemirror_mode": { 124 | "name": "ipython", 125 | "version": 3 126 | }, 127 | "file_extension": ".py", 128 | "mimetype": "text/x-python", 129 | "name": "python", 130 | "nbconvert_exporter": "python", 131 | "pygments_lexer": "ipython3", 132 | "version": "3.11.8" 133 | } 134 | }, 135 | "nbformat": 4, 136 | "nbformat_minor": 2 137 | } 138 | -------------------------------------------------------------------------------- /e2e_sae/__init__.py: -------------------------------------------------------------------------------- 1 | from .models.transformers import SAETransformer # noqa 2 | -------------------------------------------------------------------------------- /e2e_sae/data.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import einops 4 | import numpy as np 5 | import torch 6 | from datasets import Dataset, IterableDataset, load_dataset 7 | from numpy.typing import NDArray 8 | from pydantic import BaseModel, ConfigDict 9 | from torch.utils.data import DataLoader 10 | from transformers import AutoTokenizer 11 | 12 | from e2e_sae.types import Samples 13 | 14 | 15 | class DatasetConfig(BaseModel): 16 | model_config = ConfigDict(extra="forbid", frozen=True) 17 | dataset_name: str 18 | is_tokenized: bool = True 19 | tokenizer_name: str 20 | streaming: bool = True 21 | split: str 22 | n_ctx: int 23 | seed: int | None = None 24 | column_name: str = "input_ids" 25 | """The name of the column in the dataset that contains the data (tokenized or non-tokenized). 26 | Typically 'input_ids' for datasets stored with e2e_sae/scripts/upload_hf_dataset.py, or "tokens" 27 | for datasets tokenized in TransformerLens (e.g. NeelNanda/pile-10k).""" 28 | 29 | 30 | def _keep_single_column(dataset: Dataset, col_name: str) -> Dataset: 31 | """ 32 | Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful 33 | when we want to tokenize and mix together different strings. 34 | """ 35 | for key in dataset.features: # pyright: ignore[reportAttributeAccessIssue] 36 | if key != col_name: 37 | dataset = dataset.remove_columns(key) 38 | return dataset 39 | 40 | 41 | def tokenize_and_concatenate( 42 | dataset: Dataset, 43 | tokenizer: AutoTokenizer, 44 | max_length: int = 1024, 45 | column_name: str = "text", 46 | add_bos_token: bool = False, 47 | num_proc: int = 10, 48 | ) -> Dataset: 49 | """Helper function to tokenizer and concatenate a dataset of text. This converts the text to 50 | tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of 51 | shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if 52 | parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with 53 | padding, then remove padding at the end. 54 | 55 | NOTE: Adapted from 56 | https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/utils.py#L267 57 | to handle IterableDataset. 58 | 59 | TODO: Fix typing of tokenizer 60 | 61 | This tokenization is useful for training language models, as it allows us to efficiently train 62 | on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). 63 | Further, for models with absolute positional encodings, this avoids privileging early tokens 64 | (eg, news articles often begin with CNN, and models may learn to use early positional 65 | encodings to predict these) 66 | 67 | Args: 68 | dataset: The dataset to tokenize, assumed to be a HuggingFace text dataset. Can be a regular 69 | Dataset or an IterableDataset. 70 | tokenizer: The tokenizer. Assumed to have a bos_token_id and an eos_token_id. 71 | max_length: The length of the context window of the sequence. Defaults to 1024. 72 | column_name: The name of the text column in the dataset. Defaults to 'text'. 73 | add_bos_token: Add BOS token at the beginning of each sequence. Defaults to False as this 74 | is not done during training. 75 | 76 | Returns: 77 | Dataset or IterableDataset: Returns the tokenized dataset, as a dataset of tensors, with a 78 | single column called "input_ids". 79 | 80 | Note: There is a bug when inputting very small datasets (eg, <1 batch per process) where it 81 | just outputs nothing. I'm not super sure why 82 | """ 83 | dataset = _keep_single_column(dataset, column_name) 84 | if tokenizer.pad_token is None: # pyright: ignore[reportAttributeAccessIssue] 85 | # We add a padding token, purely to implement the tokenizer. This will be removed before 86 | # inputting tokens to the model, so we do not need to increment d_vocab in the model. 87 | tokenizer.add_special_tokens({"pad_token": ""}) # pyright: ignore[reportAttributeAccessIssue] 88 | # Define the length to chop things up into - leaving space for a bos_token if required 89 | seq_len = max_length - 1 if add_bos_token else max_length 90 | 91 | def tokenize_function( 92 | examples: dict[str, list[str]], 93 | ) -> dict[ 94 | str, 95 | NDArray[np.signedinteger[Any]], 96 | ]: 97 | text = examples[column_name] 98 | # Concatenate it all into an enormous string, separated by eos_tokens 99 | full_text = tokenizer.eos_token.join(text) # pyright: ignore[reportAttributeAccessIssue] 100 | # Divide into 20 chunks of ~ equal length 101 | num_chunks = 20 102 | chunk_length = (len(full_text) - 1) // num_chunks + 1 103 | chunks = [full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)] 104 | # Tokenize the chunks in parallel. Uses no because HF map doesn't want tensors returned 105 | tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten() # type: ignore 106 | # Drop padding tokens 107 | tokens = tokens[tokens != tokenizer.pad_token_id] # pyright: ignore[reportAttributeAccessIssue] 108 | num_tokens = len(tokens) 109 | num_batches = num_tokens // (seq_len) 110 | # Drop the final tokens if not enough to make a full sequence 111 | tokens = tokens[: seq_len * num_batches] 112 | tokens = einops.rearrange( 113 | tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len 114 | ) 115 | if add_bos_token: 116 | prefix = np.full((num_batches, 1), tokenizer.bos_token_id) # pyright: ignore[reportAttributeAccessIssue] 117 | tokens = np.concatenate([prefix, tokens], axis=1) 118 | return {"input_ids": tokens} 119 | 120 | if isinstance(dataset, IterableDataset): 121 | tokenized_dataset = dataset.map( 122 | tokenize_function, batched=True, remove_columns=[column_name] 123 | ) 124 | else: 125 | tokenized_dataset = dataset.map( 126 | tokenize_function, batched=True, remove_columns=[column_name], num_proc=num_proc 127 | ) 128 | 129 | tokenized_dataset = tokenized_dataset.with_format("torch") 130 | 131 | return tokenized_dataset 132 | 133 | 134 | def create_data_loader( 135 | dataset_config: DatasetConfig, batch_size: int, buffer_size: int = 1000, global_seed: int = 0 136 | ) -> tuple[DataLoader[Samples], AutoTokenizer]: 137 | """Create a DataLoader for the given dataset. 138 | 139 | Args: 140 | dataset_config: The configuration for the dataset. 141 | batch_size: The batch size. 142 | buffer_size: The buffer size for streaming datasets. 143 | global_seed: Used for shuffling if dataset_config.seed is None. 144 | 145 | Returns: 146 | A tuple of the DataLoader and the tokenizer. 147 | """ 148 | dataset = load_dataset( 149 | dataset_config.dataset_name, streaming=dataset_config.streaming, split=dataset_config.split 150 | ) 151 | seed = dataset_config.seed if dataset_config.seed is not None else global_seed 152 | if dataset_config.streaming: 153 | assert isinstance(dataset, IterableDataset) 154 | dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) 155 | else: 156 | dataset = dataset.shuffle(seed=seed) 157 | 158 | tokenizer = AutoTokenizer.from_pretrained(dataset_config.tokenizer_name) 159 | torch_dataset: Dataset 160 | if dataset_config.is_tokenized: 161 | torch_dataset = dataset.with_format("torch") # type: ignore 162 | # Get a sample from the dataset and check if it's tokenized and what the n_ctx is 163 | # Note that the dataset may be streamed, so we can't just index into it 164 | sample = next(iter(torch_dataset))[dataset_config.column_name] # type: ignore 165 | assert ( 166 | isinstance(sample, torch.Tensor) and sample.ndim == 1 167 | ), "Expected the dataset to be tokenized." 168 | assert len(sample) == dataset_config.n_ctx, "n_ctx does not match the tokenized length." 169 | 170 | else: 171 | torch_dataset = tokenize_and_concatenate( 172 | dataset, # type: ignore 173 | tokenizer, 174 | max_length=dataset_config.n_ctx, 175 | add_bos_token=True, 176 | ) 177 | 178 | # Note that a pre-tokenized dataset was shuffled when generated 179 | # see e2e_sae.scripts.upload_hf_dataset.TextDataset.__init__ 180 | loader = DataLoader[Samples]( 181 | torch_dataset, # type: ignore 182 | batch_size=batch_size, 183 | shuffle=False, 184 | ) 185 | return loader, tokenizer 186 | -------------------------------------------------------------------------------- /e2e_sae/hooks.py: -------------------------------------------------------------------------------- 1 | from typing import Any, NamedTuple 2 | 3 | import torch 4 | from jaxtyping import Float 5 | from transformer_lens.hook_points import HookPoint 6 | 7 | from e2e_sae.models.sparsifiers import SAE 8 | 9 | 10 | class CacheActs(NamedTuple): 11 | input: Float[torch.Tensor, "... dim"] 12 | 13 | 14 | class SAEActs(NamedTuple): 15 | input: Float[torch.Tensor, "... dim"] 16 | c: Float[torch.Tensor, "... c"] 17 | output: Float[torch.Tensor, "... dim"] 18 | 19 | 20 | def sae_hook( 21 | x: Float[torch.Tensor, "... dim"], 22 | hook: HookPoint | None, 23 | sae: SAE | torch.nn.Module, 24 | hook_acts: dict[str, Any], 25 | hook_key: str, 26 | ) -> Float[torch.Tensor, "... dim"]: 27 | """Runs the SAE on the input and stores the input, output and c in hook_acts under hook_key. 28 | 29 | Args: 30 | x: The input. 31 | hook: HookPoint object. Unused. 32 | sae: The SAE to run the input through. 33 | hook_acts: Dictionary of SAEActs and CacheActs objects to store the input, c, and output in. 34 | hook_key: The key in hook_acts to store the input, c, and output in. 35 | 36 | Returns: 37 | The output of the SAE. 38 | """ 39 | output, c = sae(x) 40 | hook_acts[hook_key] = SAEActs(input=x, c=c, output=output) 41 | return output 42 | 43 | 44 | def cache_hook( 45 | x: Float[torch.Tensor, "... dim"], 46 | hook: HookPoint | None, 47 | hook_acts: dict[str, Any], 48 | hook_key: str, 49 | ) -> Float[torch.Tensor, "... dim"]: 50 | """Stores the input in hook_acts under hook_key. 51 | 52 | Args: 53 | x: The input. 54 | hook: HookPoint object. Unused. 55 | hook_acts: CacheActs object to store the input in. 56 | 57 | Returns: 58 | The input. 59 | """ 60 | hook_acts[hook_key] = CacheActs(input=x) 61 | return x 62 | -------------------------------------------------------------------------------- /e2e_sae/loader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import yaml 5 | from transformer_lens import HookedTransformer, HookedTransformerConfig 6 | 7 | from e2e_sae.scripts.train_tlens.run_train_tlens import HookedTransformerPreConfig 8 | from e2e_sae.types import RootPath 9 | 10 | 11 | def load_tlens_model( 12 | tlens_model_name: str | None, tlens_model_path: RootPath | None 13 | ) -> HookedTransformer: 14 | """Load transformerlens model from either HuggingFace or local path.""" 15 | if tlens_model_name is not None: 16 | tlens_model = HookedTransformer.from_pretrained(tlens_model_name) 17 | else: 18 | assert tlens_model_path is not None, "tlens_model_path is None." 19 | # Load the tlens_config 20 | assert ( 21 | tlens_model_path.parent / "final_config.yaml" 22 | ).exists(), "final_config.yaml does not exist." 23 | with open(tlens_model_path.parent / "final_config.yaml") as f: 24 | tlens_config = HookedTransformerPreConfig(**yaml.safe_load(f)["tlens_config"]) 25 | hooked_transformer_config = HookedTransformerConfig(**tlens_config.model_dump()) 26 | 27 | # Load the model 28 | tlens_model = HookedTransformer(hooked_transformer_config) 29 | tlens_model.load_state_dict(torch.load(tlens_model_path, map_location="cpu")) 30 | 31 | assert tlens_model.tokenizer is not None 32 | return tlens_model 33 | 34 | 35 | def load_pretrained_saes( 36 | saes: torch.nn.ModuleDict, 37 | pretrained_sae_paths: list[Path], 38 | all_param_names: list[str], 39 | retrain_saes: bool, 40 | ) -> list[str]: 41 | """Load in the pretrained SAEs to saes (in place) and return the trainable param names. 42 | 43 | Args: 44 | saes: The base SAEs to load the pretrained SAEs into. Updated in place. 45 | pretrained_sae_paths: List of paths to the pretrained SAEs. 46 | all_param_names: List of all the parameter names in saes. 47 | retrain_saes: Whether to retrain the pretrained SAEs. 48 | 49 | Returns: 50 | The updated all_param_names. 51 | """ 52 | pretrained_sae_params = {} 53 | for pretrained_sae_path in pretrained_sae_paths: 54 | # Add new sae params (note that this will overwrite existing SAEs with the same name) 55 | pretrained_sae_params = {**pretrained_sae_params, **torch.load(pretrained_sae_path)} 56 | sae_state_dict = {**dict(saes.named_parameters()), **pretrained_sae_params} 57 | 58 | saes.load_state_dict(sae_state_dict) 59 | if not retrain_saes: 60 | # Don't retrain the pretrained SAEs 61 | trainable_param_names = [ 62 | name for name in all_param_names if name not in pretrained_sae_params 63 | ] 64 | else: 65 | trainable_param_names = all_param_names 66 | 67 | return trainable_param_names 68 | -------------------------------------------------------------------------------- /e2e_sae/log.py: -------------------------------------------------------------------------------- 1 | """Setup a logger to be used in all modules in the library. 2 | 3 | To use the logger, import it in any module and use it as follows: 4 | 5 | ``` 6 | from e2e_sae.log import logger 7 | logger.info("Info message") 8 | logger.warning("Warning message") 9 | ``` 10 | """ 11 | 12 | import logging 13 | from logging.config import dictConfig 14 | from pathlib import Path 15 | 16 | DEFAULT_LOGFILE = Path(__file__).resolve().parent.parent / "logs" / "logs.log" 17 | 18 | 19 | def setup_logger(logfile: Path = DEFAULT_LOGFILE) -> logging.Logger: 20 | """Setup a logger to be used in all modules in the library. 21 | 22 | Sets up logging configuration with a console handler and a file handler. 23 | Console handler logs messages with INFO level, file handler logs WARNING level. 24 | The root logger is configured to use both handlers. 25 | 26 | Returns: 27 | logging.Logger: A configured logger object. 28 | 29 | Example: 30 | >>> logger = setup_logger() 31 | >>> logger.debug("Debug message") 32 | >>> logger.info("Info message") 33 | >>> logger.warning("Warning message") 34 | """ 35 | if not logfile.parent.exists(): 36 | logfile.parent.mkdir(parents=True, exist_ok=True) 37 | 38 | logging_config = { 39 | "version": 1, 40 | "formatters": { 41 | "default": { 42 | "format": "%(asctime)s - %(levelname)s - %(message)s", 43 | "datefmt": "%Y-%m-%d %H:%M:%S", 44 | }, 45 | }, 46 | "handlers": { 47 | "console": { 48 | "class": "logging.StreamHandler", 49 | "formatter": "default", 50 | "level": "INFO", 51 | }, 52 | "file": { 53 | "class": "logging.FileHandler", 54 | "filename": str(logfile), 55 | "formatter": "default", 56 | "level": "WARNING", 57 | }, 58 | }, 59 | "root": { 60 | "handlers": ["console", "file"], 61 | "level": "INFO", 62 | }, 63 | } 64 | 65 | dictConfig(logging_config) 66 | return logging.getLogger() 67 | 68 | 69 | logger = setup_logger() 70 | -------------------------------------------------------------------------------- /e2e_sae/losses.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | import einops 4 | import torch 5 | import torch.nn.functional as F 6 | from jaxtyping import Float 7 | from pydantic import BaseModel, BeforeValidator, ConfigDict, Field 8 | from torch import Tensor 9 | 10 | from e2e_sae.hooks import CacheActs, SAEActs 11 | 12 | 13 | def _layer_norm_pre(x: Float[Tensor, "... dim"], eps: float = 1e-5) -> Float[Tensor, "... dim"]: 14 | """Layernorm without the affine transformation.""" 15 | x = x - x.mean(dim=-1, keepdim=True) 16 | scale = (x.pow(2).mean(-1, keepdim=True) + eps).sqrt() 17 | return x / scale 18 | 19 | 20 | def calc_explained_variance( 21 | pred: Float[Tensor, "... dim"], target: Float[Tensor, "... dim"], layer_norm: bool = False 22 | ) -> Float[Tensor, "..."]: 23 | """Calculate the explained variance of the pred and target. 24 | 25 | Args: 26 | pred: The prediction to compare to the target. 27 | target: The target to compare the prediction to. 28 | layer_norm: Whether to apply layer norm to the pred and target before calculating the loss. 29 | 30 | Returns: 31 | The explained variance between the prediction and target for each batch and sequence pos. 32 | """ 33 | if layer_norm: 34 | pred = _layer_norm_pre(pred) 35 | target = _layer_norm_pre(target) 36 | sample_dims = tuple(range(pred.ndim - 1)) 37 | per_token_l2_loss = (pred - target).pow(2).sum(dim=-1) 38 | total_variance = (target - target.mean(dim=sample_dims)).pow(2).sum(dim=-1) 39 | return 1 - per_token_l2_loss / total_variance 40 | 41 | 42 | class SparsityLoss(BaseModel): 43 | model_config = ConfigDict(extra="forbid", frozen=True) 44 | coeff: float 45 | p_norm: float = 1.0 46 | 47 | def calc_loss(self, c: Float[Tensor, "... c"], dense_dim: int) -> Float[Tensor, ""]: 48 | """Calculate the sparsity loss. 49 | Note that we divide by the dimension of the input to the SAE. This helps with using the same 50 | hyperparameters across different model sizes (input dimension is more relevant than the c 51 | dimension for Lp loss). 52 | Args: 53 | c: The activations after the non-linearity in the SAE. 54 | dense_dim: The dimension of the input to the SAE. Used to normalize the loss. 55 | Returns: 56 | The L_p norm of the activations. 57 | """ 58 | return torch.norm(c, p=self.p_norm, dim=-1).mean() / dense_dim 59 | 60 | 61 | class InToOrigLoss(BaseModel): 62 | """Config for the loss between the input and original activations. 63 | 64 | The input activations may come from the input to an SAE or the activations at a cache_hook. 65 | 66 | Note that `run_train_tlens_saes.evaluate` will automatically log the in_to_orig loss for all 67 | residual stream positions, so you do not need to set values here with coeff=0.0 for logging. 68 | """ 69 | 70 | model_config = ConfigDict(extra="forbid", frozen=True) 71 | total_coeff: float = Field( 72 | ..., description="The sum of coefficients equally weighted across all hook_positions." 73 | ) 74 | hook_positions: Annotated[ 75 | list[str], BeforeValidator(lambda x: [x] if isinstance(x, str) else x) 76 | ] = Field( 77 | ..., 78 | description="The exact hook positions at which to compare raw and SAE-augmented " 79 | "activations. E.g. 'blocks.3.hook_resid_post' or " 80 | "['blocks.3.hook_resid_post', 'blocks.5.hook_resid_post'].", 81 | ) 82 | 83 | @property 84 | def coeff(self) -> float: 85 | """The coefficient for the loss of each hook position.""" 86 | return self.total_coeff / len(self.hook_positions) 87 | 88 | def calc_loss( 89 | self, input: Float[Tensor, "... dim"], orig: Float[Tensor, "... dim"] 90 | ) -> Float[Tensor, ""]: 91 | """Calculate the MSE between the input and orig.""" 92 | return F.mse_loss(input, orig) 93 | 94 | 95 | class OutToOrigLoss(BaseModel): 96 | model_config = ConfigDict(extra="forbid", frozen=True) 97 | coeff: float 98 | 99 | def calc_loss( 100 | self, output: Float[Tensor, "... dim"], orig: Float[Tensor, "... dim"] 101 | ) -> Float[Tensor, ""]: 102 | """Calculate loss between the output of the SAE and the non-SAE-augmented activations.""" 103 | return F.mse_loss(output, orig) 104 | 105 | 106 | class OutToInLoss(BaseModel): 107 | model_config = ConfigDict(extra="forbid", frozen=True) 108 | coeff: float 109 | 110 | def calc_loss( 111 | self, input: Float[Tensor, "... dim"], output: Float[Tensor, "... dim"] 112 | ) -> Float[Tensor, ""]: 113 | """Calculate loss between the input and output of the SAE.""" 114 | return F.mse_loss(input, output) 115 | 116 | 117 | class LogitsKLLoss(BaseModel): 118 | model_config = ConfigDict(extra="forbid", frozen=True) 119 | coeff: float 120 | 121 | def calc_loss( 122 | self, new_logits: Float[Tensor, "... vocab"], orig_logits: Float[Tensor, "... vocab"] 123 | ) -> Float[Tensor, ""]: 124 | """Calculate KL divergence between SAE-augmented and non-SAE-augmented logits. 125 | 126 | Important: new_logits should be passed first as we want the relative entropy from 127 | new_logits to orig_logits - KL(new_logits || orig_logits). 128 | 129 | We flatten all but the last dimensions and take the mean over this new dimension. 130 | """ 131 | new_logits_flat = einops.rearrange(new_logits, "... vocab -> (...) vocab") 132 | orig_logits_flat = einops.rearrange(orig_logits, "... vocab -> (...) vocab") 133 | 134 | return F.kl_div( 135 | F.log_softmax(new_logits_flat, dim=-1), 136 | F.log_softmax(orig_logits_flat, dim=-1), 137 | log_target=True, 138 | reduction="batchmean", 139 | ) 140 | 141 | 142 | class LossConfigs(BaseModel): 143 | model_config = ConfigDict(extra="forbid", frozen=True) 144 | sparsity: SparsityLoss 145 | in_to_orig: InToOrigLoss | None 146 | out_to_orig: OutToOrigLoss | None 147 | out_to_in: OutToInLoss | None 148 | logits_kl: LogitsKLLoss | None 149 | 150 | @property 151 | def activation_loss_configs( 152 | self, 153 | ) -> dict[str, SparsityLoss | InToOrigLoss | OutToOrigLoss | OutToInLoss | None]: 154 | return { 155 | "sparsity": self.sparsity, 156 | "in_to_orig": self.in_to_orig, 157 | "out_to_orig": self.out_to_orig, 158 | "out_to_in": self.out_to_in, 159 | } 160 | 161 | 162 | def calc_loss( 163 | orig_acts: dict[str, Tensor], 164 | new_acts: dict[str, SAEActs | CacheActs], 165 | orig_logits: Float[Tensor, "batch pos vocab"] | None, 166 | new_logits: Float[Tensor, "batch pos vocab"] | None, 167 | loss_configs: LossConfigs, 168 | is_log_step: bool = False, 169 | train: bool = True, 170 | ) -> tuple[Float[Tensor, ""], dict[str, Float[Tensor, ""]]]: 171 | """Compute losses. 172 | 173 | Note that some losses may be computed on the final logits, while others may be computed on 174 | intermediate activations. 175 | 176 | Additionally, for cache activations, only the in_to_orig loss is computed. 177 | 178 | Args: 179 | orig_acts: Dictionary of original activations, keyed by tlens attribute. 180 | new_acts: Dictionary of SAE or cache activations. Keys should match orig_acts. 181 | orig_logits: Logits from non-SAE-augmented model. 182 | new_logits: Logits from SAE-augmented model. 183 | loss_configs: Config for the losses to be computed. 184 | is_log_step: Whether to store additional loss information for logging. 185 | train: Whether in train or evaluation mode. Only affects the keys of the loss_dict. 186 | 187 | Returns: 188 | loss: Scalar tensor representing the loss. 189 | loss_dict: Dictionary of losses, keyed by loss type and name. 190 | """ 191 | assert set(orig_acts.keys()) == set(new_acts.keys()), ( 192 | f"Keys of orig_acts and new_acts must match, got {orig_acts.keys()} and " 193 | f"{new_acts.keys()}" 194 | ) 195 | 196 | prefix = "loss/train" if train else "loss/eval" 197 | 198 | loss: Float[Tensor, ""] = torch.zeros( 199 | 1, device=next(iter(orig_acts.values())).device, dtype=next(iter(orig_acts.values())).dtype 200 | ) 201 | loss_dict = {} 202 | 203 | if loss_configs.logits_kl and orig_logits is not None and new_logits is not None: 204 | loss_dict[f"{prefix}/logits_kl"] = loss_configs.logits_kl.calc_loss( 205 | new_logits=new_logits, orig_logits=orig_logits 206 | ) 207 | loss = loss + loss_configs.logits_kl.coeff * loss_dict[f"{prefix}/logits_kl"] 208 | 209 | for name, orig_act in orig_acts.items(): 210 | # Convert from inference tensor. 211 | orig_act = orig_act.detach().clone() 212 | new_act = new_acts[name] 213 | 214 | for config_type, loss_config in loss_configs.activation_loss_configs.items(): 215 | if isinstance(new_act, CacheActs) and not isinstance(loss_config, InToOrigLoss): 216 | # Cache acts are only used for in_to_orig loss 217 | continue 218 | 219 | var: Float[Tensor, "batch_token"] | None = None # noqa: F821 220 | var_ln: Float[Tensor, "batch_token"] | None = None # noqa: F821 221 | if isinstance(loss_config, InToOrigLoss) and name in loss_config.hook_positions: 222 | # Note that out_to_in can calculate losses using CacheActs or SAEActs. 223 | loss_val = loss_config.calc_loss(new_act.input, orig_act) 224 | var = calc_explained_variance( 225 | new_act.input.detach().clone(), orig_act, layer_norm=False 226 | ) 227 | var_ln = calc_explained_variance( 228 | new_act.input.detach().clone(), orig_act, layer_norm=True 229 | ) 230 | elif isinstance(loss_config, OutToOrigLoss): 231 | assert isinstance(new_act, SAEActs) 232 | loss_val = loss_config.calc_loss(new_act.output, orig_act) 233 | var = calc_explained_variance( 234 | new_act.output.detach().clone(), orig_act, layer_norm=False 235 | ) 236 | var_ln = calc_explained_variance( 237 | new_act.output.detach().clone(), orig_act, layer_norm=True 238 | ) 239 | elif isinstance(loss_config, OutToInLoss): 240 | assert isinstance(new_act, SAEActs) 241 | loss_val = loss_config.calc_loss(new_act.input, new_act.output) 242 | var = calc_explained_variance( 243 | new_act.input.detach().clone(), new_act.output, layer_norm=False 244 | ) 245 | var_ln = calc_explained_variance( 246 | new_act.input.detach().clone(), new_act.output, layer_norm=True 247 | ) 248 | elif isinstance(loss_config, SparsityLoss): 249 | assert isinstance(new_act, SAEActs) 250 | loss_val = loss_config.calc_loss(new_act.c, dense_dim=new_act.input.shape[-1]) 251 | else: 252 | assert loss_config is None or ( 253 | isinstance(loss_config, InToOrigLoss) and name not in loss_config.hook_positions 254 | ), f"Unexpected loss_config {loss_config} for name {name}" 255 | continue 256 | 257 | loss = loss + loss_config.coeff * loss_val 258 | loss_dict[f"{prefix}/{config_type}/{name}"] = loss_val.detach().clone() 259 | 260 | if ( 261 | var is not None 262 | and var_ln is not None 263 | and is_log_step 264 | and isinstance(loss_config, InToOrigLoss | OutToOrigLoss | OutToInLoss) 265 | ): 266 | loss_dict[f"{prefix}/{config_type}/explained_variance/{name}"] = var.mean() 267 | loss_dict[f"{prefix}/{config_type}/explained_variance_std/{name}"] = var.std() 268 | loss_dict[f"{prefix}/{config_type}/explained_variance_ln/{name}"] = var_ln.mean() 269 | loss_dict[f"{prefix}/{config_type}/explained_variance_ln_std/{name}"] = var_ln.std() 270 | 271 | return loss, loss_dict 272 | -------------------------------------------------------------------------------- /e2e_sae/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import wandb 4 | from einops import einsum, repeat 5 | from jaxtyping import Float, Int 6 | from torch import Tensor 7 | from tqdm import tqdm 8 | from transformer_lens.utils import lm_cross_entropy_loss 9 | 10 | from e2e_sae.data import DatasetConfig, create_data_loader 11 | from e2e_sae.hooks import CacheActs, SAEActs 12 | from e2e_sae.models.transformers import SAETransformer 13 | 14 | 15 | def topk_accuracy( 16 | logits: Float[Tensor, "... vocab"], 17 | tokens: Int[Tensor, "batch pos"] | Int[Tensor, "pos"], # noqa: F821 18 | k: int = 1, 19 | per_token: bool = False, 20 | ) -> Tensor: 21 | """The proportion of the time that the true token lies within the top k predicted tokens.""" 22 | top_predictions = logits.topk(k=k, dim=-1).indices 23 | tokens_repeated = repeat(tokens, "... -> ... k", k=k) 24 | correct_matches = (top_predictions == tokens_repeated).any(dim=-1) 25 | if per_token: 26 | return correct_matches 27 | else: 28 | return correct_matches.sum() / correct_matches.numel() 29 | 30 | 31 | def top1_consistency( 32 | orig_logits: Float[Tensor, "... vocab"], 33 | new_logits: Float[Tensor, "... vocab"], 34 | per_token: bool = False, 35 | ) -> Tensor: 36 | """The proportion of the time the original model and SAE-model predict the same next token.""" 37 | orig_prediction = orig_logits.argmax(dim=-1) 38 | sae_prediction = new_logits.argmax(dim=-1) 39 | correct_matches = orig_prediction == sae_prediction 40 | if per_token: 41 | return correct_matches 42 | else: 43 | return correct_matches.sum() / correct_matches.numel() 44 | 45 | 46 | def statistical_distance( 47 | orig_logits: Float[Tensor, "... vocab"], new_logits: Float[Tensor, "... vocab"] 48 | ) -> Tensor: 49 | """The sum of absolute differences between the original model and SAE-model probabilities.""" 50 | orig_probs = torch.exp(F.log_softmax(orig_logits, dim=-1)) 51 | new_probs = torch.exp(F.log_softmax(new_logits, dim=-1)) 52 | return 0.5 * torch.abs(orig_probs - new_probs).sum(dim=-1).mean() 53 | 54 | 55 | class ActFrequencyMetrics: 56 | """Manages activation frequency metrics calculated over fixed spans of training batches.""" 57 | 58 | def __init__(self, dict_sizes: dict[str, int], device: torch.device) -> None: 59 | """Initialize ActFrequencyMetrics. Create the dictionary element frequency tensors. 60 | 61 | Args: 62 | dict_sizes: Sizes of the dictionaries for each sae position. 63 | device: Device to store dictionary element activation frequencies on. 64 | """ 65 | self.tokens_used = 0 # Number of tokens used in dict_el_frequencies 66 | self.dict_el_frequencies: dict[str, Float[Tensor, "dims"]] = { # noqa: F821 67 | sae_pos: torch.zeros(dict_size, device=device) 68 | for sae_pos, dict_size in dict_sizes.items() 69 | } 70 | self.dict_el_frequency_history: dict[str, list[Float[Tensor, "dims"]]] = { # noqa: F821 71 | sae_pos: [] for sae_pos, dict_size in dict_sizes.items() 72 | } 73 | 74 | def update_dict_el_frequencies( 75 | self, new_acts: dict[str, SAEActs | CacheActs], batch_tokens: int 76 | ) -> None: 77 | """Update the dictionary element frequencies with the new batch frequencies. 78 | 79 | Args: 80 | new_acts: Dictionary of activations for each hook position. 81 | batch_tokens: Number of tokens used to produce the sae acts. 82 | """ 83 | for sae_pos in self.dict_el_frequencies: 84 | new_acts_pos = new_acts[sae_pos] 85 | if isinstance(new_acts_pos, SAEActs): 86 | self.dict_el_frequencies[sae_pos] += einsum(new_acts_pos.c != 0, "... dim -> dim") 87 | self.tokens_used += batch_tokens 88 | 89 | def collect_for_logging( 90 | self, log_wandb_histogram: bool = True, post_training: bool = False 91 | ) -> dict[str, list[float] | int]: 92 | """Collect the activation frequency metrics for logging. 93 | 94 | Currently collects: 95 | - The number of alive dictionary elements for each hook. 96 | - The indices of the alive dictionary elements for each hook. 97 | - The histogram of dictionary element activation frequencies for each hook (if 98 | log_wandb_histogram is True). 99 | 100 | Note that the dictionary element frequencies are divided by the number of tokens used to 101 | calculate them. 102 | 103 | Args: 104 | log_wandb_histogram: Whether to log the dictionary element activation frequency 105 | histograms to wandb. 106 | post_training: Whether the metrics are being collected post-training. Affects the name 107 | of the metrics. 108 | """ 109 | log_dict = {} 110 | for sae_pos in self.dict_el_frequencies: 111 | self.dict_el_frequencies[sae_pos] /= self.tokens_used 112 | self.dict_el_frequency_history[sae_pos].append( 113 | self.dict_el_frequencies[sae_pos].detach().cpu() 114 | ) 115 | alive_elements_name = "alive_dict_elements" 116 | alive_indices_name = "alive_dict_elements_indices" 117 | if post_training: 118 | alive_elements_name += "_final" 119 | alive_indices_name += "_final" 120 | 121 | log_dict[f"sparsity/{alive_indices_name}/{sae_pos}"] = [ 122 | i for i, v in enumerate(self.dict_el_frequencies[sae_pos]) if v > 0 123 | ] 124 | log_dict[f"sparsity/{alive_elements_name}/{sae_pos}"] = len( 125 | log_dict[f"sparsity/{alive_indices_name}/{sae_pos}"] 126 | ) 127 | 128 | if log_wandb_histogram: 129 | data = [[s] for s in self.dict_el_frequencies[sae_pos]] 130 | data_log = [[torch.log10(s + 1e-10)] for s in self.dict_el_frequencies[sae_pos]] 131 | plot = wandb.plot.histogram( 132 | wandb.Table(data=data, columns=["dict element activation frequency"]), 133 | "dict element activation frequency", 134 | title=f"{sae_pos} (most_recent_n_tokens={self.tokens_used} " 135 | f"dict_size={self.dict_el_frequencies[sae_pos].shape[0]})", 136 | ) 137 | plot_log10 = wandb.plot.histogram( 138 | wandb.Table( 139 | data=data_log, columns=["log10(dict element activation frequency)"] 140 | ), 141 | "log10(dict element activation frequency)", 142 | title=f"{sae_pos} (most_recent_n_tokens={self.tokens_used} " 143 | f"dict_size={self.dict_el_frequencies[sae_pos].shape[0]})", 144 | ) 145 | log_dict[f"sparsity/dict_el_frequency_hist/{sae_pos}"] = plot 146 | log_dict[f"sparsity/dict_el_frequency_hist/log10/{sae_pos}"] = plot_log10 147 | 148 | log_dict[f"sparsity/dict_el_frequency_hist/over_time/{sae_pos}"] = wandb.Histogram( 149 | self.dict_el_frequency_history[sae_pos] 150 | ) 151 | log_dict[ 152 | f"sparsity/dict_el_frequency_hist/over_time/log10/{sae_pos}" 153 | ] = wandb.Histogram( 154 | [torch.log10(s + 1e-10) for s in self.dict_el_frequency_history[sae_pos]] 155 | ) 156 | return log_dict 157 | 158 | 159 | @torch.inference_mode() 160 | def calc_sparsity_metrics( 161 | new_acts: dict[str, SAEActs | CacheActs], train: bool = True 162 | ) -> dict[str, float]: 163 | """Collect sparsity metrics for logging. 164 | 165 | Args: 166 | new_acts: Dictionary of activations for each hook position (may include SAE or cache acts). 167 | train: Whether in train or evaluation mode. Only affects the keys of the metrics. 168 | 169 | Returns: 170 | Dictionary of sparsity metrics. 171 | """ 172 | prefix = "sparsity/train" if train else "sparsity/eval" 173 | sparsity_metrics = {} 174 | for name, new_act in new_acts.items(): 175 | if isinstance(new_act, SAEActs): 176 | # Record L_0 norm of the cs 177 | l_0_norm = torch.norm(new_act.c, p=0, dim=-1).mean().item() 178 | sparsity_metrics[f"{prefix}/L_0/{name}"] = l_0_norm 179 | 180 | # Record fraction of zeros in the cs 181 | frac_zeros = ((new_act.c == 0).sum() / new_act.c.numel()).item() 182 | sparsity_metrics[f"{prefix}/frac_zeros/{name}"] = frac_zeros 183 | 184 | return sparsity_metrics 185 | 186 | 187 | @torch.inference_mode() 188 | def calc_output_metrics( 189 | tokens: Int[Tensor, "batch pos"] | Int[Tensor, "pos"], # noqa: F821 190 | orig_logits: Float[Tensor, "... vocab"], 191 | new_logits: Float[Tensor, "... vocab"], 192 | train: bool = True, 193 | ) -> dict[str, float]: 194 | """Get metrics on the outputs of the SAE-augmented model and the original model. 195 | 196 | Args: 197 | tokens: The tokens used to produce the logits. 198 | orig_logits: The logits produced by the original model. 199 | new_logits: The logits produced by the SAE model. 200 | train: Whether in train or evaluation mode. Only affects the keys of the metrics. 201 | 202 | Returns: 203 | Dictionary of output metrics 204 | """ 205 | orig_model_ce_loss = lm_cross_entropy_loss(orig_logits, tokens, per_token=False).item() 206 | sae_model_ce_loss = lm_cross_entropy_loss(new_logits, tokens, per_token=False).item() 207 | 208 | orig_model_top1_accuracy = topk_accuracy(orig_logits, tokens, k=1, per_token=False).item() 209 | sae_model_top1_accuracy = topk_accuracy(new_logits, tokens, k=1, per_token=False).item() 210 | orig_vs_sae_top1_consistency = top1_consistency(orig_logits, new_logits, per_token=False).item() 211 | orig_vs_sae_stat_distance = statistical_distance(orig_logits, new_logits).item() 212 | 213 | prefix = "performance/train" if train else "performance/eval" 214 | metrics = { 215 | f"{prefix}/orig_model_ce_loss": orig_model_ce_loss, 216 | f"{prefix}/sae_model_ce_loss": sae_model_ce_loss, 217 | f"{prefix}/difference_ce_loss": orig_model_ce_loss - sae_model_ce_loss, 218 | f"{prefix}/orig_model_top1_accuracy": orig_model_top1_accuracy, 219 | f"{prefix}/sae_model_top1_accuracy": sae_model_top1_accuracy, 220 | f"{prefix}/difference_top1_accuracy": orig_model_top1_accuracy - sae_model_top1_accuracy, 221 | f"{prefix}/orig_vs_sae_top1_consistency": orig_vs_sae_top1_consistency, 222 | f"{prefix}/orig_vs_sae_statistical_distance": orig_vs_sae_stat_distance, 223 | } 224 | return metrics 225 | 226 | 227 | def collect_act_frequency_metrics( 228 | model: SAETransformer, 229 | data_config: DatasetConfig, 230 | batch_size: int, 231 | global_seed: int, 232 | device: torch.device, 233 | n_tokens: int, 234 | ) -> dict[str, int | list[float]]: 235 | """Collect SAE activation frequency metrics for a SAETransformer model. 236 | Args: 237 | model: The SAETransformer model. 238 | data_config: The data configuration to use for calculating the frequency metrics. 239 | batch_size: The batch size. 240 | global_seed: The global seed. Only matters when data_config.seed is None. 241 | device: The device to use. 242 | n_tokens: The number of tokens to use for calculating the frequency metrics. 243 | 244 | Returns: 245 | A dictionary of the collected metrics (e.g. alive dictionary elements and their indices). 246 | """ 247 | 248 | data_loader = create_data_loader(data_config, batch_size=batch_size, global_seed=global_seed)[0] 249 | 250 | act_frequency_metrics = ActFrequencyMetrics( 251 | dict_sizes={ 252 | raw_pos: model.saes[all_pos].decoder.in_features 253 | for raw_pos, all_pos in zip( 254 | model.raw_sae_positions, model.all_sae_positions, strict=True 255 | ) 256 | }, 257 | device=device, 258 | ) 259 | n_samples = n_tokens // data_config.n_ctx 260 | n_batches = n_samples // batch_size 261 | 262 | # Iterate over the data loader and calculate the frequency metrics 263 | for batch_idx, batch in tqdm(enumerate(data_loader), total=n_batches, desc="Batches"): 264 | if batch_idx >= n_batches: 265 | break 266 | tokens = batch[data_config.column_name].to(device=device) 267 | _, new_acts = model.forward( 268 | tokens=tokens, 269 | sae_positions=model.raw_sae_positions, 270 | cache_positions=None, 271 | ) 272 | act_frequency_metrics.update_dict_el_frequencies( 273 | new_acts, batch_tokens=tokens.shape[0] * tokens.shape[1] 274 | ) 275 | metrics = act_frequency_metrics.collect_for_logging( 276 | log_wandb_histogram=False, post_training=True 277 | ) 278 | return metrics 279 | -------------------------------------------------------------------------------- /e2e_sae/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloResearch/e2e_sae/83b4add4d2652a4e5c6775a0e2b752897c871a87/e2e_sae/models/__init__.py -------------------------------------------------------------------------------- /e2e_sae/models/mlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines a generic MLP. 3 | """ 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from e2e_sae.models.sparsifiers import SAE, Codebook 9 | 10 | 11 | class Layer(nn.Module): 12 | """ 13 | Neural network layer consisting of a linear layer followed by RELU. 14 | 15 | Args: 16 | in_features: The size of each input. 17 | out_features: The size of each output. 18 | has_activation_fn: Whether to use an activation function. Default is True. 19 | bias: Whether to add a bias term to the linear transformation. Default is True. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | in_features: int, 25 | out_features: int, 26 | has_activation_fn: bool = True, 27 | bias: bool = True, 28 | ): 29 | super().__init__() 30 | self.linear = nn.Linear(in_features, out_features, bias=bias) 31 | if has_activation_fn: 32 | self.activation = nn.GELU() 33 | 34 | def forward(self, x: torch.Tensor) -> torch.Tensor: 35 | x = self.linear(x) 36 | if hasattr(self, "activation"): 37 | x = self.activation(x) 38 | return x 39 | 40 | 41 | class MLP(nn.Module): 42 | """ 43 | This class defines an MLP with a variable number of hidden layers. 44 | 45 | All layers use a linear transformation followed by RELU, except for the final layer which 46 | uses a linear transformation followed by no activation function. 47 | 48 | Args: 49 | hidden_sizes: A list of integers specifying the sizes of the hidden layers. 50 | input_size: The size of each input sample. 51 | output_size: The size of each output sample. 52 | bias: Whether to add a bias term to the linear transformations. Default is True. 53 | """ 54 | 55 | def __init__( 56 | self, 57 | hidden_sizes: list[int] | None, 58 | input_size: int, 59 | output_size: int, 60 | bias: bool = True, 61 | ): 62 | super().__init__() 63 | 64 | if hidden_sizes is None: 65 | hidden_sizes = [] 66 | 67 | # Size of each layer (including input and output) 68 | sizes = [input_size] + hidden_sizes + [output_size] 69 | 70 | layers = nn.ModuleList() 71 | for i in range(len(sizes) - 1): 72 | # No activation for final layer 73 | has_activation_fn = i < len(sizes) - 2 74 | layers.append( 75 | Layer( 76 | in_features=sizes[i], 77 | out_features=sizes[i + 1], 78 | has_activation_fn=has_activation_fn, 79 | bias=bias, 80 | ) 81 | ) 82 | self.layers = nn.Sequential(*layers) 83 | 84 | def forward(self, x: torch.Tensor) -> torch.Tensor: 85 | return self.layers(x) 86 | 87 | 88 | class MLPMod(nn.Module): 89 | """ 90 | This class defines an MLP with a variable number of hidden layers. 91 | 92 | All layers use a linear transformation followed by RELU, except for the final layer which 93 | uses a linear transformation followed by no activation function. 94 | 95 | Args: 96 | hidden_sizes: A list of integers specifying the sizes of the hidden layers. 97 | input_size: The size of each input sample. 98 | output_size: The size of each output sample. 99 | bias: Whether to add a bias term to the linear transformations. Default is True. 100 | """ 101 | 102 | def __init__( 103 | self, 104 | hidden_sizes: list[int] | None, 105 | input_size: int, 106 | output_size: int, 107 | bias: bool = True, 108 | type_of_sparsifier: str = "sae", 109 | dict_eles_to_input_ratio: float = 2, 110 | k: int = 0, 111 | ): 112 | super().__init__() 113 | 114 | if hidden_sizes is None: 115 | hidden_sizes = [] 116 | 117 | # Size of each layer (including input and output) 118 | sizes = [input_size] + hidden_sizes + [output_size] 119 | 120 | self.layers = nn.ModuleDict() 121 | self.sparsifiers = nn.ModuleDict() 122 | self.dict_eles_to_input_ratio = dict_eles_to_input_ratio 123 | for i in range(len(sizes) - 1): 124 | has_activation_fn = i < len(sizes) - 2 125 | # Add layers with custom keys 126 | self.layers[f"{i}"] = Layer( 127 | in_features=sizes[i], 128 | out_features=sizes[i + 1], 129 | has_activation_fn=has_activation_fn, 130 | bias=bias, 131 | ) 132 | if has_activation_fn: 133 | if type_of_sparsifier == "sae": 134 | sparsifier = SAE( 135 | input_size=sizes[i + 1], 136 | n_dict_components=int(sizes[i + 1] * self.dict_eles_to_input_ratio), 137 | ) 138 | elif type_of_sparsifier == "codebook": 139 | assert k > 0, "k must be greater than 0" 140 | sparsifier = Codebook( 141 | input_size=sizes[i + 1], 142 | n_dict_components=int(sizes[i + 1] * dict_eles_to_input_ratio), 143 | k=k, 144 | ) 145 | else: 146 | raise ValueError("type_of_sparsifier must be either 'sae' or 'codebook'") 147 | self.sparsifiers[f"{i}"] = sparsifier 148 | 149 | def forward( 150 | self, x: torch.Tensor 151 | ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], dict[str, torch.Tensor]]: 152 | outs = {} 153 | cs = {} 154 | sparsifiers_outs = {} 155 | for i in range(len(self.layers)): 156 | x = self.layers[f"{i}"](x) 157 | outs[f"{i}"] = x 158 | if f"{i}" in self.sparsifiers: 159 | x, c = self.sparsifiers[f"{i}"](x) 160 | cs[f"{i}"] = c 161 | sparsifiers_outs[f"{i}"] = x 162 | return outs, cs, sparsifiers_outs 163 | -------------------------------------------------------------------------------- /e2e_sae/models/sparsifiers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines a generic MLP. 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class SAE(nn.Module): 11 | """ 12 | Sparse AutoEncoder 13 | """ 14 | 15 | def __init__( 16 | self, input_size: int, n_dict_components: int, init_decoder_orthogonal: bool = True 17 | ): 18 | """Initialize the SAE. 19 | 20 | Args: 21 | input_size: Dimensionality of input data 22 | n_dict_components: Number of dictionary components 23 | init_decoder_orthogonal: Initialize the decoder weights to be orthonormal 24 | """ 25 | 26 | super().__init__() 27 | # self.encoder[0].weight has shape: (n_dict_components, input_size) 28 | # self.decoder.weight has shape: (input_size, n_dict_components) 29 | 30 | self.encoder = nn.Sequential(nn.Linear(input_size, n_dict_components, bias=True), nn.ReLU()) 31 | self.decoder = nn.Linear(n_dict_components, input_size, bias=True) 32 | self.n_dict_components = n_dict_components 33 | self.input_size = input_size 34 | 35 | if init_decoder_orthogonal: 36 | # Initialize so that there are n_dict_components orthonormal vectors 37 | self.decoder.weight.data = nn.init.orthogonal_(self.decoder.weight.data.T).T 38 | 39 | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 40 | """Pass input through the encoder and normalized decoder.""" 41 | c = self.encoder(x) 42 | x_hat = F.linear(c, self.dict_elements, bias=self.decoder.bias) 43 | return x_hat, c 44 | 45 | @property 46 | def dict_elements(self): 47 | """Dictionary elements are simply the normalized decoder weights.""" 48 | return F.normalize(self.decoder.weight, dim=0) 49 | 50 | @property 51 | def device(self): 52 | return next(self.parameters()).device 53 | 54 | 55 | class Codebook(nn.Module): 56 | """ 57 | Codebook from Tamkin et al. (2023) 58 | 59 | It compute the cosine similarity between an input and a dictionary of features of 60 | size size n_dict_components. Then it simply takes the top k most similar codebook features 61 | and outputs their sum. The output thus has size input_size and consists of a simple sum of 62 | the top k codebook features. There is no encoder, just the dictionary of codebook features. 63 | """ 64 | 65 | def __init__(self, input_size: int, n_dict_components: int, k: int): 66 | super().__init__() 67 | 68 | self.codebook = nn.Parameter( 69 | torch.randn(n_dict_components, input_size) 70 | ) # (n_dict_components, input_size) 71 | self.n_dict_components = n_dict_components 72 | self.input_size = input_size 73 | self.k = k 74 | 75 | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, None]: 76 | # Compute cosine similarity between input and codebook features (batch_size, dict_size) 77 | cos_sim = F.cosine_similarity(x.unsqueeze(1), self.codebook, dim=2) 78 | 79 | # Take the top k most similar codebook features 80 | _, topk = torch.topk(cos_sim, self.k, dim=1) # (batch_size, k) 81 | 82 | # Sum the top k codebook features 83 | x_hat = torch.sum(self.codebook[topk], dim=1) # (batch_size, input_size) 84 | 85 | return x_hat, None 86 | 87 | @property 88 | def device(self): 89 | return next(self.parameters()).device 90 | -------------------------------------------------------------------------------- /e2e_sae/plotting.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable, Mapping, Sequence 2 | from pathlib import Path 3 | from typing import Any, Literal 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | import seaborn as sns 9 | from numpy.typing import NDArray 10 | 11 | from e2e_sae.log import logger 12 | 13 | 14 | def plot_per_layer_metric( 15 | df: pd.DataFrame, 16 | run_ids: Mapping[int, Mapping[str, str]], 17 | metric: str, 18 | final_layer: int, 19 | run_types: Sequence[str], 20 | out_file: str | Path | None = None, 21 | ylim: tuple[float | None, float | None] = (None, None), 22 | legend_label_cols_and_precision: list[tuple[str, int]] | None = None, 23 | legend_title: str | None = None, 24 | styles: Mapping[str, Mapping[str, Any]] | None = None, 25 | horz_layout: bool = False, 26 | show_ax_titles: bool = True, 27 | save_svg: bool = True, 28 | ) -> None: 29 | """ 30 | Plot the per-layer metric (explained variance or reconstruction loss) for different run types. 31 | 32 | Args: 33 | df: DataFrame containing the filtered data for the specific layer. 34 | run_ids: The run IDs to use. Format: {layer: {run_type: run_id}}. 35 | metric: The metric to plot ('explained_var' or 'recon_loss'). 36 | final_layer: The final layer to plot up to. 37 | run_types: The run types to include in the plot. 38 | out_file: The filename which the plot will be saved as. 39 | ylim: The y-axis limits. 40 | legend_label_cols_and_precision: Columns in df that should be used for the legend, along 41 | with their precision. Added in addition to the run type. 42 | legend_title: The title of the legend. 43 | styles: The styles to use. 44 | horz_layout: Whether to use a horizontal layout for the subplots. Requires sae_layers to be 45 | exactly [2, 6, 10]. Ignores legend_label_cols_and_precision if True. 46 | show_ax_titles: Whether to show titles for each subplot. 47 | save_svg: Whether to save the plot as an SVG file in addition to PNG. Default is True. 48 | """ 49 | metric_names = { 50 | "explained_var": "Explained Variance", 51 | "explained_var_ln": "Explained Variance\nof Normalized Activations", 52 | "recon_loss": "Reconstruction MSE", 53 | } 54 | metric_name = metric_names.get(metric, metric) 55 | 56 | sae_layers = list(run_ids.keys()) 57 | n_sae_layers = len(sae_layers) 58 | 59 | if horz_layout: 60 | assert sae_layers == [2, 6, 10] 61 | fig, axs = plt.subplots( 62 | 1, n_sae_layers, figsize=(10, 4), gridspec_kw={"width_ratios": [3, 2, 1.2]} 63 | ) 64 | legend_label_cols_and_precision = None 65 | else: 66 | fig, axs = plt.subplots(n_sae_layers, 1, figsize=(5, 3.5 * n_sae_layers)) 67 | axs = np.atleast_1d(axs) # type: ignore 68 | 69 | def plot_metric( 70 | ax: plt.Axes, 71 | plot_df: pd.DataFrame, 72 | sae_layer: int, 73 | xs: NDArray[np.signedinteger[Any]], 74 | ) -> None: 75 | for _, row in plot_df.iterrows(): 76 | run_type = row["run_type"] 77 | assert isinstance(run_type, str) 78 | legend_label = styles[run_type]["label"] if styles is not None else run_type 79 | if legend_label_cols_and_precision is not None: 80 | assert all( 81 | col in row for col, _ in legend_label_cols_and_precision 82 | ), f"Legend label cols not found in row: {row}" 83 | metric_strings = [ 84 | f"{col}={format(row[col], f'.{prec}f')}" 85 | for col, prec in legend_label_cols_and_precision 86 | ] 87 | legend_label += f" ({', '.join(metric_strings)})" 88 | ys = [row[f"{metric}_layer-{i}"] for i in range(sae_layer, final_layer + 1)] 89 | kwargs = styles[run_type] if styles is not None else {} 90 | ax.plot(xs, ys, **kwargs) 91 | 92 | for i, sae_layer in enumerate(sae_layers): 93 | layer_df = df.loc[df["id"].isin(list(run_ids[sae_layer].values()))] 94 | 95 | ax = axs[i] 96 | 97 | xs = np.arange(sae_layer, final_layer + 1) 98 | for run_type in run_types: 99 | plot_metric(ax, layer_df.loc[layer_df["run_type"] == run_type], sae_layer, xs) 100 | 101 | if show_ax_titles: 102 | ax.set_title(f"SAE Layer {sae_layer}", fontweight="bold") 103 | ax.set_xlabel("Model Layer") 104 | if (not horz_layout) or i == 0: 105 | ax.legend(title=legend_title, loc="best") 106 | ax.set_ylabel(metric_name) 107 | ax.set_xticks(xs) 108 | ax.set_xticklabels([str(x) for x in xs]) 109 | ax.set_ylim(ylim) 110 | 111 | plt.tight_layout() 112 | if out_file is not None: 113 | plt.savefig(out_file, dpi=400) 114 | logger.info(f"Saved to {out_file}") 115 | if save_svg: 116 | plt.savefig(Path(out_file).with_suffix(".svg")) 117 | plt.close() 118 | 119 | 120 | def plot_facet( 121 | df: pd.DataFrame, 122 | xs: Sequence[str], 123 | y: str, 124 | facet_by: str, 125 | line_by: str, 126 | line_by_vals: Sequence[str] | None = None, 127 | sort_by: str | None = None, 128 | xlabels: Sequence[str | None] | None = None, 129 | ylabel: str | None = None, 130 | suptitle: str | None = None, 131 | facet_vals: Sequence[Any] | None = None, 132 | xlims: Sequence[Mapping[Any, tuple[float | None, float | None]] | None] | None = None, 133 | xticks: Sequence[tuple[list[float], list[str]] | None] | None = None, 134 | yticks: tuple[list[float], list[str]] | None = None, 135 | ylim: Mapping[Any, tuple[float | None, float | None]] | None = None, 136 | styles: Mapping[Any, Mapping[str, Any]] | None = None, 137 | title: Mapping[Any, str] | None = None, 138 | legend_title: str | None = None, 139 | legend_pos: str = "lower right", 140 | axis_formatter: Callable[[Sequence[plt.Axes]], None] | None = None, 141 | out_file: str | Path | None = None, 142 | plot_type: Literal["line", "scatter"] = "line", 143 | save_svg: bool = True, 144 | ) -> None: 145 | """Line plot with multiple x-axes and one y-axis between them. One line for each run type. 146 | 147 | Args: 148 | df: DataFrame containing the data. 149 | xs: The variables to plot on the x-axes. 150 | y: The variable to plot on the y-axis. 151 | facet_by: The variable to facet the plot by. 152 | line_by: The variable to draw lines for. 153 | line_by_vals: The values to draw lines for. If None, all unique values will be used. 154 | sort_by: The variable governing how lines are drawn between points. If None, lines will be 155 | drawn based on the y value. 156 | title: The title of the plot. 157 | xlabel: The labels for the x-axes. 158 | ylabel: The label for the y-axis. 159 | out_file: The filename which the plot will be saved as. 160 | run_types: The run types to include in the plot. 161 | xlims: The x-axis limits for each x-axis for each layer. 162 | xticks: The x-ticks for each x-axis. 163 | yticks: The y-ticks for the y-axis. 164 | ylim: The y-axis limits for each layer. 165 | styles: The styles to use for each line. If None, default styles will be used. 166 | title: The title for each row of the plot. 167 | legend_title: The title for the legend. 168 | axis_formatter: A function to format the axes, e.g. to add "better" labels. 169 | out_file: The filename which the plot will be saved as. 170 | plot_type: The type of plot to create. Either "line" or "scatter". 171 | save_svg: Whether to save the plot as an SVG file in addition to png. Default is True. 172 | """ 173 | 174 | num_axes = len(xs) 175 | if facet_vals is None: 176 | facet_vals = sorted(df[facet_by].unique()) 177 | if sort_by is None: 178 | sort_by = y 179 | 180 | # TODO: For some reason the title is not centered at x=0.5. Fix 181 | xtitle_pos = 0.513 182 | 183 | sns.set_theme(style="darkgrid", rc={"axes.facecolor": "#f5f6fc"}) 184 | fig_width = 4 * num_axes 185 | fig = plt.figure(figsize=(fig_width, 4 * len(facet_vals)), constrained_layout=True) 186 | subfigs = fig.subfigures(len(facet_vals)) 187 | subfigs = np.atleast_1d(subfigs) # type: ignore 188 | 189 | # Get all unique line values from the entire DataFrame 190 | all_line_vals = df[line_by].unique() 191 | if line_by_vals is not None: 192 | assert all( 193 | val in all_line_vals for val in line_by_vals 194 | ), f"Invalid line values: {line_by_vals}" 195 | sorted_line_vals = line_by_vals 196 | else: 197 | sorted_line_vals = sorted(all_line_vals, key=str if df[line_by].dtype == object else float) 198 | 199 | colors = sns.color_palette("tab10", n_colors=len(sorted_line_vals)) 200 | for subfig, facet_val in zip(subfigs, facet_vals, strict=False): 201 | axs = subfig.subplots(1, num_axes) 202 | facet_df = df.loc[df[facet_by] == facet_val] 203 | for line_val, color in zip(sorted_line_vals, colors, strict=True): 204 | data = facet_df.loc[facet_df[line_by] == line_val] 205 | line_style = { 206 | "label": line_val, 207 | "marker": "o", 208 | "linewidth": 1.1, 209 | "color": color, 210 | "linestyle": "-" if plot_type == "line" else "None", 211 | } # default 212 | line_style.update( 213 | {} if styles is None else styles.get(line_val, {}) 214 | ) # specific overrides 215 | if not data.empty: 216 | # draw the lines between points based on the y value 217 | data = data.sort_values(sort_by) 218 | for i in range(num_axes): 219 | if plot_type == "scatter": 220 | axs[i].scatter(data[xs[i]], data[y], **line_style) 221 | elif plot_type == "line": 222 | axs[i].plot(data[xs[i]], data[y], **line_style) 223 | else: 224 | raise ValueError(f"Unknown plot type: {plot_type}") 225 | else: 226 | # Add empty plots for missing line values to ensure they appear in the legend 227 | for i in range(num_axes): 228 | axs[i].plot([], [], **line_style) 229 | 230 | if facet_val == facet_vals[-1]: 231 | axs[0].legend(title=legend_title or line_by, loc=legend_pos) 232 | 233 | for i in range(num_axes): 234 | if xlims is not None and xlims[i] is not None: 235 | xmin, xmax = xlims[i][facet_val] # type: ignore 236 | axs[i].set_xlim(xmin=xmin, xmax=xmax) 237 | if ylim is not None: 238 | ymin, ymax = ylim[facet_val] 239 | axs[i].set_ylim(ymin=ymin, ymax=ymax) 240 | 241 | # Set a title above the subplots to show the layer number 242 | row_title = title[facet_val] if title is not None else None 243 | subfig.suptitle(row_title, fontweight="bold", x=xtitle_pos) 244 | for i in range(num_axes): 245 | axs[i].set_xlabel(xlabels[i] if xlabels is not None else xs[i]) 246 | if i == 0: 247 | axs[i].set_ylabel(ylabel or y) 248 | if xticks is not None and xticks[i] is not None: 249 | ticks, labels = xticks[i] # type: ignore 250 | axs[i].set_xticks(ticks, labels=labels) 251 | if yticks is not None: 252 | axs[i].set_yticks(yticks[0], yticks[1]) 253 | 254 | if axis_formatter is not None: 255 | axis_formatter(axs) 256 | 257 | if suptitle is not None: 258 | fig.suptitle(suptitle, fontweight="bold", x=xtitle_pos) 259 | 260 | if out_file is not None: 261 | Path(out_file).parent.mkdir(parents=True, exist_ok=True) 262 | plt.savefig(out_file, dpi=400) 263 | logger.info(f"Saved to {out_file}") 264 | if save_svg: 265 | plt.savefig(Path(out_file).with_suffix(".svg")) 266 | 267 | plt.close(fig) 268 | -------------------------------------------------------------------------------- /e2e_sae/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloResearch/e2e_sae/83b4add4d2652a4e5c6775a0e2b752897c871a87/e2e_sae/scripts/__init__.py -------------------------------------------------------------------------------- /e2e_sae/scripts/analysis/activation_analysis.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import tqdm 8 | import wandb 9 | from jaxtyping import Float, Int 10 | from matplotlib import pyplot as plt 11 | from pydantic import BaseModel, ConfigDict 12 | from wandb.apis.public import Run 13 | 14 | from e2e_sae.data import DatasetConfig, create_data_loader 15 | from e2e_sae.hooks import SAEActs 16 | from e2e_sae.log import logger 17 | from e2e_sae.models.transformers import SAETransformer 18 | from e2e_sae.scripts.analysis.geometric_analysis import create_subplot_hists 19 | from e2e_sae.scripts.analysis.plot_settings import SIMILAR_CE_RUNS, STYLE_MAP 20 | 21 | ActTensor = Float[torch.Tensor, "batch seq hidden"] 22 | LogitTensor = Float[torch.Tensor, "batch seq vocab"] 23 | 24 | OUT_DIR = Path(__file__).parent / "out" / "activation_analysis" 25 | 26 | 27 | class Acts(BaseModel): 28 | model_config = ConfigDict(arbitrary_types_allowed=True) 29 | 30 | tokens: Int[torch.Tensor, "batch seq"] = torch.empty(0, 1024, dtype=torch.long) 31 | orig: ActTensor = torch.empty(0, 1024, 768) 32 | recon: ActTensor = torch.empty(0, 1024, 768) 33 | kl: Float[torch.Tensor, "batch seq"] = torch.empty(0, 1024) 34 | c_idxs: Int[torch.Tensor, "idxs"] | None = None # noqa: F821 35 | c: Float[torch.Tensor, "batch seq c"] | None = None 36 | orig_pred_tok_ids: Int[torch.Tensor, "batch seq k"] | None = None 37 | orig_pred_log_probs: Float[torch.Tensor, "batch seq k"] | None = None 38 | recon_pred_tok_ids: Int[torch.Tensor, "batch seq k"] | None = None 39 | recon_pred_log_probs: Float[torch.Tensor, "batch seq k"] | None = None 40 | k: int = 5 41 | 42 | def add( 43 | self, 44 | tokens: torch.Tensor, 45 | acts: SAEActs, 46 | kl: Float[torch.Tensor, "batch seq"], 47 | c_idxs: Int[torch.Tensor, "idxs"] | None = None, # noqa: F821 48 | orig_logits: LogitTensor | None = None, 49 | sae_logits: LogitTensor | None = None, 50 | ): 51 | self.tokens = torch.cat([self.tokens, tokens.cpu()]) 52 | self.orig = torch.cat([self.orig, acts.input.cpu()]) 53 | self.recon = torch.cat([self.recon, acts.output.cpu()]) 54 | self.kl = torch.cat([self.kl, kl.cpu()]) 55 | if c_idxs is not None: 56 | if self.c is None: 57 | self.c_idxs = c_idxs 58 | self.c = torch.empty(0, 1024, len(c_idxs), dtype=acts.c.dtype) 59 | self.c = torch.cat([self.c, acts.c[:, :, c_idxs].cpu()]) 60 | if orig_logits is not None: 61 | if self.orig_pred_tok_ids is None: 62 | self.orig_pred_tok_ids = torch.empty(0, 1024, self.k, dtype=torch.long) 63 | if self.orig_pred_log_probs is None: 64 | self.orig_pred_log_probs = torch.empty(0, 1024, self.k) 65 | orig_logprobs = F.log_softmax(orig_logits, dim=-1) 66 | orig_pred_log_probs, orig_pred_tok_ids = orig_logprobs.topk(self.k, dim=-1) 67 | self.orig_pred_tok_ids = torch.cat([self.orig_pred_tok_ids, orig_pred_tok_ids.cpu()]) 68 | self.orig_pred_log_probs = torch.cat( 69 | [self.orig_pred_log_probs, orig_pred_log_probs.cpu()] 70 | ) 71 | if sae_logits is not None: 72 | if self.recon_pred_tok_ids is None: 73 | self.recon_pred_tok_ids = torch.empty(0, 1024, self.k, dtype=torch.long) 74 | if self.recon_pred_log_probs is None: 75 | self.recon_pred_log_probs = torch.empty(0, 1024, self.k) 76 | recon_logprobs = F.log_softmax(sae_logits, dim=-1) 77 | recon_pred_log_probs, recon_pred_tok_ids = recon_logprobs.topk(self.k, dim=-1) 78 | self.recon_pred_tok_ids = torch.cat([self.recon_pred_tok_ids, recon_pred_tok_ids.cpu()]) 79 | self.recon_pred_log_probs = torch.cat( 80 | [self.recon_pred_log_probs, recon_pred_log_probs.cpu()] 81 | ) 82 | 83 | def __len__(self): 84 | return len(self.tokens) 85 | 86 | def __str__(self) -> str: 87 | return f"Acts(len={len(self)})" 88 | 89 | 90 | def kl_div(new_logits: LogitTensor, old_logits: LogitTensor) -> Float[torch.Tensor, "batch seq"]: 91 | return F.kl_div( 92 | F.log_softmax(new_logits, dim=-1), F.softmax(old_logits, dim=-1), reduction="none" 93 | ).sum(-1) 94 | 95 | 96 | @torch.no_grad() 97 | def get_acts( 98 | run: Run, 99 | batch_size=5, 100 | batches=1, 101 | device: str = "cuda", 102 | c_idxs: Int[torch.Tensor, "idxs"] | None = None, # noqa: F821 103 | load_cache: bool = True, 104 | save_cache: bool = True, 105 | ) -> Acts: 106 | cache_file = OUT_DIR / "cache" / f"{run.id}.pt" 107 | if load_cache and cache_file.exists(): 108 | cached_acts = Acts(**torch.load(cache_file)) 109 | if len(cached_acts) >= batches * batch_size: 110 | logger.info(f"Loaded cached acts from {cache_file}") 111 | return cached_acts 112 | logger.info(f"Cache file {cache_file} is incomplete, generating new acts...") 113 | 114 | model = SAETransformer.from_wandb(f"{run.project}/{run.id}") 115 | model.to(device) 116 | data_config = DatasetConfig(**run.config["eval_data"]) 117 | loader, _ = create_data_loader(data_config, batch_size=batch_size, global_seed=22) 118 | acts = Acts() 119 | assert len(model.raw_sae_positions) == 1 120 | sae_pos = model.raw_sae_positions[0] 121 | 122 | loader_iter = iter(loader) 123 | for _ in tqdm.trange(batches, disable=(batches == 1)): 124 | tokens = next(loader_iter)["input_ids"].to(device) 125 | orig_logits, _ = model.forward_raw(tokens, run_entire_model=True, final_layer=None) 126 | sae_logits, sae_cache = model.forward(tokens, [sae_pos]) 127 | sae_acts = sae_cache[sae_pos] 128 | assert isinstance(sae_acts, SAEActs) 129 | assert sae_logits is not None 130 | acts.add( 131 | tokens, 132 | sae_acts, 133 | kl=kl_div(sae_logits, orig_logits), 134 | c_idxs=c_idxs, 135 | orig_logits=orig_logits, 136 | sae_logits=sae_logits, 137 | ) 138 | 139 | if save_cache: 140 | torch.save(acts.model_dump(), cache_file) 141 | 142 | return acts 143 | 144 | 145 | def norm_scatterplot( 146 | acts: Acts, 147 | xlim: tuple[float | None, float | None] = (0, None), 148 | ylim: tuple[float | None, float | None] = (0, None), 149 | inset_extent: int = 150, 150 | out_file: Path | None = None, 151 | inset_pos: tuple[float, float, float, float] = (0.2, 0.2, 0.6, 0.7), 152 | figsize: tuple[float, float] = (6, 5), 153 | main_plot_diag_line: bool = False, 154 | scatter_alphas: tuple[float, float] = (0.3, 0.1), 155 | ): 156 | orig_norm = torch.norm(acts.orig.flatten(0, 1), dim=-1) 157 | recon_norms = torch.norm(acts.recon.flatten(0, 1), dim=-1) 158 | 159 | plt.subplots(figsize=figsize) 160 | ax = plt.gca() 161 | ax.scatter(orig_norm, recon_norms, alpha=scatter_alphas[0], s=3, c="k") 162 | ax.set_xlim(xlim) # type: ignore 163 | ax.set_ylim(ylim) # type: ignore 164 | ax.set_aspect("equal") 165 | if xlim[1] is not None: 166 | ax.set_xticks(range(0, xlim[1] + 1, 1000)) # type: ignore[reportCallIssue] 167 | if ylim[1] is not None: 168 | ax.set_yticks(range(0, ylim[1] + 1, 1000)) # type: ignore[reportCallIssue] 169 | 170 | axins = plt.gca().inset_axes( # type: ignore 171 | inset_pos, 172 | xlim=(0, inset_extent), 173 | ylim=(0, inset_extent), 174 | xticks=[0, inset_extent], 175 | yticks=[0, inset_extent], 176 | ) 177 | axins.scatter(orig_norm, recon_norms, alpha=scatter_alphas[1], s=1, c="k") 178 | axins.plot([0, inset_extent], [0, inset_extent], "k--", alpha=1, lw=0.8) 179 | axins.set_aspect("equal") 180 | if main_plot_diag_line: 181 | assert xlim[1] is not None 182 | ax.plot([0, xlim[1]], [0, xlim[1]], "k--", alpha=0.5, lw=0.8) 183 | 184 | ax.indicate_inset_zoom(axins, edgecolor="black") 185 | 186 | plt.xlabel("Norm of Original Acts, $||a(x)||_2$") 187 | plt.ylabel("Norm of Reconstructed Acts, $||\\hat{a}(x)||_2$") 188 | 189 | if out_file is not None: 190 | plt.savefig(out_file, bbox_inches="tight") 191 | plt.savefig(Path(out_file).with_suffix(".svg")) 192 | logger.info(f"Saved plot to {out_file}") 193 | 194 | 195 | def get_norm_ratios(acts: Acts) -> dict[str, float]: 196 | norm_ratios = torch.norm(acts.recon, dim=-1) / torch.norm(acts.orig, dim=-1) 197 | return {"pos0": norm_ratios[:, 0].mean().item(), "pos_gt_0": norm_ratios[:, 1:].mean().item()} 198 | 199 | 200 | ActsDict = dict[tuple[int, str], Acts] 201 | 202 | 203 | def get_acts_from_layer_type(layer: int, run_type: str, n_batches: int = 1): 204 | run_id = SIMILAR_CE_RUNS[layer][run_type] 205 | run = wandb.Api().run(f"sparsify/gpt2/{run_id}") 206 | device = "cuda" if torch.cuda.is_available() else "cpu" 207 | return get_acts(run, batch_size=5, batches=n_batches, device=device) 208 | 209 | 210 | def cosine_sim_plot(acts_dict: ActsDict, run_types: list[str], out_file: Path | None = None): 211 | colors = [STYLE_MAP[run_type]["color"] for run_type in run_types] 212 | 213 | def get_sims(acts: Acts): 214 | orig = acts.orig.flatten(0, 1) 215 | recon = acts.recon.flatten(0, 1) 216 | return F.cosine_similarity(orig, recon, dim=-1) 217 | 218 | fig = plt.figure(figsize=(8, 4), layout="constrained") 219 | subfigs = fig.subfigures(1, 3, wspace=0.05) 220 | 221 | for subfig, layer in zip(subfigs, [2, 6, 10], strict=True): 222 | sims = [get_sims(acts_dict[layer, run_type]) for run_type in run_types] 223 | create_subplot_hists( 224 | sim_list=sims, 225 | titles=[STYLE_MAP[run_type]["label"] for run_type in run_types], 226 | colors=colors, 227 | fig=subfig, 228 | suptitle=f"Layer {layer}", 229 | ) 230 | # subfigs[i].suptitle(f"Layer {layer_num}") 231 | fig.suptitle("Input-Output Similarities", fontweight="bold") 232 | if out_file is not None: 233 | plt.savefig(out_file) 234 | plt.savefig(out_file.with_suffix(".svg")) 235 | logger.info(f"Saved plot to {out_file}") 236 | 237 | 238 | def pca(x: Float[torch.Tensor, "n emb"], n_dims: int | None) -> Float[torch.Tensor, "emb emb"]: 239 | x = x - x.mean(0) 240 | cov_matrix = torch.cov(x.T) 241 | eigenvalues, eigenvectors = torch.linalg.eigh(cov_matrix) 242 | sorted_indices = torch.argsort(eigenvalues, descending=True) 243 | eigenvalues, eigenvectors = eigenvalues[sorted_indices], eigenvectors[:, sorted_indices] 244 | 245 | explained_var = eigenvalues[:n_dims].sum() / eigenvalues.sum() 246 | print(f"Explaining {explained_var.item():.2%} of variance") 247 | 248 | return eigenvectors[:, :n_dims] 249 | 250 | 251 | def pos_dir_plot(acts: Acts, out_file: Path | None = None): 252 | pca_dirs = pca(acts.orig.flatten(0, 1), n_dims=None).T 253 | 254 | seqpos_arr = torch.arange(acts.orig.shape[1]).expand((len(acts), -1)) 255 | 256 | fig, axs = plt.subplots(3, 1, figsize=(6, 4), sharex=True) 257 | axs = np.atleast_1d(axs) # type: ignore 258 | for dir_idx, ax in zip(range(1, 4), axs, strict=True): 259 | ax.plot( 260 | seqpos_arr.flatten(), 261 | acts.orig.flatten(0, 1) @ pca_dirs[dir_idx], 262 | ".k", 263 | ms=1, 264 | alpha=0.02, 265 | label="orig", 266 | ) 267 | ax.plot( 268 | seqpos_arr.flatten(), 269 | acts.recon.flatten(0, 1) @ pca_dirs[dir_idx], 270 | ".r", 271 | ms=1, 272 | alpha=0.02, 273 | label="recon", 274 | ) 275 | ax.set_ylabel(f"PCA dir {dir_idx}") 276 | ax.set_yticks([]) 277 | 278 | plt.xlabel("Seqence Position") 279 | 280 | leg = axs[0].legend() 281 | for lh in leg.legendHandles: 282 | lh.set_alpha(1) 283 | lh.set_markersize(5) 284 | 285 | plt.tight_layout() 286 | if out_file is not None: 287 | plt.savefig(out_file) 288 | plt.savefig(out_file.with_suffix(".svg")) 289 | logger.info(f"Saved plot to {out_file}") 290 | 291 | 292 | def create_latex_table(data: dict[int, dict[str, dict[str, float]]]): 293 | """Formats norms into the appropriate latex table body""" 294 | body = "" 295 | for sae_type in ["local", "e2e", "downstream"]: 296 | row = [sae_type] 297 | for pos in ["pos0", "pos_gt_0"]: 298 | row.extend([f"{data[layer][sae_type][pos]:.2f}" for layer in [2, 6, 10]]) 299 | body += " & ".join(row) + " \\\\\n" 300 | return body 301 | 302 | 303 | if __name__ == "__main__": 304 | run_types = list(SIMILAR_CE_RUNS[6].keys()) 305 | acts_dict: ActsDict = { 306 | (layer, run_type): get_acts_from_layer_type(layer, run_type, n_batches=20) 307 | for layer in SIMILAR_CE_RUNS 308 | for run_type in run_types 309 | } 310 | 311 | acts_6_e2e = acts_dict[6, "e2e"] 312 | assert acts_6_e2e is not None 313 | # norm_scatterplot( 314 | # acts_6_e2e, xlim=(0, 3200), ylim=(0, 2000), out_file=OUT_DIR / "norm_scatter.png" 315 | # ) 316 | 317 | norm_scatterplot( 318 | acts_dict[6, "downstream"], 319 | xlim=(0, 3200), 320 | ylim=(0, 3200), 321 | figsize=(5, 5), 322 | main_plot_diag_line=True, 323 | inset_pos=(0.2, 0.34, 0.6, 0.6), 324 | inset_extent=150, 325 | scatter_alphas=(0.1, 0.05), 326 | out_file=OUT_DIR / "norm_scatter_downstream.png", 327 | ) 328 | 329 | norms = { 330 | layer: {run_type: get_norm_ratios(acts_dict[(layer, run_type)]) for run_type in run_types} 331 | for layer in [2, 6, 10] 332 | } 333 | 334 | # Generate LaTeX table 335 | print(create_latex_table(norms)) 336 | 337 | with open(OUT_DIR / "norm_ratios.json", "w") as f: 338 | json.dump(norms, f) 339 | 340 | cosine_sim_plot(acts_dict, run_types, out_file=OUT_DIR / "cosine_similarity.png") 341 | 342 | pos_dir_plot(acts_dict[6, "e2e"], out_file=OUT_DIR / "pos_dir_e2e_l6.png") 343 | pos_dir_plot(acts_dict[6, "downstream"], out_file=OUT_DIR / "pos_dir_downstream_l6.png") 344 | pos_dir_plot(acts_dict[10, "e2e"], out_file=OUT_DIR / "pos_dir_e2e_l10.png") 345 | pos_dir_plot(acts_dict[10, "downstream"], out_file=OUT_DIR / "pos_dir_downstream_l10.png") 346 | -------------------------------------------------------------------------------- /e2e_sae/scripts/analysis/pca_dir0_explore.py: -------------------------------------------------------------------------------- 1 | """Analysis of PCA dir 0 in layer 10""" 2 | 3 | # %% 4 | from pathlib import Path 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import torch 9 | import wandb 10 | from einops import repeat 11 | from jaxtyping import Float 12 | from scipy.stats import pearsonr, spearmanr 13 | from torch.nn.functional import cosine_similarity, normalize 14 | 15 | from e2e_sae.log import logger 16 | from e2e_sae.scripts.analysis.activation_analysis import Acts, get_acts, pca 17 | from e2e_sae.scripts.analysis.geometric_analysis import ( 18 | EmbedInfo, 19 | get_alive_dict_elements, 20 | ) 21 | from e2e_sae.scripts.analysis.plot_settings import SIMILAR_CE_RUNS, STYLE_MAP 22 | 23 | # %% 24 | local_run_id = SIMILAR_CE_RUNS[10]["local"] 25 | downstream_run_id = SIMILAR_CE_RUNS[10]["downstream"] 26 | 27 | analysis_dir = Path(__file__).parent 28 | umap_data_dir = Path(analysis_dir / "out/umap") 29 | out_dir = analysis_dir / "out/pca_dir_0" 30 | 31 | umap_file = umap_data_dir / "constant_CE/downstream_local_umap_blocks.10.hook_resid_pre.pt" 32 | umap_info = EmbedInfo(**torch.load(umap_file)) 33 | 34 | api = wandb.Api() 35 | 36 | local_acts = get_acts(api.run(f"sparsify/gpt2/{local_run_id}")) 37 | local_dictionary = get_alive_dict_elements(api, "gpt2", local_run_id) 38 | local_embeds = umap_info.embedding[umap_info.alive_elements_per_dict[0] :, :] 39 | 40 | downstream_acts = get_acts(api.run(f"sparsify/gpt2/{downstream_run_id}")) 41 | downstream_dictionary = get_alive_dict_elements(api, "gpt2", downstream_run_id) 42 | downstream_embeds = umap_info.embedding[: umap_info.alive_elements_per_dict[0]] 43 | 44 | pca_dirs = pca(local_acts.orig.flatten(0, 1), n_dims=None).T 45 | 46 | outlier_pos_0_dir = normalize(local_acts.orig[:, 0, :].mean(0), p=2, dim=0) 47 | print(cosine_similarity(pca_dirs[0], outlier_pos_0_dir).item()) 48 | 49 | # %% 50 | ######## UMAP IN DIR PLOT ######## 51 | 52 | 53 | def umaps_in_dir( 54 | dir: Float[torch.Tensor, "emb"], # noqa: F821 55 | vabs: float | None = None, 56 | outfile: Path | None = None, 57 | ): 58 | fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 5), layout="constrained") 59 | axs = np.atleast_1d(axs) # type: ignore 60 | local_sims = cosine_similarity(local_dictionary.alive_dict_elements.T, dir, dim=-1) 61 | downstream_sims = cosine_similarity(downstream_dictionary.alive_dict_elements.T, dir, dim=-1) 62 | 63 | vabs = vabs or max(local_sims.abs().max().item(), downstream_sims.abs().max().item()) 64 | 65 | dir = pca_dirs[1] 66 | kwargs = {"s": 1, "vmin": -vabs, "vmax": vabs, "alpha": 0.3, "cmap": "coolwarm_r"} 67 | axs[0].scatter(local_embeds[:, 0], local_embeds[:, 1], c=local_sims, **kwargs) 68 | mappable = axs[1].scatter( 69 | downstream_embeds[:, 0], downstream_embeds[:, 1], c=downstream_sims, **kwargs 70 | ) 71 | cbar = plt.colorbar(mappable=mappable, label="Similarity to 0th PCA direction", shrink=0.8) 72 | cbar.solids.set(alpha=1) # type: ignore[reportOptionalMemberAccess] 73 | 74 | axs[0].set_xticks([]) 75 | axs[0].set_yticks([]) 76 | axs[1].set_xticks([]) 77 | axs[1].set_yticks([]) 78 | 79 | axs[0].set_title("Local SAE in layer 10") 80 | axs[1].set_title("Downstream SAE in layer 10") 81 | 82 | axs[0].set_box_aspect(1) 83 | axs[1].set_box_aspect(1) 84 | 85 | if outfile is not None: 86 | plt.savefig(outfile, dpi=300, bbox_inches="tight") 87 | plt.savefig(outfile.with_suffix(".svg"), bbox_inches="tight") 88 | logger.info(f"Saved UMAP plot to {outfile}") 89 | 90 | 91 | umaps_in_dir(pca_dirs[0], outfile=out_dir / "umap_pos0_dir.png", vabs=0.5) 92 | 93 | # %% 94 | ######## ACTIVATION HISTOGRAM ######## 95 | 96 | fig, axs = plt.subplots(1, 2, figsize=(7, 2), sharey=True) 97 | axs = np.atleast_1d(axs) # type: ignore 98 | fig.subplots_adjust(wspace=0.05) # adjust space between axes 99 | 100 | acts_in_dir = local_acts.orig @ pca_dirs[0] 101 | eot_token = local_acts.tokens.max() 102 | eot_mask = local_acts.tokens == eot_token 103 | n_batch, n_seq = local_acts.tokens.shape 104 | pos0_mask = repeat(torch.arange(n_seq), "seq -> batch seq", batch=n_batch) == 0 105 | 106 | ax0_xlim = (-20, 220) 107 | ax1_xlim = (2980, 3220) 108 | 109 | colors = plt.get_cmap("tab10").colors # type: ignore[reportAttributeAccessIssue] 110 | 111 | axs[1].hist( 112 | acts_in_dir[pos0_mask], 113 | label="position 0", 114 | density=True, 115 | bins=40, 116 | range=ax1_xlim, 117 | color=colors[4], 118 | ) 119 | axs[0].hist( 120 | acts_in_dir[eot_mask & ~pos0_mask], 121 | label="end-of-text", 122 | density=True, 123 | bins=40, 124 | range=ax0_xlim, 125 | color=colors[6], 126 | ) 127 | axs[0].hist( 128 | acts_in_dir[~eot_mask & ~pos0_mask], 129 | label="other", 130 | density=True, 131 | bins=40, 132 | range=ax0_xlim, 133 | color=colors[5], 134 | ) 135 | 136 | # hide the spines between ax and ax2 137 | axs[0].spines.right.set_visible(False) 138 | axs[1].spines.left.set_visible(False) 139 | axs[1].set_yticks([]) 140 | 141 | axs[0].set_xlim(*ax0_xlim) 142 | axs[1].set_xlim(*ax1_xlim) 143 | 144 | # axis break symbols 145 | d = 0.5 # proportion of vertical to horizontal extent of the slanted line 146 | kwargs = dict( 147 | marker=[(-d, -1), (d, 1)], 148 | markersize=12, 149 | linestyle="none", 150 | color="k", 151 | mec="k", 152 | mew=1, 153 | clip_on=False, 154 | ) 155 | axs[0].plot([1, 1], [0, 1], transform=axs[0].transAxes, **kwargs) 156 | axs[1].plot([0, 0], [0, 1], transform=axs[1].transAxes, **kwargs) 157 | 158 | # legend 159 | lines = axs[1].get_legend_handles_labels()[0] + axs[0].get_legend_handles_labels()[0] 160 | labels = axs[1].get_legend_handles_labels()[1] + axs[0].get_legend_handles_labels()[1] 161 | axs[1].legend(lines, labels) 162 | 163 | axs[1].set_xlabel("Activation in PCA direction 0", x=-0.025, horizontalalignment="center") 164 | 165 | plt.savefig(out_dir / "activation_hist.png", dpi=300, bbox_inches="tight") 166 | plt.savefig(out_dir / "activation_hist.svg", bbox_inches="tight") 167 | 168 | # %% 169 | ######## INPUT-OUTPUT CORRELATION IN PCA DIR 0 ######## 170 | 171 | 172 | def corr_in_dir( 173 | acts: Acts, direction: torch.Tensor, spearman: bool = False, include_pos0: bool = True 174 | ) -> float: 175 | pos_idx = 0 if include_pos0 else 1 176 | orig_in_d = (acts.orig[:, pos_idx:] @ direction).flatten() 177 | recon_in_d = (acts.recon[:, pos_idx:] @ direction).flatten() 178 | if spearman: 179 | return spearmanr(orig_in_d, recon_in_d).statistic # type: ignore[reportAttributeAccessIssue] 180 | else: 181 | return pearsonr(orig_in_d, recon_in_d).statistic # type: ignore[reportAttributeAccessIssue] 182 | 183 | 184 | print( 185 | "Local SAE, corr in 0th pca:", 186 | f"{corr_in_dir(local_acts, pca_dirs[0], include_pos0=False):.3f}", 187 | ) 188 | print( 189 | "Downstream SAE, corr in 0th pca:", 190 | f"{corr_in_dir(downstream_acts, pca_dirs[0], include_pos0=False):.3f}", 191 | ) 192 | 193 | print( 194 | "Downstream SAE, corr in 0th pca at position 0", 195 | pearsonr( 196 | (downstream_acts.orig[:, 0] @ pca_dirs[0]).flatten(), 197 | (downstream_acts.recon[:, 0] @ pca_dirs[0]).flatten(), 198 | ).statistic, # type: ignore[reportAttributeAccessIssue] 199 | ) 200 | 201 | # %% 202 | ######## INPUT-OUTPUT CORRELATION IN PCA DIRS PLOT ######## 203 | 204 | xs = range(25) 205 | corrs = { 206 | "local": [corr_in_dir(local_acts, pca_dirs[i], include_pos0=False) for i in range(50)], 207 | "downstream": [ 208 | corr_in_dir(downstream_acts, pca_dirs[i], include_pos0=False) for i in range(50) 209 | ], 210 | } 211 | 212 | plt.plot(xs, corrs["local"], **STYLE_MAP["local"]) # type: ignore[reportArgumentType] 213 | plt.plot(xs, corrs["downstream"], **STYLE_MAP["downstream"]) # type: ignore[reportArgumentType] 214 | plt.ylabel("input-output correlation") 215 | plt.xlabel("PCA direction") 216 | plt.legend(loc="lower right", title="SAE type") 217 | plt.ylim(0, 1) 218 | plt.gcf().set_size_inches(4, 3) 219 | plt.xlim(-1, None) 220 | plt.savefig(out_dir / "input_output_corr.png", dpi=300, bbox_inches="tight") 221 | plt.savefig(out_dir / "input_output_corr.svg", bbox_inches="tight") 222 | -------------------------------------------------------------------------------- /e2e_sae/scripts/analysis/plot_settings.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | # Runs with constant CE loss increase for each layer. Values represent wandb run IDs. 4 | SIMILAR_CE_RUNS = { 5 | 2: {"local": "ue3lz0n7", "e2e": "ovhfts9n", "downstream": "visi12en"}, 6 | 6: {"local": "1jy3m5j0", "e2e": "zgdpkafo", "downstream": "2lzle2f0"}, 7 | 10: {"local": "m2hntlav", "e2e": "8crnit9h", "downstream": "cvj5um2h"}, 8 | } 9 | # Runs with similar L0 loss increase for each layer. Values represent wandb run IDs. 10 | SIMILAR_L0_RUNS = { 11 | 2: {"local": "6vtk4k51", "e2e": "bst0prdd", "downstream": "e26jflpq"}, 12 | 6: {"local": "jup3glm9", "e2e": "tvj2owza", "downstream": "2lzle2f0"}, 13 | 10: {"local": "5vmpdgaz", "e2e": "8crnit9h", "downstream": "cvj5um2h"}, 14 | } 15 | # Runs with similar alive dictionary elements. Values represent wandb run IDs. 16 | SIMILAR_ALIVE_ELEMENTS_RUNS = { 17 | 2: {"local": "6vtk4k51", "e2e": "0z98g9pf", "downstream": "visi12en"}, 18 | 6: {"local": "h9hrelni", "e2e": "tvj2owza", "downstream": "p9zmh62k"}, 19 | 10: {"local": "5vmpdgaz", "e2e": "vnfh4vpi", "downstream": "f2fs7hk3"}, 20 | } 21 | 22 | SIMILAR_RUN_INFO = { 23 | "CE": SIMILAR_CE_RUNS, 24 | "l0": SIMILAR_L0_RUNS, 25 | "alive_elements": SIMILAR_ALIVE_ELEMENTS_RUNS, 26 | } 27 | 28 | STYLE_MAP = { 29 | "local": {"marker": "^", "color": "#f0a70a", "label": "local"}, 30 | "e2e": {"marker": "o", "color": "#518c31", "label": "e2e"}, 31 | "downstream": {"marker": "X", "color": plt.get_cmap("tab20b").colors[2], "label": "e2e+ds"}, # type: ignore[reportAttributeAccessIssue] 32 | } 33 | -------------------------------------------------------------------------------- /e2e_sae/scripts/analysis/resample_direction.py: -------------------------------------------------------------------------------- 1 | """Tests how robust network outputs are to resampling-ablating PCA directions""" 2 | from collections.abc import Callable 3 | from functools import partial 4 | from pathlib import Path 5 | 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import tqdm 9 | import transformer_lens as tl 10 | import wandb 11 | from jaxtyping import Float, Int 12 | from transformer_lens import HookedTransformer 13 | 14 | from e2e_sae.data import DatasetConfig, create_data_loader 15 | from e2e_sae.scripts.analysis.activation_analysis import get_acts, kl_div, pca 16 | from e2e_sae.scripts.analysis.plot_settings import SIMILAR_CE_RUNS 17 | 18 | ActTensor = Float[torch.Tensor, "batch seq hidden"] 19 | LogitTensor = Float[torch.Tensor, "batch seq vocab"] 20 | DirTensor = Float[torch.Tensor, "hidden"] 21 | TokenTensor = Int[torch.Tensor, "batch seq"] 22 | 23 | 24 | def shuffle_tensor(x: torch.Tensor) -> torch.Tensor: 25 | return x[torch.randperm(x.shape[0])] 26 | 27 | 28 | def apply_fn_to_mask( 29 | x: torch.Tensor, mask: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor] 30 | ) -> torch.Tensor: 31 | x_out = x[:] 32 | x_out[mask] = fn(x[mask]) 33 | return x_out 34 | 35 | 36 | def resample_direction_hook( 37 | x: ActTensor, hook: tl.hook_points.HookPoint, dir: DirTensor 38 | ) -> ActTensor: 39 | seqpos_arr = torch.arange(x.shape[1]).repeat([len(x), 1]).to(device) 40 | mask = seqpos_arr > 0 41 | x_within_dir = (x @ dir).unsqueeze(-1) * dir 42 | resid = x - x_within_dir 43 | shuffled_x_within_dir = apply_fn_to_mask(x_within_dir, mask, shuffle_tensor) 44 | return resid + shuffled_x_within_dir 45 | 46 | 47 | @torch.no_grad() 48 | def get_kl_diff_permuting_dir( 49 | model: HookedTransformer, 50 | hook_point_name: str, 51 | input_ids: TokenTensor, 52 | dir: DirTensor, 53 | device: str, 54 | ): 55 | input_ids = input_ids.to(device) 56 | orig_logits = model(input_ids) 57 | partial_hook = partial(resample_direction_hook, dir=dir.to(device)) 58 | hooked_logits = model.run_with_hooks(input_ids, fwd_hooks=[(hook_point_name, partial_hook)]) 59 | 60 | return kl_div(orig_logits, hooked_logits).mean().item() 61 | 62 | 63 | def get_batch(): 64 | dataset_config = DatasetConfig( 65 | dataset_name="apollo-research/Skylion007-openwebtext-tokenizer-gpt2", 66 | is_tokenized=True, 67 | tokenizer_name="gpt2", 68 | streaming=True, 69 | split="train", 70 | n_ctx=1024, 71 | seed=100, 72 | column_name="input_ids", 73 | ) 74 | 75 | data_loader, _ = create_data_loader(dataset_config, batch_size=30) 76 | return next(iter(data_loader))["input_ids"] 77 | 78 | 79 | def get_pca_dirs(): 80 | api = wandb.Api() 81 | local_run_id = SIMILAR_CE_RUNS[10]["local"] 82 | run = api.run(f"sparsify/gpt2/{local_run_id}") 83 | acts = get_acts(run) 84 | return pca(acts.orig.flatten(0, 1), n_dims=None).T 85 | 86 | 87 | if __name__ == "__main__": 88 | device = "cuda" if torch.cuda.is_available() else "cpu" 89 | gpt2 = HookedTransformer.from_pretrained("gpt2") 90 | hook_point_name = "blocks.10.hook_resid_pre" 91 | batch = get_batch() 92 | pca_dirs = get_pca_dirs() 93 | 94 | get_kl = partial( 95 | get_kl_diff_permuting_dir, 96 | model=gpt2, 97 | hook_point_name=hook_point_name, 98 | device=device, 99 | input_ids=batch, 100 | ) 101 | xs = range(25) 102 | kls = [get_kl(dir=pca_dirs[x]) for x in tqdm.tqdm(xs)] 103 | 104 | plt.plot(xs, kls, "o-k") 105 | plt.ylim(0, None) 106 | plt.xlabel("PCA direction") 107 | plt.ylabel("KL divergence") 108 | # plt.title("How much does permuting activations mess up the model?") 109 | plt.gcf().set_size_inches(4, 3) 110 | 111 | scripts_dir = Path(__file__).parent 112 | out_dir = scripts_dir / "out/pca_dir_0" 113 | 114 | plt.savefig(out_dir / "resample_sensitivity.png", dpi=300, bbox_inches="tight") 115 | plt.savefig(out_dir / "resample_sensitivity.svg", bbox_inches="tight") 116 | -------------------------------------------------------------------------------- /e2e_sae/scripts/analysis/utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import tqdm 3 | import wandb 4 | from wandb.apis.public import Run 5 | from wandb.apis.public.runs import Runs 6 | 7 | 8 | def _get_run_type( 9 | kl_coeff: float | None, in_to_orig_coeff: float | None, out_to_in_coeff: float | None 10 | ) -> str: 11 | if ( 12 | kl_coeff is not None 13 | and in_to_orig_coeff is not None 14 | and kl_coeff > 0 15 | and in_to_orig_coeff > 0 16 | ): 17 | if out_to_in_coeff is not None and out_to_in_coeff > 0: 18 | return "downstream_all" 19 | else: 20 | return "downstream" 21 | if ( 22 | kl_coeff is not None 23 | and out_to_in_coeff is not None 24 | and in_to_orig_coeff is None 25 | and kl_coeff > 0 26 | and out_to_in_coeff > 0 27 | ): 28 | return "e2e_local" 29 | if ( 30 | kl_coeff is not None 31 | and kl_coeff > 0 32 | and (out_to_in_coeff is None or out_to_in_coeff == 0) 33 | and in_to_orig_coeff is None 34 | ): 35 | return "e2e" 36 | return "local" 37 | 38 | 39 | def _get_run_type_using_names(run_name: str) -> str: 40 | if "logits-kl-1.0" in run_name and "in-to-orig" not in run_name: 41 | return "e2e" 42 | if "logits-kl-" in run_name and "in-to-orig" in run_name: 43 | return "downstream" 44 | return "local" 45 | 46 | 47 | def _extract_per_layer_metrics( 48 | run: Run, metric_key: str, metric_name_prefix: str, sae_layer: int, sae_pos: str 49 | ) -> dict[str, float]: 50 | """Extract the per layer metrics from the run summary metrics. 51 | 52 | Note that the layer indices correspond to those collected before that layer. E.g. those from 53 | hook_resid_pre rather than hook_resid_post. 54 | 55 | Args: 56 | run: The run to extract the metrics from. 57 | metric_key: The key to use to extract the metrics from the run summary. 58 | metric_name_prefix: The prefix to use for the metric names in the returned dictionary. 59 | sae_layer: The layer number of the SAE. 60 | sae_pos: The position of the SAE. 61 | """ 62 | 63 | results: dict[str, float] = {} 64 | for key, value in run.summary_metrics.items(): 65 | if not key.startswith(f"{metric_key}/blocks"): 66 | # We don't care about the other metrics 67 | continue 68 | layer_num_str, hook_pos = key.split(f"{metric_key}/blocks.")[1].split(".") 69 | if "pre" in hook_pos: 70 | layer_num = int(layer_num_str) 71 | elif "post" in hook_pos: 72 | layer_num = int(layer_num_str) + 1 73 | else: 74 | raise ValueError(f"Unknown hook position: {hook_pos}") 75 | results[f"{metric_name_prefix}-{layer_num}"] = value 76 | 77 | # Overwrite the SAE layer with the out_to_in for that layer. This is so that we get the 78 | # reconstruction/variance at the output of the SAE rather than the input 79 | out_to_in_prefix = metric_key.replace("in_to_orig", "out_to_in") 80 | results[f"{metric_name_prefix}-{sae_layer}"] = run.summary_metrics[ 81 | f"{out_to_in_prefix}/{sae_pos}" 82 | ] 83 | return results 84 | 85 | 86 | def create_run_df( 87 | runs: Runs, per_layer_metrics: bool = True, use_run_name: bool = False, grad_norm: bool = True 88 | ) -> pd.DataFrame: 89 | run_info = [] 90 | for run in tqdm.tqdm(runs, total=len(runs), desc="Processing runs"): 91 | if run.state != "finished": 92 | print(f"Run {run.name} is not finished, skipping") 93 | continue 94 | sae_pos = run.config["saes"]["sae_positions"] 95 | if isinstance(sae_pos, list): 96 | if len(sae_pos) > 1: 97 | raise ValueError("More than one SAE position found") 98 | sae_pos = sae_pos[0] 99 | sae_layer = int(sae_pos.split(".")[1]) 100 | 101 | kl_coeff = None 102 | in_to_orig_coeff = None 103 | out_to_in_coeff = None 104 | if "logits_kl" in run.config["loss"] and run.config["loss"]["logits_kl"] is not None: 105 | kl_coeff = run.config["loss"]["logits_kl"]["coeff"] 106 | if "in_to_orig" in run.config["loss"] and run.config["loss"]["in_to_orig"] is not None: 107 | in_to_orig_coeff = run.config["loss"]["in_to_orig"]["total_coeff"] 108 | if "out_to_in" in run.config["loss"] and run.config["loss"]["out_to_in"] is not None: 109 | out_to_in_coeff = run.config["loss"]["out_to_in"]["coeff"] 110 | 111 | if use_run_name: 112 | run_type = _get_run_type_using_names(run.name) 113 | else: 114 | run_type = _get_run_type(kl_coeff, in_to_orig_coeff, out_to_in_coeff) 115 | 116 | explained_var_layers = {} 117 | explained_var_ln_layers = {} 118 | recon_loss_layers = {} 119 | if per_layer_metrics: 120 | # The out_to_in in the below is to handle the e2e+recon loss runs which specified 121 | # future layers in the in_to_orig but not the output of the SAE at the current layer 122 | # (i.e. at hook_resid_post). Note that now if you leave in_to_orig as None, it will 123 | # default to calculating in_to_orig at all layers at hook_resid_post. 124 | # The explained variance at each layer 125 | explained_var_layers = _extract_per_layer_metrics( 126 | run=run, 127 | metric_key="loss/eval/in_to_orig/explained_variance", 128 | metric_name_prefix="explained_var_layer", 129 | sae_layer=sae_layer, 130 | sae_pos=sae_pos, 131 | ) 132 | 133 | explained_var_ln_layers = _extract_per_layer_metrics( 134 | run=run, 135 | metric_key="loss/eval/in_to_orig/explained_variance_ln", 136 | metric_name_prefix="explained_var_ln_layer", 137 | sae_layer=sae_layer, 138 | sae_pos=sae_pos, 139 | ) 140 | 141 | recon_loss_layers = _extract_per_layer_metrics( 142 | run=run, 143 | metric_key="loss/eval/in_to_orig", 144 | metric_name_prefix="recon_loss_layer", 145 | sae_layer=sae_layer, 146 | sae_pos=sae_pos, 147 | ) 148 | 149 | if "dict_size_to_input_ratio" in run.config["saes"]: 150 | ratio = float(run.config["saes"]["dict_size_to_input_ratio"]) 151 | else: 152 | # local runs didn't store the ratio in the config for these runs 153 | ratio = float(run.name.split("ratio-")[1].split("_")[0]) 154 | 155 | out_to_in = None 156 | explained_var = None 157 | explained_var_ln = None 158 | if f"loss/eval/out_to_in/{sae_pos}" in run.summary_metrics: 159 | out_to_in = run.summary_metrics[f"loss/eval/out_to_in/{sae_pos}"] 160 | explained_var = run.summary_metrics[f"loss/eval/out_to_in/explained_variance/{sae_pos}"] 161 | try: 162 | explained_var_ln = run.summary_metrics[ 163 | f"loss/eval/out_to_in/explained_variance_ln/{sae_pos}" 164 | ] 165 | except KeyError: 166 | explained_var_ln = None 167 | 168 | try: 169 | kl = run.summary_metrics["loss/eval/logits_kl"] 170 | except KeyError: 171 | kl = None 172 | 173 | mean_grad_norm = None 174 | if grad_norm: 175 | # Check if "mean_grad_norm" is in the run summary, if not, we need to calculate it 176 | if "mean_grad_norm" in run.summary: 177 | mean_grad_norm = run.summary["mean_grad_norm"] 178 | else: 179 | grad_norm_history = run.history(keys=["grad_norm"], samples=2000) 180 | # Get the mean of grad norms after the first 10000 steps 181 | mean_grad_norm = grad_norm_history.loc[ 182 | grad_norm_history["_step"] > 10000, "grad_norm" 183 | ].mean() 184 | 185 | run.summary["mean_grad_norm"] = mean_grad_norm 186 | run.summary.update() 187 | 188 | run_info.append( 189 | { 190 | "name": run.name, 191 | "id": run.id, 192 | "sae_pos": sae_pos, 193 | "model_name": run.config["tlens_model_name"], 194 | "run_type": run_type, 195 | "layer": sae_layer, 196 | "seed": run.config["seed"], 197 | "n_samples": run.config["n_samples"], 198 | "lr": run.config["lr"], 199 | "ratio": ratio, 200 | "sparsity_coeff": run.config["loss"]["sparsity"]["coeff"], 201 | "in_to_orig_coeff": in_to_orig_coeff, 202 | "kl_coeff": kl_coeff, 203 | "out_to_in": out_to_in, 204 | "L0": run.summary_metrics[f"sparsity/eval/L_0/{sae_pos}"], 205 | "explained_var": explained_var, 206 | "explained_var_ln": explained_var_ln, 207 | "CE_diff": run.summary_metrics["performance/eval/difference_ce_loss"], 208 | "CELossIncrease": -run.summary_metrics["performance/eval/difference_ce_loss"], 209 | "alive_dict_elements": run.summary_metrics[ 210 | f"sparsity/alive_dict_elements/{sae_pos}" 211 | ], 212 | "mean_grad_norm": mean_grad_norm, 213 | **explained_var_layers, 214 | **explained_var_ln_layers, 215 | **recon_loss_layers, 216 | "sum_recon_loss": sum(recon_loss_layers.values()), 217 | "kl": kl, 218 | } 219 | ) 220 | df = pd.DataFrame(run_info) 221 | return df 222 | 223 | 224 | def get_df_gpt2() -> pd.DataFrame: 225 | api = wandb.Api() 226 | project = "sparsify/gpt2" 227 | runs = api.runs(project) 228 | 229 | d_resid = 768 230 | 231 | df = create_run_df(runs) 232 | 233 | assert df["model_name"].nunique() == 1 234 | 235 | # Ignore runs that have an L0 bigger than d_resid 236 | df = df.loc[df["L0"] <= d_resid] 237 | return df 238 | -------------------------------------------------------------------------------- /e2e_sae/scripts/train_mlp_saes/max_act_data_mnist.py: -------------------------------------------------------------------------------- 1 | """Collect and analyze max activating dataset examples. 2 | 3 | Usage: 4 | python max_act_data_mnist.py 5 | """ 6 | 7 | from datetime import datetime 8 | from pathlib import Path 9 | 10 | import fire 11 | import torch 12 | from pydantic import ( 13 | BaseModel, 14 | ConfigDict, 15 | NonNegativeInt, 16 | PositiveInt, 17 | ) 18 | from torch.utils.data import DataLoader 19 | from torchvision import datasets, transforms 20 | 21 | from e2e_sae.log import logger 22 | from e2e_sae.models.mlp import MLPMod 23 | from e2e_sae.settings import REPO_ROOT 24 | from e2e_sae.utils import load_config 25 | 26 | 27 | class ModelConfig(BaseModel): 28 | model_config = ConfigDict(extra="forbid", frozen=True) 29 | hidden_sizes: list[PositiveInt] | None 30 | 31 | 32 | class InferenceConfig(BaseModel): 33 | model_config = ConfigDict(extra="forbid", frozen=True) 34 | batch_size: PositiveInt 35 | model_name: str 36 | save_dir: Path | None 37 | 38 | 39 | class Config(BaseModel): 40 | model_config = ConfigDict(extra="forbid", frozen=True) 41 | seed: NonNegativeInt 42 | model: ModelConfig 43 | infer: InferenceConfig 44 | 45 | 46 | def max_act(config: Config) -> None: 47 | """Collect and analyze max activating dataset examples.""" 48 | torch.manual_seed(config.seed) 49 | device = "cuda" if torch.cuda.is_available() else "cpu" 50 | logger.info("Using device: %s", device) 51 | 52 | # Load the MNIST dataset 53 | transform = transforms.ToTensor() 54 | train_data = datasets.MNIST( 55 | root=str(REPO_ROOT / ".data"), train=True, download=True, transform=transform 56 | ) 57 | DataLoader(train_data, batch_size=config.infer.batch_size, shuffle=True) 58 | test_data = datasets.MNIST( 59 | root=str(REPO_ROOT / ".data"), train=False, download=True, transform=transform 60 | ) 61 | test_loader = DataLoader(test_data, batch_size=config.infer.batch_size, shuffle=False) 62 | 63 | # Define model path to load 64 | model_path = REPO_ROOT / "models" / config.infer.model_name 65 | model_path = max(model_path.glob("*.pt"), key=lambda x: int(x.stem.split("_")[-1])) 66 | model_mod = MLPMod(config.model.hidden_sizes, input_size=784, output_size=10) 67 | model_mod_state_dict = torch.load(model_path) 68 | model_mod.load_state_dict(model_mod_state_dict) 69 | model_mod = model_mod.to(device) 70 | 71 | datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 72 | 73 | # Get number of dataset samples in testset 74 | num_dataset_samples = len(test_loader.dataset) # type: ignore 75 | # Get size of each autoencoder dictionary 76 | sae_sizes = {} 77 | for key in model_mod.sparsifiers: 78 | sae_sizes[key] = model_mod.sparsifiers[key].n_dict_components 79 | 80 | # Make a dictionary for each autoencoder to store the cs for each dataset sample 81 | cs_dict = {} 82 | for key in model_mod.sparsifiers: 83 | cs_dict[key] = torch.zeros(num_dataset_samples, sae_sizes[key]) 84 | 85 | # Make a dictionary to count the dead cs for each autoencoder 86 | dead_dict = {} 87 | for key in model_mod.sparsifiers: 88 | dead_dict[key] = torch.zeros(sae_sizes[key]) 89 | 90 | samples = 0 91 | for i, (images, labels) in enumerate(test_loader): 92 | images, labels = images.to(device), labels.to(device) 93 | samples += images.shape[0] 94 | images = images.view(images.shape[0], -1) 95 | mod_acts, cs, saes_outs = model_mod(images) 96 | 97 | # Identify the dead dictionary elements (the cs with no activations) by summing over the 98 | # dataset samples and counting the number of zeros for each dictionary element 99 | for key in cs: 100 | dead_dict[key] += torch.sum(cs[key], dim=0).cpu() 101 | 102 | # Store the cs for each dataset sample 103 | for key in cs: 104 | cs_dict[key][i * config.infer.batch_size : (i + 1) * config.infer.batch_size] = cs[key] 105 | 106 | # Divide the dead dictionary elements by the number of dataset samples to get the percentage of 107 | # dead dictionary elements 108 | for key in dead_dict: 109 | dead_dict[key] = dead_dict[key] / num_dataset_samples 110 | 111 | # Plot a histogram of the dead dictionary elements for each autoencoder 112 | # import matplotlib.pyplot as plt 113 | # fig, axs = plt.subplots(1, len(dead_dict.keys())) 114 | # for i, key in enumerate(dead_dict.keys()): 115 | # axs[i].hist(dead_dict[key].detach().numpy(), bins=20) 116 | # axs[i].set_title(key) 117 | # plt.show() 118 | # plt.savefig("dead_dict_elements.png") 119 | 120 | 121 | def main(config_path_str: str) -> None: 122 | config_path = Path(config_path_str) # TODO make separate config for model_mod 123 | config = load_config(config_path, config_model=Config) 124 | max_act(config) 125 | 126 | 127 | if __name__ == "__main__": 128 | fire.Fire(main) 129 | -------------------------------------------------------------------------------- /e2e_sae/scripts/train_mlp_saes/max_act_data_mnist.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | model: 3 | hidden_sizes: 4 | - 100 5 | - 100 6 | infer: 7 | batch_size: 1024 8 | model_name: sparse-lambda-50.0-lr-0.001_bs-1024-[100, 100]_2024-01-14_22-09-22 -------------------------------------------------------------------------------- /e2e_sae/scripts/train_mlp_saes/mnist_saes.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | saved_model_dir: /path/to/saved_model_dir 3 | train: 4 | learning_rate: 0.001 5 | batch_size: 1024 6 | n_epochs: 8 7 | type_of_sparsifier: sae 8 | sparsity_lambda: 0.005 9 | dict_eles_to_input_ratio: 10.0 10 | sparsifier_in_out_recon_loss_scale: 0.0 11 | k: 128 12 | wandb_project: mlp-mnist-sae 13 | -------------------------------------------------------------------------------- /e2e_sae/scripts/train_mlp_saes/train_mnist_saes.py: -------------------------------------------------------------------------------- 1 | """Train SAEs on an MNIST model. 2 | 3 | NOTE: To run this, you must first train an MLP model on MNIST. You can do this with the 4 | `e2e_sae/scripts/train_mnist/run_train_mnist.py` script. 5 | 6 | Usage: 7 | python train_mnist_saes.py 8 | """ 9 | 10 | import os 11 | from collections import OrderedDict 12 | from collections.abc import Callable 13 | from datetime import datetime 14 | from pathlib import Path 15 | 16 | import fire 17 | import torch 18 | import wandb 19 | import yaml 20 | from dotenv import load_dotenv 21 | from jaxtyping import Float 22 | from pydantic import ( 23 | BaseModel, 24 | ConfigDict, 25 | NonNegativeFloat, 26 | NonNegativeInt, 27 | PositiveFloat, 28 | PositiveInt, 29 | ) 30 | from torch import Tensor, nn 31 | from torch.utils.data import DataLoader 32 | from torchvision import datasets, transforms 33 | from tqdm import tqdm 34 | 35 | from e2e_sae.log import logger 36 | from e2e_sae.models.mlp import MLP, MLPMod 37 | from e2e_sae.settings import REPO_ROOT 38 | from e2e_sae.types import RootPath 39 | from e2e_sae.utils import load_config, save_module, set_seed 40 | 41 | 42 | class TrainConfig(BaseModel): 43 | model_config = ConfigDict(extra="forbid", frozen=True) 44 | learning_rate: float 45 | batch_size: PositiveInt 46 | n_epochs: PositiveInt 47 | save_dir: RootPath | None = Path(__file__).parent / "out" 48 | type_of_sparsifier: str 49 | sparsity_lambda: NonNegativeFloat 50 | dict_eles_to_input_ratio: PositiveFloat 51 | sparsifier_in_out_recon_loss_scale: NonNegativeFloat 52 | k: PositiveInt 53 | 54 | 55 | class Config(BaseModel): 56 | model_config = ConfigDict(extra="forbid", frozen=True) 57 | seed: NonNegativeInt 58 | saved_model_dir: RootPath 59 | train: TrainConfig 60 | wandb_project: str | None # If None, don't log to Weights & Biases 61 | 62 | 63 | def get_activation( 64 | name: str, activations: OrderedDict[str, torch.Tensor] 65 | ) -> Callable[[nn.Module, tuple[torch.Tensor, ...], torch.Tensor], None]: 66 | """function to be called when the forward pass reaches a layer""" 67 | 68 | def hook(model: nn.Module, input: tuple[torch.Tensor, ...], output: torch.Tensor) -> None: 69 | activations[name] = output.detach() 70 | 71 | return hook 72 | 73 | 74 | def load_data(config: Config) -> tuple[DataLoader[datasets.MNIST], DataLoader[datasets.MNIST]]: 75 | transform = transforms.ToTensor() 76 | train_data = datasets.MNIST( 77 | root=str(REPO_ROOT / ".data"), train=True, download=True, transform=transform 78 | ) 79 | train_loader = DataLoader(train_data, batch_size=config.train.batch_size, shuffle=True) 80 | test_data = datasets.MNIST( 81 | root=str(REPO_ROOT / ".data"), train=False, download=True, transform=transform 82 | ) 83 | test_loader = DataLoader(test_data, batch_size=config.train.batch_size, shuffle=False) 84 | return train_loader, test_loader 85 | 86 | 87 | def get_models( 88 | config: Config, device: str | torch.device 89 | ) -> tuple[MLP, MLPMod, OrderedDict[str, torch.Tensor]]: 90 | # Load the hidden_sizes form the trained model 91 | with open(config.saved_model_dir / "final_config.yaml") as f: 92 | hidden_sizes = yaml.safe_load(f)["model"]["hidden_sizes"] 93 | 94 | latest_model_path = max( 95 | config.saved_model_dir.glob("*.pt"), key=lambda x: int(x.stem.split("_")[-1]) 96 | ) 97 | # Initialize the MLP model 98 | model = MLP(hidden_sizes, input_size=784, output_size=10) 99 | model = model.to(device) 100 | model_trained_statedict = torch.load(latest_model_path) 101 | model.load_state_dict(model_trained_statedict) 102 | model.eval() 103 | 104 | # Add hooks to the model so we can get all intermediate activations 105 | activations = OrderedDict() 106 | for name, layer in model.layers.named_children(): 107 | layer.register_forward_hook(get_activation(name, activations)) 108 | 109 | # Get the SAEs from the model_mod and put them in the statedict of model_trained 110 | model_mod = MLPMod( 111 | hidden_sizes=hidden_sizes, 112 | input_size=784, 113 | output_size=10, 114 | type_of_sparsifier=config.train.type_of_sparsifier, 115 | k=config.train.k, 116 | dict_eles_to_input_ratio=config.train.dict_eles_to_input_ratio, 117 | ) 118 | for k, v in model_mod.state_dict().items(): 119 | if k.startswith("sparsifiers"): 120 | model_trained_statedict[k] = v 121 | 122 | model_mod.load_state_dict(model_trained_statedict) 123 | model_mod = model_mod.to(device) 124 | return model, model_mod, activations 125 | 126 | 127 | def train(config: Config) -> None: 128 | """Train the MLP on MNIST. 129 | 130 | If config.wandb is not None, log the results to Weights & Biases. 131 | """ 132 | device = "cuda" if torch.cuda.is_available() else "cpu" 133 | logger.info("Using device: %s", device) 134 | 135 | # Load the MNIST dataset 136 | train_loader, test_loader = load_data(config) 137 | 138 | # Initialize the MLP model and modified model 139 | model, model_mod, activations = get_models(config, device) 140 | 141 | model_mod.sparsifiers.train() 142 | 143 | for param in model.layers.parameters(): 144 | param.requires_grad = False 145 | 146 | # Define the loss and optimizer 147 | criterion = nn.MSELoss() 148 | # Note: only pass the SAE parameters to the optimizer 149 | optimizer = torch.optim.Adam(model_mod.sparsifiers.parameters(), lr=config.train.learning_rate) 150 | 151 | run_name = ( 152 | f"sae_lambda-{config.train.sparsity_lambda}_lr-{config.train.learning_rate}" 153 | f"_bs-{config.train.batch_size}" 154 | ) 155 | if config.wandb_project: 156 | load_dotenv() 157 | wandb.init( 158 | name=run_name, 159 | project=config.wandb_project, 160 | entity=os.getenv("WANDB_ENTITY"), 161 | config=config.model_dump(mode="json"), 162 | ) 163 | 164 | timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 165 | save_dir = config.train.save_dir / f"{run_name}_{timestamp}" if config.train.save_dir else None 166 | 167 | samples = 0 168 | # Training loop 169 | for epoch in tqdm( 170 | range(1, config.train.n_epochs + 1), total=config.train.n_epochs, desc="Epochs" 171 | ): 172 | for i, (images, labels) in enumerate(train_loader): 173 | images, labels = images.to(device), labels.to(device) 174 | 175 | samples += images.shape[0] 176 | # Flatten the images 177 | images = images.view(images.shape[0], -1) 178 | 179 | model(images) # Consider passing orig outputs as input, so that you can interpolate. 180 | mod_acts, cs, sparsifiers_outs = model_mod(images) 181 | # Get final item of dict 182 | mod_out = mod_acts[list(mod_acts.keys())[-1]] 183 | 184 | # Get loss that compares each item of the hooked activations with the corresponding item 185 | # of the outputs 186 | loss: Float[Tensor, ""] = torch.zeros(1, requires_grad=True, device=device) 187 | assert len(activations) == len( 188 | mod_acts 189 | ), "Number of activations and modified activations must be the same" 190 | 191 | # Create dictionaries for the different losses so we can log in wand later 192 | sp_orig_losses: dict[str, torch.Tensor] = {} 193 | new_orig_losses: dict[str, torch.Tensor] = {} 194 | sp_new_losses: dict[str, torch.Tensor] = {} 195 | sparsity_losses: dict[str, torch.Tensor] = {} 196 | zeros_counts: dict[str, torch.Tensor] = {} 197 | zeros_fracs: dict[str, torch.Tensor] = {} 198 | 199 | for layer in range(len(activations)): 200 | if layer < len(activations) - 1: 201 | # sae-orig reconstruction loss 202 | sp_orig_losses[str(layer)] = criterion( 203 | sparsifiers_outs[str(layer)], activations[str(layer)] 204 | ) 205 | loss = loss + sp_orig_losses[str(layer)] 206 | 207 | # new-orig reconstruction loss 208 | new_orig_losses[str(layer)] = criterion( 209 | mod_acts[str(layer)], activations[str(layer)] 210 | ) 211 | loss = loss + new_orig_losses[str(layer)] 212 | 213 | # sae-new reconstruction loss 214 | if layer < len(activations) - 1: 215 | if config.train.type_of_sparsifier == "sae": 216 | sp_new_losses[str(layer)] = criterion( 217 | sparsifiers_outs[str(layer)], mod_acts[str(layer)] 218 | ) 219 | elif config.train.type_of_sparsifier == "codebook": 220 | # Auxiliary recon loss described in Tamkin et al. (2023) p. 3 221 | sp_new_losses[str(layer)] = criterion( 222 | sparsifiers_outs[str(layer)], mod_acts[str(layer)].detach() 223 | ) 224 | loss = loss + ( 225 | sp_new_losses[str(layer)] * config.train.sparsifier_in_out_recon_loss_scale 226 | ) 227 | 228 | # Add L_p norm loss 229 | if config.train.type_of_sparsifier == "sae": 230 | for layer in range(len(cs)): 231 | sparsity_losses[str(layer)] = torch.norm(cs[str(layer)], p=0.6, dim=1).mean() 232 | loss = loss + config.train.sparsity_lambda * sparsity_losses[str(layer)] 233 | 234 | # Calculate counts and fractions of zero entries in the saes per batch 235 | for layer in range(len(cs)): 236 | zeros_counts[str(layer)] = torch.sum(cs[str(layer)] == 0) 237 | zeros_fracs[str(layer)] = ( 238 | torch.sum(cs[str(layer)] == 0) / cs[str(layer)].numel() 239 | ) 240 | 241 | # Calculate accuracy 242 | _, argmax = torch.max(mod_out, 1) 243 | 244 | accuracy = (labels == argmax.squeeze()).float().mean() 245 | 246 | optimizer.zero_grad() 247 | loss.backward() 248 | optimizer.step() 249 | 250 | if i % 10 == 0: 251 | logger.info( 252 | "Epoch [%d/%d], Step [%d/%d], Loss: %f, Accuracy: %f", 253 | epoch, 254 | config.train.n_epochs, 255 | i + 1, 256 | len(train_loader), 257 | loss.item(), 258 | accuracy, 259 | ) 260 | 261 | if config.wandb_project: 262 | wandb.log({"train/loss": loss.item()}, step=samples) 263 | wandb.log({"train/accuracy": accuracy}, step=samples) 264 | for k, v in sp_orig_losses.items(): 265 | wandb.log({f"train/loss-sae-orig-{k}": v.item()}, step=samples) 266 | for k, v in new_orig_losses.items(): 267 | wandb.log({f"train/loss-new-orig-{k}": v.item()}, step=samples) 268 | for k, v in sp_new_losses.items(): 269 | wandb.log({f"train/loss-sae-new-{k}": v.item()}, step=samples) 270 | for k, v in sparsity_losses.items(): 271 | wandb.log({f"train/loss-sparsity-loss-{k}": v.item()}, step=samples) 272 | for k, v in zeros_counts.items(): 273 | wandb.log({f"train/zero-counts-{k}": v.item()}, step=samples) 274 | for k, v in zeros_fracs.items(): 275 | wandb.log({f"train/fraction-zeros-{k}": v.item()}, step=samples) 276 | 277 | # Validate the model 278 | model_mod.sparsifiers.eval() 279 | with torch.no_grad(): 280 | correct = 0 281 | total = 0 282 | for images, labels in test_loader: 283 | images, labels = images.to(device), labels.to(device) 284 | images = images.view(images.shape[0], -1) 285 | mod_acts, cs, sparsifiers_outs = model_mod(images) 286 | mod_out = mod_acts[list(mod_acts.keys())[-1]] 287 | _, argmax = torch.max(mod_out, 1) 288 | total += labels.size(0) 289 | correct += (labels == argmax.squeeze()).sum().item() 290 | 291 | accuracy = correct / total 292 | logger.info("Accuracy of the network on the 10000 test images: %f %%", 100 * accuracy) 293 | 294 | if config.wandb_project: 295 | wandb.log({"valid/accuracy": accuracy}, step=samples) 296 | model_mod.sparsifiers.train() 297 | 298 | if save_dir: 299 | save_module( 300 | config_dict=config.model_dump(mode="json"), 301 | save_dir=save_dir, 302 | module=model_mod.sparsifiers, 303 | model_filename=f"epoch_{config.train.n_epochs}.pt", 304 | ) 305 | if config.wandb_project: 306 | wandb.finish() 307 | 308 | 309 | def main(config_path_str: str) -> None: 310 | config_path = Path(config_path_str) 311 | config = load_config(config_path, config_model=Config) 312 | set_seed(config.seed) 313 | train(config) 314 | 315 | 316 | if __name__ == "__main__": 317 | fire.Fire(main) 318 | -------------------------------------------------------------------------------- /e2e_sae/scripts/train_mnist/mnist.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | model: 3 | hidden_sizes: 4 | - 100 5 | - 100 6 | train: 7 | learning_rate: 0.001 8 | batch_size: 1024 9 | n_epochs: 30 10 | wandb_project: mlp-mnist-sae 11 | -------------------------------------------------------------------------------- /e2e_sae/scripts/train_mnist/run_train_mnist.py: -------------------------------------------------------------------------------- 1 | """Train a model on MNIST. 2 | 3 | This script takes ~40 seconds to run for 3 layers and 15 epochs on a CPU. 4 | 5 | Usage: 6 | python run_train_mnist.py 7 | """ 8 | 9 | import os 10 | from datetime import datetime 11 | from pathlib import Path 12 | 13 | import fire 14 | import torch 15 | import wandb 16 | from dotenv import load_dotenv 17 | from pydantic import BaseModel, ConfigDict, NonNegativeInt, PositiveFloat, PositiveInt 18 | from torch import nn 19 | from torch.utils.data import DataLoader 20 | from torchvision import datasets, transforms 21 | from tqdm import tqdm 22 | 23 | from e2e_sae.log import logger 24 | from e2e_sae.models.mlp import MLP 25 | from e2e_sae.settings import REPO_ROOT 26 | from e2e_sae.types import RootPath 27 | from e2e_sae.utils import load_config, save_module, set_seed 28 | 29 | 30 | class ModelConfig(BaseModel): 31 | model_config = ConfigDict(extra="forbid", frozen=True) 32 | hidden_sizes: list[PositiveInt] | None 33 | 34 | 35 | class TrainConfig(BaseModel): 36 | model_config = ConfigDict(extra="forbid", frozen=True) 37 | learning_rate: PositiveFloat 38 | batch_size: PositiveInt 39 | n_epochs: PositiveInt 40 | save_dir: RootPath | None = Path(__file__).parent / "out" 41 | 42 | 43 | class Config(BaseModel): 44 | model_config = ConfigDict(extra="forbid", frozen=True) 45 | seed: NonNegativeInt 46 | model: ModelConfig 47 | train: TrainConfig 48 | wandb_project: str | None # If None, don't log to Weights & Biases 49 | 50 | 51 | def train(config: Config) -> None: 52 | """Train the MLP on MNIST. 53 | 54 | If config.wandb is not None, log the results to Weights & Biases. 55 | """ 56 | device = "cuda" if torch.cuda.is_available() else "cpu" 57 | logger.info("Using device: %s", device) 58 | 59 | # Load the MNIST dataset 60 | data_path = str(REPO_ROOT / ".data") 61 | transform = transforms.ToTensor() 62 | train_data = datasets.MNIST(root=data_path, train=True, download=True, transform=transform) 63 | train_loader = DataLoader(train_data, batch_size=config.train.batch_size, shuffle=True) 64 | test_data = datasets.MNIST(root=data_path, train=False, download=True, transform=transform) 65 | test_loader = DataLoader(test_data, batch_size=config.train.batch_size, shuffle=False) 66 | valid_data = datasets.MNIST(root=data_path, train=False, download=True, transform=transform) 67 | valid_loader = DataLoader(valid_data, batch_size=config.train.batch_size, shuffle=False) 68 | 69 | # Initialize the MLP model 70 | model = MLP(config.model.hidden_sizes, input_size=784, output_size=10) 71 | model = model.to(device) 72 | model.train() 73 | 74 | # Define the loss and optimizer 75 | criterion = nn.CrossEntropyLoss() 76 | optimizer = torch.optim.Adam(model.parameters(), lr=config.train.learning_rate) 77 | 78 | hidden_repr = ( 79 | "-".join(str(x) for x in config.model.hidden_sizes) if config.model.hidden_sizes else None 80 | ) 81 | 82 | run_name = ( 83 | f"orig-train-lr-{config.train.learning_rate}_bs-{config.train.batch_size}" 84 | f"_hidden-{hidden_repr}" 85 | ) 86 | if config.wandb_project: 87 | load_dotenv() 88 | wandb.init( 89 | name=run_name, 90 | project=config.wandb_project, 91 | entity=os.getenv("WANDB_ENTITY"), 92 | config=config.model_dump(mode="json"), 93 | ) 94 | 95 | timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 96 | save_dir = config.train.save_dir / f"{run_name}_{timestamp}" if config.train.save_dir else None 97 | 98 | samples = 0 99 | # Training loop 100 | for epoch in tqdm( 101 | range(1, config.train.n_epochs + 1), total=config.train.n_epochs, desc="Epochs" 102 | ): 103 | for i, (images, labels) in enumerate(train_loader): 104 | images, labels = images.to(device), labels.to(device) 105 | 106 | samples += images.shape[0] 107 | # Flatten the images 108 | images = images.view(images.shape[0], -1) 109 | 110 | outputs = model(images) 111 | loss = criterion(outputs, labels) 112 | 113 | # Calculate accuracy 114 | _, argmax = torch.max(outputs, 1) 115 | accuracy = (labels == argmax.squeeze()).float().mean() 116 | 117 | optimizer.zero_grad() 118 | loss.backward() 119 | optimizer.step() 120 | 121 | if (i + 1) % config.train.n_epochs // 10 == 0: 122 | logger.info( 123 | "Epoch [%d/%d], Step [%d/%d], Loss: %f, Accuracy: %f", 124 | epoch, 125 | config.train.n_epochs, 126 | i + 1, 127 | len(train_loader), 128 | loss.item(), 129 | accuracy.item(), 130 | ) 131 | 132 | if config.wandb_project: 133 | wandb.log({"train/loss": loss.item()}, step=samples) 134 | 135 | # Validate the model 136 | model.eval() 137 | with torch.no_grad(): 138 | correct = 0 139 | total = 0 140 | for images, labels in valid_loader: 141 | images, labels = images.to(device), labels.to(device) 142 | images = images.view(images.shape[0], -1) 143 | outputs = model(images) 144 | _, argmax = torch.max(outputs, 1) 145 | total += labels.size(0) 146 | correct += (labels == argmax.squeeze()).sum().item() 147 | 148 | accuracy = correct / total 149 | logger.info("Accuracy of the network on the 10000 test images: %f %%", 100 * accuracy) 150 | 151 | if config.wandb_project: 152 | wandb.log({"valid/accuracy": accuracy}, step=samples) 153 | model.train() 154 | 155 | # Test the model 156 | model.eval() 157 | with torch.no_grad(): 158 | correct = 0 159 | total = 0 160 | for images, labels in test_loader: 161 | images, labels = images.to(device), labels.to(device) 162 | images = images.view(images.shape[0], -1) 163 | outputs = model(images) 164 | _, argmax = torch.max(outputs, 1) 165 | total += labels.size(0) 166 | correct += (labels == argmax.squeeze()).sum().item() 167 | 168 | accuracy = correct / total 169 | logger.info("Accuracy of the network on the 10000 test images: %f %%", 100 * accuracy) 170 | 171 | if config.wandb_project: 172 | wandb.log({"test/accuracy": accuracy}, step=samples) 173 | model.train() 174 | 175 | if save_dir: 176 | save_module( 177 | config_dict=config.model_dump(mode="json"), 178 | save_dir=save_dir, 179 | module=model, 180 | model_filename=f"epoch_{config.train.n_epochs}.pt", 181 | ) 182 | if config.wandb_project: 183 | wandb.finish() 184 | 185 | 186 | def main(config_path_str: str) -> None: 187 | config_path = Path(config_path_str) 188 | config = load_config(config_path, config_model=Config) 189 | set_seed(config.seed) 190 | train(config) 191 | 192 | 193 | if __name__ == "__main__": 194 | fire.Fire(main) 195 | -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens/run_train_tlens.py: -------------------------------------------------------------------------------- 1 | """Train a custom transformerlens model. 2 | 3 | Usage: 4 | python run_train_tlens.py 5 | """ 6 | 7 | import os 8 | from datetime import datetime 9 | from pathlib import Path 10 | from typing import Self 11 | 12 | import fire 13 | import torch 14 | import wandb 15 | from dotenv import load_dotenv 16 | from jaxtyping import Int 17 | from pydantic import ( 18 | BaseModel, 19 | ConfigDict, 20 | NonNegativeInt, 21 | PositiveFloat, 22 | PositiveInt, 23 | model_validator, 24 | ) 25 | from torch import Tensor 26 | from tqdm import tqdm 27 | from transformer_lens import HookedTransformer, HookedTransformerConfig, evals 28 | 29 | from e2e_sae.types import RootPath, TorchDtype 30 | from e2e_sae.utils import load_config, save_module, set_seed 31 | 32 | 33 | class HookedTransformerPreConfig(BaseModel): 34 | """Pydantic model whose arguments will be passed to a HookedTransformerConfig.""" 35 | 36 | model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True, frozen=True) 37 | d_model: PositiveInt 38 | n_layers: PositiveInt 39 | n_ctx: PositiveInt 40 | d_head: PositiveInt 41 | d_vocab: PositiveInt 42 | act_fn: str 43 | dtype: TorchDtype | None 44 | tokenizer_name: str 45 | 46 | 47 | class TrainConfig(BaseModel): 48 | model_config = ConfigDict(extra="forbid", frozen=True) 49 | n_epochs: PositiveInt 50 | batch_size: PositiveInt 51 | effective_batch_size: PositiveInt | None = None 52 | lr: PositiveFloat 53 | warmup_samples: NonNegativeInt = 0 54 | save_dir: RootPath | None = Path(__file__).parent / "out" 55 | save_every_n_epochs: PositiveInt | None 56 | 57 | @model_validator(mode="after") 58 | def check_effective_batch_size(self) -> Self: 59 | if self.effective_batch_size is not None: 60 | assert ( 61 | self.effective_batch_size % self.batch_size == 0 62 | ), "effective_batch_size must be a multiple of batch_size." 63 | return self 64 | 65 | 66 | class Config(BaseModel): 67 | model_config = ConfigDict(extra="forbid", frozen=True) 68 | seed: int = 0 69 | name: str 70 | tlens_config: HookedTransformerPreConfig 71 | train: TrainConfig 72 | wandb_project: str | None # If None, don't log to Weights & Biases 73 | 74 | 75 | def train(config: Config, model: HookedTransformer, device: torch.device) -> None: 76 | model.train() 77 | optimizer = torch.optim.Adam(model.parameters(), lr=config.train.lr) 78 | 79 | effective_batch_size = config.train.effective_batch_size or config.train.batch_size 80 | n_gradient_accumulation_steps = effective_batch_size // config.train.batch_size 81 | 82 | scheduler = None 83 | if config.train.warmup_samples > 0: 84 | scheduler = torch.optim.lr_scheduler.LambdaLR( 85 | optimizer, 86 | lr_lambda=lambda step: min( 87 | 1.0, (step + 1) / (config.train.warmup_samples // effective_batch_size) 88 | ), 89 | ) 90 | 91 | train_loader = evals.make_pile_data_loader(model.tokenizer, batch_size=config.train.batch_size) 92 | 93 | # Initialize wandb 94 | run_name = f"{config.name}_lr-{config.train.lr}_bs-{config.train.batch_size}" 95 | if config.wandb_project: 96 | load_dotenv() 97 | wandb.init( 98 | name=run_name, 99 | project=config.wandb_project, 100 | entity=os.getenv("WANDB_ENTITY"), 101 | config=config.model_dump(mode="json"), 102 | ) 103 | timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 104 | save_dir = config.train.save_dir / f"{run_name}_{timestamp}" if config.train.save_dir else None 105 | 106 | samples = 0 107 | grad_updates = 0 108 | for epoch in tqdm( 109 | range(1, config.train.n_epochs + 1), total=config.train.n_epochs, desc="Epochs" 110 | ): 111 | for step, batch in tqdm(enumerate(train_loader), total=len(train_loader), desc="Steps"): 112 | tokens: Int[Tensor, "batch pos"] = batch["tokens"].to(device=device) 113 | loss = model(tokens, return_type="loss") 114 | 115 | loss = loss / n_gradient_accumulation_steps 116 | loss.backward() 117 | 118 | if (step + 1) % n_gradient_accumulation_steps == 0: 119 | optimizer.step() 120 | optimizer.zero_grad() 121 | grad_updates += 1 122 | 123 | if config.train.warmup_samples > 0: 124 | assert scheduler is not None 125 | scheduler.step() 126 | 127 | samples += tokens.shape[0] 128 | if step == 0 or step % 20 == 0: 129 | tqdm.write( 130 | f"Epoch {epoch} Samples {samples} Step {step} GradUpdates {grad_updates} " 131 | f"Loss {loss.item()}" 132 | ) 133 | 134 | if config.wandb_project: 135 | wandb.log( 136 | { 137 | "train_loss": loss.item(), 138 | "epoch": epoch, 139 | "grad_updates": grad_updates, 140 | "lr": optimizer.param_groups[0]["lr"], 141 | }, 142 | step=samples, 143 | ) 144 | if save_dir and ( 145 | (config.train.save_every_n_epochs and epoch % config.train.save_every_n_epochs == 0) 146 | or epoch == config.train.n_epochs # Save the last epoch 147 | ): 148 | save_module( 149 | config_dict=config.model_dump(mode="json"), 150 | save_dir=save_dir, 151 | module=model, 152 | model_filename=f"epoch_{epoch}.pt", 153 | ) 154 | # TODO: Add evaluation loop 155 | if config.wandb_project: 156 | wandb.finish() 157 | 158 | 159 | def main(config_path_str: str) -> None: 160 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 161 | config = load_config(config_path_str, config_model=Config) 162 | set_seed(config.seed) 163 | 164 | hooked_transformer_config = HookedTransformerConfig(**config.tlens_config.model_dump()) 165 | model = HookedTransformer(hooked_transformer_config) 166 | model.to(device) 167 | 168 | train(config, model, device=device) 169 | 170 | 171 | if __name__ == "__main__": 172 | fire.Fire(main) 173 | -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens/sample_models/tiny-gpt2_lr-0.001_bs-16_2024-04-21_14-01-14/epoch_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloResearch/e2e_sae/83b4add4d2652a4e5c6775a0e2b752897c871a87/e2e_sae/scripts/train_tlens/sample_models/tiny-gpt2_lr-0.001_bs-16_2024-04-21_14-01-14/epoch_1.pt -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens/sample_models/tiny-gpt2_lr-0.001_bs-16_2024-04-21_14-01-14/final_config.yaml: -------------------------------------------------------------------------------- 1 | name: tiny-gpt2 2 | seed: 0 3 | tlens_config: 4 | act_fn: gelu 5 | d_head: 4 6 | d_model: 4 7 | d_vocab: 50257 8 | dtype: float32 9 | n_ctx: 1024 10 | n_layers: 2 11 | tokenizer_name: gpt2 12 | train: 13 | batch_size: 16 14 | effective_batch_size: null 15 | lr: 0.001 16 | n_epochs: 3 17 | save_dir: /path/to/save_dir 18 | save_every_n_epochs: 1 19 | warmup_samples: 0 20 | wandb_project: e2e_sae-custom 21 | -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens/tiny_gpt2.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | name: tiny-gpt2 3 | tlens_config: 4 | d_model: 4 5 | n_layers: 2 6 | n_ctx: 1024 7 | d_head: 4 8 | d_vocab: 50257 9 | act_fn: gelu 10 | dtype: float32 11 | tokenizer_name: gpt2 12 | train: 13 | n_epochs: 3 14 | batch_size: 16 15 | effective_batch_size: null 16 | lr: 1e-3 17 | warmup_samples: 0 18 | save_every_n_epochs: 1 19 | wandb_project: e2e_sae-custom -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens_saes/gpt2_e2e.yaml: -------------------------------------------------------------------------------- 1 | wandb_project: gpt2-e2e_play 2 | wandb_run_name: null # If not set, will use a name based on important config values 3 | wandb_run_name_prefix: "" 4 | 5 | seed: 0 6 | tlens_model_name: gpt2-small 7 | tlens_model_path: null 8 | 9 | n_samples: 400_000 10 | save_every_n_samples: null 11 | eval_every_n_samples: 40_000 12 | eval_n_samples: 500 13 | log_every_n_grad_steps: 20 14 | collect_act_frequency_every_n_samples: 40_000 15 | act_frequency_n_tokens: 500_000 16 | batch_size: 8 17 | effective_batch_size: 16 # Number of samples before each optimizer step 18 | lr: 5e-4 19 | lr_schedule: cosine 20 | min_lr_factor: 0.1 # Minimum learning rate as a fraction of the initial learning rate 21 | warmup_samples: 20_000 # Linear warmup over this many samples 22 | max_grad_norm: 10.0 # Gradient norms get clipped to this value before optimizer steps 23 | 24 | loss: 25 | # Note that "original acts" below refers to the activations in a model without SAEs 26 | sparsity: 27 | p_norm: 1.0 # p value in Lp norm 28 | coeff: 1.5 # Multiplies the Lp norm in the loss (sparsity coefficient) 29 | in_to_orig: null # Used for e2e+future recon. MSE between the input to the SAE and original acts 30 | out_to_orig: null # Not commonly used. MSE between the output of the SAE and original acts 31 | out_to_in: 32 | # Multiplies the MSE between the output and input of the SAE. Setting to 0 lets us track this 33 | # loss during training without optimizing it 34 | coeff: 0.0 35 | logits_kl: 36 | coeff: 1.0 # Multiplies the KL divergence between the logits of the SAE model and original model 37 | train_data: 38 | # See https://huggingface.co/apollo-research for other pre-tokenized datasets 39 | dataset_name: apollo-research/Skylion007-openwebtext-tokenizer-gpt2 40 | is_tokenized: true 41 | tokenizer_name: gpt2 42 | streaming: true 43 | split: train 44 | n_ctx: 1024 45 | eval_data: # By default this will use a different seed to the training data, but can be set with `seed` 46 | dataset_name: apollo-research/Skylion007-openwebtext-tokenizer-gpt2 47 | is_tokenized: true 48 | tokenizer_name: gpt2 49 | streaming: true 50 | split: train 51 | n_ctx: 1024 52 | saes: 53 | retrain_saes: false # Determines whether to continue training the SAEs in pretrained_sae_paths 54 | pretrained_sae_paths: null # Path or paths to pretrained SAEs 55 | sae_positions: # Position or positions to place SAEs in the model 56 | - blocks.6.hook_resid_pre 57 | dict_size_to_input_ratio: 60.0 # Size of the dictionary relative to the activations at the SAE positions -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens_saes/gpt2_e2e_recon.yaml: -------------------------------------------------------------------------------- 1 | wandb_project: gpt2-e2e 2 | wandb_run_name: null # If not set, will use a name based on important config values 3 | wandb_run_name_prefix: "" 4 | 5 | seed: 0 6 | tlens_model_name: gpt2-small 7 | tlens_model_path: null 8 | 9 | n_samples: 400_000 10 | save_every_n_samples: null 11 | eval_every_n_samples: 40_000 12 | eval_n_samples: 500 13 | log_every_n_grad_steps: 20 14 | collect_act_frequency_every_n_samples: 40_000 15 | act_frequency_n_tokens: 500_000 16 | batch_size: 8 17 | effective_batch_size: 16 # Number of samples before each optimizer step 18 | lr: 5e-4 19 | lr_schedule: cosine 20 | min_lr_factor: 0.1 # Minimum learning rate as a fraction of the initial learning rate 21 | warmup_samples: 20_000 # Linear warmup over this many samples 22 | max_grad_norm: 10.0 # Gradient norms get clipped to this value before optimizer steps 23 | 24 | loss: 25 | # Note that "original acts" below refers to the activations in a model without SAEs 26 | sparsity: 27 | p_norm: 1.0 # p value in Lp norm 28 | coeff: 1.5 # Multiplies the Lp norm in the loss (sparsity coefficient) 29 | in_to_orig: 30 | # Used for e2e+recon. Positions in which to calculate the MSE between the activations of the 31 | # model with SAEs and the original model 32 | hook_positions: 33 | - blocks.7.hook_resid_pre 34 | - blocks.8.hook_resid_pre 35 | - blocks.9.hook_resid_pre 36 | - blocks.10.hook_resid_pre 37 | - blocks.11.hook_resid_pre 38 | total_coeff: 2.5 # Coefficient for the above MSE loss. Is split evenly between all hook_positions 39 | out_to_orig: null # Not commonly used. MSE between the output of the SAE and original acts 40 | out_to_in: 41 | # Multiplies the MSE between the output and input of the SAE. Setting to 0 lets us track this 42 | # loss during training without optimizing it 43 | coeff: 0.0 44 | logits_kl: 45 | coeff: 0.5 # Multiplies the KL divergence between the logits of the SAE model and original model 46 | train_data: 47 | # See https://huggingface.co/apollo-research for other pre-tokenized datasets 48 | dataset_name: apollo-research/Skylion007-openwebtext-tokenizer-gpt2 49 | is_tokenized: true 50 | tokenizer_name: gpt2 51 | streaming: true 52 | split: train 53 | n_ctx: 1024 54 | eval_data: # By default this will use a different seed to the training data, but can be set with `seed` 55 | dataset_name: apollo-research/Skylion007-openwebtext-tokenizer-gpt2 56 | is_tokenized: true 57 | tokenizer_name: gpt2 58 | streaming: true 59 | split: train 60 | n_ctx: 1024 61 | saes: 62 | retrain_saes: false # Determines whether to continue training the SAEs in pretrained_sae_paths 63 | pretrained_sae_paths: null # Path or paths to pretrained SAEs 64 | sae_positions: # Position or positions to place SAEs in the model 65 | - blocks.6.hook_resid_pre 66 | dict_size_to_input_ratio: 60.0 # Size of the dictionary relative to the activations at the SAE positions -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens_saes/gpt2_e2e_recon_sweep.yaml: -------------------------------------------------------------------------------- 1 | program: e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py 2 | name: gpt2-e2e 3 | method: grid 4 | metric: 5 | name: val_loss 6 | goal: minimize 7 | parameters: 8 | seed: 9 | values: [0] 10 | n_samples: 11 | values: [400_000] 12 | lr: 13 | values: [5e-4] 14 | loss: 15 | parameters: 16 | sparsity: 17 | parameters: 18 | coeff: 19 | values: [0.8, 2.5, 5] 20 | in_to_orig: 21 | parameters: 22 | hook_positions: 23 | values: [[blocks.11.hook_resid_pre]] 24 | total_coeff: 25 | values: [0.05, 0.1] 26 | out_to_orig: 27 | values: [null] 28 | out_to_in: 29 | parameters: 30 | coeff: 31 | values: [0.0] 32 | logits_kl: 33 | parameters: 34 | coeff: 35 | values: [0.5] 36 | saes: 37 | parameters: 38 | sae_positions: 39 | values: 40 | - blocks.10.hook_resid_pre 41 | dict_size_to_input_ratio: 42 | values: [60.0] 43 | 44 | train_data: 45 | parameters: 46 | dataset_name: 47 | values: [apollo-research/Skylion007-openwebtext-tokenizer-gpt2] 48 | is_tokenized: 49 | values: [true] 50 | tokenizer_name: 51 | values: [gpt2] 52 | streaming: 53 | values: [true] 54 | split: 55 | values: [train] 56 | n_ctx: 57 | values: [1024] 58 | eval_data: 59 | parameters: 60 | dataset_name: 61 | values: [apollo-research/Skylion007-openwebtext-tokenizer-gpt2] 62 | is_tokenized: 63 | values: [true] 64 | tokenizer_name: 65 | values: [gpt2] 66 | streaming: 67 | values: [true] 68 | split: 69 | values: [train] 70 | n_ctx: 71 | values: [1024] 72 | command: 73 | - ${env} 74 | - ${interpreter} 75 | - ${program} 76 | - e2e_sae/scripts/train_tlens_saes/gpt2_e2e_recon.yaml -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens_saes/gpt2_e2e_sweep.yaml: -------------------------------------------------------------------------------- 1 | program: e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py 2 | name: gpt2-e2e 3 | method: grid 4 | metric: 5 | name: val_loss 6 | goal: minimize 7 | parameters: 8 | seed: 9 | values: [0] 10 | n_samples: 11 | values: [400_000] 12 | lr: 13 | values: [5e-4] 14 | loss: 15 | parameters: 16 | sparsity: 17 | parameters: 18 | coeff: 19 | values: [0.08, 0.2, 0.8, 2] 20 | in_to_orig: 21 | values: [null] 22 | out_to_orig: 23 | values: [null] 24 | out_to_in: 25 | parameters: 26 | coeff: 27 | values: [0.0] 28 | logits_kl: 29 | parameters: 30 | coeff: 31 | values: [1.0] 32 | saes: 33 | parameters: 34 | sae_positions: 35 | values: 36 | - blocks.2.hook_resid_pre 37 | dict_size_to_input_ratio: 38 | values: [60.0] 39 | 40 | train_data: 41 | parameters: 42 | dataset_name: 43 | values: [apollo-research/Skylion007-openwebtext-tokenizer-gpt2] 44 | is_tokenized: 45 | values: [true] 46 | tokenizer_name: 47 | values: [gpt2] 48 | streaming: 49 | values: [true] 50 | split: 51 | values: [train] 52 | n_ctx: 53 | values: [1024] 54 | eval_data: 55 | parameters: 56 | dataset_name: 57 | values: [apollo-research/Skylion007-openwebtext-tokenizer-gpt2] 58 | is_tokenized: 59 | values: [true] 60 | tokenizer_name: 61 | values: [gpt2] 62 | streaming: 63 | values: [true] 64 | split: 65 | values: [train] 66 | n_ctx: 67 | values: [1024] 68 | command: 69 | - ${env} 70 | - ${interpreter} 71 | - ${program} 72 | - e2e_sae/scripts/train_tlens_saes/gpt2_e2e.yaml -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens_saes/gpt2_local.yaml: -------------------------------------------------------------------------------- 1 | wandb_project: gpt2-layerwise_play 2 | wandb_run_name: null 3 | wandb_run_name_prefix: "" 4 | 5 | seed: 0 6 | tlens_model_name: gpt2-small 7 | tlens_model_path: null 8 | 9 | n_samples: 400_000 10 | save_every_n_samples: null 11 | eval_every_n_samples: 40_000 12 | eval_n_samples: 500 13 | log_every_n_grad_steps: 20 14 | collect_act_frequency_every_n_samples: 40_000 15 | act_frequency_n_tokens: 500_000 16 | batch_size: 8 17 | effective_batch_size: 16 # Number of samples before each optimizer step 18 | lr: 5e-4 19 | lr_schedule: cosine 20 | min_lr_factor: 0.1 # Minimum learning rate as a fraction of the initial learning rate 21 | warmup_samples: 20_000 # Linear warmup over this many samples 22 | max_grad_norm: 10.0 # Gradient norms get clipped to this value before optimizer steps 23 | 24 | loss: 25 | # Note that "original acts" below refers to the activations in a model without SAEs 26 | sparsity: 27 | p_norm: 1.0 # p value in Lp norm 28 | coeff: 6.0 # Multiplies the Lp norm in the loss (sparsity coefficient) 29 | in_to_orig: null # Used for e2e+future recon. MSE between the input to the SAE and original acts 30 | out_to_orig: null # Not commonly used. MSE between the output of the SAE and original acts 31 | out_to_in: 32 | coeff: 1.0 # Multiplies the MSE between the output and input of the SAE 33 | logits_kl: null # Multiplies the KL divergence between the logits of the SAE model and original model 34 | train_data: 35 | # See https://huggingface.co/apollo-research for other pre-tokenized datasets 36 | dataset_name: apollo-research/Skylion007-openwebtext-tokenizer-gpt2 37 | is_tokenized: true 38 | tokenizer_name: gpt2 39 | streaming: true 40 | split: train 41 | n_ctx: 1024 42 | eval_data: # By default this will use a different seed to the training data, but can be set with `seed` 43 | dataset_name: apollo-research/Skylion007-openwebtext-tokenizer-gpt2 44 | is_tokenized: true 45 | tokenizer_name: gpt2 46 | streaming: true 47 | split: train 48 | n_ctx: 1024 49 | saes: 50 | retrain_saes: false # Determines whether to continue training the SAEs in pretrained_sae_paths 51 | pretrained_sae_paths: null # Path or paths to pretrained SAEs 52 | sae_positions: # Position or positions to place SAEs in the model 53 | - blocks.6.hook_resid_pre 54 | dict_size_to_input_ratio: 60.0 # Size of the dictionary relative to the activations at the SAE positions -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens_saes/gpt2_local_sweep.yaml: -------------------------------------------------------------------------------- 1 | program: e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py 2 | name: gpt2_local 3 | method: grid 4 | metric: 5 | name: val_loss 6 | goal: minimize 7 | parameters: 8 | seed: 9 | values: [0, 1] 10 | n_samples: 11 | values: [400_000] 12 | lr: 13 | values: [5e-4] 14 | loss: 15 | parameters: 16 | sparsity: 17 | parameters: 18 | coeff: 19 | values: [1, 3, 6, 10, 15, 20, 30] 20 | in_to_orig: 21 | values: [null] 22 | out_to_orig: 23 | values: [null] 24 | out_to_in: 25 | parameters: 26 | coeff: 27 | values: [1.0] 28 | logits_kl: 29 | values: [null] 30 | saes: 31 | parameters: 32 | sae_positions: 33 | values: 34 | - blocks.10.hook_resid_pre 35 | dict_size_to_input_ratio: 36 | values: [60.0] 37 | 38 | train_data: 39 | parameters: 40 | dataset_name: 41 | values: [apollo-research/Skylion007-openwebtext-tokenizer-gpt2] 42 | is_tokenized: 43 | values: [true] 44 | tokenizer_name: 45 | values: [gpt2] 46 | streaming: 47 | values: [true] 48 | split: 49 | values: [train] 50 | n_ctx: 51 | values: [1024] 52 | eval_data: 53 | parameters: 54 | dataset_name: 55 | values: [apollo-research/Skylion007-openwebtext-tokenizer-gpt2] 56 | is_tokenized: 57 | values: [true] 58 | tokenizer_name: 59 | values: [gpt2] 60 | streaming: 61 | values: [true] 62 | split: 63 | values: [train] 64 | n_ctx: 65 | values: [1024] 66 | command: 67 | - ${env} 68 | - ${interpreter} 69 | - ${program} 70 | - e2e_sae/scripts/train_tlens_saes/gpt2_local.yaml -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens_saes/pythia_14m_e2e.yaml: -------------------------------------------------------------------------------- 1 | wandb_project: pythia-14m 2 | wandb_run_name: null 3 | wandb_run_name_prefix: "" 4 | 5 | seed: 0 6 | tlens_model_name: pythia-14m 7 | tlens_model_path: null 8 | 9 | n_samples: 200_000 10 | save_every_n_samples: null 11 | eval_every_n_samples: 40_000 12 | eval_n_samples: 500 13 | log_every_n_grad_steps: 20 14 | collect_act_frequency_every_n_samples: 40_000 15 | act_frequency_n_tokens: 1_000_000 16 | batch_size: 8 17 | effective_batch_size: 8 18 | lr: 5e-4 19 | lr_schedule: cosine 20 | min_lr_factor: 0.1 21 | warmup_samples: 20_000 22 | max_grad_norm: null 23 | 24 | loss: 25 | sparsity: 26 | p_norm: 1.0 27 | coeff: 1.0 28 | in_to_orig: null 29 | out_to_orig: null 30 | out_to_in: 31 | coeff: 0.0 32 | logits_kl: 33 | coeff: 1.0 34 | train_data: 35 | dataset_name: apollo-research/monology-pile-uncopyrighted-tokenizer-EleutherAI-gpt-neox-20b 36 | is_tokenized: true 37 | tokenizer_name: EleutherAI/gpt-neox-20B 38 | streaming: true 39 | split: train 40 | n_ctx: 2048 41 | eval_data: 42 | dataset_name: apollo-research/monology-pile-uncopyrighted-tokenizer-EleutherAI-gpt-neox-20b 43 | is_tokenized: true 44 | tokenizer_name: EleutherAI/gpt-neox-20B 45 | streaming: true 46 | split: train 47 | n_ctx: 2048 48 | saes: 49 | retrain_saes: false 50 | pretrained_sae_paths: null 51 | sae_positions: 52 | - blocks.3.hook_resid_pre 53 | dict_size_to_input_ratio: 60.0 -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens_saes/run_sweep.py: -------------------------------------------------------------------------------- 1 | """Run the training script with different config params. 2 | 3 | TODO: Replace this with wandb sweeps. 4 | Usage: 5 | python run_sweep.py 6 | """ 7 | 8 | import yaml 9 | from fire import Fire 10 | 11 | from e2e_sae.scripts.train_tlens_saes.run_train_tlens_saes import Config 12 | from e2e_sae.scripts.train_tlens_saes.run_train_tlens_saes import main as run_train 13 | from e2e_sae.utils import replace_pydantic_model 14 | 15 | 16 | def main(config_path_str: str) -> None: 17 | """Run the training script with different sae_position values.""" 18 | sweep_name = "tinystories-1m_sparsity-coeff" 19 | values = [1.0, 0.1, 0.01, 0.001, 0.0001, 0.00001] 20 | 21 | with open(config_path_str) as f: 22 | base_config = Config(**yaml.safe_load(f)) 23 | 24 | for value in values: 25 | update_dict = { 26 | "train": {"loss": {"sparsity": {"coeff": value}}}, 27 | "wandb_project": sweep_name, 28 | } 29 | new_config = replace_pydantic_model(base_config, update_dict) 30 | print(new_config) 31 | run_train(new_config) 32 | 33 | 34 | if __name__ == "__main__": 35 | Fire(main) 36 | -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens_saes/run_sweep_mp.py: -------------------------------------------------------------------------------- 1 | """Run the training script with different config params over multiple GPUs. 2 | 3 | Usage: 4 | python run_sweep_mp.py 5 | """ 6 | 7 | import subprocess 8 | from tempfile import NamedTemporaryFile 9 | 10 | import yaml 11 | from fire import Fire 12 | 13 | from e2e_sae.scripts.train_tlens_saes.run_train_tlens_saes import Config 14 | from e2e_sae.settings import REPO_ROOT 15 | from e2e_sae.utils import replace_pydantic_model 16 | 17 | SCRIPT_PATH = f"{REPO_ROOT}/e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py" 18 | 19 | 20 | def main(config_path_str: str) -> None: 21 | """Run the training script with different sae_position values. 22 | 23 | NOTE: You must specify the GPU indices to use in the `gpu_idxs` list. 24 | """ 25 | sweep_name = "tinystories-1m_sparsity-coeff" 26 | values = [0.01, 0.001, 0.0001, 0.00001] 27 | gpu_idxs = [0, 1, 2, 3] 28 | 29 | assert len(values) == len( 30 | gpu_idxs 31 | ), "Currently only supports having the same number of values and gpu_idxs" 32 | 33 | with open(config_path_str) as f: 34 | base_config = Config(**yaml.safe_load(f)) 35 | 36 | for idx, value in zip(gpu_idxs, values, strict=True): 37 | update_dict = { 38 | "train": {"loss": {"sparsity": {"coeff": value}}}, 39 | "wandb_project": sweep_name, 40 | } 41 | new_config = replace_pydantic_model(base_config, update_dict) 42 | # Write the config to a temporary file and then call a subprocess to run the training script 43 | print(new_config) 44 | with NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: 45 | yaml.dump(new_config.model_dump(mode="json"), f) 46 | config_path = f.name 47 | 48 | session_exists = subprocess.run(f"tmux has-session -t {idx}".split(), capture_output=True) 49 | if session_exists.returncode == 0: 50 | # Session exists, kill it 51 | subprocess.run(f"tmux kill-session -t {idx}".split()) 52 | 53 | # Create a new tmux session 54 | subprocess.run(f"tmux new-session -d -s {idx}".split()) 55 | 56 | train_command = f"CUDA_VISIBLE_DEVICES={idx} python {SCRIPT_PATH} {config_path}" 57 | tmux_send_keys_cuda_command = f"tmux send-keys -t {idx} '{train_command}' Enter" 58 | subprocess.run(tmux_send_keys_cuda_command, shell=True) 59 | 60 | 61 | if __name__ == "__main__": 62 | Fire(main) 63 | -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens_saes/run_wandb_sweep.py: -------------------------------------------------------------------------------- 1 | """Run wandb sweep agents in parallel using tmux sessions. 2 | 3 | Usage: 4 | The script can be invoked from the command line, providing the wandb agent ID and optionally, 5 | GPU indices separated by a comma. If no GPU indices are provided, the script will automatically 6 | use all available GPUs. 7 | 8 | Example command to run the script specifying GPUs: 9 | python run_wandb_sweep.py your_agent_id 0,1,4 10 | 11 | Example command to run the script using all available GPUs: 12 | python run_wandb_sweep.py your_agent_id 13 | 14 | """ 15 | import subprocess 16 | 17 | import torch 18 | from fire import Fire 19 | 20 | 21 | def main(agent_id: str, gpu_idxs: tuple[int, ...] | int | None = None) -> None: 22 | """Run the training script with specified GPU indices. 23 | 24 | Args: 25 | agent_id: The wandb agent ID. 26 | gpu_idxs: The GPU indices to use for training. If None, all available GPUs will be used. 27 | """ 28 | if isinstance(gpu_idxs, int): 29 | gpu_idxs = (gpu_idxs,) 30 | elif gpu_idxs is None: 31 | gpu_idxs = tuple(range(torch.cuda.device_count())) 32 | 33 | assert isinstance(gpu_idxs, tuple), "gpu_idxs must be a tuple of integers" 34 | 35 | print(f"Running wandb agent {agent_id} on GPUs {gpu_idxs}") 36 | for idx in gpu_idxs: 37 | session_exists = subprocess.run(f"tmux has-session -t {idx}".split(), capture_output=True) 38 | if session_exists.returncode == 0: 39 | # Session exists, kill it 40 | subprocess.run(f"tmux kill-session -t {idx}".split()) 41 | 42 | # Create a new tmux session 43 | subprocess.run(f"tmux new-session -d -s {idx}".split()) 44 | 45 | train_command = f"CUDA_VISIBLE_DEVICES={idx} wandb agent {agent_id}" 46 | tmux_send_keys_cuda_command = f"tmux send-keys -t {idx} '{train_command}' Enter" 47 | subprocess.run(tmux_send_keys_cuda_command, shell=True) 48 | 49 | 50 | if __name__ == "__main__": 51 | Fire(main) 52 | -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens_saes/tinystories_1M_e2e.yaml: -------------------------------------------------------------------------------- 1 | wandb_project: tinystories-1m-2 2 | wandb_run_name: null 3 | wandb_run_name_prefix: "" 4 | 5 | seed: 0 6 | tlens_model_name: roneneldan/TinyStories-1M 7 | tlens_model_path: null 8 | 9 | n_samples: 400_000 # 918604 samples of 512 tokens 10 | save_every_n_samples: null 11 | eval_every_n_samples: 40_000 12 | eval_n_samples: 500 13 | log_every_n_grad_steps: 20 14 | collect_act_frequency_every_n_samples: 40_000 15 | act_frequency_n_tokens: 500_000 #500k tokens is ~977 samples 16 | batch_size: 20 17 | effective_batch_size: 20 18 | lr: 1e-3 19 | lr_schedule: cosine 20 | min_lr_factor: 0.1 21 | warmup_samples: 20_000 22 | max_grad_norm: 1.0 23 | 24 | loss: 25 | sparsity: 26 | p_norm: 1.0 27 | coeff: 3.0 28 | in_to_orig: null 29 | out_to_orig: null 30 | out_to_in: 31 | coeff: 0.0 32 | logits_kl: 33 | coeff: 1.0 34 | train_data: 35 | dataset_name: apollo-research/roneneldan-TinyStories-tokenizer-gpt2 36 | is_tokenized: true 37 | tokenizer_name: gpt2 38 | streaming: true 39 | split: train 40 | n_ctx: 512 41 | eval_data: 42 | dataset_name: apollo-research/roneneldan-TinyStories-tokenizer-gpt2 43 | is_tokenized: true 44 | tokenizer_name: gpt2 45 | streaming: true 46 | split: validation 47 | n_ctx: 512 48 | saes: 49 | retrain_saes: false 50 | pretrained_sae_paths: null 51 | sae_positions: blocks.4.hook_resid_pre 52 | dict_size_to_input_ratio: 50.0 53 | -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens_saes/tinystories_1M_e2e_sweep.yaml: -------------------------------------------------------------------------------- 1 | program: e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py 2 | name: logits_kl 3 | method: grid 4 | metric: 5 | name: val_loss 6 | goal: minimize 7 | parameters: 8 | n_samples: 9 | values: [450_000] 10 | lr: 11 | values: [1e-2, 5e-3, 1e-3] 12 | loss: 13 | parameters: 14 | sparsity: 15 | parameters: 16 | coeff: 17 | values: [50, 30, 20] 18 | saes: 19 | parameters: 20 | sae_positions: 21 | values: [blocks.1.hook_resid_pre] 22 | dict_size_to_input_ratio: 23 | values: [5.0, 10.0] 24 | 25 | command: 26 | - ${env} 27 | - ${interpreter} 28 | - ${program} 29 | - e2e_sae/scripts/train_tlens_saes/tinystories_1M_e2e.yaml -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens_saes/tinystories_1M_local.yaml: -------------------------------------------------------------------------------- 1 | wandb_project: tinystories-1m_play 2 | wandb_run_name: null 3 | wandb_run_name_prefix: "" 4 | 5 | seed: 0 6 | tlens_model_name: roneneldan/TinyStories-1M 7 | tlens_model_path: null 8 | 9 | n_samples: 400_000 # 918604 samples of 512 tokens 10 | save_every_n_samples: null 11 | eval_every_n_samples: 40_000 12 | eval_n_samples: 500 13 | log_every_n_grad_steps: 20 14 | collect_act_frequency_every_n_samples: 40_000 15 | act_frequency_n_tokens: 500_000 #500k tokens is ~977 samples 16 | batch_size: 20 17 | effective_batch_size: 20 18 | lr: 1e-3 19 | lr_schedule: cosine 20 | min_lr_factor: 0.1 21 | warmup_samples: 20_000 22 | max_grad_norm: 1.0 23 | 24 | loss: 25 | sparsity: 26 | p_norm: 1.0 27 | coeff: 1e-2 28 | in_to_orig: null 29 | out_to_orig: null 30 | out_to_in: 31 | coeff: 1.0 32 | logits_kl: null 33 | train_data: 34 | dataset_name: apollo-research/roneneldan-TinyStories-tokenizer-gpt2 35 | is_tokenized: true 36 | tokenizer_name: gpt2 37 | streaming: true 38 | split: train 39 | n_ctx: 512 40 | eval_data: 41 | dataset_name: apollo-research/roneneldan-TinyStories-tokenizer-gpt2 42 | is_tokenized: true 43 | tokenizer_name: gpt2 44 | streaming: true 45 | split: validation 46 | n_ctx: 512 47 | saes: 48 | retrain_saes: false 49 | pretrained_sae_paths: null 50 | sae_positions: 51 | - blocks.4.hook_resid_pre 52 | dict_size_to_input_ratio: 50.0 53 | -------------------------------------------------------------------------------- /e2e_sae/scripts/train_tlens_saes/tinystories_1M_local_sweep.yaml: -------------------------------------------------------------------------------- 1 | program: e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py 2 | name: tinystories_1m_local 3 | method: grid 4 | metric: 5 | name: val_loss 6 | goal: minimize 7 | parameters: 8 | n_samples: 9 | values: [400_000] 10 | lr: 11 | values: [1e-3] 12 | loss: 13 | parameters: 14 | sparsity: 15 | parameters: 16 | coeff: 17 | values: [0.001, 0.005, 0.008, 0.01, 0.02, 0.05] 18 | saes: 19 | parameters: 20 | sae_positions: 21 | values: [blocks.4.hook_resid_pre] 22 | dict_size_to_input_ratio: 23 | values: [5, 20, 60, 100] 24 | 25 | command: 26 | - ${env} 27 | - ${interpreter} 28 | - ${program} 29 | - e2e_sae/scripts/train_tlens_saes/tinystories_1M_local.yaml -------------------------------------------------------------------------------- /e2e_sae/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | REPO_ROOT = ( 5 | Path(os.environ["GITHUB_WORKSPACE"]) if os.environ.get("CI") else Path(__file__).parent.parent 6 | ) 7 | -------------------------------------------------------------------------------- /e2e_sae/types.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Annotated, Any, Literal, TypedDict 3 | 4 | import torch 5 | from jaxtyping import Int 6 | from pydantic import BeforeValidator, PlainSerializer 7 | from torch import Tensor 8 | 9 | from e2e_sae.utils import to_root_path 10 | 11 | StrDtype = Literal["float32", "float64", "bfloat16"] 12 | TORCH_DTYPES: dict[StrDtype, torch.dtype] = { 13 | "float32": torch.float32, 14 | "float64": torch.float64, 15 | "bfloat16": torch.bfloat16, 16 | } 17 | 18 | 19 | class Samples(TypedDict): 20 | """Tokenized samples.""" 21 | 22 | input_ids: Int[Tensor, "batch pos"] 23 | 24 | 25 | def convert_str_to_torch_dtype(v: Any) -> torch.dtype: 26 | """Convert dtype from str to a supported torch dtype.""" 27 | if v in TORCH_DTYPES: 28 | return TORCH_DTYPES[v] 29 | elif v in TORCH_DTYPES.values(): 30 | return v 31 | else: 32 | raise ValueError(f"Invalid dtype: {v}") 33 | 34 | 35 | def serialize_torch_dtype_to_str(v: torch.dtype) -> str: 36 | """Convert dtype from torch dtype to str.""" 37 | for k, v2 in TORCH_DTYPES.items(): 38 | if v == v2: 39 | return k 40 | raise ValueError(f"Invalid dtype found during serialization: {v}") 41 | 42 | 43 | # Pydantic magic for: 44 | # 1. If given a string as input (e.g. "float32"), convert it to a torch dtype (e.g. torch.float32) 45 | # 2. model_dump(mode="json") will serialize the torch dtype to a string, model_dump() leaves it 46 | # as a torch dtype 47 | # 48 | TorchDtype = Annotated[ 49 | torch.dtype, 50 | BeforeValidator(convert_str_to_torch_dtype), 51 | PlainSerializer(serialize_torch_dtype_to_str, when_used="json"), 52 | ] 53 | 54 | # This is a type for pydantic configs that will convert all relative paths 55 | # to be relative to the ROOT_DIR of e2e_sae 56 | RootPath = Annotated[Path, BeforeValidator(to_root_path), PlainSerializer(lambda x: str(x))] 57 | -------------------------------------------------------------------------------- /e2e_sae/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | from collections.abc import Callable 5 | from functools import partial 6 | from pathlib import Path 7 | from typing import Any, TypeVar 8 | 9 | import numpy as np 10 | import torch 11 | import wandb 12 | import yaml 13 | from dotenv import load_dotenv 14 | from pydantic import BaseModel 15 | from pydantic.v1.utils import deep_update 16 | from torch import nn 17 | from torch.optim.lr_scheduler import LambdaLR 18 | from transformer_lens import HookedTransformer 19 | from transformer_lens.hook_points import HookPoint 20 | 21 | from e2e_sae.log import logger 22 | from e2e_sae.settings import REPO_ROOT 23 | 24 | T = TypeVar("T", bound=BaseModel) 25 | 26 | 27 | def to_numpy(tensor: Any) -> np.ndarray[Any, np.dtype[Any]]: 28 | """ 29 | Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays. 30 | Adapted from TransformerLens/transformer_lens/utils.py 31 | """ 32 | if isinstance(tensor, np.ndarray): 33 | return tensor 34 | elif isinstance(tensor, list | tuple): 35 | array = np.array(tensor) 36 | return array 37 | elif isinstance(tensor, torch.Tensor | torch.nn.parameter.Parameter): 38 | return tensor.detach().cpu().numpy() 39 | elif isinstance(tensor, int | float | bool | str): 40 | return np.array(tensor) 41 | else: 42 | raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}") 43 | 44 | 45 | def to_root_path(path: str | Path): 46 | """Converts relative paths to absolute ones, assuming they are relative to the rib root.""" 47 | return Path(path) if Path(path).is_absolute() else Path(REPO_ROOT / path) 48 | 49 | 50 | def save_module( 51 | config_dict: dict[str, Any], 52 | save_dir: Path, 53 | module: nn.Module, 54 | model_filename: str, 55 | config_filename: str = "final_config.yaml", 56 | ) -> None: 57 | """Save the pytorch module and config to the save_dir. 58 | 59 | The config will only be saved if the save_dir doesn't exist (i.e. the first time the module is 60 | saved assuming the save_dir is unique to the module). 61 | 62 | Args: 63 | config_dict: Dictionary representation of the config to save. 64 | save_dir: Directory to save the module. 65 | module: The module to save. 66 | model_filename: The filename to save the model to. 67 | config_filename: The filename to save the config to. 68 | """ 69 | # If the save_dir doesn't exist, create it and save the config 70 | if not save_dir.exists(): 71 | save_dir.mkdir(parents=True) 72 | with open(save_dir / config_filename, "w") as f: 73 | yaml.dump(config_dict, f) 74 | logger.info("Saved config to %s", save_dir / config_filename) 75 | 76 | torch.save(module.state_dict(), save_dir / model_filename) 77 | logger.info("Saved model to %s", save_dir / model_filename) 78 | 79 | 80 | def load_config(config_path_or_obj: Path | str | T, config_model: type[T]) -> T: 81 | """Load the config of class `config_model`, either from YAML file or existing config object. 82 | 83 | Args: 84 | config_path_or_obj (Union[Path, str, `config_model`]): if config object, must be instance 85 | of `config_model`. If str or Path, this must be the path to a .yaml. 86 | config_model: the class of the config that we are loading 87 | """ 88 | if isinstance(config_path_or_obj, config_model): 89 | return config_path_or_obj 90 | 91 | if isinstance(config_path_or_obj, str): 92 | config_path_or_obj = Path(config_path_or_obj) 93 | 94 | assert isinstance( 95 | config_path_or_obj, Path 96 | ), f"passed config is of invalid type {type(config_path_or_obj)}" 97 | assert ( 98 | config_path_or_obj.suffix == ".yaml" 99 | ), f"Config file {config_path_or_obj} must be a YAML file." 100 | assert Path(config_path_or_obj).exists(), f"Config file {config_path_or_obj} does not exist." 101 | with open(config_path_or_obj) as f: 102 | config_dict = yaml.safe_load(f) 103 | return config_model(**config_dict) 104 | 105 | 106 | def set_seed(seed: int | None) -> None: 107 | """Set the random seed for random, PyTorch and NumPy""" 108 | if seed is not None: 109 | torch.manual_seed(seed) 110 | np.random.seed(seed) 111 | random.seed(seed) 112 | 113 | 114 | BaseModelType = TypeVar("BaseModelType", bound=BaseModel) 115 | 116 | 117 | def replace_pydantic_model(model: BaseModelType, *updates: dict[str, Any]) -> BaseModelType: 118 | """Create a new model with (potentially nested) updates in the form of dictionaries. 119 | 120 | Args: 121 | model: The model to update. 122 | updates: The zero or more dictionaries of updates that will be applied sequentially. 123 | 124 | Returns: 125 | A replica of the model with the updates applied. 126 | 127 | Examples: 128 | >>> class Foo(BaseModel): 129 | ... a: int 130 | ... b: int 131 | >>> foo = Foo(a=1, b=2) 132 | >>> foo2 = replace_pydantic_model(foo, {"a": 3}) 133 | >>> foo2 134 | Foo(a=3, b=2) 135 | >>> class Bar(BaseModel): 136 | ... foo: Foo 137 | >>> bar = Bar(foo={"a": 1, "b": 2}) 138 | >>> bar2 = replace_pydantic_model(bar, {"foo": {"a": 3}}) 139 | >>> bar2 140 | Bar(foo=Foo(a=3, b=2)) 141 | """ 142 | return model.__class__(**deep_update(model.model_dump(), *updates)) 143 | 144 | 145 | def filter_names(all_names: list[str], filter_names: list[str] | str) -> list[str]: 146 | """Use filter_names to filter `all_names` by partial match. 147 | 148 | The filtering is done by checking if any of the filter_names are in the all_names. Partial 149 | matches are allowed. E.g. "hook_resid_pre" matches ["blocks.0.hook_resid_pre", 150 | "blocks.1.hook_resid_pre", ...]. 151 | 152 | Args: 153 | all_names: The names to filter. 154 | filter_names: The names to use to filter all_names by partial match. 155 | Returns: 156 | The filtered names. 157 | """ 158 | if isinstance(filter_names, str): 159 | filter_names = [filter_names] 160 | return [name for name in all_names if any(filter_name in name for filter_name in filter_names)] 161 | 162 | 163 | def get_hook_shapes(tlens_model: HookedTransformer, hook_names: list[str]) -> dict[str, list[int]]: 164 | """Get the shapes of activations at the hook points labelled by hook_names""" 165 | # Sadly I can't see any way to easily get the shapes of activations at hook_points without 166 | # actually running the model. 167 | hook_shapes = {} 168 | 169 | def get_activation_shape_hook_function(activation: torch.Tensor, hook: HookPoint) -> None: 170 | hook_shapes[hook.name] = activation.shape 171 | 172 | def hook_names_filter(name: str) -> bool: 173 | return name in hook_names 174 | 175 | test_prompt = torch.tensor([0]) 176 | tlens_model.run_with_hooks( 177 | test_prompt, 178 | return_type=None, 179 | fwd_hooks=[(hook_names_filter, get_activation_shape_hook_function)], 180 | ) 181 | return hook_shapes 182 | 183 | 184 | def get_linear_lr_schedule( 185 | warmup_samples: int, 186 | cooldown_samples: int, 187 | n_samples: int | None, 188 | effective_batch_size: int, 189 | min_lr_factor: float = 0.0, 190 | ) -> Callable[[int], float]: 191 | """ 192 | Generates a linear learning rate schedule function that incorporates warmup and cooldown phases. 193 | If warmup_samples and cooldown_samples are both 0, the learning rate will be constant at 1.0 194 | throughout training. 195 | 196 | Args: 197 | warmup_samples: The number of samples to use for warmup. 198 | cooldown_samples: The number of samples to use for cooldown. 199 | effective_batch_size: The effective batch size used during training. 200 | min_lr_factor: The minimum learning rate as a fraction of the maximum learning rate. Used 201 | in the cooldown phase. 202 | 203 | Returns: 204 | A function that takes a training step as input and returns the corresponding learning rate. 205 | 206 | Raises: 207 | ValueError: If the cooldown period starts before the warmup period ends. 208 | AssertionError: If a cooldown is requested but the total number of samples is not provided. 209 | """ 210 | warmup_steps = warmup_samples // effective_batch_size 211 | cooldown_steps = cooldown_samples // effective_batch_size 212 | 213 | if n_samples is None: 214 | assert cooldown_samples == 0, "Cooldown requested but total number of samples not provided." 215 | cooldown_start = float("inf") 216 | else: 217 | # NOTE: There may be 1 fewer steps if batch_size < effective_batch_size, but this won't 218 | # make a big difference for most learning setups. The + 1 is to account for the scheduler 219 | # step that occurs after training has finished 220 | total_steps = math.ceil(n_samples / effective_batch_size) + 1 221 | # Calculate the start step for cooldown 222 | cooldown_start = total_steps - cooldown_steps 223 | 224 | # Check for overlap between warmup and cooldown 225 | assert ( 226 | cooldown_start > warmup_steps 227 | ), "Cooldown starts before warmup ends. Adjust your parameters." 228 | 229 | def lr_schedule(step: int) -> float: 230 | if step < warmup_steps: 231 | # Warmup phase: linearly increase learning rate 232 | return (step + 1) / warmup_steps 233 | elif step >= cooldown_start: 234 | # Cooldown phase: linearly decrease learning rate 235 | # Calculate how many steps have been taken in the cooldown phase 236 | steps_into_cooldown = step - cooldown_start 237 | # Linearly decrease the learning rate 238 | return max(min_lr_factor, 1 - (steps_into_cooldown / cooldown_steps)) 239 | else: 240 | # Maintain maximum learning rate after warmup and before cooldown 241 | return 1.0 242 | 243 | return lr_schedule 244 | 245 | 246 | def _get_cosine_schedule_with_warmup_lr_lambda( 247 | current_step: int, 248 | *, 249 | num_warmup_steps: int, 250 | num_training_steps: int, 251 | num_cycles: float, 252 | min_lr_factor: float, 253 | ): 254 | if current_step < num_warmup_steps: 255 | return float(current_step) / float(max(1, num_warmup_steps)) 256 | progress = float(current_step - num_warmup_steps) / float( 257 | max(1, num_training_steps - num_warmup_steps) 258 | ) 259 | return max( 260 | min_lr_factor, 261 | min_lr_factor 262 | + (1 - min_lr_factor) 263 | * 0.5 264 | * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), 265 | ) 266 | 267 | 268 | def get_cosine_schedule_with_warmup( 269 | optimizer: torch.optim.Optimizer, 270 | num_warmup_steps: int, 271 | num_training_steps: int, 272 | num_cycles: float = 0.5, 273 | min_lr_factor: float = 0.0, 274 | last_epoch: int = -1, 275 | ): 276 | """ 277 | Create a schedule with a learning rate that decreases following the values of the cosine 278 | function between the initial lr set in the optimizer to 0, after a warmup period during which it 279 | increases linearly between 0 and the initial lr set in the optimizer. 280 | 281 | The min_lr_factor is used to set a minimum learning rate that is a fraction of the initial 282 | learning rate. 283 | 284 | Adapted from `transformers.get_cosine_schedule_with_warmup` to support a minimum learning rate. 285 | 286 | Args: 287 | optimizer ([`~torch.optim.Optimizer`]): 288 | The optimizer for which to schedule the learning rate. 289 | num_warmup_steps (`int`): 290 | The number of steps for the warmup phase. 291 | num_training_steps (`int`): 292 | The total number of training steps. 293 | num_cycles (`float`, *optional*, defaults to 0.5): 294 | The number of waves in the cosine schedule (the defaults is to just decrease from the 295 | max value to 0 following a half-cosine). 296 | min_lr_factor (`float`, *optional*, defaults to 0.0): 297 | last_epoch (`int`, *optional*, defaults to -1): 298 | The index of the last epoch when resuming training. 299 | 300 | Return: 301 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 302 | """ 303 | 304 | lr_lambda = partial( 305 | _get_cosine_schedule_with_warmup_lr_lambda, 306 | num_warmup_steps=num_warmup_steps, 307 | num_training_steps=num_training_steps, 308 | num_cycles=num_cycles, 309 | min_lr_factor=min_lr_factor, 310 | ) 311 | return LambdaLR(optimizer, lr_lambda, last_epoch) 312 | 313 | 314 | def init_wandb(config: T, project: str, sweep_config_path: Path | str | None) -> T: 315 | """Initialize Weights & Biases and return a config updated with sweep hyperparameters. 316 | 317 | If no sweep config is provided, the config is returned as is. 318 | 319 | If a sweep config is provided, wandb is first initialized with the sweep config. This will 320 | cause wandb to choose specific hyperparameters for this instance of the sweep and store them 321 | in wandb.config. We then update the config with these hyperparameters. 322 | 323 | Args: 324 | config: The base config. 325 | project: The name of the wandb project. 326 | sweep_config_path: The path to the sweep config file. If provided, updates the config with 327 | the hyperparameters from this instance of the sweep. 328 | 329 | Returns: 330 | Config updated with sweep hyperparameters (if any). 331 | """ 332 | if sweep_config_path is not None: 333 | with open(sweep_config_path) as f: 334 | sweep_data = yaml.safe_load(f) 335 | wandb.init(config=sweep_data, save_code=True) 336 | else: 337 | load_dotenv(override=True) 338 | wandb.init(project=project, entity=os.getenv("WANDB_ENTITY"), save_code=True) 339 | 340 | # Update the config with the hyperparameters for this sweep (if any) 341 | config = replace_pydantic_model(config, wandb.config) 342 | 343 | # Update the non-frozen keys in the wandb config (only relevant for sweeps) 344 | wandb.config.update(config.model_dump(mode="json")) 345 | return config 346 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "e2e_sae" 3 | version = "2.1.1" 4 | description = "Repo for training sparse autoencoders end-to-end" 5 | requires-python = ">=3.11" 6 | readme = "README.md" 7 | dependencies = [ 8 | "torch", 9 | "torchvision", 10 | "einops", 11 | "pydantic", 12 | "wandb", 13 | "fire", 14 | "tqdm", 15 | "pytest", 16 | "ipykernel", 17 | "transformer-lens", 18 | "jaxtyping", 19 | "python-dotenv", 20 | "zstandard", 21 | "matplotlib", 22 | "seaborn", 23 | "umap-learn", 24 | "tenacity", 25 | "statsmodels", 26 | "automated-interpretability" 27 | ] 28 | 29 | [project.urls] 30 | repository = "https://github.com/ApolloResearch/e2e_sae" 31 | 32 | [project.optional-dependencies] 33 | dev = [ 34 | "ruff", 35 | "pyright", 36 | "pre-commit", 37 | ] 38 | 39 | [build-system] 40 | requires = ["setuptools", "wheel"] 41 | build-backend = "setuptools.build_meta" 42 | 43 | [tool.setuptools] 44 | packages = ["e2e_sae", "e2e_sae.models", "e2e_sae.scripts"] 45 | 46 | [tool.ruff] 47 | line-length = 100 48 | fix = true 49 | ignore = [ 50 | "F722" # Incompatible with jaxtyping 51 | ] 52 | 53 | [tool.ruff.lint] 54 | select = [ 55 | # pycodestyle 56 | "E", 57 | # Pyflakes 58 | "F", 59 | # pyupgrade 60 | "UP", 61 | # flake8-bugbear 62 | "B", 63 | # flake8-simplify 64 | "SIM", 65 | # isort 66 | "I", 67 | ] 68 | 69 | [tool.ruff.format] 70 | # Enable reformatting of code snippets in docstrings. 71 | docstring-code-format = true 72 | 73 | [tool.ruff.isort] 74 | known-third-party = ["wandb"] 75 | 76 | [tool.pyright] 77 | include = ["e2e_sae", "tests"] 78 | 79 | strictListInference = true 80 | strictDictionaryInference = true 81 | strictSetInference = true 82 | reportFunctionMemberAccess = true 83 | reportUnknownParameterType = true 84 | reportIncompatibleMethodOverride = true 85 | reportIncompatibleVariableOverride = true 86 | reportInconsistentConstructorType = true 87 | reportOverlappingOverload = true 88 | reportConstantRedefinition = true 89 | reportImportCycles = false 90 | reportPropertyTypeMismatch = true 91 | reportMissingTypeArgument = true 92 | reportUnnecessaryCast = true 93 | reportUnnecessaryComparison = true 94 | reportUnnecessaryContains = true 95 | reportUnusedExpression = true 96 | reportMatchNotExhaustive = true 97 | reportShadowedImports = true 98 | reportPrivateImportUsage = false 99 | 100 | [tool.pytest.ini_options] 101 | filterwarnings = [ 102 | # https://github.com/google/python-fire/pull/447 103 | "ignore::DeprecationWarning:fire:59", 104 | ] 105 | -------------------------------------------------------------------------------- /tests/test_mlp.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torch import nn 3 | 4 | from e2e_sae.models.mlp import MLP, Layer 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "hidden_sizes, bias, expected_layer_sizes", 9 | [ 10 | # 2 hidden layers, bias False 11 | ([4, 3], False, [(3, 4), (4, 3), (3, 4)]), 12 | # no hidden layers, bias True 13 | ([], True, [(3, 4)]), 14 | # 1 hidden layer, bias True 15 | ([4], True, [(3, 4), (4, 4)]), 16 | ], 17 | ) 18 | def test_mlp_layers( 19 | hidden_sizes: list[int], 20 | bias: bool, 21 | expected_layer_sizes: list[tuple[int, int]], 22 | ) -> None: 23 | """Test the MLP constructor for fixed input and output sizes. 24 | 25 | Verifies the created layers' types, sizes and bias. 26 | 27 | Args: 28 | hidden_sizes: A list of hidden layer sizes. If None, no hidden layers are added. 29 | bias: Whether to add a bias to the Linear layers. 30 | expected_layer_sizes: A list of tuples where each tuple is a pair of in_features and 31 | out_features of a layer. 32 | """ 33 | input_size = 3 34 | output_size = 4 35 | model = MLP( 36 | hidden_sizes, 37 | input_size, 38 | output_size, 39 | bias=bias, 40 | ) 41 | 42 | assert isinstance(model, MLP) 43 | 44 | for i, layer in enumerate(model.layers): 45 | assert isinstance(layer, Layer) 46 | assert isinstance(layer.linear, nn.Linear) 47 | 48 | # Check the in/out feature sizes of Linear layers 49 | assert layer.linear.in_features == expected_layer_sizes[i][0] 50 | assert layer.linear.out_features == expected_layer_sizes[i][1] 51 | # Check bias is not None when bias is True, and None otherwise 52 | assert layer.linear.bias is not None if bias else layer.linear.bias is None 53 | 54 | if i < len(model.layers) - 1: 55 | # Activation layers at indices before the last layer 56 | assert isinstance(layer.activation, nn.GELU) 57 | else: 58 | # No activation function for the last layer 59 | assert not hasattr(layer, "activation") 60 | -------------------------------------------------------------------------------- /tests/test_sae.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | import torch 5 | from transformer_lens import HookedTransformer 6 | 7 | from e2e_sae.loader import load_pretrained_saes 8 | from e2e_sae.models.sparsifiers import SAE 9 | from e2e_sae.models.transformers import SAETransformer 10 | from e2e_sae.utils import save_module, set_seed 11 | from tests.utils import get_tinystories_config 12 | 13 | 14 | def test_orthonormal_initialization(): 15 | """After initialising an SAE, the dictionary components should be orthonormal.""" 16 | set_seed(0) 17 | input_size = 2 18 | n_dict_components = 4 19 | sae = SAE(input_size, n_dict_components) 20 | assert sae.decoder.weight.shape == (input_size, n_dict_components) 21 | # If vectors are orthonormal, the gram matrix (X X^T) should be the identity matrix 22 | assert torch.allclose( 23 | sae.decoder.weight @ sae.decoder.weight.T, torch.eye(input_size), atol=1e-6 24 | ) 25 | 26 | 27 | @pytest.mark.parametrize("retrain_saes", [True, False]) 28 | def test_load_single_pretrained_sae(tmp_path: Path, retrain_saes: bool): 29 | """Test that loading a single pretrained SAE into a tlens model works. 30 | 31 | First create an SAETransformer with a single SAE position. We will save this model to a file 32 | and then load it back in to a new SAETransformer. 33 | 34 | Checks that: 35 | - The trainable parameters don't include the pretrained SAE if retrain_saes is False and 36 | include it if retrain_saes is True 37 | - The new SAE params are the same as the pretrained SAE for the position that was copied 38 | """ 39 | 40 | # Make a dummy tlens config with two layers 41 | tlens_config = { 42 | "n_layers": 2, 43 | "d_model": 2, 44 | "n_ctx": 3, 45 | "d_head": 2, 46 | "act_fn": "gelu", 47 | "d_vocab": 2, 48 | } 49 | tlens_model = HookedTransformer(tlens_config) 50 | 51 | sae_position = "blocks.0.hook_resid_post" 52 | pretrained_config = get_tinystories_config({"saes": {"sae_positions": sae_position}}) 53 | 54 | model = SAETransformer( 55 | tlens_model=tlens_model, 56 | raw_sae_positions=[sae_position], 57 | dict_size_to_input_ratio=pretrained_config.saes.dict_size_to_input_ratio, 58 | init_decoder_orthogonal=False, 59 | ) 60 | # Save the model.saes to a temp file 61 | save_module( 62 | config_dict=pretrained_config.model_dump(mode="json"), 63 | save_dir=tmp_path, 64 | module=model.saes, 65 | model_filename="sae.pth", 66 | ) 67 | 68 | # Get a new config that has more than one sae_position (including the one we saved) 69 | sae_positions = ["blocks.0.hook_resid_post", "blocks.1.hook_resid_post"] 70 | new_config = get_tinystories_config( 71 | { 72 | "saes": { 73 | "sae_positions": sae_positions, 74 | "pretrained_sae_paths": tmp_path / "sae.pth", 75 | } 76 | } 77 | ) 78 | new_tlens_model = HookedTransformer(tlens_config) 79 | new_model = SAETransformer( 80 | tlens_model=new_tlens_model, 81 | raw_sae_positions=sae_positions, 82 | dict_size_to_input_ratio=new_config.saes.dict_size_to_input_ratio, 83 | init_decoder_orthogonal=False, 84 | ) 85 | 86 | assert isinstance(new_config.saes.pretrained_sae_paths, list) 87 | 88 | # Now load in the pretrained SAE to the new model 89 | trainable_param_names = load_pretrained_saes( 90 | saes=new_model.saes, 91 | pretrained_sae_paths=new_config.saes.pretrained_sae_paths, 92 | all_param_names=[name for name, _ in new_model.saes.named_parameters()], 93 | retrain_saes=retrain_saes, 94 | ) 95 | suffixes = ["encoder.0.weight", "encoder.0.bias", "decoder.weight", "decoder.bias"] 96 | block_0_params = [f"blocks-0-hook_resid_post.{suffix}" for suffix in suffixes] 97 | block_1_params = [f"blocks-1-hook_resid_post.{suffix}" for suffix in suffixes] 98 | if retrain_saes: 99 | assert trainable_param_names == block_0_params + block_1_params 100 | else: 101 | assert trainable_param_names == block_1_params 102 | 103 | model_named_params = dict(model.saes.named_parameters()) 104 | new_model_named_params = dict(new_model.saes.named_parameters()) 105 | for suffix in suffixes: 106 | # Check that the params for block 0 are the same as the pretrained SAE 107 | assert torch.allclose( 108 | model_named_params[f"blocks-0-hook_resid_post.{suffix}"], 109 | new_model_named_params[f"blocks-0-hook_resid_post.{suffix}"], 110 | ) 111 | 112 | 113 | def test_load_multiple_pretrained_sae(tmp_path: Path): 114 | """Test that loading multiple pretrained SAE into a tlens model works. 115 | 116 | - Creates an SAETransformer with SAEs in blocks 0 and 1 and save to file. 117 | - Creates another SAETransformer with SAEs in blocks 1 and 2 and save to file. 118 | - Creates a new SAETransformer with SAEs in blocks 0, 1 and 2 and load in the saved SAEs. 119 | 120 | Checks that we have: 121 | - Block 0 should have the params from the first saved SAE 122 | - Blocks 1 and 2 should have the params from the second saved SAE 123 | """ 124 | 125 | # Make a dummy tlens config with two layers 126 | tlens_config = { 127 | "n_layers": 4, 128 | "d_model": 2, 129 | "n_ctx": 3, 130 | "d_head": 2, 131 | "act_fn": "gelu", 132 | "d_vocab": 2, 133 | } 134 | all_positions = [ 135 | "blocks.0.hook_resid_post", 136 | "blocks.1.hook_resid_post", 137 | "blocks.2.hook_resid_post", 138 | "blocks.3.hook_resid_post", 139 | ] 140 | 141 | filenames = ["sae_0.pth", "sae_1.pth"] 142 | sae_position_lists = [ 143 | ["blocks.0.hook_resid_post", "blocks.1.hook_resid_post"], 144 | ["blocks.1.hook_resid_post", "blocks.2.hook_resid_post"], 145 | ] 146 | sae_params = [] 147 | for filename, sae_positions in zip(filenames, sae_position_lists, strict=True): 148 | pretrained_config = get_tinystories_config({"saes": {"sae_positions": sae_positions}}) 149 | tlens_model = HookedTransformer(tlens_config) 150 | model = SAETransformer( 151 | tlens_model=tlens_model, 152 | raw_sae_positions=sae_positions, 153 | dict_size_to_input_ratio=pretrained_config.saes.dict_size_to_input_ratio, 154 | init_decoder_orthogonal=False, 155 | ) 156 | # Save the model.saes to a temp file 157 | save_module( 158 | config_dict=pretrained_config.model_dump(mode="json"), 159 | save_dir=tmp_path, 160 | module=model.saes, 161 | model_filename=filename, 162 | ) 163 | sae_params.append(model.saes) 164 | 165 | new_config = get_tinystories_config( 166 | { 167 | "saes": { 168 | "sae_positions": all_positions, 169 | "pretrained_sae_paths": [tmp_path / filename for filename in filenames], 170 | } 171 | } 172 | ) 173 | # Create a new model to load in the pretrained SAEs 174 | new_tlens_model = HookedTransformer(tlens_config) 175 | new_model = SAETransformer( 176 | tlens_model=new_tlens_model, 177 | raw_sae_positions=all_positions, 178 | dict_size_to_input_ratio=new_config.saes.dict_size_to_input_ratio, 179 | init_decoder_orthogonal=False, 180 | ) 181 | 182 | assert isinstance(new_config.saes.pretrained_sae_paths, list) 183 | # Now load in the pretrained SAE to the new model 184 | trainable_param_names = load_pretrained_saes( 185 | saes=new_model.saes, 186 | pretrained_sae_paths=new_config.saes.pretrained_sae_paths, 187 | all_param_names=[name for name, _ in new_model.saes.named_parameters()], 188 | retrain_saes=False, 189 | ) 190 | 191 | model_named_params = dict(new_model.saes.named_parameters()) 192 | suffixes = ["encoder.0.weight", "encoder.0.bias", "decoder.weight", "decoder.bias"] 193 | assert trainable_param_names == [ 194 | f"blocks-3-hook_resid_post.{suffix}" for suffix in suffixes 195 | ], "Only block 2 should be trainable" 196 | 197 | for suffix in suffixes: 198 | # Check that the params for block 0 are the same as the first pretrained SAE 199 | assert torch.allclose( 200 | model_named_params[f"blocks-0-hook_resid_post.{suffix}"], 201 | sae_params[0].state_dict()[f"blocks-0-hook_resid_post.{suffix}"], 202 | ) 203 | # Check that the params for blocks 1 and 2 are the same as the second pretrained SAE 204 | for block in [1, 2]: 205 | assert torch.allclose( 206 | model_named_params[f"blocks-{block}-hook_resid_post.{suffix}"], 207 | sae_params[1].state_dict()[f"blocks-{block}-hook_resid_post.{suffix}"], 208 | ) 209 | -------------------------------------------------------------------------------- /tests/test_train_tlens_saes.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | from e2e_sae.data import DatasetConfig 6 | from e2e_sae.losses import ( 7 | LogitsKLLoss, 8 | LossConfigs, 9 | OutToInLoss, 10 | SparsityLoss, 11 | ) 12 | from e2e_sae.scripts.train_tlens_saes.run_train_tlens_saes import Config, SAEsConfig 13 | from e2e_sae.scripts.train_tlens_saes.run_train_tlens_saes import main as run_training 14 | 15 | 16 | @pytest.mark.cpuslow 17 | def test_train_tiny_gpt(): 18 | """Test training an SAE on a custom tiny 2-layer GPT-style model. 19 | 20 | NOTE: This could be sped up using a a custom dataset stored locally but we don't yet support 21 | this. 22 | """ 23 | model_path = Path( 24 | "e2e_sae/scripts/train_tlens/sample_models/tiny-gpt2_lr-0.001_bs-16_2024-04-21_14-01-14/epoch_1.pt" 25 | ) 26 | config = Config( 27 | wandb_project=None, 28 | wandb_run_name=None, # If not set, will use a name based on important config values 29 | wandb_run_name_prefix="", 30 | seed=0, 31 | tlens_model_name=None, 32 | tlens_model_path=model_path, 33 | save_dir=None, 34 | n_samples=3, 35 | save_every_n_samples=None, 36 | eval_every_n_samples=2, # Just eval once at start and once during training 37 | eval_n_samples=2, 38 | log_every_n_grad_steps=20, 39 | collect_act_frequency_every_n_samples=2, 40 | act_frequency_n_tokens=2000, 41 | batch_size=2, 42 | effective_batch_size=2, 43 | lr=5e-4, 44 | lr_schedule="cosine", 45 | min_lr_factor=0.1, 46 | warmup_samples=2, 47 | max_grad_norm=10.0, 48 | loss=LossConfigs( 49 | sparsity=SparsityLoss(p_norm=1.0, coeff=1.5), 50 | in_to_orig=None, 51 | out_to_orig=None, 52 | out_to_in=OutToInLoss(coeff=0.0), 53 | logits_kl=LogitsKLLoss(coeff=1.0), 54 | ), 55 | train_data=DatasetConfig( 56 | dataset_name="apollo-research/Skylion007-openwebtext-tokenizer-gpt2", 57 | is_tokenized=True, 58 | tokenizer_name="gpt2", 59 | streaming=True, 60 | split="train", 61 | n_ctx=1024, 62 | ), 63 | eval_data=DatasetConfig( 64 | dataset_name="apollo-research/Skylion007-openwebtext-tokenizer-gpt2", 65 | is_tokenized=True, 66 | tokenizer_name="gpt2", 67 | streaming=True, 68 | split="train", 69 | n_ctx=1024, 70 | ), 71 | saes=SAEsConfig( 72 | retrain_saes=False, 73 | pretrained_sae_paths=None, 74 | sae_positions=[ 75 | "blocks.1.hook_resid_pre", 76 | ], 77 | dict_size_to_input_ratio=1.0, 78 | ), 79 | ) 80 | run_training(config) 81 | -------------------------------------------------------------------------------- /tests/test_transformer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from e2e_sae.loader import load_tlens_model 5 | from e2e_sae.models.transformers import SAETransformer 6 | from tests.utils import get_tinystories_config 7 | 8 | 9 | @pytest.fixture(scope="module") 10 | def tinystories_model() -> SAETransformer: 11 | tlens_model = load_tlens_model( 12 | tlens_model_name="roneneldan/TinyStories-1M", tlens_model_path=None 13 | ) 14 | sae_positions = ["blocks.2.hook_resid_pre"] 15 | config = get_tinystories_config({"saes": {"sae_positions": sae_positions}}) 16 | model = SAETransformer( 17 | tlens_model=tlens_model, 18 | raw_sae_positions=sae_positions, 19 | dict_size_to_input_ratio=config.saes.dict_size_to_input_ratio, 20 | ) 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | model.to(device) 23 | return model 24 | 25 | 26 | def test_generate(tinystories_model: SAETransformer, prompt: str = "One", max_new_tokens: int = 2): 27 | completion = tinystories_model.generate( 28 | input=prompt, sae_positions=None, max_new_tokens=max_new_tokens, temperature=0 29 | ) 30 | assert completion == "One day," 31 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from e2e_sae.utils import get_linear_lr_schedule 4 | 5 | 6 | class TestLinearLRSchedule: 7 | @pytest.mark.parametrize( 8 | "warmup_samples, cooldown_samples, n_samples, effective_batch_size, expected_error", 9 | [ 10 | (100, 50, None, 10, AssertionError), # Cooldown requested without setting n_samples 11 | (100, 100, 150, 10, AssertionError), # Cooldown starts before warmup ends 12 | ], 13 | ) 14 | def test_value_errors( 15 | self, 16 | warmup_samples: int, 17 | cooldown_samples: int, 18 | n_samples: int | None, 19 | effective_batch_size: int, 20 | expected_error: type[BaseException], 21 | ): 22 | with pytest.raises(expected_error): 23 | get_linear_lr_schedule( 24 | warmup_samples, cooldown_samples, n_samples, effective_batch_size 25 | ) 26 | 27 | def test_constant_lr(self): 28 | lr_schedule = get_linear_lr_schedule( 29 | warmup_samples=0, cooldown_samples=0, n_samples=None, effective_batch_size=10 30 | ) 31 | for step in range(0, 100): 32 | assert lr_schedule(step) == 1.0, "Learning rate should be constant at 1.0" 33 | 34 | @pytest.mark.parametrize( 35 | "warmup_samples, cooldown_samples, n_samples, effective_batch_size, step, expected_lr", 36 | [ 37 | (100, 0, None, 10, 5, 0.6), # During warmup 38 | (100, 200, 1000, 10, 50, 1.0), # After warmup, before cooldown 39 | (100, 100, 1000, 10, 92, 0.9), # During cooldown 40 | (100, 100, 1000, 10, 101, 0.0), # After cooldown 41 | ], 42 | ) 43 | def test_learning_rate_transitions( 44 | self, 45 | warmup_samples: int, 46 | cooldown_samples: int, 47 | n_samples: int, 48 | effective_batch_size: int, 49 | step: int, 50 | expected_lr: float, 51 | ): 52 | # Note that a `step` corresponds to the upcoming sample. 53 | lr_schedule = get_linear_lr_schedule( 54 | warmup_samples, cooldown_samples, n_samples, effective_batch_size 55 | ) 56 | assert lr_schedule(step) == pytest.approx( 57 | expected_lr 58 | ), f"Learning rate at step {step} should be {expected_lr}" 59 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import yaml 4 | 5 | from e2e_sae.scripts.train_tlens_saes.run_train_tlens_saes import Config 6 | from e2e_sae.settings import REPO_ROOT 7 | from e2e_sae.utils import replace_pydantic_model 8 | 9 | TINYSTORIES_CONFIG = f"{REPO_ROOT}/e2e_sae/scripts/train_tlens_saes/tinystories_1M_e2e.yaml" 10 | 11 | 12 | def get_tinystories_config(*updates: dict[str, Any]) -> Config: 13 | """Load the tinystories config and update it with the given updates.""" 14 | # Set the wandb_project to null since we never want to log tests to wandb 15 | updates = updates + ({"wandb_project": None, "save_dir": None},) 16 | with open(TINYSTORIES_CONFIG) as f: 17 | config = Config(**yaml.safe_load(f)) 18 | return replace_pydantic_model(config, *updates) 19 | --------------------------------------------------------------------------------