├── .devcontainer
├── Dockerfile
└── devcontainer.json
├── .github
└── workflows
│ ├── checks.yml
│ ├── gh-pages.yml
│ └── release.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .vscode
├── cspell.json
├── extensions.json
└── settings.json
├── LICENSE
├── README.md
├── docs
├── content
│ ├── SUMMARY.md
│ ├── citation.md
│ ├── contributing.md
│ ├── css
│ │ ├── custom_formatting.css
│ │ └── material_extra.css
│ ├── demo.ipynb
│ ├── index.md
│ ├── javascript
│ │ ├── custom_formatting.js
│ │ └── mathjax.js
│ └── pre-process-datasets.ipynb
└── gen_ref_pages.py
├── mkdocs.yml
├── poetry.lock
├── pyproject.toml
└── sparse_autoencoder
├── __init__.py
├── activation_resampler
├── __init__.py
├── activation_resampler.py
├── tests
│ └── test_activation_resampler.py
└── utils
│ ├── __init__.py
│ └── component_slice_tensor.py
├── activation_store
├── __init__.py
├── base_store.py
├── tensor_store.py
└── tests
│ └── test_tensor_store.py
├── autoencoder
├── __init__.py
├── components
│ ├── __init__.py
│ ├── linear_encoder.py
│ ├── tests
│ │ ├── __snapshots__
│ │ │ └── test_linear_encoder.ambr
│ │ ├── test_compare_neel_implementation.py
│ │ ├── test_linear_encoder.py
│ │ ├── test_tied_bias.py
│ │ └── test_unit_norm_decoder.py
│ ├── tied_bias.py
│ └── unit_norm_decoder.py
├── lightning.py
├── model.py
├── tests
│ ├── __snapshots__
│ │ └── test_model.ambr
│ └── test_model.py
└── types.py
├── metrics
├── __init__.py
├── loss
│ ├── __init__.py
│ ├── l1_absolute_loss.py
│ ├── l2_reconstruction_loss.py
│ ├── sae_loss.py
│ └── tests
│ │ ├── test_l1_absolute_loss.py
│ │ ├── test_l2_reconstruction_loss.py
│ │ └── test_sae_loss.py
├── train
│ ├── __init__.py
│ ├── capacity.py
│ ├── feature_density.py
│ ├── l0_norm.py
│ ├── neuron_activity.py
│ ├── neuron_fired_count.py
│ └── tests
│ │ ├── test_feature_density.py
│ │ ├── test_l0_norm.py
│ │ ├── test_neuron_activity.py
│ │ └── test_neuron_fired_count.py
├── validate
│ ├── __init__.py
│ ├── reconstruction_score.py
│ └── tests
│ │ └── __snapshots__
│ │ └── test_model_reconstruction_score.ambr
└── wrappers
│ ├── __init__.py
│ ├── classwise.py
│ └── tests
│ └── test_classwise.py
├── optimizer
├── __init__.py
├── adam_with_reset.py
└── tests
│ └── test_adam_with_reset.py
├── source_data
├── __init__.py
├── abstract_dataset.py
├── mock_dataset.py
├── pretokenized_dataset.py
├── tests
│ ├── test_abstract_dataset.py
│ ├── test_mock_dataset.py
│ ├── test_pretokenized_dataset.py
│ └── test_text_dataset.py
└── text_dataset.py
├── source_model
├── __init__.py
├── replace_activations_hook.py
├── reshape_activations.py
├── store_activations_hook.py
├── tests
│ ├── test_replace_activations_hook.py
│ ├── test_store_activations_hook.py
│ └── test_zero_ablate_hook.py
└── zero_ablate_hook.py
├── tensor_types.py
├── train
├── __init__.py
├── join_sweep.py
├── pipeline.py
├── sweep.py
├── sweep_config.py
├── tests
│ ├── __snapshots__
│ │ └── test_sweep.ambr
│ ├── test_pipeline.py
│ └── test_sweep.py
└── utils
│ ├── __init__.py
│ ├── get_model_device.py
│ ├── round_down.py
│ ├── tests
│ ├── test_get_model_device.py
│ └── test_wandb_sweep_types.py
│ └── wandb_sweep_types.py
├── training_runs
├── __init__.py
└── gpt2.py
└── utils
├── __init__.py
├── data_parallel.py
└── tensor_shape.py
/.devcontainer/Dockerfile:
--------------------------------------------------------------------------------
1 | # Use Nvidia Ubuntu 20 base (includes CUDA if a supported GPU is present)
2 | # https://hub.docker.com/r/nvidia/cuda
3 | FROM nvidia/cuda:12.2.2-devel-ubuntu22.04
4 |
5 | ARG USERNAME
6 | ARG USER_UID=1000
7 | ARG USER_GID=$USER_UID
8 |
9 | # Create the user
10 | # https://code.visualstudio.com/remote/advancedcontainers/add-nonroot-user
11 | RUN groupadd --gid $USER_GID $USERNAME \
12 | && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \
13 | && usermod -a -G video user \
14 | && apt-get update \
15 | && apt-get install -y sudo \
16 | && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
17 | && chmod 0440 /etc/sudoers.d/$USERNAME
18 |
19 | # Add the deadsnakes PPA to get the latest Python version
20 | RUN sudo apt-get install -y software-properties-common \
21 | && sudo add-apt-repository ppa:deadsnakes/ppa
22 |
23 | # Install dependencies
24 | RUN sudo apt-get update && \
25 | DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \
26 | build-essential \
27 | curl \
28 | git \
29 | poppler-utils \
30 | python3 \
31 | python3.11 \
32 | python3-dev \
33 | python3-distutils \
34 | python3-venv \
35 | expat \
36 | rsync
37 |
38 | # Install pip (we need the latest version not the standard Ubuntu version, to
39 | # support modern wheels)
40 | RUN sudo curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py
41 |
42 | # Install poetry
43 | RUN sudo curl -sSL https://install.python-poetry.org | python3 -
44 | ENV PATH="/root/.local/bin:$PATH"
45 |
46 | # Set python aliases
47 | RUN sudo update-alternatives --install /usr/bin/python python /usr/bin/python3 1
48 |
49 | # User the new user
50 | USER $USERNAME
51 |
--------------------------------------------------------------------------------
/.devcontainer/devcontainer.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Python 3",
3 | "build": {
4 | "dockerfile": "Dockerfile",
5 | "args": {
6 | "USERNAME": "user"
7 | }
8 | },
9 | "runArgs": [
10 | "--gpus",
11 | "all"
12 | ],
13 | "containerUser": "user",
14 | "postCreateCommand": "poetry env use 3.11 && poetry install --with dev,jupyter"
15 | }
16 |
--------------------------------------------------------------------------------
/.github/workflows/checks.yml:
--------------------------------------------------------------------------------
1 | name: Checks
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | paths-ignore:
8 | - '.devcontainer/**'
9 | - '.vscode/**'
10 | - '.gitignore'
11 | - 'README.md'
12 | pull_request:
13 | branches:
14 | - main
15 | paths-ignore:
16 | - '.devcontainer/**'
17 | - '.vscode/**'
18 | - '.gitignore'
19 | - 'README.md'
20 | # Allow this workflow to be called from other workflows
21 | workflow_call:
22 | inputs:
23 | # Requires at least one input to be valid, but in practice we don't need any
24 | dummy:
25 | type: string
26 | required: false
27 |
28 | jobs:
29 | checs:
30 | name: Checks
31 | runs-on: ubuntu-latest
32 | strategy:
33 | matrix:
34 | python-version:
35 | - "3.10"
36 | - "3.11"
37 | steps:
38 | - uses: actions/checkout@v4
39 | - name: Install Poetry
40 | uses: snok/install-poetry@v1
41 | - name: Set up Python
42 | uses: actions/setup-python@v4
43 | with:
44 | python-version: ${{ matrix.python-version }}
45 | cache: 'poetry'
46 | allow-prereleases: true
47 | - name: Check lockfile
48 | run: poetry check
49 | - name: Install dependencies
50 | run: poetry install --with dev
51 | - name: Pyright type check
52 | run: poetry run pyright
53 | - name: Ruff lint
54 | run: poetry run ruff check . --output-format=github
55 | - name: Docstrings lint
56 | run: poetry run pydoclint .
57 | - name: Ruff format
58 | run: poetry run ruff format . --check
59 | - name: Pytest
60 | run: poetry run pytest
61 | - name: Build check
62 | run: poetry build
63 |
--------------------------------------------------------------------------------
/.github/workflows/gh-pages.yml:
--------------------------------------------------------------------------------
1 | name: Docs
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 | branches:
8 | - '*'
9 |
10 | permissions:
11 | contents: write
12 |
13 | jobs:
14 | build-docs:
15 | # When running on a PR, this just checks we can build the docs without errors
16 | # When running on merge to main, it builds the docs and then another job deploys them
17 | name: ${{ github.event_name == 'pull_request' && 'Check Build Docs' || 'Build Docs' }}
18 | runs-on: ubuntu-latest
19 | steps:
20 | - uses: actions/checkout@v4
21 | - name: Install Poetry
22 | uses: snok/install-poetry@v1
23 | - name: Set up Python
24 | uses: actions/setup-python@v4
25 | with:
26 | python-version: "3.11"
27 | cache: "poetry"
28 | - name: Install poe
29 | run: pip install poethepoet
30 | - name: Install dependencies
31 | run: poetry install --with docs
32 | - name: Generate docs
33 | run: poe gen-docs
34 | - name: Build Docs
35 | run: poe make-docs
36 | - name: Upload Docs Artifact
37 | uses: actions/upload-artifact@v3
38 | with:
39 | name: documentation
40 | path: docs/generated
41 |
42 | deploy-docs:
43 | name: Deploy Docs
44 | runs-on: ubuntu-latest
45 | # Only run if merging a PR into main
46 | if: github.event_name == 'push' && github.ref == 'refs/heads/main'
47 | needs: build-docs
48 | steps:
49 | - uses: actions/checkout@v4
50 | - name: Download Docs Artifact
51 | uses: actions/download-artifact@v3
52 | with:
53 | name: documentation
54 | path: docs/generated
55 | - name: Upload to GitHub Pages
56 | uses: JamesIves/github-pages-deploy-action@v4
57 | with:
58 | folder: docs/generated
59 | clean-exclude: |
60 | *.*.*/
61 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | release:
5 | types:
6 | - published
7 |
8 | jobs:
9 | checks:
10 | name: Run checks workflow
11 | uses: alan-cooney/sparse_autoencoder/.github/workflows/checks.yml@main
12 |
13 | semver-parser:
14 | name: Parse the semantic version from the release
15 | runs-on: ubuntu-latest
16 | steps:
17 | - name: Parse semver string
18 | id: semver_parser
19 | uses: booxmedialtd/ws-action-parse-semver@v1.4.7
20 | with:
21 | input_string: ${{ github.event.release.tag_name }}
22 | outputs:
23 | major: "${{ steps.semver_parser.outputs.major }}"
24 | minor: "${{ steps.semver_parser.outputs.minor }}"
25 | patch: "${{ steps.semver_parser.outputs.patch }}"
26 | semver: "${{ steps.semver_parser.outputs.fullversion }}"
27 |
28 | release-python:
29 | name: Release Python package to PyPi
30 | needs:
31 | - checks
32 | - semver-parser
33 | runs-on: ubuntu-latest
34 | steps:
35 | - uses: actions/checkout@v3
36 | - name: Install Poetry
37 | uses: snok/install-poetry@v1
38 | - name: Poetry config
39 | run: poetry self add 'poethepoet[poetry_plugin]'
40 | - name: Set up Python
41 | uses: actions/setup-python@v4
42 | with:
43 | python-version: '3.10'
44 | cache: 'poetry'
45 | - name: Install dependencies
46 | run: poetry install --with dev
47 | - name: Set the version
48 | run: poetry version ${{needs.semver-parser.outputs.semver}}
49 | - name: Build
50 | run: poetry build
51 | - name: Publish
52 | run: poetry publish
53 | env:
54 | POETRY_PYPI_TOKEN_PYPI: ${{ secrets.POETRY_PYPI_TOKEN }}
55 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | manifest
29 | .DS_Store
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Docs stuff
56 | docs/content/reference
57 | docs/generated
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 |
136 | # Generated docs
137 | docs/content/reference
138 |
139 | # Checkpoints directory
140 | .checkpoints
141 |
142 | # Wandb
143 | wandb/
144 | artifacts/
145 |
146 | # Lightning
147 | lightning_logs
148 |
149 | # Scratch files
150 | scratch.py
151 | scratch.ipynb
152 | scratch/
153 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v4.5.0
6 | hooks:
7 | - id: trailing-whitespace
8 | - id: end-of-file-fixer
9 | - id: check-yaml
10 | - id: check-toml
11 | - id: check-json
12 | - id: check-added-large-files
13 | - id: check-merge-conflict
14 | - id: check-symlinks
15 | - id: destroyed-symlinks
16 | - id: detect-private-key
17 | - repo: local
18 | hooks:
19 | - id: ruff_lint
20 | name: Ruff Lint
21 | entry: poetry run ruff check sparse_autoencoder --fix
22 | language: system
23 | types: [python]
24 | require_serial: true
25 | - id: ruff_format
26 | name: Ruff Format
27 | entry: poetry run ruff format sparse_autoencoder --check --diff
28 | language: system
29 | types: [python]
30 | require_serial: true
31 | - id: typecheck
32 | name: Pyright Type Check
33 | entry: poetry run pyright
34 | language: system
35 | types: [python]
36 | require_serial: true
37 |
--------------------------------------------------------------------------------
/.vscode/cspell.json:
--------------------------------------------------------------------------------
1 | {
2 | "language": "en,en-GB",
3 | "words": [
4 | "alancooney",
5 | "allclose",
6 | "arange",
7 | "arithmatex",
8 | "astroid",
9 | "autocast",
10 | "autoencoder",
11 | "autoencoders",
12 | "autoencoding",
13 | "autofix",
14 | "capturable",
15 | "categoricalwprobabilities",
16 | "circuitsvis",
17 | "Classwise",
18 | "coeff",
19 | "colab",
20 | "cooney",
21 | "cuda",
22 | "cudnn",
23 | "datapoints",
24 | "davidanson",
25 | "devcontainer",
26 | "devel",
27 | "dmypy",
28 | "docstrings",
29 | "donjayamanne",
30 | "dtype",
31 | "dtypes",
32 | "dunder",
33 | "earlyterminate",
34 | "einops",
35 | "einsum",
36 | "endoftext",
37 | "gelu",
38 | "githistory",
39 | "hobbhahn",
40 | "htmlproofer",
41 | "huggingface",
42 | "hyperband",
43 | "hyperparameters",
44 | "imageuri",
45 | "imputewhilerunning",
46 | "interpretability",
47 | "intuniform",
48 | "invloguniform",
49 | "invloguniformvalues",
50 | "ipynb",
51 | "itemwise",
52 | "jaxtyping",
53 | "kaiming",
54 | "keepdim",
55 | "logit",
56 | "lognormal",
57 | "loguniform",
58 | "loguniformvalues",
59 | "mathbb",
60 | "mathbf",
61 | "maxiter",
62 | "miniter",
63 | "mkdocs",
64 | "mkdocstrings",
65 | "mknotebooks",
66 | "monosemantic",
67 | "monosemanticity",
68 | "multipled",
69 | "nanda",
70 | "ncols",
71 | "ndarray",
72 | "ndim",
73 | "neel",
74 | "nelement",
75 | "neox",
76 | "nonlinerity",
77 | "numel",
78 | "onebit",
79 | "openwebtext",
80 | "optim",
81 | "penality",
82 | "perp",
83 | "pickleable",
84 | "polysemantic",
85 | "polysemantically",
86 | "polysemanticity",
87 | "precommit",
88 | "pretokenized",
89 | "pydantic",
90 | "pyproject",
91 | "pyright",
92 | "pytest",
93 | "qbeta",
94 | "qlognormal",
95 | "qloguniform",
96 | "qloguniformvalues",
97 | "qnormal",
98 | "quniform",
99 | "randn",
100 | "randperm",
101 | "relu",
102 | "resampler",
103 | "resid",
104 | "roneneldan",
105 | "rtol",
106 | "runcap",
107 | "sharded",
108 | "snapshottest",
109 | "solu",
110 | "tinystories",
111 | "torchmetrics",
112 | "tqdm",
113 | "transformer_lens",
114 | "typecheck",
115 | "ultralow",
116 | "uncopyright",
117 | "uncopyrighted",
118 | "ungraphed",
119 | "unsqueeze",
120 | "unsync",
121 | "venv",
122 | "virtualenv",
123 | "virtualenvs",
124 | "wandb",
125 | "zoadam"
126 | ]
127 | }
128 |
--------------------------------------------------------------------------------
/.vscode/extensions.json:
--------------------------------------------------------------------------------
1 | {
2 | "recommendations": [
3 | "davidanson.vscode-markdownlint",
4 | "donjayamanne.githistory",
5 | "github.copilot",
6 | "github.vscode-github-actions",
7 | "github.vscode-pull-request-github",
8 | "ionutvmi.path-autocomplete",
9 | "littlefoxteam.vscode-python-test-adapter",
10 | "ms-python.python",
11 | "ms-python.vscode-pylance",
12 | "ms-toolsai.jupyter-keymap",
13 | "ms-toolsai.jupyter-renderers",
14 | "ms-toolsai.jupyter",
15 | "richie5um2.vscode-sort-json",
16 | "stkb.rewrap",
17 | "streetsidesoftware.code-spell-checker-british-english",
18 | "streetsidesoftware.code-spell-checker",
19 | "yzhang.markdown-all-in-one",
20 | "kevinrose.vsc-python-indent",
21 | "donjayamanne.python-environment-manager",
22 | "tamasfe.even-better-toml"
23 | ]
24 | }
25 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "[jsonc]": {
3 | "editor.defaultFormatter": "vscode.json-language-features"
4 | },
5 | "[python]": {
6 | "editor.defaultFormatter": "charliermarsh.ruff"
7 | },
8 | "[toml]": {
9 | "editor.defaultFormatter": "tamasfe.even-better-toml"
10 | },
11 | "editor.codeActionsOnSave": {
12 | "source.fixAll.eslint": "explicit",
13 | "source.organizeImports": "explicit"
14 | },
15 | "editor.formatOnSave": true,
16 | "evenBetterToml.formatter.allowedBlankLines": 1,
17 | "evenBetterToml.formatter.arrayAutoCollapse": true,
18 | "evenBetterToml.formatter.arrayAutoExpand": true,
19 | "evenBetterToml.formatter.arrayTrailingComma": true,
20 | "evenBetterToml.formatter.columnWidth": 100,
21 | "evenBetterToml.formatter.compactArrays": true,
22 | "evenBetterToml.formatter.compactEntries": true,
23 | "evenBetterToml.formatter.compactInlineTables": true,
24 | "evenBetterToml.formatter.indentEntries": true,
25 | "evenBetterToml.formatter.indentString": " ",
26 | "evenBetterToml.formatter.indentTables": true,
27 | "evenBetterToml.formatter.inlineTableExpand": true,
28 | "evenBetterToml.formatter.reorderArrays": true,
29 | "evenBetterToml.formatter.reorderKeys": true,
30 | "evenBetterToml.formatter.trailingNewline": true,
31 | "evenBetterToml.schema.enabled": true,
32 | "evenBetterToml.schema.links": true,
33 | "evenBetterToml.syntax.semanticTokens": false,
34 | "notebook.formatOnCellExecution": true,
35 | "notebook.formatOnSave.enabled": true,
36 | "python.analysis.autoFormatStrings": true,
37 | "python.analysis.inlayHints.functionReturnTypes": false,
38 | "python.analysis.typeCheckingMode": "basic",
39 | "python.languageServer": "Pylance",
40 | "python.terminal.activateEnvInCurrentTerminal": true,
41 | "python.testing.pytestEnabled": true,
42 | "rewrap.autoWrap.enabled": true,
43 | "rewrap.wrappingColumn": 100,
44 | "pylint.ignorePatterns": [
45 | "*"
46 | ],
47 | "pythonTestExplorer.testFramework": "pytest"
48 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Alan Cooney
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Sparse Autoencoder
2 |
3 | [](https://pypi.org/project/transformer-lens/)
4 |  [](https://github.com/alan-cooney/sparse_autoencoder/actions/workflows/checks.yml)
6 | [](https://github.com/alan-cooney/sparse_autoencoder/actions/workflows/release.yml)
7 |
8 | A sparse autoencoder for mechanistic interpretability research.
9 |
10 | [](https://ai-safety-foundation.github.io/sparse_autoencoder/)
12 |
13 | Train a Sparse Autoencoder [in colab](https://colab.research.google.com/github/ai-safety-foundation/sparse_autoencoder/blob/main/docs/content/demo.ipynb), or install for your project:
14 |
15 | ```shell
16 | pip install sparse_autoencoder
17 | ```
18 |
19 | ## Features
20 |
21 | This library contains:
22 |
23 | 1. **A sparse autoencoder model**, along with all the underlying PyTorch components you need to
24 | customise and/or build your own:
25 | - Encoder, constrained unit norm decoder and tied bias PyTorch modules in `autoencoder`.
26 | - L1 and L2 loss modules in `loss`.
27 | - Adam module with helper method to reset state in `optimizer`.
28 | 2. **Activations data generator** using TransformerLens, with the underlying steps in case you
29 | want to customise the approach:
30 | - Activation store options (in-memory or on disk) in `activation_store`.
31 | - Hook to get the activations from TransformerLens in an efficient way in `source_model`.
32 | - Source dataset (i.e. prompts to generate these activations) utils in `source_data`, that
33 | stream data from HuggingFace and pre-process (tokenize & shuffle).
34 | 3. **Activation resampler** to help reduce the number of dead neurons.
35 | 4. **Metrics** that log at various stages of training (e.g. during training, resampling and
36 | validation), and integrate with wandb.
37 | 5. **Training pipeline** that combines everything together, allowing you to run hyperparameter
38 | sweeps and view progress on wandb.
39 |
40 | ## Designed for Research
41 |
42 | The library is designed to be modular. By default it takes the approach from [Towards
43 | Monosemanticity: Decomposing Language Models With Dictionary Learning
44 | ](https://transformer-circuits.pub/2023/monosemantic-features/index.html), so you can pip install
45 | the library and get started quickly. Then when you need to customise something, you can just extend
46 | the class for that component (e.g. you can extend `SparseAutoencoder` if you want to customise the
47 | model, and then drop it back into the training pipeline. Every component is fully documented, so
48 | it's nice and easy to do this.
49 |
50 | ## Demo
51 |
52 | Check out the demo notebook [docs/content/demo.ipynb](https://github.com/ai-safety-foundation/sparse_autoencoder/blob/main/docs/content/demo.ipynb) for a guide to using this library.
53 |
54 | ## Contributing
55 |
56 | This project uses [Poetry](https://python-poetry.org) for dependency management, and
57 | [PoeThePoet](https://poethepoet.natn.io/installation.html) for scripts. After checking out the repo,
58 | we recommend setting poetry's config to create the `.venv` in the root directory (note this is a
59 | global setting) and then installing with the dev and demos dependencies.
60 |
61 | ```shell
62 | poetry config virtualenvs.in-project true
63 | poetry install --with dev,demos
64 | ```
65 |
66 | ### Checks
67 |
68 | For a full list of available commands (e.g. `test` or `typecheck`), run this in your terminal
69 | (assumes the venv is active already).
70 |
71 | ```shell
72 | poe
73 | ```
74 |
--------------------------------------------------------------------------------
/docs/content/SUMMARY.md:
--------------------------------------------------------------------------------
1 |
2 | * [Home](index.md)
3 | * [Demo](demo.ipynb)
4 | * [Source dataset pre-processing](pre-process-datasets.ipynb)
5 | * [Reference](reference/)
6 | * [Contributing](contributing.md)
7 | * [Citation](citation.md)
--------------------------------------------------------------------------------
/docs/content/citation.md:
--------------------------------------------------------------------------------
1 |
2 | # Citation
3 |
4 | Please cite this library as:
5 |
6 | ```BibTeX
7 | @misc{cooney2023SparseAutoencoder,
8 | title = {Sparse Autoencoder Library},
9 | author = {Alan Cooney},
10 | year = {2023},
11 | howpublished = {\url{https://github.com/ai-safety-foundation/sparse_autoencoder}},
12 | }
13 | ```
14 |
--------------------------------------------------------------------------------
/docs/content/contributing.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | ## Setup
4 |
5 | This project uses [Poetry](https://python-poetry.org) for dependency management, and
6 | [PoeThePoet](https://poethepoet.natn.io/installation.html) for scripts. After checking out the repo,
7 | we recommend setting poetry's config to create the `.venv` in the root directory (note this is a
8 | global setting) and then installing with the dev and demos dependencies.
9 |
10 | ```shell
11 | poetry config virtualenvs.in-project true
12 | poetry install --with dev,demos
13 | ```
14 |
15 | If you are using VSCode we highly recommend installing the recommended extensions as well (it will
16 | prompt you to do this when you checkout the repo).
17 |
18 | ## Checks
19 |
20 | For a full list of available commands (e.g. `test` or `typecheck`), run this in your terminal
21 | (assumes the venv is active already).
22 |
23 | ```shell
24 | poe
25 | ```
26 |
27 | ## Documentation
28 |
29 | Please make sure to add thorough documentation for any features you add. You should do this directly
30 | in the docstring, and this will then automatically generate the API docs when merged into `main`.
31 | They will also be automatically checked with [pytest](https://docs.pytest.org/) (via
32 | [doctest](https://docs.python.org/3/library/doctest.html)).
33 |
34 | If you want to view your documentation changes, run `poe docs-hot-reload`. This will give you
35 | hot-reloading docs (they change in real time as you edit docstrings).
36 |
37 | ### Docstring Style Guide
38 |
39 | We follow the [Google Python Docstring Style](https://google.github.io/styleguide/pyguide.html) for
40 | writing docstrings. Some important details below:
41 |
42 | #### Sections and Order
43 |
44 | You should follow this order:
45 |
46 | ```python
47 | """Title In Title Case.
48 |
49 | A description of what the function/class does, including as much detail as is necessary to fully understand it.
50 |
51 | Warning:
52 |
53 | Any warnings to the user (e.g. common pitfalls).
54 |
55 | Examples:
56 |
57 | Include any examples here. They will be checked with doctest.
58 |
59 | >>> print(1 + 2)
60 | 3
61 |
62 | Args:
63 | param_without_type_signature:
64 | Each description should be indented once more.
65 | param_2:
66 | Another example parameter.
67 |
68 | Returns:
69 | Returns description without type signature.
70 |
71 | Raises:
72 | Information about the error it may raise (if any).
73 | """
74 | ```
75 |
76 | #### LaTeX support
77 |
78 | You can use LaTeX, inside `$$` for blocks or `$` for inline
79 |
80 | ```markdown
81 | Some text $(a + b)^2 = a^2 + 2ab + b^2$
82 | ```
83 |
84 | ```markdown
85 | Some text:
86 |
87 | $$
88 | y & = & ax^2 + bx + c \\
89 | f(x) & = & x^2 + 2xy + y^2
90 | $$
91 | ```
92 |
93 | #### Markup
94 |
95 | - Italics - `*text*`
96 | - Bold - `**text**`
97 | - Code - ` ``code`` `
98 | - List items - `*item`
99 | - Numbered items - `1. Item`
100 | - Quotes - indent one level
101 | - External links = ``` `Link text ` ```
102 |
--------------------------------------------------------------------------------
/docs/content/css/custom_formatting.css:
--------------------------------------------------------------------------------
1 | /* Show the full equation without vertical scroll */
2 | .md-typeset div.arithmatex {
3 | overflow: visible !important;
4 | }
5 |
6 | /* Preserve new lines in code example blocks */
7 | .example blockquote blockquote blockquote p {
8 | white-space: pre-wrap;
9 | font-family: monospace;
10 | }
--------------------------------------------------------------------------------
/docs/content/css/material_extra.css:
--------------------------------------------------------------------------------
1 | /* Indentation. */
2 | div.doc-contents:not(.first) {
3 | padding-left: 25px;
4 | border-left: .05rem solid var(--md-typeset-table-color);
5 | }
6 |
7 | /* Mark external links as such. */
8 | a.external::after,
9 | a.autorefs-external::after {
10 | /* https://primer.style/octicons/arrow-up-right-24 */
11 | mask-image: url('data:image/svg+xml,');
12 | -webkit-mask-image: url('data:image/svg+xml,');
13 | content: ' ';
14 |
15 | display: inline-block;
16 | vertical-align: middle;
17 | position: relative;
18 |
19 | height: 1em;
20 | width: 1em;
21 | background-color: var(--md-typeset-a-color);
22 | }
23 |
24 | a.external:hover::after,
25 | a.autorefs-external:hover::after {
26 | background-color: var(--md-accent-fg-color);
27 | }
--------------------------------------------------------------------------------
/docs/content/demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "\n",
8 | "
\n",
9 | ""
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "# Training Demo\n",
17 | "\n",
18 | "This is a quick start demo to get training a SAE right away. All you need to do is choose a few\n",
19 | "hyperparameters (like the model to train on), and then set it off.\n",
20 | "\n",
21 | "In this demo we'll train a sparse autoencoder on all MLP layer outputs in GPT-2 small (effectively\n",
22 | "training an SAE on each layer in parallel)."
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "metadata": {},
28 | "source": [
29 | "## Setup"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {},
35 | "source": [
36 | "### Imports"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 1,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "# Check if we're in Colab\n",
46 | "try:\n",
47 | " import google.colab # noqa: F401 # type: ignore\n",
48 | "\n",
49 | " in_colab = True\n",
50 | "except ImportError:\n",
51 | " in_colab = False\n",
52 | "\n",
53 | "# Install if in Colab\n",
54 | "if in_colab:\n",
55 | " %pip install sparse_autoencoder transformer_lens transformers wandb\n",
56 | "\n",
57 | "# Otherwise enable hot reloading in dev mode\n",
58 | "if not in_colab:\n",
59 | " %load_ext autoreload\n",
60 | " %autoreload 2"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 2,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "import os\n",
70 | "\n",
71 | "from sparse_autoencoder import (\n",
72 | " ActivationResamplerHyperparameters,\n",
73 | " AutoencoderHyperparameters,\n",
74 | " Hyperparameters,\n",
75 | " LossHyperparameters,\n",
76 | " Method,\n",
77 | " OptimizerHyperparameters,\n",
78 | " Parameter,\n",
79 | " PipelineHyperparameters,\n",
80 | " SourceDataHyperparameters,\n",
81 | " SourceModelHyperparameters,\n",
82 | " SweepConfig,\n",
83 | " sweep,\n",
84 | ")\n",
85 | "\n",
86 | "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"demo.ipynb\""
87 | ]
88 | },
89 | {
90 | "cell_type": "markdown",
91 | "metadata": {},
92 | "source": [
93 | "### Hyperparameters"
94 | ]
95 | },
96 | {
97 | "cell_type": "markdown",
98 | "metadata": {},
99 | "source": [
100 | "Customize any hyperparameters you want below (by default we're sweeping over l1 coefficient and\n",
101 | "learning rate).\n",
102 | "\n",
103 | "Note we are using the RANDOM sweep approach (try random combinations of hyperparameters), which\n",
104 | "works surprisingly well but will need to be stopped at some point (as otherwise it will continue\n",
105 | "forever). If you want to run pre-defined runs consider using `Parameter(values=[0.01, 0.05...])` for\n",
106 | "example rather than `Parameter(max=0.03, min=0.008)` for each parameter you are sweeping over. You\n",
107 | "can then set the strategy to `Method.GRID`."
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": 3,
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "def train_gpt_small_mlp_layers(\n",
117 | " expansion_factor: int = 4,\n",
118 | " n_layers: int = 12,\n",
119 | ") -> None:\n",
120 | " \"\"\"Run a new sweep experiment on GPT 2 Small's MLP layers.\n",
121 | "\n",
122 | " Args:\n",
123 | " expansion_factor: Expansion factor for the autoencoder.\n",
124 | " n_layers: Number of layers to train on. Max is 12.\n",
125 | "\n",
126 | " \"\"\"\n",
127 | " sweep_config = SweepConfig(\n",
128 | " parameters=Hyperparameters(\n",
129 | " loss=LossHyperparameters(\n",
130 | " l1_coefficient=Parameter(max=0.03, min=0.008),\n",
131 | " ),\n",
132 | " optimizer=OptimizerHyperparameters(\n",
133 | " lr=Parameter(max=0.001, min=0.00001),\n",
134 | " ),\n",
135 | " source_model=SourceModelHyperparameters(\n",
136 | " name=Parameter(\"gpt2\"),\n",
137 | " cache_names=Parameter(\n",
138 | " [f\"blocks.{layer}.hook_mlp_out\" for layer in range(n_layers)]\n",
139 | " ),\n",
140 | " hook_dimension=Parameter(768),\n",
141 | " ),\n",
142 | " source_data=SourceDataHyperparameters(\n",
143 | " dataset_path=Parameter(\"alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2\"),\n",
144 | " context_size=Parameter(256),\n",
145 | " pre_tokenized=Parameter(value=True),\n",
146 | " pre_download=Parameter(value=False), # Default to streaming the dataset\n",
147 | " ),\n",
148 | " autoencoder=AutoencoderHyperparameters(\n",
149 | " expansion_factor=Parameter(value=expansion_factor)\n",
150 | " ),\n",
151 | " pipeline=PipelineHyperparameters(\n",
152 | " max_activations=Parameter(1_000_000_000),\n",
153 | " checkpoint_frequency=Parameter(100_000_000),\n",
154 | " validation_frequency=Parameter(100_000_000),\n",
155 | " max_store_size=Parameter(1_000_000),\n",
156 | " ),\n",
157 | " activation_resampler=ActivationResamplerHyperparameters(\n",
158 | " resample_interval=Parameter(200_000_000),\n",
159 | " n_activations_activity_collate=Parameter(100_000_000),\n",
160 | " threshold_is_dead_portion_fires=Parameter(1e-6),\n",
161 | " max_n_resamples=Parameter(4),\n",
162 | " ),\n",
163 | " ),\n",
164 | " method=Method.RANDOM,\n",
165 | " )\n",
166 | "\n",
167 | " sweep(sweep_config=sweep_config)"
168 | ]
169 | },
170 | {
171 | "cell_type": "markdown",
172 | "metadata": {},
173 | "source": [
174 | "### Run the sweep"
175 | ]
176 | },
177 | {
178 | "cell_type": "markdown",
179 | "metadata": {},
180 | "source": [
181 | "This will start a sweep with just one agent (the current machine). If you have multiple GPUs, it\n",
182 | "will use them automatically. Similarly it will work on Apple silicon devices by automatically using MPS."
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": 4,
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "train_gpt_small_mlp_layers()"
192 | ]
193 | },
194 | {
195 | "cell_type": "markdown",
196 | "metadata": {},
197 | "source": [
198 | "Want to speed things up? You can trivially add extra machines to the sweep, each of which will peel\n",
199 | "of some runs from the sweep agent (stored on Wandb). To do this, on another machine simply run:\n",
200 | "\n",
201 | "```bash\n",
202 | "pip install sparse_autoencoder\n",
203 | "join-sae-sweep --id=SWEEP_ID_SHOWN_ON_WANDB\n",
204 | "```"
205 | ]
206 | }
207 | ],
208 | "metadata": {
209 | "kernelspec": {
210 | "display_name": ".venv",
211 | "language": "python",
212 | "name": "python3"
213 | },
214 | "language_info": {
215 | "codemirror_mode": {
216 | "name": "ipython",
217 | "version": 3
218 | },
219 | "file_extension": ".py",
220 | "mimetype": "text/x-python",
221 | "name": "python",
222 | "nbconvert_exporter": "python",
223 | "pygments_lexer": "ipython3",
224 | "version": "3.11.6"
225 | },
226 | "vscode": {
227 | "interpreter": {
228 | "hash": "31186ba1239ad81afeb3c631b4833e71f34259d3b92eebb37a9091b916e08620"
229 | }
230 | }
231 | },
232 | "nbformat": 4,
233 | "nbformat_minor": 2
234 | }
235 |
--------------------------------------------------------------------------------
/docs/content/index.md:
--------------------------------------------------------------------------------
1 | # Sparse Autoencoder
2 |
3 | [](https://pypi.org/project/transformer-lens/)
4 | 
5 | [](https://github.com/alan-cooney/sparse_autoencoder/actions/workflows/checks.yml)
6 | [](https://github.com/alan-cooney/sparse_autoencoder/actions/workflows/release.yml)
7 |
8 | A sparse autoencoder for mechanistic interpretability research.
9 |
10 | ```shell
11 | pip install sparse_autoencoder
12 | ```
13 |
14 | ## Quick Start
15 |
16 | Check out the [demo notebook](demo) for a guide to using this library.
17 |
18 | We also highly recommend skimming the reference docs to see all the features that are available.
19 |
20 | ## Features
21 |
22 | This library contains:
23 |
24 | 1. **A sparse autoencoder model**, along with all the underlying PyTorch components you need to
25 | customise and/or build your own:
26 | - Encoder, constrained unit norm decoder and tied bias PyTorch modules in
27 | [sparse_autoencoder.autoencoder][].
28 | - Adam module with helper method to reset state in [sparse_autoencoder.optimizer][].
29 | 2. **Activations data generator** using TransformerLens, with the underlying steps in case you
30 | want to customise the approach:
31 | - Activation store options (in-memory or on disk) in [sparse_autoencoder.activation_store][].
32 | - Hook to get the activations from TransformerLens in an efficient way in
33 | [sparse_autoencoder.source_model][].
34 | - Source dataset (i.e. prompts to generate these activations) utils in
35 | [sparse_autoencoder.source_data][], that stream data from HuggingFace and pre-process
36 | (tokenize & shuffle).
37 | 3. **Activation resampler** to help reduce the number of dead neurons.
38 | 4. **Metrics** that log at various stages of training (loss, train metrics and validation metrics)
39 | , based on torchmetrics.
40 | 5. **Training pipeline** that combines everything together, allowing you to run hyperparameter
41 | sweeps and view progress on wandb.
42 |
43 | ## Designed for Research
44 |
45 | The library is designed to be modular. By default it takes the approach from [Towards
46 | Monosemanticity: Decomposing Language Models With Dictionary Learning
47 | ](https://transformer-circuits.pub/2023/monosemantic-features/index.html), so you can pip install
48 | the library and get started quickly. Then when you need to customise something, you can just extend
49 | the abstract class for that component (every component is documented so that it's easy to do this).
50 |
--------------------------------------------------------------------------------
/docs/content/javascript/custom_formatting.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Expand the 'Reference' section of the navigation
3 | *
4 | * @param {boolean} remove_icon - Whether to remove the '>' icon
5 | */
6 | function expandReferenceSection(remove_icon = false) {
7 | // Find all labels in the navigation menu
8 | const navLabels = document.querySelectorAll('.md-nav__item label');
9 |
10 | // Search for the label with the text 'Reference'
11 | Array.from(navLabels).forEach(label => {
12 | if (label.textContent.trim() === 'Reference') {
13 | // Find the associated checkbox to expand the section
14 | const toggleInput = label.previousElementSibling;
15 | if (toggleInput && toggleInput.tagName === 'INPUT') {
16 | toggleInput.checked = true;
17 | }
18 |
19 | // Find the '>' icon and hide it
20 | if (remove_icon) {
21 | const icon = label.querySelector('.md-nav__icon');
22 | if (icon) {
23 | icon.style.display = 'none';
24 | }
25 | }
26 | }
27 | });
28 | }
29 |
30 | /**
31 | * Hides the Table of Contents (TOC) section if it only contains one link.
32 | */
33 | function hideSingleItemTOC() {
34 | // Find the TOC list
35 | const tocList = document.querySelector('.md-nav--secondary .md-nav__list');
36 |
37 | if (tocList) {
38 | // Count the number of list items (links) in the TOC
39 | const itemCount = tocList.querySelectorAll('li').length;
40 |
41 | // If there is only one item, hide the entire TOC section
42 | if (itemCount === 1) {
43 | console.log("only one")
44 | const tocSection = document.querySelector('.md-sidebar--secondary[data-md-component="sidebar"][data-md-type="toc"]');
45 | if (tocSection) {
46 | tocSection.style.display = 'none';
47 | }
48 | }
49 | }
50 | }
51 |
52 | function main() {
53 | document.addEventListener("DOMContentLoaded", function () {
54 | // Expand the 'Reference' section of the navigation
55 | expandReferenceSection();
56 |
57 | // Hide the Table of Contents (TOC) section if it only contains one link
58 | hideSingleItemTOC();
59 | })
60 | }
61 |
62 | main();
--------------------------------------------------------------------------------
/docs/content/javascript/mathjax.js:
--------------------------------------------------------------------------------
1 | window.MathJax = {
2 | tex: {
3 | inlineMath: [["\\(", "\\)"]],
4 | displayMath: [["\\[", "\\]"]],
5 | processEscapes: true,
6 | processEnvironments: true
7 | },
8 | options: {
9 | ignoreHtmlClass: ".*|",
10 | processHtmlClass: "arithmatex"
11 | }
12 | };
13 |
14 | document$.subscribe(() => {
15 | MathJax.typesetPromise()
16 | })
--------------------------------------------------------------------------------
/docs/gen_ref_pages.py:
--------------------------------------------------------------------------------
1 | """Auto-generate reference documentation based on docstrings."""
2 | from pathlib import Path
3 | import shutil
4 |
5 | import mkdocs_gen_files
6 |
7 |
8 | CURRENT_DIR = Path(__file__).parent
9 | REPO_ROOT = CURRENT_DIR.parent
10 | PROJECT_ROOT = REPO_ROOT / "sparse_autoencoder"
11 | REFERENCE_DIR = CURRENT_DIR / "content/reference"
12 |
13 |
14 | def reset_reference_dir() -> None:
15 | """Reset the reference directory to its initial state."""
16 | # Unlink the directory including all files
17 | shutil.rmtree(REFERENCE_DIR, ignore_errors=True)
18 | REFERENCE_DIR.mkdir(parents=True, exist_ok=True)
19 |
20 |
21 | def is_source_file(file: Path) -> bool:
22 | """Check if the provided file is a source file for Sparse Encoder.
23 |
24 | Args:
25 | file: The file path to check.
26 |
27 | Returns:
28 | bool: True if the file is a source file, False otherwise.
29 | """
30 | return "test" not in str(file)
31 |
32 |
33 | def process_path(path: Path) -> tuple[Path, Path, Path]:
34 | """Process the given path for documentation generation.
35 |
36 | Args:
37 | path: The file path to process.
38 |
39 | Returns:
40 | A tuple containing module path, documentation path, and full documentation path.
41 | """
42 | module_path = path.relative_to(PROJECT_ROOT).with_suffix("")
43 | doc_path = path.relative_to(PROJECT_ROOT).with_suffix(".md")
44 | full_doc_path = Path(REFERENCE_DIR, doc_path)
45 |
46 | if module_path.name == "__init__":
47 | module_path = module_path.parent
48 | doc_path = doc_path.with_name("index.md")
49 | full_doc_path = full_doc_path.with_name("index.md")
50 |
51 | return module_path, doc_path, full_doc_path
52 |
53 |
54 | def generate_documentation(path: Path, module_path: Path, full_doc_path: Path) -> None:
55 | """Generate documentation for the given source file.
56 |
57 | Args:
58 | path: The source file path.
59 | module_path: The module path.
60 | full_doc_path: The full documentation file path.
61 | """
62 | if module_path.name == "__main__":
63 | return
64 |
65 | # Get the mkdocstrings identifier for the module
66 | parts = list(module_path.parts)
67 | parts.insert(0, "sparse_autoencoder")
68 | identifier = ".".join(parts)
69 |
70 | # Read the first line of the file docstring, and set as the header
71 | with path.open() as fd:
72 | first_line = fd.readline()
73 | first_line_without_docstring = first_line.replace('"""', "").strip()
74 | first_line_without_last_dot = first_line_without_docstring.rstrip(".")
75 | title = first_line_without_last_dot or module_path.name
76 |
77 | with mkdocs_gen_files.open(full_doc_path, "w") as fd:
78 | fd.write(f"# {title}" + "\n\n" + f"::: {identifier}")
79 |
80 | mkdocs_gen_files.set_edit_path(full_doc_path, path)
81 |
82 |
83 | def generate_nav_file(nav: mkdocs_gen_files.nav.Nav, reference_dir: Path) -> None:
84 | """Generate the navigation file for the documentation.
85 |
86 | Args:
87 | nav: The navigation object.
88 | reference_dir: The directory to write the navigation file.
89 | """
90 | with mkdocs_gen_files.open(reference_dir / "SUMMARY.md", "w") as nav_file:
91 | nav_file.writelines(nav.build_literate_nav())
92 |
93 |
94 | def run() -> None:
95 | """Handle the generation of reference documentation for Sparse Encoder."""
96 | reset_reference_dir()
97 | nav = mkdocs_gen_files.Nav() # type: ignore
98 |
99 | python_files = PROJECT_ROOT.rglob("*.py")
100 | source_files = filter(is_source_file, python_files)
101 |
102 | for path in sorted(source_files):
103 | module_path, doc_path, full_doc_path = process_path(path)
104 | generate_documentation(path, module_path, full_doc_path)
105 |
106 | url_slug_parts = list(module_path.parts)
107 |
108 | # Don't create a page for the main __init__.py file (as this includes most exports).
109 | if not url_slug_parts:
110 | continue
111 |
112 | nav[url_slug_parts] = doc_path.as_posix() # type: ignore
113 |
114 | generate_nav_file(nav, REFERENCE_DIR)
115 |
116 |
117 | if __name__ == "__main__":
118 | run()
119 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: Sparse Autoencoder
2 | site_description: Sparse Autoencoder for Mechanistic Interpretability
3 | docs_dir: docs/content
4 | site_dir: docs/generated
5 | repo_url: https://github.com/ai-safety-foundation/sparse_autoencoder
6 | repo_name: ai-safety-foundation/sparse_autoencoder
7 | edit_uri: "" # Disabled as we use mkdocstrings which auto-generates some pages
8 | strict: true
9 |
10 | theme:
11 | name: material
12 | palette:
13 | - scheme: default
14 | primary: teal
15 | accent: amber
16 | toggle:
17 | icon: material/weather-night
18 | name: Switch to dark mode
19 | - scheme: slate
20 | primary: black
21 | accent: amber
22 | toggle:
23 | icon: material/weather-sunny
24 | name: Switch to light mode
25 | icon:
26 | repo: fontawesome/brands/github # GitHub logo in top right
27 | features:
28 | - content.action.edit
29 |
30 | extra_javascript:
31 | - javascript/custom_formatting.js
32 | # The below three make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/
33 | - javascript/mathjax.js
34 | - https://polyfill.io/v3/polyfill.min.js?features=es6
35 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
36 |
37 | extra_css:
38 | - "css/material_extra.css"
39 | - "css/custom_formatting.css"
40 |
41 | markdown_extensions:
42 | - pymdownx.arithmatex: # Render LaTeX via MathJax
43 | generic: true
44 | - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme.
45 | - pymdownx.magiclink
46 | - pymdownx.saneheaders
47 | - pymdownx.details # Allowing hidden expandable regions denoted by ???
48 | - pymdownx.snippets: # Include one Markdown file into another
49 | base_path: docs/content
50 | - admonition # Adds admonition blocks (e.g. warning, note, tip, etc.)
51 | - toc:
52 | permalink: "¤" # Adds a clickable permalink to each section heading
53 | toc_depth: 4
54 |
55 | plugins:
56 | - search
57 | - autorefs
58 | - section-index
59 | - literate-nav:
60 | nav_file: SUMMARY.md
61 | - mknotebooks
62 | - mkdocstrings: # https://mkdocstrings.github.io
63 | handlers:
64 | python:
65 | setup_commands:
66 | - import pytkdocs_tweaks
67 | - pytkdocs_tweaks.main()
68 | - import jaxtyping
69 | - jaxtyping.set_array_name_format("array")
70 | options:
71 | docstring_style: google
72 | line_length: 100
73 | show_symbol_type_heading: true
74 | edit_uri: ""
75 | - htmlproofer:
76 | raise_error: True
77 |
--------------------------------------------------------------------------------
/sparse_autoencoder/__init__.py:
--------------------------------------------------------------------------------
1 | """Sparse Autoencoder Library."""
2 | from sparse_autoencoder.activation_resampler.activation_resampler import ActivationResampler
3 | from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
4 | from sparse_autoencoder.autoencoder.model import SparseAutoencoder, SparseAutoencoderConfig
5 | from sparse_autoencoder.metrics.loss.l1_absolute_loss import L1AbsoluteLoss
6 | from sparse_autoencoder.metrics.loss.l2_reconstruction_loss import L2ReconstructionLoss
7 | from sparse_autoencoder.metrics.loss.sae_loss import SparseAutoencoderLoss
8 | from sparse_autoencoder.metrics.train.capacity import CapacityMetric
9 | from sparse_autoencoder.metrics.train.feature_density import FeatureDensityMetric
10 | from sparse_autoencoder.metrics.train.l0_norm import L0NormMetric
11 | from sparse_autoencoder.metrics.train.neuron_activity import NeuronActivityMetric
12 | from sparse_autoencoder.metrics.train.neuron_fired_count import NeuronFiredCountMetric
13 | from sparse_autoencoder.metrics.validate.reconstruction_score import ReconstructionScoreMetric
14 | from sparse_autoencoder.optimizer.adam_with_reset import AdamWithReset
15 | from sparse_autoencoder.source_data.pretokenized_dataset import PreTokenizedDataset
16 | from sparse_autoencoder.source_data.text_dataset import TextDataset
17 | from sparse_autoencoder.train.pipeline import Pipeline
18 | from sparse_autoencoder.train.sweep import sweep
19 | from sparse_autoencoder.train.sweep_config import (
20 | ActivationResamplerHyperparameters,
21 | AutoencoderHyperparameters,
22 | Hyperparameters,
23 | LossHyperparameters,
24 | OptimizerHyperparameters,
25 | PipelineHyperparameters,
26 | SourceDataHyperparameters,
27 | SourceModelHyperparameters,
28 | SourceModelRuntimeHyperparameters,
29 | SweepConfig,
30 | )
31 | from sparse_autoencoder.train.utils.wandb_sweep_types import (
32 | Controller,
33 | ControllerType,
34 | Distribution,
35 | Goal,
36 | HyperbandStopping,
37 | HyperbandStoppingType,
38 | Impute,
39 | ImputeWhileRunning,
40 | Kind,
41 | Method,
42 | Metric,
43 | NestedParameter,
44 | Parameter,
45 | )
46 |
47 |
48 | __all__ = [
49 | "ActivationResampler",
50 | "ActivationResamplerHyperparameters",
51 | "AdamWithReset",
52 | "AutoencoderHyperparameters",
53 | "CapacityMetric",
54 | "Controller",
55 | "ControllerType",
56 | "Distribution",
57 | "FeatureDensityMetric",
58 | "Goal",
59 | "HyperbandStopping",
60 | "HyperbandStoppingType",
61 | "Hyperparameters",
62 | "Impute",
63 | "ImputeWhileRunning",
64 | "Kind",
65 | "L0NormMetric",
66 | "L1AbsoluteLoss",
67 | "L2ReconstructionLoss",
68 | "LossHyperparameters",
69 | "Method",
70 | "Metric",
71 | "NestedParameter",
72 | "NeuronActivityMetric",
73 | "NeuronFiredCountMetric",
74 | "OptimizerHyperparameters",
75 | "Parameter",
76 | "Pipeline",
77 | "PipelineHyperparameters",
78 | "PreTokenizedDataset",
79 | "ReconstructionScoreMetric",
80 | "SourceDataHyperparameters",
81 | "SourceModelHyperparameters",
82 | "SourceModelRuntimeHyperparameters",
83 | "SparseAutoencoder",
84 | "SparseAutoencoderConfig",
85 | "SparseAutoencoderLoss",
86 | "sweep",
87 | "SweepConfig",
88 | "TensorActivationStore",
89 | "TextDataset",
90 | ]
91 |
--------------------------------------------------------------------------------
/sparse_autoencoder/activation_resampler/__init__.py:
--------------------------------------------------------------------------------
1 | """Activation Resampler."""
2 |
--------------------------------------------------------------------------------
/sparse_autoencoder/activation_resampler/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Activation resampler utils."""
2 |
--------------------------------------------------------------------------------
/sparse_autoencoder/activation_resampler/utils/component_slice_tensor.py:
--------------------------------------------------------------------------------
1 | """Component slice tensor utils."""
2 | from torch import Tensor
3 |
4 |
5 | def get_component_slice_tensor(
6 | input_tensor: Tensor,
7 | n_dim_with_component: int,
8 | component_dim: int,
9 | component_idx: int,
10 | ) -> Tensor:
11 | """Get a slice of a tensor for a specific component.
12 |
13 | Examples:
14 | >>> import torch
15 | >>> input_tensor = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
16 | >>> get_component_slice_tensor(input_tensor, 2, 1, 0)
17 | tensor([1, 3, 5, 7])
18 |
19 | >>> input_tensor = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
20 | >>> get_component_slice_tensor(input_tensor, 3, 1, 0)
21 | tensor([[1, 2],
22 | [3, 4],
23 | [5, 6],
24 | [7, 8]])
25 |
26 | Args:
27 | input_tensor: Input tensor.
28 | n_dim_with_component: Number of dimensions in the input tensor with the component axis.
29 | component_dim: Dimension of the component axis.
30 | component_idx: Index of the component to get the slice for.
31 |
32 | Returns:
33 | Tensor slice.
34 |
35 | Raises:
36 | ValueError: If the input tensor does not have the expected number of dimensions.
37 | """
38 | if n_dim_with_component - 1 == input_tensor.ndim:
39 | return input_tensor
40 |
41 | if n_dim_with_component != input_tensor.ndim:
42 | error_message = (
43 | f"Cannot get component slice for tensor with {input_tensor.ndim} dimensions "
44 | f"and {n_dim_with_component} dimensions with component."
45 | )
46 | raise ValueError(error_message)
47 |
48 | # Create a tuple of slices for each dimension
49 | slice_tuple = tuple(
50 | component_idx if i == component_dim else slice(None) for i in range(input_tensor.ndim)
51 | )
52 |
53 | return input_tensor[slice_tuple]
54 |
--------------------------------------------------------------------------------
/sparse_autoencoder/activation_store/__init__.py:
--------------------------------------------------------------------------------
1 | """Activation Stores."""
2 |
--------------------------------------------------------------------------------
/sparse_autoencoder/activation_store/base_store.py:
--------------------------------------------------------------------------------
1 | """Activation Store Base Class."""
2 | from abc import ABC, abstractmethod
3 | from concurrent.futures import Future
4 | from typing import final
5 |
6 | from jaxtyping import Float
7 | from pydantic import PositiveInt, validate_call
8 | import torch
9 | from torch import Tensor
10 | from torch.utils.data import Dataset
11 |
12 | from sparse_autoencoder.tensor_types import Axis
13 |
14 |
15 | class ActivationStore(
16 | Dataset[Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)]], ABC
17 | ):
18 | """Activation Store Abstract Class.
19 |
20 | Extends the `torch.utils.data.Dataset` class to provide an activation store, with additional
21 | :meth:`append` and :meth:`extend` methods (the latter of which should typically be
22 | non-blocking). The resulting activation store can be used with a `torch.utils.data.DataLoader`
23 | to iterate over the dataset.
24 |
25 | Extend this class if you want to create a new activation store (noting you also need to create
26 | `__getitem__` and `__len__` methods from the underlying `torch.utils.data.Dataset` class).
27 |
28 | Example:
29 | >>> import torch
30 | >>> class MyActivationStore(ActivationStore):
31 | ...
32 | ... @property
33 | ... def current_activations_stored_per_component(self):
34 | ... raise NotImplementedError
35 | ...
36 | ... @property
37 | ... def n_components(self):
38 | ... raise NotImplementedError
39 | ...
40 | ... def __init__(self):
41 | ... super().__init__()
42 | ... self._data = [] # In this example, we just store in a list
43 | ...
44 | ... def append(self, item) -> None:
45 | ... self._data.append(item)
46 | ...
47 | ... def extend(self, batch):
48 | ... self._data.extend(batch)
49 | ...
50 | ... def empty(self):
51 | ... self._data = []
52 | ...
53 | ... def __getitem__(self, index: int):
54 | ... return self._data[index]
55 | ...
56 | ... def __len__(self) -> int:
57 | ... return len(self._data)
58 | ...
59 | >>> store = MyActivationStore()
60 | >>> store.append(torch.randn(100))
61 | >>> print(len(store))
62 | 1
63 | """
64 |
65 | @abstractmethod
66 | def append(
67 | self,
68 | item: Float[Tensor, Axis.names(Axis.INPUT_OUTPUT_FEATURE)],
69 | component_idx: int,
70 | ) -> Future | None:
71 | """Add a Single Item to the Store."""
72 |
73 | @abstractmethod
74 | def extend(
75 | self,
76 | batch: Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)],
77 | component_idx: int,
78 | ) -> Future | None:
79 | """Add a Batch to the Store."""
80 |
81 | @abstractmethod
82 | def empty(self) -> None:
83 | """Empty the Store."""
84 |
85 | @property
86 | @abstractmethod
87 | def n_components(self) -> int:
88 | """Number of components."""
89 |
90 | @property
91 | @abstractmethod
92 | def current_activations_stored_per_component(self) -> list[int]:
93 | """Current activations stored per component."""
94 |
95 | @abstractmethod
96 | def __len__(self) -> int:
97 | """Get the Length of the Store."""
98 |
99 | @abstractmethod
100 | def __getitem__(
101 | self, index: tuple[int, ...] | slice | int
102 | ) -> Float[Tensor, Axis.names(Axis.ANY)]:
103 | """Get an Item from the Store."""
104 |
105 | def shuffle(self) -> None:
106 | """Optional shuffle method."""
107 |
108 | @final
109 | @validate_call
110 | def fill_with_test_data(
111 | self,
112 | n_batches: PositiveInt = 1,
113 | batch_size: PositiveInt = 16,
114 | n_components: PositiveInt = 1,
115 | input_features: PositiveInt = 256,
116 | ) -> None:
117 | """Fill the store with test data.
118 |
119 | For use when testing your code, to ensure it works with a real activation store.
120 |
121 | Warning:
122 | You may want to use `torch.seed(0)` to make the random data deterministic, if your test
123 | requires inspecting the data itself.
124 |
125 | Example:
126 | >>> from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
127 | >>> store = TensorActivationStore(max_items=100, n_neurons=256, n_components=1)
128 | >>> store.fill_with_test_data(batch_size=100)
129 | >>> len(store)
130 | 100
131 |
132 | Args:
133 | n_batches: Number of batches to fill the store with.
134 | batch_size: Number of items per batch.
135 | n_components: Number of source model components the SAE is trained on.
136 | input_features: Number of input features per item.
137 | """
138 | for _ in range(n_batches):
139 | for component_idx in range(n_components):
140 | sample = torch.rand(batch_size, input_features)
141 | self.extend(sample, component_idx)
142 |
143 |
144 | class StoreFullError(IndexError):
145 | """Exception raised when the activation store is full."""
146 |
147 | def __init__(self, message: str = "Activation store is full"):
148 | """Initialise the exception.
149 |
150 | Args:
151 | message: Override the default message.
152 | """
153 | super().__init__(message)
154 |
--------------------------------------------------------------------------------
/sparse_autoencoder/activation_store/tests/test_tensor_store.py:
--------------------------------------------------------------------------------
1 | """Tests for the TensorActivationStore."""
2 | import pytest
3 | import torch
4 |
5 | from sparse_autoencoder.activation_store.base_store import StoreFullError
6 | from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
7 |
8 |
9 | def test_works_with_2_components() -> None:
10 | """Test that it works with 2 components."""
11 | n_neurons: int = 128
12 | n_batches: int = 10
13 | batch_size: int = 16
14 | n_components: int = 2
15 |
16 | store = TensorActivationStore(
17 | max_items=int(n_batches * batch_size), n_neurons=n_neurons, n_components=n_components
18 | )
19 |
20 | # Fill the store
21 | for component_idx in range(n_components):
22 | for _ in range(n_batches):
23 | store.extend(torch.rand(batch_size, n_neurons), component_idx=component_idx)
24 |
25 | # Check the size
26 | assert len(store) == int(n_batches * batch_size)
27 |
28 |
29 | def test_appending_more_than_max_items_raises() -> None:
30 | """Test that appending more than the max items raises an error."""
31 | n_neurons: int = 128
32 | store = TensorActivationStore(max_items=1, n_neurons=n_neurons, n_components=1)
33 | store.append(torch.rand(n_neurons), component_idx=0)
34 |
35 | with pytest.raises(StoreFullError):
36 | store.append(torch.rand(n_neurons), component_idx=0)
37 |
38 |
39 | def test_extending_more_than_max_items_raises() -> None:
40 | """Test that extending more than the max items raises an error."""
41 | n_neurons: int = 128
42 | store = TensorActivationStore(max_items=6, n_neurons=n_neurons, n_components=1)
43 | store.extend(torch.rand(4, n_neurons), component_idx=0)
44 |
45 | with pytest.raises(StoreFullError):
46 | store.extend(torch.rand(4, n_neurons), component_idx=0)
47 |
48 |
49 | def test_getting_out_of_range_raises() -> None:
50 | """Test that getting an out of range index raises an error."""
51 | n_neurons: int = 128
52 | store = TensorActivationStore(max_items=1, n_neurons=n_neurons, n_components=1)
53 | store.append(torch.rand(n_neurons), component_idx=0)
54 |
55 | with pytest.raises(IndexError):
56 | store[1, 0]
57 |
--------------------------------------------------------------------------------
/sparse_autoencoder/autoencoder/__init__.py:
--------------------------------------------------------------------------------
1 | """Sparse autoencoder model & components."""
2 |
--------------------------------------------------------------------------------
/sparse_autoencoder/autoencoder/components/__init__.py:
--------------------------------------------------------------------------------
1 | """Sparse autoencoder components."""
2 |
--------------------------------------------------------------------------------
/sparse_autoencoder/autoencoder/components/tests/__snapshots__/test_linear_encoder.ambr:
--------------------------------------------------------------------------------
1 | # serializer version: 1
2 | # name: test_extra_repr
3 | '''
4 | LinearEncoder(
5 | input_features=2, learnt_features=4, n_components=2
6 | (activation_function): ReLU()
7 | )
8 | '''
9 | # ---
10 | # name: test_forward_pass_result_matches_the_snapshot[1]
11 | tensor([[[0.0000, 0.0166, 0.2664, 0.1434]],
12 |
13 | [[0.2154, 0.0000, 0.2007, 0.0000]],
14 |
15 | [[0.3107, 0.0000, 0.1137, 0.0000]]], grad_fn=)
16 | # ---
17 | # name: test_forward_pass_result_matches_the_snapshot[3]
18 | tensor([[[1.0238, 0.0000, 0.0000, 0.0583],
19 | [0.2885, 0.2002, 0.5567, 0.0000],
20 | [0.0000, 0.0000, 0.0000, 0.0000]],
21 |
22 | [[0.8789, 0.0000, 0.0000, 0.1781],
23 | [0.2351, 0.0000, 0.3539, 0.0000],
24 | [0.0000, 0.1716, 0.2969, 0.2289]],
25 |
26 | [[0.9740, 0.0000, 0.0000, 0.0802],
27 | [0.1881, 0.0000, 0.3415, 0.0000],
28 | [0.0000, 0.1251, 0.2831, 0.1824]]], grad_fn=)
29 | # ---
30 | # name: test_forward_pass_result_matches_the_snapshot[None]
31 | tensor([[0.0000, 0.0595, 0.4903, 0.2879],
32 | [0.4522, 0.0000, 0.3589, 0.0000],
33 | [0.6428, 0.0000, 0.1849, 0.0000]], grad_fn=)
34 | # ---
35 |
--------------------------------------------------------------------------------
/sparse_autoencoder/autoencoder/components/tests/test_linear_encoder.py:
--------------------------------------------------------------------------------
1 | """Linear encoder tests."""
2 | from jaxtyping import Float, Int64
3 | import pytest
4 | from syrupy.session import SnapshotSession
5 | import torch
6 | from torch import Tensor
7 |
8 | from sparse_autoencoder.autoencoder.components.linear_encoder import LinearEncoder
9 | from sparse_autoencoder.tensor_types import Axis
10 | from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions
11 |
12 |
13 | # Constants for testing
14 | INPUT_FEATURES = 2
15 | LEARNT_FEATURES = 4
16 | N_COMPONENTS = 2
17 | BATCH_SIZE = 3
18 |
19 |
20 | @pytest.fixture()
21 | def encoder() -> LinearEncoder:
22 | """Fixture to create a LinearEncoder instance."""
23 | torch.manual_seed(0)
24 | return LinearEncoder(
25 | input_features=INPUT_FEATURES, learnt_features=LEARNT_FEATURES, n_components=N_COMPONENTS
26 | )
27 |
28 |
29 | def test_reset_parameters(encoder: LinearEncoder) -> None:
30 | """Test resetting of parameters."""
31 | old_weight = encoder.weight.clone()
32 | old_bias = encoder.bias.clone()
33 | encoder.reset_parameters()
34 | assert not torch.equal(encoder.weight, old_weight)
35 | assert not torch.equal(encoder.bias, old_bias)
36 |
37 |
38 | def test_forward_pass(encoder: LinearEncoder) -> None:
39 | """Test the forward pass of the LinearEncoder."""
40 | input_tensor = torch.randn(BATCH_SIZE, N_COMPONENTS, INPUT_FEATURES)
41 | output = encoder.forward(input_tensor)
42 | assert output.shape == (BATCH_SIZE, N_COMPONENTS, LEARNT_FEATURES)
43 |
44 |
45 | def test_extra_repr(encoder: LinearEncoder, snapshot: SnapshotSession) -> None:
46 | """Test the string representation of the LinearEncoder."""
47 | assert snapshot == str(encoder), "Model string representation has changed."
48 |
49 |
50 | @pytest.mark.parametrize("n_components", [None, 1, 3])
51 | def test_forward_pass_result_matches_the_snapshot(
52 | n_components: int | None, snapshot: SnapshotSession
53 | ) -> None:
54 | """Test the forward pass of the LinearEncoder."""
55 | torch.manual_seed(1)
56 | input_tensor = torch.rand(
57 | shape_with_optional_dimensions(BATCH_SIZE, n_components, INPUT_FEATURES)
58 | )
59 | encoder = LinearEncoder(
60 | input_features=INPUT_FEATURES, learnt_features=LEARNT_FEATURES, n_components=n_components
61 | )
62 | output = encoder.forward(input_tensor)
63 | assert snapshot == output
64 |
65 |
66 | def test_output_same_without_component_dim_vs_with_1_component() -> None:
67 | """Test the forward pass gives identical results for None and 1 component."""
68 | # Create the layers to compare
69 | encoder_without_components_dim = LinearEncoder(
70 | input_features=INPUT_FEATURES, learnt_features=LEARNT_FEATURES, n_components=None
71 | )
72 | encoder_with_1_component = LinearEncoder(
73 | input_features=INPUT_FEATURES, learnt_features=LEARNT_FEATURES, n_components=1
74 | )
75 |
76 | # Set the weight and value parameters to be the same
77 | encoder_with_1_component.weight = torch.nn.Parameter(
78 | encoder_without_components_dim.weight.unsqueeze(0)
79 | )
80 | encoder_with_1_component.bias = torch.nn.Parameter(
81 | encoder_without_components_dim.bias.unsqueeze(0)
82 | )
83 |
84 | # Create the input
85 | input_tensor = torch.rand(BATCH_SIZE, INPUT_FEATURES)
86 | input_with_components_dim = input_tensor.unsqueeze(1)
87 |
88 | # Check the output is the same
89 | output_without_components_dim = encoder_without_components_dim(input_tensor)
90 | output_with_1_component = encoder_with_1_component(input_with_components_dim)
91 |
92 | assert torch.allclose(output_without_components_dim, output_with_1_component.squeeze(1))
93 |
94 |
95 | def test_update_dictionary_vectors_with_no_neurons(encoder: LinearEncoder) -> None:
96 | """Test update_dictionary_vectors with 0 neurons to update."""
97 | torch.random.manual_seed(0)
98 | original_weight = encoder.weight.clone() # Save original weight for comparison
99 |
100 | dictionary_vector_indices: Int64[Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX)] = torch.empty(
101 | 0, dtype=torch.int64
102 | )
103 |
104 | updates: Float[
105 | Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX, Axis.INPUT_OUTPUT_FEATURE)
106 | ] = torch.empty((0, INPUT_FEATURES), dtype=torch.float)
107 |
108 | encoder.update_dictionary_vectors(dictionary_vector_indices, updates, component_idx=0)
109 |
110 | # Ensure weight did not change when no indices were provided
111 | assert torch.equal(
112 | encoder.weight, original_weight
113 | ), "Weights should not change when no indices are provided."
114 |
115 |
116 | @pytest.mark.parametrize(
117 | ("dictionary_vector_indices", "updates"),
118 | [
119 | pytest.param(torch.tensor([1]), torch.rand((1, INPUT_FEATURES)), id="update 1 neuron"),
120 | pytest.param(
121 | torch.tensor([0, 1]),
122 | torch.rand((2, INPUT_FEATURES)),
123 | id="update 2 neurons with different values",
124 | ),
125 | ],
126 | )
127 | def test_update_dictionary_vectors_with_neurons(
128 | encoder: LinearEncoder,
129 | dictionary_vector_indices: Int64[Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX)],
130 | updates: Float[Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX, Axis.INPUT_OUTPUT_FEATURE)],
131 | ) -> None:
132 | """Test update_dictionary_vectors with 1 or 2 neurons to update."""
133 | with torch.no_grad():
134 | component_idx = 0
135 | encoder.update_dictionary_vectors(
136 | dictionary_vector_indices, updates, component_idx=component_idx
137 | )
138 |
139 | # Check if the specified neurons are updated correctly
140 | assert torch.allclose(
141 | encoder.weight[component_idx, dictionary_vector_indices, :], updates
142 | ), "update_dictionary_vectors should update the weights correctly."
143 |
--------------------------------------------------------------------------------
/sparse_autoencoder/autoencoder/components/tests/test_tied_bias.py:
--------------------------------------------------------------------------------
1 | """Tied Bias Tests."""
2 | from jaxtyping import Float
3 | import torch
4 | from torch import Tensor
5 | from torch.nn import Parameter
6 |
7 | from sparse_autoencoder.autoencoder.components.tied_bias import TiedBias, TiedBiasPosition
8 | from sparse_autoencoder.tensor_types import Axis
9 |
10 |
11 | def test_pre_encoder_subtracts_bias() -> None:
12 | """Check that the pre-encoder bias subtracts the bias."""
13 | encoder_input: Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)] = torch.tensor(
14 | [[5.0, 3.0, 1.0]]
15 | )
16 | bias = Parameter(torch.tensor([2.0, 4.0, 6.0]))
17 | expected = encoder_input - bias
18 |
19 | pre_encoder = TiedBias(bias, TiedBiasPosition.PRE_ENCODER)
20 | output = pre_encoder(encoder_input)
21 |
22 | assert torch.allclose(output, expected)
23 |
24 |
25 | def test_post_encoder_adds_bias() -> None:
26 | """Check that the post-encoder bias adds the bias."""
27 | decoder_output: Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)] = torch.tensor(
28 | [[5.0, 3.0, 1.0]]
29 | )
30 | bias = Parameter(torch.tensor([2.0, 4.0, 6.0]))
31 | expected = decoder_output + bias
32 |
33 | post_decoder = TiedBias(bias, TiedBiasPosition.POST_DECODER)
34 | output = post_decoder(decoder_output)
35 |
36 | assert torch.allclose(output, expected)
37 |
--------------------------------------------------------------------------------
/sparse_autoencoder/autoencoder/components/tests/test_unit_norm_decoder.py:
--------------------------------------------------------------------------------
1 | """Tests for the constrained unit norm linear layer."""
2 | from jaxtyping import Float, Int64
3 | import pytest
4 | import torch
5 | from torch import Tensor
6 |
7 | from sparse_autoencoder.autoencoder.components.unit_norm_decoder import UnitNormDecoder
8 | from sparse_autoencoder.tensor_types import Axis
9 |
10 |
11 | DEFAULT_N_LEARNT_FEATURES = 3
12 | DEFAULT_N_DECODED_FEATURES = 4
13 | DEFAULT_N_COMPONENTS = 2
14 |
15 |
16 | @pytest.fixture()
17 | def decoder() -> UnitNormDecoder:
18 | """Pytest fixture to provide a MockDecoder instance."""
19 | return UnitNormDecoder(
20 | learnt_features=DEFAULT_N_LEARNT_FEATURES,
21 | decoded_features=DEFAULT_N_DECODED_FEATURES,
22 | n_components=DEFAULT_N_COMPONENTS,
23 | )
24 |
25 |
26 | def test_initialization() -> None:
27 | """Test that the weights are initialized with unit norm."""
28 | layer = UnitNormDecoder(learnt_features=3, decoded_features=4, n_components=None)
29 | weight_norms = torch.norm(layer.weight, dim=0)
30 | assert torch.allclose(weight_norms, torch.ones_like(weight_norms))
31 |
32 |
33 | def test_forward_pass() -> None:
34 | """Test the forward pass of the layer."""
35 | layer = UnitNormDecoder(learnt_features=3, decoded_features=4, n_components=None)
36 | input_tensor = torch.randn(5, 3) # Batch size of 5, learnt features of 3
37 | output = layer(input_tensor)
38 | assert output.shape == (5, 4) # Batch size of 5, decoded features of 4
39 |
40 |
41 | def test_multiple_training_steps() -> None:
42 | """Test the unit norm constraint over multiple training iterations."""
43 | torch.random.manual_seed(0)
44 | layer = UnitNormDecoder(learnt_features=3, decoded_features=4, n_components=None)
45 | optimizer = torch.optim.Adam(layer.parameters(), lr=0.01)
46 | for _ in range(4):
47 | data = torch.randn(5, 3)
48 | optimizer.zero_grad()
49 | logits = layer(data)
50 |
51 | loss = torch.mean(logits**2)
52 | loss.backward()
53 | optimizer.step()
54 | layer.constrain_weights_unit_norm()
55 |
56 | columns_norms = torch.norm(layer.weight, dim=0)
57 | assert torch.allclose(columns_norms, torch.ones_like(columns_norms))
58 |
59 |
60 | def test_unit_norm_decreases() -> None:
61 | """Check that the unit norm is applied after each gradient step."""
62 | for _ in range(4):
63 | data = torch.randn((1, 3), requires_grad=True)
64 |
65 | # run with the backward hook
66 | layer = UnitNormDecoder(learnt_features=3, decoded_features=4, n_components=None)
67 | layer_weights = torch.nn.Parameter(layer.weight.clone())
68 | optimizer = torch.optim.SGD(layer.parameters(), lr=0.1, momentum=0)
69 | logits = layer(data)
70 | loss = torch.mean(logits**2)
71 | loss.backward()
72 | optimizer.step()
73 | weight_norms_with_hook = torch.sum(layer.weight**2, dim=0).clone()
74 |
75 | # Run without the hook
76 | layer_without_hook = UnitNormDecoder(
77 | learnt_features=3, decoded_features=4, n_components=None, enable_gradient_hook=False
78 | )
79 | layer_without_hook.weight = layer_weights
80 | optimizer_without_hook = torch.optim.SGD(
81 | layer_without_hook.parameters(), lr=0.1, momentum=0
82 | )
83 | logits_without_hook = layer_without_hook(data)
84 | loss_without_hook = torch.mean(logits_without_hook**2)
85 | loss_without_hook.backward()
86 | optimizer_without_hook.step()
87 | weight_norms_without_hook = torch.sum(layer_without_hook.weight**2, dim=0).clone()
88 |
89 | # Check that the norm with the hook is closer to 1 than without the hook
90 | target_norms = torch.ones_like(weight_norms_with_hook)
91 | absolute_diff_with_hook = torch.abs(weight_norms_with_hook - target_norms)
92 | absolute_diff_without_hook = torch.abs(weight_norms_without_hook - target_norms)
93 | assert torch.all(absolute_diff_with_hook < absolute_diff_without_hook)
94 |
95 |
96 | def test_output_same_without_component_dim_vs_with_1_component() -> None:
97 | """Test the forward pass gives identical results for None and 1 component."""
98 | decoded_features = 2
99 | learnt_features = 4
100 | batch_size = 1
101 |
102 | # Create the layers to compare
103 | torch.manual_seed(1)
104 | decoder_without_components_dim = UnitNormDecoder(
105 | decoded_features=decoded_features, learnt_features=learnt_features, n_components=None
106 | )
107 | torch.manual_seed(1)
108 | decoder_with_1_component = UnitNormDecoder(
109 | decoded_features=decoded_features, learnt_features=learnt_features, n_components=1
110 | )
111 |
112 | # Create the input
113 | input_tensor = torch.randn(batch_size, learnt_features)
114 | input_with_components_dim = input_tensor.unsqueeze(1)
115 |
116 | # Check the output is the same
117 | output_without_components_dim = decoder_without_components_dim(input_tensor)
118 | output_with_1_component = decoder_with_1_component(input_with_components_dim)
119 | assert torch.allclose(output_without_components_dim, output_with_1_component.squeeze(1))
120 |
121 |
122 | def test_update_dictionary_vectors_with_no_neurons(decoder: UnitNormDecoder) -> None:
123 | """Test update_dictionary_vectors with 0 neurons to update."""
124 | original_weight = decoder.weight.clone() # Save original weight for comparison
125 |
126 | dictionary_vector_indices: Int64[
127 | Tensor, Axis.names(Axis.COMPONENT, Axis.INPUT_OUTPUT_FEATURE)
128 | ] = torch.empty((0, 0), dtype=torch.int64)
129 |
130 | updates: Float[
131 | Tensor, Axis.names(Axis.COMPONENT, Axis.INPUT_OUTPUT_FEATURE, Axis.DEAD_FEATURE)
132 | ] = torch.empty((0, 0, 0), dtype=torch.float)
133 |
134 | decoder.update_dictionary_vectors(dictionary_vector_indices, updates)
135 |
136 | # Ensure weight did not change when no indices were provided
137 | assert torch.equal(
138 | decoder.weight, original_weight
139 | ), "Weights should not change when no indices are provided."
140 |
141 |
142 | @pytest.mark.parametrize(
143 | ("dictionary_vector_indices", "updates"),
144 | [
145 | pytest.param(torch.tensor([1]), torch.rand(4, 1), id="One neuron to update"),
146 | pytest.param(
147 | torch.tensor([0, 2]),
148 | torch.rand(4, 2),
149 | id="Two neurons to update",
150 | ),
151 | ],
152 | )
153 | def test_update_dictionary_vectors_with_neurons(
154 | decoder: UnitNormDecoder,
155 | dictionary_vector_indices: Int64[Tensor, Axis.INPUT_OUTPUT_FEATURE],
156 | updates: Float[Tensor, Axis.names(Axis.INPUT_OUTPUT_FEATURE, Axis.DEAD_FEATURE)],
157 | ) -> None:
158 | """Test update_dictionary_vectors with 1 or 2 neurons to update."""
159 | decoder.update_dictionary_vectors(dictionary_vector_indices, updates, component_idx=0)
160 |
161 | # Check if the specified neurons are updated correctly
162 | assert torch.allclose(
163 | decoder.weight[0, :, dictionary_vector_indices], updates
164 | ), "update_dictionary_vectors should update the weights correctly."
165 |
--------------------------------------------------------------------------------
/sparse_autoencoder/autoencoder/components/tied_bias.py:
--------------------------------------------------------------------------------
1 | """Tied Biases (Pre-Encoder and Post-Decoder)."""
2 | from enum import Enum
3 | from typing import final
4 |
5 | from jaxtyping import Float
6 | from torch import Tensor
7 | from torch.nn import Module, Parameter
8 |
9 | from sparse_autoencoder.tensor_types import Axis
10 |
11 |
12 | class TiedBiasPosition(str, Enum):
13 | """Tied Bias Position."""
14 |
15 | PRE_ENCODER = "pre_encoder"
16 | POST_DECODER = "post_decoder"
17 |
18 |
19 | @final
20 | class TiedBias(Module):
21 | """Tied Bias Layer.
22 |
23 | The tied pre-encoder bias is a learned bias term that is subtracted from the input before
24 | encoding, and added back after decoding.
25 |
26 | The bias parameter must be initialised in the parent module, and then passed to this layer.
27 |
28 | https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-bias
29 | """
30 |
31 | _bias_position: TiedBiasPosition
32 |
33 | _bias_reference: Float[
34 | Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
35 | ]
36 |
37 | @property
38 | def bias(
39 | self,
40 | ) -> Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)]:
41 | """Bias."""
42 | return self._bias_reference
43 |
44 | def __init__(
45 | self,
46 | bias_reference: Float[
47 | Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
48 | ],
49 | position: TiedBiasPosition,
50 | ) -> None:
51 | """Initialize the bias layer.
52 |
53 | Args:
54 | bias_reference: Tied bias parameter (initialised in the parent module), used for both
55 | the pre-encoder and post-encoder bias. The original paper initialised this using the
56 | geometric median of the dataset.
57 | position: Whether this is the pre-encoder or post-encoder bias.
58 | """
59 | super().__init__()
60 |
61 | self._bias_reference = bias_reference
62 |
63 | # Support string literals as well as enums
64 | self._bias_position = position
65 |
66 | def forward(
67 | self,
68 | x: Float[
69 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
70 | ],
71 | ) -> Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)]:
72 | """Forward Pass.
73 |
74 | Args:
75 | x: Input tensor.
76 |
77 | Returns:
78 | Output of the forward pass.
79 | """
80 | # If this is the pre-encoder bias, we subtract the bias from the input.
81 | if self._bias_position == TiedBiasPosition.PRE_ENCODER:
82 | return x - self.bias
83 |
84 | # If it's the post-encoder bias, we add the bias to the input.
85 | return x + self.bias
86 |
87 | def extra_repr(self) -> str:
88 | """String extra representation of the module."""
89 | return f"position={self._bias_position.value}"
90 |
--------------------------------------------------------------------------------
/sparse_autoencoder/autoencoder/tests/__snapshots__/test_model.ambr:
--------------------------------------------------------------------------------
1 | # serializer version: 1
2 | # name: test_representation
3 | '''
4 | SparseAutoencoder(
5 | (pre_encoder_bias): TiedBias(position=pre_encoder)
6 | (encoder): LinearEncoder(
7 | input_features=3, learnt_features=6, n_components=None
8 | (activation_function): ReLU()
9 | )
10 | (decoder): UnitNormDecoder(learnt_features=6, decoded_features=3, n_components=None)
11 | (post_decoder_bias): TiedBias(position=post_decoder)
12 | )
13 | '''
14 | # ---
15 |
--------------------------------------------------------------------------------
/sparse_autoencoder/autoencoder/types.py:
--------------------------------------------------------------------------------
1 | """Autoencoder types."""
2 | from typing import NamedTuple
3 |
4 | from torch.nn import Parameter
5 |
6 |
7 | class ResetOptimizerParameterDetails(NamedTuple):
8 | """Reset Optimizer Parameter Details.
9 |
10 | Details of a parameter that should be reset in the optimizer, when resetting
11 | its corresponding dictionary vectors.
12 | """
13 |
14 | parameter: Parameter
15 | """Parameter to reset."""
16 |
17 | axis: int
18 | """Axis of the parameter to reset."""
19 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | """Metrics.
2 |
3 | All metrics are based on torchmetrics, which means they support distributed training and can be
4 | combined with other metrics easily.
5 | """
6 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/loss/__init__.py:
--------------------------------------------------------------------------------
1 | """Loss metrics."""
2 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/loss/l1_absolute_loss.py:
--------------------------------------------------------------------------------
1 | """L1 (absolute error) loss."""
2 | from typing import Any
3 |
4 | from jaxtyping import Float, Int64
5 | from pydantic import PositiveInt, validate_call
6 | import torch
7 | from torch import Tensor
8 | from torchmetrics import Metric
9 |
10 | from sparse_autoencoder.tensor_types import Axis
11 |
12 |
13 | class L1AbsoluteLoss(Metric):
14 | """L1 (absolute error) loss.
15 |
16 | L1 loss penalty is the absolute sum of the learned activations, averaged over the number of
17 | activation vectors.
18 |
19 | Example:
20 | >>> l1_loss = L1AbsoluteLoss()
21 | >>> learned_activations = torch.tensor([
22 | ... [ # Batch 1
23 | ... [1., 0., 1.] # Component 1: learned features (L1 of 2)
24 | ... ],
25 | ... [ # Batch 2
26 | ... [0., 1., 0.] # Component 1: learned features (L1 of 1)
27 | ... ]
28 | ... ])
29 | >>> l1_loss.forward(learned_activations=learned_activations)
30 | tensor(1.5000)
31 | """
32 |
33 | # Torchmetrics settings
34 | is_differentiable: bool | None = True
35 | full_state_update: bool | None = False
36 | plot_lower_bound: float | None = 0.0
37 |
38 | # Settings
39 | _num_components: int
40 | _keep_batch_dim: bool
41 |
42 | @property
43 | def keep_batch_dim(self) -> bool:
44 | """Whether to keep the batch dimension in the loss output."""
45 | return self._keep_batch_dim
46 |
47 | @keep_batch_dim.setter
48 | def keep_batch_dim(self, keep_batch_dim: bool) -> None:
49 | """Set whether to keep the batch dimension in the loss output.
50 |
51 | When setting this we need to change the state to either a list if keeping the batch
52 | dimension (so we can accumulate all the losses and concatenate them at the end along this
53 | dimension). Alternatively it should be a tensor if not keeping the batch dimension (so we
54 | can sum the losses over the batch dimension during update and then take the mean).
55 |
56 | By doing this in a setter we allow changing of this setting after the metric is initialised.
57 | """
58 | self._keep_batch_dim = keep_batch_dim
59 | self.reset() # Reset the metric to update the state
60 | if keep_batch_dim and not isinstance(self.absolute_loss, list):
61 | self.add_state(
62 | "absolute_loss",
63 | default=[],
64 | dist_reduce_fx="sum",
65 | )
66 | elif not isinstance(self.absolute_loss, Tensor):
67 | self.add_state(
68 | "absolute_loss",
69 | default=torch.zeros(self._num_components),
70 | dist_reduce_fx="sum",
71 | )
72 |
73 | # State
74 | absolute_loss: Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL)] | list[
75 | Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)]
76 | ] | None = None
77 | num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM]
78 |
79 | @validate_call
80 | def __init__(
81 | self,
82 | num_components: PositiveInt = 1,
83 | *,
84 | keep_batch_dim: bool = False,
85 | ) -> None:
86 | """Initialize the metric.
87 |
88 | Args:
89 | num_components: Number of components.
90 | keep_batch_dim: Whether to keep the batch dimension in the loss output.
91 | """
92 | super().__init__()
93 | self._num_components = num_components
94 | self.keep_batch_dim = keep_batch_dim
95 | self.add_state(
96 | "num_activation_vectors",
97 | default=torch.tensor(0, dtype=torch.int64),
98 | dist_reduce_fx="sum",
99 | )
100 |
101 | @staticmethod
102 | def calculate_abs_sum(
103 | learned_activations: Float[
104 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
105 | ],
106 | ) -> Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)]:
107 | """Calculate the absolute sum of the learned activations.
108 |
109 | Args:
110 | learned_activations: Learned activations (intermediate activations in the autoencoder).
111 |
112 | Returns:
113 | Absolute sum of the learned activations (keeping the batch and component axis).
114 | """
115 | return torch.abs(learned_activations).sum(dim=-1)
116 |
117 | def update(
118 | self,
119 | learned_activations: Float[
120 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
121 | ],
122 | **kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics)
123 | ) -> None:
124 | """Update the metric state.
125 |
126 | If we're keeping the batch dimension, we simply take the absolute sum of the activations
127 | (over the features dimension) and then append this tensor to a list. Then during compute we
128 | just concatenate and return this list. This is useful for e.g. getting L1 loss by batch item
129 | when resampling neurons (see the neuron resampler for details).
130 |
131 | By contrast if we're averaging over the batch dimension, we sum the activations over the
132 | batch dimension during update (on each process), and then divide by the number of activation
133 | vectors on compute to get the mean.
134 |
135 | Args:
136 | learned_activations: Learned activations (intermediate activations in the autoencoder).
137 | **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection).
138 | """
139 | absolute_loss = self.calculate_abs_sum(learned_activations)
140 |
141 | if self.keep_batch_dim:
142 | self.absolute_loss.append(absolute_loss) # type: ignore
143 | else:
144 | self.absolute_loss += absolute_loss.sum(dim=0)
145 | self.num_activation_vectors += learned_activations.shape[0]
146 |
147 | def compute(self) -> Tensor:
148 | """Compute the metric."""
149 | return (
150 | torch.cat(self.absolute_loss) # type: ignore
151 | if self.keep_batch_dim
152 | else self.absolute_loss / self.num_activation_vectors
153 | )
154 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/loss/l2_reconstruction_loss.py:
--------------------------------------------------------------------------------
1 | """L2 Reconstruction loss."""
2 | from typing import Any
3 |
4 | from jaxtyping import Float, Int64
5 | from pydantic import PositiveInt, validate_call
6 | import torch
7 | from torch import Tensor
8 | from torchmetrics import Metric
9 |
10 | from sparse_autoencoder.tensor_types import Axis
11 |
12 |
13 | class L2ReconstructionLoss(Metric):
14 | """L2 Reconstruction loss (MSE).
15 |
16 | L2 reconstruction loss is calculated as the sum squared error between each each input vector
17 | and it's corresponding decoded vector. The original paper found that models trained with some
18 | loss functions such as cross-entropy loss generally prefer to represent features
19 | polysemantically, whereas models trained with L2 may achieve the same loss for both
20 | polysemantic and monosemantic representations of true features.
21 |
22 | Example:
23 | >>> import torch
24 | >>> loss = L2ReconstructionLoss(num_components=1)
25 | >>> source_activations = torch.tensor([
26 | ... [ # Batch 1
27 | ... [4., 2.] # Component 1
28 | ... ],
29 | ... [ # Batch 2
30 | ... [2., 0.] # Component 1
31 | ... ]
32 | ... ])
33 | >>> decoded_activations = torch.tensor([
34 | ... [ # Batch 1
35 | ... [2., 0.] # Component 1 (MSE of 4)
36 | ... ],
37 | ... [ # Batch 2
38 | ... [0., 0.] # Component 1 (MSE of 2)
39 | ... ]
40 | ... ])
41 | >>> loss.forward(
42 | ... decoded_activations=decoded_activations, source_activations=source_activations
43 | ... )
44 | tensor(3.)
45 | """
46 |
47 | # Torchmetrics settings
48 | is_differentiable: bool | None = True
49 | higher_is_better = False
50 | full_state_update: bool | None = False
51 | plot_lower_bound: float | None = 0.0
52 |
53 | # Settings
54 | _num_components: int
55 | _keep_batch_dim: bool
56 |
57 | @property
58 | def keep_batch_dim(self) -> bool:
59 | """Whether to keep the batch dimension in the loss output."""
60 | return self._keep_batch_dim
61 |
62 | @keep_batch_dim.setter
63 | def keep_batch_dim(self, keep_batch_dim: bool) -> None:
64 | """Set whether to keep the batch dimension in the loss output.
65 |
66 | When setting this we need to change the state to either a list if keeping the batch
67 | dimension (so we can accumulate all the losses and concatenate them at the end along this
68 | dimension). Alternatively it should be a tensor if not keeping the batch dimension (so we
69 | can sum the losses over the batch dimension during update and then take the mean).
70 |
71 | By doing this in a setter we allow changing of this setting after the metric is initialised.
72 | """
73 | self._keep_batch_dim = keep_batch_dim
74 | self.reset() # Reset the metric to update the state
75 | if keep_batch_dim and not isinstance(self.mse, list):
76 | self.add_state(
77 | "mse",
78 | default=[],
79 | dist_reduce_fx="sum",
80 | )
81 | elif not isinstance(self.mse, Tensor):
82 | self.add_state(
83 | "mse",
84 | default=torch.zeros(self._num_components),
85 | dist_reduce_fx="sum",
86 | )
87 |
88 | # State
89 | mse: Float[Tensor, Axis.COMPONENT_OPTIONAL] | list[
90 | Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)]
91 | ] | None = None
92 | num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM]
93 |
94 | @validate_call
95 | def __init__(
96 | self,
97 | num_components: PositiveInt = 1,
98 | *,
99 | keep_batch_dim: bool = False,
100 | ) -> None:
101 | """Initialise the L2 reconstruction loss."""
102 | super().__init__()
103 | self._num_components = num_components
104 | self.keep_batch_dim = keep_batch_dim
105 | self.add_state(
106 | "num_activation_vectors",
107 | default=torch.tensor(0, dtype=torch.int64),
108 | dist_reduce_fx="sum",
109 | )
110 |
111 | @staticmethod
112 | def calculate_mse(
113 | decoded_activations: Float[
114 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
115 | ],
116 | source_activations: Float[
117 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
118 | ],
119 | ) -> Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)]:
120 | """Calculate the MSE."""
121 | return (decoded_activations - source_activations).pow(2).mean(dim=-1)
122 |
123 | def update(
124 | self,
125 | decoded_activations: Float[
126 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
127 | ],
128 | source_activations: Float[
129 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
130 | ],
131 | **kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics)
132 | ) -> None:
133 | """Update the metric state.
134 |
135 | If we're keeping the batch dimension, we simply take the mse of the activations
136 | (over the features dimension) and then append this tensor to a list. Then during compute we
137 | just concatenate and return this list. This is useful for e.g. getting L1 loss by batch item
138 | when resampling neurons (see the neuron resampler for details).
139 |
140 | By contrast if we're averaging over the batch dimension, we sum the activations over the
141 | batch dimension during update (on each process), and then divide by the number of activation
142 | vectors on compute to get the mean.
143 |
144 | Args:
145 | decoded_activations: The decoded activations from the autoencoder.
146 | source_activations: The source activations from the autoencoder.
147 | **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection).
148 | """
149 | mse = self.calculate_mse(decoded_activations, source_activations)
150 |
151 | if self.keep_batch_dim:
152 | self.mse.append(mse) # type: ignore
153 | else:
154 | self.mse += mse.sum(dim=0)
155 | self.num_activation_vectors += source_activations.shape[0]
156 |
157 | def compute(self) -> Float[Tensor, Axis.COMPONENT_OPTIONAL]:
158 | """Compute the metric."""
159 | return (
160 | torch.cat(self.mse) # type: ignore
161 | if self.keep_batch_dim
162 | else self.mse / self.num_activation_vectors
163 | )
164 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/loss/sae_loss.py:
--------------------------------------------------------------------------------
1 | """Sparse Autoencoder loss."""
2 | from typing import Any
3 |
4 | from jaxtyping import Float, Int64
5 | from pydantic import PositiveFloat, PositiveInt, validate_call
6 | import torch
7 | from torch import Tensor
8 | from torchmetrics import Metric
9 |
10 | from sparse_autoencoder.metrics.loss.l1_absolute_loss import L1AbsoluteLoss
11 | from sparse_autoencoder.metrics.loss.l2_reconstruction_loss import L2ReconstructionLoss
12 | from sparse_autoencoder.tensor_types import Axis
13 |
14 |
15 | class SparseAutoencoderLoss(Metric):
16 | """Sparse Autoencoder loss.
17 |
18 | This is the same as composing `L1AbsoluteLoss() * l1_coefficient + L2ReconstructionLoss()`. It
19 | is separated out so that you can use all three metrics (l1, l2, total loss) in the same
20 | `MetricCollection` and they will then share state (to avoid calculating the same thing twice).
21 | """
22 |
23 | # Torchmetrics settings
24 | is_differentiable: bool | None = True
25 | higher_is_better = False
26 | full_state_update: bool | None = False
27 | plot_lower_bound: float | None = 0.0
28 |
29 | # Settings
30 | _num_components: int
31 | _keep_batch_dim: bool
32 | _l1_coefficient: float
33 |
34 | @property
35 | def keep_batch_dim(self) -> bool:
36 | """Whether to keep the batch dimension in the loss output."""
37 | return self._keep_batch_dim
38 |
39 | @keep_batch_dim.setter
40 | def keep_batch_dim(self, keep_batch_dim: bool) -> None:
41 | """Set whether to keep the batch dimension in the loss output.
42 |
43 | When setting this we need to change the state to either a list if keeping the batch
44 | dimension (so we can accumulate all the losses and concatenate them at the end along this
45 | dimension). Alternatively it should be a tensor if not keeping the batch dimension (so we
46 | can sum the losses over the batch dimension during update and then take the mean).
47 |
48 | By doing this in a setter we allow changing of this setting after the metric is initialised.
49 | """
50 | self._keep_batch_dim = keep_batch_dim
51 | self.reset() # Reset the metric to update the state
52 | if keep_batch_dim and not isinstance(self.mse, list):
53 | self.add_state(
54 | "mse",
55 | default=[],
56 | dist_reduce_fx="sum",
57 | )
58 | self.add_state(
59 | "absolute_loss",
60 | default=[],
61 | dist_reduce_fx="sum",
62 | )
63 | elif not isinstance(self.mse, Tensor):
64 | self.add_state(
65 | "mse",
66 | default=torch.zeros(self._num_components),
67 | dist_reduce_fx="sum",
68 | )
69 | self.add_state(
70 | "absolute_loss",
71 | default=torch.zeros(self._num_components),
72 | dist_reduce_fx="sum",
73 | )
74 |
75 | # State
76 | absolute_loss: Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL)] | list[
77 | Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)]
78 | ] | None = None
79 | mse: Float[Tensor, Axis.COMPONENT_OPTIONAL] | list[
80 | Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)]
81 | ] | None = None
82 | num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM]
83 |
84 | @validate_call
85 | def __init__(
86 | self,
87 | num_components: PositiveInt = 1,
88 | l1_coefficient: PositiveFloat = 0.001,
89 | *,
90 | keep_batch_dim: bool = False,
91 | ):
92 | """Initialise the metric."""
93 | super().__init__()
94 | self._num_components = num_components
95 | self.keep_batch_dim = keep_batch_dim
96 | self._l1_coefficient = l1_coefficient
97 |
98 | # Add the state
99 | self.add_state(
100 | "num_activation_vectors",
101 | default=torch.tensor(0, dtype=torch.int64),
102 | dist_reduce_fx="sum",
103 | )
104 |
105 | def update(
106 | self,
107 | source_activations: Float[
108 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
109 | ],
110 | learned_activations: Float[
111 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
112 | ],
113 | decoded_activations: Float[
114 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
115 | ],
116 | **kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics))
117 | ) -> None:
118 | """Update the metric."""
119 | absolute_loss = L1AbsoluteLoss.calculate_abs_sum(learned_activations)
120 | mse = L2ReconstructionLoss.calculate_mse(decoded_activations, source_activations)
121 |
122 | if self.keep_batch_dim:
123 | self.absolute_loss.append(absolute_loss) # type: ignore
124 | self.mse.append(mse) # type: ignore
125 | else:
126 | self.absolute_loss += absolute_loss.sum(dim=0)
127 | self.mse += mse.sum(dim=0)
128 | self.num_activation_vectors += learned_activations.shape[0]
129 |
130 | def compute(self) -> Tensor:
131 | """Compute the metric."""
132 | l1 = (
133 | torch.cat(self.absolute_loss) # type: ignore
134 | if self.keep_batch_dim
135 | else self.absolute_loss / self.num_activation_vectors
136 | )
137 |
138 | l2 = (
139 | torch.cat(self.mse) # type: ignore
140 | if self.keep_batch_dim
141 | else self.mse / self.num_activation_vectors
142 | )
143 |
144 | return l1 * self._l1_coefficient + l2
145 |
146 | def forward( # type: ignore[override] (narrowing)
147 | self,
148 | source_activations: Float[
149 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
150 | ],
151 | learned_activations: Float[
152 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
153 | ],
154 | decoded_activations: Float[
155 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
156 | ],
157 | ) -> Tensor:
158 | """Forward pass."""
159 | return super().forward(
160 | source_activations=source_activations,
161 | learned_activations=learned_activations,
162 | decoded_activations=decoded_activations,
163 | )
164 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/loss/tests/test_l1_absolute_loss.py:
--------------------------------------------------------------------------------
1 | """Test the L1 absolute loss metric."""
2 | from jaxtyping import Float
3 | import pytest
4 | from torch import Tensor, allclose, ones, tensor, zeros
5 |
6 | from sparse_autoencoder.metrics.loss.l1_absolute_loss import L1AbsoluteLoss
7 | from sparse_autoencoder.tensor_types import Axis
8 |
9 |
10 | @pytest.mark.parametrize(
11 | # Each source/decoded tensor is of the form (batch_size, num_components, num_features)
12 | ("learned_activations", "expected_loss"),
13 | [
14 | pytest.param(
15 | zeros(2, 3),
16 | tensor(0.0),
17 | id="All zero activations -> zero loss (single component)",
18 | ),
19 | pytest.param(
20 | zeros(2, 2, 3),
21 | tensor([0.0, 0.0]),
22 | id="All zero activations -> zero loss (2 components)",
23 | ),
24 | pytest.param(
25 | ones(2, 3), # 3 features -> 3.0 loss
26 | tensor(3.0),
27 | id="All 1.0 activations -> 3.0 loss (single component)",
28 | ),
29 | pytest.param(
30 | ones(2, 2, 3),
31 | tensor([3.0, 3.0]),
32 | id="All 1.0 activations -> 3.0 loss (2 components)",
33 | ),
34 | pytest.param(
35 | ones(2, 2, 3) * -1, # Loss is absolute so the same as +ve 1s
36 | tensor([3.0, 3.0]),
37 | id="All -ve 1.0 activations -> 3.0 loss (2 components)",
38 | ),
39 | ],
40 | )
41 | def test_l1_absolute_loss(
42 | learned_activations: Float[
43 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
44 | ],
45 | expected_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL],
46 | ) -> None:
47 | """Test the L1 absolute loss."""
48 | num_components: int = learned_activations.shape[1] if learned_activations.ndim == 3 else 1 # noqa: PLR2004
49 | l1 = L1AbsoluteLoss(num_components)
50 |
51 | res = l1.forward(learned_activations=learned_activations)
52 |
53 | assert allclose(res, expected_loss)
54 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/loss/tests/test_l2_reconstruction_loss.py:
--------------------------------------------------------------------------------
1 | """Test the L2 reconstruction loss metric."""
2 | from jaxtyping import Float
3 | import pytest
4 | from torch import Tensor, allclose, ones, tensor, zeros
5 |
6 | from sparse_autoencoder.metrics.loss.l2_reconstruction_loss import L2ReconstructionLoss
7 | from sparse_autoencoder.tensor_types import Axis
8 |
9 |
10 | @pytest.mark.parametrize(
11 | # Each source/decoded tensor is of the form (batch_size, num_components, num_features)
12 | ("source_activations", "decoded_activations", "expected_loss"),
13 | [
14 | pytest.param(
15 | ones(2, 3),
16 | ones(2, 3),
17 | tensor(0.0),
18 | id="Perfect reconstruction -> zero loss (single component)",
19 | ),
20 | pytest.param(
21 | ones(2, 2, 3),
22 | ones(2, 2, 3),
23 | tensor([0.0, 0.0]),
24 | id="Perfect reconstruction -> zero loss (2 components)",
25 | ),
26 | pytest.param(
27 | zeros(2, 3),
28 | ones(2, 3),
29 | tensor(1.0),
30 | id="All errors 1.0 -> 1.0 loss (single component)",
31 | ),
32 | pytest.param(
33 | zeros(2, 2, 3),
34 | ones(2, 2, 3),
35 | tensor([1.0, 1.0]),
36 | id="All errors 1.0 -> 1.0 loss (2 components)",
37 | ),
38 | ],
39 | )
40 | def test_l2_reconstruction_loss(
41 | source_activations: Float[
42 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
43 | ],
44 | decoded_activations: Float[
45 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
46 | ],
47 | expected_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL],
48 | ) -> None:
49 | """Test the L2 reconstruction loss."""
50 | num_components: int = source_activations.shape[1] if source_activations.ndim == 3 else 1 # noqa: PLR2004
51 | l2 = L2ReconstructionLoss(num_components)
52 |
53 | res = l2.forward(decoded_activations=decoded_activations, source_activations=source_activations)
54 |
55 | assert allclose(res, expected_loss)
56 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/loss/tests/test_sae_loss.py:
--------------------------------------------------------------------------------
1 | """Test the sparse autoencoder loss metric."""
2 | from jaxtyping import Float
3 | import pytest
4 | from torch import Tensor, allclose, ones, rand, tensor, zeros
5 |
6 | from sparse_autoencoder.metrics.loss.l1_absolute_loss import L1AbsoluteLoss
7 | from sparse_autoencoder.metrics.loss.l2_reconstruction_loss import L2ReconstructionLoss
8 | from sparse_autoencoder.metrics.loss.sae_loss import SparseAutoencoderLoss
9 | from sparse_autoencoder.tensor_types import Axis
10 |
11 |
12 | @pytest.mark.parametrize(
13 | # Each source/decoded tensor is of the form (batch_size, num_components, num_features)
14 | (
15 | "source_activations",
16 | "learned_activations",
17 | "decoded_activations",
18 | "l1_coefficient",
19 | "expected_loss",
20 | ),
21 | [
22 | pytest.param(
23 | ones(2, 3),
24 | zeros(2, 4), # Fully sparse = no activity
25 | ones(2, 3),
26 | 0.01,
27 | tensor(0.0),
28 | id="Perfect reconstruction & perfect sparsity -> zero loss (single component)",
29 | ),
30 | pytest.param(
31 | ones(2, 2, 3),
32 | zeros(2, 2, 4),
33 | ones(2, 2, 3),
34 | 0.01,
35 | tensor([0.0, 0.0]),
36 | id="Perfect reconstruction & perfect sparsity -> zero loss (2 components)",
37 | ),
38 | pytest.param(
39 | ones(2, 3),
40 | ones(2, 4), # Abs error of 1.0 per component => average of 4 loss
41 | ones(2, 3),
42 | 0.01,
43 | tensor(0.04),
44 | id="Just sparsity error (single component)",
45 | ),
46 | pytest.param(
47 | ones(2, 2, 3),
48 | ones(2, 2, 4),
49 | ones(2, 2, 3),
50 | 0.01,
51 | tensor([0.04, 0.04]),
52 | id="Just sparsity error (2 components)",
53 | ),
54 | pytest.param(
55 | zeros(2, 3),
56 | zeros(2, 4),
57 | ones(2, 3),
58 | 0.01,
59 | tensor(1.0),
60 | id="Just reconstruction error (single component)",
61 | ),
62 | pytest.param(
63 | zeros(2, 2, 3),
64 | zeros(2, 2, 4),
65 | ones(2, 2, 3),
66 | 0.01,
67 | tensor([1.0, 1.0]),
68 | id="Just reconstruction error (2 components)",
69 | ),
70 | pytest.param(
71 | zeros(2, 3),
72 | ones(2, 4),
73 | ones(2, 3),
74 | 0.01,
75 | tensor(1.04),
76 | id="Sparsity and reconstruction error (single component)",
77 | ),
78 | pytest.param(
79 | zeros(2, 2, 3),
80 | ones(2, 2, 4),
81 | ones(2, 2, 3),
82 | 0.01,
83 | tensor([1.04, 1.04]),
84 | id="Sparsity and reconstruction error (2 components)",
85 | ),
86 | ],
87 | )
88 | def test_sae_loss(
89 | source_activations: Float[
90 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
91 | ],
92 | learned_activations: Float[
93 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
94 | ],
95 | decoded_activations: Float[
96 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
97 | ],
98 | l1_coefficient: float,
99 | expected_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL],
100 | ) -> None:
101 | """Test the SAE loss."""
102 | num_components: int = source_activations.shape[1] if source_activations.ndim == 3 else 1 # noqa: PLR2004
103 | metric = SparseAutoencoderLoss(num_components, l1_coefficient)
104 |
105 | res = metric.forward(
106 | source_activations=source_activations,
107 | learned_activations=learned_activations,
108 | decoded_activations=decoded_activations,
109 | )
110 |
111 | assert allclose(res, expected_loss)
112 |
113 |
114 | def test_compare_sae_loss_to_composition() -> None:
115 | """Test the SAE loss metric against composition of l1 and l2."""
116 | num_components = 3
117 | l1_coefficient = 0.01
118 | l1 = L1AbsoluteLoss(num_components)
119 | l2 = L2ReconstructionLoss(num_components)
120 | composition_loss = l1 * l1_coefficient + l2
121 |
122 | sae_loss = SparseAutoencoderLoss(num_components, l1_coefficient)
123 |
124 | source_activations = rand(2, num_components, 3)
125 | learned_activations = rand(2, num_components, 4)
126 | decoded_activations = rand(2, num_components, 3)
127 |
128 | composition_res = composition_loss.forward(
129 | source_activations=source_activations,
130 | learned_activations=learned_activations,
131 | decoded_activations=decoded_activations,
132 | )
133 |
134 | sae_res = sae_loss.forward(
135 | source_activations=source_activations,
136 | learned_activations=learned_activations,
137 | decoded_activations=decoded_activations,
138 | )
139 |
140 | assert allclose(composition_res, sae_res)
141 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/train/__init__.py:
--------------------------------------------------------------------------------
1 | """Train step metrics."""
2 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/train/capacity.py:
--------------------------------------------------------------------------------
1 | """Capacity Metrics."""
2 | from typing import Any
3 |
4 | import einops
5 | from jaxtyping import Float
6 | import torch
7 | from torch import Tensor
8 | from torchmetrics import Metric
9 |
10 | from sparse_autoencoder.tensor_types import Axis
11 |
12 |
13 | class CapacityMetric(Metric):
14 | """Capacities metric.
15 |
16 | Measure the capacity of a set of features as defined in [Polysemanticity and Capacity in Neural
17 | Networks](https://arxiv.org/pdf/2210.01892.pdf).
18 |
19 | Capacity is intuitively measuring the 'proportion of a dimension' assigned to a feature.
20 | Formally it's the ratio of the squared dot product of a feature with itself to the sum of its
21 | squared dot products of all features.
22 |
23 | Warning:
24 | This is memory intensive as it requires caching all learned activations for a batch.
25 |
26 | Examples:
27 | If the features are orthogonal, the capacity is 1.
28 |
29 | >>> metric = CapacityMetric()
30 | >>> learned_activations = torch.tensor([
31 | ... [ # Batch 1
32 | ... [1., 0., 1.] # Component 1: learned features
33 | ... ],
34 | ... [ # Batch 2
35 | ... [0., 1., 0.] # Component 1: learned features (orthogonal)
36 | ... ]
37 | ... ])
38 | >>> metric.forward(learned_activations)
39 | tensor([[1., 1.]])
40 |
41 | If they are all the same, the capacity is 1/n.
42 |
43 | >>> learned_activations = torch.tensor([
44 | ... [ # Batch 1
45 | ... [1., 1., 1.] # Component 1: learned features
46 | ... ],
47 | ... [ # Batch 2
48 | ... [1., 1., 1.] # Component 1: learned features (same)
49 | ... ]
50 | ... ])
51 | >>> metric.forward(learned_activations)
52 | tensor([[0.5000, 0.5000]])
53 | """
54 |
55 | # Torchmetrics settings
56 | is_differentiable: bool | None = False
57 | full_state_update: bool | None = False
58 | plot_lower_bound: float | None = 0.0
59 | plot_upper_bound: float | None = 1.0
60 |
61 | # State
62 | learned_activations: list[
63 | Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]
64 | ]
65 |
66 | def __init__(self) -> None:
67 | """Initialize the metric."""
68 | super().__init__()
69 | self.add_state("learned_activations", default=[])
70 |
71 | def update(
72 | self,
73 | learned_activations: Float[
74 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
75 | ],
76 | **kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics)
77 | ) -> None:
78 | """Update the metric state.
79 |
80 | Args:
81 | learned_activations: The learned activations.
82 | **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection).
83 | """
84 | self.learned_activations.append(learned_activations)
85 |
86 | @staticmethod
87 | def capacities(
88 | features: Float[
89 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
90 | ],
91 | ) -> Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.BATCH)]:
92 | r"""Calculate capacities.
93 |
94 | Example:
95 | >>> import torch
96 | >>> orthogonal_features = torch.tensor([[[1., 0., 0.]], [[0., 1., 0.]], [[0., 0., 1.]]])
97 | >>> orthogonal_caps = CapacityMetric.capacities(orthogonal_features)
98 | >>> orthogonal_caps
99 | tensor([[1., 1., 1.]])
100 |
101 | Args:
102 | features: A collection of features.
103 |
104 | Returns:
105 | A 1D tensor of capacities, where each element is the capacity of the corresponding
106 | feature.
107 | """
108 | squared_dot_products: Float[
109 | Tensor, Axis.names(Axis.BATCH, Axis.BATCH, Axis.COMPONENT_OPTIONAL)
110 | ] = (
111 | einops.einsum(
112 | features,
113 | features,
114 | f"batch_1 ... {Axis.LEARNT_FEATURE}, \
115 | batch_2 ... {Axis.LEARNT_FEATURE} \
116 | -> ... batch_1 batch_2",
117 | )
118 | ** 2
119 | )
120 |
121 | sum_of_sq_dot: Float[
122 | Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.BATCH)
123 | ] = squared_dot_products.sum(dim=-1)
124 |
125 | diagonal: Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.BATCH)] = torch.diagonal(
126 | squared_dot_products, dim1=1, dim2=2
127 | )
128 |
129 | return diagonal / sum_of_sq_dot
130 |
131 | def compute(
132 | self,
133 | ) -> Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.BATCH)]:
134 | """Compute the metric."""
135 | batch_learned_activations: Float[
136 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
137 | ] = torch.cat(self.learned_activations)
138 |
139 | return self.capacities(batch_learned_activations)
140 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/train/feature_density.py:
--------------------------------------------------------------------------------
1 | """Train batch feature density."""
2 | from typing import Any
3 |
4 | from jaxtyping import Bool, Float, Int64
5 | from pydantic import PositiveInt, validate_call
6 | import torch
7 | from torch import Tensor
8 | from torchmetrics import Metric
9 |
10 | from sparse_autoencoder.tensor_types import Axis
11 | from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions
12 |
13 |
14 | class FeatureDensityMetric(Metric):
15 | """Feature density metric.
16 |
17 | Percentage of samples in which each feature was active (i.e. the neuron has "fired"), in a
18 | training batch.
19 |
20 | Generally we want a small number of features to be active in each batch, so average feature
21 | density should be low. By contrast if the average feature density is high, it means that the
22 | features are not sparse enough.
23 |
24 | Example:
25 | >>> metric = FeatureDensityMetric(num_learned_features=3, num_components=1)
26 | >>> learned_activations = torch.tensor([
27 | ... [ # Batch 1
28 | ... [1., 0., 1.] # Component 1: learned features (2 active neurons)
29 | ... ],
30 | ... [ # Batch 2
31 | ... [0., 0., 0.] # Component 1: learned features (0 active neuron)
32 | ... ]
33 | ... ])
34 | >>> metric.forward(learned_activations)
35 | tensor([[0.5000, 0.0000, 0.5000]])
36 | """
37 |
38 | # Torchmetrics settings
39 | is_differentiable: bool | None = False
40 | full_state_update: bool | None = True
41 | plot_lower_bound: float | None = 0.0
42 | plot_upper_bound: float | None = 1.0
43 |
44 | # State
45 | neuron_fired_count: Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]
46 | num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM]
47 |
48 | @validate_call
49 | def __init__(
50 | self, num_learned_features: PositiveInt, num_components: PositiveInt | None = None
51 | ) -> None:
52 | """Initialise the metric."""
53 | super().__init__()
54 |
55 | self.add_state(
56 | "neuron_fired_count",
57 | default=torch.zeros(
58 | size=shape_with_optional_dimensions(num_components, num_learned_features),
59 | dtype=torch.float, # Float is needed for dist reduce to work
60 | ),
61 | dist_reduce_fx="sum",
62 | )
63 |
64 | self.add_state(
65 | "num_activation_vectors",
66 | default=torch.tensor(0, dtype=torch.int64),
67 | dist_reduce_fx="sum",
68 | )
69 |
70 | def update(
71 | self,
72 | learned_activations: Float[
73 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
74 | ],
75 | **kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics)
76 | ) -> None:
77 | """Update the metric state.
78 |
79 | Args:
80 | learned_activations: The learned activations.
81 | **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection).
82 | """
83 | # Increment the counter of activations seen since the last compute step
84 | self.num_activation_vectors += learned_activations.shape[0]
85 |
86 | # Count the number of active neurons in the batch
87 | neuron_has_fired: Bool[
88 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
89 | ] = torch.gt(learned_activations, 0)
90 |
91 | self.neuron_fired_count += neuron_has_fired.sum(dim=0, dtype=torch.int64)
92 |
93 | def compute(
94 | self,
95 | ) -> Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]:
96 | """Compute the metric."""
97 | return self.neuron_fired_count / self.num_activation_vectors
98 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/train/l0_norm.py:
--------------------------------------------------------------------------------
1 | """L0 norm sparsity metric."""
2 | from typing import Any
3 |
4 | from jaxtyping import Float, Int64
5 | from pydantic import PositiveInt, validate_call
6 | import torch
7 | from torch import Tensor
8 | from torchmetrics import Metric
9 |
10 | from sparse_autoencoder.tensor_types import Axis
11 | from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions
12 |
13 |
14 | class L0NormMetric(Metric):
15 | """Learned activations L0 norm metric.
16 |
17 | The L0 norm is the number of non-zero elements in a learned activation vector, averaged over the
18 | number of activation vectors.
19 |
20 | Examples:
21 | >>> metric = L0NormMetric()
22 | >>> learned_activations = torch.tensor([
23 | ... [1., 0., 1.], # Batch 1 (single component): learned features (2 active neurons)
24 | ... [0., 1., 0.] # Batch 2 (single component): learned features (1 active neuron)
25 | ... ])
26 | >>> metric.forward(learned_activations)
27 | tensor(1.5000)
28 |
29 | With 2 components, the metric will return the average number of active (non-zero)
30 | neurons as a 1d tensor.
31 |
32 | >>> metric = L0NormMetric(num_components=2)
33 | >>> learned_activations = torch.tensor([
34 | ... [ # Batch 1
35 | ... [1., 0., 1.], # Component 1: learned features (2 active neurons)
36 | ... [1., 0., 1.] # Component 2: learned features (2 active neurons)
37 | ... ],
38 | ... [ # Batch 2
39 | ... [0., 1., 0.], # Component 1: learned features (1 active neuron)
40 | ... [1., 0., 1.] # Component 2: learned features (2 active neurons)
41 | ... ]
42 | ... ])
43 | >>> metric.forward(learned_activations)
44 | tensor([1.5000, 2.0000])
45 | """
46 |
47 | # Torchmetrics settings
48 | is_differentiable: bool | None = False
49 | full_state_update: bool | None = False
50 | plot_lower_bound: float | None = 0.0
51 |
52 | # State
53 | active_neurons_count: Float[Tensor, Axis.COMPONENT_OPTIONAL]
54 | num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM]
55 |
56 | @validate_call
57 | def __init__(self, num_components: PositiveInt | None = None) -> None:
58 | """Initialize the metric."""
59 | super().__init__()
60 |
61 | self.add_state(
62 | "active_neurons_count",
63 | default=torch.zeros(shape_with_optional_dimensions(num_components), dtype=torch.float),
64 | dist_reduce_fx="sum", # Float is needed for dist reduce to work
65 | )
66 |
67 | self.add_state(
68 | "num_activation_vectors",
69 | default=torch.tensor(0, dtype=torch.int64),
70 | dist_reduce_fx="sum",
71 | )
72 |
73 | def update(
74 | self,
75 | learned_activations: Float[
76 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
77 | ],
78 | **kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics)
79 | ) -> None:
80 | """Update the metric state.
81 |
82 | Args:
83 | learned_activations: The learned activations.
84 | **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection).
85 | """
86 | self.num_activation_vectors += learned_activations.shape[0]
87 |
88 | self.active_neurons_count += torch.count_nonzero(learned_activations, dim=-1).sum(
89 | dim=0, dtype=torch.int64
90 | )
91 |
92 | def compute(
93 | self,
94 | ) -> Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL)]:
95 | """Compute the metric.
96 |
97 | Note that torchmetrics converts shape `[0]` tensors into scalars (shape `0`).
98 | """
99 | return self.active_neurons_count / self.num_activation_vectors
100 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/train/neuron_activity.py:
--------------------------------------------------------------------------------
1 | """Neuron activity metric."""
2 | from typing import Annotated, Any
3 |
4 | from jaxtyping import Bool, Float, Int64
5 | from pydantic import Field, NonNegativeFloat, PositiveInt, validate_call
6 | import torch
7 | from torch import Tensor
8 | from torchmetrics import Metric
9 |
10 | from sparse_autoencoder.tensor_types import Axis
11 | from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions
12 |
13 |
14 | class NeuronActivityMetric(Metric):
15 | """Neuron activity metric.
16 |
17 | Example:
18 | With a single component and a horizon of 2 activations, the metric will return nothing
19 | after the first activation is added and then computed, and then return the number of dead
20 | neurons after the second activation is added (with update). The breakdown by component isn't
21 | shown here as there is just one component.
22 |
23 | >>> metric = NeuronActivityMetric(num_learned_features=3)
24 | >>> learned_activations = torch.tensor([
25 | ... [1., 0., 1.], # Batch 1 (single component): learned features (2 active neurons)
26 | ... [0., 0., 0.] # Batch 2 (single component): learned features (0 active neuron)
27 | ... ])
28 | >>> metric.forward(learned_activations)
29 | tensor(1)
30 | """
31 |
32 | # Torchmetrics settings
33 | is_differentiable: bool | None = False
34 | full_state_update: bool | None = True
35 | plot_lower_bound: float | None = 0.0
36 |
37 | # Metric settings
38 | _threshold_is_dead_portion_fires: NonNegativeFloat
39 |
40 | # State
41 | neuron_fired_count: Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]
42 | num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM]
43 |
44 | @validate_call
45 | def __init__(
46 | self,
47 | num_learned_features: PositiveInt,
48 | num_components: PositiveInt | None = None,
49 | threshold_is_dead_portion_fires: Annotated[float, Field(strict=True, ge=0, le=1)] = 0.0,
50 | ) -> None:
51 | """Initialise the metric.
52 |
53 | Args:
54 | num_learned_features: Number of learned features.
55 | num_components: Number of components.
56 | threshold_is_dead_portion_fires: Thresholds for counting a neuron as dead (portion of
57 | activation vectors that it fires for must be less than or equal to this number).
58 | Commonly used values are 0.0, 1e-5 and 1e-6.
59 | """
60 | super().__init__()
61 | self._threshold_is_dead_portion_fires = threshold_is_dead_portion_fires
62 |
63 | self.add_state(
64 | "neuron_fired_count",
65 | default=torch.zeros(
66 | shape_with_optional_dimensions(num_components, num_learned_features),
67 | dtype=torch.float, # Float is needed for dist reduce to work
68 | ),
69 | dist_reduce_fx="sum",
70 | )
71 |
72 | self.add_state(
73 | "num_activation_vectors",
74 | default=torch.tensor(0, dtype=torch.int64),
75 | dist_reduce_fx="sum",
76 | )
77 |
78 | def update(
79 | self,
80 | learned_activations: Float[
81 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
82 | ],
83 | **kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics)
84 | ) -> None:
85 | """Update the metric state.
86 |
87 | Args:
88 | learned_activations: The learned activations.
89 | **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection).
90 | """
91 | # Increment the counter of activations seen since the last compute step
92 | self.num_activation_vectors += learned_activations.shape[0]
93 |
94 | # Count the number of active neurons in the batch
95 | neuron_has_fired: Bool[
96 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
97 | ] = torch.gt(learned_activations, 0)
98 |
99 | self.neuron_fired_count += neuron_has_fired.sum(dim=0, dtype=torch.float)
100 |
101 | def compute(self) -> Int64[Tensor, Axis.COMPONENT_OPTIONAL]:
102 | """Compute the metric.
103 |
104 | Note that torchmetrics converts shape `[0]` tensors into scalars (shape `0`).
105 | """
106 | threshold_activations: Float[Tensor, Axis.SINGLE_ITEM] = (
107 | self._threshold_is_dead_portion_fires * self.num_activation_vectors
108 | )
109 |
110 | return torch.sum(
111 | self.neuron_fired_count <= threshold_activations, dim=-1, dtype=torch.int64
112 | )
113 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/train/neuron_fired_count.py:
--------------------------------------------------------------------------------
1 | """Neuron fired count metric."""
2 | from typing import Any
3 |
4 | from jaxtyping import Bool, Float, Int
5 | from pydantic import PositiveInt, validate_call
6 | import torch
7 | from torch import Tensor
8 | from torchmetrics import Metric
9 |
10 | from sparse_autoencoder.tensor_types import Axis
11 | from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions
12 |
13 |
14 | class NeuronFiredCountMetric(Metric):
15 | """Neuron activity metric.
16 |
17 | Example:
18 | >>> metric = NeuronFiredCountMetric(num_learned_features=3)
19 | >>> learned_activations = torch.tensor([
20 | ... [1., 0., 1.], # Batch 1 (single component): learned features (2 active neurons)
21 | ... [0., 0., 0.] # Batch 2 (single component): learned features (0 active neuron)
22 | ... ])
23 | >>> metric.forward(learned_activations)
24 | tensor([1, 0, 1])
25 | """
26 |
27 | # Torchmetrics settings
28 | is_differentiable: bool | None = True
29 | full_state_update: bool | None = True
30 | plot_lower_bound: float | None = 0.0
31 |
32 | # State
33 | neuron_fired_count: Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]
34 |
35 | @validate_call
36 | def __init__(
37 | self,
38 | num_learned_features: PositiveInt,
39 | num_components: PositiveInt | None = None,
40 | ) -> None:
41 | """Initialise the metric.
42 |
43 | Args:
44 | num_learned_features: Number of learned features.
45 | num_components: Number of components.
46 | """
47 | super().__init__()
48 | self.add_state(
49 | "neuron_fired_count",
50 | default=torch.zeros(
51 | shape_with_optional_dimensions(num_components, num_learned_features),
52 | dtype=torch.float, # Float is needed for dist reduce to work
53 | ),
54 | dist_reduce_fx="sum",
55 | )
56 |
57 | def update(
58 | self,
59 | learned_activations: Float[
60 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
61 | ],
62 | **kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics)
63 | ) -> None:
64 | """Update the metric state.
65 |
66 | Args:
67 | learned_activations: The learned activations.
68 | **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection).
69 | """
70 | neuron_has_fired: Bool[
71 | Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
72 | ] = torch.gt(learned_activations, 0)
73 |
74 | self.neuron_fired_count += neuron_has_fired.sum(dim=0, dtype=torch.float)
75 |
76 | def compute(self) -> Int[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]:
77 | """Compute the metric."""
78 | return self.neuron_fired_count.to(dtype=torch.int64)
79 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/train/tests/test_feature_density.py:
--------------------------------------------------------------------------------
1 | """Test the feature density metric."""
2 | import pytest
3 | import torch
4 |
5 | from sparse_autoencoder.metrics.train.feature_density import FeatureDensityMetric
6 |
7 |
8 | @pytest.mark.parametrize(
9 | ("num_learned_features", "num_components", "learned_activations", "expected_output"),
10 | [
11 | pytest.param(
12 | 3,
13 | 1,
14 | torch.tensor([[[1.0, 1.0, 1.0]]]),
15 | torch.tensor([[1.0, 1.0, 1.0]]),
16 | id="Single component axis, all neurons active",
17 | ),
18 | pytest.param(
19 | 3,
20 | None,
21 | torch.tensor([[1.0, 1.0, 1.0]]),
22 | torch.tensor([1.0, 1.0, 1.0]),
23 | id="No component axis, all neurons active",
24 | ),
25 | pytest.param(
26 | 3,
27 | 1,
28 | torch.tensor([[[0.0, 0.0, 0.0]]]),
29 | torch.tensor([[0.0, 0.0, 0.0]]),
30 | id="Single component, no neurons active",
31 | ),
32 | pytest.param(
33 | 3,
34 | 2,
35 | torch.tensor(
36 | [
37 | [ # Batch 1
38 | [1.0, 0.0, 1.0], # Component 1: learned features
39 | [0.0, 1.0, 0.0], # Component 2: learned features
40 | ],
41 | [ # Batch 2
42 | [0.0, 1.0, 0.0], # Component 1: learned features
43 | [1.0, 0.0, 1.0], # Component 2: learned features
44 | ],
45 | ],
46 | ),
47 | torch.tensor(
48 | [
49 | [0.5, 0.5, 0.5], # Component 1: learned features
50 | [0.5, 0.5, 0.5], # Component 2: learned features
51 | ]
52 | ),
53 | id="Multiple components, mixed activity",
54 | ),
55 | ],
56 | )
57 | def test_feature_density_metric(
58 | num_learned_features: int,
59 | num_components: int,
60 | learned_activations: torch.Tensor,
61 | expected_output: torch.Tensor,
62 | ) -> None:
63 | """Test the FeatureDensityMetric for different scenarios.
64 |
65 | Args:
66 | num_learned_features: Number of learned features.
67 | num_components: Number of components.
68 | learned_activations: Learned activations tensor.
69 | expected_output: Expected output tensor.
70 | """
71 | metric = FeatureDensityMetric(num_learned_features, num_components)
72 | result = metric.forward(learned_activations)
73 | assert result.shape == expected_output.shape
74 | assert torch.allclose(result, expected_output)
75 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/train/tests/test_l0_norm.py:
--------------------------------------------------------------------------------
1 | """Test the L0 norm sparsity metric."""
2 | import pytest
3 | import torch
4 |
5 | from sparse_autoencoder.metrics.train.l0_norm import (
6 | L0NormMetric, # Adjust the import path as needed
7 | )
8 |
9 |
10 | @pytest.mark.parametrize(
11 | ("num_components", "learned_activations", "expected_output"),
12 | [
13 | pytest.param(
14 | 1,
15 | torch.tensor([[[1.0, 0.0, 1.0]]]),
16 | torch.tensor(2.0),
17 | id="Single component, mixed activity",
18 | ),
19 | pytest.param(
20 | None,
21 | torch.tensor([[1.0, 0.0, 1.0]]),
22 | torch.tensor(2.0),
23 | id="No component axis, mixed activity",
24 | ),
25 | pytest.param(
26 | 1,
27 | torch.tensor([[[0.0, 0.0, 0.0]]]),
28 | torch.tensor(0.0),
29 | id="Single component, no neurons active",
30 | ),
31 | pytest.param(
32 | 2,
33 | torch.tensor(
34 | [
35 | [[1.0, 0.0, 1.0], [1.0, 0.0, 1.0]],
36 | [[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]],
37 | ]
38 | ),
39 | torch.tensor([1.5, 2.0]),
40 | id="Multiple components, mixed activity",
41 | ),
42 | ],
43 | )
44 | def test_l0_norm_metric(
45 | num_components: int,
46 | learned_activations: torch.Tensor,
47 | expected_output: torch.Tensor,
48 | ) -> None:
49 | """Test the L0NormMetric for different scenarios.
50 |
51 | Args:
52 | num_components: Number of components.
53 | learned_activations: Learned activations tensor.
54 | expected_output: Expected output tensor.
55 | """
56 | metric = L0NormMetric(num_components)
57 | result = metric.forward(learned_activations)
58 |
59 | assert result.shape == expected_output.shape
60 | assert torch.allclose(result, expected_output)
61 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/train/tests/test_neuron_activity.py:
--------------------------------------------------------------------------------
1 | """Test the Neuron Activity Metric."""
2 | import pytest
3 | import torch
4 |
5 | from sparse_autoencoder.metrics.train.neuron_activity import NeuronActivityMetric
6 |
7 |
8 | @pytest.mark.parametrize(
9 | (
10 | "num_learned_features",
11 | "num_components",
12 | "threshold_is_dead_portion_fires",
13 | "learned_activations",
14 | "expected_output",
15 | ),
16 | [
17 | pytest.param(
18 | 3,
19 | 1,
20 | 0,
21 | torch.tensor(
22 | [
23 | [ # Batch 1
24 | [1.0, 0.0, 1.0] # Component 1: learned features (2 active neurons)
25 | ],
26 | [ # Batch 2
27 | [0.0, 0.0, 0.0] # Component 1: learned features (0 active neuron)
28 | ],
29 | ]
30 | ),
31 | torch.tensor(1),
32 | id="Single component, one dead neuron",
33 | ),
34 | pytest.param(
35 | 3,
36 | None,
37 | 0,
38 | torch.tensor([[1.0, 0.0, 1.0], [0.0, 0.0, 0.0]]),
39 | torch.tensor(1),
40 | id="No component axis, one dead neuron",
41 | ),
42 | pytest.param(
43 | 3,
44 | 1,
45 | 0.0,
46 | torch.tensor([[[1.0, 1.0, 1.0]], [[0.0, 0.0, 0.0]]]),
47 | torch.tensor(0),
48 | id="Single component, no dead neurons",
49 | ),
50 | pytest.param(
51 | 3,
52 | 2,
53 | 0,
54 | torch.tensor(
55 | [
56 | [ # Batch 1
57 | [1.0, 0.0, 1.0], # Component 1: learned features
58 | [1.0, 0.0, 1.0], # Component 2: learned features
59 | ],
60 | [ # Batch 2
61 | [0.0, 1.0, 0.0], # Component 1: learned features
62 | [1.0, 0.0, 1.0], # Component 2: learned features
63 | ],
64 | ]
65 | ),
66 | torch.tensor([0, 1]),
67 | id="Multiple components, mixed dead neurons",
68 | ),
69 | ],
70 | )
71 | def test_neuron_activity_metric(
72 | num_learned_features: int,
73 | num_components: int,
74 | threshold_is_dead_portion_fires: float,
75 | learned_activations: torch.Tensor,
76 | expected_output: torch.Tensor,
77 | ) -> None:
78 | """Test the NeuronActivityMetric for different scenarios.
79 |
80 | Args:
81 | num_learned_features: Number of learned features.
82 | num_components: Number of components.
83 | threshold_is_dead_portion_fires: Threshold for counting a neuron as dead.
84 | learned_activations: Learned activations tensor.
85 | expected_output: Expected number of dead neurons.
86 | """
87 | metric = NeuronActivityMetric(
88 | num_learned_features, num_components, threshold_is_dead_portion_fires
89 | )
90 | result = metric.forward(learned_activations)
91 | assert result.shape == expected_output.shape
92 | assert torch.equal(result, expected_output)
93 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/train/tests/test_neuron_fired_count.py:
--------------------------------------------------------------------------------
1 | """Test the Neuron Fired Count Metric."""
2 | import pytest
3 | import torch
4 |
5 | from sparse_autoencoder.metrics.train.neuron_fired_count import NeuronFiredCountMetric
6 |
7 |
8 | @pytest.mark.parametrize(
9 | (
10 | "num_learned_features",
11 | "num_components",
12 | "threshold_is_dead_portion_fires",
13 | "learned_activations",
14 | "expected_output",
15 | ),
16 | [
17 | pytest.param(
18 | 3,
19 | 1,
20 | 0,
21 | torch.tensor(
22 | [
23 | [ # Batch 1
24 | [1.0, 0.0, 1.0] # Component 1: learned features (2 active neurons)
25 | ],
26 | [ # Batch 2
27 | [0.0, 0.0, 0.0] # Component 1: learned features (0 active neuron)
28 | ],
29 | ]
30 | ),
31 | torch.tensor([[1, 0, 1]]),
32 | id="Single component, one dead neuron",
33 | ),
34 | pytest.param(
35 | 3,
36 | None,
37 | 0,
38 | torch.tensor([[1.0, 0.0, 1.0], [0.0, 0.0, 0.0]]),
39 | torch.tensor([1, 0, 1]),
40 | id="No component axis, one dead neuron",
41 | ),
42 | pytest.param(
43 | 3,
44 | 1,
45 | 0.0,
46 | torch.tensor([[[1.0, 1.0, 1.0]], [[0.0, 0.0, 0.0]]]),
47 | torch.tensor([[1, 1, 1]]),
48 | id="Single component, no dead neurons",
49 | ),
50 | pytest.param(
51 | 3,
52 | 2,
53 | 0,
54 | torch.tensor(
55 | [
56 | [ # Batch 1
57 | [1.0, 0.0, 1.0], # Component 1: learned features
58 | [1.0, 0.0, 1.0], # Component 2: learned features
59 | ],
60 | [ # Batch 2
61 | [0.0, 1.0, 0.0], # Component 1: learned features
62 | [1.0, 0.0, 1.0], # Component 2: learned features
63 | ],
64 | ]
65 | ),
66 | torch.tensor([[1, 1, 1], [2, 0, 2]]),
67 | id="Multiple components, mixed dead neurons",
68 | ),
69 | ],
70 | )
71 | def test_neuron_fired_count_metric(
72 | num_learned_features: int,
73 | num_components: int,
74 | threshold_is_dead_portion_fires: float,
75 | learned_activations: torch.Tensor,
76 | expected_output: torch.Tensor,
77 | ) -> None:
78 | """Test the NeuronFiredCount for different scenarios.
79 |
80 | Args:
81 | num_learned_features: Number of learned features.
82 | num_components: Number of components.
83 | threshold_is_dead_portion_fires: Threshold for counting a neuron as dead.
84 | learned_activations: Learned activations tensor.
85 | expected_output: Expected number of dead neurons.
86 | """
87 | metric = NeuronFiredCountMetric(num_learned_features, num_components)
88 | result = metric.forward(learned_activations)
89 | assert result.shape == expected_output.shape
90 | assert torch.equal(result, expected_output)
91 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/validate/__init__.py:
--------------------------------------------------------------------------------
1 | """Validate step metrics."""
2 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/validate/reconstruction_score.py:
--------------------------------------------------------------------------------
1 | """Reconstruction score metric."""
2 | from jaxtyping import Float, Int64
3 | from pydantic import PositiveInt, validate_call
4 | import torch
5 | from torch import Tensor
6 | from torchmetrics import Metric
7 |
8 | from sparse_autoencoder.tensor_types import Axis
9 |
10 |
11 | class ReconstructionScoreMetric(Metric):
12 | r"""Model reconstruction score.
13 |
14 | Creates a score that measures how well the model can reconstruct the data.
15 |
16 | $$
17 | \begin{align*}
18 | v &= \text{number of validation items} \\
19 | l \in{\mathbb{R}^v} &= \text{loss with no changes to the source model} \\
20 | l_\text{recon} \in{\mathbb{R}^v} &= \text{loss with reconstruction} \\
21 | l_\text{zero} \in{\mathbb{R}^v} &= \text{loss with zero ablation} \\
22 | s &= \text{reconstruction score} \\
23 | s_\text{itemwise} &= \frac{l_\text{zero} - l_\text{recon}}{l_\text{zero} - l} \\
24 | s &= \sum_{i=1}^v s_\text{itemwise} / v
25 | \end{align*}
26 | $$
27 |
28 | Example:
29 | >>> metric = ReconstructionScoreMetric(num_components=1)
30 | >>> source_model_loss=torch.tensor([2.0, 2.0, 2.0])
31 | >>> source_model_loss_with_reconstruction=torch.tensor([3.0, 3.0, 3.0])
32 | >>> source_model_loss_with_zero_ablation=torch.tensor([5.0, 5.0, 5.0])
33 | >>> metric.forward(
34 | ... source_model_loss=source_model_loss,
35 | ... source_model_loss_with_reconstruction=source_model_loss_with_reconstruction,
36 | ... source_model_loss_with_zero_ablation=source_model_loss_with_zero_ablation
37 | ... )
38 | tensor(0.6667)
39 | """
40 |
41 | # Torchmetrics settings
42 | is_differentiable: bool | None = False
43 | full_state_update: bool | None = True
44 |
45 | # State
46 | source_model_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL]
47 | source_model_loss_with_zero_ablation: Float[Tensor, Axis.COMPONENT_OPTIONAL]
48 | source_model_loss_with_reconstruction: Float[Tensor, Axis.COMPONENT_OPTIONAL]
49 | num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM]
50 |
51 | @validate_call
52 | def __init__(self, num_components: PositiveInt = 1) -> None:
53 | """Initialise the metric."""
54 | super().__init__()
55 |
56 | self.add_state(
57 | "source_model_loss", default=torch.zeros(num_components), dist_reduce_fx="sum"
58 | )
59 | self.add_state(
60 | "source_model_loss_with_zero_ablation",
61 | default=torch.zeros(num_components),
62 | dist_reduce_fx="sum",
63 | )
64 | self.add_state(
65 | "source_model_loss_with_reconstruction",
66 | default=torch.zeros(num_components),
67 | dist_reduce_fx="sum",
68 | )
69 |
70 | def update(
71 | self,
72 | source_model_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL],
73 | source_model_loss_with_reconstruction: Float[Tensor, Axis.COMPONENT_OPTIONAL],
74 | source_model_loss_with_zero_ablation: Float[Tensor, Axis.COMPONENT_OPTIONAL],
75 | component_idx: int = 0,
76 | ) -> None:
77 | """Update the metric state.
78 |
79 | Args:
80 | source_model_loss: Loss with no changes to the source model.
81 | source_model_loss_with_reconstruction: Loss with SAE reconstruction.
82 | source_model_loss_with_zero_ablation: Loss with zero ablation.
83 | component_idx: Component idx.
84 | """
85 | self.source_model_loss[component_idx] += source_model_loss.sum()
86 | self.source_model_loss_with_zero_ablation[
87 | component_idx
88 | ] += source_model_loss_with_zero_ablation.sum()
89 | self.source_model_loss_with_reconstruction[
90 | component_idx
91 | ] += source_model_loss_with_reconstruction.sum()
92 |
93 | def compute(
94 | self,
95 | ) -> Float[Tensor, Axis.COMPONENT_OPTIONAL]:
96 | """Compute the metric."""
97 | zero_ablate_loss_minus_reconstruction_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL] = (
98 | self.source_model_loss_with_zero_ablation - self.source_model_loss_with_reconstruction
99 | )
100 |
101 | zero_ablate_loss_minus_default_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL] = (
102 | self.source_model_loss_with_zero_ablation - self.source_model_loss
103 | )
104 |
105 | return zero_ablate_loss_minus_reconstruction_loss / zero_ablate_loss_minus_default_loss
106 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/validate/tests/__snapshots__/test_model_reconstruction_score.ambr:
--------------------------------------------------------------------------------
1 | # serializer version: 1
2 | # name: test_weights_biases_log_matches_snapshot
3 | list([
4 | dict({
5 | 'component_0/validate/reconstruction_score/baseline_loss': tensor(0.3800),
6 | 'component_1/validate/reconstruction_score/baseline_loss': tensor(0.5251),
7 | 'component_2/validate/reconstruction_score/baseline_loss': tensor(0.4923),
8 | 'component_3/validate/reconstruction_score/baseline_loss': tensor(0.4598),
9 | 'component_4/validate/reconstruction_score/baseline_loss': tensor(0.4281),
10 | 'component_5/validate/reconstruction_score/baseline_loss': tensor(0.4961),
11 | 'validate/reconstruction_score/baseline_loss/component_mean': tensor(0.4636),
12 | }),
13 | dict({
14 | 'component_0/validate/reconstruction_score/loss_with_reconstruction': tensor(0.6111),
15 | 'component_1/validate/reconstruction_score/loss_with_reconstruction': tensor(0.5219),
16 | 'component_2/validate/reconstruction_score/loss_with_reconstruction': tensor(0.4063),
17 | 'component_3/validate/reconstruction_score/loss_with_reconstruction': tensor(0.6497),
18 | 'component_4/validate/reconstruction_score/loss_with_reconstruction': tensor(0.4929),
19 | 'component_5/validate/reconstruction_score/loss_with_reconstruction': tensor(0.3723),
20 | 'validate/reconstruction_score/loss_with_reconstruction/component_mean': tensor(0.5090),
21 | }),
22 | dict({
23 | 'component_0/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.2891),
24 | 'component_1/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.3879),
25 | 'component_2/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.5850),
26 | 'component_3/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.4740),
27 | 'component_4/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.5452),
28 | 'component_5/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.3733),
29 | 'validate/reconstruction_score/loss_with_zero_ablation/component_mean': tensor(0.4424),
30 | }),
31 | dict({
32 | 'component_0/validate/reconstruction_score': tensor(3.5422),
33 | 'component_1/validate/reconstruction_score': tensor(0.9767),
34 | 'component_2/validate/reconstruction_score': tensor(1.9278),
35 | 'component_3/validate/reconstruction_score': tensor(-12.3338),
36 | 'component_4/validate/reconstruction_score': tensor(0.4468),
37 | 'component_5/validate/reconstruction_score': tensor(-0.0081),
38 | 'validate/reconstruction_score/component_mean': tensor(-0.9081),
39 | }),
40 | ])
41 | # ---
42 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | """Metric wrappers."""
2 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/wrappers/classwise.py:
--------------------------------------------------------------------------------
1 | """Classwise metrics wrapper."""
2 | import torch
3 | from torch import Tensor
4 | from torchmetrics import ClasswiseWrapper, Metric
5 |
6 |
7 | class ClasswiseWrapperWithMean(ClasswiseWrapper):
8 | """Classwise wrapper with mean.
9 |
10 | This metric works together with classification metrics that returns multiple values (one value
11 | per class) such that label information can be automatically included in the output. It extends
12 | the standard torchmetrics wrapper that does this, adding in an additional mean value (across all
13 | classes).
14 | """
15 |
16 | _prefix: str
17 |
18 | labels: list[str]
19 |
20 | def __init__(
21 | self,
22 | metric: Metric,
23 | component_names: list[str] | None = None,
24 | prefix: str | None = None,
25 | ) -> None:
26 | """Initialise the classwise wrapper.
27 |
28 | Args:
29 | metric: Metric to wrap.
30 | component_names: Component names.
31 | prefix: Prefix for the name (will replace the default of the class name).
32 | """
33 | super().__init__(metric, component_names, prefix)
34 |
35 | # Default prefix
36 | if not self._prefix:
37 | self._prefix = f"{self.metric.__class__.__name__.lower()}"
38 |
39 | def _convert(self, x: Tensor) -> dict[str, Tensor]:
40 | """Convert the input tensor to a dictionary of metrics.
41 |
42 | Args:
43 | x: The input tensor.
44 |
45 | Returns:
46 | A dictionary of metric results.
47 | """
48 | # Add a component axis if not present (as Metric squeezes it out)
49 | if x.ndim == 0:
50 | x = x.unsqueeze(dim=0)
51 |
52 | # Same splitting as the original classwise wrapper
53 | res = {f"{self._prefix}/{lab}": val for lab, val in zip(self.labels, x)}
54 |
55 | # Add in the mean
56 | res[f"{self._prefix}/mean"] = x.mean(0, dtype=torch.float)
57 |
58 | return res
59 |
--------------------------------------------------------------------------------
/sparse_autoencoder/metrics/wrappers/tests/test_classwise.py:
--------------------------------------------------------------------------------
1 | """Test the classwise wrapper."""
2 | import pytest
3 | import torch
4 |
5 | from sparse_autoencoder.metrics.train.feature_density import FeatureDensityMetric
6 | from sparse_autoencoder.metrics.train.l0_norm import L0NormMetric
7 | from sparse_autoencoder.metrics.wrappers.classwise import ClasswiseWrapperWithMean
8 |
9 |
10 | @pytest.mark.parametrize(
11 | ("num_components"),
12 | [
13 | pytest.param(1, id="Single component"),
14 | pytest.param(2, id="Multiple components"),
15 | ],
16 | )
17 | def test_feature_density_classwise_wrapper(num_components: int) -> None:
18 | """Test the classwise wrapper."""
19 | metric = FeatureDensityMetric(3, num_components)
20 | component_names = [f"mlp_{n}" for n in range(num_components)]
21 | wrapped_metric = ClasswiseWrapperWithMean(metric, component_names, prefix="feature_density")
22 |
23 | learned_activations = torch.tensor([[[1.0, 1.0, 1.0]] * num_components])
24 | expected_output = torch.tensor([[1.0, 1.0, 1.0]])
25 | res = wrapped_metric.forward(learned_activations=learned_activations)
26 |
27 | for component in component_names:
28 | assert torch.allclose(res[f"feature_density/{component}"], expected_output)
29 |
30 | assert torch.allclose(res["feature_density/mean"], expected_output.mean(0))
31 |
32 |
33 | @pytest.mark.parametrize(
34 | ("num_components"),
35 | [
36 | pytest.param(1, id="Single component"),
37 | pytest.param(2, id="Multiple components"),
38 | ],
39 | )
40 | def test_l0_norm_classwise_wrapper(num_components: int) -> None:
41 | """Test the classwise wrapper."""
42 | metric = L0NormMetric(num_components)
43 | component_names = [f"mlp_{n}" for n in range(num_components)]
44 | wrapped_metric = ClasswiseWrapperWithMean(metric, component_names, prefix="l0")
45 |
46 | learned_activations = torch.tensor([[[1.0, 0.0, 1.0]] * num_components])
47 | expected_output = torch.tensor([2.0])
48 | res = wrapped_metric.forward(learned_activations=learned_activations)
49 |
50 | for component in component_names:
51 | assert torch.allclose(res[f"l0/{component}"], expected_output)
52 |
53 | assert torch.allclose(res["l0/mean"], expected_output.mean(0))
54 |
--------------------------------------------------------------------------------
/sparse_autoencoder/optimizer/__init__.py:
--------------------------------------------------------------------------------
1 | """Optimizers for Sparse Autoencoders."""
2 |
--------------------------------------------------------------------------------
/sparse_autoencoder/optimizer/tests/test_adam_with_reset.py:
--------------------------------------------------------------------------------
1 | """Tests for AdamWithReset optimizer."""
2 | import pytest
3 | import torch
4 |
5 | from sparse_autoencoder.autoencoder.model import SparseAutoencoder, SparseAutoencoderConfig
6 | from sparse_autoencoder.optimizer.adam_with_reset import AdamWithReset
7 |
8 |
9 | @pytest.fixture()
10 | def model_and_optimizer() -> tuple[torch.nn.Module, AdamWithReset]:
11 | """Model and optimizer fixture."""
12 | torch.random.manual_seed(0)
13 | model = SparseAutoencoder(SparseAutoencoderConfig(n_input_features=5, n_learned_features=10))
14 | optimizer = AdamWithReset(
15 | model.parameters(),
16 | named_parameters=model.named_parameters(),
17 | lr=0.0001,
18 | has_components_dim=False,
19 | )
20 |
21 | # Initialise adam state by doing some steps
22 | for _ in range(3):
23 | source_activations = torch.rand((10, 5))
24 | optimizer.zero_grad()
25 | _, decoded_activations = model.forward(source_activations)
26 | dummy_loss = torch.nn.functional.mse_loss(decoded_activations, source_activations)
27 | dummy_loss.backward()
28 | optimizer.step()
29 |
30 | # Force all state values to be non-zero
31 | optimizer.state[model.encoder.weight]["exp_avg"] = (
32 | optimizer.state[model.encoder.weight]["exp_avg"] + 1.0
33 | )
34 | optimizer.state[model.encoder.weight]["exp_avg_sq"] = (
35 | optimizer.state[model.encoder.weight]["exp_avg_sq"] + 1.0
36 | )
37 |
38 | return model, optimizer
39 |
40 |
41 | def test_initialization(model_and_optimizer: tuple[torch.nn.Module, AdamWithReset]) -> None:
42 | """Test initialization of AdamWithReset optimizer."""
43 | model, optimizer = model_and_optimizer
44 | assert len(optimizer.parameter_names) == len(list(model.named_parameters()))
45 |
46 |
47 | def test_reset_state_all_parameters(
48 | model_and_optimizer: tuple[torch.nn.Module, AdamWithReset],
49 | ) -> None:
50 | """Test reset_state_all_parameters method."""
51 | _, optimizer = model_and_optimizer
52 | optimizer.reset_state_all_parameters()
53 |
54 | for group in optimizer.param_groups:
55 | for parameter in group["params"]:
56 | # Get the state
57 | parameter_state = optimizer.state[parameter]
58 | for state_name in parameter_state:
59 | if state_name in ["exp_avg", "exp_avg_sq", "max_exp_avg_sq"]:
60 | # Check all state values are reset to zero
61 | state = parameter_state[state_name]
62 | assert torch.all(state == 0)
63 |
64 |
65 | def test_reset_neurons_state(model_and_optimizer: tuple[torch.nn.Module, AdamWithReset]) -> None:
66 | """Test reset_neurons_state method."""
67 | model, optimizer = model_and_optimizer
68 | optimizer.reset_neurons_state(model.encoder.weight, torch.tensor([1]), axis=0)
69 |
70 | res = optimizer.state[model.encoder.weight]
71 |
72 | assert torch.all(res["exp_avg"][1, :] == 0)
73 | assert not torch.any(res["exp_avg"][2, :] == 0)
74 | assert not torch.any(res["exp_avg"][2:, 1] == 0)
75 |
76 |
77 | def test_reset_neurons_state_no_dead_neurons(
78 | model_and_optimizer: tuple[torch.nn.Module, AdamWithReset],
79 | ) -> None:
80 | """Test reset_neurons_state method with 0 dead neurons."""
81 | model, optimizer = model_and_optimizer
82 |
83 | res = optimizer.state[model.encoder.weight]
84 |
85 | # Example usage
86 | optimizer.reset_neurons_state(model.encoder.weight, torch.tensor([], dtype=torch.int64), axis=0)
87 |
88 | res = optimizer.state[model.encoder.weight]
89 |
90 | assert not torch.any(res["exp_avg"] == 0)
91 | assert not torch.any(res["exp_avg_sq"] == 0)
92 |
--------------------------------------------------------------------------------
/sparse_autoencoder/source_data/__init__.py:
--------------------------------------------------------------------------------
1 | """Source Data."""
2 |
--------------------------------------------------------------------------------
/sparse_autoencoder/source_data/mock_dataset.py:
--------------------------------------------------------------------------------
1 | """Mock dataset.
2 |
3 | For use with tests and simple examples.
4 | """
5 | from collections.abc import Iterator
6 | from typing import Literal, final
7 |
8 | from datasets import IterableDataset
9 | from jaxtyping import Int
10 | from pydantic import PositiveInt, validate_call
11 | import torch
12 | from torch import Tensor
13 | from transformers import PreTrainedTokenizerFast
14 |
15 | from sparse_autoencoder.source_data.abstract_dataset import (
16 | SourceDataset,
17 | TokenizedPrompts,
18 | TorchTokenizedPrompts,
19 | )
20 |
21 |
22 | class ConsecutiveIntHuggingFaceDataset(IterableDataset):
23 | """Consecutive integers Hugging Face dataset for testing.
24 |
25 | Creates a dataset where the first item is [0,1,2...], and the second item is [1,2,3...] and so
26 | on.
27 | """
28 |
29 | _data: Int[Tensor, "items context_size"]
30 | """Generated data."""
31 |
32 | _length: int
33 | """Size of the dataset."""
34 |
35 | _format: Literal["torch", "list"] = "list"
36 | """Format of the data."""
37 |
38 | def create_data(self, n_items: int, context_size: int) -> Int[Tensor, "items context_size"]:
39 | """Create the data.
40 |
41 | Args:
42 | n_items: The number of items in the dataset.
43 | context_size: The number of tokens in the context window.
44 |
45 | Returns:
46 | The generated data.
47 | """
48 | rows = torch.arange(n_items).unsqueeze(1)
49 | columns = torch.arange(context_size).unsqueeze(0)
50 | return rows + columns
51 |
52 | def __init__(self, context_size: int, vocab_size: int = 50_000, n_items: int = 10_000) -> None:
53 | """Initialize the mock HF dataset.
54 |
55 | Args:
56 | context_size: The number of tokens in the context window
57 | vocab_size: The size of the vocabulary to use.
58 | n_items: The number of items in the dataset.
59 |
60 | Raises:
61 | ValueError: If more items are requested than we can create with the vocab size (given
62 | that each item is a consecutive list of integers and unique).
63 | """
64 | self._length = n_items
65 |
66 | # Check we can create the data
67 | if n_items + context_size > vocab_size:
68 | error_message = (
69 | f"n_items ({n_items}) + context_size ({context_size}) must be less than "
70 | f"vocab_size ({vocab_size})"
71 | )
72 | raise ValueError(error_message)
73 |
74 | # Initialise the data
75 | self._data = self.create_data(n_items, context_size)
76 |
77 | def __iter__(self) -> Iterator: # type: ignore (HF typing is incorrect)
78 | """Initialize the iterator.
79 |
80 | Returns:
81 | Iterator.
82 | """
83 | self._index = 0
84 | return self
85 |
86 | def __next__(self) -> TokenizedPrompts | TorchTokenizedPrompts:
87 | """Return the next item in the dataset.
88 |
89 | Returns:
90 | TokenizedPrompts: The next item in the dataset.
91 |
92 | Raises:
93 | StopIteration: If the end of the dataset is reached.
94 | """
95 | if self._index < self._length:
96 | item = self[self._index]
97 | self._index += 1
98 | return item
99 |
100 | raise StopIteration
101 |
102 | def __len__(self) -> int:
103 | """Len Dunder Method."""
104 | return self._length
105 |
106 | def __getitem__(self, index: int) -> TokenizedPrompts | TorchTokenizedPrompts:
107 | """Get Item."""
108 | item = self._data[index]
109 |
110 | if self._format == "torch":
111 | return {"input_ids": item}
112 |
113 | return {"input_ids": item.tolist()}
114 |
115 | def with_format( # type: ignore (only support 2 types)
116 | self,
117 | type: Literal["torch", "list"], # noqa: A002
118 | ) -> "ConsecutiveIntHuggingFaceDataset":
119 | """With Format."""
120 | self._format = type
121 | return self
122 |
123 |
124 | @final
125 | class MockDataset(SourceDataset[TokenizedPrompts]):
126 | """Mock dataset for testing.
127 |
128 | For use with tests and simple examples.
129 | """
130 |
131 | tokenizer: PreTrainedTokenizerFast
132 |
133 | def preprocess(
134 | self,
135 | source_batch: TokenizedPrompts,
136 | *,
137 | context_size: int, # noqa: ARG002
138 | ) -> TokenizedPrompts:
139 | """Preprocess a batch of prompts."""
140 | # Nothing to do here
141 | return source_batch
142 |
143 | @validate_call
144 | def __init__(
145 | self,
146 | context_size: PositiveInt = 250,
147 | buffer_size: PositiveInt = 1000, # noqa: ARG002
148 | preprocess_batch_size: PositiveInt = 1000, # noqa: ARG002
149 | dataset_path: str = "dummy", # noqa: ARG002
150 | dataset_split: str = "train", # noqa: ARG002
151 | ):
152 | """Initialize the Random Int Dummy dataset.
153 |
154 | Example:
155 | >>> data = MockDataset()
156 | >>> first_item = next(iter(data))
157 | >>> len(first_item["input_ids"])
158 | 250
159 |
160 | Args:
161 | context_size: The context size to use when returning a list of tokenized prompts.
162 | *Towards Monosemanticity: Decomposing Language Models With Dictionary Learning* used
163 | a context size of 250.
164 | buffer_size: The buffer size to use when shuffling the dataset. As the dataset is
165 | streamed, this just pre-downloads at least `buffer_size` items and then shuffles
166 | just that buffer. Note that the generated activations should also be shuffled before
167 | training the sparse autoencoder, so a large buffer may not be strictly necessary
168 | here. Note also that this is the number of items in the dataset (e.g. number of
169 | prompts) and is typically significantly less than the number of tokenized prompts
170 | once the preprocessing function has been applied.
171 | preprocess_batch_size: The batch size to use just for preprocessing the dataset (e.g.
172 | tokenizing prompts).
173 | dataset_path: The path to the dataset on Hugging Face.
174 | dataset_split: Dataset split (e.g. `train`).
175 | """
176 | self.dataset = ConsecutiveIntHuggingFaceDataset(context_size=context_size) # type: ignore
177 | self.context_size = context_size
178 |
--------------------------------------------------------------------------------
/sparse_autoencoder/source_data/pretokenized_dataset.py:
--------------------------------------------------------------------------------
1 | """Pre-Tokenized Dataset from Hugging Face.
2 |
3 | PreTokenizedDataset should work with any of the following tokenized datasets:
4 | - NeelNanda/pile-small-tokenized-2b
5 | - NeelNanda/pile-tokenized-10b
6 | - NeelNanda/openwebtext-tokenized-9b
7 | - NeelNanda/c4-tokenized-2b
8 | - NeelNanda/code-tokenized
9 | - NeelNanda/c4-code-tokenized-2b
10 | - NeelNanda/pile-old-tokenized-2b
11 | - alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2
12 |
13 | """
14 | from collections.abc import Mapping, Sequence
15 | from typing import final
16 |
17 | from pydantic import PositiveInt, validate_call
18 |
19 | from sparse_autoencoder.source_data.abstract_dataset import SourceDataset, TokenizedPrompts
20 |
21 |
22 | @final
23 | class PreTokenizedDataset(SourceDataset[dict]):
24 | """General Pre-Tokenized Dataset from Hugging Face.
25 |
26 | Can be used for various datasets available on Hugging Face.
27 | """
28 |
29 | def preprocess(
30 | self,
31 | source_batch: dict,
32 | *,
33 | context_size: int,
34 | ) -> TokenizedPrompts:
35 | """Preprocess a batch of prompts.
36 |
37 | The method splits each pre-tokenized item based on the context size.
38 |
39 | Args:
40 | source_batch: A batch of source data.
41 | context_size: The context size to use for tokenized prompts.
42 |
43 | Returns:
44 | Tokenized prompts.
45 |
46 | Raises:
47 | ValueError: If the context size is larger than the tokenized prompt size.
48 | """
49 | tokenized_prompts: list[list[int]] = source_batch[self._dataset_column_name]
50 |
51 | # Check the context size is not too large
52 | if context_size > len(tokenized_prompts[0]):
53 | error_message = (
54 | f"The context size ({context_size}) is larger than the "
55 | f"tokenized prompt size ({len(tokenized_prompts[0])})."
56 | )
57 | raise ValueError(error_message)
58 |
59 | # Chunk each tokenized prompt into blocks of context_size,
60 | # discarding the last block if too small.
61 | context_size_prompts = []
62 | for encoding in tokenized_prompts:
63 | chunks = [
64 | encoding[i : i + context_size]
65 | for i in range(0, len(encoding), context_size)
66 | if len(encoding[i : i + context_size]) == context_size
67 | ]
68 | context_size_prompts.extend(chunks)
69 |
70 | return {"input_ids": context_size_prompts}
71 |
72 | @validate_call
73 | def __init__(
74 | self,
75 | dataset_path: str,
76 | context_size: PositiveInt = 256,
77 | buffer_size: PositiveInt = 1000,
78 | dataset_dir: str | None = None,
79 | dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None,
80 | dataset_split: str = "train",
81 | dataset_column_name: str = "input_ids",
82 | preprocess_batch_size: PositiveInt = 1000,
83 | *,
84 | pre_download: bool = False,
85 | ):
86 | """Initialize a pre-tokenized dataset from Hugging Face.
87 |
88 | Args:
89 | dataset_path: The path to the dataset on Hugging Face (e.g.
90 | `alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2).
91 | context_size: The context size for tokenized prompts.
92 | buffer_size: The buffer size to use when shuffling the dataset when streaming. When
93 | streaming a dataset, this just pre-downloads at least `buffer_size` items and then
94 | shuffles just that buffer. Note that the generated activations should also be
95 | shuffled before training the sparse autoencoder, so a large buffer may not be
96 | strictly necessary here. Note also that this is the number of items in the dataset
97 | (e.g. number of prompts) and is typically significantly less than the number of
98 | tokenized prompts once the preprocessing function has been applied.
99 | dataset_dir: Defining the `data_dir` of the dataset configuration.
100 | dataset_files: Path(s) to source data file(s).
101 | dataset_split: Dataset split (e.g. `train`).
102 | dataset_column_name: The column name for the tokenized prompts.
103 | preprocess_batch_size: The batch size to use just for preprocessing the dataset (e.g.
104 | tokenizing prompts).
105 | pre_download: Whether to pre-download the whole dataset.
106 | """
107 | super().__init__(
108 | buffer_size=buffer_size,
109 | context_size=context_size,
110 | dataset_dir=dataset_dir,
111 | dataset_files=dataset_files,
112 | dataset_path=dataset_path,
113 | dataset_split=dataset_split,
114 | dataset_column_name=dataset_column_name,
115 | pre_download=pre_download,
116 | preprocess_batch_size=preprocess_batch_size,
117 | )
118 |
--------------------------------------------------------------------------------
/sparse_autoencoder/source_data/tests/test_abstract_dataset.py:
--------------------------------------------------------------------------------
1 | """Test the abstract dataset."""
2 |
3 | import pytest
4 |
5 | from sparse_autoencoder.source_data.abstract_dataset import SourceDataset
6 | from sparse_autoencoder.source_data.mock_dataset import MockDataset
7 |
8 |
9 | @pytest.fixture()
10 | def mock_dataset() -> MockDataset:
11 | """Fixture to create a default ConsecutiveIntHuggingFaceDataset for testing.
12 |
13 | Returns:
14 | ConsecutiveIntHuggingFaceDataset: An instance of the dataset for testing.
15 | """
16 | return MockDataset(context_size=10, buffer_size=100)
17 |
18 |
19 | def test_extended_dataset_initialization(mock_dataset: MockDataset) -> None:
20 | """Test the initialization of the extended dataset."""
21 | assert mock_dataset is not None
22 | assert isinstance(mock_dataset, SourceDataset)
23 |
24 |
25 | def test_extended_dataset_iterator(mock_dataset: MockDataset) -> None:
26 | """Test the iterator of the extended dataset."""
27 | iterator = iter(mock_dataset)
28 | assert iterator is not None
29 |
30 |
31 | def test_get_dataloader(mock_dataset: MockDataset) -> None:
32 | """Test the get_dataloader method of the extended dataset."""
33 | batch_size = 3
34 | dataloader = mock_dataset.get_dataloader(batch_size=batch_size)
35 | first_item = next(iter(dataloader))["input_ids"]
36 | assert first_item.shape[0] == batch_size
37 |
--------------------------------------------------------------------------------
/sparse_autoencoder/source_data/tests/test_mock_dataset.py:
--------------------------------------------------------------------------------
1 | """Tests for the mock dataset."""
2 | import pytest
3 | import torch
4 | from torch import Tensor
5 |
6 | from sparse_autoencoder.source_data.mock_dataset import ConsecutiveIntHuggingFaceDataset
7 |
8 |
9 | class TestConsecutiveIntHuggingFaceDataset:
10 | """Tests for the ConsecutiveIntHuggingFaceDataset."""
11 |
12 | @pytest.fixture(scope="class")
13 | def create_dataset(self) -> ConsecutiveIntHuggingFaceDataset:
14 | """Fixture to create a default ConsecutiveIntHuggingFaceDataset for testing.
15 |
16 | Returns:
17 | ConsecutiveIntHuggingFaceDataset: An instance of the dataset for testing.
18 | """
19 | return ConsecutiveIntHuggingFaceDataset(context_size=10, vocab_size=1000, n_items=100)
20 |
21 | def test_dataset_initialization_failure(self) -> None:
22 | """Test invalid initialization failure of the ConsecutiveIntHuggingFaceDataset."""
23 | with pytest.raises(
24 | ValueError,
25 | match=r"n_items \(\d+\) \+ context_size \(\d+\) must be less than vocab_size \(\d+\)",
26 | ):
27 | ConsecutiveIntHuggingFaceDataset(context_size=40, vocab_size=50, n_items=20)
28 |
29 | def test_dataset_len(self, create_dataset: ConsecutiveIntHuggingFaceDataset) -> None:
30 | """Test the __len__ method of the dataset.
31 |
32 | Args:
33 | create_dataset: Fixture to create a test dataset instance.
34 | """
35 | expected_length = 100
36 | assert len(create_dataset) == expected_length, "Dataset length is not as expected."
37 |
38 | def test_dataset_getitem(self, create_dataset: ConsecutiveIntHuggingFaceDataset) -> None:
39 | """Test the __getitem__ method of the dataset.
40 |
41 | Args:
42 | create_dataset: Fixture to create a test dataset instance.
43 | """
44 | item = create_dataset[0]
45 | assert isinstance(item, dict), "Item should be a dictionary."
46 | assert "input_ids" in item, "Item should have 'input_ids' key."
47 | assert isinstance(item["input_ids"], list), "input_ids should be a list."
48 |
49 | def test_create_data(self, create_dataset: ConsecutiveIntHuggingFaceDataset) -> None:
50 | """Test the create_data method of the dataset.
51 |
52 | Args:
53 | create_dataset: Fixture to create a test dataset instance.
54 | """
55 | data: Tensor = create_dataset.create_data(n_items=10, context_size=5)
56 | assert data.shape == (10, 5), "Data shape is not as expected."
57 |
58 | def test_dataset_iteration(self, create_dataset: ConsecutiveIntHuggingFaceDataset) -> None:
59 | """Test the iteration functionality of the dataset.
60 |
61 | Args:
62 | create_dataset: Fixture to create a test dataset instance.
63 | """
64 | items = [item["input_ids"] for item in create_dataset]
65 |
66 | # Check they are all unique
67 | items_tensor = torch.tensor(items)
68 | unique_items = torch.unique(items_tensor, dim=0)
69 | assert items_tensor.shape == unique_items.shape, "Items are not unique."
70 |
--------------------------------------------------------------------------------
/sparse_autoencoder/source_data/tests/test_pretokenized_dataset.py:
--------------------------------------------------------------------------------
1 | """Tests for General Pre-Tokenized Dataset."""
2 | import pytest
3 |
4 | from sparse_autoencoder.source_data.pretokenized_dataset import PreTokenizedDataset
5 |
6 |
7 | TEST_DATASET = "alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2"
8 |
9 |
10 | @pytest.mark.integration_test()
11 | @pytest.mark.parametrize("context_size", [128, 256])
12 | def test_tokenized_prompts_correct_size(context_size: int) -> None:
13 | """Test that the tokenized prompts have the correct context size."""
14 | data = PreTokenizedDataset(dataset_path=TEST_DATASET, context_size=context_size)
15 |
16 | # Check the first k items
17 | iterable = iter(data.dataset)
18 | for _ in range(2):
19 | item = next(iterable)
20 | assert len(item["input_ids"]) == context_size
21 |
22 | # Check the tokens are integers
23 | for token in item["input_ids"]:
24 | assert isinstance(token, int)
25 |
26 |
27 | @pytest.mark.integration_test()
28 | def test_fails_context_size_too_large() -> None:
29 | """Test that it fails if the context size is set as larger than the source dataset on HF."""
30 | data = PreTokenizedDataset(dataset_path=TEST_DATASET, context_size=512)
31 | with pytest.raises(ValueError, match=r"larger than the tokenized prompt size"):
32 | next(iter(data))
33 |
--------------------------------------------------------------------------------
/sparse_autoencoder/source_data/tests/test_text_dataset.py:
--------------------------------------------------------------------------------
1 | """Pile Uncopyrighted Dataset Tests."""
2 | import pytest
3 | from transformers import GPT2Tokenizer
4 |
5 | from sparse_autoencoder.source_data.text_dataset import TextDataset
6 |
7 |
8 | @pytest.mark.integration_test()
9 | @pytest.mark.parametrize("context_size", [50, 250])
10 | def test_tokenized_prompts_correct_size(context_size: int) -> None:
11 | """Test that the tokenized prompts have the correct context size."""
12 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
13 |
14 | data = TextDataset(
15 | tokenizer=tokenizer, context_size=context_size, dataset_path="monology/pile-uncopyrighted"
16 | )
17 |
18 | # Check the first 100 items
19 | iterable = iter(data.dataset)
20 | for _ in range(100):
21 | item = next(iterable)
22 | assert len(item["input_ids"]) == context_size
23 |
24 | # Check the tokens are integers
25 | for token in item["input_ids"]:
26 | assert isinstance(token, int)
27 |
28 |
29 | @pytest.mark.integration_test()
30 | def test_dataloader_correct_size_items() -> None:
31 | """Test the dataloader returns the correct number & sized items."""
32 | batch_size = 10
33 | context_size = 250
34 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
35 | data = TextDataset(
36 | tokenizer=tokenizer, context_size=context_size, dataset_path="monology/pile-uncopyrighted"
37 | )
38 | dataloader = data.get_dataloader(batch_size=batch_size)
39 |
40 | checks = 100
41 | for item in dataloader:
42 | checks -= 1
43 | if checks == 0:
44 | break
45 |
46 | tokens = item["input_ids"]
47 | assert tokens.shape[0] == batch_size
48 | assert tokens.shape[1] == context_size
49 |
--------------------------------------------------------------------------------
/sparse_autoencoder/source_data/text_dataset.py:
--------------------------------------------------------------------------------
1 | """Generic Text Dataset Module for Hugging Face Datasets.
2 |
3 | GenericTextDataset should work with the following datasets:
4 | - monology/pile-uncopyrighted
5 | - the_pile_openwebtext2
6 | - roneneldan/TinyStories
7 | """
8 | from collections.abc import Mapping, Sequence
9 | from typing import TypedDict, final
10 |
11 | from datasets import IterableDataset
12 | from pydantic import PositiveInt, validate_call
13 | from transformers import PreTrainedTokenizerBase
14 |
15 | from sparse_autoencoder.source_data.abstract_dataset import SourceDataset, TokenizedPrompts
16 |
17 |
18 | class GenericTextDataBatch(TypedDict):
19 | """Generic Text Dataset Batch.
20 |
21 | Assumes the dataset provides a 'text' field with a list of strings.
22 | """
23 |
24 | text: list[str]
25 | meta: list[dict[str, dict[str, str]]] # Optional, depending on the dataset structure.
26 |
27 |
28 | @final
29 | class TextDataset(SourceDataset[GenericTextDataBatch]):
30 | """Generic Text Dataset for any text-based dataset from Hugging Face."""
31 |
32 | tokenizer: PreTrainedTokenizerBase
33 |
34 | def preprocess(
35 | self,
36 | source_batch: GenericTextDataBatch,
37 | *,
38 | context_size: int,
39 | ) -> TokenizedPrompts:
40 | """Preprocess a batch of prompts.
41 |
42 | Tokenizes and chunks text data into lists of tokenized prompts with specified context size.
43 |
44 | Args:
45 | source_batch: A batch of source data, including 'text' with a list of strings.
46 | context_size: Context size for tokenized prompts.
47 |
48 | Returns:
49 | Tokenized prompts.
50 | """
51 | prompts: list[str] = source_batch["text"]
52 |
53 | tokenized_prompts = self.tokenizer(prompts, truncation=True, padding=False)
54 |
55 | # Chunk each tokenized prompt into blocks of context_size, discarding incomplete blocks.
56 | context_size_prompts = []
57 | for encoding in list(tokenized_prompts[self._dataset_column_name]): # type: ignore
58 | chunks = [
59 | encoding[i : i + context_size]
60 | for i in range(0, len(encoding), context_size)
61 | if len(encoding[i : i + context_size]) == context_size
62 | ]
63 | context_size_prompts.extend(chunks)
64 |
65 | return {"input_ids": context_size_prompts}
66 |
67 | @validate_call(config={"arbitrary_types_allowed": True})
68 | def __init__(
69 | self,
70 | dataset_path: str,
71 | tokenizer: PreTrainedTokenizerBase,
72 | buffer_size: PositiveInt = 1000,
73 | context_size: PositiveInt = 256,
74 | dataset_dir: str | None = None,
75 | dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None,
76 | dataset_split: str = "train",
77 | dataset_column_name: str = "input_ids",
78 | n_processes_preprocessing: PositiveInt | None = None,
79 | preprocess_batch_size: PositiveInt = 1000,
80 | *,
81 | pre_download: bool = False,
82 | ):
83 | """Initialize a generic text dataset from Hugging Face.
84 |
85 | Args:
86 | dataset_path: Path to the dataset on Hugging Face (e.g. `'monology/pile-uncopyright'`).
87 | tokenizer: Tokenizer to process text data.
88 | buffer_size: The buffer size to use when shuffling the dataset when streaming. When
89 | streaming a dataset, this just pre-downloads at least `buffer_size` items and then
90 | shuffles just that buffer. Note that the generated activations should also be
91 | shuffled before training the sparse autoencoder, so a large buffer may not be
92 | strictly necessary here. Note also that this is the number of items in the dataset
93 | (e.g. number of prompts) and is typically significantly less than the number of
94 | tokenized prompts once the preprocessing function has been applied.
95 | context_size: The context size to use when returning a list of tokenized prompts.
96 | *Towards Monosemanticity: Decomposing Language Models With Dictionary Learning* used
97 | a context size of 250.
98 | dataset_dir: Defining the `data_dir` of the dataset configuration.
99 | dataset_files: Path(s) to source data file(s).
100 | dataset_split: Dataset split (e.g., 'train').
101 | dataset_column_name: The column name for the prompts.
102 | n_processes_preprocessing: Number of processes to use for preprocessing.
103 | preprocess_batch_size: Batch size for preprocessing (tokenizing prompts).
104 | pre_download: Whether to pre-download the whole dataset.
105 | """
106 | self.tokenizer = tokenizer
107 |
108 | super().__init__(
109 | buffer_size=buffer_size,
110 | context_size=context_size,
111 | dataset_dir=dataset_dir,
112 | dataset_files=dataset_files,
113 | dataset_path=dataset_path,
114 | dataset_split=dataset_split,
115 | dataset_column_name=dataset_column_name,
116 | n_processes_preprocessing=n_processes_preprocessing,
117 | pre_download=pre_download,
118 | preprocess_batch_size=preprocess_batch_size,
119 | )
120 |
121 | @validate_call
122 | def push_to_hugging_face_hub(
123 | self,
124 | repo_id: str,
125 | commit_message: str = "Upload preprocessed dataset using sparse_autoencoder.",
126 | max_shard_size: str | None = None,
127 | n_shards: PositiveInt = 64,
128 | revision: str = "main",
129 | *,
130 | private: bool = False,
131 | ) -> None:
132 | """Share preprocessed dataset to Hugging Face hub.
133 |
134 | Motivation:
135 | Pre-processing a dataset can be time-consuming, so it is useful to be able to share the
136 | pre-processed dataset with others. This function allows you to do that by pushing the
137 | pre-processed dataset to the Hugging Face hub.
138 |
139 | Warning:
140 | You must be logged into HuggingFace (e.g with `huggingface-cli login` from the terminal)
141 | to use this.
142 |
143 | Warning:
144 | This will only work if the dataset is not streamed (i.e. if `pre_download=True` when
145 | initializing the dataset).
146 |
147 | Args:
148 | repo_id: Hugging Face repo ID to save the dataset to (e.g. `username/dataset_name`).
149 | commit_message: Commit message.
150 | max_shard_size: Maximum shard size (e.g. `'500MB'`). Should not be set if `n_shards`
151 | is set.
152 | n_shards: Number of shards to split the dataset into. A high number is recommended
153 | here to allow for flexible distributed training of SAEs across nodes (where e.g.
154 | each node fetches its own shard).
155 | revision: Branch to push to.
156 | private: Whether to save the dataset privately.
157 |
158 | Raises:
159 | TypeError: If the dataset is streamed.
160 | """
161 | if isinstance(self.dataset, IterableDataset):
162 | error_message = (
163 | "Cannot share a streamed dataset to Hugging Face. "
164 | "Please use `pre_download=True` when initializing the dataset."
165 | )
166 | raise TypeError(error_message)
167 |
168 | self.dataset.push_to_hub(
169 | repo_id=repo_id,
170 | commit_message=commit_message,
171 | max_shard_size=max_shard_size,
172 | num_shards=n_shards,
173 | private=private,
174 | revision=revision,
175 | )
176 |
--------------------------------------------------------------------------------
/sparse_autoencoder/source_model/__init__.py:
--------------------------------------------------------------------------------
1 | """Source Model."""
2 |
--------------------------------------------------------------------------------
/sparse_autoencoder/source_model/replace_activations_hook.py:
--------------------------------------------------------------------------------
1 | """Replace activations hook."""
2 | from typing import TYPE_CHECKING
3 |
4 | from jaxtyping import Float
5 | from torch import Tensor
6 | from torch.nn import Module
7 | from torch.nn.parallel import DataParallel
8 | from transformer_lens.hook_points import HookPoint
9 |
10 | from sparse_autoencoder.autoencoder.lightning import LitSparseAutoencoder
11 | from sparse_autoencoder.autoencoder.model import SparseAutoencoder
12 |
13 |
14 | if TYPE_CHECKING:
15 | from sparse_autoencoder.tensor_types import Axis
16 |
17 |
18 | def replace_activations_hook(
19 | value: Tensor,
20 | hook: HookPoint, # noqa: ARG001
21 | sparse_autoencoder: SparseAutoencoder
22 | | DataParallel[SparseAutoencoder]
23 | | LitSparseAutoencoder
24 | | Module,
25 | component_idx: int | None = None,
26 | n_components: int | None = None,
27 | ) -> Tensor:
28 | """Replace activations hook.
29 |
30 | This should be pre-initialised with `functools.partial`.
31 |
32 | Args:
33 | value: The activations to replace.
34 | hook: The hook point.
35 | sparse_autoencoder: The sparse autoencoder.
36 | component_idx: The component index to replace the activations with, if just replacing
37 | activations for a single component. Requires the model to have a component axis.
38 | n_components: The number of components that the SAE is trained on.
39 |
40 | Returns:
41 | Replaced activations.
42 |
43 | Raises:
44 | RuntimeError: If `component_idx` is specified, but the model does not have a component
45 | """
46 | # Squash to just have a "*items" and a "batch" dimension
47 | original_shape = value.shape
48 |
49 | squashed_value: Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)] = value.view(
50 | -1, value.size(-1)
51 | )
52 |
53 | if component_idx is not None:
54 | if n_components is None:
55 | error_message = "The number of model components must be set if component_idx is set."
56 | raise RuntimeError(error_message)
57 |
58 | # The approach here is to run a forward pass with dummy values for all components other than
59 | # the one we want to replace. This is done by expanding the inputs to the SAE for a specific
60 | # component across all components. We then simply discard the activations for all other
61 | # components.
62 | expanded_shape = [
63 | squashed_value.shape[0],
64 | n_components,
65 | squashed_value.shape[-1],
66 | ]
67 | expanded = squashed_value.unsqueeze(1).expand(*expanded_shape)
68 |
69 | _learned_activations, output_activations = sparse_autoencoder.forward(expanded)
70 | component_output_activations = output_activations[:, component_idx]
71 |
72 | return component_output_activations.view(*original_shape)
73 |
74 | # Get the output activations from a forward pass of the SAE
75 | _learned_activations, output_activations = sparse_autoencoder.forward(squashed_value)
76 |
77 | # Reshape to the original shape
78 | return output_activations.view(*original_shape)
79 |
--------------------------------------------------------------------------------
/sparse_autoencoder/source_model/reshape_activations.py:
--------------------------------------------------------------------------------
1 | """Methods to reshape activation tensors."""
2 | from collections.abc import Callable
3 | from functools import reduce
4 | from typing import TypeAlias
5 |
6 | from einops import rearrange
7 | from jaxtyping import Float
8 | from torch import Tensor
9 |
10 | from sparse_autoencoder.tensor_types import Axis
11 |
12 |
13 | ReshapeActivationsFunction: TypeAlias = Callable[
14 | [Float[Tensor, Axis.names(Axis.ANY)]],
15 | Float[Tensor, Axis.names(Axis.STORE_BATCH, Axis.INPUT_OUTPUT_FEATURE)],
16 | ]
17 | """Reshape Activations Function.
18 |
19 | Used within hooks to e.g. reshape activations before storing them in the activation store.
20 | """
21 |
22 |
23 | def reshape_to_last_dimension(
24 | batch_activations: Float[Tensor, Axis.names(Axis.ANY)],
25 | ) -> Float[Tensor, Axis.names(Axis.STORE_BATCH, Axis.INPUT_OUTPUT_FEATURE)]:
26 | """Reshape to Last Dimension.
27 |
28 | Takes a tensor of activation vectors, with arbitrary numbers of dimensions (the last of which is
29 | the neurons dimension), and returns a single tensor of size [item, neurons].
30 |
31 | Examples:
32 | With 2 axis (e.g. pos neuron):
33 |
34 | >>> import torch
35 | >>> input = torch.rand(3, 100)
36 | >>> res = reshape_to_last_dimension(input)
37 | >>> res.shape
38 | torch.Size([3, 100])
39 |
40 | With 3 axis (e.g. batch, pos, neuron):
41 |
42 | >>> input = torch.randn(3, 3, 100)
43 | >>> res = reshape_to_last_dimension(input)
44 | >>> res.shape
45 | torch.Size([9, 100])
46 |
47 | With 4 axis (e.g. batch, pos, head_idx, neuron)
48 |
49 | >>> input = torch.rand(3, 3, 3, 100)
50 | >>> res = reshape_to_last_dimension(input)
51 | >>> res.shape
52 | torch.Size([27, 100])
53 |
54 | Args:
55 | batch_activations: Input Activation Store Batch
56 |
57 | Returns:
58 | Single Tensor of Activation Store Items
59 | """
60 | return rearrange(batch_activations, "... input_output_feature -> (...) input_output_feature")
61 |
62 |
63 | def reshape_concat_last_dimensions(
64 | batch_activations: Float[Tensor, Axis.names(Axis.ANY)],
65 | concat_dims: int,
66 | ) -> Float[Tensor, Axis.names(Axis.STORE_BATCH, Axis.INPUT_OUTPUT_FEATURE)]:
67 | """Reshape to Last Dimension, Concatenating the Specified Dimensions.
68 |
69 | Takes a tensor of activation vectors, with arbitrary numbers of dimensions (the last
70 | `concat_dims` of which are the neuron dimensions), and returns a single tensor of size
71 | [item, neurons].
72 |
73 | Examples:
74 | With 3 axis (e.g. batch, pos, neuron), concatenating the last 2 dimensions:
75 |
76 | >>> import torch
77 | >>> input = torch.randn(3, 4, 5)
78 | >>> res = reshape_concat_last_dimensions(input, 2)
79 | >>> res.shape
80 | torch.Size([3, 20])
81 |
82 | With 4 axis (e.g. batch, pos, head_idx, neuron), concatenating the last 3 dimensions:
83 |
84 | >>> input = torch.rand(2, 3, 4, 5)
85 | >>> res = reshape_concat_last_dimensions(input, 3)
86 | >>> res.shape
87 | torch.Size([2, 60])
88 |
89 | Args:
90 | batch_activations: Input Activation Store Batch
91 | concat_dims: Number of dimensions to concatenate
92 |
93 | Returns:
94 | Single Tensor of Activation Store Items
95 | """
96 | neurons = reduce(lambda x, y: x * y, batch_activations.shape[-concat_dims:])
97 | items = reduce(lambda x, y: x * y, batch_activations.shape[:-concat_dims])
98 |
99 | return batch_activations.reshape(items, neurons)
100 |
--------------------------------------------------------------------------------
/sparse_autoencoder/source_model/store_activations_hook.py:
--------------------------------------------------------------------------------
1 | """TransformerLens Hook for storing activations."""
2 | from jaxtyping import Float
3 | from torch import Tensor
4 | from transformer_lens.hook_points import HookPoint
5 |
6 | from sparse_autoencoder.activation_store.base_store import ActivationStore
7 | from sparse_autoencoder.source_model.reshape_activations import (
8 | ReshapeActivationsFunction,
9 | reshape_to_last_dimension,
10 | )
11 | from sparse_autoencoder.tensor_types import Axis
12 |
13 |
14 | def store_activations_hook(
15 | value: Float[Tensor, Axis.names(Axis.ANY)],
16 | hook: HookPoint, # noqa: ARG001
17 | store: ActivationStore,
18 | reshape_method: ReshapeActivationsFunction = reshape_to_last_dimension,
19 | component_idx: int = 0,
20 | ) -> Float[Tensor, Axis.names(Axis.ANY)]:
21 | """Store Activations Hook.
22 |
23 | Useful for getting just the specific activations wanted, rather than the full cache.
24 |
25 | Example:
26 | First we'll need a source model from TransformerLens and an activation store.
27 |
28 | >>> from functools import partial
29 | >>> from transformer_lens import HookedTransformer
30 | >>> from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
31 | >>> store = TensorActivationStore(max_items=1000, n_neurons=64, n_components=1)
32 | >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
33 | Loaded pretrained model tiny-stories-1M into HookedTransformer
34 |
35 | Next we can add the hook to specific neurons (in this case the first MLP neurons), and
36 | create the tokens for a forward pass.
37 |
38 | >>> model.add_hook(
39 | ... "blocks.0.hook_mlp_out", partial(store_activations_hook, store=store)
40 | ... )
41 | >>> tokens = model.to_tokens("Hello world")
42 | >>> tokens.shape
43 | torch.Size([1, 3])
44 |
45 | Then when we run the model, we should get one activation vector for each token (as we just
46 | have one batch item). Note we also set `stop_at_layer=1` as we don't need the logits or any
47 | other activations after the hook point that we've specified (in this case the first MLP
48 | layer).
49 |
50 | >>> _output = model.forward("Hello world", stop_at_layer=1) # Change this layer as required
51 | >>> len(store)
52 | 3
53 |
54 | Args:
55 | value: The activations to store.
56 | hook: The hook point.
57 | store: The activation store. This should be pre-initialised with `functools.partial`.
58 | reshape_method: The method to reshape the activations before storing them.
59 | component_idx: The component index of the activations to store.
60 |
61 | Returns:
62 | Unmodified activations.
63 | """
64 | reshaped: Float[
65 | Tensor, Axis.names(Axis.STORE_BATCH, Axis.INPUT_OUTPUT_FEATURE)
66 | ] = reshape_method(value)
67 |
68 | store.extend(reshaped, component_idx=component_idx)
69 |
70 | # Return the unmodified value
71 | return value
72 |
--------------------------------------------------------------------------------
/sparse_autoencoder/source_model/tests/test_replace_activations_hook.py:
--------------------------------------------------------------------------------
1 | """Replace activations hook tests."""
2 | from functools import partial
3 |
4 | from jaxtyping import Int
5 | import pytest
6 | import torch
7 | from torch import Tensor
8 | from transformer_lens import HookedTransformer
9 |
10 | from sparse_autoencoder.autoencoder.model import SparseAutoencoder, SparseAutoencoderConfig
11 | from sparse_autoencoder.source_model.replace_activations_hook import replace_activations_hook
12 | from sparse_autoencoder.tensor_types import Axis
13 |
14 |
15 | @pytest.mark.integration_test()
16 | def test_hook_replaces_activations() -> None:
17 | """Test that the hook replaces activations."""
18 | torch.random.manual_seed(0)
19 | source_model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu")
20 | autoencoder = SparseAutoencoder(
21 | SparseAutoencoderConfig(
22 | n_input_features=source_model.cfg.d_model,
23 | n_learned_features=source_model.cfg.d_model * 2,
24 | )
25 | )
26 |
27 | tokens: Int[Tensor, Axis.names(Axis.SOURCE_DATA_BATCH, Axis.POSITION)] = source_model.to_tokens(
28 | "Hello world"
29 | )
30 | loss_without_hook = source_model.forward(tokens, return_type="loss")
31 | loss_with_hook = source_model.run_with_hooks(
32 | tokens,
33 | return_type="loss",
34 | fwd_hooks=[
35 | (
36 | "blocks.0.hook_mlp_out",
37 | partial(replace_activations_hook, sparse_autoencoder=autoencoder),
38 | )
39 | ],
40 | )
41 |
42 | # Check it decrease performance (as the sae is untrained so it will output nonsense).
43 | assert torch.all(torch.gt(loss_with_hook, loss_without_hook))
44 |
45 |
46 | @pytest.mark.integration_test()
47 | def test_hook_replaces_activations_2_components() -> None:
48 | """Test that the hook replaces activations."""
49 | torch.random.manual_seed(0)
50 | source_model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu")
51 | autoencoder = SparseAutoencoder(
52 | SparseAutoencoderConfig(
53 | n_input_features=source_model.cfg.d_model,
54 | n_learned_features=source_model.cfg.d_model * 2,
55 | n_components=2,
56 | )
57 | )
58 |
59 | tokens: Int[Tensor, Axis.names(Axis.SOURCE_DATA_BATCH, Axis.POSITION)] = source_model.to_tokens(
60 | "Hello world"
61 | )
62 | loss_without_hook = source_model.forward(tokens, return_type="loss")
63 | loss_with_hook = source_model.run_with_hooks(
64 | tokens,
65 | return_type="loss",
66 | fwd_hooks=[
67 | (
68 | "blocks.0.hook_mlp_out",
69 | partial(
70 | replace_activations_hook,
71 | sparse_autoencoder=autoencoder,
72 | component_idx=1,
73 | n_components=2,
74 | ),
75 | )
76 | ],
77 | )
78 |
79 | # Check it decrease performance (as the sae is untrained so it will output nonsense).
80 | assert torch.all(torch.gt(loss_with_hook, loss_without_hook))
81 |
--------------------------------------------------------------------------------
/sparse_autoencoder/source_model/tests/test_store_activations_hook.py:
--------------------------------------------------------------------------------
1 | """Store Activations Hook Tests."""
2 | from functools import partial
3 |
4 | from jaxtyping import Int
5 | import pytest
6 | import torch
7 | from torch import Tensor
8 | from transformer_lens import HookedTransformer
9 |
10 | from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
11 | from sparse_autoencoder.source_model.store_activations_hook import store_activations_hook
12 | from sparse_autoencoder.tensor_types import Axis
13 |
14 |
15 | @pytest.mark.integration_test()
16 | def test_hook_stores_activations() -> None:
17 | """Test that the hook stores activations correctly."""
18 | store = TensorActivationStore(max_items=100, n_neurons=256, n_components=1)
19 |
20 | model = HookedTransformer.from_pretrained("tiny-stories-1M")
21 |
22 | model.add_hook(
23 | "blocks.0.mlp.hook_post",
24 | partial(store_activations_hook, store=store),
25 | )
26 |
27 | tokens: Int[Tensor, Axis.names(Axis.SOURCE_DATA_BATCH, Axis.POSITION)] = model.to_tokens(
28 | "Hello world"
29 | )
30 | logits = model.forward(tokens, stop_at_layer=2) # type: ignore
31 |
32 | n_of_tokens = tokens.numel()
33 | mlp_size: int = model.cfg.d_mlp # type: ignore
34 |
35 | assert len(store) == n_of_tokens
36 | assert store[0, 0].shape[0] == mlp_size
37 | assert torch.is_tensor(logits) # Check the forward pass completed
38 |
--------------------------------------------------------------------------------
/sparse_autoencoder/source_model/tests/test_zero_ablate_hook.py:
--------------------------------------------------------------------------------
1 | """Test the zero ablate hook."""
2 | import pytest
3 | import torch
4 | from transformer_lens.hook_points import HookPoint
5 |
6 | from sparse_autoencoder.source_model.zero_ablate_hook import zero_ablate_hook
7 |
8 |
9 | class MockHookPoint(HookPoint):
10 | """Mock HookPoint class."""
11 |
12 |
13 | @pytest.fixture()
14 | def mock_hook_point() -> MockHookPoint:
15 | """Fixture to provide a mock HookPoint instance."""
16 | return MockHookPoint()
17 |
18 |
19 | def test_zero_ablate_hook_with_standard_tensor(mock_hook_point: MockHookPoint) -> None:
20 | """Test zero_ablate_hook with a standard tensor.
21 |
22 | Args:
23 | mock_hook_point: A mock HookPoint instance.
24 | """
25 | value = torch.ones(3, 4)
26 | expected = torch.zeros(3, 4)
27 | result = zero_ablate_hook(value, mock_hook_point)
28 | assert torch.equal(result, expected), "The output tensor should contain only zeros."
29 |
30 |
31 | @pytest.mark.parametrize("shape", [(10,), (5, 5), (2, 3, 4)])
32 | def test_zero_ablate_hook_with_various_shapes(
33 | mock_hook_point: MockHookPoint, shape: tuple[int, ...]
34 | ) -> None:
35 | """Test zero_ablate_hook with tensors of various shapes.
36 |
37 | Args:
38 | mock_hook_point: A mock HookPoint instance.
39 | shape: A tuple representing the shape of the tensor.
40 | """
41 | value = torch.ones(*shape)
42 | expected = torch.zeros(*shape)
43 | result = zero_ablate_hook(value, mock_hook_point)
44 | assert torch.equal(
45 | result, expected
46 | ), f"The output tensor should be of shape {shape} with zeros."
47 |
48 |
49 | def test_float_dtype_maintained(mock_hook_point: MockHookPoint) -> None:
50 | """Test that the float dtype is maintained.
51 |
52 | Args:
53 | mock_hook_point: A mock HookPoint instance.
54 | """
55 | value = torch.ones(3, 4, dtype=torch.float)
56 | result = zero_ablate_hook(value, mock_hook_point)
57 | assert result.dtype == torch.float, "The output tensor should be of dtype float."
58 |
--------------------------------------------------------------------------------
/sparse_autoencoder/source_model/zero_ablate_hook.py:
--------------------------------------------------------------------------------
1 | """Zero ablate hook."""
2 | import torch
3 | from torch import Tensor
4 | from transformer_lens.hook_points import HookPoint
5 |
6 |
7 | def zero_ablate_hook(
8 | value: Tensor,
9 | hook: HookPoint, # noqa: ARG001
10 | ) -> Tensor:
11 | """Zero ablate hook.
12 |
13 | Args:
14 | value: The activations to store.
15 | hook: The hook point.
16 |
17 | Example:
18 | >>> dummy_hook_point = HookPoint()
19 | >>> value = torch.ones(2, 3)
20 | >>> zero_ablate_hook(value, dummy_hook_point)
21 | tensor([[0., 0., 0.],
22 | [0., 0., 0.]])
23 |
24 | Returns:
25 | Replaced activations.
26 | """
27 | return torch.zeros_like(value)
28 |
--------------------------------------------------------------------------------
/sparse_autoencoder/tensor_types.py:
--------------------------------------------------------------------------------
1 | """Tensor Axis Types."""
2 | from enum import auto
3 |
4 | from strenum import LowercaseStrEnum
5 |
6 |
7 | class Axis(LowercaseStrEnum):
8 | """Tensor axis names.
9 |
10 | Used to annotate tensor types.
11 |
12 | Example:
13 | When used directly it prints a string:
14 |
15 | >>> print(Axis.INPUT_OUTPUT_FEATURE)
16 | input_output_feature
17 |
18 | The primary use is to annotate tensor types:
19 |
20 | >>> from jaxtyping import Float
21 | >>> from torch import Tensor
22 | >>> from typing import TypeAlias
23 | >>> batch: TypeAlias = Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)]
24 | >>> print(batch)
25 |
26 |
27 | You can also join multiple axis together to represent the dimensions of a tensor:
28 |
29 | >>> print(Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE))
30 | batch input_output_feature
31 | """
32 |
33 | # Component idx
34 | COMPONENT = auto()
35 | """Component index."""
36 |
37 | COMPONENT_OPTIONAL = "*component"
38 | """Optional component index."""
39 |
40 | # Batches
41 | SOURCE_DATA_BATCH = auto()
42 | """Batch of prompts used to generate source model activations."""
43 |
44 | BATCH = auto()
45 | """Batch of items that the SAE is being trained on."""
46 |
47 | STORE_BATCH = auto()
48 | """Batch of items to be written to the store."""
49 |
50 | ITEMS = auto()
51 | """Arbitrary number of items."""
52 |
53 | # Features
54 | INPUT_OUTPUT_FEATURE = auto()
55 | """Input or output feature (e.g. feature in activation vector from source model)."""
56 |
57 | LEARNT_FEATURE = auto()
58 | """Learn feature (e.g. feature in learnt activation vector)."""
59 |
60 | DEAD_FEATURE = auto()
61 | """Dead feature."""
62 |
63 | ALIVE_FEATURE = auto()
64 | """Alive feature."""
65 |
66 | # Feature indices
67 | INPUT_OUTPUT_FEATURE_IDX = auto()
68 | """Input or output feature index."""
69 |
70 | LEARNT_FEATURE_IDX = auto()
71 | """Learn feature index."""
72 |
73 | # Other
74 | POSITION = auto()
75 | """Token position."""
76 |
77 | SINGLE_ITEM = ""
78 | """Single item axis."""
79 |
80 | ANY = "..."
81 | """Any number of axis."""
82 |
83 | @staticmethod
84 | def names(*axis: "Axis") -> str:
85 | """Join multiple axis together, to represent the dimensions of a tensor.
86 |
87 | Example:
88 | >>> print(Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE))
89 | batch input_output_feature
90 |
91 | Args:
92 | *axis: Axis to join.
93 |
94 | Returns:
95 | Joined axis string.
96 | """
97 | return " ".join(a.value for a in axis)
98 |
--------------------------------------------------------------------------------
/sparse_autoencoder/train/__init__.py:
--------------------------------------------------------------------------------
1 | """Train."""
2 |
--------------------------------------------------------------------------------
/sparse_autoencoder/train/join_sweep.py:
--------------------------------------------------------------------------------
1 | """Join an existing Weights and Biases sweep, as a new agent."""
2 | import argparse
3 |
4 | from sparse_autoencoder.train.sweep import sweep
5 |
6 |
7 | def parse_arguments() -> argparse.Namespace:
8 | """Parse command line arguments.
9 |
10 | Returns:
11 | argparse.Namespace: Parsed command line arguments.
12 | """
13 | parser = argparse.ArgumentParser(description="Join an existing W&B sweep.")
14 | parser.add_argument(
15 | "--id", type=str, default=None, help="Sweep ID for the existing sweep.", required=True
16 | )
17 | return parser.parse_args()
18 |
19 |
20 | def run() -> None:
21 | """Run the join_sweep script."""
22 | args = parse_arguments()
23 |
24 | sweep(sweep_id=args.id)
25 |
26 |
27 | if __name__ == "__main__":
28 | run()
29 |
--------------------------------------------------------------------------------
/sparse_autoencoder/train/tests/__snapshots__/test_sweep.ambr:
--------------------------------------------------------------------------------
1 | # serializer version: 1
2 | # name: test_setup_autoencoder
3 | '''
4 | LitSparseAutoencoder(
5 | (sparse_autoencoder): SparseAutoencoder(
6 | (pre_encoder_bias): TiedBias(position=pre_encoder)
7 | (encoder): LinearEncoder(
8 | input_features=512, learnt_features=2048, n_components=1
9 | (activation_function): ReLU()
10 | )
11 | (decoder): UnitNormDecoder(learnt_features=2048, decoded_features=512, n_components=1)
12 | (post_decoder_bias): TiedBias(position=post_decoder)
13 | )
14 | (loss_fn): SparseAutoencoderLoss()
15 | (train_metrics): MetricCollection(
16 | (activity): ClasswiseWrapperWithMean(
17 | (metric): NeuronActivityMetric()
18 | )
19 | (l0): ClasswiseWrapperWithMean(
20 | (metric): L0NormMetric()
21 | )
22 | (l1): ClasswiseWrapperWithMean(
23 | (metric): L1AbsoluteLoss()
24 | )
25 | (l2): ClasswiseWrapperWithMean(
26 | (metric): L2ReconstructionLoss()
27 | )
28 | (loss): ClasswiseWrapperWithMean(
29 | (metric): SparseAutoencoderLoss()
30 | )
31 | )
32 | (activation_resampler): ActivationResampler()
33 | )
34 | '''
35 | # ---
36 |
--------------------------------------------------------------------------------
/sparse_autoencoder/train/tests/test_sweep.py:
--------------------------------------------------------------------------------
1 | """Tests for sweep functionality."""
2 | import pytest
3 | from syrupy.session import SnapshotSession
4 |
5 | from sparse_autoencoder.train.sweep import setup_autoencoder
6 | from sparse_autoencoder.train.sweep_config import (
7 | RuntimeHyperparameters,
8 | )
9 |
10 |
11 | @pytest.fixture()
12 | def dummy_hyperparameters() -> RuntimeHyperparameters:
13 | """Sweep config dummy fixture."""
14 | return {
15 | "activation_resampler": {
16 | "threshold_is_dead_portion_fires": 0.0,
17 | "max_n_resamples": 4,
18 | "n_activations_activity_collate": 100_000_000,
19 | "resample_dataset_size": 819_200,
20 | "resample_interval": 200_000_000,
21 | },
22 | "autoencoder": {"expansion_factor": 4},
23 | "loss": {"l1_coefficient": 0.0001},
24 | "optimizer": {
25 | "adam_beta_1": 0.9,
26 | "adam_beta_2": 0.99,
27 | "adam_weight_decay": 0.0,
28 | "amsgrad": False,
29 | "fused": False,
30 | "lr": 1e-05,
31 | "lr_scheduler": None,
32 | },
33 | "pipeline": {
34 | "checkpoint_frequency": 100000000,
35 | "log_frequency": 100,
36 | "max_activations": 2000000000,
37 | "max_store_size": 3145728,
38 | "source_data_batch_size": 12,
39 | "train_batch_size": 4096,
40 | "validation_frequency": 314572800,
41 | "validation_n_activations": 1024,
42 | "num_workers_data_loading": 0,
43 | },
44 | "random_seed": 49,
45 | "source_data": {
46 | "context_size": 128,
47 | "dataset_column_name": "input_ids",
48 | "dataset_dir": None,
49 | "dataset_files": None,
50 | "dataset_path": "NeelNanda/c4-code-tokenized-2b",
51 | "pre_download": False,
52 | "pre_tokenized": True,
53 | "tokenizer_name": None,
54 | },
55 | "source_model": {
56 | "dtype": "float32",
57 | "hook_dimension": 512,
58 | "cache_names": ["mlp_out"],
59 | "name": "gelu-2l",
60 | },
61 | }
62 |
63 |
64 | def test_setup_autoencoder(
65 | dummy_hyperparameters: RuntimeHyperparameters, snapshot: SnapshotSession
66 | ) -> None:
67 | """Test the setup_autoencoder function."""
68 | autoencoder = setup_autoencoder(dummy_hyperparameters)
69 | assert snapshot == str(autoencoder), "Autoencoder string representation has changed."
70 |
--------------------------------------------------------------------------------
/sparse_autoencoder/train/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Train Utils."""
2 |
--------------------------------------------------------------------------------
/sparse_autoencoder/train/utils/get_model_device.py:
--------------------------------------------------------------------------------
1 | """Get the device that the model is on."""
2 | from lightning import LightningModule
3 | import torch
4 | from torch.nn import Module
5 | from torch.nn.parallel import DataParallel
6 |
7 |
8 | def get_model_device(model: Module | DataParallel | LightningModule) -> torch.device | None:
9 | """Get the device on which a PyTorch model is on.
10 |
11 | Args:
12 | model: The PyTorch model.
13 |
14 | Returns:
15 | The device ('cuda' or 'cpu') where the model is located.
16 |
17 | Raises:
18 | ValueError: If the model has no parameters.
19 | """
20 | # Deepspeed models already have a device property, so just return that
21 | if hasattr(model, "device"):
22 | return model.device
23 |
24 | # Tensors for lightning should not have device set (as lightning will handle this)
25 | if isinstance(model, LightningModule):
26 | return None
27 |
28 | # Check if the model has parameters
29 | if len(list(model.parameters())) == 0:
30 | exception_message = "The model has no parameters."
31 | raise ValueError(exception_message)
32 |
33 | # Return the device of the first parameter
34 | return next(model.parameters()).device
35 |
--------------------------------------------------------------------------------
/sparse_autoencoder/train/utils/round_down.py:
--------------------------------------------------------------------------------
1 | """Round down to the nearest multiple."""
2 |
3 |
4 | def round_to_multiple(value: int | float, multiple: int) -> int: # noqa: PYI041
5 | """Round down to the nearest multiple.
6 |
7 | Helper function for creating default values.
8 |
9 | Example:
10 | >>> round_to_multiple(1023, 100)
11 | 1000
12 |
13 | Args:
14 | value: The value to round down.
15 | multiple: The multiple to round down to.
16 |
17 | Returns:
18 | The value rounded down to the nearest multiple.
19 |
20 | Raises:
21 | ValueError: If `value` is less than `multiple`.
22 | """
23 | int_value = int(value)
24 |
25 | if int_value < multiple:
26 | error_message = f"{value=} must be greater than or equal to {multiple=}"
27 | raise ValueError(error_message)
28 |
29 | return int_value - int_value % multiple
30 |
--------------------------------------------------------------------------------
/sparse_autoencoder/train/utils/tests/test_get_model_device.py:
--------------------------------------------------------------------------------
1 | """Test get_model_device.py."""
2 | import pytest
3 | import torch
4 | from torch import Tensor
5 | from torch.nn import Linear, Module
6 |
7 | from sparse_autoencoder.train.utils.get_model_device import get_model_device
8 |
9 |
10 | def test_model_on_cpu() -> None:
11 | """Test that it returns CPU when the model is on CPU."""
12 |
13 | class TestModel(Module):
14 | """Test model."""
15 |
16 | def __init__(self) -> None:
17 | """Initialize the model."""
18 | super().__init__()
19 | self.fc = Linear(10, 5)
20 |
21 | def forward(self, x: Tensor) -> Tensor:
22 | """Forward pass."""
23 | return self.fc(x)
24 |
25 | model = TestModel()
26 | model.to("cpu")
27 | assert get_model_device(model) == torch.device("cpu"), "Device should be CPU"
28 |
29 |
30 | # Test with a model that has no parameters
31 | def test_model_no_parameters() -> None:
32 | """Test that it raises a ValueError when the model has no parameters."""
33 |
34 | class EmptyModel(Module):
35 | def forward(self, x: Tensor) -> Tensor:
36 | return x
37 |
38 | model = EmptyModel()
39 | with pytest.raises(ValueError, match="The model has no parameters."):
40 | _ = get_model_device(model)
41 |
--------------------------------------------------------------------------------
/sparse_autoencoder/train/utils/tests/test_wandb_sweep_types.py:
--------------------------------------------------------------------------------
1 | """Test wandb sweep types."""
2 | from dataclasses import dataclass, field
3 |
4 | from sparse_autoencoder.train.utils.wandb_sweep_types import (
5 | Method,
6 | Metric,
7 | NestedParameter,
8 | Parameter,
9 | Parameters,
10 | WandbSweepConfig,
11 | )
12 |
13 |
14 | class TestNestedParameter:
15 | """NestedParameter tests."""
16 |
17 | def test_to_dict(self) -> None:
18 | """Test to_dict method."""
19 |
20 | @dataclass(frozen=True)
21 | class DummyNestedParameter(NestedParameter):
22 | nested_property: Parameter[float] = field(default=Parameter(1.0))
23 |
24 | dummy = DummyNestedParameter()
25 |
26 | # It should be in the nested "parameters" key.
27 | assert dummy.to_dict() == {"parameters": {"nested_property": {"value": 1.0}}}
28 |
29 |
30 | class TestWandbSweepConfig:
31 | """WandbSweepConfig tests."""
32 |
33 | def test_to_dict(self) -> None:
34 | """Test to_dict method."""
35 |
36 | @dataclass(frozen=True)
37 | class DummyNestedParameter(NestedParameter):
38 | nested_property: Parameter[float] = field(default=Parameter(1.0))
39 |
40 | @dataclass
41 | class DummyParameters(Parameters):
42 | nested: DummyNestedParameter = field(default=DummyNestedParameter())
43 | top_level: Parameter[float] = field(default=Parameter(1.0))
44 |
45 | dummy = WandbSweepConfig(
46 | parameters=DummyParameters(), method=Method.GRID, metric=Metric(name="total_loss")
47 | )
48 |
49 | assert dummy.to_dict() == {
50 | "method": "grid",
51 | "metric": {"goal": "minimize", "name": "total_loss"},
52 | "parameters": {
53 | "nested": {
54 | "parameters": {"nested_property": {"value": 1.0}},
55 | },
56 | "top_level": {"value": 1.0},
57 | },
58 | }
59 |
--------------------------------------------------------------------------------
/sparse_autoencoder/training_runs/__init__.py:
--------------------------------------------------------------------------------
1 | """Training runs."""
2 |
--------------------------------------------------------------------------------
/sparse_autoencoder/training_runs/gpt2.py:
--------------------------------------------------------------------------------
1 | """Run an sweep on all layers of GPT2 Small.
2 |
3 | Command:
4 |
5 | ```bash
6 | git clone https://github.com/ai-safety-foundation/sparse_autoencoder.git && cd sparse_autoencoder &&
7 | poetry env use python3.11 && poetry install &&
8 | poetry run python sparse_autoencoder/training_runs/gpt2.py
9 | ```
10 | """
11 | import os
12 |
13 | from sparse_autoencoder import (
14 | ActivationResamplerHyperparameters,
15 | AutoencoderHyperparameters,
16 | Hyperparameters,
17 | LossHyperparameters,
18 | Method,
19 | OptimizerHyperparameters,
20 | Parameter,
21 | PipelineHyperparameters,
22 | SourceDataHyperparameters,
23 | SourceModelHyperparameters,
24 | SweepConfig,
25 | sweep,
26 | )
27 |
28 |
29 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
30 |
31 |
32 | def train() -> None:
33 | """Train."""
34 | sweep_config = SweepConfig(
35 | parameters=Hyperparameters(
36 | loss=LossHyperparameters(
37 | l1_coefficient=Parameter(values=[0.0001]),
38 | ),
39 | optimizer=OptimizerHyperparameters(
40 | lr=Parameter(value=0.0001),
41 | ),
42 | source_model=SourceModelHyperparameters(
43 | name=Parameter("gpt2"),
44 | cache_names=Parameter(
45 | value=[f"blocks.{layer}.hook_mlp_out" for layer in range(12)]
46 | ),
47 | hook_dimension=Parameter(768),
48 | ),
49 | source_data=SourceDataHyperparameters(
50 | dataset_path=Parameter("alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2"),
51 | context_size=Parameter(256),
52 | pre_tokenized=Parameter(value=True),
53 | pre_download=Parameter(value=True),
54 | # Total dataset is c.7bn activations (64 files)
55 | # C. 1.5TB needed to store all activations
56 | dataset_files=Parameter(
57 | [f"data/train-{str(i).zfill(5)}-of-00064.parquet" for i in range(20)]
58 | ),
59 | ),
60 | autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(values=[32, 64])),
61 | pipeline=PipelineHyperparameters(),
62 | activation_resampler=ActivationResamplerHyperparameters(
63 | threshold_is_dead_portion_fires=Parameter(1e-5),
64 | ),
65 | ),
66 | method=Method.GRID,
67 | )
68 |
69 | sweep(sweep_config=sweep_config)
70 |
71 |
72 | if __name__ == "__main__":
73 | train()
74 |
--------------------------------------------------------------------------------
/sparse_autoencoder/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Shared utils."""
2 |
--------------------------------------------------------------------------------
/sparse_autoencoder/utils/data_parallel.py:
--------------------------------------------------------------------------------
1 | """Data parallel utils."""
2 | from typing import Any, Generic, TypeVar
3 |
4 | from torch.nn import DataParallel, Module
5 |
6 |
7 | T = TypeVar("T", bound=Module)
8 |
9 |
10 | class DataParallelWithModelAttributes(DataParallel[T], Generic[T]):
11 | """Data parallel with access to underlying model attributes/methods.
12 |
13 | Allows access to underlying model attributes/methods, which is not possible with the default
14 | `DataParallel` class. Based on:
15 | https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html
16 |
17 | Example:
18 | >>> from sparse_autoencoder import SparseAutoencoder, SparseAutoencoderConfig
19 | >>> model = SparseAutoencoder(SparseAutoencoderConfig(
20 | ... n_input_features=2,
21 | ... n_learned_features=4,
22 | ... ))
23 | >>> distributed_model = DataParallelWithModelAttributes(model)
24 | >>> distributed_model.config.n_learned_features
25 | 4
26 | """
27 |
28 | def __getattr__(self, name: str) -> Any: # noqa: ANN401
29 | """Allow access to underlying model attributes/methods.
30 |
31 | Args:
32 | name: Attribute/method name.
33 |
34 | Returns:
35 | Attribute value/method.
36 | """
37 | try:
38 | return super().__getattr__(name)
39 | except AttributeError:
40 | return getattr(self.module, name)
41 |
--------------------------------------------------------------------------------
/sparse_autoencoder/utils/tensor_shape.py:
--------------------------------------------------------------------------------
1 | """Tensor shape utilities."""
2 |
3 |
4 | def shape_with_optional_dimensions(*shape: int | None) -> tuple[int, ...]:
5 | """Create a shape from a tuple of optional dimensions.
6 |
7 | Motivation:
8 | By default PyTorch tensor shapes will error if you set an axis to `None`. This allows
9 | you to set that size and then the resulting output simply removes that axis.
10 |
11 | Examples:
12 | >>> shape_with_optional_dimensions(1, 2, 3)
13 | (1, 2, 3)
14 |
15 | >>> shape_with_optional_dimensions(1, None, 3)
16 | (1, 3)
17 |
18 | >>> shape_with_optional_dimensions(1, None, None)
19 | (1,)
20 |
21 | >>> shape_with_optional_dimensions(None, None, None)
22 | ()
23 |
24 | Args:
25 | *shape: Axis sizes, with `None` representing an optional axis.
26 |
27 | Returns:
28 | Axis sizes.
29 | """
30 | return tuple(dimension for dimension in shape if dimension is not None)
31 |
--------------------------------------------------------------------------------