├── .env.example ├── .github ├── pull_request_template.md └── workflows │ └── checks.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode ├── extensions.json ├── launch.json └── settings-example.json ├── ACCESS.md ├── Makefile ├── README.md ├── conftest.py ├── pyproject.toml ├── spd ├── __init__.py ├── experiments │ ├── resid_mlp │ │ ├── model_interp.py │ │ ├── models.py │ │ ├── plotting.py │ │ ├── resid_mlp_dataset.py │ │ ├── resid_mlp_decomposition.py │ │ ├── resid_mlp_sweep_config.yaml │ │ ├── resid_mlp_topk_config.yaml │ │ ├── spd_interp.py │ │ └── train_resid_mlp.py │ └── tms │ │ ├── models.py │ │ ├── spd_interp.py │ │ ├── tms_decomposition.py │ │ ├── tms_lp_config.yaml │ │ ├── tms_sweep_config.yaml │ │ ├── tms_topk_config.yaml │ │ └── train_tms.py ├── hooks.py ├── log.py ├── models │ ├── __init__.py │ ├── base.py │ └── components.py ├── module_utils.py ├── plotting.py ├── run_spd.py ├── settings.py ├── types.py ├── utils.py └── wandb_utils.py └── tests ├── test_resid_mlp.py ├── test_spd_losses.py ├── test_spd_model.py ├── test_tms.py └── test_utils.py /.env.example: -------------------------------------------------------------------------------- 1 | WANDB_API_KEY=your_api_key 2 | WANDB_ENTITY=your_entity_name -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | 4 | ## Related Issue 5 | 6 | 7 | ## Motivation and Context 8 | 9 | 10 | ## How Has This Been Tested? 11 | 12 | 13 | ## Does this PR introduce a breaking change? 14 | 15 | -------------------------------------------------------------------------------- /.github/workflows/checks.yaml: -------------------------------------------------------------------------------- 1 | name: Checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - '**' 10 | # Allows you to run this workflow manually from the Actions tab 11 | workflow_dispatch: 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | timeout-minutes: 15 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v4 20 | 21 | - name: Set up Python 3.11 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: '3.11' 25 | 26 | - name: Cache Python Dependencies 27 | uses: actions/cache@v3 28 | id: cache-env 29 | with: 30 | path: | 31 | ./venv 32 | key: ${{ runner.os }}-${{ runner.environment }}-cache-v1-${{ hashFiles('**/pyproject.toml') }} 33 | 34 | - name: Install Dependencies 35 | if: steps.cache-env.outputs.cache-hit != 'true' 36 | run: | 37 | python -m venv venv 38 | source ./venv/bin/activate 39 | pip install ".[dev]" 40 | echo PATH=$PATH >> $GITHUB_ENV 41 | echo "./venv/bin" >> $GITHUB_PATH 42 | 43 | - name: Restore Venv 44 | if: steps.cache-env.outputs.cache-hit == 'true' 45 | run: | 46 | source ./venv/bin/activate 47 | echo PATH=$PATH >> $GITHUB_ENV 48 | echo "./venv/bin" >> $GITHUB_PATH 49 | 50 | - name: Print dependencies 51 | run: pip freeze 52 | 53 | - name: Run pyright 54 | run: pyright --pythonpath ./venv/bin/python 55 | 56 | - name: Run ruff lint 57 | run: ruff check --fix-only . 58 | 59 | - name: Run ruff format 60 | run: ruff format . 61 | 62 | - name: Run tests 63 | run: python -m pytest tests/ --runslow --durations=10 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/out/ 2 | neuronpedia_outputs/ 3 | .env 4 | .vscode/settings.json 5 | 6 | wandb/ 7 | .data/ 8 | .checkpoints/ 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/#use-with-ide 118 | .pdm.toml 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # If you want branch protection for the main branch, use this hook 3 | # - repo: https://github.com/pre-commit/pre-commit-hooks 4 | # rev: v4.5.0 5 | # hooks: 6 | # - id: no-commit-to-branch 7 | # args: ["--branch=main"] 8 | # stages: 9 | # - commit 10 | - repo: local 11 | hooks: 12 | - id: pyright 13 | name: Pyright 14 | entry: pyright 15 | language: system 16 | types: [python] 17 | stages: 18 | - pre-commit 19 | 20 | - id: ruff-lint 21 | name: Ruff lint 22 | entry: ruff check 23 | args: ["--fix-only"] 24 | language: system 25 | types: [python] 26 | stages: 27 | - pre-commit 28 | 29 | - id: ruff-format 30 | name: Ruff format 31 | entry: ruff format 32 | language: system 33 | types: [python] 34 | stages: 35 | - pre-commit -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-toolsai.jupyter", 4 | "ms-python.python", 5 | "charliermarsh.ruff", 6 | "stkb.rewrap" 7 | ] 8 | } 9 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Attach", 9 | "type": "debugpy", 10 | "request": "attach", 11 | "connect": { 12 | "host": "localhost", 13 | "port": 5678 14 | }, 15 | }, 16 | { 17 | "name": "tms_lp", 18 | "type": "debugpy", 19 | "request": "launch", 20 | "program": "${workspaceFolder}/spd/experiments/tms/tms_decomposition.py", 21 | "args": "${workspaceFolder}/spd/experiments/tms/tms_lp_config.yaml", 22 | "console": "integratedTerminal", 23 | "justMyCode": true, 24 | "env": { 25 | "PYDEVD_DISABLE_FILE_VALIDATION": "1" 26 | } 27 | }, 28 | { 29 | "name": "tms_topk", 30 | "type": "debugpy", 31 | "request": "launch", 32 | "program": "${workspaceFolder}/spd/experiments/tms/tms_decomposition.py", 33 | "args": "${workspaceFolder}/spd/experiments/tms/tms_topk_config.yaml", 34 | "console": "integratedTerminal", 35 | "justMyCode": true, 36 | "env": { 37 | "PYDEVD_DISABLE_FILE_VALIDATION": "1" 38 | } 39 | }, 40 | { 41 | "name": "resid_mlp_topk", 42 | "type": "debugpy", 43 | "request": "launch", 44 | "program": "${workspaceFolder}/spd/experiments/resid_mlp/resid_mlp_decomposition.py", 45 | "args": "${workspaceFolder}/spd/experiments/resid_mlp/resid_mlp_topk_config.yaml", 46 | "console": "integratedTerminal", 47 | "justMyCode": true, 48 | "env": { 49 | "PYDEVD_DISABLE_FILE_VALIDATION": "1" 50 | } 51 | }, 52 | ] 53 | } -------------------------------------------------------------------------------- /.vscode/settings-example.json: -------------------------------------------------------------------------------- 1 | { 2 | "[python]": { 3 | "editor.formatOnSave": true, 4 | "editor.codeActionsOnSave": { 5 | "source.fixAll": "explicit", 6 | "source.organizeImports": "explicit" 7 | }, 8 | "editor.defaultFormatter": "charliermarsh.ruff" 9 | }, 10 | "rewrap.autoWrap.enabled": true, 11 | "rewrap.wrappingColumn": 100, 12 | "notebook.formatOnSave.enabled": true 13 | } -------------------------------------------------------------------------------- /ACCESS.md: -------------------------------------------------------------------------------- 1 | # Disclosure Level - Public 2 | 3 | See Privacy Levels [here](https://www.apolloresearch.ai/blog/security) 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: install 2 | install: 3 | pip install -e . 4 | 5 | .PHONY: install-dev 6 | install-dev: 7 | pip install -e .[dev] 8 | pre-commit install 9 | 10 | .PHONY: type 11 | type: 12 | SKIP=no-commit-to-branch pre-commit run -a pyright 13 | 14 | .PHONY: format 15 | format: 16 | # Fix all autofixable problems (which sorts imports) then format errors 17 | SKIP=no-commit-to-branch pre-commit run -a ruff-lint 18 | SKIP=no-commit-to-branch pre-commit run -a ruff-format 19 | 20 | .PHONY: check 21 | check: 22 | SKIP=no-commit-to-branch pre-commit run -a --hook-stage commit 23 | 24 | .PHONY: test 25 | test: 26 | python -m pytest tests/ 27 | 28 | .PHONY: test-all 29 | test-all: 30 | python -m pytest tests/ --runslow -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # APD - Attribution-based Parameter Decomposition 2 | Code used in the paper [Interpretability in Parameter Space: Minimizing 3 | Mechanistic Description Length with 4 | Attribution-based Parameter Decomposition](https://publications.apolloresearch.ai/apd) 5 | 6 | Weights and Bias report accompanying the paper: https://api.wandb.ai/links/apollo-interp/h5ekyxm7 7 | 8 | Note: previously called Sparse Parameter Decomposition (SPD). The package name will remain as `spd` 9 | for now, but the repository has been renamed to `apd`. 10 | 11 | ## Installation 12 | From the root of the repository, run one of 13 | 14 | ```bash 15 | make install-dev # To install the package, dev requirements and pre-commit hooks 16 | make install # To just install the package (runs `pip install -e .`) 17 | ``` 18 | 19 | ## Usage 20 | Place your wandb information in a .env file. You can use the .env.example file as an example. 21 | 22 | The repository consists of several `experiments`, each of which containing scripts to train target 23 | models and run APD. 24 | - `spd/experiments/tms` - Toy model of superposition 25 | - `spd/experiments/resid_mlp` - Toy model of compressed computation and toy model of distributed 26 | representations 27 | 28 | Deprecated: 29 | - `spd/experiments/piecewise` - Handcoded gated function model. Use [this](117284172497ca420f22c29cef3ddcd5e4bcceb8) commit if you need to use 30 | this experiment. 31 | 32 | ### Train a target model 33 | All experiments require training a target model. Look for the `train_*.py` script in the experiment 34 | directory. Your trained model will be saved locally and uploaded to wandb. 35 | 36 | ### Run APD 37 | APD can be run by executing any of the `*_decomposition.py` scripts defined in the experiment 38 | subdirectories. A config file is required for each experiment, which can be found in the same 39 | directory. For example: 40 | ```bash 41 | python spd/experiments/tms/tms_decomposition.py spd/experiments/tms/tms_topk_config.yaml 42 | ``` 43 | will run SPD on TMS with the config file `tms_topk_config.yaml` (which is the main config file used 44 | for the TMS experiments in the paper). 45 | 46 | Wandb sweep files are also provided in the experiment subdirectories, and can be run with e.g.: 47 | ```bash 48 | wandb sweep spd/experiments/tms/tms_sweep_config.yaml 49 | ``` 50 | 51 | All experiments call the `optimize` function in `spd/run_spd.py`, which contains the main APD logic. 52 | 53 | ### Analyze results 54 | Experiments contain `*_interp.py` scripts which generate the plots used in the paper. 55 | 56 | ## Development 57 | 58 | Suggested extensions and settings for VSCode/Cursor are provided in `.vscode/`. To use the suggested 59 | settings, copy `.vscode/settings-example.json` to `.vscode/settings.json`. 60 | 61 | There are various `make` commands that may be helpful 62 | 63 | ```bash 64 | make check # Run pre-commit on all files (i.e. pyright, ruff linter, and ruff formatter) 65 | make type # Run pyright on all files 66 | make format # Run ruff linter and formatter on all files 67 | make test # Run tests that aren't marked `slow` 68 | make test-all # Run all tests 69 | ``` -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | """Add runslow option and skip slow tests if not specified. 2 | 3 | Taken from https://docs.pytest.org/en/latest/example/simple.html. 4 | """ 5 | 6 | from collections.abc import Iterable 7 | 8 | import pytest 9 | from pytest import Config, Item, Parser 10 | 11 | 12 | def pytest_addoption(parser: Parser) -> None: 13 | parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") 14 | 15 | 16 | def pytest_configure(config: Config) -> None: 17 | config.addinivalue_line("markers", "slow: mark test as slow to run") 18 | 19 | 20 | def pytest_collection_modifyitems(config: Config, items: Iterable[Item]) -> None: 21 | if config.getoption("--runslow"): 22 | # --runslow given in cli: do not skip slow tests 23 | return 24 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 25 | for item in items: 26 | if "slow" in item.keywords: 27 | item.add_marker(skip_slow) 28 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "spd" 3 | version = "0.0.1" 4 | description = "Sparse Parameter Decomposition" 5 | requires-python = ">=3.11" 6 | readme = "README.md" 7 | dependencies = [ 8 | "torch<2.6.0", 9 | "torchvision", 10 | "pydantic", 11 | "wandb", 12 | "fire", 13 | "tqdm", 14 | "pytest", 15 | "ipykernel", 16 | "transformers", 17 | "transformer-lens", 18 | "matplotlib==3.9.1", # Avoid frequent pyright errors with new matplotlib versions 19 | "numpy", 20 | "python-dotenv", 21 | "wandb<=0.17.7", # due to https://github.com/wandb/wandb/issues/8248 22 | "sympy", 23 | ] 24 | 25 | [project.optional-dependencies] 26 | dev = [ 27 | "ruff", 28 | "pyright", 29 | "pre-commit", 30 | ] 31 | 32 | [build-system] 33 | requires = ["setuptools", "wheel"] 34 | build-backend = "setuptools.build_meta" 35 | 36 | [tool.setuptools] 37 | packages = ["spd", "spd.models", "spd.experiments"] 38 | 39 | [tool.ruff] 40 | line-length = 100 41 | fix = true 42 | ignore = [ 43 | "F722", # Incompatible with jaxtyping 44 | "E731" # I think lambda functions are fine in several places 45 | ] 46 | 47 | [tool.ruff.lint] 48 | select = [ 49 | # pycodestyle 50 | "E", 51 | # Pyflakes 52 | "F", 53 | # pyupgrade 54 | "UP", 55 | # flake8-bugbear 56 | "B", 57 | # flake8-simplify 58 | "SIM", 59 | # isort 60 | "I", 61 | ] 62 | 63 | [tool.ruff.format] 64 | # Enable reformatting of code snippets in docstrings. 65 | docstring-code-format = true 66 | 67 | [tool.ruff.isort] 68 | known-third-party = ["wandb"] 69 | 70 | [tool.pyright] 71 | include = ["spd", "tests"] 72 | 73 | strictListInference = true 74 | strictDictionaryInference = true 75 | strictSetInference = true 76 | reportFunctionMemberAccess = true 77 | reportUnknownParameterType = true 78 | reportIncompatibleMethodOverride = true 79 | reportIncompatibleVariableOverride = true 80 | reportInconsistentConstructorType = true 81 | reportOverlappingOverload = true 82 | reportConstantRedefinition = true 83 | reportImportCycles = true 84 | reportPropertyTypeMismatch = true 85 | reportMissingTypeArgument = true 86 | reportUnnecessaryCast = true 87 | reportUnnecessaryComparison = true 88 | reportUnnecessaryContains = true 89 | reportUnusedExpression = true 90 | reportMatchNotExhaustive = true 91 | reportShadowedImports = true 92 | reportPrivateImportUsage = false 93 | reportCallIssue = true 94 | 95 | [tool.pytest.ini_options] 96 | filterwarnings = [ 97 | # https://github.com/google/python-fire/pull/447 98 | "ignore::DeprecationWarning:fire:59", 99 | ] -------------------------------------------------------------------------------- /spd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloResearch/apd/4e29d4375c333dff3a1046d36da726cffb25af65/spd/__init__.py -------------------------------------------------------------------------------- /spd/experiments/resid_mlp/model_interp.py: -------------------------------------------------------------------------------- 1 | # %% Imports 2 | 3 | import einops 4 | import matplotlib.pyplot as plt 5 | import torch 6 | 7 | from spd.experiments.resid_mlp.models import ( 8 | ResidualMLPModel, 9 | ) 10 | from spd.experiments.resid_mlp.plotting import ( 11 | calculate_virtual_weights, 12 | plot_2d_snr, 13 | plot_all_relu_curves, 14 | plot_individual_feature_response, 15 | plot_resid_vs_mlp_out, 16 | plot_single_feature_response, 17 | plot_single_relu_curve, 18 | relu_contribution_plot, 19 | ) 20 | from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset 21 | from spd.experiments.resid_mlp.train_resid_mlp import ResidMLPTrainConfig 22 | from spd.plotting import plot_matrix 23 | from spd.settings import REPO_ROOT 24 | from spd.types import ModelPath 25 | from spd.utils import set_seed 26 | 27 | # %% Load model and config 28 | 29 | out_dir = REPO_ROOT / "spd/experiments/resid_mlp/out/figures/" 30 | out_dir.mkdir(parents=True, exist_ok=True) 31 | 32 | set_seed(0) 33 | device = "cpu" if torch.cuda.is_available() else "cpu" 34 | path: ModelPath = "wandb:spd-train-resid-mlp/runs/zas5yjdl" # 1 layer 35 | # path: ModelPath = "wandb:spd-train-resid-mlp/runs/sv23xrhj" # 2 layers 36 | model, train_config_dict, label_coeffs = ResidualMLPModel.from_pretrained(path) 37 | model = model.to(device) 38 | train_config = ResidMLPTrainConfig(**train_config_dict) 39 | dataset = ResidualMLPDataset( 40 | n_instances=train_config.resid_mlp_config.n_instances, 41 | n_features=train_config.resid_mlp_config.n_features, 42 | feature_probability=train_config.feature_probability, 43 | device=device, 44 | calc_labels=False, 45 | label_type=train_config.label_type, 46 | act_fn_name=train_config.resid_mlp_config.act_fn_name, 47 | label_fn_seed=train_config.label_fn_seed, 48 | label_coeffs=label_coeffs, 49 | data_generation_type=train_config.data_generation_type, 50 | ) 51 | if train_config.data_generation_type == "at_least_zero_active": 52 | # In the future this will be merged into generate_batch 53 | batch = dataset._generate_multi_feature_batch_no_zero_samples( 54 | train_config.batch_size, buffer_ratio=2 55 | ) 56 | if isinstance(dataset, ResidualMLPDataset) and dataset.label_fn is not None: 57 | labels = dataset.label_fn(batch) 58 | else: 59 | labels = batch.clone().detach() 60 | else: 61 | batch, labels = dataset.generate_batch(train_config.batch_size) 62 | 63 | n_layers = train_config.resid_mlp_config.n_layers 64 | # %% Plot feature response with one active feature 65 | fig = plot_individual_feature_response( 66 | lambda batch: model(batch), 67 | model_config=train_config.resid_mlp_config, 68 | device=device, 69 | sweep=False, 70 | plot_type="line", 71 | ) 72 | fig = plot_individual_feature_response( 73 | lambda batch: model(batch), 74 | model_config=train_config.resid_mlp_config, 75 | device=device, 76 | sweep=True, 77 | plot_type="line", 78 | ) 79 | plt.show() 80 | 81 | # %% Simple plot for paper appendix 82 | 83 | fig, axes = plt.subplots(ncols=2, figsize=(10, 5), constrained_layout=True, sharey=True) 84 | ax1, ax2 = axes # type: ignore 85 | plot_single_feature_response( 86 | lambda batch: model(batch), 87 | model_config=train_config.resid_mlp_config, 88 | device=device, 89 | subtract_inputs=False, 90 | feature_idx=42, 91 | ax=ax1, 92 | ) 93 | plot_single_relu_curve( 94 | lambda batch: model(batch), 95 | model_config=train_config.resid_mlp_config, 96 | device=device, 97 | subtract_inputs=False, 98 | feature_idx=42, 99 | ax=ax2, 100 | ) 101 | fig.savefig( 102 | out_dir / f"resid_mlp_feature_response_single_{n_layers}layers.png", 103 | bbox_inches="tight", 104 | dpi=300, 105 | ) 106 | print(f"Saved figure to {out_dir / f'resid_mlp_feature_response_single_{n_layers}layers.png'}") 107 | 108 | fig, axes = plt.subplots(ncols=2, figsize=(10, 5), constrained_layout=True, sharey=True) 109 | ax1, ax2 = axes # type: ignore 110 | plot_individual_feature_response( 111 | lambda batch: model(batch), 112 | model_config=train_config.resid_mlp_config, 113 | device=device, 114 | sweep=False, 115 | subtract_inputs=False, 116 | ax=ax1, 117 | cbar=False, 118 | ) 119 | ax1.set_title("Outputs one-hot inputs (coloured by input index)") 120 | plot_all_relu_curves( 121 | lambda batch: model(batch), 122 | model_config=train_config.resid_mlp_config, 123 | ax=ax2, 124 | device=device, 125 | subtract_inputs=False, 126 | ) 127 | # Colorbar 128 | cmap_viridis = plt.get_cmap("viridis") 129 | sm = plt.cm.ScalarMappable( 130 | cmap=cmap_viridis, norm=plt.Normalize(0, train_config.resid_mlp_config.n_features) 131 | ) 132 | sm.set_array([]) 133 | cbar = plt.colorbar(sm, ax=ax2, orientation="vertical") 134 | cbar.set_label("Active input feature index") 135 | 136 | ax2.plot([], [], color="red", ls="--", label=r"Label ($x+\mathrm{ReLU}(x)$)") 137 | ax2.legend(loc="upper left") 138 | 139 | fig.savefig( 140 | out_dir / f"resid_mlp_feature_response_multi_{n_layers}layers.png", bbox_inches="tight", dpi=300 141 | ) 142 | print(f"Saved figure to {out_dir / f'resid_mlp_feature_response_multi_{n_layers}layers.png'}") 143 | 144 | # %% 145 | 146 | 147 | instance_idx = 0 148 | nrows = 10 149 | fig, axs = plt.subplots(nrows=nrows, ncols=1, constrained_layout=True, figsize=(10, 3 + 4 * nrows)) 150 | fig.suptitle(f"Model {path}") 151 | for i in range(nrows): 152 | ax = axs[i] # type: ignore 153 | plot_resid_vs_mlp_out( 154 | target_model=model, device=device, ax=ax, instance_idx=instance_idx, feature_idx=i 155 | ) 156 | plt.show() 157 | 158 | 159 | # %% Show connection strength between ReLUs and features 160 | virtual_weights = calculate_virtual_weights(target_model=model, device=device) 161 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 5), constrained_layout=True) # type: ignore 162 | 163 | relu_contribution_plot( 164 | ax1=ax1, 165 | ax2=ax2, 166 | all_diag_relu_conns=virtual_weights["diag_relu_conns"], 167 | model=model, 168 | device=device, 169 | instance_idx=0, 170 | ) 171 | plt.show() 172 | 173 | # %% Calculate S/N ratio for 1 and 2 active features. 174 | fig = plot_2d_snr(model, device) 175 | plt.show() 176 | 177 | # %% Plot virtual weights 178 | 179 | fig = plt.figure(constrained_layout=True, figsize=(20, 20)) 180 | gs = fig.add_gridspec(ncols=2, nrows=3) 181 | ax1 = fig.add_subplot(gs[0, 0]) 182 | ax2 = fig.add_subplot(gs[0, 1]) 183 | ax3 = fig.add_subplot(gs[1:, :]) 184 | virtual_weights = calculate_virtual_weights(target_model=model, device=device) 185 | instance_idx = 0 186 | in_conns = virtual_weights["in_conns"][instance_idx].cpu().detach() 187 | out_conns = virtual_weights["out_conns"][instance_idx].cpu().detach() 188 | W_E_W_U = einops.einsum( 189 | virtual_weights["W_E"][instance_idx], 190 | virtual_weights["W_U"][instance_idx], 191 | "n_features1 d_embed, d_embed n_features2 -> n_features1 n_features2", 192 | ) 193 | plot_matrix( 194 | ax1, 195 | in_conns.T, 196 | "Virtual input weights $(W_E W_{in})^T$", 197 | "Features", 198 | "Neurons", 199 | colorbar_format="%.2f", 200 | ) 201 | plot_matrix( 202 | ax2, 203 | out_conns, 204 | "Virtual output weights $W_{out} W_U$", 205 | "Features", 206 | "Neurons", 207 | colorbar_format="%.2f", 208 | ) 209 | ax2.xaxis.set_label_position("top") 210 | plot_matrix( 211 | ax3, 212 | W_E_W_U, 213 | "Virtual weights $W_E W_U$", 214 | "Features", 215 | "Features", 216 | colorbar_format="%.2f", 217 | ) 218 | plt.show() 219 | 220 | # %% 221 | -------------------------------------------------------------------------------- /spd/experiments/resid_mlp/models.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections.abc import Callable 3 | from pathlib import Path 4 | from typing import Any, Literal 5 | 6 | import einops 7 | import torch 8 | import torch.nn.functional as F 9 | import wandb 10 | import yaml 11 | from jaxtyping import Bool, Float 12 | from pydantic import BaseModel, ConfigDict, Field, PositiveInt 13 | from torch import Tensor, nn 14 | from wandb.apis.public import Run 15 | 16 | from spd.hooks import HookedRootModule 17 | from spd.log import logger 18 | from spd.models.base import SPDModel 19 | from spd.models.components import Linear, LinearComponent 20 | from spd.module_utils import init_param_ 21 | from spd.run_spd import Config, ResidualMLPTaskConfig 22 | from spd.types import WANDB_PATH_PREFIX, ModelPath 23 | from spd.utils import replace_deprecated_param_names 24 | from spd.wandb_utils import download_wandb_file, fetch_latest_wandb_checkpoint, fetch_wandb_run_dir 25 | 26 | 27 | class MLP(nn.Module): 28 | """An MLP with an optional n_instances dimension.""" 29 | 30 | def __init__( 31 | self, 32 | d_model: int, 33 | d_mlp: int, 34 | act_fn: Callable[[Tensor], Tensor], 35 | in_bias: bool, 36 | out_bias: bool, 37 | init_scale: float, 38 | init_type: Literal["kaiming_uniform", "xavier_normal"] = "kaiming_uniform", 39 | n_instances: int | None = None, 40 | spd_kwargs: dict[str, Any] | None = None, 41 | ): 42 | super().__init__() 43 | self.n_instances = n_instances 44 | self.d_model = d_model 45 | self.d_mlp = d_mlp 46 | self.act_fn = act_fn 47 | 48 | if spd_kwargs: 49 | self.mlp_in = LinearComponent( 50 | d_in=d_model, 51 | d_out=d_mlp, 52 | n_instances=n_instances, 53 | init_type=init_type, 54 | init_scale=init_scale, 55 | C=spd_kwargs["C"], 56 | m=spd_kwargs["m"], 57 | ) 58 | self.mlp_out = LinearComponent( 59 | d_in=d_mlp, 60 | d_out=d_model, 61 | n_instances=n_instances, 62 | init_type=init_type, 63 | init_scale=init_scale, 64 | C=spd_kwargs["C"], 65 | m=spd_kwargs["m"], 66 | ) 67 | else: 68 | self.mlp_in = Linear( 69 | d_in=d_model, 70 | d_out=d_mlp, 71 | n_instances=n_instances, 72 | init_type=init_type, 73 | init_scale=init_scale, 74 | ) 75 | self.mlp_out = Linear( 76 | d_in=d_mlp, 77 | d_out=d_model, 78 | n_instances=n_instances, 79 | init_type=init_type, 80 | init_scale=init_scale, 81 | ) 82 | 83 | self.bias1 = None 84 | self.bias2 = None 85 | if in_bias: 86 | shape = (n_instances, d_mlp) if n_instances is not None else d_mlp 87 | self.bias1 = nn.Parameter(torch.zeros(shape)) 88 | if out_bias: 89 | shape = (n_instances, d_model) if n_instances is not None else d_model 90 | self.bias2 = nn.Parameter(torch.zeros(shape)) 91 | 92 | def forward( 93 | self, 94 | x: Float[Tensor, "batch ... d_model"], 95 | topk_mask: Bool[Tensor, "batch ... C"] | None = None, 96 | ) -> tuple[Float[Tensor, "batch ... d_model"],]: 97 | """Run a forward pass and cache pre and post activations for each parameter. 98 | 99 | Note that we don't need to cache pre activations for the biases. We also don't care about 100 | the output bias which is always zero. 101 | """ 102 | mid_pre_act_fn = self.mlp_in(x, topk_mask=topk_mask) 103 | if self.bias1 is not None: 104 | mid_pre_act_fn = mid_pre_act_fn + self.bias1 105 | mid = self.act_fn(mid_pre_act_fn) 106 | out = self.mlp_out(mid, topk_mask=topk_mask) 107 | if self.bias2 is not None: 108 | out = out + self.bias2 109 | return out 110 | 111 | 112 | class ResidualMLPPaths(BaseModel): 113 | """Paths to output files from a ResidualMLPModel training run.""" 114 | 115 | resid_mlp_train_config: Path 116 | label_coeffs: Path 117 | checkpoint: Path 118 | 119 | 120 | class ResidualMLPConfig(BaseModel): 121 | model_config = ConfigDict(extra="forbid", frozen=True) 122 | n_instances: PositiveInt 123 | n_features: PositiveInt 124 | d_embed: PositiveInt 125 | d_mlp: PositiveInt 126 | n_layers: PositiveInt 127 | act_fn_name: Literal["gelu", "relu"] = Field( 128 | description="Defines the activation function in the model. Also used in the labeling " 129 | "function if label_type is act_plus_resid." 130 | ) 131 | apply_output_act_fn: bool 132 | in_bias: bool 133 | out_bias: bool 134 | init_scale: float = 1.0 135 | 136 | 137 | class ResidualMLPModel(HookedRootModule): 138 | def __init__(self, config: ResidualMLPConfig): 139 | super().__init__() 140 | self.config = config 141 | self.W_E = nn.Parameter(torch.empty(config.n_instances, config.n_features, config.d_embed)) 142 | init_param_(self.W_E, scale=config.init_scale) 143 | self.W_U = nn.Parameter(torch.empty(config.n_instances, config.d_embed, config.n_features)) 144 | init_param_(self.W_U, scale=config.init_scale) 145 | 146 | assert config.act_fn_name in ["gelu", "relu"] 147 | self.act_fn = F.gelu if config.act_fn_name == "gelu" else F.relu 148 | self.layers = nn.ModuleList( 149 | [ 150 | MLP( 151 | n_instances=config.n_instances, 152 | d_model=config.d_embed, 153 | d_mlp=config.d_mlp, 154 | act_fn=self.act_fn, 155 | in_bias=config.in_bias, 156 | out_bias=config.out_bias, 157 | init_scale=config.init_scale, 158 | ) 159 | for _ in range(config.n_layers) 160 | ] 161 | ) 162 | self.setup() 163 | 164 | def forward( 165 | self, 166 | x: Float[Tensor, "batch n_instances n_features"], 167 | return_residual: bool = False, 168 | ) -> Float[Tensor, "batch n_instances n_features"] | Float[Tensor, "batch n_instances d_embed"]: 169 | # Make sure that n_instances are correct to avoid unintended broadcasting 170 | assert x.shape[1] == self.config.n_instances, "n_instances mismatch" 171 | assert x.shape[2] == self.config.n_features, "n_features mismatch" 172 | residual = einops.einsum( 173 | x, 174 | self.W_E, 175 | "batch n_instances n_features, n_instances n_features d_embed -> batch n_instances d_embed", 176 | ) 177 | for layer in self.layers: 178 | out = layer(residual) 179 | residual = residual + out 180 | out = einops.einsum( 181 | residual, 182 | self.W_U, 183 | "batch n_instances d_embed, n_instances d_embed n_features -> batch n_instances n_features", 184 | ) 185 | if self.config.apply_output_act_fn: 186 | out = self.act_fn(out) 187 | return residual if return_residual else out 188 | 189 | @staticmethod 190 | def _download_wandb_files(wandb_project_run_id: str) -> ResidualMLPPaths: 191 | """Download the relevant files from a wandb run.""" 192 | api = wandb.Api() 193 | run: Run = api.run(wandb_project_run_id) 194 | 195 | checkpoint = fetch_latest_wandb_checkpoint(run) 196 | 197 | run_dir = fetch_wandb_run_dir(run.id) 198 | 199 | resid_mlp_train_config_path = download_wandb_file( 200 | run, run_dir, "resid_mlp_train_config.yaml" 201 | ) 202 | label_coeffs_path = download_wandb_file(run, run_dir, "label_coeffs.json") 203 | checkpoint_path = download_wandb_file(run, run_dir, checkpoint.name) 204 | logger.info(f"Downloaded checkpoint from {checkpoint_path}") 205 | return ResidualMLPPaths( 206 | resid_mlp_train_config=resid_mlp_train_config_path, 207 | label_coeffs=label_coeffs_path, 208 | checkpoint=checkpoint_path, 209 | ) 210 | 211 | @classmethod 212 | def from_pretrained( 213 | cls, path: ModelPath 214 | ) -> tuple["ResidualMLPModel", dict[str, Any], Float[Tensor, "n_instances n_features"]]: 215 | """Fetch a pretrained model from wandb or a local path to a checkpoint. 216 | 217 | Args: 218 | path: The path to local checkpoint or wandb project. If a wandb project, format must be 219 | `wandb://` or `wandb://runs/`. 220 | If `api.entity` is set (e.g. via setting WANDB_ENTITY in .env), can be 221 | omitted, and if `api.project` is set, can be omitted. If local path, 222 | assumes that `resid_mlp_train_config.yaml` and `label_coeffs.json` are in the same 223 | directory as the checkpoint. 224 | 225 | Returns: 226 | model: The pretrained ResidualMLPModel 227 | resid_mlp_train_config_dict: The config dict used to train the model (we don't 228 | instantiate a train config due to circular import issues) 229 | label_coeffs: The label coefficients used to train the model 230 | """ 231 | if isinstance(path, str) and path.startswith(WANDB_PATH_PREFIX): 232 | wandb_path = path.removeprefix(WANDB_PATH_PREFIX) 233 | paths = cls._download_wandb_files(wandb_path) 234 | else: 235 | # `path` should be a local path to a checkpoint 236 | paths = ResidualMLPPaths( 237 | resid_mlp_train_config=Path(path).parent / "resid_mlp_train_config.yaml", 238 | label_coeffs=Path(path).parent / "label_coeffs.json", 239 | checkpoint=Path(path), 240 | ) 241 | 242 | with open(paths.resid_mlp_train_config) as f: 243 | resid_mlp_train_config_dict = yaml.safe_load(f) 244 | 245 | with open(paths.label_coeffs) as f: 246 | label_coeffs = torch.tensor(json.load(f)) 247 | 248 | resid_mlp_config = ResidualMLPConfig(**resid_mlp_train_config_dict["resid_mlp_config"]) 249 | resid_mlp = cls(resid_mlp_config) 250 | params = torch.load(paths.checkpoint, weights_only=True, map_location="cpu") 251 | 252 | params = replace_deprecated_param_names( 253 | params, 254 | name_map={"linear1": "mlp_in.weight", "linear2": "mlp_out.weight"}, 255 | ) 256 | resid_mlp.load_state_dict(params) 257 | 258 | return resid_mlp, resid_mlp_train_config_dict, label_coeffs 259 | 260 | 261 | class ResidualMLPSPDPaths(BaseModel): 262 | """Paths to output files from a ResidualMLPSPDModel training run.""" 263 | 264 | final_config: Path 265 | resid_mlp_train_config: Path 266 | label_coeffs: Path 267 | checkpoint: Path 268 | 269 | 270 | class ResidualMLPSPDConfig(BaseModel): 271 | model_config = ConfigDict(extra="forbid", frozen=True) 272 | n_instances: PositiveInt 273 | n_features: PositiveInt 274 | d_embed: PositiveInt 275 | d_mlp: PositiveInt 276 | n_layers: PositiveInt 277 | act_fn_name: Literal["gelu", "relu"] 278 | apply_output_act_fn: bool 279 | in_bias: bool 280 | out_bias: bool 281 | init_scale: float 282 | C: PositiveInt 283 | m: PositiveInt | None = None 284 | init_type: Literal["kaiming_uniform", "xavier_normal"] = "xavier_normal" 285 | 286 | 287 | class ResidualMLPSPDModel(SPDModel): 288 | def __init__( 289 | self, 290 | config: ResidualMLPSPDConfig, 291 | ): 292 | super().__init__() 293 | self.config = config 294 | self.n_features = config.n_features # Required for backward compatibility 295 | self.n_instances = config.n_instances # Required for backward compatibility 296 | self.C = config.C # Required for backward compatibility 297 | 298 | assert config.act_fn_name in ["gelu", "relu"] 299 | self.act_fn = F.gelu if config.act_fn_name == "gelu" else F.relu 300 | 301 | self.W_E = nn.Parameter(torch.empty(config.n_instances, config.n_features, config.d_embed)) 302 | self.W_U = nn.Parameter(torch.empty(config.n_instances, config.d_embed, config.n_features)) 303 | init_param_(self.W_E, init_type=config.init_type) 304 | init_param_(self.W_U, init_type=config.init_type) 305 | 306 | self.m = min(config.d_embed, config.d_mlp) if config.m is None else config.m 307 | 308 | self.layers = nn.ModuleList( 309 | [ 310 | MLP( 311 | n_instances=config.n_instances, 312 | d_model=config.d_embed, 313 | d_mlp=config.d_mlp, 314 | init_type=config.init_type, 315 | init_scale=config.init_scale, 316 | in_bias=config.in_bias, 317 | out_bias=config.out_bias, 318 | act_fn=self.act_fn, 319 | spd_kwargs={"C": config.C, "m": self.m}, 320 | ) 321 | for _ in range(config.n_layers) 322 | ] 323 | ) 324 | self.setup() 325 | 326 | def forward( 327 | self, 328 | x: Float[Tensor, "batch n_instances n_features"], 329 | topk_mask: Bool[Tensor, "batch n_instances C"] | None = None, 330 | ) -> Float[Tensor, "batch n_instances d_embed"]: 331 | """ 332 | Returns: 333 | x: The output of the model 334 | """ 335 | residual = einops.einsum( 336 | x, 337 | self.W_E, 338 | "batch n_instances n_features, n_instances n_features d_embed -> batch n_instances d_embed", 339 | ) 340 | for layer in self.layers: 341 | residual = residual + layer(residual, topk_mask) 342 | out = einops.einsum( 343 | residual, 344 | self.W_U, 345 | "batch n_instances d_embed, n_instances d_embed n_features -> batch n_instances n_features", 346 | ) 347 | if self.config.apply_output_act_fn: 348 | out = self.act_fn(out) 349 | return out 350 | 351 | @staticmethod 352 | def _download_wandb_files(wandb_project_run_id: str) -> ResidualMLPSPDPaths: 353 | """Download the relevant files from a wandb run.""" 354 | api = wandb.Api() 355 | run: Run = api.run(wandb_project_run_id) 356 | 357 | checkpoint = fetch_latest_wandb_checkpoint(run, prefix="spd_model") 358 | 359 | run_dir = fetch_wandb_run_dir(run.id) 360 | 361 | final_config_path = download_wandb_file(run, run_dir, "final_config.yaml") 362 | resid_mlp_train_config_path = download_wandb_file( 363 | run, run_dir, "resid_mlp_train_config.yaml" 364 | ) 365 | label_coeffs_path = download_wandb_file(run, run_dir, "label_coeffs.json") 366 | checkpoint_path = download_wandb_file(run, run_dir, checkpoint.name) 367 | logger.info(f"Downloaded checkpoint from {checkpoint_path}") 368 | return ResidualMLPSPDPaths( 369 | final_config=final_config_path, 370 | resid_mlp_train_config=resid_mlp_train_config_path, 371 | label_coeffs=label_coeffs_path, 372 | checkpoint=checkpoint_path, 373 | ) 374 | 375 | @classmethod 376 | def from_pretrained( 377 | cls, path: str | Path 378 | ) -> tuple["ResidualMLPSPDModel", Config, Float[Tensor, "n_instances n_features"]]: 379 | """Fetch a pretrained model from wandb or a local path to a checkpoint. 380 | 381 | Args: 382 | path: The path to local checkpoint or wandb project. If a wandb project, the format 383 | must be `wandb:entity/project/run_id`. If `api.entity` is set (e.g. via setting 384 | WANDB_ENTITY in .env), this can be in the form `wandb:project/run_id` and if 385 | form `wandb:project/run_id` and if `api.project` is set this can just be 386 | `wandb:run_id`. If local path, assumes that `resid_mlp_train_config.yaml` and 387 | `label_coeffs.json` are in the same directory as the checkpoint. 388 | """ 389 | if isinstance(path, str) and path.startswith(WANDB_PATH_PREFIX): 390 | wandb_path = path.removeprefix(WANDB_PATH_PREFIX) 391 | paths = cls._download_wandb_files(wandb_path) 392 | else: 393 | paths = ResidualMLPSPDPaths( 394 | final_config=Path(path).parent / "final_config.yaml", 395 | resid_mlp_train_config=Path(path).parent / "resid_mlp_train_config.yaml", 396 | label_coeffs=Path(path).parent / "label_coeffs.json", 397 | checkpoint=Path(path), 398 | ) 399 | 400 | with open(paths.final_config) as f: 401 | final_config_dict = yaml.safe_load(f) 402 | config = Config(**final_config_dict) 403 | 404 | with open(paths.resid_mlp_train_config) as f: 405 | resid_mlp_train_config_dict = yaml.safe_load(f) 406 | 407 | with open(paths.label_coeffs) as f: 408 | label_coeffs = torch.tensor(json.load(f)) 409 | 410 | assert isinstance(config.task_config, ResidualMLPTaskConfig) 411 | resid_mlp_spd_config = ResidualMLPSPDConfig( 412 | **resid_mlp_train_config_dict["resid_mlp_config"], C=config.C, m=config.m 413 | ) 414 | model = cls(config=resid_mlp_spd_config) 415 | params = torch.load(paths.checkpoint, weights_only=True, map_location="cpu") 416 | 417 | params = replace_deprecated_param_names( 418 | params, name_map={"linear1": "mlp_in", "linear2": "mlp_out"} 419 | ) 420 | 421 | model.load_state_dict(params) 422 | return model, config, label_coeffs 423 | -------------------------------------------------------------------------------- /spd/experiments/resid_mlp/resid_mlp_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import einops 4 | import torch 5 | import torch.nn.functional as F 6 | from jaxtyping import Float 7 | from torch import Tensor 8 | 9 | from spd.utils import SparseFeatureDataset 10 | 11 | 12 | class ResidualMLPDataset(SparseFeatureDataset): 13 | def __init__( 14 | self, 15 | n_instances: int, 16 | n_features: int, 17 | feature_probability: float, 18 | device: str, 19 | calc_labels: bool = True, # If False, just return the inputs as labels 20 | label_type: Literal["act_plus_resid", "abs"] | None = None, 21 | act_fn_name: Literal["relu", "gelu"] | None = None, 22 | label_fn_seed: int | None = None, 23 | label_coeffs: Float[Tensor, "n_instances n_features"] | None = None, 24 | data_generation_type: Literal[ 25 | "exactly_one_active", "exactly_two_active", "at_least_zero_active" 26 | ] = "at_least_zero_active", 27 | synced_inputs: list[list[int]] | None = None, 28 | ): 29 | """Sparse feature dataset for use in training a resid_mlp model or running SPD on it. 30 | 31 | If calc_labels is True, labels are of the form `act_fn(coeffs*x) + x` or `abs(coeffs*x)`, 32 | depending on label_type, act_fn_name, and label_fn_seed. 33 | 34 | Otherwise, the labels are the same as the inputs. 35 | 36 | Args: 37 | n_instances: The number of instances in the model and dataset. 38 | n_features: The number of features in the model and dataset. 39 | feature_probability: The probability that a feature is active in a given instance. 40 | device: The device to calculate and store the data on. 41 | calc_labels: Whether to calculate labels. If False, labels are the same as the inputs. 42 | label_type: The type of labels to calculate. Ignored if calc_labels is False. 43 | act_fn_name: Used for labels if calc_labels is True and label_type is act_plus_resid. 44 | Ignored if calc_labels is False. 45 | label_fn_seed: The seed to use for generating the label coefficients. Ignored if 46 | calc_labels is False or label_coeffs is not None. 47 | label_coeffs: The label coefficients to use. If None, the coefficients are generated 48 | randomly (unless calc_labels is False). 49 | data_generation_type: The number of active features in each sample. 50 | synced_inputs: The indices of the inputs to sync. 51 | """ 52 | super().__init__( 53 | n_instances=n_instances, 54 | n_features=n_features, 55 | feature_probability=feature_probability, 56 | device=device, 57 | data_generation_type=data_generation_type, 58 | value_range=(-1.0, 1.0), 59 | synced_inputs=synced_inputs, 60 | ) 61 | 62 | self.label_fn = None 63 | self.label_coeffs = None 64 | 65 | if calc_labels: 66 | self.label_coeffs = ( 67 | self.calc_label_coeffs(label_fn_seed) if label_coeffs is None else label_coeffs 68 | ).to(self.device) 69 | 70 | assert label_type is not None, "Must provide label_type if calc_labels is True" 71 | if label_type == "act_plus_resid": 72 | assert act_fn_name in ["relu", "gelu"], "act_fn_name must be 'relu' or 'gelu'" 73 | self.label_fn = lambda batch: self.calc_act_plus_resid_labels( 74 | batch=batch, act_fn_name=act_fn_name 75 | ) 76 | elif label_type == "abs": 77 | self.label_fn = lambda batch: self.calc_abs_labels(batch) 78 | 79 | def generate_batch( 80 | self, batch_size: int 81 | ) -> tuple[ 82 | Float[Tensor, "batch n_instances n_features"], Float[Tensor, "batch n_instances n_features"] 83 | ]: 84 | # Note that the parent_labels are just the batch itself 85 | batch, parent_labels = super().generate_batch(batch_size) 86 | labels = self.label_fn(batch) if self.label_fn is not None else parent_labels 87 | return batch, labels 88 | 89 | def calc_act_plus_resid_labels( 90 | self, 91 | batch: Float[Tensor, "batch n_instances n_functions"], 92 | act_fn_name: Literal["relu", "gelu"], 93 | ) -> Float[Tensor, "batch n_instances n_functions"]: 94 | """Calculate the corresponding labels for the batch using `act_fn(coeffs*x) + x`.""" 95 | assert self.label_coeffs is not None 96 | weighted_inputs = einops.einsum( 97 | batch, 98 | self.label_coeffs, 99 | "batch n_instances n_functions, n_instances n_functions -> batch n_instances n_functions", 100 | ) 101 | assert act_fn_name in ["relu", "gelu"], "act_fn_name must be 'relu' or 'gelu'" 102 | act_fn = F.relu if act_fn_name == "relu" else F.gelu 103 | labels = act_fn(weighted_inputs) + batch 104 | return labels 105 | 106 | def calc_abs_labels( 107 | self, batch: Float[Tensor, "batch n_instances n_features"] 108 | ) -> Float[Tensor, "batch n_instances n_features"]: 109 | assert self.label_coeffs is not None 110 | weighted_inputs = einops.einsum( 111 | batch, 112 | self.label_coeffs, 113 | "batch n_instances n_functions, n_instances n_functions -> batch n_instances n_functions", 114 | ) 115 | return torch.abs(weighted_inputs) 116 | 117 | def calc_label_coeffs( 118 | self, label_fn_seed: int | None = None 119 | ) -> Float[Tensor, "n_instances n_features"]: 120 | """Create random coeffs between [1, 2] using label_fn_seed if provided.""" 121 | gen = torch.Generator(device=self.device) 122 | if label_fn_seed is not None: 123 | gen.manual_seed(label_fn_seed) 124 | return torch.rand(self.n_instances, self.n_features, generator=gen, device=self.device) + 1 125 | -------------------------------------------------------------------------------- /spd/experiments/resid_mlp/resid_mlp_sweep_config.yaml: -------------------------------------------------------------------------------- 1 | program: spd/experiments/resid_mlp/resid_mlp_decomposition.py 2 | method: grid 3 | metric: 4 | name: total_loss 5 | goal: minimize 6 | parameters: 7 | seed: 8 | values: [0] 9 | lr: 10 | values: [1e-2] 11 | topk_recon_coeff: 12 | values: [1e-1, 1e-2] 13 | 14 | command: 15 | - ${env} 16 | - ${interpreter} 17 | - ${program} 18 | - spd/experiments/resid_mlp/resid_mlp_topk_config.yaml -------------------------------------------------------------------------------- /spd/experiments/resid_mlp/resid_mlp_topk_config.yaml: -------------------------------------------------------------------------------- 1 | # ########## 1 layer ########## 2 | wandb_project: spd-resid-mlp 3 | wandb_run_name: null 4 | wandb_run_name_prefix: "" 5 | unit_norm_matrices: true 6 | seed: 0 7 | # topk: 1 8 | topk: 1.28 9 | m: null 10 | C: 130 11 | pnorm: null 12 | batch_topk: true 13 | param_match_coeff: 1.0 14 | topk_recon_coeff: 1.0 15 | act_recon_coeff: 1.0 16 | schatten_pnorm: 0.9 17 | schatten_coeff: 1e1 18 | lr: 1e-3 19 | batch_size: 256 20 | steps: 10_000 21 | print_freq: 500 22 | image_freq: 5_000 23 | save_freq: 10_000 24 | lr_warmup_pct: 0.01 25 | lr_schedule: cosine 26 | image_on_first_step: false 27 | task_config: 28 | task_name: residual_mlp 29 | init_scale: 2.0 30 | feature_probability: 0.01 31 | data_generation_type: "at_least_zero_active" 32 | pretrained_model_path: wandb:spd-train-resid-mlp/runs/zas5yjdl # 1 layer 33 | 34 | 35 | ########## 2 layer ########## 36 | # wandb_project: spd-resid-mlp 37 | # wandb_run_name: null 38 | # wandb_run_name_prefix: "" 39 | # unit_norm_matrices: false 40 | # seed: 0 41 | # topk: 1.28 # bs=256 42 | # m: null 43 | # C: 200 44 | # pnorm: null 45 | # batch_topk: true 46 | # param_match_coeff: 1.0 47 | # topk_recon_coeff: 2.0 48 | # act_recon_coeff: 1.0 49 | # schatten_pnorm: 0.9 50 | # schatten_coeff: 7 51 | # lr: 1e-3 52 | # batch_size: 256 53 | # steps: 10_000 54 | # print_freq: 500 55 | # image_freq: 10_000 56 | # save_freq: 10_000 57 | # lr_warmup_pct: 0.01 58 | # lr_schedule: cosine 59 | # image_on_first_step: false 60 | # task_config: 61 | # task_name: residual_mlp 62 | # init_scale: 2.0 63 | # feature_probability: 0.01 64 | # data_generation_type: "at_least_zero_active" 65 | # pretrained_model_path: wandb:spd-train-resid-mlp/runs/sv23xrhj # 2 layer 66 | -------------------------------------------------------------------------------- /spd/experiments/resid_mlp/spd_interp.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from pathlib import Path 3 | 4 | import matplotlib.pyplot as plt 5 | import torch 6 | from jaxtyping import Float 7 | from pydantic import PositiveFloat 8 | from torch import Tensor 9 | 10 | from spd.experiments.resid_mlp.models import ResidualMLPModel, ResidualMLPSPDModel 11 | from spd.experiments.resid_mlp.plotting import ( 12 | analyze_per_feature_performance, 13 | collect_average_components_per_feature, 14 | collect_per_feature_losses, 15 | get_feature_subnet_map, 16 | get_scrubbed_losses, 17 | plot_avg_components_scatter, 18 | plot_feature_response_with_subnets, 19 | plot_per_feature_performance_fig, 20 | plot_scrub_losses, 21 | plot_spd_feature_contributions_truncated, 22 | ) 23 | from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset 24 | from spd.experiments.resid_mlp.resid_mlp_decomposition import plot_subnet_categories 25 | from spd.run_spd import ResidualMLPTaskConfig 26 | from spd.settings import REPO_ROOT 27 | from spd.utils import ( 28 | COLOR_PALETTE, 29 | SPDOutputs, 30 | run_spd_forward_pass, 31 | set_seed, 32 | ) 33 | 34 | color_map = { 35 | "target": COLOR_PALETTE[0], 36 | "apd_topk": COLOR_PALETTE[1], 37 | "apd_scrubbed": COLOR_PALETTE[4], 38 | "apd_antiscrubbed": COLOR_PALETTE[2], # alt: 3 39 | "baseline_monosemantic": "grey", 40 | } 41 | 42 | out_dir = REPO_ROOT / "spd/experiments/resid_mlp/out/figures/" 43 | out_dir.mkdir(parents=True, exist_ok=True) 44 | 45 | # %% Loading 46 | device = "cuda" if torch.cuda.is_available() else "cpu" 47 | print(f"Using device: {device}") 48 | set_seed(0) # You can change this seed if needed 49 | 50 | use_data_from_files = True 51 | wandb_path = "wandb:spd-resid-mlp/runs/8qz1si1l" # 1 layer 40k steps (R6) topk=1.28 52 | # wandb_path = "wandb:spd-resid-mlp/runs/9a639c6w" # 1 layer topk=1 53 | # wandb_path = "wandb:spd-resid-mlp/runs/cb0ej7hj" # 2 layer 2LR4 topk=1.28 54 | # wandb_path = "wandb:spd-resid-mlp/runs/wbeghftm" # 2 layer topk=1 55 | # wandb_path = "wandb:spd-resid-mlp/runs/c1q3bs6f" # 2 layer m=1 topk=1.28 (not in paper) 56 | 57 | wandb_id = wandb_path.split("/")[-1] 58 | 59 | # Load the pretrained SPD model 60 | model, config, label_coeffs = ResidualMLPSPDModel.from_pretrained(wandb_path) 61 | assert isinstance(config.task_config, ResidualMLPTaskConfig) 62 | 63 | # Path must be local 64 | target_model, target_model_train_config_dict, target_label_coeffs = ( 65 | ResidualMLPModel.from_pretrained(config.task_config.pretrained_model_path) 66 | ) 67 | # Print some basic information about the model 68 | print(f"Number of features: {model.config.n_features}") 69 | print(f"Feature probability: {config.task_config.feature_probability}") 70 | print(f"Embedding dimension: {model.config.d_embed}") 71 | print(f"MLP dimension: {model.config.d_mlp}") 72 | print(f"Number of layers: {model.config.n_layers}") 73 | print(f"Number of subnetworks (C): {model.config.C}") 74 | model = model.to(device) 75 | label_coeffs = label_coeffs.to(device) 76 | target_model = target_model.to(device) 77 | target_label_coeffs = target_label_coeffs.to(device) 78 | assert torch.allclose(target_label_coeffs, label_coeffs) 79 | 80 | n_layers = target_model.config.n_layers 81 | 82 | 83 | # Functions used for various plots 84 | def spd_model_fn( 85 | batch: Float[Tensor, "batch n_instances n_features"], 86 | topk: PositiveFloat | None = config.topk, 87 | batch_topk: bool = config.batch_topk, 88 | ) -> SPDOutputs: 89 | assert topk is not None 90 | return run_spd_forward_pass( 91 | spd_model=model, 92 | target_model=target_model, 93 | input_array=batch, 94 | attribution_type=config.attribution_type, 95 | batch_topk=batch_topk, 96 | topk=topk, 97 | distil_from_target=config.distil_from_target, 98 | ) 99 | 100 | 101 | def target_model_fn(batch: Float[Tensor, "batch n_instances"]): 102 | return target_model(batch) 103 | 104 | 105 | def top1_model_fn( 106 | batch: Float[Tensor, "batch n_instances n_features"], 107 | topk_mask: Float[Tensor, "batch n_instances C"] | None, 108 | ) -> SPDOutputs: 109 | """Top1 if topk_mask is None, else just use provided topk_mask""" 110 | topk_mask = topk_mask.to(device) if topk_mask is not None else None 111 | assert config.topk is not None 112 | return run_spd_forward_pass( 113 | spd_model=model, 114 | target_model=target_model, 115 | input_array=batch, 116 | attribution_type=config.attribution_type, 117 | batch_topk=False, 118 | topk=1, 119 | distil_from_target=config.distil_from_target, 120 | topk_mask=topk_mask, 121 | ) 122 | 123 | 124 | dataset = ResidualMLPDataset( 125 | n_instances=model.config.n_instances, 126 | n_features=model.config.n_features, 127 | feature_probability=config.task_config.feature_probability, 128 | device=device, 129 | calc_labels=True, 130 | label_type=target_model_train_config_dict["label_type"], 131 | act_fn_name=target_model.config.act_fn_name, 132 | label_coeffs=target_label_coeffs, 133 | data_generation_type="at_least_zero_active", # We will change this in the for loop 134 | ) 135 | 136 | # %% Plot how many subnets are monosemantic, etc. 137 | fig = plot_subnet_categories(model, device, cutoff=4e-2) 138 | # Save the figure 139 | fig.savefig(out_dir / f"resid_mlp_subnet_categories_{n_layers}layers_{wandb_id}.png") 140 | print(f"Saved figure to {out_dir / f'resid_mlp_subnet_categories_{n_layers}layers_{wandb_id}.png'}") 141 | 142 | 143 | # %% 144 | per_feature_losses_path = Path(out_dir) / f"resid_mlp_losses_{n_layers}layers_{wandb_id}.pt" 145 | if not use_data_from_files or not per_feature_losses_path.exists(): 146 | loss_target, loss_spd_batch_topk, loss_spd_sample_topk = collect_per_feature_losses( 147 | target_model=target_model, 148 | spd_model=model, 149 | config=config, 150 | dataset=dataset, 151 | device=device, 152 | batch_size=config.batch_size, 153 | n_samples=100_000, 154 | ) 155 | # Save the losses to a file 156 | torch.save( 157 | (loss_target, loss_spd_batch_topk, loss_spd_sample_topk), 158 | per_feature_losses_path, 159 | ) 160 | 161 | # Load the losses from a file 162 | loss_target, loss_spd_batch_topk, loss_spd_sample_topk = torch.load( 163 | per_feature_losses_path, weights_only=True, map_location="cpu" 164 | ) 165 | 166 | fig = plot_per_feature_performance_fig( 167 | loss_target=loss_target, 168 | loss_spd_batch_topk=loss_spd_batch_topk, 169 | loss_spd_sample_topk=loss_spd_sample_topk, 170 | config=config, 171 | color_map=color_map, 172 | ) 173 | fig.show() 174 | fig.savefig(out_dir / f"resid_mlp_per_feature_performance_{n_layers}layers_{wandb_id}.png") 175 | print( 176 | f"Saved figure to {out_dir / f'resid_mlp_per_feature_performance_{n_layers}layers_{wandb_id}.png'}" 177 | ) 178 | 179 | # %% 180 | # Scatter plot of avg active components vs loss difference 181 | avg_components_path = Path(out_dir) / f"avg_components_{n_layers}layers_{wandb_id}.pt" 182 | if not use_data_from_files or not avg_components_path.exists(): 183 | avg_components = collect_average_components_per_feature( 184 | model_fn=spd_model_fn, 185 | dataset=dataset, 186 | device=device, 187 | n_features=model.config.n_features, 188 | batch_size=config.batch_size, 189 | n_samples=500_000, 190 | ) 191 | # Save the avg_components to a file 192 | torch.save(avg_components.cpu(), avg_components_path) 193 | 194 | # Load the avg_components from a file 195 | avg_components = torch.load(avg_components_path, map_location=device, weights_only=True) 196 | 197 | # Get the loss of the spd model w.r.t the target model 198 | fn_without_batch_topk = lambda batch: spd_model_fn( 199 | batch, topk=1, batch_topk=False 200 | ).spd_topk_model_output # type: ignore 201 | losses_spd_wrt_target = analyze_per_feature_performance( 202 | model_fn=fn_without_batch_topk, 203 | target_model_fn=target_model_fn, 204 | model_config=model.config, 205 | device=device, 206 | batch_size=config.batch_size, 207 | ) 208 | 209 | fig = plot_avg_components_scatter( 210 | losses_spd_wrt_target=losses_spd_wrt_target, avg_components=avg_components 211 | ) 212 | fig.show() 213 | # Save the figure 214 | fig.savefig(out_dir / f"resid_mlp_avg_components_scatter_{n_layers}layers_{wandb_id}.png") 215 | print( 216 | f"Saved figure to {out_dir / f'resid_mlp_avg_components_scatter_{n_layers}layers_{wandb_id}.png'}" 217 | ) 218 | 219 | # %% 220 | # Plot the main truncated feature contributions figure for the paper 221 | fig = plot_spd_feature_contributions_truncated( 222 | spd_model=model, 223 | target_model=target_model, 224 | device=device, 225 | n_features=10, 226 | include_crossterms=False, 227 | ) 228 | fig.savefig(out_dir / f"resid_mlp_weights_{n_layers}layers_{wandb_id}.png") 229 | print(f"Saved figure to {out_dir / f'resid_mlp_weights_{n_layers}layers_{wandb_id}.png'}") 230 | 231 | # Full figure for updating wandb report 232 | # fig = plot_spd_feature_contributions( 233 | # spd_model=model, 234 | # target_model=target_model, 235 | # device=device, 236 | # ) 237 | # fig.savefig(out_dir / f"resid_mlp_weights_full_{n_layers}layers_{wandb_id}.png") 238 | # plt.close(fig) 239 | # print(f"Saved figure to {out_dir / f'resid_mlp_weights_full_{n_layers}layers_{wandb_id}.png'}") 240 | # import wandb 241 | 242 | # # Restart the run and log the figure 243 | # run = wandb.init(project="spd-resid-mlp", id=wandb_id, resume="must") 244 | # run.log({"neuron_contributions": wandb.Image(fig)}) 245 | # run.finish() 246 | 247 | # %% 248 | # Plot causal scrubbing-esque test 249 | n_batches = 100 250 | losses = get_scrubbed_losses( 251 | top1_model_fn=top1_model_fn, 252 | spd_model_fn=spd_model_fn, 253 | target_model=target_model, 254 | dataset=dataset, 255 | model=model, 256 | device=device, 257 | config=config, 258 | n_batches=n_batches, 259 | ) 260 | 261 | fig = plot_scrub_losses(losses, config, color_map, n_batches) 262 | fig.savefig( 263 | out_dir / f"resid_mlp_scrub_hist_{n_layers}layers_{wandb_id}.png", bbox_inches="tight", dpi=300 264 | ) 265 | print(f"Saved figure to {out_dir / f'resid_mlp_scrub_hist_{n_layers}layers_{wandb_id}.png'}") 266 | 267 | # %% Linearity test: Enable one subnet after the other 268 | # candlestick plot 269 | 270 | # # Dictionary feature_idx -> subnet_idx 271 | subnet_indices = get_feature_subnet_map(top1_model_fn, device, model.config, instance_idx=0) 272 | 273 | n_features = model.config.n_features 274 | feature_idx = 42 275 | subtract_inputs = True # TODO TRUE subnet 276 | 277 | 278 | fig = plot_feature_response_with_subnets( 279 | topk_model_fn=top1_model_fn, 280 | device=device, 281 | model_config=model.config, 282 | feature_idx=feature_idx, 283 | subnet_idx=subnet_indices[feature_idx], 284 | batch_size=1000, 285 | plot_type="errorbar", 286 | color_map=color_map, 287 | )["feature_response_with_subnets"] 288 | fig.savefig( # type: ignore 289 | out_dir / f"feature_response_with_subnets_{feature_idx}_{n_layers}layers_{wandb_id}.png", 290 | bbox_inches="tight", 291 | dpi=300, 292 | ) 293 | print( 294 | f"Saved figure to {out_dir / f'feature_response_with_subnets_{feature_idx}_{n_layers}layers_{wandb_id}.png'}" 295 | ) 296 | plt.show() 297 | -------------------------------------------------------------------------------- /spd/experiments/resid_mlp/train_resid_mlp.py: -------------------------------------------------------------------------------- 1 | """Trains a residual linear model on one-hot input vectors.""" 2 | 3 | import json 4 | from datetime import datetime 5 | from pathlib import Path 6 | from typing import Literal, Self 7 | 8 | import einops 9 | import torch 10 | import wandb 11 | import yaml 12 | from jaxtyping import Float 13 | from pydantic import BaseModel, ConfigDict, PositiveFloat, PositiveInt, model_validator 14 | from torch import Tensor, nn 15 | from tqdm import tqdm 16 | 17 | from spd.experiments.resid_mlp.models import ResidualMLPConfig, ResidualMLPModel 18 | from spd.experiments.resid_mlp.resid_mlp_dataset import ( 19 | ResidualMLPDataset, 20 | ) 21 | from spd.log import logger 22 | from spd.utils import ( 23 | DatasetGeneratedDataLoader, 24 | compute_feature_importances, 25 | get_lr_schedule_fn, 26 | set_seed, 27 | ) 28 | from spd.wandb_utils import init_wandb 29 | 30 | wandb.require("core") 31 | 32 | 33 | class ResidMLPTrainConfig(BaseModel): 34 | model_config = ConfigDict(extra="forbid", frozen=True) 35 | wandb_project: str | None = None # The name of the wandb project (if None, don't log to wandb) 36 | seed: int = 0 37 | resid_mlp_config: ResidualMLPConfig 38 | label_fn_seed: int = 0 39 | label_type: Literal["act_plus_resid", "abs"] = "act_plus_resid" 40 | loss_type: Literal["readoff", "resid"] = "readoff" 41 | use_trivial_label_coeffs: bool = False 42 | feature_probability: PositiveFloat 43 | synced_inputs: list[list[int]] | None = None 44 | importance_val: float | None = None 45 | data_generation_type: Literal[ 46 | "exactly_one_active", "exactly_two_active", "at_least_zero_active" 47 | ] = "at_least_zero_active" 48 | batch_size: PositiveInt 49 | steps: PositiveInt 50 | print_freq: PositiveInt 51 | lr: PositiveFloat 52 | lr_schedule: Literal["linear", "constant", "cosine", "exponential"] = "constant" 53 | fixed_random_embedding: bool = False 54 | fixed_identity_embedding: bool = False 55 | n_batches_final_losses: PositiveInt = 1 56 | 57 | @model_validator(mode="after") 58 | def validate_model(self) -> Self: 59 | assert not ( 60 | self.fixed_random_embedding and self.fixed_identity_embedding 61 | ), "Can't have both fixed_random_embedding and fixed_identity_embedding" 62 | if self.fixed_identity_embedding: 63 | assert ( 64 | self.resid_mlp_config.n_features == self.resid_mlp_config.d_embed 65 | ), "n_features must equal d_embed if we are using an identity embedding matrix" 66 | if self.synced_inputs is not None: 67 | # Ensure that the synced_inputs are non-overlapping with eachother 68 | all_indices = [item for sublist in self.synced_inputs for item in sublist] 69 | if len(all_indices) != len(set(all_indices)): 70 | raise ValueError("Synced inputs must be non-overlapping") 71 | return self 72 | 73 | 74 | def loss_function( 75 | out: Float[Tensor, "batch n_instances n_features"] | Float[Tensor, "batch n_instances d_embed"], 76 | labels: Float[Tensor, "batch n_instances n_features"], 77 | feature_importances: Float[Tensor, "batch n_instances n_features"], 78 | model: ResidualMLPModel, 79 | config: ResidMLPTrainConfig, 80 | ) -> Float[Tensor, "batch n_instances d_embed"] | Float[Tensor, "batch n_instances d_embed"]: 81 | if config.loss_type == "readoff": 82 | loss = ((out - labels) ** 2) * feature_importances 83 | elif config.loss_type == "resid": 84 | assert torch.allclose( 85 | feature_importances, torch.ones_like(feature_importances) 86 | ), "feature_importances incompatible with loss_type resid" 87 | resid_out: Float[Tensor, "batch n_instances d_embed"] = out 88 | resid_labels: Float[Tensor, "batch n_instances d_embed"] = einops.einsum( 89 | labels, 90 | model.W_E, 91 | "batch n_instances n_features, n_instances n_features d_embed " 92 | "-> batch n_instances d_embed", 93 | ) 94 | loss = (resid_out - resid_labels) ** 2 95 | else: 96 | raise ValueError(f"Invalid loss_type: {config.loss_type}") 97 | return loss 98 | 99 | 100 | def train( 101 | config: ResidMLPTrainConfig, 102 | model: ResidualMLPModel, 103 | trainable_params: list[nn.Parameter], 104 | dataloader: DatasetGeneratedDataLoader[ 105 | tuple[ 106 | Float[Tensor, "batch n_instances n_features"], 107 | Float[Tensor, "batch n_instances d_embed"], 108 | ] 109 | ], 110 | feature_importances: Float[Tensor, "batch_size n_instances n_features"], 111 | device: str, 112 | out_dir: Path, 113 | run_name: str, 114 | ) -> Float[Tensor, " n_instances"]: 115 | if config.wandb_project: 116 | config = init_wandb(config, config.wandb_project, name=run_name) 117 | 118 | out_dir.mkdir(parents=True, exist_ok=True) 119 | 120 | # Save config 121 | config_path = out_dir / "resid_mlp_train_config.yaml" 122 | with open(config_path, "w") as f: 123 | yaml.dump(config.model_dump(mode="json"), f, indent=2) 124 | logger.info(f"Saved config to {config_path}") 125 | if config.wandb_project: 126 | wandb.save(str(config_path), base_path=out_dir, policy="now") 127 | 128 | # Save the coefficients used to generate the labels 129 | assert isinstance(dataloader.dataset, ResidualMLPDataset) 130 | assert dataloader.dataset.label_coeffs is not None 131 | label_coeffs = dataloader.dataset.label_coeffs.tolist() 132 | label_coeffs_path = out_dir / "label_coeffs.json" 133 | with open(label_coeffs_path, "w") as f: 134 | json.dump(label_coeffs, f) 135 | logger.info(f"Saved label coefficients to {label_coeffs_path}") 136 | if config.wandb_project: 137 | wandb.save(str(label_coeffs_path), base_path=out_dir, policy="now") 138 | 139 | optimizer = torch.optim.AdamW(trainable_params, lr=config.lr, weight_decay=0.01) 140 | 141 | # Add this line to get the lr_schedule_fn 142 | lr_schedule_fn = get_lr_schedule_fn(config.lr_schedule) 143 | 144 | current_losses = torch.tensor([]) 145 | pbar = tqdm(range(config.steps), total=config.steps) 146 | for step, (batch, labels) in zip(pbar, dataloader, strict=False): 147 | if step >= config.steps: 148 | break 149 | 150 | # Add this block to update the learning rate 151 | current_lr = config.lr * lr_schedule_fn(step, config.steps) 152 | for param_group in optimizer.param_groups: 153 | param_group["lr"] = current_lr 154 | 155 | optimizer.zero_grad() 156 | batch: Float[Tensor, "batch n_instances n_features"] = batch.to(device) 157 | labels: Float[Tensor, "batch n_instances n_features"] = labels.to(device) 158 | out = model(batch, return_residual=config.loss_type == "resid") 159 | loss: ( 160 | Float[Tensor, "batch n_instances n_features"] 161 | | Float[Tensor, "batch n_instances d_embed"] 162 | ) = loss_function(out, labels, feature_importances, model, config) 163 | loss = loss.mean(dim=(0, 2)) 164 | current_losses = loss.detach() 165 | loss = loss.mean(dim=0) 166 | loss.backward() 167 | optimizer.step() 168 | if step % config.print_freq == 0: 169 | tqdm.write(f"step {step}: loss={current_losses.mean():.2e}, lr={current_lr:.2e}") 170 | if config.wandb_project: 171 | wandb.log({"loss": current_losses.mean(), "lr": current_lr}, step=step) 172 | 173 | model_path = out_dir / "resid_mlp.pth" 174 | torch.save(model.state_dict(), model_path) 175 | if config.wandb_project: 176 | wandb.save(str(model_path), base_path=out_dir, policy="now") 177 | print(f"Saved model to {model_path}") 178 | 179 | # Calculate final losses by averaging many batches 180 | final_losses = [] 181 | for _ in range(config.n_batches_final_losses): 182 | batch, labels = next(iter(dataloader)) 183 | batch = batch.to(device) 184 | labels = labels.to(device) 185 | out = model(batch, return_residual=config.loss_type == "resid") 186 | loss = loss_function(out, labels, feature_importances, model, config) 187 | loss = loss.mean(dim=(0, 2)) 188 | final_losses.append(loss) 189 | final_losses = torch.stack(final_losses).mean(dim=0).cpu().detach() 190 | print(f"Final losses: {final_losses.numpy()}") 191 | return final_losses 192 | 193 | 194 | def run_train(config: ResidMLPTrainConfig, device: str) -> Float[Tensor, " n_instances"]: 195 | model_cfg = config.resid_mlp_config 196 | run_name = ( 197 | f"resid_mlp_identity_{config.label_type}_n-instances{model_cfg.n_instances}_" 198 | f"n-features{model_cfg.n_features}_d-resid{model_cfg.d_embed}_" 199 | f"d-mlp{model_cfg.d_mlp}_n-layers{model_cfg.n_layers}_seed{config.seed}" 200 | f"_p{config.feature_probability}_random_embedding_{config.fixed_random_embedding}_" 201 | f"identity_embedding_{config.fixed_identity_embedding}_bias_{model_cfg.in_bias}_" 202 | f"{model_cfg.out_bias}_loss{config.loss_type}" 203 | ) 204 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] 205 | out_dir = Path(__file__).parent / "out" / f"{run_name}_{timestamp}" 206 | 207 | model = ResidualMLPModel(config=model_cfg).to(device) 208 | 209 | if config.fixed_random_embedding or config.fixed_identity_embedding: 210 | # Don't train the embedding matrices 211 | model.W_E.requires_grad = False 212 | model.W_U.requires_grad = False 213 | if config.fixed_random_embedding: 214 | # Init with randn values and make unit norm 215 | model.W_E.data[:, :, :] = torch.randn( 216 | model_cfg.n_instances, model_cfg.n_features, model_cfg.d_embed, device=device 217 | ) 218 | model.W_E.data /= model.W_E.data.norm(dim=-1, keepdim=True) 219 | # Set W_U to W_E^T 220 | model.W_U.data = model.W_E.data.transpose(-2, -1) 221 | assert torch.allclose(model.W_U.data, model.W_E.data.transpose(-2, -1)) 222 | elif config.fixed_identity_embedding: 223 | assert ( 224 | model_cfg.n_features == model_cfg.d_embed 225 | ), "n_features must equal d_embed for W_E=id" 226 | # Make W_E the identity matrix 227 | model.W_E.data[:, :, :] = einops.repeat( 228 | torch.eye(model_cfg.d_embed, device=device), 229 | "d_features d_embed -> n_instances d_features d_embed", 230 | n_instances=model_cfg.n_instances, 231 | ) 232 | 233 | label_coeffs = None 234 | if config.use_trivial_label_coeffs: 235 | label_coeffs = torch.ones(model_cfg.n_instances, model_cfg.n_features, device=device) 236 | 237 | dataset = ResidualMLPDataset( 238 | n_instances=model_cfg.n_instances, 239 | n_features=model_cfg.n_features, 240 | feature_probability=config.feature_probability, 241 | device=device, 242 | calc_labels=True, 243 | label_type=config.label_type, 244 | act_fn_name=model_cfg.act_fn_name, 245 | label_fn_seed=config.label_fn_seed, 246 | label_coeffs=label_coeffs, 247 | data_generation_type=config.data_generation_type, 248 | synced_inputs=config.synced_inputs, 249 | ) 250 | dataloader = DatasetGeneratedDataLoader(dataset, batch_size=config.batch_size, shuffle=False) 251 | 252 | feature_importances = compute_feature_importances( 253 | batch_size=config.batch_size, 254 | n_instances=model_cfg.n_instances, 255 | n_features=model_cfg.n_features, 256 | importance_val=config.importance_val, 257 | device=device, 258 | ) 259 | 260 | final_losses = train( 261 | config=config, 262 | model=model, 263 | trainable_params=[p for p in model.parameters() if p.requires_grad], 264 | dataloader=dataloader, 265 | feature_importances=feature_importances, 266 | device=device, 267 | out_dir=out_dir, 268 | run_name=run_name, 269 | ) 270 | return final_losses 271 | 272 | 273 | if __name__ == "__main__": 274 | device = "cuda" if torch.cuda.is_available() else "cpu" 275 | config = ResidMLPTrainConfig( 276 | wandb_project="spd-train-resid-mlp", 277 | seed=0, 278 | resid_mlp_config=ResidualMLPConfig( 279 | n_instances=1, 280 | n_features=100, 281 | d_embed=1000, 282 | d_mlp=50, 283 | n_layers=1, 284 | act_fn_name="relu", 285 | apply_output_act_fn=False, 286 | in_bias=False, 287 | out_bias=False, 288 | ), 289 | label_fn_seed=0, 290 | label_type="act_plus_resid", 291 | loss_type="readoff", 292 | use_trivial_label_coeffs=True, 293 | feature_probability=0.01, 294 | # synced_inputs=[[0, 1], [2, 3]], # synced inputs 295 | importance_val=1, 296 | data_generation_type="at_least_zero_active", 297 | batch_size=2048, 298 | steps=10000, 299 | print_freq=500, 300 | lr=3e-3, 301 | lr_schedule="cosine", 302 | fixed_random_embedding=True, 303 | fixed_identity_embedding=False, 304 | n_batches_final_losses=10, 305 | ) 306 | 307 | set_seed(config.seed) 308 | 309 | run_train(config, device) 310 | -------------------------------------------------------------------------------- /spd/experiments/tms/models.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | import torch 5 | import wandb 6 | import yaml 7 | from jaxtyping import Bool, Float 8 | from pydantic import BaseModel, ConfigDict, NonNegativeInt, PositiveInt 9 | from torch import Tensor, nn 10 | from torch.nn import functional as F 11 | from wandb.apis.public import Run 12 | 13 | from spd.hooks import HookedRootModule 14 | from spd.models.base import SPDModel 15 | from spd.models.components import ( 16 | Linear, 17 | LinearComponent, 18 | TransposedLinear, 19 | TransposedLinearComponent, 20 | ) 21 | from spd.run_spd import Config, TMSTaskConfig 22 | from spd.types import WANDB_PATH_PREFIX, ModelPath 23 | from spd.utils import replace_deprecated_param_names 24 | from spd.wandb_utils import download_wandb_file, fetch_latest_wandb_checkpoint, fetch_wandb_run_dir 25 | 26 | 27 | class TMSModelPaths(BaseModel): 28 | """Paths to output files from a TMSModel training run.""" 29 | 30 | tms_train_config: Path 31 | checkpoint: Path 32 | 33 | 34 | class TMSModelConfig(BaseModel): 35 | model_config = ConfigDict(extra="forbid", frozen=True) 36 | n_instances: PositiveInt 37 | n_features: PositiveInt 38 | n_hidden: PositiveInt 39 | n_hidden_layers: NonNegativeInt 40 | device: str 41 | 42 | 43 | def _tms_forward( 44 | x: Float[Tensor, "batch n_instances n_features"], 45 | linear1: Linear | LinearComponent, 46 | linear2: TransposedLinear | TransposedLinearComponent, 47 | b_final: Float[Tensor, "n_instances n_features"], 48 | topk_mask: Bool[Tensor, "batch n_instances C"] | None = None, 49 | hidden_layers: nn.ModuleList | None = None, 50 | ) -> Float[Tensor, "batch n_instances n_features"]: 51 | """Forward pass used for TMSModel and TMSSPDModel. 52 | 53 | Note that topk_mask is only used for TMSSPDModel. 54 | """ 55 | hidden = linear1(x, topk_mask=topk_mask) 56 | if hidden_layers is not None: 57 | for layer in hidden_layers: 58 | hidden = layer(hidden, topk_mask=topk_mask) 59 | out_pre_relu = linear2(hidden, topk_mask=topk_mask) + b_final 60 | out = F.relu(out_pre_relu) 61 | return out 62 | 63 | 64 | class TMSModel(HookedRootModule): 65 | def __init__(self, config: TMSModelConfig): 66 | super().__init__() 67 | self.config = config 68 | 69 | self.linear1 = Linear( 70 | d_in=config.n_features, 71 | d_out=config.n_hidden, 72 | n_instances=config.n_instances, 73 | init_type="xavier_normal", 74 | ) 75 | # Use tied weights for the second linear layer 76 | self.linear2 = TransposedLinear(self.linear1.weight) 77 | 78 | self.b_final = nn.Parameter(torch.zeros((config.n_instances, config.n_features))) 79 | 80 | self.hidden_layers = None 81 | if config.n_hidden_layers > 0: 82 | self.hidden_layers = nn.ModuleList() 83 | for _ in range(config.n_hidden_layers): 84 | layer = Linear( 85 | d_in=config.n_hidden, 86 | d_out=config.n_hidden, 87 | n_instances=config.n_instances, 88 | init_type="xavier_normal", 89 | ) 90 | self.hidden_layers.append(layer) 91 | self.setup() 92 | 93 | def forward( 94 | self, x: Float[Tensor, "... n_instances n_features"], **_: Any 95 | ) -> Float[Tensor, "... n_instances n_features"]: 96 | return _tms_forward( 97 | x=x, 98 | linear1=self.linear1, 99 | linear2=self.linear2, 100 | b_final=self.b_final, 101 | hidden_layers=self.hidden_layers, 102 | ) 103 | 104 | @staticmethod 105 | def _download_wandb_files(wandb_project_run_id: str) -> TMSModelPaths: 106 | """Download the relevant files from a wandb run.""" 107 | api = wandb.Api() 108 | run: Run = api.run(wandb_project_run_id) 109 | run_dir = fetch_wandb_run_dir(run.id) 110 | 111 | tms_model_config_path = download_wandb_file(run, run_dir, "tms_train_config.yaml") 112 | 113 | checkpoint = fetch_latest_wandb_checkpoint(run) 114 | checkpoint_path = download_wandb_file(run, run_dir, checkpoint.name) 115 | return TMSModelPaths(tms_train_config=tms_model_config_path, checkpoint=checkpoint_path) 116 | 117 | @classmethod 118 | def from_pretrained(cls, path: ModelPath) -> tuple["TMSModel", dict[str, Any]]: 119 | """Fetch a pretrained model from wandb or a local path to a checkpoint. 120 | 121 | Args: 122 | path: The path to local checkpoint or wandb project. If a wandb project, format must be 123 | `wandb://` or `wandb://runs/`. 124 | If `api.entity` is set (e.g. via setting WANDB_ENTITY in .env), can be 125 | omitted, and if `api.project` is set, can be omitted. If local path, 126 | assumes that `resid_mlp_train_config.yaml` and `label_coeffs.json` are in the same 127 | directory as the checkpoint. 128 | 129 | Returns: 130 | model: The pretrained TMSModel 131 | tms_model_config_dict: The config dict used to train the model (we don't 132 | instantiate a train config due to circular import issues) 133 | """ 134 | if isinstance(path, str) and path.startswith(WANDB_PATH_PREFIX): 135 | wandb_path = path.removeprefix(WANDB_PATH_PREFIX) 136 | paths = cls._download_wandb_files(wandb_path) 137 | else: 138 | # `path` should be a local path to a checkpoint 139 | paths = TMSModelPaths( 140 | tms_train_config=Path(path).parent / "tms_train_config.yaml", 141 | checkpoint=Path(path), 142 | ) 143 | 144 | with open(paths.tms_train_config) as f: 145 | tms_train_config_dict = yaml.safe_load(f) 146 | 147 | tms_config = TMSModelConfig(**tms_train_config_dict["tms_model_config"]) 148 | tms = cls(config=tms_config) 149 | params = torch.load(paths.checkpoint, weights_only=True, map_location="cpu") 150 | params = replace_deprecated_param_names(params, {"W": "linear1.weight"}) 151 | tms.load_state_dict(params) 152 | 153 | return tms, tms_train_config_dict 154 | 155 | 156 | class TMSSPDPaths(BaseModel): 157 | """Paths to output files from a TMSSPDModel training run.""" 158 | 159 | final_config: Path 160 | tms_train_config: Path 161 | checkpoint: Path 162 | 163 | 164 | class TMSSPDModelConfig(BaseModel): 165 | model_config = ConfigDict(extra="forbid", frozen=True) 166 | n_instances: PositiveInt 167 | n_features: PositiveInt 168 | n_hidden: PositiveInt 169 | n_hidden_layers: NonNegativeInt 170 | C: PositiveInt | None = None 171 | bias_val: float 172 | device: str 173 | m: PositiveInt | None = None 174 | 175 | 176 | class TMSSPDModel(SPDModel): 177 | def __init__(self, config: TMSSPDModelConfig): 178 | super().__init__() 179 | self.config = config 180 | self.n_instances = config.n_instances # Required for backwards compatibility 181 | self.n_features = config.n_features # Required for backwards compatibility 182 | self.C = config.C if config.C is not None else config.n_features 183 | self.bias_val = config.bias_val 184 | 185 | self.m = min(config.n_features, config.n_hidden) + 1 if config.m is None else config.m 186 | 187 | self.linear1 = LinearComponent( 188 | d_in=config.n_features, 189 | d_out=config.n_hidden, 190 | n_instances=config.n_instances, 191 | init_type="xavier_normal", 192 | init_scale=1.0, 193 | C=self.C, 194 | m=self.m, 195 | ) 196 | self.linear2 = TransposedLinearComponent(self.linear1.A, self.linear1.B) 197 | 198 | bias_data = ( 199 | torch.zeros((config.n_instances, config.n_features), device=config.device) 200 | + config.bias_val 201 | ) 202 | self.b_final = nn.Parameter(bias_data) 203 | 204 | self.hidden_layers = None 205 | if config.n_hidden_layers > 0: 206 | self.hidden_layers = nn.ModuleList( 207 | [ 208 | LinearComponent( 209 | d_in=config.n_hidden, 210 | d_out=config.n_hidden, 211 | n_instances=config.n_instances, 212 | init_type="xavier_normal", 213 | init_scale=1.0, 214 | C=self.C, 215 | m=self.m, 216 | ) 217 | for _ in range(config.n_hidden_layers) 218 | ] 219 | ) 220 | 221 | self.setup() 222 | 223 | def forward( 224 | self, 225 | x: Float[Tensor, "batch n_instances n_features"], 226 | topk_mask: Bool[Tensor, "batch n_instances C"] | None = None, 227 | ) -> Float[Tensor, "batch n_instances n_features"]: 228 | return _tms_forward( 229 | x=x, 230 | linear1=self.linear1, 231 | linear2=self.linear2, 232 | b_final=self.b_final, 233 | hidden_layers=self.hidden_layers, 234 | topk_mask=topk_mask, 235 | ) 236 | 237 | @staticmethod 238 | def _download_wandb_files(wandb_project_run_id: str) -> TMSSPDPaths: 239 | """Download the relevant files from a wandb run.""" 240 | api = wandb.Api() 241 | run: Run = api.run(wandb_project_run_id) 242 | 243 | checkpoint = fetch_latest_wandb_checkpoint(run, prefix="spd_model") 244 | 245 | run_dir = fetch_wandb_run_dir(run.id) 246 | 247 | final_config_path = download_wandb_file(run, run_dir, "final_config.yaml") 248 | tms_train_config_path = download_wandb_file(run, run_dir, "tms_train_config.yaml") 249 | checkpoint_path = download_wandb_file(run, run_dir, checkpoint.name) 250 | return TMSSPDPaths( 251 | final_config=final_config_path, 252 | tms_train_config=tms_train_config_path, 253 | checkpoint=checkpoint_path, 254 | ) 255 | 256 | @classmethod 257 | def from_pretrained(cls, path: ModelPath) -> tuple["TMSSPDModel", Config]: 258 | """Fetch a pretrained model from wandb or a local path to a checkpoint. 259 | 260 | Args: 261 | path: The path to local checkpoint or wandb project. If a wandb project, the format 262 | must be `wandb:entity/project/run_id`. If `api.entity` is set (e.g. via setting 263 | WANDB_ENTITY in .env), this can be in the form `wandb:project/run_id` and if 264 | form `wandb:project/run_id` and if `api.project` is set this can just be 265 | `wandb:run_id`. If local path, assumes that `resid_mlp_train_config.yaml` and 266 | `label_coeffs.json` are in the same directory as the checkpoint. 267 | """ 268 | if isinstance(path, str) and path.startswith(WANDB_PATH_PREFIX): 269 | wandb_path = path.removeprefix(WANDB_PATH_PREFIX) 270 | paths = cls._download_wandb_files(wandb_path) 271 | else: 272 | paths = TMSSPDPaths( 273 | final_config=Path(path).parent / "final_config.yaml", 274 | tms_train_config=Path(path).parent / "tms_train_config.yaml", 275 | checkpoint=Path(path), 276 | ) 277 | 278 | with open(paths.final_config) as f: 279 | final_config_dict = yaml.safe_load(f) 280 | 281 | spd_config = Config(**final_config_dict) 282 | 283 | with open(paths.tms_train_config) as f: 284 | tms_train_config_dict = yaml.safe_load(f) 285 | 286 | assert isinstance(spd_config.task_config, TMSTaskConfig) 287 | tms_spd_config = TMSSPDModelConfig( 288 | **tms_train_config_dict["tms_model_config"], 289 | C=spd_config.C, 290 | m=spd_config.m, 291 | bias_val=spd_config.task_config.bias_val, 292 | ) 293 | model = cls(config=tms_spd_config) 294 | params = torch.load(paths.checkpoint, weights_only=True, map_location="cpu") 295 | params = replace_deprecated_param_names(params, {"A": "linear1.A", "B": "linear1.B"}) 296 | model.load_state_dict(params) 297 | return model, spd_config 298 | -------------------------------------------------------------------------------- /spd/experiments/tms/spd_interp.py: -------------------------------------------------------------------------------- 1 | # %% 2 | 3 | import matplotlib.collections as mc 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import numpy.typing as npt 7 | import torch 8 | from jaxtyping import Float 9 | from torch import Tensor 10 | 11 | from spd.experiments.tms.models import TMSModel, TMSSPDModel 12 | from spd.plotting import collect_sparse_dataset_mse_losses, plot_sparse_feature_mse_line_plot 13 | from spd.run_spd import TMSTaskConfig 14 | from spd.settings import REPO_ROOT 15 | from spd.utils import COLOR_PALETTE, DataGenerationType, SparseFeatureDataset 16 | 17 | 18 | def plot_vectors( 19 | subnets: Float[Tensor, "n_instances n_subnets n_features n_hidden"], 20 | axs: npt.NDArray[np.object_], 21 | ) -> None: 22 | """2D polygon plot of each subnetwork. 23 | 24 | Adapted from 25 | https://colab.research.google.com/github/anthropics/toy-models-of-superposition/blob/main/toy_models.ipynb. 26 | """ 27 | n_instances, n_subnets, n_features, n_hidden = subnets.shape 28 | 29 | # Use different colors for each subnetwork if there's only one instance 30 | color_vals = np.linspace(0, 1, n_features) if n_instances == 1 else np.zeros(n_features) 31 | colors = plt.cm.viridis(color_vals) # type: ignore 32 | 33 | for subnet_idx in range(n_subnets): 34 | for instance_idx, ax in enumerate(axs[:, subnet_idx]): 35 | arr = subnets[instance_idx, subnet_idx].cpu().detach().numpy() 36 | 37 | # Plot each feature with its unique color 38 | for j in range(n_features): 39 | ax.scatter(arr[j, 0], arr[j, 1], color=colors[j]) 40 | ax.add_collection( 41 | mc.LineCollection([[(0, 0), (arr[j, 0], arr[j, 1])]], colors=[colors[j]]) 42 | ) 43 | 44 | ax.set_aspect("equal") 45 | z = 1.3 46 | ax.set_facecolor("#f6f6f6") 47 | ax.set_xlim((-z, z)) 48 | ax.set_ylim((-z, z)) 49 | ax.tick_params(left=True, right=False, labelleft=False, labelbottom=False, bottom=True) 50 | for spine in ["top", "right"]: 51 | ax.spines[spine].set_visible(False) 52 | for spine in ["bottom", "left"]: 53 | ax.spines[spine].set_position("center") 54 | 55 | if instance_idx == 0: # Only add labels to the first row 56 | if subnet_idx == 0: 57 | label = "Target model" 58 | elif subnet_idx == 1: 59 | label = "Sum of components" 60 | else: 61 | label = f"Component {subnet_idx - 2}" 62 | ax.set_title(label, pad=10, fontsize="large") 63 | 64 | 65 | def plot_networks( 66 | subnets: Float[Tensor, "n_instances n_subnets n_features n_hidden"], 67 | axs: npt.NDArray[np.object_], 68 | ) -> None: 69 | """Plot neural network diagrams for each W matrix in the subnet variable. 70 | 71 | Args: 72 | subnets: Tensor of shape [n_instances, n_subnets, n_features, n_hidden]. 73 | axs: Matplotlib axes to plot on. 74 | """ 75 | 76 | n_instances, n_subnets, n_features, n_hidden = subnets.shape 77 | 78 | # Take the absolute value of the weights 79 | subnets_abs = subnets.abs() 80 | 81 | # Find the maximum weight across each instance 82 | max_weights = subnets_abs.amax(dim=(1, 2, 3)) 83 | 84 | axs = np.atleast_2d(np.array(axs)) 85 | 86 | # axs[0, 0].set_xlabel("Outputs (before ReLU and biases)") 87 | # Add the above but in text because the x-axis is killed 88 | axs[0, 0].text( 89 | 0.05, 90 | 0.05, 91 | "Outputs (before bias & ReLU)", 92 | ha="left", 93 | va="center", 94 | transform=axs[0, 0].transAxes, 95 | ) 96 | # Also add "input label" 97 | axs[0, 0].text( 98 | 0.05, 99 | 0.95, 100 | "Inputs", 101 | ha="left", 102 | va="center", 103 | transform=axs[0, 0].transAxes, 104 | ) 105 | 106 | # Grayscale colormap. darker for larger weight 107 | cmap = plt.get_cmap("gray_r") 108 | 109 | for subnet_idx in range(n_subnets): 110 | for instance_idx, ax in enumerate(axs[:, subnet_idx]): 111 | arr = subnets_abs[instance_idx, subnet_idx].cpu().detach().numpy() 112 | 113 | # Define node positions (top to bottom) 114 | y_input, y_hidden, y_output = 0, -1, -2 115 | x_input = np.linspace(0.05, 0.95, n_features) 116 | x_hidden = np.linspace(0.25, 0.75, n_hidden) 117 | x_output = np.linspace(0.05, 0.95, n_features) 118 | 119 | # Add transparent grey box around hidden layer 120 | box_width = 0.8 121 | box_height = 0.4 122 | box = plt.Rectangle( 123 | (0.5 - box_width / 2, y_hidden - box_height / 2), 124 | box_width, 125 | box_height, 126 | fill=True, 127 | facecolor="#e4e4e4", 128 | edgecolor="none", 129 | alpha=0.33, 130 | transform=ax.transData, 131 | ) 132 | ax.add_patch(box) 133 | 134 | # Plot nodes 135 | ax.scatter( 136 | x_input, [y_input] * n_features, s=200, color="grey", edgecolors="k", zorder=3 137 | ) 138 | ax.scatter( 139 | x_hidden, [y_hidden] * n_hidden, s=200, color="grey", edgecolors="k", zorder=3 140 | ) 141 | ax.scatter( 142 | x_output, [y_output] * n_features, s=200, color="grey", edgecolors="k", zorder=3 143 | ) 144 | 145 | # Plot edges from input to hidden layer 146 | for idx_input in range(n_features): 147 | for idx_hidden in range(n_hidden): 148 | weight = arr[idx_input, idx_hidden] 149 | norm_weight = weight / max_weights[instance_idx] 150 | color = cmap(norm_weight) 151 | ax.plot( 152 | [x_input[idx_input], x_hidden[idx_hidden]], 153 | [y_input, y_hidden], 154 | color=color, 155 | linewidth=1, 156 | ) 157 | 158 | # Plot edges from hidden to output layer 159 | arr_T = arr.T # Transpose of W for W^T 160 | for idx_hidden in range(n_hidden): 161 | for idx_output in range(n_features): 162 | weight = arr_T[idx_hidden, idx_output] 163 | norm_weight = weight / max_weights[instance_idx] 164 | color = cmap(norm_weight) 165 | ax.plot( 166 | [x_hidden[idx_hidden], x_output[idx_output]], 167 | [y_hidden, y_output], 168 | color=color, 169 | linewidth=1, 170 | ) 171 | 172 | # Remove axes for clarity 173 | # ax.axis("off") 174 | ax.set_xlim(-0.1, 1.1) 175 | ax.set_ylim(y_output - 0.5, y_input + 0.5) 176 | # Remove x and y ticks and bounding boxes 177 | ax.set_xticks([]) 178 | ax.set_yticks([]) 179 | for spine in ["top", "right", "bottom", "left"]: 180 | ax.spines[spine].set_visible(False) 181 | 182 | 183 | def plot_combined( 184 | subnets: Float[Tensor, "n_instances n_subnets n_features n_hidden"], 185 | target_weights: Float[Tensor, "n_instances n_features n_hidden"], 186 | n_instances: int | None = None, 187 | ) -> plt.Figure: 188 | """Create a combined figure with both vector and network diagrams side by side.""" 189 | if n_instances is not None: 190 | subnets = subnets[:n_instances] 191 | target_weights = target_weights[:n_instances] 192 | n_instances, n_subnets, n_features, n_hidden = subnets.shape 193 | 194 | # We wish to add two panels to the left: The target model weights and the sum of the subnets 195 | # Add an extra dimension to the target weights so we can concatenate them 196 | target_subnet = target_weights[:, None, :, :] 197 | summed_subnet = subnets.sum(dim=1, keepdim=True) 198 | subnets = torch.cat([target_subnet, summed_subnet, subnets], dim=1) 199 | n_subnets += 2 200 | 201 | # Create figure with two rows 202 | fig, axs = plt.subplots( 203 | nrows=n_instances * 2, 204 | ncols=n_subnets, 205 | figsize=(3 * n_subnets, 6 * n_instances), 206 | ) 207 | 208 | plt.subplots_adjust(hspace=0) 209 | 210 | axs = np.atleast_2d(np.array(axs)) 211 | 212 | # Split axes into left (vectors) and right (networks) sides 213 | axs_vectors = axs[:n_instances, :] 214 | axs_networks = axs[n_instances:, :] 215 | 216 | # Call existing plotting logic with the split axes 217 | plot_vectors(subnets=subnets, axs=axs_vectors) 218 | plot_networks(subnets=subnets, axs=axs_networks) 219 | 220 | return fig 221 | 222 | 223 | # %% 224 | device = "cuda" if torch.cuda.is_available() else "cpu" 225 | # path = "wandb:spd-tms/runs/bft0pgi8" # Old 5-2 run with attributions from spd model # paper run 226 | # instance_idx = 0 227 | # path = "wandb:spd-tms/runs/sv9padmo" # 10-5 228 | # path = "wandb:spd-tms/runs/vt0i4a22" # 20-5 229 | # path = "wandb:spd-tms/runs/tyo4serm" # 40-10 with topk=2, topk_recon_coeff=1e1, schatten_coeff=15# old paper run 230 | # path = "wandb:spd-tms/runs/9zzp2s68" # 40-10 with topk=2, topk_recon_coeff=1e1, schatten_coeff=20 231 | path = "wandb:spd-tms/runs/08no00iq" # 40-10 with topk=1, topk_recon_coeff=1e1, schatten_coeff=20# new paper run 232 | instance_idx = 2 233 | # path = "wandb:spd-tms/runs/014t4f9n" # 40-10 with topk=1, topk_recon_coeff=1e1, schatten_coeff=1e1 234 | 235 | run_id = path.split("/")[-1] 236 | 237 | # Plot showing polygons for each subnet 238 | model, config = TMSSPDModel.from_pretrained(path) 239 | subnets = model.linear1.component_weights.detach().cpu() 240 | 241 | assert isinstance(config.task_config, TMSTaskConfig) 242 | target_model, target_model_train_config_dict = TMSModel.from_pretrained( 243 | config.task_config.pretrained_model_path 244 | ) 245 | 246 | out_dir = REPO_ROOT / "spd/experiments/tms/out/figures/" 247 | out_dir.mkdir(parents=True, exist_ok=True) 248 | 249 | 250 | # %% 251 | # Max cosine similarity between subnets and target model 252 | def plot_max_cosine_sim(max_cosine_sim: Float[Tensor, " n_features"]) -> plt.Figure: 253 | fig, ax = plt.subplots() 254 | # Make a bar plot of the max cosine similarity for each feature 255 | ax.bar(range(max_cosine_sim.shape[0]), max_cosine_sim.cpu().detach().numpy()) 256 | # Add a grey horizontal line at 1 257 | ax.axhline(1, color="grey", linestyle="--") 258 | ax.set_xlabel("Input feature index") 259 | ax.set_ylabel("Max cosine similarity") 260 | # Remove top and right spines 261 | ax.spines["top"].set_visible(False) 262 | ax.spines["right"].set_visible(False) 263 | return fig 264 | 265 | 266 | cosine_sims = torch.einsum( 267 | "C f h, f h -> C f", 268 | subnets[instance_idx] / torch.norm(subnets[instance_idx], dim=-1, keepdim=True), 269 | target_model.linear1.weight[instance_idx] 270 | / torch.norm(target_model.linear1.weight[instance_idx], dim=-1, keepdim=True), 271 | ) 272 | max_cosine_sim = cosine_sims.max(dim=0).values 273 | print(f"Max cosine similarity:\n{max_cosine_sim}") 274 | print(f"Mean max cosine similarity: {max_cosine_sim.mean()}") 275 | print(f"std max cosine similarity: {max_cosine_sim.std()}") 276 | 277 | 278 | # Get the subnet weights at the max cosine similarity 279 | subnet_weights_at_max_cosine_sim: Float[Tensor, "n_features n_hidden"] = subnets[ 280 | instance_idx, cosine_sims.max(dim=0).indices, torch.arange(target_model.config.n_features) 281 | ] 282 | # Get the norm of the target model weights 283 | target_model_weights_norm = torch.norm( 284 | target_model.linear1.weight[instance_idx], dim=-1, keepdim=True 285 | ) 286 | # Get the norm of subnet_weights_at_max_cosine_sim 287 | subnet_weights_at_max_cosine_sim_norm = torch.norm( 288 | subnet_weights_at_max_cosine_sim, dim=-1, keepdim=True 289 | ) 290 | # Divide the subnet weights by the target model weights ratio 291 | l2_ratio = subnet_weights_at_max_cosine_sim_norm / target_model_weights_norm 292 | print(f"Mean L2 ratio: {l2_ratio.mean()}") 293 | print(f"std L2 ratio: {l2_ratio.std()}") 294 | 295 | # Mean bias 296 | print(f"Mean bias: {target_model.b_final[instance_idx].mean()}") 297 | 298 | 299 | # fig = plot_max_cosine_sim(max_cosine_sim) 300 | # # Save figure 301 | # fig.savefig(out_dir / f"tms_max_cosine_sim_{run_id}.png", bbox_inches="tight", dpi=400) 302 | # print(f"Saved figure to {out_dir / f'tms_max_cosine_sim_{run_id}.png'}") 303 | # %% 304 | # Only plot if the hidden dimension is 2 305 | if target_model.config.n_hidden == 2: 306 | # We only look at the first instance 307 | fig = plot_combined(subnets, target_model.linear1.weight.detach().cpu(), n_instances=1) 308 | fig.savefig(out_dir / f"tms_combined_diagram_{run_id}.png", bbox_inches="tight", dpi=400) 309 | print(f"Saved figure to {out_dir / f'tms_combined_diagram_{run_id}.png'}") 310 | 311 | # %% 312 | # This doesn't work for TMS. 313 | # # Get the entries for the main loss table in the paper 314 | dataset = SparseFeatureDataset( 315 | n_instances=target_model.config.n_instances, 316 | n_features=target_model.config.n_features, 317 | feature_probability=config.task_config.feature_probability, 318 | device=device, 319 | data_generation_type="at_least_zero_active", # This will be changed in collect_sparse_dataset_mse_losses 320 | value_range=(0.0, 1.0), 321 | ) 322 | gen_types: list[DataGenerationType] = [ 323 | "at_least_zero_active", 324 | "exactly_one_active", 325 | "exactly_two_active", 326 | "exactly_three_active", 327 | "exactly_four_active", 328 | ] 329 | assert config.topk is not None 330 | results = collect_sparse_dataset_mse_losses( 331 | dataset=dataset, 332 | target_model=target_model, 333 | spd_model=model, 334 | batch_size=10000, 335 | device=device, 336 | topk=config.topk, 337 | attribution_type=config.attribution_type, 338 | batch_topk=config.batch_topk, 339 | distil_from_target=config.distil_from_target, 340 | gen_types=gen_types, 341 | ) 342 | 343 | # %% 344 | # Option to plot a single instance 345 | inst = None 346 | if inst is not None: 347 | # We only plot the {inst}th instance 348 | plot_data = { 349 | gen_type: {k: float(v[inst].detach().cpu()) for k, v in results[gen_type].items()} 350 | for gen_type in gen_types 351 | } 352 | else: 353 | # Take the mean over all instances 354 | plot_data = { 355 | gen_type: {k: float(v.mean(dim=0).detach().cpu()) for k, v in results[gen_type].items()} 356 | for gen_type in gen_types 357 | } 358 | 359 | # %% 360 | # Create line plot of results 361 | color_map = { 362 | "target": COLOR_PALETTE[0], 363 | "apd_topk": COLOR_PALETTE[1], 364 | "baseline_monosemantic": "grey", 365 | } 366 | label_map = [ 367 | ("target", "Target model", color_map["target"]), 368 | ("spd", "APD model", color_map["apd_topk"]), 369 | ("baseline_monosemantic", "Monosemantic baseline", color_map["baseline_monosemantic"]), 370 | ] 371 | 372 | fig = plot_sparse_feature_mse_line_plot(plot_data, label_map=label_map, log_scale=False) 373 | fig.show() 374 | # fig.savefig(out_dir / f"tms_mse_{run_id}_inst{inst}.png", dpi=400) 375 | # print(f"Saved figure to {out_dir / f'tms_mse_{run_id}_inst{inst}.png'}") 376 | fig.savefig(out_dir / f"tms_mse_{run_id}.png", dpi=400) 377 | print(f"Saved figure to {out_dir / f'tms_mse_{run_id}.png'}") 378 | 379 | # %% 380 | -------------------------------------------------------------------------------- /spd/experiments/tms/tms_decomposition.py: -------------------------------------------------------------------------------- 1 | """Run spd on a TMS model. 2 | 3 | Note that the first instance index is fixed to the identity matrix. This is done so we can compare 4 | the losses of the "correct" solution during training. 5 | """ 6 | 7 | from datetime import datetime 8 | from pathlib import Path 9 | from typing import Any 10 | 11 | import fire 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | import torch 15 | import wandb 16 | import yaml 17 | from jaxtyping import Float 18 | from torch import Tensor 19 | from tqdm import tqdm 20 | 21 | from spd.experiments.tms.models import ( 22 | TMSModel, 23 | TMSModelConfig, 24 | TMSSPDModel, 25 | TMSSPDModelConfig, 26 | ) 27 | from spd.log import logger 28 | from spd.run_spd import Config, TMSTaskConfig, get_common_run_name_suffix, optimize 29 | from spd.utils import ( 30 | DatasetGeneratedDataLoader, 31 | SparseFeatureDataset, 32 | collect_subnetwork_attributions, 33 | load_config, 34 | set_seed, 35 | ) 36 | from spd.wandb_utils import init_wandb 37 | 38 | wandb.require("core") 39 | 40 | 41 | def get_run_name(config: Config, tms_model_config: TMSModelConfig) -> str: 42 | """Generate a run name based on the config.""" 43 | if config.wandb_run_name: 44 | run_suffix = config.wandb_run_name 45 | else: 46 | run_suffix = get_common_run_name_suffix(config) 47 | run_suffix += f"ft{tms_model_config.n_features}_" 48 | run_suffix += f"hid{tms_model_config.n_hidden}" 49 | run_suffix += f"hid-layers{tms_model_config.n_hidden_layers}" 50 | return config.wandb_run_name_prefix + run_suffix 51 | 52 | 53 | def plot_A_matrix(x: torch.Tensor, pos_only: bool = False) -> plt.Figure: 54 | n_instances = x.shape[0] 55 | 56 | fig, axs = plt.subplots( 57 | 1, n_instances, figsize=(2.5 * n_instances, 2), squeeze=False, sharey=True 58 | ) 59 | 60 | cmap = "Blues" if pos_only else "RdBu" 61 | ims = [] 62 | for i in range(n_instances): 63 | ax = axs[0, i] 64 | instance_data = x[i, :, :].detach().cpu().float().numpy() 65 | max_abs_val = np.abs(instance_data).max() 66 | vmin = 0 if pos_only else -max_abs_val 67 | vmax = max_abs_val 68 | im = ax.matshow(instance_data, vmin=vmin, vmax=vmax, cmap=cmap) 69 | ims.append(im) 70 | ax.xaxis.set_ticks_position("bottom") 71 | if i == 0: 72 | ax.set_ylabel("k", rotation=0, labelpad=10, va="center") 73 | else: 74 | ax.set_yticks([]) # Remove y-axis ticks for all but the first plot 75 | ax.xaxis.set_label_position("top") 76 | ax.set_xlabel("n_features") 77 | 78 | plt.subplots_adjust(wspace=0.1, bottom=0.15, top=0.9) 79 | fig.subplots_adjust(bottom=0.2) 80 | 81 | return fig 82 | 83 | 84 | def plot_subnetwork_attributions_multiple_instances( 85 | attribution_scores: Float[Tensor, "batch n_instances C"], 86 | out_dir: Path, 87 | step: int | None, 88 | ) -> plt.Figure: 89 | """Plot subnetwork attributions for multiple instances in a row.""" 90 | n_instances = attribution_scores.shape[1] 91 | 92 | # Create a wide figure with subplots in a row 93 | fig, axes = plt.subplots(1, n_instances, figsize=(5 * n_instances, 5), constrained_layout=True) 94 | 95 | axes = np.array([axes]) if isinstance(axes, plt.Axes) else axes 96 | 97 | images = [] 98 | for idx, ax in enumerate(axes): 99 | instance_scores = attribution_scores[:, idx, :] 100 | im = ax.matshow(instance_scores.detach().cpu().numpy(), aspect="auto", cmap="Reds") 101 | images.append(im) 102 | 103 | # Annotate each cell with the numeric value 104 | for i in range(instance_scores.shape[0]): 105 | for j in range(instance_scores.shape[1]): 106 | ax.text( 107 | j, 108 | i, 109 | f"{instance_scores[i, j]:.2f}", 110 | ha="center", 111 | va="center", 112 | color="black", 113 | fontsize=10, 114 | ) 115 | 116 | ax.set_xlabel("Subnetwork Index") 117 | if idx == 0: # Only set ylabel for leftmost plot 118 | ax.set_ylabel("Batch Index") 119 | ax.set_title(f"Instance {idx}") 120 | 121 | # Add a single colorbar that references all plots 122 | norm = plt.Normalize(vmin=attribution_scores.min().item(), vmax=attribution_scores.max().item()) 123 | for im in images: 124 | im.set_norm(norm) 125 | fig.colorbar(images[0], ax=axes) 126 | 127 | fig.suptitle(f"Subnetwork Attributions (Step {step})") 128 | filename = ( 129 | f"subnetwork_attributions_s{step}.png" 130 | if step is not None 131 | else "subnetwork_attributions.png" 132 | ) 133 | fig.savefig(out_dir / filename, dpi=300, bbox_inches="tight") 134 | plt.close(fig) 135 | tqdm.write(f"Saved subnetwork attributions to {out_dir / filename}") 136 | return fig 137 | 138 | 139 | def plot_subnetwork_attributions_statistics_multiple_instances( 140 | topk_mask: Float[Tensor, "batch_size n_instances C"], out_dir: Path, step: int | None 141 | ) -> plt.Figure: 142 | """Plot a row of vertical bar charts showing active subnetworks for each instance.""" 143 | n_instances = topk_mask.shape[1] 144 | fig, axes = plt.subplots(1, n_instances, figsize=(5 * n_instances, 5), constrained_layout=True) 145 | 146 | axes = np.array([axes]) if isinstance(axes, plt.Axes) else axes 147 | 148 | for instance_idx in range(n_instances): 149 | ax = axes[instance_idx] 150 | instance_mask = topk_mask[:, instance_idx] 151 | 152 | values = instance_mask.sum(dim=1).cpu().detach().numpy() 153 | bins = list(range(int(values.min().item()), int(values.max().item()) + 2)) 154 | counts, _ = np.histogram(values, bins=bins) 155 | 156 | bars = ax.bar(bins[:-1], counts, align="center", width=0.8) 157 | ax.set_xticks(bins[:-1]) 158 | ax.set_xticklabels([str(b) for b in bins[:-1]]) 159 | ax.set_title(f"Instance {instance_idx}") 160 | 161 | if instance_idx == 0: # Only set y-label for leftmost plot 162 | ax.set_ylabel("Count") 163 | ax.set_xlabel("Number of active subnetworks") 164 | 165 | # Add value annotations on top of each bar 166 | for bar in bars: 167 | height = bar.get_height() 168 | ax.annotate( 169 | f"{height}", 170 | xy=(bar.get_x() + bar.get_width() / 2, height), 171 | xytext=(0, 3), 172 | textcoords="offset points", 173 | ha="center", 174 | va="bottom", 175 | ) 176 | 177 | fig.suptitle(f"Active subnetworks per instance (batch_size={topk_mask.shape[0]})") 178 | filename = ( 179 | f"subnetwork_attributions_statistics_s{step}.png" 180 | if step is not None 181 | else "subnetwork_attributions_statistics.png" 182 | ) 183 | fig.savefig(out_dir / filename, dpi=300, bbox_inches="tight") 184 | plt.close(fig) 185 | tqdm.write(f"Saved subnetwork attributions statistics to {out_dir / filename}") 186 | return fig 187 | 188 | 189 | def plot_component_weights(model: TMSSPDModel, step: int, out_dir: Path, **_) -> plt.Figure: 190 | """Plot the component weight matrices.""" 191 | component_weights = model.linear1.component_weights 192 | 193 | # component_weights: [n_instances, k, n_features, n_hidden] 194 | n_instances, C, dim1, dim2 = component_weights.shape 195 | 196 | fig, axs = plt.subplots( 197 | C, 198 | n_instances, 199 | figsize=(2 * n_instances, 2 * C), 200 | constrained_layout=True, 201 | ) 202 | 203 | for i in range(n_instances): 204 | instance_max = np.abs(component_weights[i].detach().cpu().numpy()).max() 205 | for j in range(C): 206 | ax = axs[j, i] # type: ignore 207 | param = component_weights[i, j].detach().cpu().numpy() 208 | ax.matshow(param, cmap="RdBu", vmin=-instance_max, vmax=instance_max) 209 | ax.set_xticks([]) 210 | 211 | if i == 0: 212 | ax.set_ylabel(f"k={j}", rotation=0, ha="right", va="center") 213 | if j == C - 1: 214 | ax.set_xlabel(f"Inst {i}", rotation=45, ha="right") 215 | 216 | fig.suptitle(f"Component Weights (Step {step})") 217 | fig.savefig(out_dir / f"component_weights_{step}.png", dpi=300, bbox_inches="tight") 218 | plt.close(fig) 219 | tqdm.write(f"Saved component weights to {out_dir / f'component_weights_{step}.png'}") 220 | return fig 221 | 222 | 223 | def plot_batch_frequencies( 224 | frequencies: Float[Tensor, "n_instances C"], 225 | xlabel: str, 226 | ax: plt.Axes, 227 | batch_size: int, 228 | title: str | None = None, 229 | ) -> None: 230 | """Plot frequency of C activations for each instance on a given axis. 231 | 232 | Args: 233 | frequencies: Tensor counting frequencies for each instance 234 | xlabel: Label for x-axis 235 | ax: Matplotlib axis to plot on 236 | batch_size: Size of the batch 237 | title: Optional title for the subplot 238 | """ 239 | n_instances = frequencies.shape[0] 240 | C = frequencies.shape[1] 241 | 242 | for instance_idx in range(n_instances): 243 | bars = ax.bar( 244 | np.arange(C) + instance_idx * (C + 1), # Add spacing between instances 245 | frequencies[instance_idx].detach().cpu().numpy(), 246 | align="center", 247 | width=0.8, 248 | label=f"Instance {instance_idx}", 249 | ) 250 | 251 | # Add value annotations on top of each bar 252 | for bar in bars: 253 | height = bar.get_height() 254 | ax.annotate( 255 | f"{int(height)}", 256 | xy=(bar.get_x() + bar.get_width() / 2, height), 257 | xytext=(0, 3), 258 | textcoords="offset points", 259 | ha="center", 260 | va="bottom", 261 | ) 262 | 263 | ax.set_xlabel(xlabel) 264 | ax.set_ylabel(f"Activation Count (batch_size={batch_size})") 265 | if title: 266 | ax.set_title(title) 267 | 268 | # Set x-ticks for each instance group 269 | all_ticks = [] 270 | all_labels = [] 271 | for i in range(n_instances): 272 | ticks = np.arange(C) + i * (C + 1) 273 | all_ticks.extend(ticks) 274 | all_labels.extend([str(j) for j in range(C)]) 275 | ax.set_xticks(all_ticks) 276 | ax.set_xticklabels(all_labels) 277 | 278 | 279 | def plot_batch_statistics( 280 | batch: Float[Tensor, "batch n_instances n_features"], 281 | topk_mask: Float[Tensor, "batch n_instances C"], 282 | out_dir: Path, 283 | step: int | None, 284 | ) -> dict[str, plt.Figure]: 285 | # Count the number of active features over the batch 286 | active_input_feats = (batch != 0).sum(dim=0) 287 | topk_activations = topk_mask.sum(dim=0) 288 | 289 | # Create figure with two vertically stacked subplots 290 | fig = plt.figure(figsize=(15, 10)) 291 | gs = fig.add_gridspec(2, 1, height_ratios=[1, 1], hspace=0.3) 292 | 293 | # Plot input features 294 | ax1 = fig.add_subplot(gs[0]) 295 | plot_batch_frequencies( 296 | frequencies=active_input_feats, 297 | xlabel="Input feature index", 298 | ax=ax1, 299 | batch_size=batch.shape[0], 300 | title="Input feature frequencies across batch", 301 | ) 302 | 303 | # Plot subnetwork frequencies 304 | ax2 = fig.add_subplot(gs[1]) 305 | plot_batch_frequencies( 306 | frequencies=topk_activations, 307 | xlabel="Component index", 308 | ax=ax2, 309 | batch_size=batch.shape[0], 310 | title="Component frequencies across batch", 311 | ) 312 | 313 | # Ensure that each ax has the same y-axis maximum 314 | y_lims = [ax.get_ylim() for ax in [ax1, ax2]] 315 | y_max = max(y_lims[0][1], y_lims[1][1]) 316 | for ax in [ax1, ax2]: 317 | ax.set_ylim(0, y_max) 318 | 319 | # fig.suptitle(f"Batch Statistics (Step {step})") 320 | 321 | # Save the combined figure 322 | filename = f"batch_statistics_s{step}.png" if step is not None else "batch_statistics.png" 323 | fig.savefig(out_dir / filename, dpi=300, bbox_inches="tight") 324 | plt.close(fig) 325 | tqdm.write(f"Saved batch statistics to {out_dir / filename}") 326 | 327 | return {"batch_statistics": fig} 328 | 329 | 330 | def make_plots( 331 | model: TMSSPDModel, 332 | target_model: TMSModel, 333 | step: int, 334 | out_dir: Path, 335 | device: str, 336 | config: Config, 337 | topk_mask: Float[Tensor, "batch n_instances C"] | None, 338 | batch: Float[Tensor, "batch n_instances n_features"], 339 | **_, 340 | ) -> dict[str, plt.Figure]: 341 | plots = {} 342 | if model.hidden_layers is not None: 343 | logger.warning("Only plotting the W matrix params and not the hidden layers.") 344 | plots["component_weights"] = plot_component_weights(model, step, out_dir) 345 | 346 | if config.topk is not None: 347 | assert topk_mask is not None 348 | assert isinstance(config.task_config, TMSTaskConfig) 349 | n_instances = model.config.n_instances if hasattr(model, "config") else model.n_instances 350 | attribution_scores = collect_subnetwork_attributions( 351 | spd_model=model, 352 | target_model=target_model, 353 | device=device, 354 | n_instances=n_instances, 355 | ) 356 | plots["subnetwork_attributions"] = plot_subnetwork_attributions_multiple_instances( 357 | attribution_scores=attribution_scores, out_dir=out_dir, step=step 358 | ) 359 | plots["subnetwork_attributions_statistics"] = ( 360 | plot_subnetwork_attributions_statistics_multiple_instances( 361 | topk_mask=topk_mask, out_dir=out_dir, step=step 362 | ) 363 | ) 364 | 365 | batch_stat_plots = plot_batch_statistics(batch, topk_mask, out_dir, step) 366 | plots.update(batch_stat_plots) 367 | 368 | return plots 369 | 370 | 371 | def save_target_model_info( 372 | save_to_wandb: bool, 373 | out_dir: Path, 374 | tms_model: TMSModel, 375 | tms_model_train_config_dict: dict[str, Any], 376 | ) -> None: 377 | torch.save(tms_model.state_dict(), out_dir / "tms.pth") 378 | 379 | with open(out_dir / "tms_train_config.yaml", "w") as f: 380 | yaml.dump(tms_model_train_config_dict, f, indent=2) 381 | 382 | if save_to_wandb: 383 | wandb.save(str(out_dir / "tms.pth"), base_path=out_dir, policy="now") 384 | wandb.save(str(out_dir / "tms_train_config.yaml"), base_path=out_dir, policy="now") 385 | 386 | 387 | def main( 388 | config_path_or_obj: Path | str | Config, sweep_config_path: Path | str | None = None 389 | ) -> None: 390 | device = "cuda" if torch.cuda.is_available() else "cpu" 391 | 392 | config = load_config(config_path_or_obj, config_model=Config) 393 | 394 | if config.wandb_project: 395 | config = init_wandb(config, config.wandb_project, sweep_config_path) 396 | 397 | task_config = config.task_config 398 | assert isinstance(task_config, TMSTaskConfig) 399 | 400 | set_seed(config.seed) 401 | logger.info(config) 402 | 403 | target_model, target_model_train_config_dict = TMSModel.from_pretrained( 404 | task_config.pretrained_model_path 405 | ) 406 | target_model = target_model.to(device) 407 | target_model.eval() 408 | 409 | run_name = get_run_name(config=config, tms_model_config=target_model.config) 410 | if config.wandb_project: 411 | assert wandb.run, "wandb.run must be initialized before training" 412 | wandb.run.name = run_name 413 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] 414 | out_dir = Path(__file__).parent / "out" / f"{run_name}_{timestamp}" 415 | out_dir.mkdir(parents=True, exist_ok=True) 416 | 417 | with open(out_dir / "final_config.yaml", "w") as f: 418 | yaml.dump(config.model_dump(mode="json"), f, indent=2) 419 | if config.wandb_project: 420 | wandb.save(str(out_dir / "final_config.yaml"), base_path=out_dir, policy="now") 421 | 422 | save_target_model_info( 423 | save_to_wandb=config.wandb_project is not None, 424 | out_dir=out_dir, 425 | tms_model=target_model, 426 | tms_model_train_config_dict=target_model_train_config_dict, 427 | ) 428 | 429 | tms_spd_model_config = TMSSPDModelConfig( 430 | **target_model.config.model_dump(mode="json"), 431 | C=config.C, 432 | m=config.m, 433 | bias_val=task_config.bias_val, 434 | ) 435 | model = TMSSPDModel(config=tms_spd_model_config) 436 | 437 | # Manually set the bias for the SPD model from the bias in the pretrained model 438 | model.b_final.data[:] = target_model.b_final.data.clone() 439 | 440 | if not task_config.train_bias: 441 | model.b_final.requires_grad = False 442 | 443 | param_names = ["linear1", "linear2"] 444 | if model.hidden_layers is not None: 445 | for i in range(len(model.hidden_layers)): 446 | param_names.append(f"hidden_layers.{i}") 447 | 448 | synced_inputs = target_model_train_config_dict.get("synced_inputs", None) 449 | dataset = SparseFeatureDataset( 450 | n_instances=target_model.config.n_instances, 451 | n_features=target_model.config.n_features, 452 | feature_probability=task_config.feature_probability, 453 | device=device, 454 | data_generation_type=task_config.data_generation_type, 455 | value_range=(0.0, 1.0), 456 | synced_inputs=synced_inputs, 457 | ) 458 | dataloader = DatasetGeneratedDataLoader(dataset, batch_size=config.batch_size) 459 | 460 | optimize( 461 | model=model, 462 | config=config, 463 | device=device, 464 | dataloader=dataloader, 465 | target_model=target_model, 466 | param_names=param_names, 467 | out_dir=out_dir, 468 | plot_results_fn=make_plots, 469 | ) 470 | 471 | if config.wandb_project: 472 | wandb.finish() 473 | 474 | 475 | if __name__ == "__main__": 476 | fire.Fire(main) 477 | -------------------------------------------------------------------------------- /spd/experiments/tms/tms_lp_config.yaml: -------------------------------------------------------------------------------- 1 | wandb_project: spd-tms 2 | wandb_run_name: null 3 | wandb_run_name_prefix: "" 4 | unit_norm_matrices: true 5 | seed: 0 6 | topk: null 7 | m: 3 8 | C: 5 9 | param_match_coeff: 1.0 10 | lp_sparsity_coeff: 7.0 11 | pnorm: 0.9 12 | schatten_pnorm: 1.0 13 | schatten_coeff: 1.0 14 | batch_size: 2048 15 | steps: 20_000 16 | image_freq: 5000 17 | print_freq: 500 18 | save_freq: 20_000 19 | lr: 0.3 20 | lr_schedule: constant 21 | lr_warmup_pct: 0.1 22 | task_config: 23 | task_name: tms 24 | bias_val: 0.0 25 | train_bias: false 26 | feature_probability: 0.05 27 | data_generation_type: "at_least_zero_active" 28 | # File obtained by running spd/experiments/tms/train_tms.py 29 | pretrained_model_path: spd/experiments/tms/out/tms_n-features5_n-hidden2_n-instances12_seed0.pth/model.pth -------------------------------------------------------------------------------- /spd/experiments/tms/tms_sweep_config.yaml: -------------------------------------------------------------------------------- 1 | program: spd/experiments/tms/tms_decomposition.py 2 | method: grid 3 | metric: 4 | name: final_closeness 5 | goal: minimize 6 | parameters: 7 | # topk: 8 | # # values: [0.211, 0.239, 0.25, 0.261, 0.289] 9 | seed: 10 | values: [0, 1, 2, 3, 4] 11 | command: 12 | - ${env} 13 | - ${interpreter} 14 | - ${program} 15 | - spd/experiments/tms/tms_topk_config.yaml 16 | -------------------------------------------------------------------------------- /spd/experiments/tms/tms_topk_config.yaml: -------------------------------------------------------------------------------- 1 | # # TMS 5-2 2 | # wandb_project: spd-tms 3 | # wandb_run_name: null 4 | # wandb_run_name_prefix: "" 5 | # unit_norm_matrices: false 6 | # seed: 0 7 | # C: 5 8 | # topk: 0.211 9 | # batch_topk: true 10 | # param_match_coeff: 1.0 11 | # topk_recon_coeff: 1 12 | # attribution_type: gradient 13 | # pnorm: null 14 | # schatten_pnorm: 1.0 15 | # schatten_coeff: 7e-1 16 | # batch_size: 2048 17 | # steps: 20_000 18 | # image_freq: 5_000 19 | # print_freq: 1_000 20 | # save_freq: 20_000 21 | # lr: 3e-2 22 | # lr_schedule: constant 23 | # lr_warmup_pct: 0.05 24 | # task_config: 25 | # task_name: tms 26 | # bias_val: 0.0 27 | # train_bias: false 28 | # feature_probability: 0.05 29 | # data_generation_type: "at_least_zero_active" 30 | # pretrained_model_path: "wandb:spd-train-tms/runs/cv3g3z9d" # Local or wandb path 31 | 32 | # TMS 40-10 33 | wandb_project: spd-tms 34 | wandb_run_name: null 35 | wandb_run_name_prefix: "" 36 | unit_norm_matrices: false 37 | seed: 0 38 | topk: 2.0 39 | # topk: 0.8 # synced inputs 40 | C: 40 41 | batch_topk: true 42 | param_match_coeff: 1.0 43 | topk_recon_coeff: 10.0 44 | attribution_type: gradient 45 | pnorm: null 46 | schatten_pnorm: 0.9 47 | schatten_coeff: 15.0 48 | batch_size: 2048 49 | steps: 20_000 50 | image_freq: 5_000 51 | print_freq: 1_000 52 | save_freq: 20_000 53 | lr: 1e-3 54 | lr_schedule: cosine 55 | lr_warmup_pct: 0.05 56 | task_config: 57 | task_name: tms 58 | bias_val: 0.0 59 | train_bias: false 60 | feature_probability: 0.05 61 | # feature_probability: 0.02 # synced inputs 62 | data_generation_type: "at_least_zero_active" 63 | pretrained_model_path: "wandb:spd-train-tms/runs/tmzweoqk" 64 | # pretrained_model_path: "wandb:spd-train-tms/runs/rkflpubi" # synced inputs -------------------------------------------------------------------------------- /spd/experiments/tms/train_tms.py: -------------------------------------------------------------------------------- 1 | """TMS model, adapted from 2 | https://colab.research.google.com/github/anthropics/toy-models-of-superposition/blob/main/toy_models.ipynb 3 | """ 4 | 5 | from collections.abc import Callable 6 | from datetime import datetime 7 | from pathlib import Path 8 | from typing import Literal, Self 9 | 10 | import einops 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import torch 14 | import wandb 15 | import yaml 16 | from matplotlib import collections as mc 17 | from pydantic import BaseModel, ConfigDict, PositiveInt, model_validator 18 | from tqdm import tqdm, trange 19 | 20 | from spd.experiments.tms.models import TMSModel, TMSModelConfig 21 | from spd.log import logger 22 | from spd.utils import DatasetGeneratedDataLoader, SparseFeatureDataset, set_seed 23 | 24 | wandb.require("core") 25 | 26 | 27 | class TMSTrainConfig(BaseModel): 28 | model_config = ConfigDict(extra="forbid", frozen=True) 29 | wandb_project: str | None = None # The name of the wandb project (if None, don't log to wandb) 30 | tms_model_config: TMSModelConfig 31 | feature_probability: float 32 | batch_size: PositiveInt 33 | steps: PositiveInt 34 | seed: int = 0 35 | lr: float 36 | data_generation_type: Literal["at_least_zero_active", "exactly_one_active"] 37 | fixed_identity_hidden_layers: bool = False 38 | fixed_random_hidden_layers: bool = False 39 | synced_inputs: list[list[int]] | None = None 40 | 41 | @model_validator(mode="after") 42 | def validate_model(self) -> Self: 43 | if self.fixed_identity_hidden_layers and self.fixed_random_hidden_layers: 44 | raise ValueError( 45 | "Cannot set both fixed_identity_hidden_layers and fixed_random_hidden_layers to True" 46 | ) 47 | if self.synced_inputs is not None: 48 | # Ensure that the synced_inputs are non-overlapping with eachother 49 | all_indices = [item for sublist in self.synced_inputs for item in sublist] 50 | if len(all_indices) != len(set(all_indices)): 51 | raise ValueError("Synced inputs must be non-overlapping") 52 | return self 53 | 54 | 55 | def linear_lr(step: int, steps: int) -> float: 56 | return 1 - (step / steps) 57 | 58 | 59 | def constant_lr(*_: int) -> float: 60 | return 1.0 61 | 62 | 63 | def cosine_decay_lr(step: int, steps: int) -> float: 64 | return np.cos(0.5 * np.pi * step / (steps - 1)) 65 | 66 | 67 | def train( 68 | model: TMSModel, 69 | dataloader: DatasetGeneratedDataLoader[tuple[torch.Tensor, torch.Tensor]], 70 | log_wandb: bool, 71 | importance: float = 1.0, 72 | steps: int = 5_000, 73 | print_freq: int = 100, 74 | lr: float = 5e-3, 75 | lr_schedule: Callable[[int, int], float] = linear_lr, 76 | ) -> None: 77 | hooks = [] 78 | 79 | opt = torch.optim.AdamW(list(model.parameters()), lr=lr) 80 | 81 | data_iter = iter(dataloader) 82 | with trange(steps, ncols=0) as t: 83 | for step in t: 84 | step_lr = lr * lr_schedule(step, steps) 85 | for group in opt.param_groups: 86 | group["lr"] = step_lr 87 | opt.zero_grad(set_to_none=True) 88 | batch, labels = next(data_iter) 89 | out = model(batch) 90 | error = importance * (labels.abs() - out) ** 2 91 | loss = einops.reduce(error, "b i f -> i", "mean").sum() 92 | loss.backward() 93 | opt.step() 94 | 95 | if hooks: 96 | hook_data = dict( 97 | model=model, step=step, opt=opt, error=error, loss=loss, lr=step_lr 98 | ) 99 | for h in hooks: 100 | h(hook_data) 101 | if step % print_freq == 0 or (step + 1 == steps): 102 | tqdm.write(f"Step {step} Loss: {loss.item() / model.config.n_instances}") 103 | t.set_postfix( 104 | loss=loss.item() / model.config.n_instances, 105 | lr=step_lr, 106 | ) 107 | if log_wandb: 108 | wandb.log( 109 | {"loss": loss.item() / model.config.n_instances, "lr": step_lr}, step=step 110 | ) 111 | 112 | 113 | def plot_intro_diagram(model: TMSModel, filepath: Path) -> None: 114 | """2D polygon plot of the TMS model. 115 | 116 | Adapted from 117 | https://colab.research.google.com/github/anthropics/toy-models-of-superposition/blob/main/toy_models.ipynb. 118 | """ 119 | WA = model.linear1.weight.detach() 120 | sel = range(model.config.n_instances) # can be used to highlight specific sparsity levels 121 | color = plt.cm.viridis(np.array([0.0])) # type: ignore 122 | plt.rcParams["figure.dpi"] = 200 123 | fig, axs = plt.subplots(1, len(sel), figsize=(2 * len(sel), 2)) 124 | axs = np.array(axs) 125 | for i, ax in zip(sel, axs, strict=False): 126 | W = WA[i].cpu().detach().numpy() 127 | ax.scatter(W[:, 0], W[:, 1], c=color) 128 | ax.set_aspect("equal") 129 | ax.add_collection( 130 | mc.LineCollection(np.stack((np.zeros_like(W), W), axis=1), colors=[color]) # type: ignore 131 | ) 132 | 133 | z = 1.5 134 | ax.set_facecolor("#FCFBF8") 135 | ax.set_xlim((-z, z)) 136 | ax.set_ylim((-z, z)) 137 | ax.tick_params(left=True, right=False, labelleft=False, labelbottom=False, bottom=True) 138 | for spine in ["top", "right"]: 139 | ax.spines[spine].set_visible(False) 140 | for spine in ["bottom", "left"]: 141 | ax.spines[spine].set_position("center") 142 | plt.savefig(filepath) 143 | 144 | 145 | def plot_cosine_similarity_distribution( 146 | model: TMSModel, 147 | filepath: Path, 148 | ) -> None: 149 | """Create scatter plots of cosine similarities between feature vectors for each instance. 150 | 151 | Args: 152 | model: The trained TMS model 153 | filepath: Where to save the plot 154 | """ 155 | # Calculate cosine similarities 156 | rows = model.linear1.weight.detach() 157 | rows /= rows.norm(dim=-1, keepdim=True) 158 | cosine_sims = einops.einsum(rows, rows, "i f1 h, i f2 h -> i f1 f2") 159 | mask = ~torch.eye(rows.shape[1], device=rows.device, dtype=torch.bool) 160 | masked_sims = cosine_sims[:, mask].reshape(rows.shape[0], -1) 161 | 162 | # Create subplot for each instance 163 | fig, axs = plt.subplots(1, model.config.n_instances, figsize=(4 * model.config.n_instances, 4)) 164 | axs = np.array(axs).flatten() # Handle case where n_instances = 1 165 | 166 | for i, ax in enumerate(axs): 167 | sims = masked_sims[i].cpu().numpy() 168 | ax.scatter(sims, np.zeros_like(sims), alpha=0.5) 169 | ax.set_title(f"Instance {i}") 170 | ax.set_xlim(-1, 1) 171 | if i == 0: # Only show x-label for first plot 172 | ax.set_xlabel("Cosine Similarity") 173 | ax.set_yticks([]) # Hide y-axis ticks 174 | 175 | plt.tight_layout() 176 | plt.savefig(filepath) 177 | plt.close() 178 | 179 | 180 | def get_model_and_dataloader( 181 | config: TMSTrainConfig, device: str 182 | ) -> tuple[TMSModel, DatasetGeneratedDataLoader[tuple[torch.Tensor, torch.Tensor]]]: 183 | model = TMSModel(config=config.tms_model_config) 184 | model.to(device) 185 | if ( 186 | config.fixed_identity_hidden_layers or config.fixed_random_hidden_layers 187 | ) and model.hidden_layers is not None: 188 | for i in range(model.config.n_hidden_layers): 189 | if config.fixed_identity_hidden_layers: 190 | model.hidden_layers[i].weight.data[:, :, :] = torch.eye( 191 | model.config.n_hidden, device=device 192 | ) 193 | elif config.fixed_random_hidden_layers: 194 | model.hidden_layers[i].weight.data[:, :, :] = torch.randn_like( 195 | model.hidden_layers[i].weight 196 | ) 197 | model.hidden_layers[i].weight.requires_grad = False 198 | 199 | dataset = SparseFeatureDataset( 200 | n_instances=config.tms_model_config.n_instances, 201 | n_features=config.tms_model_config.n_features, 202 | feature_probability=config.feature_probability, 203 | device=device, 204 | data_generation_type=config.data_generation_type, 205 | value_range=(0.0, 1.0), 206 | synced_inputs=config.synced_inputs, 207 | ) 208 | dataloader = DatasetGeneratedDataLoader(dataset, batch_size=config.batch_size) 209 | return model, dataloader 210 | 211 | 212 | def run_train(config: TMSTrainConfig, device: str) -> None: 213 | model, dataloader = get_model_and_dataloader(config, device) 214 | 215 | model_cfg = config.tms_model_config 216 | run_name = ( 217 | f"tms_n-features{model_cfg.n_features}_n-hidden{model_cfg.n_hidden}_" 218 | f"n-hidden-layers{model_cfg.n_hidden_layers}_n-instances{model_cfg.n_instances}_" 219 | f"feat_prob{config.feature_probability}_seed{config.seed}" 220 | ) 221 | if config.fixed_identity_hidden_layers: 222 | run_name += "_fixed-identity" 223 | elif config.fixed_random_hidden_layers: 224 | run_name += "_fixed-random" 225 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] 226 | out_dir = Path(__file__).parent / "out" / f"{run_name}_{timestamp}" 227 | out_dir.mkdir(parents=True, exist_ok=True) 228 | 229 | if config.wandb_project: 230 | wandb.init(project=config.wandb_project, name=run_name) 231 | 232 | # Save config 233 | config_path = out_dir / "tms_train_config.yaml" 234 | with open(config_path, "w") as f: 235 | yaml.dump(config.model_dump(mode="json"), f, indent=2) 236 | if config.wandb_project: 237 | wandb.save(str(config_path), base_path=out_dir, policy="now") 238 | logger.info(f"Saved config to {config_path}") 239 | 240 | train( 241 | model, 242 | dataloader=dataloader, 243 | log_wandb=config.wandb_project is not None, 244 | steps=config.steps, 245 | ) 246 | 247 | model_path = out_dir / "tms.pth" 248 | torch.save(model.state_dict(), model_path) 249 | if config.wandb_project: 250 | wandb.save(str(model_path), base_path=out_dir, policy="now") 251 | logger.info(f"Saved model to {model_path}") 252 | 253 | if model_cfg.n_hidden == 2: 254 | plot_intro_diagram(model, filepath=out_dir / "polygon.png") 255 | logger.info(f"Saved diagram to {out_dir / 'polygon.png'}") 256 | 257 | plot_cosine_similarity_distribution( 258 | model, filepath=out_dir / "cosine_similarity_distribution.png" 259 | ) 260 | logger.info( 261 | f"Saved cosine similarity distribution to {out_dir / 'cosine_similarity_distribution.png'}" 262 | ) 263 | logger.info(f"1/sqrt(n_hidden): {1 / np.sqrt(model_cfg.n_hidden)}") 264 | 265 | 266 | if __name__ == "__main__": 267 | device = "cuda" if torch.cuda.is_available() else "cpu" 268 | # TMS 5-2 269 | # config = TMSTrainConfig( 270 | # wandb_project="spd-train-tms", 271 | # tms_model_config=TMSModelConfig( 272 | # n_features=5, 273 | # n_hidden=2, 274 | # n_hidden_layers=0, 275 | # n_instances=12, 276 | # device=device, 277 | # ), 278 | # feature_probability=0.05, 279 | # batch_size=1024, 280 | # steps=5000, 281 | # seed=0, 282 | # lr=5e-3, 283 | # data_generation_type="at_least_zero_active", 284 | # fixed_identity_hidden_layers=False, 285 | # fixed_random_hidden_layers=False, 286 | # ) 287 | # TMS 40-10 288 | config = TMSTrainConfig( 289 | wandb_project="spd-train-tms", 290 | tms_model_config=TMSModelConfig( 291 | n_features=40, 292 | n_hidden=10, 293 | n_hidden_layers=0, 294 | n_instances=3, 295 | device=device, 296 | ), 297 | feature_probability=0.05, 298 | # feature_probability=0.02, # synced inputs 299 | batch_size=2048, 300 | steps=2000, 301 | seed=0, 302 | lr=1e-3, 303 | data_generation_type="at_least_zero_active", 304 | fixed_identity_hidden_layers=False, 305 | fixed_random_hidden_layers=False, 306 | # synced_inputs=[[5, 6], [0, 2, 3]], 307 | ) 308 | set_seed(config.seed) 309 | 310 | run_train(config, device) 311 | -------------------------------------------------------------------------------- /spd/log.py: -------------------------------------------------------------------------------- 1 | """Setup a logger to be used in all modules in the library. 2 | 3 | To use the logger, import it in any module and use it as follows: 4 | 5 | ``` 6 | from spd.log import logger 7 | 8 | logger.info("Info message") 9 | logger.warning("Warning message") 10 | ``` 11 | """ 12 | 13 | import logging 14 | from logging.config import dictConfig 15 | from pathlib import Path 16 | 17 | DEFAULT_LOGFILE = Path(__file__).resolve().parent.parent / "logs" / "logs.log" 18 | 19 | 20 | def setup_logger(logfile: Path = DEFAULT_LOGFILE) -> logging.Logger: 21 | """Setup a logger to be used in all modules in the library. 22 | 23 | Sets up logging configuration with a console handler and a file handler. 24 | Console handler logs messages with INFO level, file handler logs WARNING level. 25 | The root logger is configured to use both handlers. 26 | 27 | Returns: 28 | logging.Logger: A configured logger object. 29 | 30 | Example: 31 | >>> logger = setup_logger() 32 | >>> logger.debug("Debug message") 33 | >>> logger.info("Info message") 34 | >>> logger.warning("Warning message") 35 | """ 36 | if not logfile.parent.exists(): 37 | logfile.parent.mkdir(parents=True, exist_ok=True) 38 | 39 | logging_config = { 40 | "version": 1, 41 | "formatters": { 42 | "default": { 43 | "format": "%(asctime)s - %(levelname)s - %(message)s", 44 | "datefmt": "%Y-%m-%d %H:%M:%S", 45 | }, 46 | }, 47 | "handlers": { 48 | "console": { 49 | "class": "logging.StreamHandler", 50 | "formatter": "default", 51 | "level": "INFO", 52 | }, 53 | "file": { 54 | "class": "logging.FileHandler", 55 | "filename": str(logfile), 56 | "formatter": "default", 57 | "level": "WARNING", 58 | }, 59 | }, 60 | "root": { 61 | "handlers": ["console", "file"], 62 | "level": "INFO", 63 | }, 64 | } 65 | 66 | dictConfig(logging_config) 67 | return logging.getLogger() 68 | 69 | 70 | logger = setup_logger() 71 | -------------------------------------------------------------------------------- /spd/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloResearch/apd/4e29d4375c333dff3a1046d36da726cffb25af65/spd/models/__init__.py -------------------------------------------------------------------------------- /spd/models/base.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | from spd.hooks import HookedRootModule 4 | from spd.models.components import TransposedLinearComponent 5 | from spd.module_utils import ( 6 | collect_nested_module_attrs, 7 | get_nested_module_attr, 8 | remove_grad_parallel_to_subnetwork_vecs, 9 | ) 10 | 11 | 12 | class SPDModel(HookedRootModule): 13 | def set_subnet_to_zero(self, subnet_idx: int, has_instance_dim: bool) -> dict[str, Tensor]: 14 | stored_vals = {} 15 | for attr_name in ["A", "B"]: 16 | params = collect_nested_module_attrs(self, attr_name) 17 | for param_name, param in params.items(): 18 | if self.parent_is_transposed_linear(param_name): 19 | continue 20 | if has_instance_dim: 21 | stored_vals[param_name] = param.data[:, subnet_idx, :, :].detach().clone() 22 | param.data[:, subnet_idx, :, :] = 0.0 23 | else: 24 | stored_vals[param_name] = param.data[subnet_idx, :, :].detach().clone() 25 | param.data[subnet_idx, :, :] = 0.0 26 | return stored_vals 27 | 28 | def restore_subnet( 29 | self, subnet_idx: int, stored_vals: dict[str, Tensor], has_instance_dim: bool 30 | ) -> None: 31 | for name, val in stored_vals.items(): 32 | param = get_nested_module_attr(self, name) 33 | if has_instance_dim: 34 | param.data[:, subnet_idx, :, :] = val 35 | else: 36 | param.data[subnet_idx, :, :] = val 37 | 38 | def set_As_to_unit_norm(self) -> None: 39 | """Set all A matrices to unit norm for stability. 40 | 41 | Normalizes over the second last dimension (which is the d_in dimension for A). 42 | 43 | Excludes TransposedLinearComponent matrices. 44 | """ 45 | params = collect_nested_module_attrs(self, "A") 46 | for param_name, param in params.items(): 47 | if not self.parent_is_transposed_linear(param_name): 48 | param.data /= param.data.norm(p=2, dim=-2, keepdim=True) 49 | 50 | def fix_normalized_adam_gradients(self) -> None: 51 | """Modify the gradient by subtracting it's component parallel to the activation.""" 52 | params = collect_nested_module_attrs(self, "A") 53 | for param_name, param in params.items(): 54 | if not self.parent_is_transposed_linear(param_name): 55 | assert param.grad is not None 56 | remove_grad_parallel_to_subnetwork_vecs(param.data, param.grad) 57 | 58 | def parent_is_transposed_linear(self, param_name: str) -> bool: 59 | """Check if the parent module of the given parameter is a TransposedLinearComponent. 60 | 61 | We use this to avoid operations on a tensor which is tied to another tensor. 62 | """ 63 | parent_module_name = ".".join(param_name.split(".")[:-1]) 64 | parent_module = get_nested_module_attr(self, parent_module_name) 65 | return isinstance(parent_module, TransposedLinearComponent) 66 | -------------------------------------------------------------------------------- /spd/models/components.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal 2 | 3 | import einops 4 | import torch 5 | from jaxtyping import Bool, Float 6 | from torch import Tensor, nn 7 | 8 | from spd.hooks import HookPoint 9 | from spd.module_utils import init_param_ 10 | 11 | 12 | class Linear(nn.Module): 13 | """A linear transformation with an optional n_instances dimension.""" 14 | 15 | def __init__( 16 | self, 17 | d_in: int, 18 | d_out: int, 19 | n_instances: int | None = None, 20 | init_type: Literal["kaiming_uniform", "xavier_normal"] = "kaiming_uniform", 21 | init_scale: float = 1.0, 22 | ): 23 | super().__init__() 24 | shape = (n_instances, d_in, d_out) if n_instances is not None else (d_in, d_out) 25 | self.weight = nn.Parameter(torch.empty(shape)) 26 | init_param_(self.weight, scale=init_scale, init_type=init_type) 27 | 28 | self.hook_pre = HookPoint() # (batch ... d_in) 29 | self.hook_post = HookPoint() # (batch ... d_out) 30 | 31 | def forward( 32 | self, x: Float[Tensor, "batch ... d_in"], *args: Any, **kwargs: Any 33 | ) -> Float[Tensor, "batch ... d_out"]: 34 | x = self.hook_pre(x) 35 | out = einops.einsum(x, self.weight, "batch ... d_in, ... d_in d_out -> batch ... d_out") 36 | out = self.hook_post(out) 37 | return out 38 | 39 | 40 | class LinearComponent(nn.Module): 41 | """A linear transformation made from A and B matrices for SPD. 42 | 43 | The weight matrix W is decomposed as W = A @ B, where A and B are learned parameters. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | d_in: int, 49 | d_out: int, 50 | C: int, 51 | n_instances: int | None = None, 52 | init_type: Literal["kaiming_uniform", "xavier_normal"] = "kaiming_uniform", 53 | init_scale: float = 1.0, 54 | m: int | None = None, 55 | ): 56 | super().__init__() 57 | self.n_instances = n_instances 58 | self.C = C 59 | self.m = min(d_in, d_out) if m is None else m 60 | 61 | # Initialize A and B matrices 62 | shape_A = (n_instances, C, d_in, self.m) if n_instances is not None else (C, d_in, self.m) 63 | shape_B = (n_instances, C, self.m, d_out) if n_instances is not None else (C, self.m, d_out) 64 | self.A = nn.Parameter(torch.empty(shape_A)) 65 | self.B = nn.Parameter(torch.empty(shape_B)) 66 | self.hook_pre = HookPoint() # (batch d_in) or (batch n_instances d_in) 67 | self.hook_component_acts = HookPoint() # (batch C d_out) or (batch n_instances C d_out) 68 | self.hook_post = HookPoint() # (batch d_out) or (batch n_instances d_out) 69 | 70 | init_param_(self.A, scale=init_scale, init_type=init_type) 71 | init_param_(self.B, scale=init_scale, init_type=init_type) 72 | 73 | @property 74 | def component_weights(self) -> Float[Tensor, "... C d_in d_out"]: 75 | """A @ B before summing over the subnetwork dimension.""" 76 | return einops.einsum(self.A, self.B, "... C d_in m, ... C m d_out -> ... C d_in d_out") 77 | 78 | @property 79 | def weight(self) -> Float[Tensor, "... d_in d_out"]: 80 | """A @ B after summing over the subnetwork dimension.""" 81 | return einops.einsum(self.A, self.B, "... C d_in m, ... C m d_out -> ... d_in d_out") 82 | 83 | def forward( 84 | self, 85 | x: Float[Tensor, "batch ... d_in"], 86 | topk_mask: Bool[Tensor, "batch ... C"] | None = None, 87 | ) -> Float[Tensor, "batch ... d_out"]: 88 | """Forward pass through A and B matrices which make up the component for this layer. 89 | 90 | Args: 91 | x: Input tensor 92 | topk_mask: Boolean tensor indicating which subnetworks to keep 93 | Returns: 94 | output: The summed output across all subnetworks 95 | """ 96 | x = self.hook_pre(x) 97 | 98 | # First multiply by A to get to intermediate dimension m 99 | inner_acts = einops.einsum(x, self.A, "batch ... d_in, ... C d_in m -> batch ... C m") 100 | if topk_mask is not None: 101 | assert topk_mask.shape == inner_acts.shape[:-1] 102 | inner_acts = einops.einsum( 103 | inner_acts, topk_mask, "batch ... C m, batch ... C -> batch ... C m" 104 | ) 105 | 106 | # Then multiply by B to get to output dimension 107 | component_acts = einops.einsum( 108 | inner_acts, self.B, "batch ... C m, ... C m d_out -> batch ... C d_out" 109 | ) 110 | self.hook_component_acts(component_acts) 111 | 112 | # Sum over subnetwork dimension 113 | out = einops.einsum(component_acts, "batch ... C d_out -> batch ... d_out") 114 | out = self.hook_post(out) 115 | return out 116 | 117 | 118 | class TransposedLinear(Linear): 119 | """Linear layer that uses a transposed weight from another Linear layer. 120 | 121 | We use 'd_in' and 'd_out' to refer to the dimensions of the original Linear layer. 122 | """ 123 | 124 | def __init__(self, original_weight: nn.Parameter): 125 | # Copy the relevant parts from Linear.__init__. Don't copy operations that will call 126 | # TransposedLinear.weight. 127 | nn.Module.__init__(self) 128 | self.hook_pre = HookPoint() # (batch ... d_out) 129 | self.hook_post = HookPoint() # (batch ... d_in) 130 | 131 | self.register_buffer("original_weight", original_weight, persistent=False) 132 | 133 | @property 134 | def weight(self) -> Float[Tensor, "... d_out d_in"]: 135 | return einops.rearrange(self.original_weight, "... d_in d_out -> ... d_out d_in") 136 | 137 | 138 | class TransposedLinearComponent(LinearComponent): 139 | """LinearComponent that uses a transposed weight from another LinearComponent. 140 | 141 | We use 'd_in' and 'd_out' to refer to the dimensions of the original LinearComponent. 142 | """ 143 | 144 | def __init__(self, original_A: nn.Parameter, original_B: nn.Parameter): 145 | # Copy the relevant parts from LinearComponent.__init__. Don't copy operations that will 146 | # call TransposedLinear.A or TransposedLinear.B. 147 | nn.Module.__init__(self) 148 | self.n_instances, self.C, _, self.m = original_A.shape 149 | 150 | self.hook_pre = HookPoint() # (batch ... d_out) 151 | self.hook_component_acts = HookPoint() # (batch ... C d_in) 152 | self.hook_post = HookPoint() # (batch ... d_in) 153 | 154 | self.register_buffer("original_A", original_A, persistent=False) 155 | self.register_buffer("original_B", original_B, persistent=False) 156 | 157 | @property 158 | def A(self) -> Float[Tensor, "... C d_out m"]: 159 | # New A is the transpose of the original B 160 | return einops.rearrange(self.original_B, "... C m d_out -> ... C d_out m") 161 | 162 | @property 163 | def B(self) -> Float[Tensor, "... C d_in m"]: 164 | # New B is the transpose of the original A 165 | return einops.rearrange(self.original_A, "... C d_in m -> ... C m d_in") 166 | 167 | @property 168 | def component_weights(self) -> Float[Tensor, "... C d_out d_in"]: 169 | """A @ B before summing over the subnetwork dimension.""" 170 | return einops.einsum(self.A, self.B, "... C d_out m, ... C m d_in -> ... C d_out d_in") 171 | 172 | @property 173 | def weight(self) -> Float[Tensor, "... d_out d_in"]: 174 | """A @ B after summing over the subnetwork dimension.""" 175 | return einops.einsum(self.A, self.B, "... C d_out m, ... C m d_in -> ... d_out d_in") 176 | -------------------------------------------------------------------------------- /spd/module_utils.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from typing import Any, Literal 3 | 4 | import einops 5 | import torch 6 | import torch.nn as nn 7 | from jaxtyping import Float 8 | from torch import Tensor 9 | 10 | 11 | def get_nested_module_attr(module: nn.Module, access_string: str) -> Any: 12 | """Access a specific attribute by its full, path-like name. 13 | 14 | Taken from https://discuss.pytorch.org/t/how-to-access-to-a-layer-by-module-name/83797/8 15 | 16 | Args: 17 | module: The module to search through. 18 | access_string: The full name of the nested attribute to access, with each object separated 19 | by periods (e.g. "linear1.A"). 20 | """ 21 | names = access_string.split(".") 22 | try: 23 | mod = reduce(getattr, names, module) 24 | except AttributeError as err: 25 | raise AttributeError(f"{module} does not have nested attribute {access_string}") from err 26 | return mod 27 | 28 | 29 | def collect_nested_module_attrs( 30 | module: nn.Module, 31 | attr_name: str, 32 | include_attr_name: bool = True, 33 | ) -> dict[str, Tensor]: 34 | """Collect all attributes matching attr_name from a module and all its submodules. 35 | 36 | Args: 37 | module: The module to collect attributes from 38 | attr_name: Name of the attributes to collect from module and all submodules. E.g. "A". 39 | include_attr_name: If True, the attribute name is included in the key of the dictionary. 40 | E.g. if attr_name is "A", the key will be "root.A" or "linear1.A". 41 | 42 | Returns: 43 | Dictionary mapping module names to their attribute values 44 | 45 | Raises: 46 | - ValueError: If no modules with the specified attribute are found 47 | - ValueError: If the attribute is not a tensor 48 | """ 49 | attributes: dict[str, Tensor] = {} 50 | 51 | all_modules = module.named_modules() 52 | for name, submodule in all_modules: 53 | if hasattr(submodule, attr_name): 54 | # For root module, name will be empty string 55 | submodule_attr = getattr(submodule, attr_name) 56 | if not isinstance(submodule_attr, Tensor): 57 | raise ValueError( 58 | f"Attribute '{attr_name}' is not a tensor. " 59 | f"Available modules: {[name for name, _ in all_modules]}" 60 | ) 61 | key = name + "." + attr_name if include_attr_name else name 62 | attributes[key] = submodule_attr 63 | 64 | if not attributes: 65 | raise ValueError( 66 | f"No modules found with attribute '{attr_name}'. " 67 | f"Available modules: {[name for name, _ in all_modules]}" 68 | ) 69 | 70 | return attributes 71 | 72 | 73 | @torch.inference_mode() 74 | def remove_grad_parallel_to_subnetwork_vecs( 75 | A: Float[Tensor, "... d_in m"], A_grad: Float[Tensor, "... d_in m"] 76 | ) -> None: 77 | """Modify the gradient by subtracting it's component parallel to the activation. 78 | 79 | I.e. subtract the projection of the gradient vector onto the activation vector. 80 | 81 | This is to stop Adam from changing the norm of A. Note that this will not completely prevent 82 | Adam from changing the norm due to Adam's (m/(sqrt(v) + eps)) term not preserving the norm 83 | direction. 84 | """ 85 | parallel_component = einops.einsum(A_grad, A, "... d_in m, ... d_in m -> ... m") 86 | A_grad -= einops.einsum(parallel_component, A, "... m, ... d_in m -> ... d_in m") 87 | 88 | 89 | def init_param_( 90 | param: torch.Tensor, 91 | scale: float = 1.0, 92 | init_type: Literal["kaiming_uniform", "xavier_normal"] = "kaiming_uniform", 93 | ) -> None: 94 | if init_type == "kaiming_uniform": 95 | torch.nn.init.kaiming_uniform_(param) 96 | with torch.no_grad(): 97 | param.mul_(scale) 98 | elif init_type == "xavier_normal": 99 | torch.nn.init.xavier_normal_(param, gain=scale) 100 | -------------------------------------------------------------------------------- /spd/plotting.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal 2 | 3 | import einops 4 | import matplotlib.pyplot as plt 5 | import matplotlib.ticker as tkr 6 | import numpy as np 7 | import torch 8 | from jaxtyping import Float 9 | from matplotlib.colors import CenteredNorm 10 | from mpl_toolkits.axes_grid1 import make_axes_locatable 11 | from torch import Tensor 12 | from torch.utils.data import DataLoader 13 | 14 | from spd.experiments.resid_mlp.models import ResidualMLPModel, ResidualMLPSPDModel 15 | from spd.experiments.tms.models import TMSModel, TMSSPDModel 16 | from spd.hooks import HookedRootModule 17 | from spd.models.base import SPDModel 18 | from spd.run_spd import Config 19 | from spd.utils import ( 20 | DataGenerationType, 21 | SparseFeatureDataset, 22 | calc_recon_mse, 23 | calc_topk_mask, 24 | calculate_attributions, 25 | run_spd_forward_pass, 26 | ) 27 | 28 | 29 | def plot_subnetwork_attributions_statistics( 30 | topk_mask: Float[Tensor, "batch_size n_instances C"], 31 | ) -> dict[str, plt.Figure]: 32 | """Plot vertical bar charts of the number of active subnetworks over the batch for each instance.""" 33 | batch_size = topk_mask.shape[0] 34 | if topk_mask.ndim == 2: 35 | n_instances = 1 36 | topk_mask = einops.repeat(topk_mask, "batch C -> batch n_instances C", n_instances=1) 37 | else: 38 | n_instances = topk_mask.shape[1] 39 | 40 | fig, axs = plt.subplots( 41 | ncols=n_instances, nrows=1, figsize=(5 * n_instances, 5), constrained_layout=True 42 | ) 43 | 44 | axs = np.array([axs]) if n_instances == 1 else np.array(axs) 45 | for i, ax in enumerate(axs): 46 | values = topk_mask[:, i].sum(dim=1).cpu().detach().numpy() 47 | bins = list(range(int(values.min().item()), int(values.max().item()) + 2)) 48 | counts, _ = np.histogram(values, bins=bins) 49 | bars = ax.bar(bins[:-1], counts, align="center", width=0.8) 50 | ax.set_xticks(bins[:-1]) 51 | ax.set_xticklabels([str(b) for b in bins[:-1]]) 52 | 53 | # Only add y-label to first subplot 54 | if i == 0: 55 | ax.set_ylabel("Count") 56 | 57 | ax.set_xlabel("Number of active subnetworks") 58 | ax.set_title(f"Instance {i+1}") 59 | 60 | # Add value annotations on top of each bar 61 | for bar in bars: 62 | height = bar.get_height() 63 | ax.annotate( 64 | f"{height}", 65 | xy=(bar.get_x() + bar.get_width() / 2, height), 66 | xytext=(0, 3), # 3 points vertical offset 67 | textcoords="offset points", 68 | ha="center", 69 | va="bottom", 70 | ) 71 | 72 | fig.suptitle(f"Active subnetworks on current batch (batch_size={batch_size})") 73 | return {"subnetwork_attributions_statistics": fig} 74 | 75 | 76 | def plot_subnetwork_correlations( 77 | dataloader: DataLoader[ 78 | tuple[Float[Tensor, "batch n_inputs"] | Float[Tensor, "batch n_instances? n_inputs"], Any] 79 | ], 80 | target_model: HookedRootModule, 81 | spd_model: SPDModel, 82 | config: Config, 83 | device: str, 84 | n_forward_passes: int = 100, 85 | ) -> dict[str, plt.Figure]: 86 | topk_masks = [] 87 | for batch, _ in dataloader: 88 | batch = batch.to(device=device) 89 | assert config.topk is not None 90 | 91 | # Forward pass on target model 92 | target_cache_filter = lambda k: k.endswith((".hook_pre", ".hook_post")) 93 | target_out, target_cache = target_model.run_with_cache( 94 | batch, names_filter=target_cache_filter 95 | ) 96 | 97 | # Do a forward pass with all subnetworks 98 | spd_cache_filter = lambda k: k.endswith((".hook_post", ".hook_component_acts")) 99 | out, spd_cache = spd_model.run_with_cache(batch, names_filter=spd_cache_filter) 100 | attribution_scores = calculate_attributions( 101 | model=spd_model, 102 | batch=batch, 103 | out=out, 104 | target_out=target_out, 105 | pre_weight_acts={k: v for k, v in target_cache.items() if k.endswith("hook_pre")}, 106 | post_weight_acts={k: v for k, v in target_cache.items() if k.endswith("hook_post")}, 107 | component_acts={ 108 | k: v for k, v in spd_cache.items() if k.endswith("hook_component_acts") 109 | }, 110 | attribution_type=config.attribution_type, 111 | ) 112 | 113 | # We always assume the final subnetwork is the one we want to distil 114 | topk_attrs = ( 115 | attribution_scores[..., :-1] if config.distil_from_target else attribution_scores 116 | ) 117 | if config.exact_topk: 118 | assert spd_model.n_instances == 1, "exact_topk only works if n_instances = 1" 119 | topk = (batch != 0).sum() / batch.shape[0] 120 | topk_mask = calc_topk_mask(topk_attrs, topk, batch_topk=config.batch_topk) 121 | else: 122 | topk_mask = calc_topk_mask(topk_attrs, config.topk, batch_topk=config.batch_topk) 123 | 124 | topk_masks.append(topk_mask) 125 | if len(topk_masks) > n_forward_passes: 126 | break 127 | topk_masks = torch.cat(topk_masks).float() 128 | 129 | if hasattr(spd_model, "n_instances"): 130 | n_instances = spd_model.n_instances 131 | else: 132 | n_instances = 1 133 | topk_masks = einops.repeat(topk_masks, "batch C -> batch n_instances C", n_instances=1) 134 | 135 | fig, axs = plt.subplots( 136 | ncols=n_instances, nrows=1, figsize=(5 * n_instances, 5), constrained_layout=True 137 | ) 138 | 139 | axs = np.array([axs]) if n_instances == 1 else np.array(axs) 140 | im, ax = None, None 141 | for i, ax in enumerate(axs): 142 | # Calculate correlation matrix 143 | corr_matrix = torch.corrcoef(topk_masks[:, i].T).cpu() 144 | 145 | im = ax.matshow(corr_matrix) 146 | ax.xaxis.set_ticks_position("bottom") 147 | if corr_matrix.shape[0] * corr_matrix.shape[1] < 200: 148 | for l in range(corr_matrix.shape[0]): 149 | for j in range(corr_matrix.shape[1]): 150 | ax.text( 151 | j, 152 | l, 153 | f"{corr_matrix[l, j]:.2f}", 154 | ha="center", 155 | va="center", 156 | color="#EE7777", 157 | fontsize=8, 158 | ) 159 | if (im is not None) and (ax is not None): 160 | divider = make_axes_locatable(plt.gca()) 161 | cax = divider.append_axes("right", size="5%", pad=0.1) 162 | plt.colorbar(im, cax=cax) 163 | ax.set_title("Subnetwork Correlation Matrix") 164 | ax.set_xlabel("Subnetwork") 165 | ax.set_ylabel("Subnetwork") 166 | return {"subnetwork_correlation_matrix": fig} 167 | 168 | 169 | def collect_sparse_dataset_mse_losses( 170 | dataset: SparseFeatureDataset, 171 | target_model: ResidualMLPModel | TMSModel, 172 | spd_model: TMSSPDModel | ResidualMLPSPDModel, 173 | batch_size: int, 174 | device: str, 175 | topk: float, 176 | attribution_type: Literal["gradient", "ablation", "activation"], 177 | batch_topk: bool, 178 | distil_from_target: bool, 179 | gen_types: list[DataGenerationType], 180 | ) -> dict[str, dict[str, Float[Tensor, ""] | Float[Tensor, " n_instances"]]]: 181 | """Collect the MSE losses for specific number of active features, as well as for 182 | 'at_least_zero_active'. 183 | 184 | We calculate two baselines: 185 | - baseline_monosemantic: a baseline loss where the first d_mlp feature indices get mapped to the 186 | true labels and the final (n_features - d_mlp) features are either 0 (TMS) or the raw inputs 187 | (ResidualMLP). 188 | 189 | Returns: 190 | A dictionary keyed by generation type and then by model type (target, spd, 191 | baseline_monosemantic), with values being MSE losses. 192 | """ 193 | target_model.to(device) 194 | spd_model.to(device) 195 | # Get the entries for the main loss table in the paper 196 | results = {gen_type: {} for gen_type in gen_types} 197 | word_to_num = {"one": 1, "two": 2, "three": 3, "four": 4, "five": 5} 198 | 199 | for gen_type in gen_types: 200 | dataset.data_generation_type = gen_type 201 | batch, labels = dataset.generate_batch(batch_size) 202 | 203 | batch = batch.to(device) 204 | labels = labels.to(device) 205 | 206 | target_model_output = target_model(batch) 207 | 208 | if gen_type == "at_least_zero_active": 209 | run_batch_topk = batch_topk 210 | run_topk = topk 211 | else: 212 | run_batch_topk = False 213 | assert gen_type.startswith("exactly_") 214 | n_active = word_to_num[gen_type.split("_")[1]] 215 | run_topk = n_active 216 | 217 | spd_outputs = run_spd_forward_pass( 218 | spd_model=spd_model, 219 | target_model=target_model, 220 | input_array=batch, 221 | attribution_type=attribution_type, 222 | batch_topk=run_batch_topk, 223 | topk=run_topk, 224 | distil_from_target=distil_from_target, 225 | ) 226 | # Combine the batch and n_instances dimension for batch, labels, target_model_output, 227 | # spd_outputs.spd_topk_model_output 228 | ein_str = "batch n_instances n_features -> (batch n_instances) n_features" 229 | batch = einops.rearrange(batch, ein_str) 230 | labels = einops.rearrange(labels, ein_str) 231 | target_model_output = einops.rearrange(target_model_output, ein_str) 232 | spd_topk_model_output = einops.rearrange(spd_outputs.spd_topk_model_output, ein_str) 233 | 234 | if gen_type == "at_least_zero_active": 235 | # Remove all entries where there are no active features 236 | mask = (batch != 0).any(dim=-1) 237 | batch = batch[mask] 238 | labels = labels[mask] 239 | target_model_output = target_model_output[mask] 240 | spd_topk_model_output = spd_topk_model_output[mask] 241 | 242 | topk_recon_loss_labels = calc_recon_mse( 243 | spd_topk_model_output, labels, has_instance_dim=False 244 | ) 245 | recon_loss = calc_recon_mse(target_model_output, labels, has_instance_dim=False) 246 | baseline_batch = calc_recon_mse(batch, labels, has_instance_dim=False) 247 | 248 | # Monosemantic baseline 249 | monosemantic_out = batch.clone() 250 | # Assumes TMS or ResidualMLP 251 | if isinstance(target_model, ResidualMLPModel): 252 | d_mlp = target_model.config.d_mlp * target_model.config.n_layers # type: ignore 253 | monosemantic_out[..., :d_mlp] = labels[..., :d_mlp] 254 | elif isinstance(target_model, TMSModel): 255 | d_mlp = target_model.config.n_hidden # type: ignore 256 | # The first d_mlp features are the true labels (i.e. the batch) and the rest are 0 257 | monosemantic_out[..., d_mlp:] = 0 258 | baseline_monosemantic = calc_recon_mse(monosemantic_out, labels, has_instance_dim=False) 259 | 260 | results[gen_type]["target"] = recon_loss 261 | results[gen_type]["spd"] = topk_recon_loss_labels 262 | results[gen_type]["baseline_batch"] = baseline_batch 263 | results[gen_type]["baseline_monosemantic"] = baseline_monosemantic 264 | return results 265 | 266 | 267 | def plot_sparse_feature_mse_line_plot( 268 | results: dict[str, dict[str, float]], 269 | label_map: list[tuple[str, str, str]], 270 | log_scale: bool = False, 271 | ) -> plt.Figure: 272 | xtick_label_map = { 273 | "at_least_zero_active": "Training distribution", 274 | "exactly_one_active": "Exactly 1 active", 275 | "exactly_two_active": "Exactly 2 active", 276 | "exactly_three_active": "Exactly 3 active", 277 | "exactly_four_active": "Exactly 4 active", 278 | "exactly_five_active": "Exactly 5 active", 279 | } 280 | # Create grouped bar plots for each generation type 281 | fig, ax = plt.subplots(figsize=(12, 6)) 282 | 283 | n_groups = len(results) # number of generation types 284 | n_models = len(label_map) # number of models to compare 285 | width = 0.8 / n_models # width of bars 286 | 287 | # Create bars for each model type 288 | for i, (model_type, label, color) in enumerate(label_map): 289 | x_positions = np.arange(n_groups) + i * width - (n_models - 1) * width / 2 290 | heights = [results[gen_type][model_type] for gen_type in results] 291 | ax.bar(x_positions, heights, width, label=label, color=color) 292 | 293 | # Customize the plot 294 | ax.set_ylabel("MSE w.r.t true labels") 295 | ax.set_xticks(np.arange(n_groups)) 296 | xtick_labels = [xtick_label_map[gen_type] for gen_type in results] 297 | ax.set_xticklabels(xtick_labels) 298 | ax.legend() 299 | ax.grid(True, alpha=0.3, axis="y") 300 | 301 | if log_scale: 302 | ax.set_yscale("log") 303 | 304 | # Remove top and right spines 305 | ax.spines["top"].set_visible(False) 306 | ax.spines["right"].set_visible(False) 307 | 308 | # Ensure that 0 is the bottom of the y-axis 309 | ax.set_ylim(bottom=0) 310 | 311 | plt.tight_layout() 312 | return fig 313 | 314 | 315 | def plot_matrix( 316 | ax: plt.Axes, 317 | matrix: torch.Tensor, 318 | title: str, 319 | xlabel: str, 320 | ylabel: str, 321 | colorbar_format: str = "%.1f", 322 | norm: plt.Normalize | None = None, 323 | ) -> None: 324 | # Useful to have bigger text for small matrices 325 | fontsize = 8 if matrix.numel() < 50 else 4 326 | norm = norm if norm is not None else CenteredNorm() 327 | im = ax.matshow(matrix.detach().cpu().numpy(), cmap="coolwarm", norm=norm) 328 | # If less than 500 elements, show the values 329 | if matrix.numel() < 500: 330 | for (j, i), label in np.ndenumerate(matrix.detach().cpu().numpy()): 331 | ax.text(i, j, f"{label:.2f}", ha="center", va="center", fontsize=fontsize) 332 | ax.set_xlabel(xlabel) 333 | if ylabel != "": 334 | ax.set_ylabel(ylabel) 335 | else: 336 | ax.set_yticklabels([]) 337 | ax.set_title(title) 338 | divider = make_axes_locatable(ax) 339 | cax = divider.append_axes("right", size=0.1, pad=0.05) 340 | fig = ax.get_figure() 341 | assert fig is not None 342 | fig.colorbar(im, cax=cax, format=tkr.FormatStrFormatter(colorbar_format)) 343 | if ylabel == "Function index": 344 | n_functions = matrix.shape[0] 345 | ax.set_yticks(range(n_functions)) 346 | ax.set_yticklabels([f"{L:.0f}" for L in range(1, n_functions + 1)]) 347 | -------------------------------------------------------------------------------- /spd/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | REPO_ROOT = ( 5 | Path(os.environ["GITHUB_WORKSPACE"]) if os.environ.get("CI") else Path(__file__).parent.parent 6 | ) 7 | -------------------------------------------------------------------------------- /spd/types.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Annotated 3 | 4 | from pydantic import BeforeValidator, Field, PlainSerializer 5 | 6 | from spd.utils import from_root_path, to_root_path 7 | 8 | WANDB_PATH_PREFIX = "wandb:" 9 | 10 | 11 | def validate_path(v: str | Path) -> str | Path: 12 | """Check if wandb path. If not, convert to relative to repo root.""" 13 | if isinstance(v, str) and v.startswith(WANDB_PATH_PREFIX): 14 | return v 15 | return to_root_path(v) 16 | 17 | 18 | # Type for paths that can either be wandb paths (starting with "wandb:") 19 | # or regular paths (converted to be relative to repo root) 20 | ModelPath = Annotated[ 21 | str | Path, 22 | BeforeValidator(validate_path), 23 | PlainSerializer(lambda x: str(from_root_path(x)) if isinstance(x, Path) else x), 24 | ] 25 | 26 | # This is a type for pydantic configs that will convert all relative paths 27 | # to be relative to the root of this repository 28 | RootPath = Annotated[ 29 | Path, BeforeValidator(to_root_path), PlainSerializer(lambda x: str(from_root_path(x))) 30 | ] 31 | 32 | TrigParams = tuple[float, float, float, float, float, float, float] 33 | 34 | Probability = Annotated[float, Field(strict=True, ge=0, le=1)] 35 | -------------------------------------------------------------------------------- /spd/wandb_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import TypeVar 4 | 5 | import wandb 6 | import yaml 7 | from dotenv import load_dotenv 8 | from pydantic import BaseModel 9 | from wandb.apis.public import File, Run 10 | 11 | from spd.settings import REPO_ROOT 12 | from spd.utils import replace_pydantic_model 13 | 14 | T = TypeVar("T", bound=BaseModel) 15 | 16 | 17 | def fetch_latest_wandb_checkpoint(run: Run, prefix: str | None = None) -> File: 18 | """Fetch the latest checkpoint from a wandb run. 19 | 20 | NOTE: Assumes that the only files that end in `.pth` are checkpoints. 21 | """ 22 | # Get the latest checkpoint. Assume format is _.pth or .pth 23 | checkpoints = [file for file in run.files() if file.name.endswith(".pth")] 24 | if prefix: 25 | checkpoints = [file for file in checkpoints if file.name.startswith(prefix)] 26 | if not checkpoints: 27 | raise ValueError(f"No checkpoint files found in run {run.name}") 28 | 29 | if len(checkpoints) == 1: 30 | latest_checkpoint_remote = checkpoints[0] 31 | else: 32 | # Assume format is _.pth 33 | latest_checkpoint_remote = sorted( 34 | checkpoints, key=lambda x: int(x.name.split(".pth")[0].split("_")[-1]) 35 | )[-1] 36 | return latest_checkpoint_remote 37 | 38 | 39 | def fetch_wandb_run_dir(run_id: str) -> Path: 40 | """Find or create a directory in the W&B cache for a given run. 41 | 42 | We first check if we already have a directory with the suffix "run_id" (if we created the run 43 | ourselves, a directory of the name "run--" should exist). If not, we create a 44 | new wandb_run_dir. 45 | """ 46 | # Default to REPO_ROOT/wandb if SPD_CACHE_DIR not set 47 | base_cache_dir = Path(os.environ.get("SPD_CACHE_DIR", REPO_ROOT / "wandb")) 48 | 49 | # Set default wandb_run_dir 50 | wandb_run_dir = base_cache_dir / run_id / "files" 51 | 52 | # Check if we already have a directory with the suffix "run_id" 53 | presaved_run_dirs = [ 54 | d for d in base_cache_dir.iterdir() if d.is_dir() and d.name.endswith(run_id) 55 | ] 56 | # If there is more than one dir, just ignore the presaved dirs and use the new wandb_run_dir 57 | if presaved_run_dirs and len(presaved_run_dirs) == 1: 58 | presaved_file_path = presaved_run_dirs[0] / "files" 59 | if presaved_file_path.exists(): 60 | # Found a cached run directory, use it 61 | wandb_run_dir = presaved_file_path 62 | 63 | wandb_run_dir.mkdir(parents=True, exist_ok=True) 64 | return wandb_run_dir 65 | 66 | 67 | def download_wandb_file(run: Run, wandb_run_dir: Path, file_name: str) -> Path: 68 | """Download a file from W&B. Don't overwrite the file if it already exists. 69 | 70 | Args: 71 | run: The W&B run to download from 72 | file_name: Name of the file to download 73 | wandb_run_dir: The directory to download the file to 74 | Returns: 75 | Path to the downloaded file 76 | """ 77 | file_on_wandb = run.file(file_name) 78 | assert isinstance(file_on_wandb, File) 79 | path = Path(file_on_wandb.download(exist_ok=True, replace=False, root=str(wandb_run_dir)).name) 80 | return path 81 | 82 | 83 | def init_wandb( 84 | config: T, project: str, sweep_config_path: Path | str | None = None, name: str | None = None 85 | ) -> T: 86 | """Initialize Weights & Biases and return a config updated with sweep hyperparameters. 87 | 88 | If no sweep config is provided, the config is returned as is. 89 | 90 | If a sweep config is provided, wandb is first initialized with the sweep config. This will 91 | cause wandb to choose specific hyperparameters for this instance of the sweep and store them 92 | in wandb.config. We then update the config with these hyperparameters. 93 | 94 | Args: 95 | config: The base config. 96 | project: The name of the wandb project. 97 | sweep_config_path: The path to the sweep config file. If provided, updates the config with 98 | the hyperparameters from this instance of the sweep. 99 | name: The name of the wandb run. 100 | 101 | Returns: 102 | Config updated with sweep hyperparameters (if any). 103 | """ 104 | if sweep_config_path is not None: 105 | with open(sweep_config_path) as f: 106 | sweep_data = yaml.safe_load(f) 107 | wandb.init(config=sweep_data, save_code=True, name=name) 108 | else: 109 | load_dotenv(override=True) 110 | wandb.init(project=project, entity=os.getenv("WANDB_ENTITY"), save_code=True, name=name) 111 | 112 | # Update the config with the hyperparameters for this sweep (if any) 113 | config = replace_pydantic_model(config, wandb.config) 114 | 115 | # Update the non-frozen keys in the wandb config (only relevant for sweeps) 116 | wandb.config.update(config.model_dump(mode="json")) 117 | return config 118 | -------------------------------------------------------------------------------- /tests/test_resid_mlp.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | from spd.experiments.resid_mlp.models import ( 8 | ResidualMLPConfig, 9 | ResidualMLPModel, 10 | ResidualMLPSPDConfig, 11 | ResidualMLPSPDModel, 12 | ) 13 | from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset 14 | from spd.module_utils import get_nested_module_attr 15 | from spd.run_spd import Config, ResidualMLPTaskConfig, optimize 16 | from spd.utils import DatasetGeneratedDataLoader, set_seed 17 | 18 | # Create a simple ResidualMLP config that we can use in multiple tests 19 | RESID_MLP_TASK_CONFIG = ResidualMLPTaskConfig( 20 | task_name="residual_mlp", 21 | feature_probability=0.333, 22 | init_scale=1.0, 23 | data_generation_type="at_least_zero_active", 24 | pretrained_model_path=Path(), # We'll create this later 25 | ) 26 | 27 | 28 | def test_resid_mlp_decomposition_happy_path() -> None: 29 | # Just noting that this test will only work on 98/100 seeds. So it's possible that future 30 | # changes will break this test. 31 | set_seed(0) 32 | resid_mlp_config = ResidualMLPConfig( 33 | n_instances=2, 34 | n_features=3, 35 | d_embed=2, 36 | d_mlp=3, 37 | n_layers=1, 38 | act_fn_name="relu", 39 | apply_output_act_fn=False, 40 | in_bias=True, 41 | out_bias=True, 42 | ) 43 | 44 | device = "cpu" 45 | config = Config( 46 | seed=0, 47 | C=3, 48 | topk=1, 49 | batch_topk=True, 50 | param_match_coeff=1.0, 51 | topk_recon_coeff=1, 52 | schatten_pnorm=1, 53 | schatten_coeff=1, 54 | attribution_type="gradient", 55 | lr=1e-3, 56 | batch_size=32, 57 | steps=10, # Run only a few steps for the test 58 | print_freq=2, 59 | image_freq=5, 60 | save_freq=None, 61 | lr_warmup_pct=0.01, 62 | lr_schedule="cosine", 63 | task_config=RESID_MLP_TASK_CONFIG, 64 | ) 65 | 66 | assert isinstance(config.task_config, ResidualMLPTaskConfig) 67 | # Create a pretrained model 68 | target_model = ResidualMLPModel(config=resid_mlp_config).to(device) 69 | 70 | # Create the SPD model 71 | spd_config = ResidualMLPSPDConfig(**resid_mlp_config.model_dump(), C=config.C) 72 | model = ResidualMLPSPDModel(config=spd_config).to(device) 73 | 74 | # Use the pretrained model's embedding matrices and don't train them further 75 | model.W_E.data[:, :] = target_model.W_E.data.detach().clone() 76 | model.W_E.requires_grad = False 77 | model.W_U.data[:, :] = target_model.W_U.data.detach().clone() 78 | model.W_U.requires_grad = False 79 | 80 | # Copy the biases from the target model to the SPD model and set requires_grad to False 81 | for i in range(resid_mlp_config.n_layers): 82 | if resid_mlp_config.in_bias: 83 | model.layers[i].bias1.data[:, :] = target_model.layers[i].bias1.data.detach().clone() 84 | model.layers[i].bias1.requires_grad = False 85 | if resid_mlp_config.out_bias: 86 | model.layers[i].bias2.data[:, :] = target_model.layers[i].bias2.data.detach().clone() 87 | model.layers[i].bias2.requires_grad = False 88 | 89 | # Create dataset and dataloader 90 | dataset = ResidualMLPDataset( 91 | n_instances=model.n_instances, 92 | n_features=model.n_features, 93 | feature_probability=config.task_config.feature_probability, 94 | device=device, 95 | calc_labels=False, 96 | label_type=None, 97 | act_fn_name=None, 98 | label_fn_seed=None, 99 | label_coeffs=None, 100 | data_generation_type="at_least_zero_active", 101 | ) 102 | dataloader = DatasetGeneratedDataLoader(dataset, batch_size=config.batch_size, shuffle=False) 103 | 104 | # Set up param_map 105 | param_map = {} 106 | for i in range(resid_mlp_config.n_layers): 107 | param_map[f"layers.{i}.mlp_in"] = f"layers.{i}.mlp_in" 108 | param_map[f"layers.{i}.mlp_out"] = f"layers.{i}.mlp_out" 109 | 110 | # Calculate initial loss 111 | with torch.inference_mode(): 112 | batch, _ = next(iter(dataloader)) 113 | initial_out = model(batch) 114 | labels = target_model(batch) 115 | initial_loss = torch.mean((labels - initial_out) ** 2).item() 116 | 117 | param_names = [] 118 | for i in range(target_model.config.n_layers): 119 | param_names.append(f"layers.{i}.mlp_in") 120 | param_names.append(f"layers.{i}.mlp_out") 121 | # Run optimize function 122 | optimize( 123 | model=model, 124 | config=config, 125 | device=device, 126 | dataloader=dataloader, 127 | target_model=target_model, 128 | param_names=param_names, 129 | out_dir=None, 130 | plot_results_fn=None, 131 | ) 132 | 133 | # Calculate final loss 134 | with torch.inference_mode(): 135 | final_out = model(batch) 136 | final_loss = torch.mean((labels - final_out) ** 2).item() 137 | 138 | print(f"Final loss: {final_loss}, initial loss: {initial_loss}") 139 | # Assert that the final loss is lower than the initial loss 140 | assert ( 141 | final_loss < initial_loss 142 | ), f"Expected final loss to be lower than initial loss, but got {final_loss} >= {initial_loss}" 143 | 144 | # Show that W_E is still the same as the target model's W_E 145 | assert torch.allclose(model.W_E, target_model.W_E, atol=1e-6) 146 | 147 | 148 | def test_resid_mlp_equivalent_to_raw_model() -> None: 149 | device = "cpu" 150 | set_seed(0) 151 | resid_mlp_config = ResidualMLPConfig( 152 | n_instances=2, 153 | n_features=3, 154 | d_embed=2, 155 | d_mlp=3, 156 | n_layers=2, 157 | act_fn_name="relu", 158 | apply_output_act_fn=False, 159 | in_bias=True, 160 | out_bias=True, 161 | ) 162 | C = 2 163 | 164 | target_model = ResidualMLPModel(config=resid_mlp_config).to(device) 165 | 166 | # Create the SPD model with k=1 167 | resid_mlp_spd_config = ResidualMLPSPDConfig(**resid_mlp_config.model_dump(), C=C) 168 | spd_model = ResidualMLPSPDModel(config=resid_mlp_spd_config).to(device) 169 | 170 | # Init all params to random values 171 | for param in spd_model.parameters(): 172 | param.data = torch.randn_like(param.data) 173 | 174 | # Copy the subnetwork params from the SPD model to the target model 175 | for i in range(target_model.config.n_layers): 176 | for pos in ["mlp_in", "mlp_out"]: 177 | target_pos: Tensor = get_nested_module_attr(target_model, f"layers.{i}.{pos}.weight") 178 | spd_pos: Tensor = get_nested_module_attr(spd_model, f"layers.{i}.{pos}.weight") 179 | target_pos.data[:, :, :] = spd_pos.data 180 | 181 | # Also copy the embeddings and biases 182 | target_model.W_E.data[:, :, :] = spd_model.W_E.data 183 | target_model.W_U.data[:, :, :] = spd_model.W_U.data 184 | for i in range(resid_mlp_config.n_layers): 185 | target_model.layers[i].bias1.data[:, :] = spd_model.layers[i].bias1.data 186 | target_model.layers[i].bias2.data[:, :] = spd_model.layers[i].bias2.data 187 | 188 | # Create a random input 189 | batch_size = 4 190 | input_data: Float[torch.Tensor, "batch n_instances n_features"] = torch.rand( 191 | batch_size, resid_mlp_config.n_instances, resid_mlp_config.n_features, device=device 192 | ) 193 | 194 | with torch.inference_mode(): 195 | # Forward pass on target model 196 | target_cache_filter = lambda k: k.endswith((".hook_pre", ".hook_post")) 197 | target_out, target_cache = target_model.run_with_cache( 198 | input_data, names_filter=target_cache_filter 199 | ) 200 | # Forward pass with all subnetworks 201 | spd_cache_filter = lambda k: k.endswith((".hook_post", ".hook_component_acts")) 202 | out, spd_cache = spd_model.run_with_cache(input_data, names_filter=spd_cache_filter) 203 | 204 | # Assert outputs are the same 205 | assert torch.allclose(target_out, out, atol=1e-6), "Outputs do not match" 206 | 207 | # Assert that all post-acts are the same 208 | target_post_weight_acts = {k: v for k, v in target_cache.items() if k.endswith(".hook_post")} 209 | spd_post_weight_acts = {k: v for k, v in spd_cache.items() if k.endswith(".hook_post")} 210 | for key_name in target_post_weight_acts: 211 | assert torch.allclose( 212 | target_post_weight_acts[key_name], spd_post_weight_acts[key_name], atol=1e-6 213 | ), f"post-acts do not match at layer {key_name}" 214 | -------------------------------------------------------------------------------- /tests/test_spd_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from spd.run_spd import _calc_param_mse, calc_act_recon 4 | 5 | 6 | class TestCalcParamMatchLoss: 7 | # Actually testing _calc_param_mse. calc_param_match_loss should fail hard in most cases, and 8 | # testing it would require lots of mocking the way it is currently written. 9 | def test_calc_param_match_loss_single_instance_single_param(self): 10 | A = torch.ones(2, 3) 11 | B = torch.ones(3, 2) 12 | n_params = 2 * 3 * 2 13 | spd_params = {"layer1": A @ B} 14 | target_params = {"layer1": torch.tensor([[1.0, 1.0], [1.0, 1.0]])} 15 | 16 | result = _calc_param_mse( 17 | params1=target_params, 18 | params2=spd_params, 19 | n_params=n_params, 20 | device="cpu", 21 | ) 22 | 23 | # A: [2, 3], B: [3, 2], both filled with ones 24 | # AB: [[3, 3], [3, 3]] 25 | # (AB - pretrained_weights)^2: [[4, 4], [4, 4]] 26 | # Sum and divide by n_params: 16 / 12 = 4/3 27 | expected = torch.tensor(4.0 / 3.0) 28 | assert torch.allclose(result, expected), f"Expected {expected}, but got {result}" 29 | 30 | def test_calc_param_match_loss_single_instance_multiple_params(self): 31 | As = [torch.ones(2, 3), torch.ones(3, 3)] 32 | Bs = [torch.ones(3, 3), torch.ones(3, 2)] 33 | n_params = 2 * 3 * 3 + 3 * 3 * 2 34 | target_params = { 35 | "layer1": torch.tensor([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), 36 | "layer2": torch.tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]), 37 | } 38 | spd_params = { 39 | "layer1": As[0] @ Bs[0], 40 | "layer2": As[1] @ Bs[1], 41 | } 42 | result = _calc_param_mse( 43 | params1=target_params, 44 | params2=spd_params, 45 | n_params=n_params, 46 | device="cpu", 47 | ) 48 | 49 | # First layer: AB1: [[3, 3, 3], [3, 3, 3]], diff^2: [[1, 1, 1], [1, 1, 1]] 50 | # Second layer: AB2: [[3, 3], [3, 3], [3, 3]], diff^2: [[4, 4], [4, 4], [4, 4]] 51 | # Add together 24 + 6 = 30 52 | # Divide by n_params: 30 / (18+18) = 5/6 53 | expected = torch.tensor(5.0 / 6.0) 54 | assert torch.allclose(result, expected), f"Expected {expected}, but got {result}" 55 | 56 | def test_calc_param_match_loss_multiple_instances(self): 57 | As = [torch.ones(2, 2, 3)] 58 | Bs = [torch.ones(2, 3, 2)] 59 | n_params = 2 * 3 * 2 60 | target_params = { 61 | "layer1": torch.tensor([[[2.0, 2.0], [2.0, 2.0]], [[1.0, 1.0], [1.0, 1.0]]]) 62 | } 63 | spd_params = {"layer1": As[0] @ Bs[0]} 64 | result = _calc_param_mse( 65 | params1=target_params, 66 | params2=spd_params, 67 | n_params=n_params, 68 | device="cpu", 69 | ) 70 | 71 | # AB [n_instances=2, d_in=2, d_out=2]: [[[3, 3], [3, 3]], [[3, 3], [3, 3]]] 72 | # diff^2: [[[1, 1], [1, 1]], [[4, 4], [4, 4]]] 73 | # Sum together and divide by n_params: [4, 16] / 12 = [1/3, 4/3] 74 | expected = torch.tensor([1.0 / 3.0, 4.0 / 3.0]) 75 | assert torch.allclose(result, expected), f"Expected {expected}, but got {result}" 76 | 77 | 78 | class TestCalcActReconLoss: 79 | def test_calc_topk_act_recon_simple(self): 80 | # Batch size 2, d_out 2 81 | target_post_weight_acts = {"layer1": torch.tensor([[1.0, 2.0], [3.0, 4.0]])} 82 | layer_acts_topk = {"layer1": torch.tensor([[1.0, 2.0], [3.0, 4.0]])} 83 | expected = torch.tensor(0.0) 84 | 85 | result = calc_act_recon(target_post_weight_acts, layer_acts_topk) 86 | torch.testing.assert_close(result, expected) 87 | 88 | def test_calc_topk_act_recon_different_d_out(self): 89 | # Batch size 2, d_out 2/3 90 | target_post_weight_acts = { 91 | "layer1": torch.tensor([[1.0, 2.0], [3.0, 4.0]]), 92 | "layer2": torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]]), 93 | } 94 | layer_acts_topk = { 95 | "layer1": torch.tensor([[1.5, 2.5], [4.0, 5.0]]), 96 | "layer2": torch.tensor([[5.5, 6.5, 7.5], [9.0, 10.0, 11.0]]), 97 | } 98 | expected = torch.tensor((0.25 + 1) / 2) # ((0.5^2 * 5) / 5 + (1^2 * 5) / 5) / 2 99 | 100 | result = calc_act_recon(target_post_weight_acts, layer_acts_topk) 101 | torch.testing.assert_close(result, expected) 102 | 103 | def test_calc_topk_act_recon_with_n_instances(self): 104 | target_post_weight_acts = { 105 | "layer1": torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]), 106 | "layer2": torch.tensor([[[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]]]), 107 | } 108 | layer_acts_topk = { 109 | "layer1": torch.tensor([[[1.5, 2.5], [3.5, 4.5]], [[5.5, 6.5], [7.5, 8.5]]]), 110 | "layer2": torch.tensor([[[9.5, 10.5], [11.5, 12.5]], [[13.5, 14.5], [15.5, 16.5]]]), 111 | } 112 | expected = torch.tensor([0.25, 0.25]) # (0.5^2 * 8) / 8 for each instance 113 | 114 | result = calc_act_recon(target_post_weight_acts, layer_acts_topk) 115 | torch.testing.assert_close(result, expected) 116 | -------------------------------------------------------------------------------- /tests/test_spd_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from spd.experiments.resid_mlp.models import ResidualMLPSPDConfig, ResidualMLPSPDModel 4 | from spd.experiments.tms.models import TMSSPDModel, TMSSPDModelConfig 5 | 6 | 7 | def test_tms_set_and_restore_subnet(): 8 | subnet_idx = 2 9 | config = TMSSPDModelConfig( 10 | n_instances=2, 11 | n_features=4, 12 | n_hidden=3, 13 | C=5, 14 | n_hidden_layers=1, 15 | bias_val=0.0, 16 | device="cpu", 17 | ) 18 | model = TMSSPDModel(config) 19 | assert model.linear1.component_weights.shape == (2, 5, 4, 3) # (n_instances, C, d_in, d_out) 20 | 21 | # Get the original values of the weight_matrix of subnet_idx 22 | original_vals = model.linear1.component_weights[:, subnet_idx, :, :].detach().clone() 23 | 24 | # Now set the 3rd subnet to zero 25 | stored_vals = model.set_subnet_to_zero(subnet_idx=subnet_idx, has_instance_dim=True) 26 | 27 | # Check that model.linear1.component_weights is zero for all instances 28 | assert model.linear1.component_weights[:, subnet_idx, :, :].allclose( 29 | torch.zeros_like(model.linear1.component_weights[:, subnet_idx, :, :]) 30 | ) 31 | assert subnet_idx != 0 32 | # Check that it's not zero in another component 33 | assert not model.linear1.component_weights[:, 0, :, :].allclose( 34 | torch.zeros_like(model.linear1.component_weights[:, 0, :, :]) 35 | ) 36 | 37 | # Now restore the subnet 38 | model.restore_subnet(subnet_idx=subnet_idx, stored_vals=stored_vals, has_instance_dim=True) 39 | assert model.linear1.component_weights[:, subnet_idx, :, :].allclose(original_vals) 40 | 41 | 42 | def test_resid_mlp_set_and_restore_subnet(): 43 | subnet_idx = 2 44 | config = ResidualMLPSPDConfig( 45 | n_instances=2, 46 | n_features=4, 47 | d_embed=6, 48 | d_mlp=8, 49 | n_layers=1, 50 | act_fn_name="gelu", 51 | apply_output_act_fn=False, 52 | in_bias=False, 53 | out_bias=False, 54 | init_scale=1.0, 55 | C=5, 56 | init_type="xavier_normal", 57 | ) 58 | model = ResidualMLPSPDModel(config) 59 | 60 | # Check shapes of first layer's component weights 61 | assert model.layers[0].mlp_in.component_weights.shape == (2, 5, 6, 8) # n_inst, C, d_in, d_out 62 | 63 | # Get the original values of the weight_matrix of subnet_idx for both mlp_in and mlp_out 64 | original_vals_in = ( 65 | model.layers[0].mlp_in.component_weights[:, subnet_idx, :, :].detach().clone() 66 | ) 67 | original_vals_out = ( 68 | model.layers[0].mlp_out.component_weights[:, subnet_idx, :, :].detach().clone() 69 | ) 70 | 71 | # Set the subnet to zero 72 | stored_vals = model.set_subnet_to_zero(subnet_idx=subnet_idx, has_instance_dim=True) 73 | 74 | # Check that component_weights are zero for all instances in both mlp_in and mlp_out 75 | assert ( 76 | model.layers[0] 77 | .mlp_in.component_weights[:, subnet_idx, :, :] 78 | .allclose(torch.zeros_like(model.layers[0].mlp_in.component_weights[:, subnet_idx, :, :])) 79 | ) 80 | assert ( 81 | model.layers[0] 82 | .mlp_out.component_weights[:, subnet_idx, :, :] 83 | .allclose(torch.zeros_like(model.layers[0].mlp_out.component_weights[:, subnet_idx, :, :])) 84 | ) 85 | 86 | assert subnet_idx != 0 87 | # Check that it's not zero in another component 88 | assert ( 89 | not model.layers[0] 90 | .mlp_in.component_weights[:, 0, :, :] 91 | .allclose(torch.zeros_like(model.layers[0].mlp_in.component_weights[:, 0, :, :])) 92 | ) 93 | assert ( 94 | not model.layers[0] 95 | .mlp_out.component_weights[:, 0, :, :] 96 | .allclose(torch.zeros_like(model.layers[0].mlp_out.component_weights[:, 0, :, :])) 97 | ) 98 | 99 | # Restore the subnet 100 | model.restore_subnet(subnet_idx=subnet_idx, stored_vals=stored_vals, has_instance_dim=True) 101 | 102 | # Verify restoration was successful 103 | assert model.layers[0].mlp_in.component_weights[:, subnet_idx, :, :].allclose(original_vals_in) 104 | assert ( 105 | model.layers[0].mlp_out.component_weights[:, subnet_idx, :, :].allclose(original_vals_out) 106 | ) 107 | -------------------------------------------------------------------------------- /tests/test_tms.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | import torch 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from spd.experiments.tms.models import ( 9 | TMSModel, 10 | TMSModelConfig, 11 | TMSSPDModel, 12 | TMSSPDModelConfig, 13 | ) 14 | from spd.experiments.tms.train_tms import TMSTrainConfig, get_model_and_dataloader, train 15 | from spd.module_utils import get_nested_module_attr 16 | from spd.run_spd import Config, TMSTaskConfig, optimize 17 | from spd.utils import ( 18 | DatasetGeneratedDataLoader, 19 | SparseFeatureDataset, 20 | set_seed, 21 | ) 22 | 23 | # Create a simple TMS config that we can use in multiple tests 24 | TMS_TASK_CONFIG = TMSTaskConfig( 25 | task_name="tms", 26 | feature_probability=0.5, 27 | train_bias=False, 28 | bias_val=0.0, 29 | pretrained_model_path=Path(""), # We'll create this later 30 | ) 31 | 32 | 33 | def tms_spd_happy_path(config: Config, n_hidden_layers: int = 0): 34 | set_seed(0) 35 | device = "cpu" 36 | assert isinstance(config.task_config, TMSTaskConfig) 37 | 38 | # For our pretrained model, just use a randomly initialized TMS model 39 | tms_model_config = TMSModelConfig( 40 | n_instances=2, 41 | n_features=5, 42 | n_hidden=2, 43 | n_hidden_layers=n_hidden_layers, 44 | device=device, 45 | ) 46 | target_model = TMSModel(config=tms_model_config) 47 | 48 | tms_spd_model_config = TMSSPDModelConfig( 49 | **tms_model_config.model_dump(mode="json"), 50 | C=config.C, 51 | bias_val=config.task_config.bias_val, 52 | ) 53 | model = TMSSPDModel(config=tms_spd_model_config) 54 | # Randomly initialize the bias for the pretrained model 55 | target_model.b_final.data = torch.randn_like(target_model.b_final.data) 56 | # Manually set the bias for the SPD model from the bias in the pretrained model 57 | model.b_final.data[:] = target_model.b_final.data.clone() 58 | 59 | if not config.task_config.train_bias: 60 | model.b_final.requires_grad = False 61 | 62 | dataset = SparseFeatureDataset( 63 | n_instances=target_model.config.n_instances, 64 | n_features=target_model.config.n_features, 65 | feature_probability=config.task_config.feature_probability, 66 | device=device, 67 | data_generation_type=config.task_config.data_generation_type, 68 | value_range=(0.0, 1.0), 69 | ) 70 | dataloader = DatasetGeneratedDataLoader(dataset, batch_size=config.batch_size) 71 | 72 | # Pick an arbitrary parameter to check that it changes 73 | initial_param = model.linear1.A.clone().detach() 74 | 75 | param_names = ["linear1", "linear2"] 76 | if model.hidden_layers is not None: 77 | for i in range(len(model.hidden_layers)): 78 | param_names.append(f"hidden_layers.{i}") 79 | 80 | optimize( 81 | model=model, 82 | config=config, 83 | device=device, 84 | dataloader=dataloader, 85 | target_model=target_model, 86 | param_names=param_names, 87 | out_dir=None, 88 | plot_results_fn=None, 89 | ) 90 | 91 | assert not torch.allclose( 92 | initial_param, model.linear1.A 93 | ), "Model A matrix should have changed after optimization" 94 | 95 | 96 | def test_tms_batch_topk_no_schatten(): 97 | config = Config( 98 | C=5, 99 | topk=2, 100 | batch_topk=True, 101 | batch_size=4, 102 | steps=4, 103 | print_freq=2, 104 | save_freq=None, 105 | lr=1e-3, 106 | topk_recon_coeff=1, 107 | schatten_pnorm=None, 108 | schatten_coeff=None, 109 | task_config=TMS_TASK_CONFIG, 110 | ) 111 | tms_spd_happy_path(config) 112 | 113 | 114 | @pytest.mark.parametrize("n_hidden_layers", [0, 2]) 115 | def test_tms_batch_topk_and_schatten(n_hidden_layers: int): 116 | config = Config( 117 | C=5, 118 | topk=2, 119 | batch_topk=True, 120 | batch_size=4, 121 | steps=4, 122 | print_freq=2, 123 | save_freq=None, 124 | lr=1e-3, 125 | topk_recon_coeff=1, 126 | schatten_pnorm=0.9, 127 | schatten_coeff=1e-1, 128 | task_config=TMS_TASK_CONFIG, 129 | ) 130 | tms_spd_happy_path(config, n_hidden_layers) 131 | 132 | 133 | def test_tms_topk_and_l2(): 134 | config = Config( 135 | C=5, 136 | topk=2, 137 | batch_topk=False, 138 | batch_size=4, 139 | steps=4, 140 | print_freq=2, 141 | save_freq=None, 142 | lr=1e-3, 143 | topk_recon_coeff=1, 144 | schatten_pnorm=0.9, 145 | schatten_coeff=1e-1, 146 | task_config=TMS_TASK_CONFIG, 147 | ) 148 | tms_spd_happy_path(config) 149 | 150 | 151 | def test_tms_lp(): 152 | config = Config( 153 | C=5, 154 | topk=None, 155 | batch_topk=False, 156 | batch_size=4, 157 | steps=4, 158 | print_freq=2, 159 | save_freq=None, 160 | lr=1e-3, 161 | lp_sparsity_coeff=0.01, 162 | pnorm=0.9, 163 | task_config=TMS_TASK_CONFIG, 164 | ) 165 | tms_spd_happy_path(config) 166 | 167 | 168 | @pytest.mark.parametrize("n_hidden_layers", [0, 2]) 169 | def test_tms_topk_and_lp(n_hidden_layers: int): 170 | config = Config( 171 | C=5, 172 | topk=2, 173 | batch_topk=False, 174 | batch_size=4, 175 | steps=4, 176 | print_freq=2, 177 | save_freq=None, 178 | lr=1e-3, 179 | pnorm=0.9, 180 | topk_recon_coeff=1, 181 | lp_sparsity_coeff=1, 182 | task_config=TMS_TASK_CONFIG, 183 | ) 184 | tms_spd_happy_path(config, n_hidden_layers) 185 | 186 | 187 | def test_train_tms_happy_path(): 188 | device = "cpu" 189 | set_seed(0) 190 | # Set up a small configuration 191 | config = TMSTrainConfig( 192 | tms_model_config=TMSModelConfig( 193 | n_features=3, 194 | n_hidden=2, 195 | n_instances=2, 196 | n_hidden_layers=0, 197 | device=device, 198 | ), 199 | feature_probability=0.1, 200 | batch_size=32, 201 | steps=5, 202 | lr=5e-3, 203 | data_generation_type="at_least_zero_active", 204 | fixed_identity_hidden_layers=False, 205 | fixed_random_hidden_layers=False, 206 | ) 207 | 208 | model, dataloader = get_model_and_dataloader(config, device) 209 | 210 | # Calculate initial loss 211 | batch, labels = next(iter(dataloader)) 212 | initial_out = model(batch) 213 | initial_loss = torch.mean((labels.abs() - initial_out) ** 2) 214 | 215 | train(model, dataloader, steps=config.steps, print_freq=1000, log_wandb=False) 216 | 217 | # Calculate final loss 218 | final_out = model(batch) 219 | final_loss = torch.mean((labels.abs() - final_out) ** 2) 220 | 221 | # Assert that the final loss is lower than the initial loss 222 | assert ( 223 | final_loss < initial_loss 224 | ), f"Final loss ({final_loss:.2e}) is not lower than initial loss ({initial_loss:.2e})" 225 | 226 | 227 | def test_tms_train_fixed_identity(): 228 | """Check that hidden layer is identity before and after training.""" 229 | device = "cpu" 230 | set_seed(0) 231 | config = TMSTrainConfig( 232 | tms_model_config=TMSModelConfig( 233 | n_features=3, 234 | n_hidden=2, 235 | n_instances=2, 236 | n_hidden_layers=2, 237 | device=device, 238 | ), 239 | feature_probability=0.1, 240 | batch_size=32, 241 | steps=2, 242 | lr=5e-3, 243 | data_generation_type="at_least_zero_active", 244 | fixed_identity_hidden_layers=True, 245 | fixed_random_hidden_layers=False, 246 | ) 247 | 248 | model, dataloader = get_model_and_dataloader(config, device) 249 | 250 | eye = torch.eye(config.tms_model_config.n_hidden, device=device).expand( 251 | config.tms_model_config.n_instances, -1, -1 252 | ) 253 | 254 | assert model.hidden_layers is not None 255 | # Assert that this is an identity matrix 256 | initial_hidden = model.hidden_layers[0].weight.data.clone() 257 | assert torch.allclose(initial_hidden, eye), "Initial hidden layer is not identity" 258 | 259 | train(model, dataloader, steps=config.steps, print_freq=1000, log_wandb=False) 260 | 261 | # Assert that the hidden layers remains identity 262 | assert torch.allclose(model.hidden_layers[0].weight.data, eye), "Hidden layer changed" 263 | 264 | 265 | def test_tms_train_fixed_random(): 266 | """Check that hidden layer is random before and after training.""" 267 | device = "cpu" 268 | set_seed(0) 269 | config = TMSTrainConfig( 270 | tms_model_config=TMSModelConfig( 271 | n_features=3, 272 | n_hidden=2, 273 | n_instances=2, 274 | n_hidden_layers=2, 275 | device=device, 276 | ), 277 | feature_probability=0.1, 278 | batch_size=32, 279 | steps=2, 280 | lr=5e-3, 281 | data_generation_type="at_least_zero_active", 282 | fixed_identity_hidden_layers=False, 283 | fixed_random_hidden_layers=True, 284 | ) 285 | 286 | model, dataloader = get_model_and_dataloader(config, device) 287 | 288 | assert model.hidden_layers is not None 289 | initial_hidden = model.hidden_layers[0].weight.data.clone() 290 | 291 | train(model, dataloader, steps=config.steps, print_freq=1000, log_wandb=False) 292 | 293 | # Assert that the hidden layers are unchanged 294 | assert torch.allclose( 295 | model.hidden_layers[0].weight.data, initial_hidden 296 | ), "Hidden layer changed" 297 | 298 | 299 | def test_tms_equivalent_to_raw_model() -> None: 300 | device = "cpu" 301 | set_seed(0) 302 | tms_config = TMSModelConfig( 303 | n_instances=2, 304 | n_features=3, 305 | n_hidden=2, 306 | n_hidden_layers=1, 307 | device=device, 308 | ) 309 | C = 2 310 | 311 | target_model = TMSModel(config=tms_config).to(device) 312 | 313 | # Create the SPD model 314 | tms_spd_config = TMSSPDModelConfig( 315 | **tms_config.model_dump(), 316 | C=C, 317 | m=3, # Small m for testing 318 | bias_val=0.0, 319 | ) 320 | spd_model = TMSSPDModel(config=tms_spd_config).to(device) 321 | 322 | # Init all params to random values 323 | for param in spd_model.parameters(): 324 | param.data = torch.randn_like(param.data) 325 | 326 | # Copy the subnetwork params from the SPD model to the target model 327 | target_model.linear1.weight.data[:, :, :] = spd_model.linear1.weight.data 328 | if target_model.hidden_layers is not None: 329 | for i in range(target_model.config.n_hidden_layers): 330 | target_layer: Tensor = get_nested_module_attr(target_model, f"hidden_layers.{i}.weight") 331 | spd_layer: Tensor = get_nested_module_attr(spd_model, f"hidden_layers.{i}.weight") 332 | target_layer.data[:, :, :] = spd_layer.data 333 | 334 | # Also copy the bias 335 | target_model.b_final.data[:, :] = spd_model.b_final.data 336 | 337 | # Create a random input 338 | batch_size = 4 339 | input_data: Float[torch.Tensor, "batch n_instances n_features"] = torch.rand( 340 | batch_size, tms_config.n_instances, tms_config.n_features, device=device 341 | ) 342 | 343 | with torch.inference_mode(): 344 | # Forward pass on target model 345 | target_cache_filter = lambda k: k.endswith((".hook_pre", ".hook_post")) 346 | target_out, target_cache = target_model.run_with_cache( 347 | input_data, names_filter=target_cache_filter 348 | ) 349 | # Forward pass with all subnetworks 350 | spd_cache_filter = lambda k: k.endswith((".hook_post", ".hook_component_acts")) 351 | out, spd_cache = spd_model.run_with_cache(input_data, names_filter=spd_cache_filter) 352 | 353 | # Assert outputs are the same 354 | assert torch.allclose(target_out, out, atol=1e-6), "Outputs do not match" 355 | 356 | # Assert that all post-acts are the same 357 | target_post_weight_acts = {k: v for k, v in target_cache.items() if k.endswith(".hook_post")} 358 | spd_post_weight_acts = {k: v for k, v in spd_cache.items() if k.endswith(".hook_post")} 359 | for key_name in target_post_weight_acts: 360 | assert torch.allclose( 361 | target_post_weight_acts[key_name], spd_post_weight_acts[key_name], atol=1e-6 362 | ), f"post-acts do not match at layer {key_name}" 363 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import pytest 4 | import torch 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from spd.utils import ( 9 | SparseFeatureDataset, 10 | calc_activation_attributions, 11 | calc_topk_mask, 12 | compute_feature_importances, 13 | ) 14 | 15 | 16 | def test_calc_topk_mask_without_batch_topk(): 17 | attribution_scores = torch.tensor([[1.0, 5.0, 2.0, 1.0, 2.0], [3.0, 3.0, 5.0, 4.0, 4.0]]) 18 | topk = 3 19 | expected_mask = torch.tensor( 20 | [[False, True, True, False, True], [False, False, True, True, True]] 21 | ) 22 | 23 | result = calc_topk_mask(attribution_scores, topk, batch_topk=False) 24 | torch.testing.assert_close(result, expected_mask) 25 | 26 | 27 | def test_calc_topk_mask_with_batch_topk(): 28 | attribution_scores = torch.tensor([[1.0, 5.0, 2.0, 1.0, 2.0], [3.0, 3.0, 5.0, 4.0, 4.0]]) 29 | topk = 3 # mutliplied by batch size to get 6 30 | expected_mask = torch.tensor( 31 | [[False, True, False, False, False], [True, True, True, True, True]] 32 | ) 33 | 34 | result = calc_topk_mask(attribution_scores, topk, batch_topk=True) 35 | torch.testing.assert_close(result, expected_mask) 36 | 37 | 38 | def test_calc_topk_mask_without_batch_topk_n_instances(): 39 | """attributions have shape [batch, n_instances, n_features]. We take the topk 40 | over the n_features dim for each instance in each batch.""" 41 | attribution_scores = torch.tensor( 42 | [[[1.0, 5.0, 3.0, 4.0], [2.0, 4.0, 6.0, 1.0]], [[2.0, 1.0, 5.0, 9.5], [3.0, 4.0, 1.0, 5.0]]] 43 | ) 44 | topk = 2 45 | expected_mask = torch.tensor( 46 | [ 47 | [[False, True, False, True], [False, True, True, False]], 48 | [[False, False, True, True], [False, True, False, True]], 49 | ] 50 | ) 51 | 52 | result = calc_topk_mask(attribution_scores, topk, batch_topk=False) 53 | torch.testing.assert_close(result, expected_mask) 54 | 55 | 56 | def test_calc_topk_mask_with_batch_topk_n_instances(): 57 | """attributions have shape [batch, n_instances, n_features]. We take the topk 58 | over the concatenated batch and n_features dim.""" 59 | attribution_scores = torch.tensor( 60 | [[[1.0, 5.0, 3.0], [2.0, 4.0, 6.0]], [[2.0, 1.0, 5.0], [3.0, 4.0, 1.0]]] 61 | ) 62 | topk = 2 # multiplied by batch size to get 4 63 | expected_mask = torch.tensor( 64 | [[[False, True, True], [False, True, True]], [[True, False, True], [True, True, False]]] 65 | ) 66 | 67 | result = calc_topk_mask(attribution_scores, topk, batch_topk=True) 68 | torch.testing.assert_close(result, expected_mask) 69 | 70 | 71 | def test_calc_activation_attributions_obvious(): 72 | component_acts = {"layer1": torch.tensor([[[1.0, 0.0], [0.0, 1.0]]])} 73 | expected = torch.tensor([[1.0, 1.0]]) 74 | 75 | result = calc_activation_attributions(component_acts) 76 | torch.testing.assert_close(result, expected) 77 | 78 | 79 | def test_calc_activation_attributions_different_d_out(): 80 | component_acts = { 81 | "layer1": torch.tensor([[[1.0, 2.0], [3.0, 4.0]]]), 82 | "layer2": torch.tensor([[[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]]]), 83 | } 84 | expected = torch.tensor( 85 | [[1.0**2 + 2**2 + 5**2 + 6**2 + 7**2, 3**2 + 4**2 + 8**2 + 9**2 + 10**2]] 86 | ) 87 | 88 | result = calc_activation_attributions(component_acts) 89 | torch.testing.assert_close(result, expected) 90 | 91 | 92 | def test_calc_activation_attributions_with_n_instances(): 93 | # Batch=1, n_instances=2, C=2, d_out=2 94 | component_acts = { 95 | "layer1": torch.tensor([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]), 96 | "layer2": torch.tensor([[[[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]]]]), 97 | } 98 | expected = torch.tensor( 99 | [ 100 | [ 101 | [1.0**2 + 2**2 + 9**2 + 10**2, 3**2 + 4**2 + 11**2 + 12**2], 102 | [5**2 + 6**2 + 13**2 + 14**2, 7**2 + 8**2 + 15**2 + 16**2], 103 | ] 104 | ] 105 | ) 106 | 107 | result = calc_activation_attributions(component_acts) 108 | torch.testing.assert_close(result, expected) 109 | 110 | 111 | def test_dataset_at_least_zero_active(): 112 | n_instances = 3 113 | n_features = 5 114 | feature_probability = 0.5 115 | device = "cpu" 116 | batch_size = 100 117 | 118 | dataset = SparseFeatureDataset( 119 | n_instances=n_instances, 120 | n_features=n_features, 121 | feature_probability=feature_probability, 122 | device=device, 123 | data_generation_type="at_least_zero_active", 124 | value_range=(0.0, 1.0), 125 | ) 126 | 127 | batch, _ = dataset.generate_batch(batch_size) 128 | 129 | # Check shape 130 | assert batch.shape == (batch_size, n_instances, n_features), "Incorrect batch shape" 131 | 132 | # Check that the values are between 0 and 1 133 | assert torch.all((batch >= 0) & (batch <= 1)), "Values should be between 0 and 1" 134 | 135 | # Check that the proportion of non-zero elements is close to feature_probability 136 | non_zero_proportion = torch.count_nonzero(batch) / batch.numel() 137 | assert ( 138 | abs(non_zero_proportion - feature_probability) < 0.05 139 | ), f"Expected proportion {feature_probability}, but got {non_zero_proportion}" 140 | 141 | 142 | def test_generate_multi_feature_batch_no_zero_samples(): 143 | n_instances = 3 144 | n_features = 5 145 | feature_probability = 0.05 # Low probability to increase chance of zero samples 146 | device = "cpu" 147 | batch_size = 100 148 | buffer_ratio = 1.5 149 | 150 | dataset = SparseFeatureDataset( 151 | n_instances=n_instances, 152 | n_features=n_features, 153 | feature_probability=feature_probability, 154 | device=device, 155 | data_generation_type="at_least_zero_active", 156 | value_range=(0.0, 1.0), 157 | ) 158 | 159 | batch = dataset._generate_multi_feature_batch_no_zero_samples(batch_size, buffer_ratio) 160 | 161 | # Check shape 162 | assert batch.shape == (batch_size, n_instances, n_features), "Incorrect batch shape" 163 | 164 | # Check that the values are between 0 and 1 165 | assert torch.all((batch >= 0) & (batch <= 1)), "Values should be between 0 and 1" 166 | 167 | # Check that there are no all-zero samples 168 | zero_samples = (batch.sum(dim=-1) == 0).sum() 169 | assert zero_samples == 0, f"Found {zero_samples} samples with all zeros" 170 | 171 | 172 | @pytest.mark.parametrize("n", [1, 2, 3, 4, 5]) 173 | def test_dataset_exactly_n_active(n: int): 174 | n_instances = 3 175 | n_features = 10 176 | feature_probability = 0.5 # This won't be used when data_generation_type="exactly_one_active" 177 | device = "cpu" 178 | batch_size = 10 179 | value_range = (0.0, 1.0) 180 | 181 | n_map: dict[ 182 | int, 183 | Literal[ 184 | "exactly_one_active", 185 | "exactly_two_active", 186 | "exactly_three_active", 187 | "exactly_four_active", 188 | "exactly_five_active", 189 | ], 190 | ] = { 191 | 1: "exactly_one_active", 192 | 2: "exactly_two_active", 193 | 3: "exactly_three_active", 194 | 4: "exactly_four_active", 195 | 5: "exactly_five_active", 196 | } 197 | dataset = SparseFeatureDataset( 198 | n_instances=n_instances, 199 | n_features=n_features, 200 | feature_probability=feature_probability, 201 | device=device, 202 | data_generation_type=n_map[n], 203 | value_range=value_range, 204 | ) 205 | 206 | batch, _ = dataset.generate_batch(batch_size) 207 | 208 | # Check shape 209 | assert batch.shape == (batch_size, n_instances, n_features), "Incorrect batch shape" 210 | 211 | # Check that there's exactly one non-zero value per sample and instance 212 | for sample in batch: 213 | for instance in sample: 214 | non_zero_count = torch.count_nonzero(instance) 215 | assert non_zero_count == n, f"Expected {n} non-zero values, but found {non_zero_count}" 216 | 217 | # Check that the non-zero values are in the value_range 218 | non_zero_values = batch[batch != 0] 219 | assert torch.all( 220 | (non_zero_values >= value_range[0]) & (non_zero_values <= value_range[1]) 221 | ), f"Non-zero values should be between {value_range[0]} and {value_range[1]}" 222 | 223 | 224 | @pytest.mark.parametrize( 225 | "importance_val, expected_tensor", 226 | [ 227 | ( 228 | 1.0, 229 | torch.tensor([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]]), 230 | ), 231 | ( 232 | 0.5, 233 | torch.tensor( 234 | [[[1.0, 0.5, 0.25], [1.0, 0.5, 0.25]], [[1.0, 0.5, 0.25], [1.0, 0.5, 0.25]]] 235 | ), 236 | ), 237 | ( 238 | 0.0, 239 | torch.tensor([[[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]]), 240 | ), 241 | ], 242 | ) 243 | def test_compute_feature_importances( 244 | importance_val: float, expected_tensor: Float[Tensor, "batch_size n_instances n_features"] 245 | ): 246 | importances = compute_feature_importances( 247 | batch_size=2, n_instances=2, n_features=3, importance_val=importance_val, device="cpu" 248 | ) 249 | torch.testing.assert_close(importances, expected_tensor) 250 | 251 | 252 | def test_sync_inputs_non_overlapping(): 253 | dataset = SparseFeatureDataset( 254 | n_instances=1, 255 | n_features=6, 256 | feature_probability=0.5, 257 | device="cpu", 258 | data_generation_type="at_least_zero_active", 259 | value_range=(0.0, 1.0), 260 | synced_inputs=[[0, 1], [2, 3, 4]], 261 | ) 262 | 263 | batch, _ = dataset.generate_batch(5) 264 | # Ignore the n_instances dimension 265 | batch = batch[:, 0, :] 266 | for sample in batch: 267 | # If there is a value in 0 or 1, there should be a value in 1 or 268 | if sample[0] != 0.0: 269 | assert sample[1] != 0.0 270 | if sample[1] != 0.0: 271 | assert sample[0] != 0.0 272 | if sample[2] != 0.0: 273 | assert sample[3] != 0.0 and sample[4] != 0.0 274 | if sample[3] != 0.0: 275 | assert sample[2] != 0.0 and sample[4] != 0.0 276 | if sample[4] != 0.0: 277 | assert sample[2] != 0.0 and sample[3] != 0.0 278 | 279 | 280 | def test_sync_inputs_overlapping(): 281 | dataset = SparseFeatureDataset( 282 | n_instances=1, 283 | n_features=6, 284 | feature_probability=0.5, 285 | device="cpu", 286 | data_generation_type="at_least_zero_active", 287 | value_range=(0.0, 1.0), 288 | synced_inputs=[[0, 1], [1, 2, 3]], 289 | ) 290 | # Should raise an assertion error with the word "overlapping" 291 | with pytest.raises(AssertionError, match="overlapping"): 292 | dataset.generate_batch(5) 293 | --------------------------------------------------------------------------------