├── .devcontainer ├── Dockerfile └── devcontainer.json ├── .github └── workflows │ ├── checks.yml │ ├── gh-pages.yml │ └── release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode ├── cspell.json ├── extensions.json └── settings.json ├── LICENSE ├── README.md ├── docs ├── content │ ├── SUMMARY.md │ ├── citation.md │ ├── contributing.md │ ├── css │ │ ├── custom_formatting.css │ │ └── material_extra.css │ ├── demo.ipynb │ ├── index.md │ ├── javascript │ │ ├── custom_formatting.js │ │ └── mathjax.js │ └── pre-process-datasets.ipynb └── gen_ref_pages.py ├── mkdocs.yml ├── poetry.lock ├── pyproject.toml └── sparse_autoencoder ├── __init__.py ├── activation_resampler ├── __init__.py ├── activation_resampler.py ├── tests │ └── test_activation_resampler.py └── utils │ ├── __init__.py │ └── component_slice_tensor.py ├── activation_store ├── __init__.py ├── base_store.py ├── tensor_store.py └── tests │ └── test_tensor_store.py ├── autoencoder ├── __init__.py ├── components │ ├── __init__.py │ ├── linear_encoder.py │ ├── tests │ │ ├── __snapshots__ │ │ │ └── test_linear_encoder.ambr │ │ ├── test_compare_neel_implementation.py │ │ ├── test_linear_encoder.py │ │ ├── test_tied_bias.py │ │ └── test_unit_norm_decoder.py │ ├── tied_bias.py │ └── unit_norm_decoder.py ├── lightning.py ├── model.py ├── tests │ ├── __snapshots__ │ │ └── test_model.ambr │ └── test_model.py └── types.py ├── metrics ├── __init__.py ├── loss │ ├── __init__.py │ ├── l1_absolute_loss.py │ ├── l2_reconstruction_loss.py │ ├── sae_loss.py │ └── tests │ │ ├── test_l1_absolute_loss.py │ │ ├── test_l2_reconstruction_loss.py │ │ └── test_sae_loss.py ├── train │ ├── __init__.py │ ├── capacity.py │ ├── feature_density.py │ ├── l0_norm.py │ ├── neuron_activity.py │ ├── neuron_fired_count.py │ └── tests │ │ ├── test_feature_density.py │ │ ├── test_l0_norm.py │ │ ├── test_neuron_activity.py │ │ └── test_neuron_fired_count.py ├── validate │ ├── __init__.py │ ├── reconstruction_score.py │ └── tests │ │ └── __snapshots__ │ │ └── test_model_reconstruction_score.ambr └── wrappers │ ├── __init__.py │ ├── classwise.py │ └── tests │ └── test_classwise.py ├── optimizer ├── __init__.py ├── adam_with_reset.py └── tests │ └── test_adam_with_reset.py ├── source_data ├── __init__.py ├── abstract_dataset.py ├── mock_dataset.py ├── pretokenized_dataset.py ├── tests │ ├── test_abstract_dataset.py │ ├── test_mock_dataset.py │ ├── test_pretokenized_dataset.py │ └── test_text_dataset.py └── text_dataset.py ├── source_model ├── __init__.py ├── replace_activations_hook.py ├── reshape_activations.py ├── store_activations_hook.py ├── tests │ ├── test_replace_activations_hook.py │ ├── test_store_activations_hook.py │ └── test_zero_ablate_hook.py └── zero_ablate_hook.py ├── tensor_types.py ├── train ├── __init__.py ├── join_sweep.py ├── pipeline.py ├── sweep.py ├── sweep_config.py ├── tests │ ├── __snapshots__ │ │ └── test_sweep.ambr │ ├── test_pipeline.py │ └── test_sweep.py └── utils │ ├── __init__.py │ ├── get_model_device.py │ ├── round_down.py │ ├── tests │ ├── test_get_model_device.py │ └── test_wandb_sweep_types.py │ └── wandb_sweep_types.py ├── training_runs ├── __init__.py └── gpt2.py └── utils ├── __init__.py ├── data_parallel.py └── tensor_shape.py /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | # Use Nvidia Ubuntu 20 base (includes CUDA if a supported GPU is present) 2 | # https://hub.docker.com/r/nvidia/cuda 3 | FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 4 | 5 | ARG USERNAME 6 | ARG USER_UID=1000 7 | ARG USER_GID=$USER_UID 8 | 9 | # Create the user 10 | # https://code.visualstudio.com/remote/advancedcontainers/add-nonroot-user 11 | RUN groupadd --gid $USER_GID $USERNAME \ 12 | && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \ 13 | && usermod -a -G video user \ 14 | && apt-get update \ 15 | && apt-get install -y sudo \ 16 | && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \ 17 | && chmod 0440 /etc/sudoers.d/$USERNAME 18 | 19 | # Add the deadsnakes PPA to get the latest Python version 20 | RUN sudo apt-get install -y software-properties-common \ 21 | && sudo add-apt-repository ppa:deadsnakes/ppa 22 | 23 | # Install dependencies 24 | RUN sudo apt-get update && \ 25 | DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \ 26 | build-essential \ 27 | curl \ 28 | git \ 29 | poppler-utils \ 30 | python3 \ 31 | python3.11 \ 32 | python3-dev \ 33 | python3-distutils \ 34 | python3-venv \ 35 | expat \ 36 | rsync 37 | 38 | # Install pip (we need the latest version not the standard Ubuntu version, to 39 | # support modern wheels) 40 | RUN sudo curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py 41 | 42 | # Install poetry 43 | RUN sudo curl -sSL https://install.python-poetry.org | python3 - 44 | ENV PATH="/root/.local/bin:$PATH" 45 | 46 | # Set python aliases 47 | RUN sudo update-alternatives --install /usr/bin/python python /usr/bin/python3 1 48 | 49 | # User the new user 50 | USER $USERNAME 51 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Python 3", 3 | "build": { 4 | "dockerfile": "Dockerfile", 5 | "args": { 6 | "USERNAME": "user" 7 | } 8 | }, 9 | "runArgs": [ 10 | "--gpus", 11 | "all" 12 | ], 13 | "containerUser": "user", 14 | "postCreateCommand": "poetry env use 3.11 && poetry install --with dev,jupyter" 15 | } 16 | -------------------------------------------------------------------------------- /.github/workflows/checks.yml: -------------------------------------------------------------------------------- 1 | name: Checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths-ignore: 8 | - '.devcontainer/**' 9 | - '.vscode/**' 10 | - '.gitignore' 11 | - 'README.md' 12 | pull_request: 13 | branches: 14 | - main 15 | paths-ignore: 16 | - '.devcontainer/**' 17 | - '.vscode/**' 18 | - '.gitignore' 19 | - 'README.md' 20 | # Allow this workflow to be called from other workflows 21 | workflow_call: 22 | inputs: 23 | # Requires at least one input to be valid, but in practice we don't need any 24 | dummy: 25 | type: string 26 | required: false 27 | 28 | jobs: 29 | checs: 30 | name: Checks 31 | runs-on: ubuntu-latest 32 | strategy: 33 | matrix: 34 | python-version: 35 | - "3.10" 36 | - "3.11" 37 | steps: 38 | - uses: actions/checkout@v4 39 | - name: Install Poetry 40 | uses: snok/install-poetry@v1 41 | - name: Set up Python 42 | uses: actions/setup-python@v4 43 | with: 44 | python-version: ${{ matrix.python-version }} 45 | cache: 'poetry' 46 | allow-prereleases: true 47 | - name: Check lockfile 48 | run: poetry check 49 | - name: Install dependencies 50 | run: poetry install --with dev 51 | - name: Pyright type check 52 | run: poetry run pyright 53 | - name: Ruff lint 54 | run: poetry run ruff check . --output-format=github 55 | - name: Docstrings lint 56 | run: poetry run pydoclint . 57 | - name: Ruff format 58 | run: poetry run ruff format . --check 59 | - name: Pytest 60 | run: poetry run pytest 61 | - name: Build check 62 | run: poetry build 63 | -------------------------------------------------------------------------------- /.github/workflows/gh-pages.yml: -------------------------------------------------------------------------------- 1 | name: Docs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: 8 | - '*' 9 | 10 | permissions: 11 | contents: write 12 | 13 | jobs: 14 | build-docs: 15 | # When running on a PR, this just checks we can build the docs without errors 16 | # When running on merge to main, it builds the docs and then another job deploys them 17 | name: ${{ github.event_name == 'pull_request' && 'Check Build Docs' || 'Build Docs' }} 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Install Poetry 22 | uses: snok/install-poetry@v1 23 | - name: Set up Python 24 | uses: actions/setup-python@v4 25 | with: 26 | python-version: "3.11" 27 | cache: "poetry" 28 | - name: Install poe 29 | run: pip install poethepoet 30 | - name: Install dependencies 31 | run: poetry install --with docs 32 | - name: Generate docs 33 | run: poe gen-docs 34 | - name: Build Docs 35 | run: poe make-docs 36 | - name: Upload Docs Artifact 37 | uses: actions/upload-artifact@v3 38 | with: 39 | name: documentation 40 | path: docs/generated 41 | 42 | deploy-docs: 43 | name: Deploy Docs 44 | runs-on: ubuntu-latest 45 | # Only run if merging a PR into main 46 | if: github.event_name == 'push' && github.ref == 'refs/heads/main' 47 | needs: build-docs 48 | steps: 49 | - uses: actions/checkout@v4 50 | - name: Download Docs Artifact 51 | uses: actions/download-artifact@v3 52 | with: 53 | name: documentation 54 | path: docs/generated 55 | - name: Upload to GitHub Pages 56 | uses: JamesIves/github-pages-deploy-action@v4 57 | with: 58 | folder: docs/generated 59 | clean-exclude: | 60 | *.*.*/ 61 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: 6 | - published 7 | 8 | jobs: 9 | checks: 10 | name: Run checks workflow 11 | uses: alan-cooney/sparse_autoencoder/.github/workflows/checks.yml@main 12 | 13 | semver-parser: 14 | name: Parse the semantic version from the release 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Parse semver string 18 | id: semver_parser 19 | uses: booxmedialtd/ws-action-parse-semver@v1.4.7 20 | with: 21 | input_string: ${{ github.event.release.tag_name }} 22 | outputs: 23 | major: "${{ steps.semver_parser.outputs.major }}" 24 | minor: "${{ steps.semver_parser.outputs.minor }}" 25 | patch: "${{ steps.semver_parser.outputs.patch }}" 26 | semver: "${{ steps.semver_parser.outputs.fullversion }}" 27 | 28 | release-python: 29 | name: Release Python package to PyPi 30 | needs: 31 | - checks 32 | - semver-parser 33 | runs-on: ubuntu-latest 34 | steps: 35 | - uses: actions/checkout@v3 36 | - name: Install Poetry 37 | uses: snok/install-poetry@v1 38 | - name: Poetry config 39 | run: poetry self add 'poethepoet[poetry_plugin]' 40 | - name: Set up Python 41 | uses: actions/setup-python@v4 42 | with: 43 | python-version: '3.10' 44 | cache: 'poetry' 45 | - name: Install dependencies 46 | run: poetry install --with dev 47 | - name: Set the version 48 | run: poetry version ${{needs.semver-parser.outputs.semver}} 49 | - name: Build 50 | run: poetry build 51 | - name: Publish 52 | run: poetry publish 53 | env: 54 | POETRY_PYPI_TOKEN_PYPI: ${{ secrets.POETRY_PYPI_TOKEN }} 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | manifest 29 | .DS_Store 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Docs stuff 56 | docs/content/reference 57 | docs/generated 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # Generated docs 137 | docs/content/reference 138 | 139 | # Checkpoints directory 140 | .checkpoints 141 | 142 | # Wandb 143 | wandb/ 144 | artifacts/ 145 | 146 | # Lightning 147 | lightning_logs 148 | 149 | # Scratch files 150 | scratch.py 151 | scratch.ipynb 152 | scratch/ 153 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.5.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-toml 11 | - id: check-json 12 | - id: check-added-large-files 13 | - id: check-merge-conflict 14 | - id: check-symlinks 15 | - id: destroyed-symlinks 16 | - id: detect-private-key 17 | - repo: local 18 | hooks: 19 | - id: ruff_lint 20 | name: Ruff Lint 21 | entry: poetry run ruff check sparse_autoencoder --fix 22 | language: system 23 | types: [python] 24 | require_serial: true 25 | - id: ruff_format 26 | name: Ruff Format 27 | entry: poetry run ruff format sparse_autoencoder --check --diff 28 | language: system 29 | types: [python] 30 | require_serial: true 31 | - id: typecheck 32 | name: Pyright Type Check 33 | entry: poetry run pyright 34 | language: system 35 | types: [python] 36 | require_serial: true 37 | -------------------------------------------------------------------------------- /.vscode/cspell.json: -------------------------------------------------------------------------------- 1 | { 2 | "language": "en,en-GB", 3 | "words": [ 4 | "alancooney", 5 | "allclose", 6 | "arange", 7 | "arithmatex", 8 | "astroid", 9 | "autocast", 10 | "autoencoder", 11 | "autoencoders", 12 | "autoencoding", 13 | "autofix", 14 | "capturable", 15 | "categoricalwprobabilities", 16 | "circuitsvis", 17 | "Classwise", 18 | "coeff", 19 | "colab", 20 | "cooney", 21 | "cuda", 22 | "cudnn", 23 | "datapoints", 24 | "davidanson", 25 | "devcontainer", 26 | "devel", 27 | "dmypy", 28 | "docstrings", 29 | "donjayamanne", 30 | "dtype", 31 | "dtypes", 32 | "dunder", 33 | "earlyterminate", 34 | "einops", 35 | "einsum", 36 | "endoftext", 37 | "gelu", 38 | "githistory", 39 | "hobbhahn", 40 | "htmlproofer", 41 | "huggingface", 42 | "hyperband", 43 | "hyperparameters", 44 | "imageuri", 45 | "imputewhilerunning", 46 | "interpretability", 47 | "intuniform", 48 | "invloguniform", 49 | "invloguniformvalues", 50 | "ipynb", 51 | "itemwise", 52 | "jaxtyping", 53 | "kaiming", 54 | "keepdim", 55 | "logit", 56 | "lognormal", 57 | "loguniform", 58 | "loguniformvalues", 59 | "mathbb", 60 | "mathbf", 61 | "maxiter", 62 | "miniter", 63 | "mkdocs", 64 | "mkdocstrings", 65 | "mknotebooks", 66 | "monosemantic", 67 | "monosemanticity", 68 | "multipled", 69 | "nanda", 70 | "ncols", 71 | "ndarray", 72 | "ndim", 73 | "neel", 74 | "nelement", 75 | "neox", 76 | "nonlinerity", 77 | "numel", 78 | "onebit", 79 | "openwebtext", 80 | "optim", 81 | "penality", 82 | "perp", 83 | "pickleable", 84 | "polysemantic", 85 | "polysemantically", 86 | "polysemanticity", 87 | "precommit", 88 | "pretokenized", 89 | "pydantic", 90 | "pyproject", 91 | "pyright", 92 | "pytest", 93 | "qbeta", 94 | "qlognormal", 95 | "qloguniform", 96 | "qloguniformvalues", 97 | "qnormal", 98 | "quniform", 99 | "randn", 100 | "randperm", 101 | "relu", 102 | "resampler", 103 | "resid", 104 | "roneneldan", 105 | "rtol", 106 | "runcap", 107 | "sharded", 108 | "snapshottest", 109 | "solu", 110 | "tinystories", 111 | "torchmetrics", 112 | "tqdm", 113 | "transformer_lens", 114 | "typecheck", 115 | "ultralow", 116 | "uncopyright", 117 | "uncopyrighted", 118 | "ungraphed", 119 | "unsqueeze", 120 | "unsync", 121 | "venv", 122 | "virtualenv", 123 | "virtualenvs", 124 | "wandb", 125 | "zoadam" 126 | ] 127 | } 128 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "davidanson.vscode-markdownlint", 4 | "donjayamanne.githistory", 5 | "github.copilot", 6 | "github.vscode-github-actions", 7 | "github.vscode-pull-request-github", 8 | "ionutvmi.path-autocomplete", 9 | "littlefoxteam.vscode-python-test-adapter", 10 | "ms-python.python", 11 | "ms-python.vscode-pylance", 12 | "ms-toolsai.jupyter-keymap", 13 | "ms-toolsai.jupyter-renderers", 14 | "ms-toolsai.jupyter", 15 | "richie5um2.vscode-sort-json", 16 | "stkb.rewrap", 17 | "streetsidesoftware.code-spell-checker-british-english", 18 | "streetsidesoftware.code-spell-checker", 19 | "yzhang.markdown-all-in-one", 20 | "kevinrose.vsc-python-indent", 21 | "donjayamanne.python-environment-manager", 22 | "tamasfe.even-better-toml" 23 | ] 24 | } 25 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[jsonc]": { 3 | "editor.defaultFormatter": "vscode.json-language-features" 4 | }, 5 | "[python]": { 6 | "editor.defaultFormatter": "charliermarsh.ruff" 7 | }, 8 | "[toml]": { 9 | "editor.defaultFormatter": "tamasfe.even-better-toml" 10 | }, 11 | "editor.codeActionsOnSave": { 12 | "source.fixAll.eslint": "explicit", 13 | "source.organizeImports": "explicit" 14 | }, 15 | "editor.formatOnSave": true, 16 | "evenBetterToml.formatter.allowedBlankLines": 1, 17 | "evenBetterToml.formatter.arrayAutoCollapse": true, 18 | "evenBetterToml.formatter.arrayAutoExpand": true, 19 | "evenBetterToml.formatter.arrayTrailingComma": true, 20 | "evenBetterToml.formatter.columnWidth": 100, 21 | "evenBetterToml.formatter.compactArrays": true, 22 | "evenBetterToml.formatter.compactEntries": true, 23 | "evenBetterToml.formatter.compactInlineTables": true, 24 | "evenBetterToml.formatter.indentEntries": true, 25 | "evenBetterToml.formatter.indentString": " ", 26 | "evenBetterToml.formatter.indentTables": true, 27 | "evenBetterToml.formatter.inlineTableExpand": true, 28 | "evenBetterToml.formatter.reorderArrays": true, 29 | "evenBetterToml.formatter.reorderKeys": true, 30 | "evenBetterToml.formatter.trailingNewline": true, 31 | "evenBetterToml.schema.enabled": true, 32 | "evenBetterToml.schema.links": true, 33 | "evenBetterToml.syntax.semanticTokens": false, 34 | "notebook.formatOnCellExecution": true, 35 | "notebook.formatOnSave.enabled": true, 36 | "python.analysis.autoFormatStrings": true, 37 | "python.analysis.inlayHints.functionReturnTypes": false, 38 | "python.analysis.typeCheckingMode": "basic", 39 | "python.languageServer": "Pylance", 40 | "python.terminal.activateEnvInCurrentTerminal": true, 41 | "python.testing.pytestEnabled": true, 42 | "rewrap.autoWrap.enabled": true, 43 | "rewrap.wrappingColumn": 100, 44 | "pylint.ignorePatterns": [ 45 | "*" 46 | ], 47 | "pythonTestExplorer.testFramework": "pytest" 48 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Alan Cooney 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse Autoencoder 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/sparse_autoencoder?color=blue)](https://pypi.org/project/transformer-lens/) 4 | ![PyPI - 5 | License](https://img.shields.io/pypi/l/sparse_autoencoder?color=blue) [![Checks](https://github.com/alan-cooney/sparse_autoencoder/actions/workflows/checks.yml/badge.svg)](https://github.com/alan-cooney/sparse_autoencoder/actions/workflows/checks.yml) 6 | [![Release](https://github.com/alan-cooney/sparse_autoencoder/actions/workflows/release.yml/badge.svg)](https://github.com/alan-cooney/sparse_autoencoder/actions/workflows/release.yml) 7 | 8 | A sparse autoencoder for mechanistic interpretability research. 9 | 10 | [![Read the Docs 11 | Here](https://img.shields.io/badge/-Read%20the%20Docs%20Here-blue?style=for-the-badge&logo=Read-the-Docs&logoColor=white&link=https://ai-safety-foundation.github.io/sparse_autoencoder/)](https://ai-safety-foundation.github.io/sparse_autoencoder/) 12 | 13 | Train a Sparse Autoencoder [in colab](https://colab.research.google.com/github/ai-safety-foundation/sparse_autoencoder/blob/main/docs/content/demo.ipynb), or install for your project: 14 | 15 | ```shell 16 | pip install sparse_autoencoder 17 | ``` 18 | 19 | ## Features 20 | 21 | This library contains: 22 | 23 | 1. **A sparse autoencoder model**, along with all the underlying PyTorch components you need to 24 | customise and/or build your own: 25 | - Encoder, constrained unit norm decoder and tied bias PyTorch modules in `autoencoder`. 26 | - L1 and L2 loss modules in `loss`. 27 | - Adam module with helper method to reset state in `optimizer`. 28 | 2. **Activations data generator** using TransformerLens, with the underlying steps in case you 29 | want to customise the approach: 30 | - Activation store options (in-memory or on disk) in `activation_store`. 31 | - Hook to get the activations from TransformerLens in an efficient way in `source_model`. 32 | - Source dataset (i.e. prompts to generate these activations) utils in `source_data`, that 33 | stream data from HuggingFace and pre-process (tokenize & shuffle). 34 | 3. **Activation resampler** to help reduce the number of dead neurons. 35 | 4. **Metrics** that log at various stages of training (e.g. during training, resampling and 36 | validation), and integrate with wandb. 37 | 5. **Training pipeline** that combines everything together, allowing you to run hyperparameter 38 | sweeps and view progress on wandb. 39 | 40 | ## Designed for Research 41 | 42 | The library is designed to be modular. By default it takes the approach from [Towards 43 | Monosemanticity: Decomposing Language Models With Dictionary Learning 44 | ](https://transformer-circuits.pub/2023/monosemantic-features/index.html), so you can pip install 45 | the library and get started quickly. Then when you need to customise something, you can just extend 46 | the class for that component (e.g. you can extend `SparseAutoencoder` if you want to customise the 47 | model, and then drop it back into the training pipeline. Every component is fully documented, so 48 | it's nice and easy to do this. 49 | 50 | ## Demo 51 | 52 | Check out the demo notebook [docs/content/demo.ipynb](https://github.com/ai-safety-foundation/sparse_autoencoder/blob/main/docs/content/demo.ipynb) for a guide to using this library. 53 | 54 | ## Contributing 55 | 56 | This project uses [Poetry](https://python-poetry.org) for dependency management, and 57 | [PoeThePoet](https://poethepoet.natn.io/installation.html) for scripts. After checking out the repo, 58 | we recommend setting poetry's config to create the `.venv` in the root directory (note this is a 59 | global setting) and then installing with the dev and demos dependencies. 60 | 61 | ```shell 62 | poetry config virtualenvs.in-project true 63 | poetry install --with dev,demos 64 | ``` 65 | 66 | ### Checks 67 | 68 | For a full list of available commands (e.g. `test` or `typecheck`), run this in your terminal 69 | (assumes the venv is active already). 70 | 71 | ```shell 72 | poe 73 | ``` 74 | -------------------------------------------------------------------------------- /docs/content/SUMMARY.md: -------------------------------------------------------------------------------- 1 | 2 | * [Home](index.md) 3 | * [Demo](demo.ipynb) 4 | * [Source dataset pre-processing](pre-process-datasets.ipynb) 5 | * [Reference](reference/) 6 | * [Contributing](contributing.md) 7 | * [Citation](citation.md) -------------------------------------------------------------------------------- /docs/content/citation.md: -------------------------------------------------------------------------------- 1 | 2 | # Citation 3 | 4 | Please cite this library as: 5 | 6 | ```BibTeX 7 | @misc{cooney2023SparseAutoencoder, 8 | title = {Sparse Autoencoder Library}, 9 | author = {Alan Cooney}, 10 | year = {2023}, 11 | howpublished = {\url{https://github.com/ai-safety-foundation/sparse_autoencoder}}, 12 | } 13 | ``` 14 | -------------------------------------------------------------------------------- /docs/content/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | ## Setup 4 | 5 | This project uses [Poetry](https://python-poetry.org) for dependency management, and 6 | [PoeThePoet](https://poethepoet.natn.io/installation.html) for scripts. After checking out the repo, 7 | we recommend setting poetry's config to create the `.venv` in the root directory (note this is a 8 | global setting) and then installing with the dev and demos dependencies. 9 | 10 | ```shell 11 | poetry config virtualenvs.in-project true 12 | poetry install --with dev,demos 13 | ``` 14 | 15 | If you are using VSCode we highly recommend installing the recommended extensions as well (it will 16 | prompt you to do this when you checkout the repo). 17 | 18 | ## Checks 19 | 20 | For a full list of available commands (e.g. `test` or `typecheck`), run this in your terminal 21 | (assumes the venv is active already). 22 | 23 | ```shell 24 | poe 25 | ``` 26 | 27 | ## Documentation 28 | 29 | Please make sure to add thorough documentation for any features you add. You should do this directly 30 | in the docstring, and this will then automatically generate the API docs when merged into `main`. 31 | They will also be automatically checked with [pytest](https://docs.pytest.org/) (via 32 | [doctest](https://docs.python.org/3/library/doctest.html)). 33 | 34 | If you want to view your documentation changes, run `poe docs-hot-reload`. This will give you 35 | hot-reloading docs (they change in real time as you edit docstrings). 36 | 37 | ### Docstring Style Guide 38 | 39 | We follow the [Google Python Docstring Style](https://google.github.io/styleguide/pyguide.html) for 40 | writing docstrings. Some important details below: 41 | 42 | #### Sections and Order 43 | 44 | You should follow this order: 45 | 46 | ```python 47 | """Title In Title Case. 48 | 49 | A description of what the function/class does, including as much detail as is necessary to fully understand it. 50 | 51 | Warning: 52 | 53 | Any warnings to the user (e.g. common pitfalls). 54 | 55 | Examples: 56 | 57 | Include any examples here. They will be checked with doctest. 58 | 59 | >>> print(1 + 2) 60 | 3 61 | 62 | Args: 63 | param_without_type_signature: 64 | Each description should be indented once more. 65 | param_2: 66 | Another example parameter. 67 | 68 | Returns: 69 | Returns description without type signature. 70 | 71 | Raises: 72 | Information about the error it may raise (if any). 73 | """ 74 | ``` 75 | 76 | #### LaTeX support 77 | 78 | You can use LaTeX, inside `$$` for blocks or `$` for inline 79 | 80 | ```markdown 81 | Some text $(a + b)^2 = a^2 + 2ab + b^2$ 82 | ``` 83 | 84 | ```markdown 85 | Some text: 86 | 87 | $$ 88 | y & = & ax^2 + bx + c \\ 89 | f(x) & = & x^2 + 2xy + y^2 90 | $$ 91 | ``` 92 | 93 | #### Markup 94 | 95 | - Italics - `*text*` 96 | - Bold - `**text**` 97 | - Code - ` ``code`` ` 98 | - List items - `*item` 99 | - Numbered items - `1. Item` 100 | - Quotes - indent one level 101 | - External links = ``` `Link text ` ``` 102 | -------------------------------------------------------------------------------- /docs/content/css/custom_formatting.css: -------------------------------------------------------------------------------- 1 | /* Show the full equation without vertical scroll */ 2 | .md-typeset div.arithmatex { 3 | overflow: visible !important; 4 | } 5 | 6 | /* Preserve new lines in code example blocks */ 7 | .example blockquote blockquote blockquote p { 8 | white-space: pre-wrap; 9 | font-family: monospace; 10 | } -------------------------------------------------------------------------------- /docs/content/css/material_extra.css: -------------------------------------------------------------------------------- 1 | /* Indentation. */ 2 | div.doc-contents:not(.first) { 3 | padding-left: 25px; 4 | border-left: .05rem solid var(--md-typeset-table-color); 5 | } 6 | 7 | /* Mark external links as such. */ 8 | a.external::after, 9 | a.autorefs-external::after { 10 | /* https://primer.style/octicons/arrow-up-right-24 */ 11 | mask-image: url('data:image/svg+xml,'); 12 | -webkit-mask-image: url('data:image/svg+xml,'); 13 | content: ' '; 14 | 15 | display: inline-block; 16 | vertical-align: middle; 17 | position: relative; 18 | 19 | height: 1em; 20 | width: 1em; 21 | background-color: var(--md-typeset-a-color); 22 | } 23 | 24 | a.external:hover::after, 25 | a.autorefs-external:hover::after { 26 | background-color: var(--md-accent-fg-color); 27 | } -------------------------------------------------------------------------------- /docs/content/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\n", 8 | " \"Open\n", 9 | "" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "# Training Demo\n", 17 | "\n", 18 | "This is a quick start demo to get training a SAE right away. All you need to do is choose a few\n", 19 | "hyperparameters (like the model to train on), and then set it off.\n", 20 | "\n", 21 | "In this demo we'll train a sparse autoencoder on all MLP layer outputs in GPT-2 small (effectively\n", 22 | "training an SAE on each layer in parallel)." 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "## Setup" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "### Imports" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 1, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# Check if we're in Colab\n", 46 | "try:\n", 47 | " import google.colab # noqa: F401 # type: ignore\n", 48 | "\n", 49 | " in_colab = True\n", 50 | "except ImportError:\n", 51 | " in_colab = False\n", 52 | "\n", 53 | "# Install if in Colab\n", 54 | "if in_colab:\n", 55 | " %pip install sparse_autoencoder transformer_lens transformers wandb\n", 56 | "\n", 57 | "# Otherwise enable hot reloading in dev mode\n", 58 | "if not in_colab:\n", 59 | " %load_ext autoreload\n", 60 | " %autoreload 2" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 2, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "import os\n", 70 | "\n", 71 | "from sparse_autoencoder import (\n", 72 | " ActivationResamplerHyperparameters,\n", 73 | " AutoencoderHyperparameters,\n", 74 | " Hyperparameters,\n", 75 | " LossHyperparameters,\n", 76 | " Method,\n", 77 | " OptimizerHyperparameters,\n", 78 | " Parameter,\n", 79 | " PipelineHyperparameters,\n", 80 | " SourceDataHyperparameters,\n", 81 | " SourceModelHyperparameters,\n", 82 | " SweepConfig,\n", 83 | " sweep,\n", 84 | ")\n", 85 | "\n", 86 | "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"demo.ipynb\"" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "### Hyperparameters" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "Customize any hyperparameters you want below (by default we're sweeping over l1 coefficient and\n", 101 | "learning rate).\n", 102 | "\n", 103 | "Note we are using the RANDOM sweep approach (try random combinations of hyperparameters), which\n", 104 | "works surprisingly well but will need to be stopped at some point (as otherwise it will continue\n", 105 | "forever). If you want to run pre-defined runs consider using `Parameter(values=[0.01, 0.05...])` for\n", 106 | "example rather than `Parameter(max=0.03, min=0.008)` for each parameter you are sweeping over. You\n", 107 | "can then set the strategy to `Method.GRID`." 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 3, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "def train_gpt_small_mlp_layers(\n", 117 | " expansion_factor: int = 4,\n", 118 | " n_layers: int = 12,\n", 119 | ") -> None:\n", 120 | " \"\"\"Run a new sweep experiment on GPT 2 Small's MLP layers.\n", 121 | "\n", 122 | " Args:\n", 123 | " expansion_factor: Expansion factor for the autoencoder.\n", 124 | " n_layers: Number of layers to train on. Max is 12.\n", 125 | "\n", 126 | " \"\"\"\n", 127 | " sweep_config = SweepConfig(\n", 128 | " parameters=Hyperparameters(\n", 129 | " loss=LossHyperparameters(\n", 130 | " l1_coefficient=Parameter(max=0.03, min=0.008),\n", 131 | " ),\n", 132 | " optimizer=OptimizerHyperparameters(\n", 133 | " lr=Parameter(max=0.001, min=0.00001),\n", 134 | " ),\n", 135 | " source_model=SourceModelHyperparameters(\n", 136 | " name=Parameter(\"gpt2\"),\n", 137 | " cache_names=Parameter(\n", 138 | " [f\"blocks.{layer}.hook_mlp_out\" for layer in range(n_layers)]\n", 139 | " ),\n", 140 | " hook_dimension=Parameter(768),\n", 141 | " ),\n", 142 | " source_data=SourceDataHyperparameters(\n", 143 | " dataset_path=Parameter(\"alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2\"),\n", 144 | " context_size=Parameter(256),\n", 145 | " pre_tokenized=Parameter(value=True),\n", 146 | " pre_download=Parameter(value=False), # Default to streaming the dataset\n", 147 | " ),\n", 148 | " autoencoder=AutoencoderHyperparameters(\n", 149 | " expansion_factor=Parameter(value=expansion_factor)\n", 150 | " ),\n", 151 | " pipeline=PipelineHyperparameters(\n", 152 | " max_activations=Parameter(1_000_000_000),\n", 153 | " checkpoint_frequency=Parameter(100_000_000),\n", 154 | " validation_frequency=Parameter(100_000_000),\n", 155 | " max_store_size=Parameter(1_000_000),\n", 156 | " ),\n", 157 | " activation_resampler=ActivationResamplerHyperparameters(\n", 158 | " resample_interval=Parameter(200_000_000),\n", 159 | " n_activations_activity_collate=Parameter(100_000_000),\n", 160 | " threshold_is_dead_portion_fires=Parameter(1e-6),\n", 161 | " max_n_resamples=Parameter(4),\n", 162 | " ),\n", 163 | " ),\n", 164 | " method=Method.RANDOM,\n", 165 | " )\n", 166 | "\n", 167 | " sweep(sweep_config=sweep_config)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "### Run the sweep" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": {}, 180 | "source": [ 181 | "This will start a sweep with just one agent (the current machine). If you have multiple GPUs, it\n", 182 | "will use them automatically. Similarly it will work on Apple silicon devices by automatically using MPS." 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 4, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "train_gpt_small_mlp_layers()" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "Want to speed things up? You can trivially add extra machines to the sweep, each of which will peel\n", 199 | "of some runs from the sweep agent (stored on Wandb). To do this, on another machine simply run:\n", 200 | "\n", 201 | "```bash\n", 202 | "pip install sparse_autoencoder\n", 203 | "join-sae-sweep --id=SWEEP_ID_SHOWN_ON_WANDB\n", 204 | "```" 205 | ] 206 | } 207 | ], 208 | "metadata": { 209 | "kernelspec": { 210 | "display_name": ".venv", 211 | "language": "python", 212 | "name": "python3" 213 | }, 214 | "language_info": { 215 | "codemirror_mode": { 216 | "name": "ipython", 217 | "version": 3 218 | }, 219 | "file_extension": ".py", 220 | "mimetype": "text/x-python", 221 | "name": "python", 222 | "nbconvert_exporter": "python", 223 | "pygments_lexer": "ipython3", 224 | "version": "3.11.6" 225 | }, 226 | "vscode": { 227 | "interpreter": { 228 | "hash": "31186ba1239ad81afeb3c631b4833e71f34259d3b92eebb37a9091b916e08620" 229 | } 230 | } 231 | }, 232 | "nbformat": 4, 233 | "nbformat_minor": 2 234 | } 235 | -------------------------------------------------------------------------------- /docs/content/index.md: -------------------------------------------------------------------------------- 1 | # Sparse Autoencoder 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/sparse_autoencoder?color=blue)](https://pypi.org/project/transformer-lens/) 4 | ![PyPI - License](https://img.shields.io/pypi/l/sparse_autoencoder?color=blue) 5 | [![Checks](https://github.com/alan-cooney/sparse_autoencoder/actions/workflows/checks.yml/badge.svg)](https://github.com/alan-cooney/sparse_autoencoder/actions/workflows/checks.yml) 6 | [![Release](https://github.com/alan-cooney/sparse_autoencoder/actions/workflows/release.yml/badge.svg)](https://github.com/alan-cooney/sparse_autoencoder/actions/workflows/release.yml) 7 | 8 | A sparse autoencoder for mechanistic interpretability research. 9 | 10 | ```shell 11 | pip install sparse_autoencoder 12 | ``` 13 | 14 | ## Quick Start 15 | 16 | Check out the [demo notebook](demo) for a guide to using this library. 17 | 18 | We also highly recommend skimming the reference docs to see all the features that are available. 19 | 20 | ## Features 21 | 22 | This library contains: 23 | 24 | 1. **A sparse autoencoder model**, along with all the underlying PyTorch components you need to 25 | customise and/or build your own: 26 | - Encoder, constrained unit norm decoder and tied bias PyTorch modules in 27 | [sparse_autoencoder.autoencoder][]. 28 | - Adam module with helper method to reset state in [sparse_autoencoder.optimizer][]. 29 | 2. **Activations data generator** using TransformerLens, with the underlying steps in case you 30 | want to customise the approach: 31 | - Activation store options (in-memory or on disk) in [sparse_autoencoder.activation_store][]. 32 | - Hook to get the activations from TransformerLens in an efficient way in 33 | [sparse_autoencoder.source_model][]. 34 | - Source dataset (i.e. prompts to generate these activations) utils in 35 | [sparse_autoencoder.source_data][], that stream data from HuggingFace and pre-process 36 | (tokenize & shuffle). 37 | 3. **Activation resampler** to help reduce the number of dead neurons. 38 | 4. **Metrics** that log at various stages of training (loss, train metrics and validation metrics) 39 | , based on torchmetrics. 40 | 5. **Training pipeline** that combines everything together, allowing you to run hyperparameter 41 | sweeps and view progress on wandb. 42 | 43 | ## Designed for Research 44 | 45 | The library is designed to be modular. By default it takes the approach from [Towards 46 | Monosemanticity: Decomposing Language Models With Dictionary Learning 47 | ](https://transformer-circuits.pub/2023/monosemantic-features/index.html), so you can pip install 48 | the library and get started quickly. Then when you need to customise something, you can just extend 49 | the abstract class for that component (every component is documented so that it's easy to do this). 50 | -------------------------------------------------------------------------------- /docs/content/javascript/custom_formatting.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Expand the 'Reference' section of the navigation 3 | * 4 | * @param {boolean} remove_icon - Whether to remove the '>' icon 5 | */ 6 | function expandReferenceSection(remove_icon = false) { 7 | // Find all labels in the navigation menu 8 | const navLabels = document.querySelectorAll('.md-nav__item label'); 9 | 10 | // Search for the label with the text 'Reference' 11 | Array.from(navLabels).forEach(label => { 12 | if (label.textContent.trim() === 'Reference') { 13 | // Find the associated checkbox to expand the section 14 | const toggleInput = label.previousElementSibling; 15 | if (toggleInput && toggleInput.tagName === 'INPUT') { 16 | toggleInput.checked = true; 17 | } 18 | 19 | // Find the '>' icon and hide it 20 | if (remove_icon) { 21 | const icon = label.querySelector('.md-nav__icon'); 22 | if (icon) { 23 | icon.style.display = 'none'; 24 | } 25 | } 26 | } 27 | }); 28 | } 29 | 30 | /** 31 | * Hides the Table of Contents (TOC) section if it only contains one link. 32 | */ 33 | function hideSingleItemTOC() { 34 | // Find the TOC list 35 | const tocList = document.querySelector('.md-nav--secondary .md-nav__list'); 36 | 37 | if (tocList) { 38 | // Count the number of list items (links) in the TOC 39 | const itemCount = tocList.querySelectorAll('li').length; 40 | 41 | // If there is only one item, hide the entire TOC section 42 | if (itemCount === 1) { 43 | console.log("only one") 44 | const tocSection = document.querySelector('.md-sidebar--secondary[data-md-component="sidebar"][data-md-type="toc"]'); 45 | if (tocSection) { 46 | tocSection.style.display = 'none'; 47 | } 48 | } 49 | } 50 | } 51 | 52 | function main() { 53 | document.addEventListener("DOMContentLoaded", function () { 54 | // Expand the 'Reference' section of the navigation 55 | expandReferenceSection(); 56 | 57 | // Hide the Table of Contents (TOC) section if it only contains one link 58 | hideSingleItemTOC(); 59 | }) 60 | } 61 | 62 | main(); -------------------------------------------------------------------------------- /docs/content/javascript/mathjax.js: -------------------------------------------------------------------------------- 1 | window.MathJax = { 2 | tex: { 3 | inlineMath: [["\\(", "\\)"]], 4 | displayMath: [["\\[", "\\]"]], 5 | processEscapes: true, 6 | processEnvironments: true 7 | }, 8 | options: { 9 | ignoreHtmlClass: ".*|", 10 | processHtmlClass: "arithmatex" 11 | } 12 | }; 13 | 14 | document$.subscribe(() => { 15 | MathJax.typesetPromise() 16 | }) -------------------------------------------------------------------------------- /docs/gen_ref_pages.py: -------------------------------------------------------------------------------- 1 | """Auto-generate reference documentation based on docstrings.""" 2 | from pathlib import Path 3 | import shutil 4 | 5 | import mkdocs_gen_files 6 | 7 | 8 | CURRENT_DIR = Path(__file__).parent 9 | REPO_ROOT = CURRENT_DIR.parent 10 | PROJECT_ROOT = REPO_ROOT / "sparse_autoencoder" 11 | REFERENCE_DIR = CURRENT_DIR / "content/reference" 12 | 13 | 14 | def reset_reference_dir() -> None: 15 | """Reset the reference directory to its initial state.""" 16 | # Unlink the directory including all files 17 | shutil.rmtree(REFERENCE_DIR, ignore_errors=True) 18 | REFERENCE_DIR.mkdir(parents=True, exist_ok=True) 19 | 20 | 21 | def is_source_file(file: Path) -> bool: 22 | """Check if the provided file is a source file for Sparse Encoder. 23 | 24 | Args: 25 | file: The file path to check. 26 | 27 | Returns: 28 | bool: True if the file is a source file, False otherwise. 29 | """ 30 | return "test" not in str(file) 31 | 32 | 33 | def process_path(path: Path) -> tuple[Path, Path, Path]: 34 | """Process the given path for documentation generation. 35 | 36 | Args: 37 | path: The file path to process. 38 | 39 | Returns: 40 | A tuple containing module path, documentation path, and full documentation path. 41 | """ 42 | module_path = path.relative_to(PROJECT_ROOT).with_suffix("") 43 | doc_path = path.relative_to(PROJECT_ROOT).with_suffix(".md") 44 | full_doc_path = Path(REFERENCE_DIR, doc_path) 45 | 46 | if module_path.name == "__init__": 47 | module_path = module_path.parent 48 | doc_path = doc_path.with_name("index.md") 49 | full_doc_path = full_doc_path.with_name("index.md") 50 | 51 | return module_path, doc_path, full_doc_path 52 | 53 | 54 | def generate_documentation(path: Path, module_path: Path, full_doc_path: Path) -> None: 55 | """Generate documentation for the given source file. 56 | 57 | Args: 58 | path: The source file path. 59 | module_path: The module path. 60 | full_doc_path: The full documentation file path. 61 | """ 62 | if module_path.name == "__main__": 63 | return 64 | 65 | # Get the mkdocstrings identifier for the module 66 | parts = list(module_path.parts) 67 | parts.insert(0, "sparse_autoencoder") 68 | identifier = ".".join(parts) 69 | 70 | # Read the first line of the file docstring, and set as the header 71 | with path.open() as fd: 72 | first_line = fd.readline() 73 | first_line_without_docstring = first_line.replace('"""', "").strip() 74 | first_line_without_last_dot = first_line_without_docstring.rstrip(".") 75 | title = first_line_without_last_dot or module_path.name 76 | 77 | with mkdocs_gen_files.open(full_doc_path, "w") as fd: 78 | fd.write(f"# {title}" + "\n\n" + f"::: {identifier}") 79 | 80 | mkdocs_gen_files.set_edit_path(full_doc_path, path) 81 | 82 | 83 | def generate_nav_file(nav: mkdocs_gen_files.nav.Nav, reference_dir: Path) -> None: 84 | """Generate the navigation file for the documentation. 85 | 86 | Args: 87 | nav: The navigation object. 88 | reference_dir: The directory to write the navigation file. 89 | """ 90 | with mkdocs_gen_files.open(reference_dir / "SUMMARY.md", "w") as nav_file: 91 | nav_file.writelines(nav.build_literate_nav()) 92 | 93 | 94 | def run() -> None: 95 | """Handle the generation of reference documentation for Sparse Encoder.""" 96 | reset_reference_dir() 97 | nav = mkdocs_gen_files.Nav() # type: ignore 98 | 99 | python_files = PROJECT_ROOT.rglob("*.py") 100 | source_files = filter(is_source_file, python_files) 101 | 102 | for path in sorted(source_files): 103 | module_path, doc_path, full_doc_path = process_path(path) 104 | generate_documentation(path, module_path, full_doc_path) 105 | 106 | url_slug_parts = list(module_path.parts) 107 | 108 | # Don't create a page for the main __init__.py file (as this includes most exports). 109 | if not url_slug_parts: 110 | continue 111 | 112 | nav[url_slug_parts] = doc_path.as_posix() # type: ignore 113 | 114 | generate_nav_file(nav, REFERENCE_DIR) 115 | 116 | 117 | if __name__ == "__main__": 118 | run() 119 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Sparse Autoencoder 2 | site_description: Sparse Autoencoder for Mechanistic Interpretability 3 | docs_dir: docs/content 4 | site_dir: docs/generated 5 | repo_url: https://github.com/ai-safety-foundation/sparse_autoencoder 6 | repo_name: ai-safety-foundation/sparse_autoencoder 7 | edit_uri: "" # Disabled as we use mkdocstrings which auto-generates some pages 8 | strict: true 9 | 10 | theme: 11 | name: material 12 | palette: 13 | - scheme: default 14 | primary: teal 15 | accent: amber 16 | toggle: 17 | icon: material/weather-night 18 | name: Switch to dark mode 19 | - scheme: slate 20 | primary: black 21 | accent: amber 22 | toggle: 23 | icon: material/weather-sunny 24 | name: Switch to light mode 25 | icon: 26 | repo: fontawesome/brands/github # GitHub logo in top right 27 | features: 28 | - content.action.edit 29 | 30 | extra_javascript: 31 | - javascript/custom_formatting.js 32 | # The below three make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ 33 | - javascript/mathjax.js 34 | - https://polyfill.io/v3/polyfill.min.js?features=es6 35 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 36 | 37 | extra_css: 38 | - "css/material_extra.css" 39 | - "css/custom_formatting.css" 40 | 41 | markdown_extensions: 42 | - pymdownx.arithmatex: # Render LaTeX via MathJax 43 | generic: true 44 | - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. 45 | - pymdownx.magiclink 46 | - pymdownx.saneheaders 47 | - pymdownx.details # Allowing hidden expandable regions denoted by ??? 48 | - pymdownx.snippets: # Include one Markdown file into another 49 | base_path: docs/content 50 | - admonition # Adds admonition blocks (e.g. warning, note, tip, etc.) 51 | - toc: 52 | permalink: "¤" # Adds a clickable permalink to each section heading 53 | toc_depth: 4 54 | 55 | plugins: 56 | - search 57 | - autorefs 58 | - section-index 59 | - literate-nav: 60 | nav_file: SUMMARY.md 61 | - mknotebooks 62 | - mkdocstrings: # https://mkdocstrings.github.io 63 | handlers: 64 | python: 65 | setup_commands: 66 | - import pytkdocs_tweaks 67 | - pytkdocs_tweaks.main() 68 | - import jaxtyping 69 | - jaxtyping.set_array_name_format("array") 70 | options: 71 | docstring_style: google 72 | line_length: 100 73 | show_symbol_type_heading: true 74 | edit_uri: "" 75 | - htmlproofer: 76 | raise_error: True 77 | -------------------------------------------------------------------------------- /sparse_autoencoder/__init__.py: -------------------------------------------------------------------------------- 1 | """Sparse Autoencoder Library.""" 2 | from sparse_autoencoder.activation_resampler.activation_resampler import ActivationResampler 3 | from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore 4 | from sparse_autoencoder.autoencoder.model import SparseAutoencoder, SparseAutoencoderConfig 5 | from sparse_autoencoder.metrics.loss.l1_absolute_loss import L1AbsoluteLoss 6 | from sparse_autoencoder.metrics.loss.l2_reconstruction_loss import L2ReconstructionLoss 7 | from sparse_autoencoder.metrics.loss.sae_loss import SparseAutoencoderLoss 8 | from sparse_autoencoder.metrics.train.capacity import CapacityMetric 9 | from sparse_autoencoder.metrics.train.feature_density import FeatureDensityMetric 10 | from sparse_autoencoder.metrics.train.l0_norm import L0NormMetric 11 | from sparse_autoencoder.metrics.train.neuron_activity import NeuronActivityMetric 12 | from sparse_autoencoder.metrics.train.neuron_fired_count import NeuronFiredCountMetric 13 | from sparse_autoencoder.metrics.validate.reconstruction_score import ReconstructionScoreMetric 14 | from sparse_autoencoder.optimizer.adam_with_reset import AdamWithReset 15 | from sparse_autoencoder.source_data.pretokenized_dataset import PreTokenizedDataset 16 | from sparse_autoencoder.source_data.text_dataset import TextDataset 17 | from sparse_autoencoder.train.pipeline import Pipeline 18 | from sparse_autoencoder.train.sweep import sweep 19 | from sparse_autoencoder.train.sweep_config import ( 20 | ActivationResamplerHyperparameters, 21 | AutoencoderHyperparameters, 22 | Hyperparameters, 23 | LossHyperparameters, 24 | OptimizerHyperparameters, 25 | PipelineHyperparameters, 26 | SourceDataHyperparameters, 27 | SourceModelHyperparameters, 28 | SourceModelRuntimeHyperparameters, 29 | SweepConfig, 30 | ) 31 | from sparse_autoencoder.train.utils.wandb_sweep_types import ( 32 | Controller, 33 | ControllerType, 34 | Distribution, 35 | Goal, 36 | HyperbandStopping, 37 | HyperbandStoppingType, 38 | Impute, 39 | ImputeWhileRunning, 40 | Kind, 41 | Method, 42 | Metric, 43 | NestedParameter, 44 | Parameter, 45 | ) 46 | 47 | 48 | __all__ = [ 49 | "ActivationResampler", 50 | "ActivationResamplerHyperparameters", 51 | "AdamWithReset", 52 | "AutoencoderHyperparameters", 53 | "CapacityMetric", 54 | "Controller", 55 | "ControllerType", 56 | "Distribution", 57 | "FeatureDensityMetric", 58 | "Goal", 59 | "HyperbandStopping", 60 | "HyperbandStoppingType", 61 | "Hyperparameters", 62 | "Impute", 63 | "ImputeWhileRunning", 64 | "Kind", 65 | "L0NormMetric", 66 | "L1AbsoluteLoss", 67 | "L2ReconstructionLoss", 68 | "LossHyperparameters", 69 | "Method", 70 | "Metric", 71 | "NestedParameter", 72 | "NeuronActivityMetric", 73 | "NeuronFiredCountMetric", 74 | "OptimizerHyperparameters", 75 | "Parameter", 76 | "Pipeline", 77 | "PipelineHyperparameters", 78 | "PreTokenizedDataset", 79 | "ReconstructionScoreMetric", 80 | "SourceDataHyperparameters", 81 | "SourceModelHyperparameters", 82 | "SourceModelRuntimeHyperparameters", 83 | "SparseAutoencoder", 84 | "SparseAutoencoderConfig", 85 | "SparseAutoencoderLoss", 86 | "sweep", 87 | "SweepConfig", 88 | "TensorActivationStore", 89 | "TextDataset", 90 | ] 91 | -------------------------------------------------------------------------------- /sparse_autoencoder/activation_resampler/__init__.py: -------------------------------------------------------------------------------- 1 | """Activation Resampler.""" 2 | -------------------------------------------------------------------------------- /sparse_autoencoder/activation_resampler/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Activation resampler utils.""" 2 | -------------------------------------------------------------------------------- /sparse_autoencoder/activation_resampler/utils/component_slice_tensor.py: -------------------------------------------------------------------------------- 1 | """Component slice tensor utils.""" 2 | from torch import Tensor 3 | 4 | 5 | def get_component_slice_tensor( 6 | input_tensor: Tensor, 7 | n_dim_with_component: int, 8 | component_dim: int, 9 | component_idx: int, 10 | ) -> Tensor: 11 | """Get a slice of a tensor for a specific component. 12 | 13 | Examples: 14 | >>> import torch 15 | >>> input_tensor = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) 16 | >>> get_component_slice_tensor(input_tensor, 2, 1, 0) 17 | tensor([1, 3, 5, 7]) 18 | 19 | >>> input_tensor = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) 20 | >>> get_component_slice_tensor(input_tensor, 3, 1, 0) 21 | tensor([[1, 2], 22 | [3, 4], 23 | [5, 6], 24 | [7, 8]]) 25 | 26 | Args: 27 | input_tensor: Input tensor. 28 | n_dim_with_component: Number of dimensions in the input tensor with the component axis. 29 | component_dim: Dimension of the component axis. 30 | component_idx: Index of the component to get the slice for. 31 | 32 | Returns: 33 | Tensor slice. 34 | 35 | Raises: 36 | ValueError: If the input tensor does not have the expected number of dimensions. 37 | """ 38 | if n_dim_with_component - 1 == input_tensor.ndim: 39 | return input_tensor 40 | 41 | if n_dim_with_component != input_tensor.ndim: 42 | error_message = ( 43 | f"Cannot get component slice for tensor with {input_tensor.ndim} dimensions " 44 | f"and {n_dim_with_component} dimensions with component." 45 | ) 46 | raise ValueError(error_message) 47 | 48 | # Create a tuple of slices for each dimension 49 | slice_tuple = tuple( 50 | component_idx if i == component_dim else slice(None) for i in range(input_tensor.ndim) 51 | ) 52 | 53 | return input_tensor[slice_tuple] 54 | -------------------------------------------------------------------------------- /sparse_autoencoder/activation_store/__init__.py: -------------------------------------------------------------------------------- 1 | """Activation Stores.""" 2 | -------------------------------------------------------------------------------- /sparse_autoencoder/activation_store/base_store.py: -------------------------------------------------------------------------------- 1 | """Activation Store Base Class.""" 2 | from abc import ABC, abstractmethod 3 | from concurrent.futures import Future 4 | from typing import final 5 | 6 | from jaxtyping import Float 7 | from pydantic import PositiveInt, validate_call 8 | import torch 9 | from torch import Tensor 10 | from torch.utils.data import Dataset 11 | 12 | from sparse_autoencoder.tensor_types import Axis 13 | 14 | 15 | class ActivationStore( 16 | Dataset[Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)]], ABC 17 | ): 18 | """Activation Store Abstract Class. 19 | 20 | Extends the `torch.utils.data.Dataset` class to provide an activation store, with additional 21 | :meth:`append` and :meth:`extend` methods (the latter of which should typically be 22 | non-blocking). The resulting activation store can be used with a `torch.utils.data.DataLoader` 23 | to iterate over the dataset. 24 | 25 | Extend this class if you want to create a new activation store (noting you also need to create 26 | `__getitem__` and `__len__` methods from the underlying `torch.utils.data.Dataset` class). 27 | 28 | Example: 29 | >>> import torch 30 | >>> class MyActivationStore(ActivationStore): 31 | ... 32 | ... @property 33 | ... def current_activations_stored_per_component(self): 34 | ... raise NotImplementedError 35 | ... 36 | ... @property 37 | ... def n_components(self): 38 | ... raise NotImplementedError 39 | ... 40 | ... def __init__(self): 41 | ... super().__init__() 42 | ... self._data = [] # In this example, we just store in a list 43 | ... 44 | ... def append(self, item) -> None: 45 | ... self._data.append(item) 46 | ... 47 | ... def extend(self, batch): 48 | ... self._data.extend(batch) 49 | ... 50 | ... def empty(self): 51 | ... self._data = [] 52 | ... 53 | ... def __getitem__(self, index: int): 54 | ... return self._data[index] 55 | ... 56 | ... def __len__(self) -> int: 57 | ... return len(self._data) 58 | ... 59 | >>> store = MyActivationStore() 60 | >>> store.append(torch.randn(100)) 61 | >>> print(len(store)) 62 | 1 63 | """ 64 | 65 | @abstractmethod 66 | def append( 67 | self, 68 | item: Float[Tensor, Axis.names(Axis.INPUT_OUTPUT_FEATURE)], 69 | component_idx: int, 70 | ) -> Future | None: 71 | """Add a Single Item to the Store.""" 72 | 73 | @abstractmethod 74 | def extend( 75 | self, 76 | batch: Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)], 77 | component_idx: int, 78 | ) -> Future | None: 79 | """Add a Batch to the Store.""" 80 | 81 | @abstractmethod 82 | def empty(self) -> None: 83 | """Empty the Store.""" 84 | 85 | @property 86 | @abstractmethod 87 | def n_components(self) -> int: 88 | """Number of components.""" 89 | 90 | @property 91 | @abstractmethod 92 | def current_activations_stored_per_component(self) -> list[int]: 93 | """Current activations stored per component.""" 94 | 95 | @abstractmethod 96 | def __len__(self) -> int: 97 | """Get the Length of the Store.""" 98 | 99 | @abstractmethod 100 | def __getitem__( 101 | self, index: tuple[int, ...] | slice | int 102 | ) -> Float[Tensor, Axis.names(Axis.ANY)]: 103 | """Get an Item from the Store.""" 104 | 105 | def shuffle(self) -> None: 106 | """Optional shuffle method.""" 107 | 108 | @final 109 | @validate_call 110 | def fill_with_test_data( 111 | self, 112 | n_batches: PositiveInt = 1, 113 | batch_size: PositiveInt = 16, 114 | n_components: PositiveInt = 1, 115 | input_features: PositiveInt = 256, 116 | ) -> None: 117 | """Fill the store with test data. 118 | 119 | For use when testing your code, to ensure it works with a real activation store. 120 | 121 | Warning: 122 | You may want to use `torch.seed(0)` to make the random data deterministic, if your test 123 | requires inspecting the data itself. 124 | 125 | Example: 126 | >>> from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore 127 | >>> store = TensorActivationStore(max_items=100, n_neurons=256, n_components=1) 128 | >>> store.fill_with_test_data(batch_size=100) 129 | >>> len(store) 130 | 100 131 | 132 | Args: 133 | n_batches: Number of batches to fill the store with. 134 | batch_size: Number of items per batch. 135 | n_components: Number of source model components the SAE is trained on. 136 | input_features: Number of input features per item. 137 | """ 138 | for _ in range(n_batches): 139 | for component_idx in range(n_components): 140 | sample = torch.rand(batch_size, input_features) 141 | self.extend(sample, component_idx) 142 | 143 | 144 | class StoreFullError(IndexError): 145 | """Exception raised when the activation store is full.""" 146 | 147 | def __init__(self, message: str = "Activation store is full"): 148 | """Initialise the exception. 149 | 150 | Args: 151 | message: Override the default message. 152 | """ 153 | super().__init__(message) 154 | -------------------------------------------------------------------------------- /sparse_autoencoder/activation_store/tests/test_tensor_store.py: -------------------------------------------------------------------------------- 1 | """Tests for the TensorActivationStore.""" 2 | import pytest 3 | import torch 4 | 5 | from sparse_autoencoder.activation_store.base_store import StoreFullError 6 | from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore 7 | 8 | 9 | def test_works_with_2_components() -> None: 10 | """Test that it works with 2 components.""" 11 | n_neurons: int = 128 12 | n_batches: int = 10 13 | batch_size: int = 16 14 | n_components: int = 2 15 | 16 | store = TensorActivationStore( 17 | max_items=int(n_batches * batch_size), n_neurons=n_neurons, n_components=n_components 18 | ) 19 | 20 | # Fill the store 21 | for component_idx in range(n_components): 22 | for _ in range(n_batches): 23 | store.extend(torch.rand(batch_size, n_neurons), component_idx=component_idx) 24 | 25 | # Check the size 26 | assert len(store) == int(n_batches * batch_size) 27 | 28 | 29 | def test_appending_more_than_max_items_raises() -> None: 30 | """Test that appending more than the max items raises an error.""" 31 | n_neurons: int = 128 32 | store = TensorActivationStore(max_items=1, n_neurons=n_neurons, n_components=1) 33 | store.append(torch.rand(n_neurons), component_idx=0) 34 | 35 | with pytest.raises(StoreFullError): 36 | store.append(torch.rand(n_neurons), component_idx=0) 37 | 38 | 39 | def test_extending_more_than_max_items_raises() -> None: 40 | """Test that extending more than the max items raises an error.""" 41 | n_neurons: int = 128 42 | store = TensorActivationStore(max_items=6, n_neurons=n_neurons, n_components=1) 43 | store.extend(torch.rand(4, n_neurons), component_idx=0) 44 | 45 | with pytest.raises(StoreFullError): 46 | store.extend(torch.rand(4, n_neurons), component_idx=0) 47 | 48 | 49 | def test_getting_out_of_range_raises() -> None: 50 | """Test that getting an out of range index raises an error.""" 51 | n_neurons: int = 128 52 | store = TensorActivationStore(max_items=1, n_neurons=n_neurons, n_components=1) 53 | store.append(torch.rand(n_neurons), component_idx=0) 54 | 55 | with pytest.raises(IndexError): 56 | store[1, 0] 57 | -------------------------------------------------------------------------------- /sparse_autoencoder/autoencoder/__init__.py: -------------------------------------------------------------------------------- 1 | """Sparse autoencoder model & components.""" 2 | -------------------------------------------------------------------------------- /sparse_autoencoder/autoencoder/components/__init__.py: -------------------------------------------------------------------------------- 1 | """Sparse autoencoder components.""" 2 | -------------------------------------------------------------------------------- /sparse_autoencoder/autoencoder/components/tests/__snapshots__/test_linear_encoder.ambr: -------------------------------------------------------------------------------- 1 | # serializer version: 1 2 | # name: test_extra_repr 3 | ''' 4 | LinearEncoder( 5 | input_features=2, learnt_features=4, n_components=2 6 | (activation_function): ReLU() 7 | ) 8 | ''' 9 | # --- 10 | # name: test_forward_pass_result_matches_the_snapshot[1] 11 | tensor([[[0.0000, 0.0166, 0.2664, 0.1434]], 12 | 13 | [[0.2154, 0.0000, 0.2007, 0.0000]], 14 | 15 | [[0.3107, 0.0000, 0.1137, 0.0000]]], grad_fn=) 16 | # --- 17 | # name: test_forward_pass_result_matches_the_snapshot[3] 18 | tensor([[[1.0238, 0.0000, 0.0000, 0.0583], 19 | [0.2885, 0.2002, 0.5567, 0.0000], 20 | [0.0000, 0.0000, 0.0000, 0.0000]], 21 | 22 | [[0.8789, 0.0000, 0.0000, 0.1781], 23 | [0.2351, 0.0000, 0.3539, 0.0000], 24 | [0.0000, 0.1716, 0.2969, 0.2289]], 25 | 26 | [[0.9740, 0.0000, 0.0000, 0.0802], 27 | [0.1881, 0.0000, 0.3415, 0.0000], 28 | [0.0000, 0.1251, 0.2831, 0.1824]]], grad_fn=) 29 | # --- 30 | # name: test_forward_pass_result_matches_the_snapshot[None] 31 | tensor([[0.0000, 0.0595, 0.4903, 0.2879], 32 | [0.4522, 0.0000, 0.3589, 0.0000], 33 | [0.6428, 0.0000, 0.1849, 0.0000]], grad_fn=) 34 | # --- 35 | -------------------------------------------------------------------------------- /sparse_autoencoder/autoencoder/components/tests/test_linear_encoder.py: -------------------------------------------------------------------------------- 1 | """Linear encoder tests.""" 2 | from jaxtyping import Float, Int64 3 | import pytest 4 | from syrupy.session import SnapshotSession 5 | import torch 6 | from torch import Tensor 7 | 8 | from sparse_autoencoder.autoencoder.components.linear_encoder import LinearEncoder 9 | from sparse_autoencoder.tensor_types import Axis 10 | from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions 11 | 12 | 13 | # Constants for testing 14 | INPUT_FEATURES = 2 15 | LEARNT_FEATURES = 4 16 | N_COMPONENTS = 2 17 | BATCH_SIZE = 3 18 | 19 | 20 | @pytest.fixture() 21 | def encoder() -> LinearEncoder: 22 | """Fixture to create a LinearEncoder instance.""" 23 | torch.manual_seed(0) 24 | return LinearEncoder( 25 | input_features=INPUT_FEATURES, learnt_features=LEARNT_FEATURES, n_components=N_COMPONENTS 26 | ) 27 | 28 | 29 | def test_reset_parameters(encoder: LinearEncoder) -> None: 30 | """Test resetting of parameters.""" 31 | old_weight = encoder.weight.clone() 32 | old_bias = encoder.bias.clone() 33 | encoder.reset_parameters() 34 | assert not torch.equal(encoder.weight, old_weight) 35 | assert not torch.equal(encoder.bias, old_bias) 36 | 37 | 38 | def test_forward_pass(encoder: LinearEncoder) -> None: 39 | """Test the forward pass of the LinearEncoder.""" 40 | input_tensor = torch.randn(BATCH_SIZE, N_COMPONENTS, INPUT_FEATURES) 41 | output = encoder.forward(input_tensor) 42 | assert output.shape == (BATCH_SIZE, N_COMPONENTS, LEARNT_FEATURES) 43 | 44 | 45 | def test_extra_repr(encoder: LinearEncoder, snapshot: SnapshotSession) -> None: 46 | """Test the string representation of the LinearEncoder.""" 47 | assert snapshot == str(encoder), "Model string representation has changed." 48 | 49 | 50 | @pytest.mark.parametrize("n_components", [None, 1, 3]) 51 | def test_forward_pass_result_matches_the_snapshot( 52 | n_components: int | None, snapshot: SnapshotSession 53 | ) -> None: 54 | """Test the forward pass of the LinearEncoder.""" 55 | torch.manual_seed(1) 56 | input_tensor = torch.rand( 57 | shape_with_optional_dimensions(BATCH_SIZE, n_components, INPUT_FEATURES) 58 | ) 59 | encoder = LinearEncoder( 60 | input_features=INPUT_FEATURES, learnt_features=LEARNT_FEATURES, n_components=n_components 61 | ) 62 | output = encoder.forward(input_tensor) 63 | assert snapshot == output 64 | 65 | 66 | def test_output_same_without_component_dim_vs_with_1_component() -> None: 67 | """Test the forward pass gives identical results for None and 1 component.""" 68 | # Create the layers to compare 69 | encoder_without_components_dim = LinearEncoder( 70 | input_features=INPUT_FEATURES, learnt_features=LEARNT_FEATURES, n_components=None 71 | ) 72 | encoder_with_1_component = LinearEncoder( 73 | input_features=INPUT_FEATURES, learnt_features=LEARNT_FEATURES, n_components=1 74 | ) 75 | 76 | # Set the weight and value parameters to be the same 77 | encoder_with_1_component.weight = torch.nn.Parameter( 78 | encoder_without_components_dim.weight.unsqueeze(0) 79 | ) 80 | encoder_with_1_component.bias = torch.nn.Parameter( 81 | encoder_without_components_dim.bias.unsqueeze(0) 82 | ) 83 | 84 | # Create the input 85 | input_tensor = torch.rand(BATCH_SIZE, INPUT_FEATURES) 86 | input_with_components_dim = input_tensor.unsqueeze(1) 87 | 88 | # Check the output is the same 89 | output_without_components_dim = encoder_without_components_dim(input_tensor) 90 | output_with_1_component = encoder_with_1_component(input_with_components_dim) 91 | 92 | assert torch.allclose(output_without_components_dim, output_with_1_component.squeeze(1)) 93 | 94 | 95 | def test_update_dictionary_vectors_with_no_neurons(encoder: LinearEncoder) -> None: 96 | """Test update_dictionary_vectors with 0 neurons to update.""" 97 | torch.random.manual_seed(0) 98 | original_weight = encoder.weight.clone() # Save original weight for comparison 99 | 100 | dictionary_vector_indices: Int64[Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX)] = torch.empty( 101 | 0, dtype=torch.int64 102 | ) 103 | 104 | updates: Float[ 105 | Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX, Axis.INPUT_OUTPUT_FEATURE) 106 | ] = torch.empty((0, INPUT_FEATURES), dtype=torch.float) 107 | 108 | encoder.update_dictionary_vectors(dictionary_vector_indices, updates, component_idx=0) 109 | 110 | # Ensure weight did not change when no indices were provided 111 | assert torch.equal( 112 | encoder.weight, original_weight 113 | ), "Weights should not change when no indices are provided." 114 | 115 | 116 | @pytest.mark.parametrize( 117 | ("dictionary_vector_indices", "updates"), 118 | [ 119 | pytest.param(torch.tensor([1]), torch.rand((1, INPUT_FEATURES)), id="update 1 neuron"), 120 | pytest.param( 121 | torch.tensor([0, 1]), 122 | torch.rand((2, INPUT_FEATURES)), 123 | id="update 2 neurons with different values", 124 | ), 125 | ], 126 | ) 127 | def test_update_dictionary_vectors_with_neurons( 128 | encoder: LinearEncoder, 129 | dictionary_vector_indices: Int64[Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX)], 130 | updates: Float[Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX, Axis.INPUT_OUTPUT_FEATURE)], 131 | ) -> None: 132 | """Test update_dictionary_vectors with 1 or 2 neurons to update.""" 133 | with torch.no_grad(): 134 | component_idx = 0 135 | encoder.update_dictionary_vectors( 136 | dictionary_vector_indices, updates, component_idx=component_idx 137 | ) 138 | 139 | # Check if the specified neurons are updated correctly 140 | assert torch.allclose( 141 | encoder.weight[component_idx, dictionary_vector_indices, :], updates 142 | ), "update_dictionary_vectors should update the weights correctly." 143 | -------------------------------------------------------------------------------- /sparse_autoencoder/autoencoder/components/tests/test_tied_bias.py: -------------------------------------------------------------------------------- 1 | """Tied Bias Tests.""" 2 | from jaxtyping import Float 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import Parameter 6 | 7 | from sparse_autoencoder.autoencoder.components.tied_bias import TiedBias, TiedBiasPosition 8 | from sparse_autoencoder.tensor_types import Axis 9 | 10 | 11 | def test_pre_encoder_subtracts_bias() -> None: 12 | """Check that the pre-encoder bias subtracts the bias.""" 13 | encoder_input: Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)] = torch.tensor( 14 | [[5.0, 3.0, 1.0]] 15 | ) 16 | bias = Parameter(torch.tensor([2.0, 4.0, 6.0])) 17 | expected = encoder_input - bias 18 | 19 | pre_encoder = TiedBias(bias, TiedBiasPosition.PRE_ENCODER) 20 | output = pre_encoder(encoder_input) 21 | 22 | assert torch.allclose(output, expected) 23 | 24 | 25 | def test_post_encoder_adds_bias() -> None: 26 | """Check that the post-encoder bias adds the bias.""" 27 | decoder_output: Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)] = torch.tensor( 28 | [[5.0, 3.0, 1.0]] 29 | ) 30 | bias = Parameter(torch.tensor([2.0, 4.0, 6.0])) 31 | expected = decoder_output + bias 32 | 33 | post_decoder = TiedBias(bias, TiedBiasPosition.POST_DECODER) 34 | output = post_decoder(decoder_output) 35 | 36 | assert torch.allclose(output, expected) 37 | -------------------------------------------------------------------------------- /sparse_autoencoder/autoencoder/components/tests/test_unit_norm_decoder.py: -------------------------------------------------------------------------------- 1 | """Tests for the constrained unit norm linear layer.""" 2 | from jaxtyping import Float, Int64 3 | import pytest 4 | import torch 5 | from torch import Tensor 6 | 7 | from sparse_autoencoder.autoencoder.components.unit_norm_decoder import UnitNormDecoder 8 | from sparse_autoencoder.tensor_types import Axis 9 | 10 | 11 | DEFAULT_N_LEARNT_FEATURES = 3 12 | DEFAULT_N_DECODED_FEATURES = 4 13 | DEFAULT_N_COMPONENTS = 2 14 | 15 | 16 | @pytest.fixture() 17 | def decoder() -> UnitNormDecoder: 18 | """Pytest fixture to provide a MockDecoder instance.""" 19 | return UnitNormDecoder( 20 | learnt_features=DEFAULT_N_LEARNT_FEATURES, 21 | decoded_features=DEFAULT_N_DECODED_FEATURES, 22 | n_components=DEFAULT_N_COMPONENTS, 23 | ) 24 | 25 | 26 | def test_initialization() -> None: 27 | """Test that the weights are initialized with unit norm.""" 28 | layer = UnitNormDecoder(learnt_features=3, decoded_features=4, n_components=None) 29 | weight_norms = torch.norm(layer.weight, dim=0) 30 | assert torch.allclose(weight_norms, torch.ones_like(weight_norms)) 31 | 32 | 33 | def test_forward_pass() -> None: 34 | """Test the forward pass of the layer.""" 35 | layer = UnitNormDecoder(learnt_features=3, decoded_features=4, n_components=None) 36 | input_tensor = torch.randn(5, 3) # Batch size of 5, learnt features of 3 37 | output = layer(input_tensor) 38 | assert output.shape == (5, 4) # Batch size of 5, decoded features of 4 39 | 40 | 41 | def test_multiple_training_steps() -> None: 42 | """Test the unit norm constraint over multiple training iterations.""" 43 | torch.random.manual_seed(0) 44 | layer = UnitNormDecoder(learnt_features=3, decoded_features=4, n_components=None) 45 | optimizer = torch.optim.Adam(layer.parameters(), lr=0.01) 46 | for _ in range(4): 47 | data = torch.randn(5, 3) 48 | optimizer.zero_grad() 49 | logits = layer(data) 50 | 51 | loss = torch.mean(logits**2) 52 | loss.backward() 53 | optimizer.step() 54 | layer.constrain_weights_unit_norm() 55 | 56 | columns_norms = torch.norm(layer.weight, dim=0) 57 | assert torch.allclose(columns_norms, torch.ones_like(columns_norms)) 58 | 59 | 60 | def test_unit_norm_decreases() -> None: 61 | """Check that the unit norm is applied after each gradient step.""" 62 | for _ in range(4): 63 | data = torch.randn((1, 3), requires_grad=True) 64 | 65 | # run with the backward hook 66 | layer = UnitNormDecoder(learnt_features=3, decoded_features=4, n_components=None) 67 | layer_weights = torch.nn.Parameter(layer.weight.clone()) 68 | optimizer = torch.optim.SGD(layer.parameters(), lr=0.1, momentum=0) 69 | logits = layer(data) 70 | loss = torch.mean(logits**2) 71 | loss.backward() 72 | optimizer.step() 73 | weight_norms_with_hook = torch.sum(layer.weight**2, dim=0).clone() 74 | 75 | # Run without the hook 76 | layer_without_hook = UnitNormDecoder( 77 | learnt_features=3, decoded_features=4, n_components=None, enable_gradient_hook=False 78 | ) 79 | layer_without_hook.weight = layer_weights 80 | optimizer_without_hook = torch.optim.SGD( 81 | layer_without_hook.parameters(), lr=0.1, momentum=0 82 | ) 83 | logits_without_hook = layer_without_hook(data) 84 | loss_without_hook = torch.mean(logits_without_hook**2) 85 | loss_without_hook.backward() 86 | optimizer_without_hook.step() 87 | weight_norms_without_hook = torch.sum(layer_without_hook.weight**2, dim=0).clone() 88 | 89 | # Check that the norm with the hook is closer to 1 than without the hook 90 | target_norms = torch.ones_like(weight_norms_with_hook) 91 | absolute_diff_with_hook = torch.abs(weight_norms_with_hook - target_norms) 92 | absolute_diff_without_hook = torch.abs(weight_norms_without_hook - target_norms) 93 | assert torch.all(absolute_diff_with_hook < absolute_diff_without_hook) 94 | 95 | 96 | def test_output_same_without_component_dim_vs_with_1_component() -> None: 97 | """Test the forward pass gives identical results for None and 1 component.""" 98 | decoded_features = 2 99 | learnt_features = 4 100 | batch_size = 1 101 | 102 | # Create the layers to compare 103 | torch.manual_seed(1) 104 | decoder_without_components_dim = UnitNormDecoder( 105 | decoded_features=decoded_features, learnt_features=learnt_features, n_components=None 106 | ) 107 | torch.manual_seed(1) 108 | decoder_with_1_component = UnitNormDecoder( 109 | decoded_features=decoded_features, learnt_features=learnt_features, n_components=1 110 | ) 111 | 112 | # Create the input 113 | input_tensor = torch.randn(batch_size, learnt_features) 114 | input_with_components_dim = input_tensor.unsqueeze(1) 115 | 116 | # Check the output is the same 117 | output_without_components_dim = decoder_without_components_dim(input_tensor) 118 | output_with_1_component = decoder_with_1_component(input_with_components_dim) 119 | assert torch.allclose(output_without_components_dim, output_with_1_component.squeeze(1)) 120 | 121 | 122 | def test_update_dictionary_vectors_with_no_neurons(decoder: UnitNormDecoder) -> None: 123 | """Test update_dictionary_vectors with 0 neurons to update.""" 124 | original_weight = decoder.weight.clone() # Save original weight for comparison 125 | 126 | dictionary_vector_indices: Int64[ 127 | Tensor, Axis.names(Axis.COMPONENT, Axis.INPUT_OUTPUT_FEATURE) 128 | ] = torch.empty((0, 0), dtype=torch.int64) 129 | 130 | updates: Float[ 131 | Tensor, Axis.names(Axis.COMPONENT, Axis.INPUT_OUTPUT_FEATURE, Axis.DEAD_FEATURE) 132 | ] = torch.empty((0, 0, 0), dtype=torch.float) 133 | 134 | decoder.update_dictionary_vectors(dictionary_vector_indices, updates) 135 | 136 | # Ensure weight did not change when no indices were provided 137 | assert torch.equal( 138 | decoder.weight, original_weight 139 | ), "Weights should not change when no indices are provided." 140 | 141 | 142 | @pytest.mark.parametrize( 143 | ("dictionary_vector_indices", "updates"), 144 | [ 145 | pytest.param(torch.tensor([1]), torch.rand(4, 1), id="One neuron to update"), 146 | pytest.param( 147 | torch.tensor([0, 2]), 148 | torch.rand(4, 2), 149 | id="Two neurons to update", 150 | ), 151 | ], 152 | ) 153 | def test_update_dictionary_vectors_with_neurons( 154 | decoder: UnitNormDecoder, 155 | dictionary_vector_indices: Int64[Tensor, Axis.INPUT_OUTPUT_FEATURE], 156 | updates: Float[Tensor, Axis.names(Axis.INPUT_OUTPUT_FEATURE, Axis.DEAD_FEATURE)], 157 | ) -> None: 158 | """Test update_dictionary_vectors with 1 or 2 neurons to update.""" 159 | decoder.update_dictionary_vectors(dictionary_vector_indices, updates, component_idx=0) 160 | 161 | # Check if the specified neurons are updated correctly 162 | assert torch.allclose( 163 | decoder.weight[0, :, dictionary_vector_indices], updates 164 | ), "update_dictionary_vectors should update the weights correctly." 165 | -------------------------------------------------------------------------------- /sparse_autoencoder/autoencoder/components/tied_bias.py: -------------------------------------------------------------------------------- 1 | """Tied Biases (Pre-Encoder and Post-Decoder).""" 2 | from enum import Enum 3 | from typing import final 4 | 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | from torch.nn import Module, Parameter 8 | 9 | from sparse_autoencoder.tensor_types import Axis 10 | 11 | 12 | class TiedBiasPosition(str, Enum): 13 | """Tied Bias Position.""" 14 | 15 | PRE_ENCODER = "pre_encoder" 16 | POST_DECODER = "post_decoder" 17 | 18 | 19 | @final 20 | class TiedBias(Module): 21 | """Tied Bias Layer. 22 | 23 | The tied pre-encoder bias is a learned bias term that is subtracted from the input before 24 | encoding, and added back after decoding. 25 | 26 | The bias parameter must be initialised in the parent module, and then passed to this layer. 27 | 28 | https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-bias 29 | """ 30 | 31 | _bias_position: TiedBiasPosition 32 | 33 | _bias_reference: Float[ 34 | Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) 35 | ] 36 | 37 | @property 38 | def bias( 39 | self, 40 | ) -> Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)]: 41 | """Bias.""" 42 | return self._bias_reference 43 | 44 | def __init__( 45 | self, 46 | bias_reference: Float[ 47 | Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) 48 | ], 49 | position: TiedBiasPosition, 50 | ) -> None: 51 | """Initialize the bias layer. 52 | 53 | Args: 54 | bias_reference: Tied bias parameter (initialised in the parent module), used for both 55 | the pre-encoder and post-encoder bias. The original paper initialised this using the 56 | geometric median of the dataset. 57 | position: Whether this is the pre-encoder or post-encoder bias. 58 | """ 59 | super().__init__() 60 | 61 | self._bias_reference = bias_reference 62 | 63 | # Support string literals as well as enums 64 | self._bias_position = position 65 | 66 | def forward( 67 | self, 68 | x: Float[ 69 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) 70 | ], 71 | ) -> Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)]: 72 | """Forward Pass. 73 | 74 | Args: 75 | x: Input tensor. 76 | 77 | Returns: 78 | Output of the forward pass. 79 | """ 80 | # If this is the pre-encoder bias, we subtract the bias from the input. 81 | if self._bias_position == TiedBiasPosition.PRE_ENCODER: 82 | return x - self.bias 83 | 84 | # If it's the post-encoder bias, we add the bias to the input. 85 | return x + self.bias 86 | 87 | def extra_repr(self) -> str: 88 | """String extra representation of the module.""" 89 | return f"position={self._bias_position.value}" 90 | -------------------------------------------------------------------------------- /sparse_autoencoder/autoencoder/tests/__snapshots__/test_model.ambr: -------------------------------------------------------------------------------- 1 | # serializer version: 1 2 | # name: test_representation 3 | ''' 4 | SparseAutoencoder( 5 | (pre_encoder_bias): TiedBias(position=pre_encoder) 6 | (encoder): LinearEncoder( 7 | input_features=3, learnt_features=6, n_components=None 8 | (activation_function): ReLU() 9 | ) 10 | (decoder): UnitNormDecoder(learnt_features=6, decoded_features=3, n_components=None) 11 | (post_decoder_bias): TiedBias(position=post_decoder) 12 | ) 13 | ''' 14 | # --- 15 | -------------------------------------------------------------------------------- /sparse_autoencoder/autoencoder/types.py: -------------------------------------------------------------------------------- 1 | """Autoencoder types.""" 2 | from typing import NamedTuple 3 | 4 | from torch.nn import Parameter 5 | 6 | 7 | class ResetOptimizerParameterDetails(NamedTuple): 8 | """Reset Optimizer Parameter Details. 9 | 10 | Details of a parameter that should be reset in the optimizer, when resetting 11 | its corresponding dictionary vectors. 12 | """ 13 | 14 | parameter: Parameter 15 | """Parameter to reset.""" 16 | 17 | axis: int 18 | """Axis of the parameter to reset.""" 19 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | """Metrics. 2 | 3 | All metrics are based on torchmetrics, which means they support distributed training and can be 4 | combined with other metrics easily. 5 | """ 6 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/loss/__init__.py: -------------------------------------------------------------------------------- 1 | """Loss metrics.""" 2 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/loss/l1_absolute_loss.py: -------------------------------------------------------------------------------- 1 | """L1 (absolute error) loss.""" 2 | from typing import Any 3 | 4 | from jaxtyping import Float, Int64 5 | from pydantic import PositiveInt, validate_call 6 | import torch 7 | from torch import Tensor 8 | from torchmetrics import Metric 9 | 10 | from sparse_autoencoder.tensor_types import Axis 11 | 12 | 13 | class L1AbsoluteLoss(Metric): 14 | """L1 (absolute error) loss. 15 | 16 | L1 loss penalty is the absolute sum of the learned activations, averaged over the number of 17 | activation vectors. 18 | 19 | Example: 20 | >>> l1_loss = L1AbsoluteLoss() 21 | >>> learned_activations = torch.tensor([ 22 | ... [ # Batch 1 23 | ... [1., 0., 1.] # Component 1: learned features (L1 of 2) 24 | ... ], 25 | ... [ # Batch 2 26 | ... [0., 1., 0.] # Component 1: learned features (L1 of 1) 27 | ... ] 28 | ... ]) 29 | >>> l1_loss.forward(learned_activations=learned_activations) 30 | tensor(1.5000) 31 | """ 32 | 33 | # Torchmetrics settings 34 | is_differentiable: bool | None = True 35 | full_state_update: bool | None = False 36 | plot_lower_bound: float | None = 0.0 37 | 38 | # Settings 39 | _num_components: int 40 | _keep_batch_dim: bool 41 | 42 | @property 43 | def keep_batch_dim(self) -> bool: 44 | """Whether to keep the batch dimension in the loss output.""" 45 | return self._keep_batch_dim 46 | 47 | @keep_batch_dim.setter 48 | def keep_batch_dim(self, keep_batch_dim: bool) -> None: 49 | """Set whether to keep the batch dimension in the loss output. 50 | 51 | When setting this we need to change the state to either a list if keeping the batch 52 | dimension (so we can accumulate all the losses and concatenate them at the end along this 53 | dimension). Alternatively it should be a tensor if not keeping the batch dimension (so we 54 | can sum the losses over the batch dimension during update and then take the mean). 55 | 56 | By doing this in a setter we allow changing of this setting after the metric is initialised. 57 | """ 58 | self._keep_batch_dim = keep_batch_dim 59 | self.reset() # Reset the metric to update the state 60 | if keep_batch_dim and not isinstance(self.absolute_loss, list): 61 | self.add_state( 62 | "absolute_loss", 63 | default=[], 64 | dist_reduce_fx="sum", 65 | ) 66 | elif not isinstance(self.absolute_loss, Tensor): 67 | self.add_state( 68 | "absolute_loss", 69 | default=torch.zeros(self._num_components), 70 | dist_reduce_fx="sum", 71 | ) 72 | 73 | # State 74 | absolute_loss: Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL)] | list[ 75 | Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)] 76 | ] | None = None 77 | num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM] 78 | 79 | @validate_call 80 | def __init__( 81 | self, 82 | num_components: PositiveInt = 1, 83 | *, 84 | keep_batch_dim: bool = False, 85 | ) -> None: 86 | """Initialize the metric. 87 | 88 | Args: 89 | num_components: Number of components. 90 | keep_batch_dim: Whether to keep the batch dimension in the loss output. 91 | """ 92 | super().__init__() 93 | self._num_components = num_components 94 | self.keep_batch_dim = keep_batch_dim 95 | self.add_state( 96 | "num_activation_vectors", 97 | default=torch.tensor(0, dtype=torch.int64), 98 | dist_reduce_fx="sum", 99 | ) 100 | 101 | @staticmethod 102 | def calculate_abs_sum( 103 | learned_activations: Float[ 104 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) 105 | ], 106 | ) -> Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)]: 107 | """Calculate the absolute sum of the learned activations. 108 | 109 | Args: 110 | learned_activations: Learned activations (intermediate activations in the autoencoder). 111 | 112 | Returns: 113 | Absolute sum of the learned activations (keeping the batch and component axis). 114 | """ 115 | return torch.abs(learned_activations).sum(dim=-1) 116 | 117 | def update( 118 | self, 119 | learned_activations: Float[ 120 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) 121 | ], 122 | **kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics) 123 | ) -> None: 124 | """Update the metric state. 125 | 126 | If we're keeping the batch dimension, we simply take the absolute sum of the activations 127 | (over the features dimension) and then append this tensor to a list. Then during compute we 128 | just concatenate and return this list. This is useful for e.g. getting L1 loss by batch item 129 | when resampling neurons (see the neuron resampler for details). 130 | 131 | By contrast if we're averaging over the batch dimension, we sum the activations over the 132 | batch dimension during update (on each process), and then divide by the number of activation 133 | vectors on compute to get the mean. 134 | 135 | Args: 136 | learned_activations: Learned activations (intermediate activations in the autoencoder). 137 | **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection). 138 | """ 139 | absolute_loss = self.calculate_abs_sum(learned_activations) 140 | 141 | if self.keep_batch_dim: 142 | self.absolute_loss.append(absolute_loss) # type: ignore 143 | else: 144 | self.absolute_loss += absolute_loss.sum(dim=0) 145 | self.num_activation_vectors += learned_activations.shape[0] 146 | 147 | def compute(self) -> Tensor: 148 | """Compute the metric.""" 149 | return ( 150 | torch.cat(self.absolute_loss) # type: ignore 151 | if self.keep_batch_dim 152 | else self.absolute_loss / self.num_activation_vectors 153 | ) 154 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/loss/l2_reconstruction_loss.py: -------------------------------------------------------------------------------- 1 | """L2 Reconstruction loss.""" 2 | from typing import Any 3 | 4 | from jaxtyping import Float, Int64 5 | from pydantic import PositiveInt, validate_call 6 | import torch 7 | from torch import Tensor 8 | from torchmetrics import Metric 9 | 10 | from sparse_autoencoder.tensor_types import Axis 11 | 12 | 13 | class L2ReconstructionLoss(Metric): 14 | """L2 Reconstruction loss (MSE). 15 | 16 | L2 reconstruction loss is calculated as the sum squared error between each each input vector 17 | and it's corresponding decoded vector. The original paper found that models trained with some 18 | loss functions such as cross-entropy loss generally prefer to represent features 19 | polysemantically, whereas models trained with L2 may achieve the same loss for both 20 | polysemantic and monosemantic representations of true features. 21 | 22 | Example: 23 | >>> import torch 24 | >>> loss = L2ReconstructionLoss(num_components=1) 25 | >>> source_activations = torch.tensor([ 26 | ... [ # Batch 1 27 | ... [4., 2.] # Component 1 28 | ... ], 29 | ... [ # Batch 2 30 | ... [2., 0.] # Component 1 31 | ... ] 32 | ... ]) 33 | >>> decoded_activations = torch.tensor([ 34 | ... [ # Batch 1 35 | ... [2., 0.] # Component 1 (MSE of 4) 36 | ... ], 37 | ... [ # Batch 2 38 | ... [0., 0.] # Component 1 (MSE of 2) 39 | ... ] 40 | ... ]) 41 | >>> loss.forward( 42 | ... decoded_activations=decoded_activations, source_activations=source_activations 43 | ... ) 44 | tensor(3.) 45 | """ 46 | 47 | # Torchmetrics settings 48 | is_differentiable: bool | None = True 49 | higher_is_better = False 50 | full_state_update: bool | None = False 51 | plot_lower_bound: float | None = 0.0 52 | 53 | # Settings 54 | _num_components: int 55 | _keep_batch_dim: bool 56 | 57 | @property 58 | def keep_batch_dim(self) -> bool: 59 | """Whether to keep the batch dimension in the loss output.""" 60 | return self._keep_batch_dim 61 | 62 | @keep_batch_dim.setter 63 | def keep_batch_dim(self, keep_batch_dim: bool) -> None: 64 | """Set whether to keep the batch dimension in the loss output. 65 | 66 | When setting this we need to change the state to either a list if keeping the batch 67 | dimension (so we can accumulate all the losses and concatenate them at the end along this 68 | dimension). Alternatively it should be a tensor if not keeping the batch dimension (so we 69 | can sum the losses over the batch dimension during update and then take the mean). 70 | 71 | By doing this in a setter we allow changing of this setting after the metric is initialised. 72 | """ 73 | self._keep_batch_dim = keep_batch_dim 74 | self.reset() # Reset the metric to update the state 75 | if keep_batch_dim and not isinstance(self.mse, list): 76 | self.add_state( 77 | "mse", 78 | default=[], 79 | dist_reduce_fx="sum", 80 | ) 81 | elif not isinstance(self.mse, Tensor): 82 | self.add_state( 83 | "mse", 84 | default=torch.zeros(self._num_components), 85 | dist_reduce_fx="sum", 86 | ) 87 | 88 | # State 89 | mse: Float[Tensor, Axis.COMPONENT_OPTIONAL] | list[ 90 | Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)] 91 | ] | None = None 92 | num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM] 93 | 94 | @validate_call 95 | def __init__( 96 | self, 97 | num_components: PositiveInt = 1, 98 | *, 99 | keep_batch_dim: bool = False, 100 | ) -> None: 101 | """Initialise the L2 reconstruction loss.""" 102 | super().__init__() 103 | self._num_components = num_components 104 | self.keep_batch_dim = keep_batch_dim 105 | self.add_state( 106 | "num_activation_vectors", 107 | default=torch.tensor(0, dtype=torch.int64), 108 | dist_reduce_fx="sum", 109 | ) 110 | 111 | @staticmethod 112 | def calculate_mse( 113 | decoded_activations: Float[ 114 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) 115 | ], 116 | source_activations: Float[ 117 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) 118 | ], 119 | ) -> Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)]: 120 | """Calculate the MSE.""" 121 | return (decoded_activations - source_activations).pow(2).mean(dim=-1) 122 | 123 | def update( 124 | self, 125 | decoded_activations: Float[ 126 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) 127 | ], 128 | source_activations: Float[ 129 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) 130 | ], 131 | **kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics) 132 | ) -> None: 133 | """Update the metric state. 134 | 135 | If we're keeping the batch dimension, we simply take the mse of the activations 136 | (over the features dimension) and then append this tensor to a list. Then during compute we 137 | just concatenate and return this list. This is useful for e.g. getting L1 loss by batch item 138 | when resampling neurons (see the neuron resampler for details). 139 | 140 | By contrast if we're averaging over the batch dimension, we sum the activations over the 141 | batch dimension during update (on each process), and then divide by the number of activation 142 | vectors on compute to get the mean. 143 | 144 | Args: 145 | decoded_activations: The decoded activations from the autoencoder. 146 | source_activations: The source activations from the autoencoder. 147 | **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection). 148 | """ 149 | mse = self.calculate_mse(decoded_activations, source_activations) 150 | 151 | if self.keep_batch_dim: 152 | self.mse.append(mse) # type: ignore 153 | else: 154 | self.mse += mse.sum(dim=0) 155 | self.num_activation_vectors += source_activations.shape[0] 156 | 157 | def compute(self) -> Float[Tensor, Axis.COMPONENT_OPTIONAL]: 158 | """Compute the metric.""" 159 | return ( 160 | torch.cat(self.mse) # type: ignore 161 | if self.keep_batch_dim 162 | else self.mse / self.num_activation_vectors 163 | ) 164 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/loss/sae_loss.py: -------------------------------------------------------------------------------- 1 | """Sparse Autoencoder loss.""" 2 | from typing import Any 3 | 4 | from jaxtyping import Float, Int64 5 | from pydantic import PositiveFloat, PositiveInt, validate_call 6 | import torch 7 | from torch import Tensor 8 | from torchmetrics import Metric 9 | 10 | from sparse_autoencoder.metrics.loss.l1_absolute_loss import L1AbsoluteLoss 11 | from sparse_autoencoder.metrics.loss.l2_reconstruction_loss import L2ReconstructionLoss 12 | from sparse_autoencoder.tensor_types import Axis 13 | 14 | 15 | class SparseAutoencoderLoss(Metric): 16 | """Sparse Autoencoder loss. 17 | 18 | This is the same as composing `L1AbsoluteLoss() * l1_coefficient + L2ReconstructionLoss()`. It 19 | is separated out so that you can use all three metrics (l1, l2, total loss) in the same 20 | `MetricCollection` and they will then share state (to avoid calculating the same thing twice). 21 | """ 22 | 23 | # Torchmetrics settings 24 | is_differentiable: bool | None = True 25 | higher_is_better = False 26 | full_state_update: bool | None = False 27 | plot_lower_bound: float | None = 0.0 28 | 29 | # Settings 30 | _num_components: int 31 | _keep_batch_dim: bool 32 | _l1_coefficient: float 33 | 34 | @property 35 | def keep_batch_dim(self) -> bool: 36 | """Whether to keep the batch dimension in the loss output.""" 37 | return self._keep_batch_dim 38 | 39 | @keep_batch_dim.setter 40 | def keep_batch_dim(self, keep_batch_dim: bool) -> None: 41 | """Set whether to keep the batch dimension in the loss output. 42 | 43 | When setting this we need to change the state to either a list if keeping the batch 44 | dimension (so we can accumulate all the losses and concatenate them at the end along this 45 | dimension). Alternatively it should be a tensor if not keeping the batch dimension (so we 46 | can sum the losses over the batch dimension during update and then take the mean). 47 | 48 | By doing this in a setter we allow changing of this setting after the metric is initialised. 49 | """ 50 | self._keep_batch_dim = keep_batch_dim 51 | self.reset() # Reset the metric to update the state 52 | if keep_batch_dim and not isinstance(self.mse, list): 53 | self.add_state( 54 | "mse", 55 | default=[], 56 | dist_reduce_fx="sum", 57 | ) 58 | self.add_state( 59 | "absolute_loss", 60 | default=[], 61 | dist_reduce_fx="sum", 62 | ) 63 | elif not isinstance(self.mse, Tensor): 64 | self.add_state( 65 | "mse", 66 | default=torch.zeros(self._num_components), 67 | dist_reduce_fx="sum", 68 | ) 69 | self.add_state( 70 | "absolute_loss", 71 | default=torch.zeros(self._num_components), 72 | dist_reduce_fx="sum", 73 | ) 74 | 75 | # State 76 | absolute_loss: Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL)] | list[ 77 | Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)] 78 | ] | None = None 79 | mse: Float[Tensor, Axis.COMPONENT_OPTIONAL] | list[ 80 | Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)] 81 | ] | None = None 82 | num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM] 83 | 84 | @validate_call 85 | def __init__( 86 | self, 87 | num_components: PositiveInt = 1, 88 | l1_coefficient: PositiveFloat = 0.001, 89 | *, 90 | keep_batch_dim: bool = False, 91 | ): 92 | """Initialise the metric.""" 93 | super().__init__() 94 | self._num_components = num_components 95 | self.keep_batch_dim = keep_batch_dim 96 | self._l1_coefficient = l1_coefficient 97 | 98 | # Add the state 99 | self.add_state( 100 | "num_activation_vectors", 101 | default=torch.tensor(0, dtype=torch.int64), 102 | dist_reduce_fx="sum", 103 | ) 104 | 105 | def update( 106 | self, 107 | source_activations: Float[ 108 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) 109 | ], 110 | learned_activations: Float[ 111 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) 112 | ], 113 | decoded_activations: Float[ 114 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) 115 | ], 116 | **kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics)) 117 | ) -> None: 118 | """Update the metric.""" 119 | absolute_loss = L1AbsoluteLoss.calculate_abs_sum(learned_activations) 120 | mse = L2ReconstructionLoss.calculate_mse(decoded_activations, source_activations) 121 | 122 | if self.keep_batch_dim: 123 | self.absolute_loss.append(absolute_loss) # type: ignore 124 | self.mse.append(mse) # type: ignore 125 | else: 126 | self.absolute_loss += absolute_loss.sum(dim=0) 127 | self.mse += mse.sum(dim=0) 128 | self.num_activation_vectors += learned_activations.shape[0] 129 | 130 | def compute(self) -> Tensor: 131 | """Compute the metric.""" 132 | l1 = ( 133 | torch.cat(self.absolute_loss) # type: ignore 134 | if self.keep_batch_dim 135 | else self.absolute_loss / self.num_activation_vectors 136 | ) 137 | 138 | l2 = ( 139 | torch.cat(self.mse) # type: ignore 140 | if self.keep_batch_dim 141 | else self.mse / self.num_activation_vectors 142 | ) 143 | 144 | return l1 * self._l1_coefficient + l2 145 | 146 | def forward( # type: ignore[override] (narrowing) 147 | self, 148 | source_activations: Float[ 149 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) 150 | ], 151 | learned_activations: Float[ 152 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) 153 | ], 154 | decoded_activations: Float[ 155 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) 156 | ], 157 | ) -> Tensor: 158 | """Forward pass.""" 159 | return super().forward( 160 | source_activations=source_activations, 161 | learned_activations=learned_activations, 162 | decoded_activations=decoded_activations, 163 | ) 164 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/loss/tests/test_l1_absolute_loss.py: -------------------------------------------------------------------------------- 1 | """Test the L1 absolute loss metric.""" 2 | from jaxtyping import Float 3 | import pytest 4 | from torch import Tensor, allclose, ones, tensor, zeros 5 | 6 | from sparse_autoencoder.metrics.loss.l1_absolute_loss import L1AbsoluteLoss 7 | from sparse_autoencoder.tensor_types import Axis 8 | 9 | 10 | @pytest.mark.parametrize( 11 | # Each source/decoded tensor is of the form (batch_size, num_components, num_features) 12 | ("learned_activations", "expected_loss"), 13 | [ 14 | pytest.param( 15 | zeros(2, 3), 16 | tensor(0.0), 17 | id="All zero activations -> zero loss (single component)", 18 | ), 19 | pytest.param( 20 | zeros(2, 2, 3), 21 | tensor([0.0, 0.0]), 22 | id="All zero activations -> zero loss (2 components)", 23 | ), 24 | pytest.param( 25 | ones(2, 3), # 3 features -> 3.0 loss 26 | tensor(3.0), 27 | id="All 1.0 activations -> 3.0 loss (single component)", 28 | ), 29 | pytest.param( 30 | ones(2, 2, 3), 31 | tensor([3.0, 3.0]), 32 | id="All 1.0 activations -> 3.0 loss (2 components)", 33 | ), 34 | pytest.param( 35 | ones(2, 2, 3) * -1, # Loss is absolute so the same as +ve 1s 36 | tensor([3.0, 3.0]), 37 | id="All -ve 1.0 activations -> 3.0 loss (2 components)", 38 | ), 39 | ], 40 | ) 41 | def test_l1_absolute_loss( 42 | learned_activations: Float[ 43 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) 44 | ], 45 | expected_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL], 46 | ) -> None: 47 | """Test the L1 absolute loss.""" 48 | num_components: int = learned_activations.shape[1] if learned_activations.ndim == 3 else 1 # noqa: PLR2004 49 | l1 = L1AbsoluteLoss(num_components) 50 | 51 | res = l1.forward(learned_activations=learned_activations) 52 | 53 | assert allclose(res, expected_loss) 54 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/loss/tests/test_l2_reconstruction_loss.py: -------------------------------------------------------------------------------- 1 | """Test the L2 reconstruction loss metric.""" 2 | from jaxtyping import Float 3 | import pytest 4 | from torch import Tensor, allclose, ones, tensor, zeros 5 | 6 | from sparse_autoencoder.metrics.loss.l2_reconstruction_loss import L2ReconstructionLoss 7 | from sparse_autoencoder.tensor_types import Axis 8 | 9 | 10 | @pytest.mark.parametrize( 11 | # Each source/decoded tensor is of the form (batch_size, num_components, num_features) 12 | ("source_activations", "decoded_activations", "expected_loss"), 13 | [ 14 | pytest.param( 15 | ones(2, 3), 16 | ones(2, 3), 17 | tensor(0.0), 18 | id="Perfect reconstruction -> zero loss (single component)", 19 | ), 20 | pytest.param( 21 | ones(2, 2, 3), 22 | ones(2, 2, 3), 23 | tensor([0.0, 0.0]), 24 | id="Perfect reconstruction -> zero loss (2 components)", 25 | ), 26 | pytest.param( 27 | zeros(2, 3), 28 | ones(2, 3), 29 | tensor(1.0), 30 | id="All errors 1.0 -> 1.0 loss (single component)", 31 | ), 32 | pytest.param( 33 | zeros(2, 2, 3), 34 | ones(2, 2, 3), 35 | tensor([1.0, 1.0]), 36 | id="All errors 1.0 -> 1.0 loss (2 components)", 37 | ), 38 | ], 39 | ) 40 | def test_l2_reconstruction_loss( 41 | source_activations: Float[ 42 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) 43 | ], 44 | decoded_activations: Float[ 45 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) 46 | ], 47 | expected_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL], 48 | ) -> None: 49 | """Test the L2 reconstruction loss.""" 50 | num_components: int = source_activations.shape[1] if source_activations.ndim == 3 else 1 # noqa: PLR2004 51 | l2 = L2ReconstructionLoss(num_components) 52 | 53 | res = l2.forward(decoded_activations=decoded_activations, source_activations=source_activations) 54 | 55 | assert allclose(res, expected_loss) 56 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/loss/tests/test_sae_loss.py: -------------------------------------------------------------------------------- 1 | """Test the sparse autoencoder loss metric.""" 2 | from jaxtyping import Float 3 | import pytest 4 | from torch import Tensor, allclose, ones, rand, tensor, zeros 5 | 6 | from sparse_autoencoder.metrics.loss.l1_absolute_loss import L1AbsoluteLoss 7 | from sparse_autoencoder.metrics.loss.l2_reconstruction_loss import L2ReconstructionLoss 8 | from sparse_autoencoder.metrics.loss.sae_loss import SparseAutoencoderLoss 9 | from sparse_autoencoder.tensor_types import Axis 10 | 11 | 12 | @pytest.mark.parametrize( 13 | # Each source/decoded tensor is of the form (batch_size, num_components, num_features) 14 | ( 15 | "source_activations", 16 | "learned_activations", 17 | "decoded_activations", 18 | "l1_coefficient", 19 | "expected_loss", 20 | ), 21 | [ 22 | pytest.param( 23 | ones(2, 3), 24 | zeros(2, 4), # Fully sparse = no activity 25 | ones(2, 3), 26 | 0.01, 27 | tensor(0.0), 28 | id="Perfect reconstruction & perfect sparsity -> zero loss (single component)", 29 | ), 30 | pytest.param( 31 | ones(2, 2, 3), 32 | zeros(2, 2, 4), 33 | ones(2, 2, 3), 34 | 0.01, 35 | tensor([0.0, 0.0]), 36 | id="Perfect reconstruction & perfect sparsity -> zero loss (2 components)", 37 | ), 38 | pytest.param( 39 | ones(2, 3), 40 | ones(2, 4), # Abs error of 1.0 per component => average of 4 loss 41 | ones(2, 3), 42 | 0.01, 43 | tensor(0.04), 44 | id="Just sparsity error (single component)", 45 | ), 46 | pytest.param( 47 | ones(2, 2, 3), 48 | ones(2, 2, 4), 49 | ones(2, 2, 3), 50 | 0.01, 51 | tensor([0.04, 0.04]), 52 | id="Just sparsity error (2 components)", 53 | ), 54 | pytest.param( 55 | zeros(2, 3), 56 | zeros(2, 4), 57 | ones(2, 3), 58 | 0.01, 59 | tensor(1.0), 60 | id="Just reconstruction error (single component)", 61 | ), 62 | pytest.param( 63 | zeros(2, 2, 3), 64 | zeros(2, 2, 4), 65 | ones(2, 2, 3), 66 | 0.01, 67 | tensor([1.0, 1.0]), 68 | id="Just reconstruction error (2 components)", 69 | ), 70 | pytest.param( 71 | zeros(2, 3), 72 | ones(2, 4), 73 | ones(2, 3), 74 | 0.01, 75 | tensor(1.04), 76 | id="Sparsity and reconstruction error (single component)", 77 | ), 78 | pytest.param( 79 | zeros(2, 2, 3), 80 | ones(2, 2, 4), 81 | ones(2, 2, 3), 82 | 0.01, 83 | tensor([1.04, 1.04]), 84 | id="Sparsity and reconstruction error (2 components)", 85 | ), 86 | ], 87 | ) 88 | def test_sae_loss( 89 | source_activations: Float[ 90 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) 91 | ], 92 | learned_activations: Float[ 93 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) 94 | ], 95 | decoded_activations: Float[ 96 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) 97 | ], 98 | l1_coefficient: float, 99 | expected_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL], 100 | ) -> None: 101 | """Test the SAE loss.""" 102 | num_components: int = source_activations.shape[1] if source_activations.ndim == 3 else 1 # noqa: PLR2004 103 | metric = SparseAutoencoderLoss(num_components, l1_coefficient) 104 | 105 | res = metric.forward( 106 | source_activations=source_activations, 107 | learned_activations=learned_activations, 108 | decoded_activations=decoded_activations, 109 | ) 110 | 111 | assert allclose(res, expected_loss) 112 | 113 | 114 | def test_compare_sae_loss_to_composition() -> None: 115 | """Test the SAE loss metric against composition of l1 and l2.""" 116 | num_components = 3 117 | l1_coefficient = 0.01 118 | l1 = L1AbsoluteLoss(num_components) 119 | l2 = L2ReconstructionLoss(num_components) 120 | composition_loss = l1 * l1_coefficient + l2 121 | 122 | sae_loss = SparseAutoencoderLoss(num_components, l1_coefficient) 123 | 124 | source_activations = rand(2, num_components, 3) 125 | learned_activations = rand(2, num_components, 4) 126 | decoded_activations = rand(2, num_components, 3) 127 | 128 | composition_res = composition_loss.forward( 129 | source_activations=source_activations, 130 | learned_activations=learned_activations, 131 | decoded_activations=decoded_activations, 132 | ) 133 | 134 | sae_res = sae_loss.forward( 135 | source_activations=source_activations, 136 | learned_activations=learned_activations, 137 | decoded_activations=decoded_activations, 138 | ) 139 | 140 | assert allclose(composition_res, sae_res) 141 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/train/__init__.py: -------------------------------------------------------------------------------- 1 | """Train step metrics.""" 2 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/train/capacity.py: -------------------------------------------------------------------------------- 1 | """Capacity Metrics.""" 2 | from typing import Any 3 | 4 | import einops 5 | from jaxtyping import Float 6 | import torch 7 | from torch import Tensor 8 | from torchmetrics import Metric 9 | 10 | from sparse_autoencoder.tensor_types import Axis 11 | 12 | 13 | class CapacityMetric(Metric): 14 | """Capacities metric. 15 | 16 | Measure the capacity of a set of features as defined in [Polysemanticity and Capacity in Neural 17 | Networks](https://arxiv.org/pdf/2210.01892.pdf). 18 | 19 | Capacity is intuitively measuring the 'proportion of a dimension' assigned to a feature. 20 | Formally it's the ratio of the squared dot product of a feature with itself to the sum of its 21 | squared dot products of all features. 22 | 23 | Warning: 24 | This is memory intensive as it requires caching all learned activations for a batch. 25 | 26 | Examples: 27 | If the features are orthogonal, the capacity is 1. 28 | 29 | >>> metric = CapacityMetric() 30 | >>> learned_activations = torch.tensor([ 31 | ... [ # Batch 1 32 | ... [1., 0., 1.] # Component 1: learned features 33 | ... ], 34 | ... [ # Batch 2 35 | ... [0., 1., 0.] # Component 1: learned features (orthogonal) 36 | ... ] 37 | ... ]) 38 | >>> metric.forward(learned_activations) 39 | tensor([[1., 1.]]) 40 | 41 | If they are all the same, the capacity is 1/n. 42 | 43 | >>> learned_activations = torch.tensor([ 44 | ... [ # Batch 1 45 | ... [1., 1., 1.] # Component 1: learned features 46 | ... ], 47 | ... [ # Batch 2 48 | ... [1., 1., 1.] # Component 1: learned features (same) 49 | ... ] 50 | ... ]) 51 | >>> metric.forward(learned_activations) 52 | tensor([[0.5000, 0.5000]]) 53 | """ 54 | 55 | # Torchmetrics settings 56 | is_differentiable: bool | None = False 57 | full_state_update: bool | None = False 58 | plot_lower_bound: float | None = 0.0 59 | plot_upper_bound: float | None = 1.0 60 | 61 | # State 62 | learned_activations: list[ 63 | Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)] 64 | ] 65 | 66 | def __init__(self) -> None: 67 | """Initialize the metric.""" 68 | super().__init__() 69 | self.add_state("learned_activations", default=[]) 70 | 71 | def update( 72 | self, 73 | learned_activations: Float[ 74 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) 75 | ], 76 | **kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics) 77 | ) -> None: 78 | """Update the metric state. 79 | 80 | Args: 81 | learned_activations: The learned activations. 82 | **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection). 83 | """ 84 | self.learned_activations.append(learned_activations) 85 | 86 | @staticmethod 87 | def capacities( 88 | features: Float[ 89 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) 90 | ], 91 | ) -> Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.BATCH)]: 92 | r"""Calculate capacities. 93 | 94 | Example: 95 | >>> import torch 96 | >>> orthogonal_features = torch.tensor([[[1., 0., 0.]], [[0., 1., 0.]], [[0., 0., 1.]]]) 97 | >>> orthogonal_caps = CapacityMetric.capacities(orthogonal_features) 98 | >>> orthogonal_caps 99 | tensor([[1., 1., 1.]]) 100 | 101 | Args: 102 | features: A collection of features. 103 | 104 | Returns: 105 | A 1D tensor of capacities, where each element is the capacity of the corresponding 106 | feature. 107 | """ 108 | squared_dot_products: Float[ 109 | Tensor, Axis.names(Axis.BATCH, Axis.BATCH, Axis.COMPONENT_OPTIONAL) 110 | ] = ( 111 | einops.einsum( 112 | features, 113 | features, 114 | f"batch_1 ... {Axis.LEARNT_FEATURE}, \ 115 | batch_2 ... {Axis.LEARNT_FEATURE} \ 116 | -> ... batch_1 batch_2", 117 | ) 118 | ** 2 119 | ) 120 | 121 | sum_of_sq_dot: Float[ 122 | Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.BATCH) 123 | ] = squared_dot_products.sum(dim=-1) 124 | 125 | diagonal: Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.BATCH)] = torch.diagonal( 126 | squared_dot_products, dim1=1, dim2=2 127 | ) 128 | 129 | return diagonal / sum_of_sq_dot 130 | 131 | def compute( 132 | self, 133 | ) -> Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.BATCH)]: 134 | """Compute the metric.""" 135 | batch_learned_activations: Float[ 136 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) 137 | ] = torch.cat(self.learned_activations) 138 | 139 | return self.capacities(batch_learned_activations) 140 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/train/feature_density.py: -------------------------------------------------------------------------------- 1 | """Train batch feature density.""" 2 | from typing import Any 3 | 4 | from jaxtyping import Bool, Float, Int64 5 | from pydantic import PositiveInt, validate_call 6 | import torch 7 | from torch import Tensor 8 | from torchmetrics import Metric 9 | 10 | from sparse_autoencoder.tensor_types import Axis 11 | from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions 12 | 13 | 14 | class FeatureDensityMetric(Metric): 15 | """Feature density metric. 16 | 17 | Percentage of samples in which each feature was active (i.e. the neuron has "fired"), in a 18 | training batch. 19 | 20 | Generally we want a small number of features to be active in each batch, so average feature 21 | density should be low. By contrast if the average feature density is high, it means that the 22 | features are not sparse enough. 23 | 24 | Example: 25 | >>> metric = FeatureDensityMetric(num_learned_features=3, num_components=1) 26 | >>> learned_activations = torch.tensor([ 27 | ... [ # Batch 1 28 | ... [1., 0., 1.] # Component 1: learned features (2 active neurons) 29 | ... ], 30 | ... [ # Batch 2 31 | ... [0., 0., 0.] # Component 1: learned features (0 active neuron) 32 | ... ] 33 | ... ]) 34 | >>> metric.forward(learned_activations) 35 | tensor([[0.5000, 0.0000, 0.5000]]) 36 | """ 37 | 38 | # Torchmetrics settings 39 | is_differentiable: bool | None = False 40 | full_state_update: bool | None = True 41 | plot_lower_bound: float | None = 0.0 42 | plot_upper_bound: float | None = 1.0 43 | 44 | # State 45 | neuron_fired_count: Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)] 46 | num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM] 47 | 48 | @validate_call 49 | def __init__( 50 | self, num_learned_features: PositiveInt, num_components: PositiveInt | None = None 51 | ) -> None: 52 | """Initialise the metric.""" 53 | super().__init__() 54 | 55 | self.add_state( 56 | "neuron_fired_count", 57 | default=torch.zeros( 58 | size=shape_with_optional_dimensions(num_components, num_learned_features), 59 | dtype=torch.float, # Float is needed for dist reduce to work 60 | ), 61 | dist_reduce_fx="sum", 62 | ) 63 | 64 | self.add_state( 65 | "num_activation_vectors", 66 | default=torch.tensor(0, dtype=torch.int64), 67 | dist_reduce_fx="sum", 68 | ) 69 | 70 | def update( 71 | self, 72 | learned_activations: Float[ 73 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) 74 | ], 75 | **kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics) 76 | ) -> None: 77 | """Update the metric state. 78 | 79 | Args: 80 | learned_activations: The learned activations. 81 | **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection). 82 | """ 83 | # Increment the counter of activations seen since the last compute step 84 | self.num_activation_vectors += learned_activations.shape[0] 85 | 86 | # Count the number of active neurons in the batch 87 | neuron_has_fired: Bool[ 88 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) 89 | ] = torch.gt(learned_activations, 0) 90 | 91 | self.neuron_fired_count += neuron_has_fired.sum(dim=0, dtype=torch.int64) 92 | 93 | def compute( 94 | self, 95 | ) -> Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]: 96 | """Compute the metric.""" 97 | return self.neuron_fired_count / self.num_activation_vectors 98 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/train/l0_norm.py: -------------------------------------------------------------------------------- 1 | """L0 norm sparsity metric.""" 2 | from typing import Any 3 | 4 | from jaxtyping import Float, Int64 5 | from pydantic import PositiveInt, validate_call 6 | import torch 7 | from torch import Tensor 8 | from torchmetrics import Metric 9 | 10 | from sparse_autoencoder.tensor_types import Axis 11 | from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions 12 | 13 | 14 | class L0NormMetric(Metric): 15 | """Learned activations L0 norm metric. 16 | 17 | The L0 norm is the number of non-zero elements in a learned activation vector, averaged over the 18 | number of activation vectors. 19 | 20 | Examples: 21 | >>> metric = L0NormMetric() 22 | >>> learned_activations = torch.tensor([ 23 | ... [1., 0., 1.], # Batch 1 (single component): learned features (2 active neurons) 24 | ... [0., 1., 0.] # Batch 2 (single component): learned features (1 active neuron) 25 | ... ]) 26 | >>> metric.forward(learned_activations) 27 | tensor(1.5000) 28 | 29 | With 2 components, the metric will return the average number of active (non-zero) 30 | neurons as a 1d tensor. 31 | 32 | >>> metric = L0NormMetric(num_components=2) 33 | >>> learned_activations = torch.tensor([ 34 | ... [ # Batch 1 35 | ... [1., 0., 1.], # Component 1: learned features (2 active neurons) 36 | ... [1., 0., 1.] # Component 2: learned features (2 active neurons) 37 | ... ], 38 | ... [ # Batch 2 39 | ... [0., 1., 0.], # Component 1: learned features (1 active neuron) 40 | ... [1., 0., 1.] # Component 2: learned features (2 active neurons) 41 | ... ] 42 | ... ]) 43 | >>> metric.forward(learned_activations) 44 | tensor([1.5000, 2.0000]) 45 | """ 46 | 47 | # Torchmetrics settings 48 | is_differentiable: bool | None = False 49 | full_state_update: bool | None = False 50 | plot_lower_bound: float | None = 0.0 51 | 52 | # State 53 | active_neurons_count: Float[Tensor, Axis.COMPONENT_OPTIONAL] 54 | num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM] 55 | 56 | @validate_call 57 | def __init__(self, num_components: PositiveInt | None = None) -> None: 58 | """Initialize the metric.""" 59 | super().__init__() 60 | 61 | self.add_state( 62 | "active_neurons_count", 63 | default=torch.zeros(shape_with_optional_dimensions(num_components), dtype=torch.float), 64 | dist_reduce_fx="sum", # Float is needed for dist reduce to work 65 | ) 66 | 67 | self.add_state( 68 | "num_activation_vectors", 69 | default=torch.tensor(0, dtype=torch.int64), 70 | dist_reduce_fx="sum", 71 | ) 72 | 73 | def update( 74 | self, 75 | learned_activations: Float[ 76 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) 77 | ], 78 | **kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics) 79 | ) -> None: 80 | """Update the metric state. 81 | 82 | Args: 83 | learned_activations: The learned activations. 84 | **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection). 85 | """ 86 | self.num_activation_vectors += learned_activations.shape[0] 87 | 88 | self.active_neurons_count += torch.count_nonzero(learned_activations, dim=-1).sum( 89 | dim=0, dtype=torch.int64 90 | ) 91 | 92 | def compute( 93 | self, 94 | ) -> Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL)]: 95 | """Compute the metric. 96 | 97 | Note that torchmetrics converts shape `[0]` tensors into scalars (shape `0`). 98 | """ 99 | return self.active_neurons_count / self.num_activation_vectors 100 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/train/neuron_activity.py: -------------------------------------------------------------------------------- 1 | """Neuron activity metric.""" 2 | from typing import Annotated, Any 3 | 4 | from jaxtyping import Bool, Float, Int64 5 | from pydantic import Field, NonNegativeFloat, PositiveInt, validate_call 6 | import torch 7 | from torch import Tensor 8 | from torchmetrics import Metric 9 | 10 | from sparse_autoencoder.tensor_types import Axis 11 | from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions 12 | 13 | 14 | class NeuronActivityMetric(Metric): 15 | """Neuron activity metric. 16 | 17 | Example: 18 | With a single component and a horizon of 2 activations, the metric will return nothing 19 | after the first activation is added and then computed, and then return the number of dead 20 | neurons after the second activation is added (with update). The breakdown by component isn't 21 | shown here as there is just one component. 22 | 23 | >>> metric = NeuronActivityMetric(num_learned_features=3) 24 | >>> learned_activations = torch.tensor([ 25 | ... [1., 0., 1.], # Batch 1 (single component): learned features (2 active neurons) 26 | ... [0., 0., 0.] # Batch 2 (single component): learned features (0 active neuron) 27 | ... ]) 28 | >>> metric.forward(learned_activations) 29 | tensor(1) 30 | """ 31 | 32 | # Torchmetrics settings 33 | is_differentiable: bool | None = False 34 | full_state_update: bool | None = True 35 | plot_lower_bound: float | None = 0.0 36 | 37 | # Metric settings 38 | _threshold_is_dead_portion_fires: NonNegativeFloat 39 | 40 | # State 41 | neuron_fired_count: Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)] 42 | num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM] 43 | 44 | @validate_call 45 | def __init__( 46 | self, 47 | num_learned_features: PositiveInt, 48 | num_components: PositiveInt | None = None, 49 | threshold_is_dead_portion_fires: Annotated[float, Field(strict=True, ge=0, le=1)] = 0.0, 50 | ) -> None: 51 | """Initialise the metric. 52 | 53 | Args: 54 | num_learned_features: Number of learned features. 55 | num_components: Number of components. 56 | threshold_is_dead_portion_fires: Thresholds for counting a neuron as dead (portion of 57 | activation vectors that it fires for must be less than or equal to this number). 58 | Commonly used values are 0.0, 1e-5 and 1e-6. 59 | """ 60 | super().__init__() 61 | self._threshold_is_dead_portion_fires = threshold_is_dead_portion_fires 62 | 63 | self.add_state( 64 | "neuron_fired_count", 65 | default=torch.zeros( 66 | shape_with_optional_dimensions(num_components, num_learned_features), 67 | dtype=torch.float, # Float is needed for dist reduce to work 68 | ), 69 | dist_reduce_fx="sum", 70 | ) 71 | 72 | self.add_state( 73 | "num_activation_vectors", 74 | default=torch.tensor(0, dtype=torch.int64), 75 | dist_reduce_fx="sum", 76 | ) 77 | 78 | def update( 79 | self, 80 | learned_activations: Float[ 81 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) 82 | ], 83 | **kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics) 84 | ) -> None: 85 | """Update the metric state. 86 | 87 | Args: 88 | learned_activations: The learned activations. 89 | **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection). 90 | """ 91 | # Increment the counter of activations seen since the last compute step 92 | self.num_activation_vectors += learned_activations.shape[0] 93 | 94 | # Count the number of active neurons in the batch 95 | neuron_has_fired: Bool[ 96 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) 97 | ] = torch.gt(learned_activations, 0) 98 | 99 | self.neuron_fired_count += neuron_has_fired.sum(dim=0, dtype=torch.float) 100 | 101 | def compute(self) -> Int64[Tensor, Axis.COMPONENT_OPTIONAL]: 102 | """Compute the metric. 103 | 104 | Note that torchmetrics converts shape `[0]` tensors into scalars (shape `0`). 105 | """ 106 | threshold_activations: Float[Tensor, Axis.SINGLE_ITEM] = ( 107 | self._threshold_is_dead_portion_fires * self.num_activation_vectors 108 | ) 109 | 110 | return torch.sum( 111 | self.neuron_fired_count <= threshold_activations, dim=-1, dtype=torch.int64 112 | ) 113 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/train/neuron_fired_count.py: -------------------------------------------------------------------------------- 1 | """Neuron fired count metric.""" 2 | from typing import Any 3 | 4 | from jaxtyping import Bool, Float, Int 5 | from pydantic import PositiveInt, validate_call 6 | import torch 7 | from torch import Tensor 8 | from torchmetrics import Metric 9 | 10 | from sparse_autoencoder.tensor_types import Axis 11 | from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions 12 | 13 | 14 | class NeuronFiredCountMetric(Metric): 15 | """Neuron activity metric. 16 | 17 | Example: 18 | >>> metric = NeuronFiredCountMetric(num_learned_features=3) 19 | >>> learned_activations = torch.tensor([ 20 | ... [1., 0., 1.], # Batch 1 (single component): learned features (2 active neurons) 21 | ... [0., 0., 0.] # Batch 2 (single component): learned features (0 active neuron) 22 | ... ]) 23 | >>> metric.forward(learned_activations) 24 | tensor([1, 0, 1]) 25 | """ 26 | 27 | # Torchmetrics settings 28 | is_differentiable: bool | None = True 29 | full_state_update: bool | None = True 30 | plot_lower_bound: float | None = 0.0 31 | 32 | # State 33 | neuron_fired_count: Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)] 34 | 35 | @validate_call 36 | def __init__( 37 | self, 38 | num_learned_features: PositiveInt, 39 | num_components: PositiveInt | None = None, 40 | ) -> None: 41 | """Initialise the metric. 42 | 43 | Args: 44 | num_learned_features: Number of learned features. 45 | num_components: Number of components. 46 | """ 47 | super().__init__() 48 | self.add_state( 49 | "neuron_fired_count", 50 | default=torch.zeros( 51 | shape_with_optional_dimensions(num_components, num_learned_features), 52 | dtype=torch.float, # Float is needed for dist reduce to work 53 | ), 54 | dist_reduce_fx="sum", 55 | ) 56 | 57 | def update( 58 | self, 59 | learned_activations: Float[ 60 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) 61 | ], 62 | **kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics) 63 | ) -> None: 64 | """Update the metric state. 65 | 66 | Args: 67 | learned_activations: The learned activations. 68 | **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection). 69 | """ 70 | neuron_has_fired: Bool[ 71 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) 72 | ] = torch.gt(learned_activations, 0) 73 | 74 | self.neuron_fired_count += neuron_has_fired.sum(dim=0, dtype=torch.float) 75 | 76 | def compute(self) -> Int[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]: 77 | """Compute the metric.""" 78 | return self.neuron_fired_count.to(dtype=torch.int64) 79 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/train/tests/test_feature_density.py: -------------------------------------------------------------------------------- 1 | """Test the feature density metric.""" 2 | import pytest 3 | import torch 4 | 5 | from sparse_autoencoder.metrics.train.feature_density import FeatureDensityMetric 6 | 7 | 8 | @pytest.mark.parametrize( 9 | ("num_learned_features", "num_components", "learned_activations", "expected_output"), 10 | [ 11 | pytest.param( 12 | 3, 13 | 1, 14 | torch.tensor([[[1.0, 1.0, 1.0]]]), 15 | torch.tensor([[1.0, 1.0, 1.0]]), 16 | id="Single component axis, all neurons active", 17 | ), 18 | pytest.param( 19 | 3, 20 | None, 21 | torch.tensor([[1.0, 1.0, 1.0]]), 22 | torch.tensor([1.0, 1.0, 1.0]), 23 | id="No component axis, all neurons active", 24 | ), 25 | pytest.param( 26 | 3, 27 | 1, 28 | torch.tensor([[[0.0, 0.0, 0.0]]]), 29 | torch.tensor([[0.0, 0.0, 0.0]]), 30 | id="Single component, no neurons active", 31 | ), 32 | pytest.param( 33 | 3, 34 | 2, 35 | torch.tensor( 36 | [ 37 | [ # Batch 1 38 | [1.0, 0.0, 1.0], # Component 1: learned features 39 | [0.0, 1.0, 0.0], # Component 2: learned features 40 | ], 41 | [ # Batch 2 42 | [0.0, 1.0, 0.0], # Component 1: learned features 43 | [1.0, 0.0, 1.0], # Component 2: learned features 44 | ], 45 | ], 46 | ), 47 | torch.tensor( 48 | [ 49 | [0.5, 0.5, 0.5], # Component 1: learned features 50 | [0.5, 0.5, 0.5], # Component 2: learned features 51 | ] 52 | ), 53 | id="Multiple components, mixed activity", 54 | ), 55 | ], 56 | ) 57 | def test_feature_density_metric( 58 | num_learned_features: int, 59 | num_components: int, 60 | learned_activations: torch.Tensor, 61 | expected_output: torch.Tensor, 62 | ) -> None: 63 | """Test the FeatureDensityMetric for different scenarios. 64 | 65 | Args: 66 | num_learned_features: Number of learned features. 67 | num_components: Number of components. 68 | learned_activations: Learned activations tensor. 69 | expected_output: Expected output tensor. 70 | """ 71 | metric = FeatureDensityMetric(num_learned_features, num_components) 72 | result = metric.forward(learned_activations) 73 | assert result.shape == expected_output.shape 74 | assert torch.allclose(result, expected_output) 75 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/train/tests/test_l0_norm.py: -------------------------------------------------------------------------------- 1 | """Test the L0 norm sparsity metric.""" 2 | import pytest 3 | import torch 4 | 5 | from sparse_autoencoder.metrics.train.l0_norm import ( 6 | L0NormMetric, # Adjust the import path as needed 7 | ) 8 | 9 | 10 | @pytest.mark.parametrize( 11 | ("num_components", "learned_activations", "expected_output"), 12 | [ 13 | pytest.param( 14 | 1, 15 | torch.tensor([[[1.0, 0.0, 1.0]]]), 16 | torch.tensor(2.0), 17 | id="Single component, mixed activity", 18 | ), 19 | pytest.param( 20 | None, 21 | torch.tensor([[1.0, 0.0, 1.0]]), 22 | torch.tensor(2.0), 23 | id="No component axis, mixed activity", 24 | ), 25 | pytest.param( 26 | 1, 27 | torch.tensor([[[0.0, 0.0, 0.0]]]), 28 | torch.tensor(0.0), 29 | id="Single component, no neurons active", 30 | ), 31 | pytest.param( 32 | 2, 33 | torch.tensor( 34 | [ 35 | [[1.0, 0.0, 1.0], [1.0, 0.0, 1.0]], 36 | [[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], 37 | ] 38 | ), 39 | torch.tensor([1.5, 2.0]), 40 | id="Multiple components, mixed activity", 41 | ), 42 | ], 43 | ) 44 | def test_l0_norm_metric( 45 | num_components: int, 46 | learned_activations: torch.Tensor, 47 | expected_output: torch.Tensor, 48 | ) -> None: 49 | """Test the L0NormMetric for different scenarios. 50 | 51 | Args: 52 | num_components: Number of components. 53 | learned_activations: Learned activations tensor. 54 | expected_output: Expected output tensor. 55 | """ 56 | metric = L0NormMetric(num_components) 57 | result = metric.forward(learned_activations) 58 | 59 | assert result.shape == expected_output.shape 60 | assert torch.allclose(result, expected_output) 61 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/train/tests/test_neuron_activity.py: -------------------------------------------------------------------------------- 1 | """Test the Neuron Activity Metric.""" 2 | import pytest 3 | import torch 4 | 5 | from sparse_autoencoder.metrics.train.neuron_activity import NeuronActivityMetric 6 | 7 | 8 | @pytest.mark.parametrize( 9 | ( 10 | "num_learned_features", 11 | "num_components", 12 | "threshold_is_dead_portion_fires", 13 | "learned_activations", 14 | "expected_output", 15 | ), 16 | [ 17 | pytest.param( 18 | 3, 19 | 1, 20 | 0, 21 | torch.tensor( 22 | [ 23 | [ # Batch 1 24 | [1.0, 0.0, 1.0] # Component 1: learned features (2 active neurons) 25 | ], 26 | [ # Batch 2 27 | [0.0, 0.0, 0.0] # Component 1: learned features (0 active neuron) 28 | ], 29 | ] 30 | ), 31 | torch.tensor(1), 32 | id="Single component, one dead neuron", 33 | ), 34 | pytest.param( 35 | 3, 36 | None, 37 | 0, 38 | torch.tensor([[1.0, 0.0, 1.0], [0.0, 0.0, 0.0]]), 39 | torch.tensor(1), 40 | id="No component axis, one dead neuron", 41 | ), 42 | pytest.param( 43 | 3, 44 | 1, 45 | 0.0, 46 | torch.tensor([[[1.0, 1.0, 1.0]], [[0.0, 0.0, 0.0]]]), 47 | torch.tensor(0), 48 | id="Single component, no dead neurons", 49 | ), 50 | pytest.param( 51 | 3, 52 | 2, 53 | 0, 54 | torch.tensor( 55 | [ 56 | [ # Batch 1 57 | [1.0, 0.0, 1.0], # Component 1: learned features 58 | [1.0, 0.0, 1.0], # Component 2: learned features 59 | ], 60 | [ # Batch 2 61 | [0.0, 1.0, 0.0], # Component 1: learned features 62 | [1.0, 0.0, 1.0], # Component 2: learned features 63 | ], 64 | ] 65 | ), 66 | torch.tensor([0, 1]), 67 | id="Multiple components, mixed dead neurons", 68 | ), 69 | ], 70 | ) 71 | def test_neuron_activity_metric( 72 | num_learned_features: int, 73 | num_components: int, 74 | threshold_is_dead_portion_fires: float, 75 | learned_activations: torch.Tensor, 76 | expected_output: torch.Tensor, 77 | ) -> None: 78 | """Test the NeuronActivityMetric for different scenarios. 79 | 80 | Args: 81 | num_learned_features: Number of learned features. 82 | num_components: Number of components. 83 | threshold_is_dead_portion_fires: Threshold for counting a neuron as dead. 84 | learned_activations: Learned activations tensor. 85 | expected_output: Expected number of dead neurons. 86 | """ 87 | metric = NeuronActivityMetric( 88 | num_learned_features, num_components, threshold_is_dead_portion_fires 89 | ) 90 | result = metric.forward(learned_activations) 91 | assert result.shape == expected_output.shape 92 | assert torch.equal(result, expected_output) 93 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/train/tests/test_neuron_fired_count.py: -------------------------------------------------------------------------------- 1 | """Test the Neuron Fired Count Metric.""" 2 | import pytest 3 | import torch 4 | 5 | from sparse_autoencoder.metrics.train.neuron_fired_count import NeuronFiredCountMetric 6 | 7 | 8 | @pytest.mark.parametrize( 9 | ( 10 | "num_learned_features", 11 | "num_components", 12 | "threshold_is_dead_portion_fires", 13 | "learned_activations", 14 | "expected_output", 15 | ), 16 | [ 17 | pytest.param( 18 | 3, 19 | 1, 20 | 0, 21 | torch.tensor( 22 | [ 23 | [ # Batch 1 24 | [1.0, 0.0, 1.0] # Component 1: learned features (2 active neurons) 25 | ], 26 | [ # Batch 2 27 | [0.0, 0.0, 0.0] # Component 1: learned features (0 active neuron) 28 | ], 29 | ] 30 | ), 31 | torch.tensor([[1, 0, 1]]), 32 | id="Single component, one dead neuron", 33 | ), 34 | pytest.param( 35 | 3, 36 | None, 37 | 0, 38 | torch.tensor([[1.0, 0.0, 1.0], [0.0, 0.0, 0.0]]), 39 | torch.tensor([1, 0, 1]), 40 | id="No component axis, one dead neuron", 41 | ), 42 | pytest.param( 43 | 3, 44 | 1, 45 | 0.0, 46 | torch.tensor([[[1.0, 1.0, 1.0]], [[0.0, 0.0, 0.0]]]), 47 | torch.tensor([[1, 1, 1]]), 48 | id="Single component, no dead neurons", 49 | ), 50 | pytest.param( 51 | 3, 52 | 2, 53 | 0, 54 | torch.tensor( 55 | [ 56 | [ # Batch 1 57 | [1.0, 0.0, 1.0], # Component 1: learned features 58 | [1.0, 0.0, 1.0], # Component 2: learned features 59 | ], 60 | [ # Batch 2 61 | [0.0, 1.0, 0.0], # Component 1: learned features 62 | [1.0, 0.0, 1.0], # Component 2: learned features 63 | ], 64 | ] 65 | ), 66 | torch.tensor([[1, 1, 1], [2, 0, 2]]), 67 | id="Multiple components, mixed dead neurons", 68 | ), 69 | ], 70 | ) 71 | def test_neuron_fired_count_metric( 72 | num_learned_features: int, 73 | num_components: int, 74 | threshold_is_dead_portion_fires: float, 75 | learned_activations: torch.Tensor, 76 | expected_output: torch.Tensor, 77 | ) -> None: 78 | """Test the NeuronFiredCount for different scenarios. 79 | 80 | Args: 81 | num_learned_features: Number of learned features. 82 | num_components: Number of components. 83 | threshold_is_dead_portion_fires: Threshold for counting a neuron as dead. 84 | learned_activations: Learned activations tensor. 85 | expected_output: Expected number of dead neurons. 86 | """ 87 | metric = NeuronFiredCountMetric(num_learned_features, num_components) 88 | result = metric.forward(learned_activations) 89 | assert result.shape == expected_output.shape 90 | assert torch.equal(result, expected_output) 91 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/validate/__init__.py: -------------------------------------------------------------------------------- 1 | """Validate step metrics.""" 2 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/validate/reconstruction_score.py: -------------------------------------------------------------------------------- 1 | """Reconstruction score metric.""" 2 | from jaxtyping import Float, Int64 3 | from pydantic import PositiveInt, validate_call 4 | import torch 5 | from torch import Tensor 6 | from torchmetrics import Metric 7 | 8 | from sparse_autoencoder.tensor_types import Axis 9 | 10 | 11 | class ReconstructionScoreMetric(Metric): 12 | r"""Model reconstruction score. 13 | 14 | Creates a score that measures how well the model can reconstruct the data. 15 | 16 | $$ 17 | \begin{align*} 18 | v &= \text{number of validation items} \\ 19 | l \in{\mathbb{R}^v} &= \text{loss with no changes to the source model} \\ 20 | l_\text{recon} \in{\mathbb{R}^v} &= \text{loss with reconstruction} \\ 21 | l_\text{zero} \in{\mathbb{R}^v} &= \text{loss with zero ablation} \\ 22 | s &= \text{reconstruction score} \\ 23 | s_\text{itemwise} &= \frac{l_\text{zero} - l_\text{recon}}{l_\text{zero} - l} \\ 24 | s &= \sum_{i=1}^v s_\text{itemwise} / v 25 | \end{align*} 26 | $$ 27 | 28 | Example: 29 | >>> metric = ReconstructionScoreMetric(num_components=1) 30 | >>> source_model_loss=torch.tensor([2.0, 2.0, 2.0]) 31 | >>> source_model_loss_with_reconstruction=torch.tensor([3.0, 3.0, 3.0]) 32 | >>> source_model_loss_with_zero_ablation=torch.tensor([5.0, 5.0, 5.0]) 33 | >>> metric.forward( 34 | ... source_model_loss=source_model_loss, 35 | ... source_model_loss_with_reconstruction=source_model_loss_with_reconstruction, 36 | ... source_model_loss_with_zero_ablation=source_model_loss_with_zero_ablation 37 | ... ) 38 | tensor(0.6667) 39 | """ 40 | 41 | # Torchmetrics settings 42 | is_differentiable: bool | None = False 43 | full_state_update: bool | None = True 44 | 45 | # State 46 | source_model_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL] 47 | source_model_loss_with_zero_ablation: Float[Tensor, Axis.COMPONENT_OPTIONAL] 48 | source_model_loss_with_reconstruction: Float[Tensor, Axis.COMPONENT_OPTIONAL] 49 | num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM] 50 | 51 | @validate_call 52 | def __init__(self, num_components: PositiveInt = 1) -> None: 53 | """Initialise the metric.""" 54 | super().__init__() 55 | 56 | self.add_state( 57 | "source_model_loss", default=torch.zeros(num_components), dist_reduce_fx="sum" 58 | ) 59 | self.add_state( 60 | "source_model_loss_with_zero_ablation", 61 | default=torch.zeros(num_components), 62 | dist_reduce_fx="sum", 63 | ) 64 | self.add_state( 65 | "source_model_loss_with_reconstruction", 66 | default=torch.zeros(num_components), 67 | dist_reduce_fx="sum", 68 | ) 69 | 70 | def update( 71 | self, 72 | source_model_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL], 73 | source_model_loss_with_reconstruction: Float[Tensor, Axis.COMPONENT_OPTIONAL], 74 | source_model_loss_with_zero_ablation: Float[Tensor, Axis.COMPONENT_OPTIONAL], 75 | component_idx: int = 0, 76 | ) -> None: 77 | """Update the metric state. 78 | 79 | Args: 80 | source_model_loss: Loss with no changes to the source model. 81 | source_model_loss_with_reconstruction: Loss with SAE reconstruction. 82 | source_model_loss_with_zero_ablation: Loss with zero ablation. 83 | component_idx: Component idx. 84 | """ 85 | self.source_model_loss[component_idx] += source_model_loss.sum() 86 | self.source_model_loss_with_zero_ablation[ 87 | component_idx 88 | ] += source_model_loss_with_zero_ablation.sum() 89 | self.source_model_loss_with_reconstruction[ 90 | component_idx 91 | ] += source_model_loss_with_reconstruction.sum() 92 | 93 | def compute( 94 | self, 95 | ) -> Float[Tensor, Axis.COMPONENT_OPTIONAL]: 96 | """Compute the metric.""" 97 | zero_ablate_loss_minus_reconstruction_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL] = ( 98 | self.source_model_loss_with_zero_ablation - self.source_model_loss_with_reconstruction 99 | ) 100 | 101 | zero_ablate_loss_minus_default_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL] = ( 102 | self.source_model_loss_with_zero_ablation - self.source_model_loss 103 | ) 104 | 105 | return zero_ablate_loss_minus_reconstruction_loss / zero_ablate_loss_minus_default_loss 106 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/validate/tests/__snapshots__/test_model_reconstruction_score.ambr: -------------------------------------------------------------------------------- 1 | # serializer version: 1 2 | # name: test_weights_biases_log_matches_snapshot 3 | list([ 4 | dict({ 5 | 'component_0/validate/reconstruction_score/baseline_loss': tensor(0.3800), 6 | 'component_1/validate/reconstruction_score/baseline_loss': tensor(0.5251), 7 | 'component_2/validate/reconstruction_score/baseline_loss': tensor(0.4923), 8 | 'component_3/validate/reconstruction_score/baseline_loss': tensor(0.4598), 9 | 'component_4/validate/reconstruction_score/baseline_loss': tensor(0.4281), 10 | 'component_5/validate/reconstruction_score/baseline_loss': tensor(0.4961), 11 | 'validate/reconstruction_score/baseline_loss/component_mean': tensor(0.4636), 12 | }), 13 | dict({ 14 | 'component_0/validate/reconstruction_score/loss_with_reconstruction': tensor(0.6111), 15 | 'component_1/validate/reconstruction_score/loss_with_reconstruction': tensor(0.5219), 16 | 'component_2/validate/reconstruction_score/loss_with_reconstruction': tensor(0.4063), 17 | 'component_3/validate/reconstruction_score/loss_with_reconstruction': tensor(0.6497), 18 | 'component_4/validate/reconstruction_score/loss_with_reconstruction': tensor(0.4929), 19 | 'component_5/validate/reconstruction_score/loss_with_reconstruction': tensor(0.3723), 20 | 'validate/reconstruction_score/loss_with_reconstruction/component_mean': tensor(0.5090), 21 | }), 22 | dict({ 23 | 'component_0/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.2891), 24 | 'component_1/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.3879), 25 | 'component_2/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.5850), 26 | 'component_3/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.4740), 27 | 'component_4/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.5452), 28 | 'component_5/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.3733), 29 | 'validate/reconstruction_score/loss_with_zero_ablation/component_mean': tensor(0.4424), 30 | }), 31 | dict({ 32 | 'component_0/validate/reconstruction_score': tensor(3.5422), 33 | 'component_1/validate/reconstruction_score': tensor(0.9767), 34 | 'component_2/validate/reconstruction_score': tensor(1.9278), 35 | 'component_3/validate/reconstruction_score': tensor(-12.3338), 36 | 'component_4/validate/reconstruction_score': tensor(0.4468), 37 | 'component_5/validate/reconstruction_score': tensor(-0.0081), 38 | 'validate/reconstruction_score/component_mean': tensor(-0.9081), 39 | }), 40 | ]) 41 | # --- 42 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | """Metric wrappers.""" 2 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/wrappers/classwise.py: -------------------------------------------------------------------------------- 1 | """Classwise metrics wrapper.""" 2 | import torch 3 | from torch import Tensor 4 | from torchmetrics import ClasswiseWrapper, Metric 5 | 6 | 7 | class ClasswiseWrapperWithMean(ClasswiseWrapper): 8 | """Classwise wrapper with mean. 9 | 10 | This metric works together with classification metrics that returns multiple values (one value 11 | per class) such that label information can be automatically included in the output. It extends 12 | the standard torchmetrics wrapper that does this, adding in an additional mean value (across all 13 | classes). 14 | """ 15 | 16 | _prefix: str 17 | 18 | labels: list[str] 19 | 20 | def __init__( 21 | self, 22 | metric: Metric, 23 | component_names: list[str] | None = None, 24 | prefix: str | None = None, 25 | ) -> None: 26 | """Initialise the classwise wrapper. 27 | 28 | Args: 29 | metric: Metric to wrap. 30 | component_names: Component names. 31 | prefix: Prefix for the name (will replace the default of the class name). 32 | """ 33 | super().__init__(metric, component_names, prefix) 34 | 35 | # Default prefix 36 | if not self._prefix: 37 | self._prefix = f"{self.metric.__class__.__name__.lower()}" 38 | 39 | def _convert(self, x: Tensor) -> dict[str, Tensor]: 40 | """Convert the input tensor to a dictionary of metrics. 41 | 42 | Args: 43 | x: The input tensor. 44 | 45 | Returns: 46 | A dictionary of metric results. 47 | """ 48 | # Add a component axis if not present (as Metric squeezes it out) 49 | if x.ndim == 0: 50 | x = x.unsqueeze(dim=0) 51 | 52 | # Same splitting as the original classwise wrapper 53 | res = {f"{self._prefix}/{lab}": val for lab, val in zip(self.labels, x)} 54 | 55 | # Add in the mean 56 | res[f"{self._prefix}/mean"] = x.mean(0, dtype=torch.float) 57 | 58 | return res 59 | -------------------------------------------------------------------------------- /sparse_autoencoder/metrics/wrappers/tests/test_classwise.py: -------------------------------------------------------------------------------- 1 | """Test the classwise wrapper.""" 2 | import pytest 3 | import torch 4 | 5 | from sparse_autoencoder.metrics.train.feature_density import FeatureDensityMetric 6 | from sparse_autoencoder.metrics.train.l0_norm import L0NormMetric 7 | from sparse_autoencoder.metrics.wrappers.classwise import ClasswiseWrapperWithMean 8 | 9 | 10 | @pytest.mark.parametrize( 11 | ("num_components"), 12 | [ 13 | pytest.param(1, id="Single component"), 14 | pytest.param(2, id="Multiple components"), 15 | ], 16 | ) 17 | def test_feature_density_classwise_wrapper(num_components: int) -> None: 18 | """Test the classwise wrapper.""" 19 | metric = FeatureDensityMetric(3, num_components) 20 | component_names = [f"mlp_{n}" for n in range(num_components)] 21 | wrapped_metric = ClasswiseWrapperWithMean(metric, component_names, prefix="feature_density") 22 | 23 | learned_activations = torch.tensor([[[1.0, 1.0, 1.0]] * num_components]) 24 | expected_output = torch.tensor([[1.0, 1.0, 1.0]]) 25 | res = wrapped_metric.forward(learned_activations=learned_activations) 26 | 27 | for component in component_names: 28 | assert torch.allclose(res[f"feature_density/{component}"], expected_output) 29 | 30 | assert torch.allclose(res["feature_density/mean"], expected_output.mean(0)) 31 | 32 | 33 | @pytest.mark.parametrize( 34 | ("num_components"), 35 | [ 36 | pytest.param(1, id="Single component"), 37 | pytest.param(2, id="Multiple components"), 38 | ], 39 | ) 40 | def test_l0_norm_classwise_wrapper(num_components: int) -> None: 41 | """Test the classwise wrapper.""" 42 | metric = L0NormMetric(num_components) 43 | component_names = [f"mlp_{n}" for n in range(num_components)] 44 | wrapped_metric = ClasswiseWrapperWithMean(metric, component_names, prefix="l0") 45 | 46 | learned_activations = torch.tensor([[[1.0, 0.0, 1.0]] * num_components]) 47 | expected_output = torch.tensor([2.0]) 48 | res = wrapped_metric.forward(learned_activations=learned_activations) 49 | 50 | for component in component_names: 51 | assert torch.allclose(res[f"l0/{component}"], expected_output) 52 | 53 | assert torch.allclose(res["l0/mean"], expected_output.mean(0)) 54 | -------------------------------------------------------------------------------- /sparse_autoencoder/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | """Optimizers for Sparse Autoencoders.""" 2 | -------------------------------------------------------------------------------- /sparse_autoencoder/optimizer/tests/test_adam_with_reset.py: -------------------------------------------------------------------------------- 1 | """Tests for AdamWithReset optimizer.""" 2 | import pytest 3 | import torch 4 | 5 | from sparse_autoencoder.autoencoder.model import SparseAutoencoder, SparseAutoencoderConfig 6 | from sparse_autoencoder.optimizer.adam_with_reset import AdamWithReset 7 | 8 | 9 | @pytest.fixture() 10 | def model_and_optimizer() -> tuple[torch.nn.Module, AdamWithReset]: 11 | """Model and optimizer fixture.""" 12 | torch.random.manual_seed(0) 13 | model = SparseAutoencoder(SparseAutoencoderConfig(n_input_features=5, n_learned_features=10)) 14 | optimizer = AdamWithReset( 15 | model.parameters(), 16 | named_parameters=model.named_parameters(), 17 | lr=0.0001, 18 | has_components_dim=False, 19 | ) 20 | 21 | # Initialise adam state by doing some steps 22 | for _ in range(3): 23 | source_activations = torch.rand((10, 5)) 24 | optimizer.zero_grad() 25 | _, decoded_activations = model.forward(source_activations) 26 | dummy_loss = torch.nn.functional.mse_loss(decoded_activations, source_activations) 27 | dummy_loss.backward() 28 | optimizer.step() 29 | 30 | # Force all state values to be non-zero 31 | optimizer.state[model.encoder.weight]["exp_avg"] = ( 32 | optimizer.state[model.encoder.weight]["exp_avg"] + 1.0 33 | ) 34 | optimizer.state[model.encoder.weight]["exp_avg_sq"] = ( 35 | optimizer.state[model.encoder.weight]["exp_avg_sq"] + 1.0 36 | ) 37 | 38 | return model, optimizer 39 | 40 | 41 | def test_initialization(model_and_optimizer: tuple[torch.nn.Module, AdamWithReset]) -> None: 42 | """Test initialization of AdamWithReset optimizer.""" 43 | model, optimizer = model_and_optimizer 44 | assert len(optimizer.parameter_names) == len(list(model.named_parameters())) 45 | 46 | 47 | def test_reset_state_all_parameters( 48 | model_and_optimizer: tuple[torch.nn.Module, AdamWithReset], 49 | ) -> None: 50 | """Test reset_state_all_parameters method.""" 51 | _, optimizer = model_and_optimizer 52 | optimizer.reset_state_all_parameters() 53 | 54 | for group in optimizer.param_groups: 55 | for parameter in group["params"]: 56 | # Get the state 57 | parameter_state = optimizer.state[parameter] 58 | for state_name in parameter_state: 59 | if state_name in ["exp_avg", "exp_avg_sq", "max_exp_avg_sq"]: 60 | # Check all state values are reset to zero 61 | state = parameter_state[state_name] 62 | assert torch.all(state == 0) 63 | 64 | 65 | def test_reset_neurons_state(model_and_optimizer: tuple[torch.nn.Module, AdamWithReset]) -> None: 66 | """Test reset_neurons_state method.""" 67 | model, optimizer = model_and_optimizer 68 | optimizer.reset_neurons_state(model.encoder.weight, torch.tensor([1]), axis=0) 69 | 70 | res = optimizer.state[model.encoder.weight] 71 | 72 | assert torch.all(res["exp_avg"][1, :] == 0) 73 | assert not torch.any(res["exp_avg"][2, :] == 0) 74 | assert not torch.any(res["exp_avg"][2:, 1] == 0) 75 | 76 | 77 | def test_reset_neurons_state_no_dead_neurons( 78 | model_and_optimizer: tuple[torch.nn.Module, AdamWithReset], 79 | ) -> None: 80 | """Test reset_neurons_state method with 0 dead neurons.""" 81 | model, optimizer = model_and_optimizer 82 | 83 | res = optimizer.state[model.encoder.weight] 84 | 85 | # Example usage 86 | optimizer.reset_neurons_state(model.encoder.weight, torch.tensor([], dtype=torch.int64), axis=0) 87 | 88 | res = optimizer.state[model.encoder.weight] 89 | 90 | assert not torch.any(res["exp_avg"] == 0) 91 | assert not torch.any(res["exp_avg_sq"] == 0) 92 | -------------------------------------------------------------------------------- /sparse_autoencoder/source_data/__init__.py: -------------------------------------------------------------------------------- 1 | """Source Data.""" 2 | -------------------------------------------------------------------------------- /sparse_autoencoder/source_data/mock_dataset.py: -------------------------------------------------------------------------------- 1 | """Mock dataset. 2 | 3 | For use with tests and simple examples. 4 | """ 5 | from collections.abc import Iterator 6 | from typing import Literal, final 7 | 8 | from datasets import IterableDataset 9 | from jaxtyping import Int 10 | from pydantic import PositiveInt, validate_call 11 | import torch 12 | from torch import Tensor 13 | from transformers import PreTrainedTokenizerFast 14 | 15 | from sparse_autoencoder.source_data.abstract_dataset import ( 16 | SourceDataset, 17 | TokenizedPrompts, 18 | TorchTokenizedPrompts, 19 | ) 20 | 21 | 22 | class ConsecutiveIntHuggingFaceDataset(IterableDataset): 23 | """Consecutive integers Hugging Face dataset for testing. 24 | 25 | Creates a dataset where the first item is [0,1,2...], and the second item is [1,2,3...] and so 26 | on. 27 | """ 28 | 29 | _data: Int[Tensor, "items context_size"] 30 | """Generated data.""" 31 | 32 | _length: int 33 | """Size of the dataset.""" 34 | 35 | _format: Literal["torch", "list"] = "list" 36 | """Format of the data.""" 37 | 38 | def create_data(self, n_items: int, context_size: int) -> Int[Tensor, "items context_size"]: 39 | """Create the data. 40 | 41 | Args: 42 | n_items: The number of items in the dataset. 43 | context_size: The number of tokens in the context window. 44 | 45 | Returns: 46 | The generated data. 47 | """ 48 | rows = torch.arange(n_items).unsqueeze(1) 49 | columns = torch.arange(context_size).unsqueeze(0) 50 | return rows + columns 51 | 52 | def __init__(self, context_size: int, vocab_size: int = 50_000, n_items: int = 10_000) -> None: 53 | """Initialize the mock HF dataset. 54 | 55 | Args: 56 | context_size: The number of tokens in the context window 57 | vocab_size: The size of the vocabulary to use. 58 | n_items: The number of items in the dataset. 59 | 60 | Raises: 61 | ValueError: If more items are requested than we can create with the vocab size (given 62 | that each item is a consecutive list of integers and unique). 63 | """ 64 | self._length = n_items 65 | 66 | # Check we can create the data 67 | if n_items + context_size > vocab_size: 68 | error_message = ( 69 | f"n_items ({n_items}) + context_size ({context_size}) must be less than " 70 | f"vocab_size ({vocab_size})" 71 | ) 72 | raise ValueError(error_message) 73 | 74 | # Initialise the data 75 | self._data = self.create_data(n_items, context_size) 76 | 77 | def __iter__(self) -> Iterator: # type: ignore (HF typing is incorrect) 78 | """Initialize the iterator. 79 | 80 | Returns: 81 | Iterator. 82 | """ 83 | self._index = 0 84 | return self 85 | 86 | def __next__(self) -> TokenizedPrompts | TorchTokenizedPrompts: 87 | """Return the next item in the dataset. 88 | 89 | Returns: 90 | TokenizedPrompts: The next item in the dataset. 91 | 92 | Raises: 93 | StopIteration: If the end of the dataset is reached. 94 | """ 95 | if self._index < self._length: 96 | item = self[self._index] 97 | self._index += 1 98 | return item 99 | 100 | raise StopIteration 101 | 102 | def __len__(self) -> int: 103 | """Len Dunder Method.""" 104 | return self._length 105 | 106 | def __getitem__(self, index: int) -> TokenizedPrompts | TorchTokenizedPrompts: 107 | """Get Item.""" 108 | item = self._data[index] 109 | 110 | if self._format == "torch": 111 | return {"input_ids": item} 112 | 113 | return {"input_ids": item.tolist()} 114 | 115 | def with_format( # type: ignore (only support 2 types) 116 | self, 117 | type: Literal["torch", "list"], # noqa: A002 118 | ) -> "ConsecutiveIntHuggingFaceDataset": 119 | """With Format.""" 120 | self._format = type 121 | return self 122 | 123 | 124 | @final 125 | class MockDataset(SourceDataset[TokenizedPrompts]): 126 | """Mock dataset for testing. 127 | 128 | For use with tests and simple examples. 129 | """ 130 | 131 | tokenizer: PreTrainedTokenizerFast 132 | 133 | def preprocess( 134 | self, 135 | source_batch: TokenizedPrompts, 136 | *, 137 | context_size: int, # noqa: ARG002 138 | ) -> TokenizedPrompts: 139 | """Preprocess a batch of prompts.""" 140 | # Nothing to do here 141 | return source_batch 142 | 143 | @validate_call 144 | def __init__( 145 | self, 146 | context_size: PositiveInt = 250, 147 | buffer_size: PositiveInt = 1000, # noqa: ARG002 148 | preprocess_batch_size: PositiveInt = 1000, # noqa: ARG002 149 | dataset_path: str = "dummy", # noqa: ARG002 150 | dataset_split: str = "train", # noqa: ARG002 151 | ): 152 | """Initialize the Random Int Dummy dataset. 153 | 154 | Example: 155 | >>> data = MockDataset() 156 | >>> first_item = next(iter(data)) 157 | >>> len(first_item["input_ids"]) 158 | 250 159 | 160 | Args: 161 | context_size: The context size to use when returning a list of tokenized prompts. 162 | *Towards Monosemanticity: Decomposing Language Models With Dictionary Learning* used 163 | a context size of 250. 164 | buffer_size: The buffer size to use when shuffling the dataset. As the dataset is 165 | streamed, this just pre-downloads at least `buffer_size` items and then shuffles 166 | just that buffer. Note that the generated activations should also be shuffled before 167 | training the sparse autoencoder, so a large buffer may not be strictly necessary 168 | here. Note also that this is the number of items in the dataset (e.g. number of 169 | prompts) and is typically significantly less than the number of tokenized prompts 170 | once the preprocessing function has been applied. 171 | preprocess_batch_size: The batch size to use just for preprocessing the dataset (e.g. 172 | tokenizing prompts). 173 | dataset_path: The path to the dataset on Hugging Face. 174 | dataset_split: Dataset split (e.g. `train`). 175 | """ 176 | self.dataset = ConsecutiveIntHuggingFaceDataset(context_size=context_size) # type: ignore 177 | self.context_size = context_size 178 | -------------------------------------------------------------------------------- /sparse_autoencoder/source_data/pretokenized_dataset.py: -------------------------------------------------------------------------------- 1 | """Pre-Tokenized Dataset from Hugging Face. 2 | 3 | PreTokenizedDataset should work with any of the following tokenized datasets: 4 | - NeelNanda/pile-small-tokenized-2b 5 | - NeelNanda/pile-tokenized-10b 6 | - NeelNanda/openwebtext-tokenized-9b 7 | - NeelNanda/c4-tokenized-2b 8 | - NeelNanda/code-tokenized 9 | - NeelNanda/c4-code-tokenized-2b 10 | - NeelNanda/pile-old-tokenized-2b 11 | - alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2 12 | 13 | """ 14 | from collections.abc import Mapping, Sequence 15 | from typing import final 16 | 17 | from pydantic import PositiveInt, validate_call 18 | 19 | from sparse_autoencoder.source_data.abstract_dataset import SourceDataset, TokenizedPrompts 20 | 21 | 22 | @final 23 | class PreTokenizedDataset(SourceDataset[dict]): 24 | """General Pre-Tokenized Dataset from Hugging Face. 25 | 26 | Can be used for various datasets available on Hugging Face. 27 | """ 28 | 29 | def preprocess( 30 | self, 31 | source_batch: dict, 32 | *, 33 | context_size: int, 34 | ) -> TokenizedPrompts: 35 | """Preprocess a batch of prompts. 36 | 37 | The method splits each pre-tokenized item based on the context size. 38 | 39 | Args: 40 | source_batch: A batch of source data. 41 | context_size: The context size to use for tokenized prompts. 42 | 43 | Returns: 44 | Tokenized prompts. 45 | 46 | Raises: 47 | ValueError: If the context size is larger than the tokenized prompt size. 48 | """ 49 | tokenized_prompts: list[list[int]] = source_batch[self._dataset_column_name] 50 | 51 | # Check the context size is not too large 52 | if context_size > len(tokenized_prompts[0]): 53 | error_message = ( 54 | f"The context size ({context_size}) is larger than the " 55 | f"tokenized prompt size ({len(tokenized_prompts[0])})." 56 | ) 57 | raise ValueError(error_message) 58 | 59 | # Chunk each tokenized prompt into blocks of context_size, 60 | # discarding the last block if too small. 61 | context_size_prompts = [] 62 | for encoding in tokenized_prompts: 63 | chunks = [ 64 | encoding[i : i + context_size] 65 | for i in range(0, len(encoding), context_size) 66 | if len(encoding[i : i + context_size]) == context_size 67 | ] 68 | context_size_prompts.extend(chunks) 69 | 70 | return {"input_ids": context_size_prompts} 71 | 72 | @validate_call 73 | def __init__( 74 | self, 75 | dataset_path: str, 76 | context_size: PositiveInt = 256, 77 | buffer_size: PositiveInt = 1000, 78 | dataset_dir: str | None = None, 79 | dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None, 80 | dataset_split: str = "train", 81 | dataset_column_name: str = "input_ids", 82 | preprocess_batch_size: PositiveInt = 1000, 83 | *, 84 | pre_download: bool = False, 85 | ): 86 | """Initialize a pre-tokenized dataset from Hugging Face. 87 | 88 | Args: 89 | dataset_path: The path to the dataset on Hugging Face (e.g. 90 | `alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2). 91 | context_size: The context size for tokenized prompts. 92 | buffer_size: The buffer size to use when shuffling the dataset when streaming. When 93 | streaming a dataset, this just pre-downloads at least `buffer_size` items and then 94 | shuffles just that buffer. Note that the generated activations should also be 95 | shuffled before training the sparse autoencoder, so a large buffer may not be 96 | strictly necessary here. Note also that this is the number of items in the dataset 97 | (e.g. number of prompts) and is typically significantly less than the number of 98 | tokenized prompts once the preprocessing function has been applied. 99 | dataset_dir: Defining the `data_dir` of the dataset configuration. 100 | dataset_files: Path(s) to source data file(s). 101 | dataset_split: Dataset split (e.g. `train`). 102 | dataset_column_name: The column name for the tokenized prompts. 103 | preprocess_batch_size: The batch size to use just for preprocessing the dataset (e.g. 104 | tokenizing prompts). 105 | pre_download: Whether to pre-download the whole dataset. 106 | """ 107 | super().__init__( 108 | buffer_size=buffer_size, 109 | context_size=context_size, 110 | dataset_dir=dataset_dir, 111 | dataset_files=dataset_files, 112 | dataset_path=dataset_path, 113 | dataset_split=dataset_split, 114 | dataset_column_name=dataset_column_name, 115 | pre_download=pre_download, 116 | preprocess_batch_size=preprocess_batch_size, 117 | ) 118 | -------------------------------------------------------------------------------- /sparse_autoencoder/source_data/tests/test_abstract_dataset.py: -------------------------------------------------------------------------------- 1 | """Test the abstract dataset.""" 2 | 3 | import pytest 4 | 5 | from sparse_autoencoder.source_data.abstract_dataset import SourceDataset 6 | from sparse_autoencoder.source_data.mock_dataset import MockDataset 7 | 8 | 9 | @pytest.fixture() 10 | def mock_dataset() -> MockDataset: 11 | """Fixture to create a default ConsecutiveIntHuggingFaceDataset for testing. 12 | 13 | Returns: 14 | ConsecutiveIntHuggingFaceDataset: An instance of the dataset for testing. 15 | """ 16 | return MockDataset(context_size=10, buffer_size=100) 17 | 18 | 19 | def test_extended_dataset_initialization(mock_dataset: MockDataset) -> None: 20 | """Test the initialization of the extended dataset.""" 21 | assert mock_dataset is not None 22 | assert isinstance(mock_dataset, SourceDataset) 23 | 24 | 25 | def test_extended_dataset_iterator(mock_dataset: MockDataset) -> None: 26 | """Test the iterator of the extended dataset.""" 27 | iterator = iter(mock_dataset) 28 | assert iterator is not None 29 | 30 | 31 | def test_get_dataloader(mock_dataset: MockDataset) -> None: 32 | """Test the get_dataloader method of the extended dataset.""" 33 | batch_size = 3 34 | dataloader = mock_dataset.get_dataloader(batch_size=batch_size) 35 | first_item = next(iter(dataloader))["input_ids"] 36 | assert first_item.shape[0] == batch_size 37 | -------------------------------------------------------------------------------- /sparse_autoencoder/source_data/tests/test_mock_dataset.py: -------------------------------------------------------------------------------- 1 | """Tests for the mock dataset.""" 2 | import pytest 3 | import torch 4 | from torch import Tensor 5 | 6 | from sparse_autoencoder.source_data.mock_dataset import ConsecutiveIntHuggingFaceDataset 7 | 8 | 9 | class TestConsecutiveIntHuggingFaceDataset: 10 | """Tests for the ConsecutiveIntHuggingFaceDataset.""" 11 | 12 | @pytest.fixture(scope="class") 13 | def create_dataset(self) -> ConsecutiveIntHuggingFaceDataset: 14 | """Fixture to create a default ConsecutiveIntHuggingFaceDataset for testing. 15 | 16 | Returns: 17 | ConsecutiveIntHuggingFaceDataset: An instance of the dataset for testing. 18 | """ 19 | return ConsecutiveIntHuggingFaceDataset(context_size=10, vocab_size=1000, n_items=100) 20 | 21 | def test_dataset_initialization_failure(self) -> None: 22 | """Test invalid initialization failure of the ConsecutiveIntHuggingFaceDataset.""" 23 | with pytest.raises( 24 | ValueError, 25 | match=r"n_items \(\d+\) \+ context_size \(\d+\) must be less than vocab_size \(\d+\)", 26 | ): 27 | ConsecutiveIntHuggingFaceDataset(context_size=40, vocab_size=50, n_items=20) 28 | 29 | def test_dataset_len(self, create_dataset: ConsecutiveIntHuggingFaceDataset) -> None: 30 | """Test the __len__ method of the dataset. 31 | 32 | Args: 33 | create_dataset: Fixture to create a test dataset instance. 34 | """ 35 | expected_length = 100 36 | assert len(create_dataset) == expected_length, "Dataset length is not as expected." 37 | 38 | def test_dataset_getitem(self, create_dataset: ConsecutiveIntHuggingFaceDataset) -> None: 39 | """Test the __getitem__ method of the dataset. 40 | 41 | Args: 42 | create_dataset: Fixture to create a test dataset instance. 43 | """ 44 | item = create_dataset[0] 45 | assert isinstance(item, dict), "Item should be a dictionary." 46 | assert "input_ids" in item, "Item should have 'input_ids' key." 47 | assert isinstance(item["input_ids"], list), "input_ids should be a list." 48 | 49 | def test_create_data(self, create_dataset: ConsecutiveIntHuggingFaceDataset) -> None: 50 | """Test the create_data method of the dataset. 51 | 52 | Args: 53 | create_dataset: Fixture to create a test dataset instance. 54 | """ 55 | data: Tensor = create_dataset.create_data(n_items=10, context_size=5) 56 | assert data.shape == (10, 5), "Data shape is not as expected." 57 | 58 | def test_dataset_iteration(self, create_dataset: ConsecutiveIntHuggingFaceDataset) -> None: 59 | """Test the iteration functionality of the dataset. 60 | 61 | Args: 62 | create_dataset: Fixture to create a test dataset instance. 63 | """ 64 | items = [item["input_ids"] for item in create_dataset] 65 | 66 | # Check they are all unique 67 | items_tensor = torch.tensor(items) 68 | unique_items = torch.unique(items_tensor, dim=0) 69 | assert items_tensor.shape == unique_items.shape, "Items are not unique." 70 | -------------------------------------------------------------------------------- /sparse_autoencoder/source_data/tests/test_pretokenized_dataset.py: -------------------------------------------------------------------------------- 1 | """Tests for General Pre-Tokenized Dataset.""" 2 | import pytest 3 | 4 | from sparse_autoencoder.source_data.pretokenized_dataset import PreTokenizedDataset 5 | 6 | 7 | TEST_DATASET = "alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2" 8 | 9 | 10 | @pytest.mark.integration_test() 11 | @pytest.mark.parametrize("context_size", [128, 256]) 12 | def test_tokenized_prompts_correct_size(context_size: int) -> None: 13 | """Test that the tokenized prompts have the correct context size.""" 14 | data = PreTokenizedDataset(dataset_path=TEST_DATASET, context_size=context_size) 15 | 16 | # Check the first k items 17 | iterable = iter(data.dataset) 18 | for _ in range(2): 19 | item = next(iterable) 20 | assert len(item["input_ids"]) == context_size 21 | 22 | # Check the tokens are integers 23 | for token in item["input_ids"]: 24 | assert isinstance(token, int) 25 | 26 | 27 | @pytest.mark.integration_test() 28 | def test_fails_context_size_too_large() -> None: 29 | """Test that it fails if the context size is set as larger than the source dataset on HF.""" 30 | data = PreTokenizedDataset(dataset_path=TEST_DATASET, context_size=512) 31 | with pytest.raises(ValueError, match=r"larger than the tokenized prompt size"): 32 | next(iter(data)) 33 | -------------------------------------------------------------------------------- /sparse_autoencoder/source_data/tests/test_text_dataset.py: -------------------------------------------------------------------------------- 1 | """Pile Uncopyrighted Dataset Tests.""" 2 | import pytest 3 | from transformers import GPT2Tokenizer 4 | 5 | from sparse_autoencoder.source_data.text_dataset import TextDataset 6 | 7 | 8 | @pytest.mark.integration_test() 9 | @pytest.mark.parametrize("context_size", [50, 250]) 10 | def test_tokenized_prompts_correct_size(context_size: int) -> None: 11 | """Test that the tokenized prompts have the correct context size.""" 12 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 13 | 14 | data = TextDataset( 15 | tokenizer=tokenizer, context_size=context_size, dataset_path="monology/pile-uncopyrighted" 16 | ) 17 | 18 | # Check the first 100 items 19 | iterable = iter(data.dataset) 20 | for _ in range(100): 21 | item = next(iterable) 22 | assert len(item["input_ids"]) == context_size 23 | 24 | # Check the tokens are integers 25 | for token in item["input_ids"]: 26 | assert isinstance(token, int) 27 | 28 | 29 | @pytest.mark.integration_test() 30 | def test_dataloader_correct_size_items() -> None: 31 | """Test the dataloader returns the correct number & sized items.""" 32 | batch_size = 10 33 | context_size = 250 34 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 35 | data = TextDataset( 36 | tokenizer=tokenizer, context_size=context_size, dataset_path="monology/pile-uncopyrighted" 37 | ) 38 | dataloader = data.get_dataloader(batch_size=batch_size) 39 | 40 | checks = 100 41 | for item in dataloader: 42 | checks -= 1 43 | if checks == 0: 44 | break 45 | 46 | tokens = item["input_ids"] 47 | assert tokens.shape[0] == batch_size 48 | assert tokens.shape[1] == context_size 49 | -------------------------------------------------------------------------------- /sparse_autoencoder/source_data/text_dataset.py: -------------------------------------------------------------------------------- 1 | """Generic Text Dataset Module for Hugging Face Datasets. 2 | 3 | GenericTextDataset should work with the following datasets: 4 | - monology/pile-uncopyrighted 5 | - the_pile_openwebtext2 6 | - roneneldan/TinyStories 7 | """ 8 | from collections.abc import Mapping, Sequence 9 | from typing import TypedDict, final 10 | 11 | from datasets import IterableDataset 12 | from pydantic import PositiveInt, validate_call 13 | from transformers import PreTrainedTokenizerBase 14 | 15 | from sparse_autoencoder.source_data.abstract_dataset import SourceDataset, TokenizedPrompts 16 | 17 | 18 | class GenericTextDataBatch(TypedDict): 19 | """Generic Text Dataset Batch. 20 | 21 | Assumes the dataset provides a 'text' field with a list of strings. 22 | """ 23 | 24 | text: list[str] 25 | meta: list[dict[str, dict[str, str]]] # Optional, depending on the dataset structure. 26 | 27 | 28 | @final 29 | class TextDataset(SourceDataset[GenericTextDataBatch]): 30 | """Generic Text Dataset for any text-based dataset from Hugging Face.""" 31 | 32 | tokenizer: PreTrainedTokenizerBase 33 | 34 | def preprocess( 35 | self, 36 | source_batch: GenericTextDataBatch, 37 | *, 38 | context_size: int, 39 | ) -> TokenizedPrompts: 40 | """Preprocess a batch of prompts. 41 | 42 | Tokenizes and chunks text data into lists of tokenized prompts with specified context size. 43 | 44 | Args: 45 | source_batch: A batch of source data, including 'text' with a list of strings. 46 | context_size: Context size for tokenized prompts. 47 | 48 | Returns: 49 | Tokenized prompts. 50 | """ 51 | prompts: list[str] = source_batch["text"] 52 | 53 | tokenized_prompts = self.tokenizer(prompts, truncation=True, padding=False) 54 | 55 | # Chunk each tokenized prompt into blocks of context_size, discarding incomplete blocks. 56 | context_size_prompts = [] 57 | for encoding in list(tokenized_prompts[self._dataset_column_name]): # type: ignore 58 | chunks = [ 59 | encoding[i : i + context_size] 60 | for i in range(0, len(encoding), context_size) 61 | if len(encoding[i : i + context_size]) == context_size 62 | ] 63 | context_size_prompts.extend(chunks) 64 | 65 | return {"input_ids": context_size_prompts} 66 | 67 | @validate_call(config={"arbitrary_types_allowed": True}) 68 | def __init__( 69 | self, 70 | dataset_path: str, 71 | tokenizer: PreTrainedTokenizerBase, 72 | buffer_size: PositiveInt = 1000, 73 | context_size: PositiveInt = 256, 74 | dataset_dir: str | None = None, 75 | dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None, 76 | dataset_split: str = "train", 77 | dataset_column_name: str = "input_ids", 78 | n_processes_preprocessing: PositiveInt | None = None, 79 | preprocess_batch_size: PositiveInt = 1000, 80 | *, 81 | pre_download: bool = False, 82 | ): 83 | """Initialize a generic text dataset from Hugging Face. 84 | 85 | Args: 86 | dataset_path: Path to the dataset on Hugging Face (e.g. `'monology/pile-uncopyright'`). 87 | tokenizer: Tokenizer to process text data. 88 | buffer_size: The buffer size to use when shuffling the dataset when streaming. When 89 | streaming a dataset, this just pre-downloads at least `buffer_size` items and then 90 | shuffles just that buffer. Note that the generated activations should also be 91 | shuffled before training the sparse autoencoder, so a large buffer may not be 92 | strictly necessary here. Note also that this is the number of items in the dataset 93 | (e.g. number of prompts) and is typically significantly less than the number of 94 | tokenized prompts once the preprocessing function has been applied. 95 | context_size: The context size to use when returning a list of tokenized prompts. 96 | *Towards Monosemanticity: Decomposing Language Models With Dictionary Learning* used 97 | a context size of 250. 98 | dataset_dir: Defining the `data_dir` of the dataset configuration. 99 | dataset_files: Path(s) to source data file(s). 100 | dataset_split: Dataset split (e.g., 'train'). 101 | dataset_column_name: The column name for the prompts. 102 | n_processes_preprocessing: Number of processes to use for preprocessing. 103 | preprocess_batch_size: Batch size for preprocessing (tokenizing prompts). 104 | pre_download: Whether to pre-download the whole dataset. 105 | """ 106 | self.tokenizer = tokenizer 107 | 108 | super().__init__( 109 | buffer_size=buffer_size, 110 | context_size=context_size, 111 | dataset_dir=dataset_dir, 112 | dataset_files=dataset_files, 113 | dataset_path=dataset_path, 114 | dataset_split=dataset_split, 115 | dataset_column_name=dataset_column_name, 116 | n_processes_preprocessing=n_processes_preprocessing, 117 | pre_download=pre_download, 118 | preprocess_batch_size=preprocess_batch_size, 119 | ) 120 | 121 | @validate_call 122 | def push_to_hugging_face_hub( 123 | self, 124 | repo_id: str, 125 | commit_message: str = "Upload preprocessed dataset using sparse_autoencoder.", 126 | max_shard_size: str | None = None, 127 | n_shards: PositiveInt = 64, 128 | revision: str = "main", 129 | *, 130 | private: bool = False, 131 | ) -> None: 132 | """Share preprocessed dataset to Hugging Face hub. 133 | 134 | Motivation: 135 | Pre-processing a dataset can be time-consuming, so it is useful to be able to share the 136 | pre-processed dataset with others. This function allows you to do that by pushing the 137 | pre-processed dataset to the Hugging Face hub. 138 | 139 | Warning: 140 | You must be logged into HuggingFace (e.g with `huggingface-cli login` from the terminal) 141 | to use this. 142 | 143 | Warning: 144 | This will only work if the dataset is not streamed (i.e. if `pre_download=True` when 145 | initializing the dataset). 146 | 147 | Args: 148 | repo_id: Hugging Face repo ID to save the dataset to (e.g. `username/dataset_name`). 149 | commit_message: Commit message. 150 | max_shard_size: Maximum shard size (e.g. `'500MB'`). Should not be set if `n_shards` 151 | is set. 152 | n_shards: Number of shards to split the dataset into. A high number is recommended 153 | here to allow for flexible distributed training of SAEs across nodes (where e.g. 154 | each node fetches its own shard). 155 | revision: Branch to push to. 156 | private: Whether to save the dataset privately. 157 | 158 | Raises: 159 | TypeError: If the dataset is streamed. 160 | """ 161 | if isinstance(self.dataset, IterableDataset): 162 | error_message = ( 163 | "Cannot share a streamed dataset to Hugging Face. " 164 | "Please use `pre_download=True` when initializing the dataset." 165 | ) 166 | raise TypeError(error_message) 167 | 168 | self.dataset.push_to_hub( 169 | repo_id=repo_id, 170 | commit_message=commit_message, 171 | max_shard_size=max_shard_size, 172 | num_shards=n_shards, 173 | private=private, 174 | revision=revision, 175 | ) 176 | -------------------------------------------------------------------------------- /sparse_autoencoder/source_model/__init__.py: -------------------------------------------------------------------------------- 1 | """Source Model.""" 2 | -------------------------------------------------------------------------------- /sparse_autoencoder/source_model/replace_activations_hook.py: -------------------------------------------------------------------------------- 1 | """Replace activations hook.""" 2 | from typing import TYPE_CHECKING 3 | 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | from torch.nn import Module 7 | from torch.nn.parallel import DataParallel 8 | from transformer_lens.hook_points import HookPoint 9 | 10 | from sparse_autoencoder.autoencoder.lightning import LitSparseAutoencoder 11 | from sparse_autoencoder.autoencoder.model import SparseAutoencoder 12 | 13 | 14 | if TYPE_CHECKING: 15 | from sparse_autoencoder.tensor_types import Axis 16 | 17 | 18 | def replace_activations_hook( 19 | value: Tensor, 20 | hook: HookPoint, # noqa: ARG001 21 | sparse_autoencoder: SparseAutoencoder 22 | | DataParallel[SparseAutoencoder] 23 | | LitSparseAutoencoder 24 | | Module, 25 | component_idx: int | None = None, 26 | n_components: int | None = None, 27 | ) -> Tensor: 28 | """Replace activations hook. 29 | 30 | This should be pre-initialised with `functools.partial`. 31 | 32 | Args: 33 | value: The activations to replace. 34 | hook: The hook point. 35 | sparse_autoencoder: The sparse autoencoder. 36 | component_idx: The component index to replace the activations with, if just replacing 37 | activations for a single component. Requires the model to have a component axis. 38 | n_components: The number of components that the SAE is trained on. 39 | 40 | Returns: 41 | Replaced activations. 42 | 43 | Raises: 44 | RuntimeError: If `component_idx` is specified, but the model does not have a component 45 | """ 46 | # Squash to just have a "*items" and a "batch" dimension 47 | original_shape = value.shape 48 | 49 | squashed_value: Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)] = value.view( 50 | -1, value.size(-1) 51 | ) 52 | 53 | if component_idx is not None: 54 | if n_components is None: 55 | error_message = "The number of model components must be set if component_idx is set." 56 | raise RuntimeError(error_message) 57 | 58 | # The approach here is to run a forward pass with dummy values for all components other than 59 | # the one we want to replace. This is done by expanding the inputs to the SAE for a specific 60 | # component across all components. We then simply discard the activations for all other 61 | # components. 62 | expanded_shape = [ 63 | squashed_value.shape[0], 64 | n_components, 65 | squashed_value.shape[-1], 66 | ] 67 | expanded = squashed_value.unsqueeze(1).expand(*expanded_shape) 68 | 69 | _learned_activations, output_activations = sparse_autoencoder.forward(expanded) 70 | component_output_activations = output_activations[:, component_idx] 71 | 72 | return component_output_activations.view(*original_shape) 73 | 74 | # Get the output activations from a forward pass of the SAE 75 | _learned_activations, output_activations = sparse_autoencoder.forward(squashed_value) 76 | 77 | # Reshape to the original shape 78 | return output_activations.view(*original_shape) 79 | -------------------------------------------------------------------------------- /sparse_autoencoder/source_model/reshape_activations.py: -------------------------------------------------------------------------------- 1 | """Methods to reshape activation tensors.""" 2 | from collections.abc import Callable 3 | from functools import reduce 4 | from typing import TypeAlias 5 | 6 | from einops import rearrange 7 | from jaxtyping import Float 8 | from torch import Tensor 9 | 10 | from sparse_autoencoder.tensor_types import Axis 11 | 12 | 13 | ReshapeActivationsFunction: TypeAlias = Callable[ 14 | [Float[Tensor, Axis.names(Axis.ANY)]], 15 | Float[Tensor, Axis.names(Axis.STORE_BATCH, Axis.INPUT_OUTPUT_FEATURE)], 16 | ] 17 | """Reshape Activations Function. 18 | 19 | Used within hooks to e.g. reshape activations before storing them in the activation store. 20 | """ 21 | 22 | 23 | def reshape_to_last_dimension( 24 | batch_activations: Float[Tensor, Axis.names(Axis.ANY)], 25 | ) -> Float[Tensor, Axis.names(Axis.STORE_BATCH, Axis.INPUT_OUTPUT_FEATURE)]: 26 | """Reshape to Last Dimension. 27 | 28 | Takes a tensor of activation vectors, with arbitrary numbers of dimensions (the last of which is 29 | the neurons dimension), and returns a single tensor of size [item, neurons]. 30 | 31 | Examples: 32 | With 2 axis (e.g. pos neuron): 33 | 34 | >>> import torch 35 | >>> input = torch.rand(3, 100) 36 | >>> res = reshape_to_last_dimension(input) 37 | >>> res.shape 38 | torch.Size([3, 100]) 39 | 40 | With 3 axis (e.g. batch, pos, neuron): 41 | 42 | >>> input = torch.randn(3, 3, 100) 43 | >>> res = reshape_to_last_dimension(input) 44 | >>> res.shape 45 | torch.Size([9, 100]) 46 | 47 | With 4 axis (e.g. batch, pos, head_idx, neuron) 48 | 49 | >>> input = torch.rand(3, 3, 3, 100) 50 | >>> res = reshape_to_last_dimension(input) 51 | >>> res.shape 52 | torch.Size([27, 100]) 53 | 54 | Args: 55 | batch_activations: Input Activation Store Batch 56 | 57 | Returns: 58 | Single Tensor of Activation Store Items 59 | """ 60 | return rearrange(batch_activations, "... input_output_feature -> (...) input_output_feature") 61 | 62 | 63 | def reshape_concat_last_dimensions( 64 | batch_activations: Float[Tensor, Axis.names(Axis.ANY)], 65 | concat_dims: int, 66 | ) -> Float[Tensor, Axis.names(Axis.STORE_BATCH, Axis.INPUT_OUTPUT_FEATURE)]: 67 | """Reshape to Last Dimension, Concatenating the Specified Dimensions. 68 | 69 | Takes a tensor of activation vectors, with arbitrary numbers of dimensions (the last 70 | `concat_dims` of which are the neuron dimensions), and returns a single tensor of size 71 | [item, neurons]. 72 | 73 | Examples: 74 | With 3 axis (e.g. batch, pos, neuron), concatenating the last 2 dimensions: 75 | 76 | >>> import torch 77 | >>> input = torch.randn(3, 4, 5) 78 | >>> res = reshape_concat_last_dimensions(input, 2) 79 | >>> res.shape 80 | torch.Size([3, 20]) 81 | 82 | With 4 axis (e.g. batch, pos, head_idx, neuron), concatenating the last 3 dimensions: 83 | 84 | >>> input = torch.rand(2, 3, 4, 5) 85 | >>> res = reshape_concat_last_dimensions(input, 3) 86 | >>> res.shape 87 | torch.Size([2, 60]) 88 | 89 | Args: 90 | batch_activations: Input Activation Store Batch 91 | concat_dims: Number of dimensions to concatenate 92 | 93 | Returns: 94 | Single Tensor of Activation Store Items 95 | """ 96 | neurons = reduce(lambda x, y: x * y, batch_activations.shape[-concat_dims:]) 97 | items = reduce(lambda x, y: x * y, batch_activations.shape[:-concat_dims]) 98 | 99 | return batch_activations.reshape(items, neurons) 100 | -------------------------------------------------------------------------------- /sparse_autoencoder/source_model/store_activations_hook.py: -------------------------------------------------------------------------------- 1 | """TransformerLens Hook for storing activations.""" 2 | from jaxtyping import Float 3 | from torch import Tensor 4 | from transformer_lens.hook_points import HookPoint 5 | 6 | from sparse_autoencoder.activation_store.base_store import ActivationStore 7 | from sparse_autoencoder.source_model.reshape_activations import ( 8 | ReshapeActivationsFunction, 9 | reshape_to_last_dimension, 10 | ) 11 | from sparse_autoencoder.tensor_types import Axis 12 | 13 | 14 | def store_activations_hook( 15 | value: Float[Tensor, Axis.names(Axis.ANY)], 16 | hook: HookPoint, # noqa: ARG001 17 | store: ActivationStore, 18 | reshape_method: ReshapeActivationsFunction = reshape_to_last_dimension, 19 | component_idx: int = 0, 20 | ) -> Float[Tensor, Axis.names(Axis.ANY)]: 21 | """Store Activations Hook. 22 | 23 | Useful for getting just the specific activations wanted, rather than the full cache. 24 | 25 | Example: 26 | First we'll need a source model from TransformerLens and an activation store. 27 | 28 | >>> from functools import partial 29 | >>> from transformer_lens import HookedTransformer 30 | >>> from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore 31 | >>> store = TensorActivationStore(max_items=1000, n_neurons=64, n_components=1) 32 | >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 33 | Loaded pretrained model tiny-stories-1M into HookedTransformer 34 | 35 | Next we can add the hook to specific neurons (in this case the first MLP neurons), and 36 | create the tokens for a forward pass. 37 | 38 | >>> model.add_hook( 39 | ... "blocks.0.hook_mlp_out", partial(store_activations_hook, store=store) 40 | ... ) 41 | >>> tokens = model.to_tokens("Hello world") 42 | >>> tokens.shape 43 | torch.Size([1, 3]) 44 | 45 | Then when we run the model, we should get one activation vector for each token (as we just 46 | have one batch item). Note we also set `stop_at_layer=1` as we don't need the logits or any 47 | other activations after the hook point that we've specified (in this case the first MLP 48 | layer). 49 | 50 | >>> _output = model.forward("Hello world", stop_at_layer=1) # Change this layer as required 51 | >>> len(store) 52 | 3 53 | 54 | Args: 55 | value: The activations to store. 56 | hook: The hook point. 57 | store: The activation store. This should be pre-initialised with `functools.partial`. 58 | reshape_method: The method to reshape the activations before storing them. 59 | component_idx: The component index of the activations to store. 60 | 61 | Returns: 62 | Unmodified activations. 63 | """ 64 | reshaped: Float[ 65 | Tensor, Axis.names(Axis.STORE_BATCH, Axis.INPUT_OUTPUT_FEATURE) 66 | ] = reshape_method(value) 67 | 68 | store.extend(reshaped, component_idx=component_idx) 69 | 70 | # Return the unmodified value 71 | return value 72 | -------------------------------------------------------------------------------- /sparse_autoencoder/source_model/tests/test_replace_activations_hook.py: -------------------------------------------------------------------------------- 1 | """Replace activations hook tests.""" 2 | from functools import partial 3 | 4 | from jaxtyping import Int 5 | import pytest 6 | import torch 7 | from torch import Tensor 8 | from transformer_lens import HookedTransformer 9 | 10 | from sparse_autoencoder.autoencoder.model import SparseAutoencoder, SparseAutoencoderConfig 11 | from sparse_autoencoder.source_model.replace_activations_hook import replace_activations_hook 12 | from sparse_autoencoder.tensor_types import Axis 13 | 14 | 15 | @pytest.mark.integration_test() 16 | def test_hook_replaces_activations() -> None: 17 | """Test that the hook replaces activations.""" 18 | torch.random.manual_seed(0) 19 | source_model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu") 20 | autoencoder = SparseAutoencoder( 21 | SparseAutoencoderConfig( 22 | n_input_features=source_model.cfg.d_model, 23 | n_learned_features=source_model.cfg.d_model * 2, 24 | ) 25 | ) 26 | 27 | tokens: Int[Tensor, Axis.names(Axis.SOURCE_DATA_BATCH, Axis.POSITION)] = source_model.to_tokens( 28 | "Hello world" 29 | ) 30 | loss_without_hook = source_model.forward(tokens, return_type="loss") 31 | loss_with_hook = source_model.run_with_hooks( 32 | tokens, 33 | return_type="loss", 34 | fwd_hooks=[ 35 | ( 36 | "blocks.0.hook_mlp_out", 37 | partial(replace_activations_hook, sparse_autoencoder=autoencoder), 38 | ) 39 | ], 40 | ) 41 | 42 | # Check it decrease performance (as the sae is untrained so it will output nonsense). 43 | assert torch.all(torch.gt(loss_with_hook, loss_without_hook)) 44 | 45 | 46 | @pytest.mark.integration_test() 47 | def test_hook_replaces_activations_2_components() -> None: 48 | """Test that the hook replaces activations.""" 49 | torch.random.manual_seed(0) 50 | source_model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu") 51 | autoencoder = SparseAutoencoder( 52 | SparseAutoencoderConfig( 53 | n_input_features=source_model.cfg.d_model, 54 | n_learned_features=source_model.cfg.d_model * 2, 55 | n_components=2, 56 | ) 57 | ) 58 | 59 | tokens: Int[Tensor, Axis.names(Axis.SOURCE_DATA_BATCH, Axis.POSITION)] = source_model.to_tokens( 60 | "Hello world" 61 | ) 62 | loss_without_hook = source_model.forward(tokens, return_type="loss") 63 | loss_with_hook = source_model.run_with_hooks( 64 | tokens, 65 | return_type="loss", 66 | fwd_hooks=[ 67 | ( 68 | "blocks.0.hook_mlp_out", 69 | partial( 70 | replace_activations_hook, 71 | sparse_autoencoder=autoencoder, 72 | component_idx=1, 73 | n_components=2, 74 | ), 75 | ) 76 | ], 77 | ) 78 | 79 | # Check it decrease performance (as the sae is untrained so it will output nonsense). 80 | assert torch.all(torch.gt(loss_with_hook, loss_without_hook)) 81 | -------------------------------------------------------------------------------- /sparse_autoencoder/source_model/tests/test_store_activations_hook.py: -------------------------------------------------------------------------------- 1 | """Store Activations Hook Tests.""" 2 | from functools import partial 3 | 4 | from jaxtyping import Int 5 | import pytest 6 | import torch 7 | from torch import Tensor 8 | from transformer_lens import HookedTransformer 9 | 10 | from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore 11 | from sparse_autoencoder.source_model.store_activations_hook import store_activations_hook 12 | from sparse_autoencoder.tensor_types import Axis 13 | 14 | 15 | @pytest.mark.integration_test() 16 | def test_hook_stores_activations() -> None: 17 | """Test that the hook stores activations correctly.""" 18 | store = TensorActivationStore(max_items=100, n_neurons=256, n_components=1) 19 | 20 | model = HookedTransformer.from_pretrained("tiny-stories-1M") 21 | 22 | model.add_hook( 23 | "blocks.0.mlp.hook_post", 24 | partial(store_activations_hook, store=store), 25 | ) 26 | 27 | tokens: Int[Tensor, Axis.names(Axis.SOURCE_DATA_BATCH, Axis.POSITION)] = model.to_tokens( 28 | "Hello world" 29 | ) 30 | logits = model.forward(tokens, stop_at_layer=2) # type: ignore 31 | 32 | n_of_tokens = tokens.numel() 33 | mlp_size: int = model.cfg.d_mlp # type: ignore 34 | 35 | assert len(store) == n_of_tokens 36 | assert store[0, 0].shape[0] == mlp_size 37 | assert torch.is_tensor(logits) # Check the forward pass completed 38 | -------------------------------------------------------------------------------- /sparse_autoencoder/source_model/tests/test_zero_ablate_hook.py: -------------------------------------------------------------------------------- 1 | """Test the zero ablate hook.""" 2 | import pytest 3 | import torch 4 | from transformer_lens.hook_points import HookPoint 5 | 6 | from sparse_autoencoder.source_model.zero_ablate_hook import zero_ablate_hook 7 | 8 | 9 | class MockHookPoint(HookPoint): 10 | """Mock HookPoint class.""" 11 | 12 | 13 | @pytest.fixture() 14 | def mock_hook_point() -> MockHookPoint: 15 | """Fixture to provide a mock HookPoint instance.""" 16 | return MockHookPoint() 17 | 18 | 19 | def test_zero_ablate_hook_with_standard_tensor(mock_hook_point: MockHookPoint) -> None: 20 | """Test zero_ablate_hook with a standard tensor. 21 | 22 | Args: 23 | mock_hook_point: A mock HookPoint instance. 24 | """ 25 | value = torch.ones(3, 4) 26 | expected = torch.zeros(3, 4) 27 | result = zero_ablate_hook(value, mock_hook_point) 28 | assert torch.equal(result, expected), "The output tensor should contain only zeros." 29 | 30 | 31 | @pytest.mark.parametrize("shape", [(10,), (5, 5), (2, 3, 4)]) 32 | def test_zero_ablate_hook_with_various_shapes( 33 | mock_hook_point: MockHookPoint, shape: tuple[int, ...] 34 | ) -> None: 35 | """Test zero_ablate_hook with tensors of various shapes. 36 | 37 | Args: 38 | mock_hook_point: A mock HookPoint instance. 39 | shape: A tuple representing the shape of the tensor. 40 | """ 41 | value = torch.ones(*shape) 42 | expected = torch.zeros(*shape) 43 | result = zero_ablate_hook(value, mock_hook_point) 44 | assert torch.equal( 45 | result, expected 46 | ), f"The output tensor should be of shape {shape} with zeros." 47 | 48 | 49 | def test_float_dtype_maintained(mock_hook_point: MockHookPoint) -> None: 50 | """Test that the float dtype is maintained. 51 | 52 | Args: 53 | mock_hook_point: A mock HookPoint instance. 54 | """ 55 | value = torch.ones(3, 4, dtype=torch.float) 56 | result = zero_ablate_hook(value, mock_hook_point) 57 | assert result.dtype == torch.float, "The output tensor should be of dtype float." 58 | -------------------------------------------------------------------------------- /sparse_autoencoder/source_model/zero_ablate_hook.py: -------------------------------------------------------------------------------- 1 | """Zero ablate hook.""" 2 | import torch 3 | from torch import Tensor 4 | from transformer_lens.hook_points import HookPoint 5 | 6 | 7 | def zero_ablate_hook( 8 | value: Tensor, 9 | hook: HookPoint, # noqa: ARG001 10 | ) -> Tensor: 11 | """Zero ablate hook. 12 | 13 | Args: 14 | value: The activations to store. 15 | hook: The hook point. 16 | 17 | Example: 18 | >>> dummy_hook_point = HookPoint() 19 | >>> value = torch.ones(2, 3) 20 | >>> zero_ablate_hook(value, dummy_hook_point) 21 | tensor([[0., 0., 0.], 22 | [0., 0., 0.]]) 23 | 24 | Returns: 25 | Replaced activations. 26 | """ 27 | return torch.zeros_like(value) 28 | -------------------------------------------------------------------------------- /sparse_autoencoder/tensor_types.py: -------------------------------------------------------------------------------- 1 | """Tensor Axis Types.""" 2 | from enum import auto 3 | 4 | from strenum import LowercaseStrEnum 5 | 6 | 7 | class Axis(LowercaseStrEnum): 8 | """Tensor axis names. 9 | 10 | Used to annotate tensor types. 11 | 12 | Example: 13 | When used directly it prints a string: 14 | 15 | >>> print(Axis.INPUT_OUTPUT_FEATURE) 16 | input_output_feature 17 | 18 | The primary use is to annotate tensor types: 19 | 20 | >>> from jaxtyping import Float 21 | >>> from torch import Tensor 22 | >>> from typing import TypeAlias 23 | >>> batch: TypeAlias = Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)] 24 | >>> print(batch) 25 | 26 | 27 | You can also join multiple axis together to represent the dimensions of a tensor: 28 | 29 | >>> print(Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)) 30 | batch input_output_feature 31 | """ 32 | 33 | # Component idx 34 | COMPONENT = auto() 35 | """Component index.""" 36 | 37 | COMPONENT_OPTIONAL = "*component" 38 | """Optional component index.""" 39 | 40 | # Batches 41 | SOURCE_DATA_BATCH = auto() 42 | """Batch of prompts used to generate source model activations.""" 43 | 44 | BATCH = auto() 45 | """Batch of items that the SAE is being trained on.""" 46 | 47 | STORE_BATCH = auto() 48 | """Batch of items to be written to the store.""" 49 | 50 | ITEMS = auto() 51 | """Arbitrary number of items.""" 52 | 53 | # Features 54 | INPUT_OUTPUT_FEATURE = auto() 55 | """Input or output feature (e.g. feature in activation vector from source model).""" 56 | 57 | LEARNT_FEATURE = auto() 58 | """Learn feature (e.g. feature in learnt activation vector).""" 59 | 60 | DEAD_FEATURE = auto() 61 | """Dead feature.""" 62 | 63 | ALIVE_FEATURE = auto() 64 | """Alive feature.""" 65 | 66 | # Feature indices 67 | INPUT_OUTPUT_FEATURE_IDX = auto() 68 | """Input or output feature index.""" 69 | 70 | LEARNT_FEATURE_IDX = auto() 71 | """Learn feature index.""" 72 | 73 | # Other 74 | POSITION = auto() 75 | """Token position.""" 76 | 77 | SINGLE_ITEM = "" 78 | """Single item axis.""" 79 | 80 | ANY = "..." 81 | """Any number of axis.""" 82 | 83 | @staticmethod 84 | def names(*axis: "Axis") -> str: 85 | """Join multiple axis together, to represent the dimensions of a tensor. 86 | 87 | Example: 88 | >>> print(Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)) 89 | batch input_output_feature 90 | 91 | Args: 92 | *axis: Axis to join. 93 | 94 | Returns: 95 | Joined axis string. 96 | """ 97 | return " ".join(a.value for a in axis) 98 | -------------------------------------------------------------------------------- /sparse_autoencoder/train/__init__.py: -------------------------------------------------------------------------------- 1 | """Train.""" 2 | -------------------------------------------------------------------------------- /sparse_autoencoder/train/join_sweep.py: -------------------------------------------------------------------------------- 1 | """Join an existing Weights and Biases sweep, as a new agent.""" 2 | import argparse 3 | 4 | from sparse_autoencoder.train.sweep import sweep 5 | 6 | 7 | def parse_arguments() -> argparse.Namespace: 8 | """Parse command line arguments. 9 | 10 | Returns: 11 | argparse.Namespace: Parsed command line arguments. 12 | """ 13 | parser = argparse.ArgumentParser(description="Join an existing W&B sweep.") 14 | parser.add_argument( 15 | "--id", type=str, default=None, help="Sweep ID for the existing sweep.", required=True 16 | ) 17 | return parser.parse_args() 18 | 19 | 20 | def run() -> None: 21 | """Run the join_sweep script.""" 22 | args = parse_arguments() 23 | 24 | sweep(sweep_id=args.id) 25 | 26 | 27 | if __name__ == "__main__": 28 | run() 29 | -------------------------------------------------------------------------------- /sparse_autoencoder/train/tests/__snapshots__/test_sweep.ambr: -------------------------------------------------------------------------------- 1 | # serializer version: 1 2 | # name: test_setup_autoencoder 3 | ''' 4 | LitSparseAutoencoder( 5 | (sparse_autoencoder): SparseAutoencoder( 6 | (pre_encoder_bias): TiedBias(position=pre_encoder) 7 | (encoder): LinearEncoder( 8 | input_features=512, learnt_features=2048, n_components=1 9 | (activation_function): ReLU() 10 | ) 11 | (decoder): UnitNormDecoder(learnt_features=2048, decoded_features=512, n_components=1) 12 | (post_decoder_bias): TiedBias(position=post_decoder) 13 | ) 14 | (loss_fn): SparseAutoencoderLoss() 15 | (train_metrics): MetricCollection( 16 | (activity): ClasswiseWrapperWithMean( 17 | (metric): NeuronActivityMetric() 18 | ) 19 | (l0): ClasswiseWrapperWithMean( 20 | (metric): L0NormMetric() 21 | ) 22 | (l1): ClasswiseWrapperWithMean( 23 | (metric): L1AbsoluteLoss() 24 | ) 25 | (l2): ClasswiseWrapperWithMean( 26 | (metric): L2ReconstructionLoss() 27 | ) 28 | (loss): ClasswiseWrapperWithMean( 29 | (metric): SparseAutoencoderLoss() 30 | ) 31 | ) 32 | (activation_resampler): ActivationResampler() 33 | ) 34 | ''' 35 | # --- 36 | -------------------------------------------------------------------------------- /sparse_autoencoder/train/tests/test_sweep.py: -------------------------------------------------------------------------------- 1 | """Tests for sweep functionality.""" 2 | import pytest 3 | from syrupy.session import SnapshotSession 4 | 5 | from sparse_autoencoder.train.sweep import setup_autoencoder 6 | from sparse_autoencoder.train.sweep_config import ( 7 | RuntimeHyperparameters, 8 | ) 9 | 10 | 11 | @pytest.fixture() 12 | def dummy_hyperparameters() -> RuntimeHyperparameters: 13 | """Sweep config dummy fixture.""" 14 | return { 15 | "activation_resampler": { 16 | "threshold_is_dead_portion_fires": 0.0, 17 | "max_n_resamples": 4, 18 | "n_activations_activity_collate": 100_000_000, 19 | "resample_dataset_size": 819_200, 20 | "resample_interval": 200_000_000, 21 | }, 22 | "autoencoder": {"expansion_factor": 4}, 23 | "loss": {"l1_coefficient": 0.0001}, 24 | "optimizer": { 25 | "adam_beta_1": 0.9, 26 | "adam_beta_2": 0.99, 27 | "adam_weight_decay": 0.0, 28 | "amsgrad": False, 29 | "fused": False, 30 | "lr": 1e-05, 31 | "lr_scheduler": None, 32 | }, 33 | "pipeline": { 34 | "checkpoint_frequency": 100000000, 35 | "log_frequency": 100, 36 | "max_activations": 2000000000, 37 | "max_store_size": 3145728, 38 | "source_data_batch_size": 12, 39 | "train_batch_size": 4096, 40 | "validation_frequency": 314572800, 41 | "validation_n_activations": 1024, 42 | "num_workers_data_loading": 0, 43 | }, 44 | "random_seed": 49, 45 | "source_data": { 46 | "context_size": 128, 47 | "dataset_column_name": "input_ids", 48 | "dataset_dir": None, 49 | "dataset_files": None, 50 | "dataset_path": "NeelNanda/c4-code-tokenized-2b", 51 | "pre_download": False, 52 | "pre_tokenized": True, 53 | "tokenizer_name": None, 54 | }, 55 | "source_model": { 56 | "dtype": "float32", 57 | "hook_dimension": 512, 58 | "cache_names": ["mlp_out"], 59 | "name": "gelu-2l", 60 | }, 61 | } 62 | 63 | 64 | def test_setup_autoencoder( 65 | dummy_hyperparameters: RuntimeHyperparameters, snapshot: SnapshotSession 66 | ) -> None: 67 | """Test the setup_autoencoder function.""" 68 | autoencoder = setup_autoencoder(dummy_hyperparameters) 69 | assert snapshot == str(autoencoder), "Autoencoder string representation has changed." 70 | -------------------------------------------------------------------------------- /sparse_autoencoder/train/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Train Utils.""" 2 | -------------------------------------------------------------------------------- /sparse_autoencoder/train/utils/get_model_device.py: -------------------------------------------------------------------------------- 1 | """Get the device that the model is on.""" 2 | from lightning import LightningModule 3 | import torch 4 | from torch.nn import Module 5 | from torch.nn.parallel import DataParallel 6 | 7 | 8 | def get_model_device(model: Module | DataParallel | LightningModule) -> torch.device | None: 9 | """Get the device on which a PyTorch model is on. 10 | 11 | Args: 12 | model: The PyTorch model. 13 | 14 | Returns: 15 | The device ('cuda' or 'cpu') where the model is located. 16 | 17 | Raises: 18 | ValueError: If the model has no parameters. 19 | """ 20 | # Deepspeed models already have a device property, so just return that 21 | if hasattr(model, "device"): 22 | return model.device 23 | 24 | # Tensors for lightning should not have device set (as lightning will handle this) 25 | if isinstance(model, LightningModule): 26 | return None 27 | 28 | # Check if the model has parameters 29 | if len(list(model.parameters())) == 0: 30 | exception_message = "The model has no parameters." 31 | raise ValueError(exception_message) 32 | 33 | # Return the device of the first parameter 34 | return next(model.parameters()).device 35 | -------------------------------------------------------------------------------- /sparse_autoencoder/train/utils/round_down.py: -------------------------------------------------------------------------------- 1 | """Round down to the nearest multiple.""" 2 | 3 | 4 | def round_to_multiple(value: int | float, multiple: int) -> int: # noqa: PYI041 5 | """Round down to the nearest multiple. 6 | 7 | Helper function for creating default values. 8 | 9 | Example: 10 | >>> round_to_multiple(1023, 100) 11 | 1000 12 | 13 | Args: 14 | value: The value to round down. 15 | multiple: The multiple to round down to. 16 | 17 | Returns: 18 | The value rounded down to the nearest multiple. 19 | 20 | Raises: 21 | ValueError: If `value` is less than `multiple`. 22 | """ 23 | int_value = int(value) 24 | 25 | if int_value < multiple: 26 | error_message = f"{value=} must be greater than or equal to {multiple=}" 27 | raise ValueError(error_message) 28 | 29 | return int_value - int_value % multiple 30 | -------------------------------------------------------------------------------- /sparse_autoencoder/train/utils/tests/test_get_model_device.py: -------------------------------------------------------------------------------- 1 | """Test get_model_device.py.""" 2 | import pytest 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import Linear, Module 6 | 7 | from sparse_autoencoder.train.utils.get_model_device import get_model_device 8 | 9 | 10 | def test_model_on_cpu() -> None: 11 | """Test that it returns CPU when the model is on CPU.""" 12 | 13 | class TestModel(Module): 14 | """Test model.""" 15 | 16 | def __init__(self) -> None: 17 | """Initialize the model.""" 18 | super().__init__() 19 | self.fc = Linear(10, 5) 20 | 21 | def forward(self, x: Tensor) -> Tensor: 22 | """Forward pass.""" 23 | return self.fc(x) 24 | 25 | model = TestModel() 26 | model.to("cpu") 27 | assert get_model_device(model) == torch.device("cpu"), "Device should be CPU" 28 | 29 | 30 | # Test with a model that has no parameters 31 | def test_model_no_parameters() -> None: 32 | """Test that it raises a ValueError when the model has no parameters.""" 33 | 34 | class EmptyModel(Module): 35 | def forward(self, x: Tensor) -> Tensor: 36 | return x 37 | 38 | model = EmptyModel() 39 | with pytest.raises(ValueError, match="The model has no parameters."): 40 | _ = get_model_device(model) 41 | -------------------------------------------------------------------------------- /sparse_autoencoder/train/utils/tests/test_wandb_sweep_types.py: -------------------------------------------------------------------------------- 1 | """Test wandb sweep types.""" 2 | from dataclasses import dataclass, field 3 | 4 | from sparse_autoencoder.train.utils.wandb_sweep_types import ( 5 | Method, 6 | Metric, 7 | NestedParameter, 8 | Parameter, 9 | Parameters, 10 | WandbSweepConfig, 11 | ) 12 | 13 | 14 | class TestNestedParameter: 15 | """NestedParameter tests.""" 16 | 17 | def test_to_dict(self) -> None: 18 | """Test to_dict method.""" 19 | 20 | @dataclass(frozen=True) 21 | class DummyNestedParameter(NestedParameter): 22 | nested_property: Parameter[float] = field(default=Parameter(1.0)) 23 | 24 | dummy = DummyNestedParameter() 25 | 26 | # It should be in the nested "parameters" key. 27 | assert dummy.to_dict() == {"parameters": {"nested_property": {"value": 1.0}}} 28 | 29 | 30 | class TestWandbSweepConfig: 31 | """WandbSweepConfig tests.""" 32 | 33 | def test_to_dict(self) -> None: 34 | """Test to_dict method.""" 35 | 36 | @dataclass(frozen=True) 37 | class DummyNestedParameter(NestedParameter): 38 | nested_property: Parameter[float] = field(default=Parameter(1.0)) 39 | 40 | @dataclass 41 | class DummyParameters(Parameters): 42 | nested: DummyNestedParameter = field(default=DummyNestedParameter()) 43 | top_level: Parameter[float] = field(default=Parameter(1.0)) 44 | 45 | dummy = WandbSweepConfig( 46 | parameters=DummyParameters(), method=Method.GRID, metric=Metric(name="total_loss") 47 | ) 48 | 49 | assert dummy.to_dict() == { 50 | "method": "grid", 51 | "metric": {"goal": "minimize", "name": "total_loss"}, 52 | "parameters": { 53 | "nested": { 54 | "parameters": {"nested_property": {"value": 1.0}}, 55 | }, 56 | "top_level": {"value": 1.0}, 57 | }, 58 | } 59 | -------------------------------------------------------------------------------- /sparse_autoencoder/training_runs/__init__.py: -------------------------------------------------------------------------------- 1 | """Training runs.""" 2 | -------------------------------------------------------------------------------- /sparse_autoencoder/training_runs/gpt2.py: -------------------------------------------------------------------------------- 1 | """Run an sweep on all layers of GPT2 Small. 2 | 3 | Command: 4 | 5 | ```bash 6 | git clone https://github.com/ai-safety-foundation/sparse_autoencoder.git && cd sparse_autoencoder && 7 | poetry env use python3.11 && poetry install && 8 | poetry run python sparse_autoencoder/training_runs/gpt2.py 9 | ``` 10 | """ 11 | import os 12 | 13 | from sparse_autoencoder import ( 14 | ActivationResamplerHyperparameters, 15 | AutoencoderHyperparameters, 16 | Hyperparameters, 17 | LossHyperparameters, 18 | Method, 19 | OptimizerHyperparameters, 20 | Parameter, 21 | PipelineHyperparameters, 22 | SourceDataHyperparameters, 23 | SourceModelHyperparameters, 24 | SweepConfig, 25 | sweep, 26 | ) 27 | 28 | 29 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 30 | 31 | 32 | def train() -> None: 33 | """Train.""" 34 | sweep_config = SweepConfig( 35 | parameters=Hyperparameters( 36 | loss=LossHyperparameters( 37 | l1_coefficient=Parameter(values=[0.0001]), 38 | ), 39 | optimizer=OptimizerHyperparameters( 40 | lr=Parameter(value=0.0001), 41 | ), 42 | source_model=SourceModelHyperparameters( 43 | name=Parameter("gpt2"), 44 | cache_names=Parameter( 45 | value=[f"blocks.{layer}.hook_mlp_out" for layer in range(12)] 46 | ), 47 | hook_dimension=Parameter(768), 48 | ), 49 | source_data=SourceDataHyperparameters( 50 | dataset_path=Parameter("alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2"), 51 | context_size=Parameter(256), 52 | pre_tokenized=Parameter(value=True), 53 | pre_download=Parameter(value=True), 54 | # Total dataset is c.7bn activations (64 files) 55 | # C. 1.5TB needed to store all activations 56 | dataset_files=Parameter( 57 | [f"data/train-{str(i).zfill(5)}-of-00064.parquet" for i in range(20)] 58 | ), 59 | ), 60 | autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(values=[32, 64])), 61 | pipeline=PipelineHyperparameters(), 62 | activation_resampler=ActivationResamplerHyperparameters( 63 | threshold_is_dead_portion_fires=Parameter(1e-5), 64 | ), 65 | ), 66 | method=Method.GRID, 67 | ) 68 | 69 | sweep(sweep_config=sweep_config) 70 | 71 | 72 | if __name__ == "__main__": 73 | train() 74 | -------------------------------------------------------------------------------- /sparse_autoencoder/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Shared utils.""" 2 | -------------------------------------------------------------------------------- /sparse_autoencoder/utils/data_parallel.py: -------------------------------------------------------------------------------- 1 | """Data parallel utils.""" 2 | from typing import Any, Generic, TypeVar 3 | 4 | from torch.nn import DataParallel, Module 5 | 6 | 7 | T = TypeVar("T", bound=Module) 8 | 9 | 10 | class DataParallelWithModelAttributes(DataParallel[T], Generic[T]): 11 | """Data parallel with access to underlying model attributes/methods. 12 | 13 | Allows access to underlying model attributes/methods, which is not possible with the default 14 | `DataParallel` class. Based on: 15 | https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html 16 | 17 | Example: 18 | >>> from sparse_autoencoder import SparseAutoencoder, SparseAutoencoderConfig 19 | >>> model = SparseAutoencoder(SparseAutoencoderConfig( 20 | ... n_input_features=2, 21 | ... n_learned_features=4, 22 | ... )) 23 | >>> distributed_model = DataParallelWithModelAttributes(model) 24 | >>> distributed_model.config.n_learned_features 25 | 4 26 | """ 27 | 28 | def __getattr__(self, name: str) -> Any: # noqa: ANN401 29 | """Allow access to underlying model attributes/methods. 30 | 31 | Args: 32 | name: Attribute/method name. 33 | 34 | Returns: 35 | Attribute value/method. 36 | """ 37 | try: 38 | return super().__getattr__(name) 39 | except AttributeError: 40 | return getattr(self.module, name) 41 | -------------------------------------------------------------------------------- /sparse_autoencoder/utils/tensor_shape.py: -------------------------------------------------------------------------------- 1 | """Tensor shape utilities.""" 2 | 3 | 4 | def shape_with_optional_dimensions(*shape: int | None) -> tuple[int, ...]: 5 | """Create a shape from a tuple of optional dimensions. 6 | 7 | Motivation: 8 | By default PyTorch tensor shapes will error if you set an axis to `None`. This allows 9 | you to set that size and then the resulting output simply removes that axis. 10 | 11 | Examples: 12 | >>> shape_with_optional_dimensions(1, 2, 3) 13 | (1, 2, 3) 14 | 15 | >>> shape_with_optional_dimensions(1, None, 3) 16 | (1, 3) 17 | 18 | >>> shape_with_optional_dimensions(1, None, None) 19 | (1,) 20 | 21 | >>> shape_with_optional_dimensions(None, None, None) 22 | () 23 | 24 | Args: 25 | *shape: Axis sizes, with `None` representing an optional axis. 26 | 27 | Returns: 28 | Axis sizes. 29 | """ 30 | return tuple(dimension for dimension in shape if dimension is not None) 31 | --------------------------------------------------------------------------------