├── tests ├── __init__.py ├── test_decode.py └── test_encode.py ├── sparsify ├── sparsify │ └── sparse_coder.py ├── __init__.py ├── sign_sgd.py ├── fused_encoder.py ├── utils.py ├── config.py ├── data.py ├── __main__.py ├── muon.py ├── xformers.py ├── sparse_coder.py └── trainer.py ├── .pre-commit-config.yaml ├── LICENSE ├── .github └── workflows │ └── build.yml ├── CHANGELOG.md ├── pyproject.toml ├── .gitignore └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sparsify/sparsify/sparse_coder.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sparsify/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.3.0" 2 | 3 | from .config import SaeConfig, SparseCoderConfig, TrainConfig, TranscoderConfig 4 | from .sparse_coder import Sae, SparseCoder 5 | from .trainer import SaeTrainer, Trainer 6 | 7 | __all__ = [ 8 | "Sae", 9 | "SaeConfig", 10 | "SaeTrainer", 11 | "SparseCoder", 12 | "SparseCoderConfig", 13 | "Trainer", 14 | "TrainConfig", 15 | "TranscoderConfig", 16 | ] 17 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v6.0.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-added-large-files 10 | - repo: https://github.com/psf/black-pre-commit-mirror 11 | rev: 25.11.0 12 | hooks: 13 | - id: black 14 | - repo: https://github.com/astral-sh/ruff-pre-commit 15 | rev: 'v0.14.5' 16 | hooks: 17 | - id: ruff 18 | args: [--fix, --exit-non-zero-on-fix] 19 | -------------------------------------------------------------------------------- /tests/test_decode.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from sparsify.utils import eager_decode, triton_decode 5 | 6 | 7 | @pytest.mark.parametrize("d_in", [48, 64]) # Power of 2 and not 8 | def test_decode(d_in: int): 9 | batch = 2 10 | d_sae = 128 11 | k = 10 12 | 13 | # Fake data 14 | latents = torch.rand(batch, d_sae, device="cuda") 15 | W_dec = torch.randn(d_sae, d_in, device="cuda") 16 | 17 | top_vals, top_idx = latents.topk(k) 18 | eager_res = eager_decode(top_idx, top_vals, W_dec.mT) 19 | triton_res = triton_decode(top_idx, top_vals, W_dec.mT) 20 | 21 | torch.testing.assert_close(eager_res, triton_res) 22 | -------------------------------------------------------------------------------- /sparsify/sign_sgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | 5 | class SignSGD(Optimizer): 6 | """Steepest descent in the L-infty norm. From """ 7 | 8 | def __init__(self, params, lr: float = 1e-3): 9 | if lr <= 0.0: 10 | raise ValueError(f"Invalid learning rate: {lr}") 11 | 12 | defaults = {"lr": lr} 13 | super(SignSGD, self).__init__(params, defaults) 14 | 15 | @torch.no_grad() 16 | def step(self, closure: None = None) -> None: 17 | assert closure is None, "Closure is not supported." 18 | 19 | for group in self.param_groups: 20 | lr = group["lr"] 21 | for p in group["params"]: 22 | if p.grad is not None: 23 | p.add_(p.grad.sign(), alpha=-lr) 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 EleutherAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/test_encode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from sparsify.fused_encoder import fused_encoder 5 | 6 | 7 | def test_fused_encoder(): 8 | torch.manual_seed(42) 9 | 10 | N, D, M = 8192, 1024, 131_072 # batch_size, input_dim, output_dim 11 | k = 32 12 | 13 | # Example inputs 14 | device = "cuda" 15 | x = torch.randn(N, D, requires_grad=True, device=device) 16 | W = torch.randn(M, D, requires_grad=True, device=device) 17 | b = torch.randn(M, requires_grad=True, device=device) 18 | 19 | from time import monotonic 20 | 21 | start = monotonic() 22 | 23 | output = F.relu(F.linear(x, W, b)) 24 | values_naive, indices_naive = torch.topk(output, k, dim=1, sorted=False) 25 | loss_naive = values_naive.sum() 26 | loss_naive.backward() 27 | 28 | torch.cuda.synchronize() 29 | print("Naive time:", monotonic() - start) 30 | 31 | x_grad_naive = x.grad.clone() 32 | W_grad_naive = W.grad.clone() 33 | b_grad_naive = b.grad.clone() 34 | 35 | # Zero out gradient buffers 36 | x.grad = None 37 | W.grad = None 38 | b.grad = None 39 | 40 | start = monotonic() 41 | 42 | # Forward pass 43 | values, indices, _ = fused_encoder(x, W, b, k, "topk") 44 | 45 | # Dummy loss (sum of top-k values) 46 | loss = values.sum() 47 | loss.backward() 48 | 49 | torch.cuda.synchronize() 50 | print("Fused time:", monotonic() - start) 51 | 52 | torch.testing.assert_close(values, values_naive) 53 | torch.testing.assert_close(indices, indices_naive) 54 | torch.testing.assert_close(loss, loss_naive) 55 | torch.testing.assert_close(x.grad, x_grad_naive) 56 | torch.testing.assert_close(W.grad, W_grad_naive) 57 | torch.testing.assert_close(b.grad, b_grad_naive) 58 | 59 | 60 | if __name__ == "__main__": 61 | test_fused_encoder() 62 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - uses: actions/setup-python@v5 17 | with: 18 | python-version: "3.10" 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install -e ".[dev]" 23 | - name: build 24 | run: pip wheel --no-deps -w dist . 25 | release: 26 | needs: build 27 | permissions: 28 | contents: write 29 | id-token: write 30 | if: github.event_name == 'push' && github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, 'chore(release):') 31 | runs-on: ubuntu-latest 32 | concurrency: release 33 | steps: 34 | - uses: actions/checkout@v4 35 | with: 36 | fetch-depth: 0 37 | - uses: actions/setup-python@v5 38 | with: 39 | python-version: "3.11" 40 | - name: Install dependencies 41 | run: pip install build twine 42 | - name: Semantic Release 43 | id: release 44 | uses: python-semantic-release/python-semantic-release@v9.19.1 45 | with: 46 | github_token: ${{ secrets.GITHUB_TOKEN }} 47 | - name: Build package 48 | run: python -m build 49 | if: steps.release.outputs.released == 'true' 50 | - name: Publish package distributions to PyPI 51 | uses: pypa/gh-action-pypi-publish@release/v1 52 | if: steps.release.outputs.released == 'true' 53 | - name: Publish package distributions to GitHub Releases 54 | uses: python-semantic-release/publish-action@main 55 | if: steps.release.outputs.released == 'true' 56 | with: 57 | github_token: ${{ secrets.GITHUB_TOKEN }} 58 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # CHANGELOG 2 | 3 | 4 | ## v1.3.0 (2025-11-17) 5 | 6 | ### Features 7 | 8 | - Pass arbitrary ds loading arguments 9 | ([`80ebed4`](https://github.com/EleutherAI/sparsify/commit/80ebed43f511601e73b71a9213d14a93e6ec686d)) 10 | 11 | 12 | ## v1.2.2 (2025-09-30) 13 | 14 | ### Bug Fixes 15 | 16 | - Revert "Merge pull request #118 from EleutherAI/exclude-follow-up" 17 | ([`304502c`](https://github.com/EleutherAI/sparsify/commit/304502c1a48b1aa6ae734fc3e8893d4743df6006)) 18 | 19 | This reverts commit 5c9c6fb89448feb4b87b23254a52ad02e60c10db, reversing changes made to 20 | bb792438e62b1ebe727720026fa7cd1d136752d1. 21 | 22 | 23 | ## v1.2.1 (2025-09-25) 24 | 25 | ### Bug Fixes 26 | 27 | - Address code review comments - move import to top, restore TODO 28 | ([`f86219a`](https://github.com/EleutherAI/sparsify/commit/f86219a9a84f163c002159683c5917b11ea03c54)) 29 | 30 | - Auto-detect dtype from safetensors file to resolve loading mismatch 31 | ([`8fc921f`](https://github.com/EleutherAI/sparsify/commit/8fc921ffb8e938e78d954140a4caceaba952986c)) 32 | 33 | 34 | ## v1.2.0 (2025-09-22) 35 | 36 | ### Features 37 | 38 | - Exclude tokens defined by user from training 39 | ([`f346038`](https://github.com/EleutherAI/sparsify/commit/f3460385efc42fac6e760357f3c34562783b515c)) 40 | 41 | 42 | ## v1.1.3 (2025-04-17) 43 | 44 | ### Bug Fixes 45 | 46 | - Save best in dist mode 47 | ([`885bcd5`](https://github.com/EleutherAI/sparsify/commit/885bcd5c1e94d6b4d82b200545fe0ee1f830068e)) 48 | 49 | 50 | ## v1.1.2 (2025-04-17) 51 | 52 | ### Bug Fixes 53 | 54 | - Only drop in dist 55 | ([`b18bd0e`](https://github.com/EleutherAI/sparsify/commit/b18bd0e272044e42dce816583214e9d099484575)) 56 | 57 | 58 | ## v1.1.1 (2025-04-16) 59 | 60 | ### Bug Fixes 61 | 62 | - Hang when the number of examples is indivisible across processes 63 | ([`0f662ad`](https://github.com/EleutherAI/sparsify/commit/0f662adf7705a836aa8c910b39b33546a8cf7975)) 64 | 65 | 66 | ## v1.1.0 (2025-04-16) 67 | 68 | ### Features 69 | 70 | - Update from deprecated publish action 71 | ([`aafc9fc`](https://github.com/EleutherAI/sparsify/commit/aafc9fc049e7e8f18a017e99f122343f0bcb4006)) 72 | 73 | fix: deprecated initial release build 74 | 75 | 76 | ## v1.0.0 (2025-04-16) 77 | 78 | ### Features 79 | 80 | - Empty commit for initial release 81 | ([`877ef6f`](https://github.com/EleutherAI/sparsify/commit/877ef6f7219e9a4424bf9cb51be5bef5ac2adca4)) 82 | 83 | BREAKING CHANGE: non-breaking inital major commit 84 | 85 | ### Breaking Changes 86 | 87 | - Non-breaking inital major commit 88 | 89 | 90 | ## v0.0.1 (2025-04-16) 91 | 92 | ### Bug Fixes 93 | 94 | - Change topk dim selection from 1 to -1 95 | ([`1cdb4a5`](https://github.com/EleutherAI/sparsify/commit/1cdb4a5bbe723b0ee0a0015f834d142e83facafe)) 96 | 97 | - Remove lint from CI, remove environment from CI, trigger release 98 | ([`f6dca80`](https://github.com/EleutherAI/sparsify/commit/f6dca80d4575fa1a58eac55cc7b24f802fa669db)) 99 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "eai-sparsify" 7 | description = "Sparsify transformers with SAEs and transcoders" 8 | readme = "README.md" 9 | requires-python = ">=3.10" 10 | keywords = ["interpretability", "explainable-ai"] 11 | license = {text = "MIT License"} 12 | dependencies = [ 13 | "accelerate", # For device_map in from_pretrained 14 | "datasets", 15 | "einops", 16 | "huggingface-hub", 17 | "natsort", # For sorting module names 18 | "safetensors", 19 | "schedulefree", 20 | "simple-parsing", 21 | "torch", 22 | "transformers", 23 | ] 24 | version = "1.3.0" 25 | [project.optional-dependencies] 26 | dev = [ 27 | "pre-commit", 28 | ] 29 | 30 | [project.scripts] 31 | sparsify = "sparsify.__main__:run" 32 | 33 | [tool.pyright] 34 | include = ["sparsify*"] 35 | reportPrivateImportUsage = false 36 | 37 | [tool.setuptools.packages.find] 38 | include = ["sparsify*"] 39 | 40 | [tool.ruff] 41 | # Enable pycodestyle (`E`), Pyflakes (`F`), and isort (`I`) codes 42 | # See https://beta.ruff.rs/docs/rules/ for more possible rules 43 | select = ["E", "F", "I"] 44 | # Same as Black. 45 | line-length = 88 46 | # Avoid automatically removing unused imports in __init__.py files. 47 | # Such imports will be flagged with a dedicated message suggesting 48 | # that the import is either added to the module's __all__ symbol 49 | ignore-init-module-imports = true 50 | 51 | [tool.semantic_release] 52 | version_variables = ["sparsify/__init__.py:__version__"] 53 | version_toml = ["pyproject.toml:project.version"] 54 | assets = [] 55 | build_command_env = [] 56 | commit_message = "{version}\n\nAutomatically generated by python-semantic-release" 57 | commit_parser = "conventional" 58 | logging_use_named_masks = false 59 | major_on_zero = true 60 | allow_zero_version = true 61 | no_git_verify = false 62 | tag_format = "v{version}" 63 | 64 | [tool.semantic_release.branches.main] 65 | match = "(main|master)" 66 | prerelease_token = "rc" 67 | prerelease = false 68 | 69 | [tool.semantic_release.changelog] 70 | changelog_file = "" 71 | exclude_commit_patterns = [] 72 | mode = "init" 73 | insertion_flag = "" 74 | template_dir = "templates" 75 | 76 | [tool.semantic_release.changelog.default_templates] 77 | changelog_file = "CHANGELOG.md" 78 | output_format = "md" 79 | mask_initial_release = false 80 | 81 | [tool.semantic_release.changelog.environment] 82 | block_start_string = "{%" 83 | block_end_string = "%}" 84 | variable_start_string = "{{" 85 | variable_end_string = "}}" 86 | comment_start_string = "{#" 87 | comment_end_string = "#}" 88 | trim_blocks = false 89 | lstrip_blocks = false 90 | newline_sequence = "\n" 91 | keep_trailing_newline = false 92 | extensions = [] 93 | autoescape = false 94 | 95 | [tool.semantic_release.commit_author] 96 | env = "GIT_COMMIT_AUTHOR" 97 | default = "semantic-release " 98 | 99 | [tool.semantic_release.commit_parser_options] 100 | minor_tags = ["feat"] 101 | patch_tags = ["fix", "perf"] 102 | other_allowed_tags = ["build", "chore", "ci", "docs", "style", "refactor", "test"] 103 | allowed_tags = ["feat", "fix", "perf", "build", "chore", "ci", "docs", "style", "refactor", "test"] 104 | default_bump_level = 0 105 | parse_squash_commits = false 106 | ignore_merge_commits = false 107 | 108 | [tool.semantic_release.remote] 109 | name = "origin" 110 | type = "github" 111 | ignore_token_for_push = false 112 | insecure = false 113 | 114 | [tool.semantic_release.publish] 115 | dist_glob_patterns = ["dist/*"] 116 | upload_to_vcs_release = true 117 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | checkpoints/ 163 | wandb/ 164 | statistics/ 165 | results/ 166 | images/ 167 | -------------------------------------------------------------------------------- /sparsify/fused_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, NamedTuple 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class EncoderOutput(NamedTuple): 8 | top_acts: torch.Tensor 9 | """Activations of the top-k latents.""" 10 | 11 | top_indices: torch.Tensor 12 | """Indices of the top-k features.""" 13 | 14 | pre_acts: torch.Tensor 15 | """Activations before the top-k selection.""" 16 | 17 | 18 | class FusedEncoder(torch.autograd.Function): 19 | @staticmethod 20 | def forward( 21 | ctx, input, weight, bias, k: int, activation: Literal["groupmax", "topk"] 22 | ): 23 | """ 24 | input: (N, D) 25 | weight: (M, D) 26 | bias: (M,) 27 | k: int (number of top elements to select along dim=1) 28 | """ 29 | preacts = F.relu(F.linear(input, weight, bias)) 30 | 31 | # Get top-k values and indices for each row 32 | if activation == "topk": 33 | values, indices = torch.topk(preacts, k, dim=-1, sorted=False) 34 | elif activation == "groupmax": 35 | values, indices = preacts.unflatten(-1, (k, -1)).max(dim=-1) 36 | 37 | # torch.max gives us indices into each group, but we want indices into the 38 | # flattened tensor. Add the offsets to get the correct indices. 39 | num_latents = preacts.shape[1] 40 | offsets = torch.arange( 41 | 0, num_latents, num_latents // k, device=preacts.device 42 | ) 43 | indices = offsets + indices 44 | else: 45 | raise ValueError(f"Unknown activation: {activation}") 46 | 47 | # Save tensors needed for the backward pass 48 | ctx.save_for_backward(input, weight, bias, indices) 49 | ctx.k = k 50 | return values, indices, preacts 51 | 52 | @staticmethod 53 | def backward(ctx, grad_values, grad_indices, grad_preacts): 54 | input, weight, bias, indices = ctx.saved_tensors 55 | grad_input = grad_weight = grad_bias = None 56 | 57 | # --- Grad w.r.t. input --- 58 | if ctx.needs_input_grad[0]: 59 | grad_input = F.embedding_bag( 60 | indices, 61 | weight, 62 | mode="sum", 63 | per_sample_weights=grad_values.type_as(weight), 64 | ) 65 | 66 | # --- Grad w.r.t. weight --- 67 | if ctx.needs_input_grad[1]: 68 | grad_weight = torch.zeros_like(weight) 69 | # Compute contributions from each top-k element: 70 | # computed as grad_values * input for each top-k location. 71 | contributions = grad_values.unsqueeze(2) * input.unsqueeze(1) 72 | _, _, D = contributions.shape 73 | # Flatten contributions to shape (N*k, D) 74 | contributions = contributions.reshape(-1, D) 75 | 76 | # Accumulate contributions into the correct rows of grad_weight. 77 | grad_weight.index_add_(0, indices.flatten(), contributions.type_as(weight)) 78 | 79 | # --- Grad w.r.t. bias --- 80 | if bias is not None and ctx.needs_input_grad[2]: 81 | grad_bias = torch.zeros_like(bias) 82 | grad_bias.index_add_( 83 | 0, indices.flatten(), grad_values.flatten().type_as(bias) 84 | ) 85 | 86 | # The k parameter is an int, so return None for its gradient. 87 | return grad_input, grad_weight, grad_bias, None, None 88 | 89 | 90 | def fused_encoder( 91 | input, 92 | weight, 93 | bias, 94 | k: int, 95 | activation: Literal["groupmax", "topk"], 96 | ) -> EncoderOutput: 97 | """ 98 | Convenience wrapper that performs an nn.Linear followed by `activation` with 99 | a backward pass optimized using index_add. 100 | """ 101 | return EncoderOutput( 102 | *FusedEncoder.apply(input, weight, bias, k, activation) # type: ignore 103 | ) 104 | -------------------------------------------------------------------------------- /sparsify/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Type, TypeVar, cast 3 | 4 | import torch 5 | from accelerate.utils import send_to_device 6 | from torch import Tensor, nn 7 | from transformers import PreTrainedModel 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | def assert_type(typ: Type[T], obj: Any) -> T: 13 | """Assert that an object is of a given type at runtime and return it.""" 14 | if not isinstance(obj, typ): 15 | raise TypeError(f"Expected {typ.__name__}, got {type(obj).__name__}") 16 | 17 | return cast(typ, obj) 18 | 19 | 20 | def get_layer_list(model: PreTrainedModel) -> tuple[str, nn.ModuleList]: 21 | """Get the list of layers to train SAEs on.""" 22 | N = assert_type(int, model.config.num_hidden_layers) 23 | candidates = [ 24 | (name, mod) 25 | for (name, mod) in model.base_model.named_modules() 26 | if isinstance(mod, nn.ModuleList) and len(mod) == N 27 | ] 28 | assert len(candidates) == 1, "Could not find the list of layers." 29 | 30 | return candidates[0] 31 | 32 | 33 | @torch.inference_mode() 34 | def resolve_widths( 35 | model: PreTrainedModel, 36 | module_names: list[str], 37 | dim: int = -1, 38 | ) -> dict[str, int]: 39 | """Find number of output dimensions for the specified modules.""" 40 | module_to_name = { 41 | model.base_model.get_submodule(name): name for name in module_names 42 | } 43 | shapes: dict[str, int] = {} 44 | 45 | def hook(module, _, output): 46 | # Unpack tuples if needed 47 | if isinstance(output, tuple): 48 | output, *_ = output 49 | 50 | name = module_to_name[module] 51 | shapes[name] = output.shape[dim] 52 | 53 | handles = [mod.register_forward_hook(hook) for mod in module_to_name] 54 | dummy = send_to_device(model.dummy_inputs, model.device) 55 | try: 56 | model(**dummy) 57 | finally: 58 | for handle in handles: 59 | handle.remove() 60 | 61 | return shapes 62 | 63 | 64 | def set_submodule(model: nn.Module, submodule_path: str, new_submodule: nn.Module): 65 | """ 66 | Replaces a submodule in a PyTorch model dynamically. 67 | 68 | Args: 69 | model (nn.Module): The root model containing the submodule. 70 | submodule_path (str): Dotted path to the submodule. 71 | new_submodule (nn.Module): The new module to replace the existing one. 72 | 73 | Example: 74 | set_submodule(model, "encoder.layer.0.attention.self", nn.Identity()) 75 | """ 76 | parent_path, _, last_name = submodule_path.rpartition(".") 77 | parent_module = model.get_submodule(parent_path) if parent_path else model 78 | setattr(parent_module, last_name, new_submodule) 79 | 80 | 81 | # Fallback implementation of SAE decoder 82 | def eager_decode(top_indices: Tensor, top_acts: Tensor, W_dec: Tensor): 83 | return nn.functional.embedding_bag( 84 | top_indices, W_dec.mT, per_sample_weights=top_acts, mode="sum" 85 | ) 86 | 87 | 88 | # Triton implementation of SAE decoder 89 | def triton_decode(top_indices: Tensor, top_acts: Tensor, W_dec: Tensor): 90 | return xformers_embedding_bag(top_indices, W_dec.mT, top_acts) 91 | 92 | 93 | try: 94 | from .xformers import xformers_embedding_bag 95 | except ImportError: 96 | decoder_impl = eager_decode 97 | print("Triton not installed, using eager implementation of sparse decoder.") 98 | else: 99 | if os.environ.get("SPARSIFY_DISABLE_TRITON") == "1": 100 | print("Triton disabled, using eager implementation of sparse decoder.") 101 | decoder_impl = eager_decode 102 | else: 103 | decoder_impl = triton_decode 104 | 105 | 106 | def handle_arg_string(arg): 107 | if arg.lower() == "true": 108 | return True 109 | elif arg.lower() == "false": 110 | return False 111 | elif arg.isnumeric(): 112 | return int(arg) 113 | try: 114 | return float(arg) 115 | except ValueError: 116 | return arg 117 | 118 | 119 | def simple_parse_args_string(args_string: str) -> dict: 120 | """ 121 | Parses something like 122 | args1=val1,arg2=val2 123 | into a dictionary. 124 | """ 125 | args_string = args_string.strip() 126 | if not args_string: 127 | return {} 128 | arg_list = [arg for arg in args_string.split(",") if arg] 129 | args_dict = { 130 | kv[0]: handle_arg_string("=".join(kv[1:])) 131 | for kv in [arg.split("=") for arg in arg_list] 132 | } 133 | return args_dict 134 | -------------------------------------------------------------------------------- /sparsify/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | from typing import Literal 4 | 5 | from simple_parsing import Serializable, list_field 6 | 7 | 8 | @dataclass 9 | class SparseCoderConfig(Serializable): 10 | """ 11 | Configuration for training a sparse coder on a language model. 12 | """ 13 | 14 | activation: Literal["groupmax", "topk"] = "topk" 15 | """Activation function to use.""" 16 | 17 | expansion_factor: int = 32 18 | """Multiple of the input dimension to use as the sparse coder dimension.""" 19 | 20 | normalize_decoder: bool = True 21 | """Normalize the decoder weights to have unit norm.""" 22 | 23 | num_latents: int = 0 24 | """Number of latents to use. If 0, use `expansion_factor`.""" 25 | 26 | k: int = 32 27 | """Number of nonzero features.""" 28 | 29 | multi_topk: bool = False 30 | """Use Multi-TopK loss.""" 31 | 32 | skip_connection: bool = False 33 | """Include a linear skip connection.""" 34 | 35 | transcode: bool = False 36 | """Whether we want to predict the output of a module given its input.""" 37 | 38 | 39 | # Support different naming conventions for the same configuration 40 | SaeConfig = SparseCoderConfig 41 | TranscoderConfig = partial(SparseCoderConfig, transcode=True) 42 | 43 | 44 | @dataclass 45 | class TrainConfig(Serializable): 46 | sae: SparseCoderConfig 47 | 48 | batch_size: int = 32 49 | """Batch size measured in sequences.""" 50 | 51 | grad_acc_steps: int = 1 52 | """Number of steps over which to accumulate gradients.""" 53 | 54 | micro_acc_steps: int = 1 55 | """Chunk the activations into this number of microbatches for training.""" 56 | 57 | loss_fn: Literal["ce", "fvu", "kl"] = "fvu" 58 | """Loss function to use for training the sparse coders. 59 | 60 | - `ce`: Cross-entropy loss of the final model logits. 61 | - `fvu`: Fraction of variance explained. 62 | - `kl`: KL divergence of the final model logits w.r.t. the original logits. 63 | """ 64 | 65 | optimizer: Literal["adam", "muon", "signum"] = "signum" 66 | """Optimizer to use.""" 67 | 68 | lr: float | None = None 69 | """Base LR. If None, it is automatically chosen based on the number of latents.""" 70 | 71 | lr_warmup_steps: int = 1000 72 | """Number of steps over which to warm up the learning rate. Only used if 73 | `optimizer` is `adam`.""" 74 | 75 | k_decay_steps: int = 0 76 | """Number of steps over which to decay the number of active latents. Starts at 77 | input width * 10 and decays to k. Experimental feature.""" 78 | 79 | auxk_alpha: float = 0.0 80 | """Weight of the auxiliary loss term.""" 81 | 82 | dead_feature_threshold: int = 10_000_000 83 | """Number of tokens after which a feature is considered dead.""" 84 | 85 | exclude_tokens: list[int] = list_field() 86 | """List of tokens to ignore during sparse coders training.""" 87 | 88 | hookpoints: list[str] = list_field() 89 | """List of hookpoints to train sparse coders on.""" 90 | 91 | init_seeds: list[int] = list_field(0) 92 | """List of random seeds to use for initialization. If more than one, train a sparse 93 | coder for each seed.""" 94 | 95 | layers: list[int] = list_field() 96 | """List of layer indices to train sparse coders on.""" 97 | 98 | layer_stride: int = 1 99 | """Stride between layers to train sparse coders on.""" 100 | 101 | distribute_modules: bool = False 102 | """Store one copy of each sparse coder, instead of copying them across devices.""" 103 | 104 | save_every: int = 1000 105 | """Save sparse coders every `save_every` steps.""" 106 | 107 | save_best: bool = False 108 | """Save the best checkpoint found for each hookpoint.""" 109 | 110 | finetune: str | None = None 111 | """Finetune the sparse coders from a pretrained checkpoint.""" 112 | 113 | log_to_wandb: bool = True 114 | run_name: str | None = None 115 | wandb_log_frequency: int = 1 116 | 117 | save_dir: str = "checkpoints" 118 | 119 | def __post_init__(self): 120 | """Validate the configuration.""" 121 | if self.layers and self.layer_stride != 1: 122 | raise ValueError("Cannot specify both `layers` and `layer_stride`.") 123 | 124 | if self.distribute_modules and self.loss_fn in ("ce", "kl"): 125 | raise ValueError( 126 | "Distributing modules across ranks is not compatible with the " 127 | "cross-entropy or KL divergence losses." 128 | ) 129 | 130 | if not self.init_seeds: 131 | raise ValueError("Must specify at least one random seed.") 132 | -------------------------------------------------------------------------------- /sparsify/data.py: -------------------------------------------------------------------------------- 1 | """Tools for tokenizing and manipulating text datasets.""" 2 | 3 | import math 4 | from multiprocessing import cpu_count 5 | from typing import TypeVar, Union 6 | 7 | import numpy as np 8 | import torch 9 | from datasets import Dataset, DatasetDict 10 | from torch.utils.data import Dataset as TorchDataset 11 | from transformers import PreTrainedTokenizerBase 12 | 13 | T = TypeVar("T", bound=Union[Dataset, DatasetDict]) 14 | 15 | 16 | def chunk_and_tokenize( 17 | data: T, 18 | tokenizer: PreTrainedTokenizerBase, 19 | *, 20 | format: str = "torch", 21 | num_proc: int = cpu_count() // 2, 22 | text_key: str = "text", 23 | max_seq_len: int = 2048, 24 | return_final_batch: bool = False, 25 | load_from_cache_file: bool = True, 26 | ) -> T: 27 | """Perform GPT-style chunking and tokenization on a dataset. 28 | 29 | The resulting dataset will consist entirely of chunks exactly `max_seq_len` tokens 30 | long. Long sequences will be split into multiple chunks, and short sequences will 31 | be merged with their neighbors, using `eos_token` as a separator. The fist token 32 | will also always be an `eos_token`. 33 | 34 | Args: 35 | data: The dataset to chunk and tokenize. 36 | tokenizer: The tokenizer to use. 37 | format: The format to return the dataset in, passed to `Dataset.with_format`. 38 | num_proc: The number of processes to use for tokenization. 39 | text_key: The key in the dataset to use as the text to tokenize. 40 | max_seq_len: The maximum length of a batch of input ids. 41 | return_final_batch: Whether to return the final batch, which may be smaller 42 | than the others. 43 | load_from_cache_file: Whether to load from the cache file. 44 | 45 | Returns: 46 | The chunked and tokenized dataset. 47 | """ 48 | 49 | def _tokenize_fn(x: dict[str, list]): 50 | chunk_size = min(tokenizer.model_max_length, max_seq_len) 51 | sep = tokenizer.eos_token or "<|endoftext|>" 52 | joined_text = sep.join([""] + x[text_key]) 53 | output = tokenizer( 54 | # Concatenate all the samples together, separated by the EOS token. 55 | joined_text, # start with an eos token 56 | max_length=chunk_size, 57 | return_attention_mask=False, 58 | return_overflowing_tokens=True, 59 | truncation=True, 60 | ) 61 | 62 | if overflow := output.pop("overflowing_tokens", None): 63 | # Slow Tokenizers return unnested lists of ints 64 | assert isinstance(output.input_ids[0], int) 65 | 66 | # Chunk the overflow into batches of size `chunk_size` 67 | chunks = [output["input_ids"]] + [ 68 | overflow[i * chunk_size : (i + 1) * chunk_size] 69 | for i in range(math.ceil(len(overflow) / chunk_size)) 70 | ] 71 | output = {"input_ids": chunks} 72 | 73 | if not return_final_batch: 74 | # We know that the last sample will almost always be less than the max 75 | # number of tokens, and we don't want to pad, so we just drop it. 76 | output = {k: v[:-1] for k, v in output.items()} 77 | 78 | output_batch_size = len(output["input_ids"]) 79 | 80 | if output_batch_size == 0: 81 | raise ValueError( 82 | "Not enough data to create a single complete batch." 83 | " Either allow the final batch to be returned," 84 | " or supply more data." 85 | ) 86 | 87 | return output 88 | 89 | data = data.map( 90 | _tokenize_fn, 91 | # Batching is important for ensuring that we don't waste tokens 92 | # since we always throw away the last element of the batch we 93 | # want to keep the batch size as large as possible 94 | batched=True, 95 | batch_size=2048, 96 | num_proc=num_proc, 97 | remove_columns=get_columns_all_equal(data), 98 | load_from_cache_file=load_from_cache_file, 99 | ) 100 | return data.with_format(format, columns=["input_ids"]) 101 | 102 | 103 | def get_columns_all_equal(dataset: Union[Dataset, DatasetDict]) -> list[str]: 104 | """Get a single list of columns in a `Dataset` or `DatasetDict`. 105 | 106 | We assert the columms are the same across splits if it's a `DatasetDict`. 107 | 108 | Args: 109 | dataset: The dataset to get the columns from. 110 | 111 | Returns: 112 | A list of columns. 113 | """ 114 | if isinstance(dataset, DatasetDict): 115 | cols_by_split = dataset.column_names.values() 116 | columns = next(iter(cols_by_split)) 117 | if not all(cols == columns for cols in cols_by_split): 118 | raise ValueError("All splits must have the same columns") 119 | 120 | return columns 121 | 122 | return dataset.column_names 123 | 124 | 125 | class MemmapDataset(TorchDataset): 126 | """Torch Dataset backed by a memory-mapped numpy array.""" 127 | 128 | def __init__( 129 | self, 130 | data_path: str, 131 | ctx_len: int, 132 | max_examples: int | None = None, 133 | dtype=np.uint16, 134 | ): 135 | mmap = np.memmap(data_path, dtype=dtype, mode="r").reshape(-1, ctx_len) 136 | self.mmap = mmap[:max_examples] 137 | 138 | def __len__(self): 139 | return len(self.mmap) 140 | 141 | def __getitem__(self, idx): 142 | return dict(input_ids=torch.from_numpy(self.mmap[idx].astype(np.int64))) 143 | 144 | def select(self, rng: range) -> "MemmapDataset": 145 | """Select a subset of the dataset.""" 146 | mmap = MemmapDataset.__new__(MemmapDataset) 147 | mmap.mmap = self.mmap[rng.start : rng.stop] 148 | return mmap 149 | 150 | def shard(self, num_shards: int, shard_id: int) -> "MemmapDataset": 151 | """Split the dataset into `num_shards` and return the `shard_id`-th shard.""" 152 | mmap = MemmapDataset.__new__(MemmapDataset) 153 | 154 | # Split the mmap array into `num_shards` and return the `shard_id`-th shard 155 | shards = np.array_split(self.mmap, num_shards) 156 | mmap.mmap = shards[shard_id] 157 | return mmap 158 | -------------------------------------------------------------------------------- /sparsify/__main__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import nullcontext, redirect_stdout 3 | from dataclasses import dataclass 4 | from datetime import timedelta 5 | from multiprocessing import cpu_count 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from datasets import Dataset, load_dataset 10 | from safetensors.torch import load_model 11 | from simple_parsing import field, parse 12 | from transformers import ( 13 | AutoModel, 14 | AutoModelForCausalLM, 15 | AutoTokenizer, 16 | BitsAndBytesConfig, 17 | PreTrainedModel, 18 | ) 19 | 20 | from .data import MemmapDataset, chunk_and_tokenize 21 | from .trainer import TrainConfig, Trainer 22 | from .utils import simple_parse_args_string 23 | 24 | 25 | @dataclass 26 | class RunConfig(TrainConfig): 27 | model: str = field( 28 | default="HuggingFaceTB/SmolLM2-135M", 29 | positional=True, 30 | ) 31 | """Name of the model to train.""" 32 | 33 | dataset: str = field( 34 | default="EleutherAI/SmolLM2-135M-10B", 35 | positional=True, 36 | ) 37 | """Path to the dataset to use for training.""" 38 | 39 | split: str = "train" 40 | """Dataset split to use for training.""" 41 | 42 | ctx_len: int = 2048 43 | """Context length to use for training.""" 44 | 45 | # Use a dummy encoding function to prevent the token from being saved 46 | # to disk in plain text 47 | hf_token: str | None = field(default=None, encoding_fn=lambda _: None) 48 | """Huggingface API token for downloading models.""" 49 | 50 | revision: str | None = None 51 | """Model revision to use for training.""" 52 | 53 | load_in_8bit: bool = False 54 | """Load the model in 8-bit mode.""" 55 | 56 | max_examples: int | None = None 57 | """Maximum number of examples to use for training.""" 58 | 59 | resume: bool = False 60 | """Whether to try resuming from the checkpoint present at `checkpoints/run_name`.""" 61 | 62 | text_column: str = "text" 63 | """Column name to use for text data.""" 64 | 65 | shuffle_seed: int = 42 66 | """Random seed for shuffling the dataset.""" 67 | 68 | data_preprocessing_num_proc: int = field( 69 | default_factory=lambda: cpu_count() // 2, 70 | ) 71 | """Number of processes to use for preprocessing data""" 72 | 73 | data_args: str = field( 74 | default="", 75 | ) 76 | """Arguments to pass to the HuggingFace dataset constructor in the 77 | format 'arg1=val1,arg2=val2'.""" 78 | 79 | 80 | def load_artifacts( 81 | args: RunConfig, rank: int 82 | ) -> tuple[PreTrainedModel, Dataset | MemmapDataset]: 83 | if args.load_in_8bit: 84 | dtype = torch.float16 85 | elif torch.cuda.is_bf16_supported(): 86 | dtype = torch.bfloat16 87 | else: 88 | dtype = "auto" 89 | 90 | # End-to-end training requires a model with a causal LM head 91 | model_cls = AutoModel if args.loss_fn == "fvu" else AutoModelForCausalLM 92 | model = model_cls.from_pretrained( 93 | args.model, 94 | device_map={"": f"cuda:{rank}"}, 95 | quantization_config=( 96 | BitsAndBytesConfig(load_in_8bit=args.load_in_8bit) 97 | if args.load_in_8bit 98 | else None 99 | ), 100 | revision=args.revision, 101 | torch_dtype=dtype, 102 | token=args.hf_token, 103 | ) 104 | 105 | # For memmap-style datasets 106 | if args.dataset.endswith(".bin"): 107 | dataset = MemmapDataset(args.dataset, args.ctx_len, args.max_examples) 108 | else: 109 | # For Huggingface datasets 110 | try: 111 | kwargs = simple_parse_args_string(args.data_args) 112 | dataset = load_dataset(args.dataset, split=args.split, **kwargs) 113 | except ValueError as e: 114 | # Automatically use load_from_disk if appropriate 115 | if "load_from_disk" in str(e): 116 | dataset = Dataset.load_from_disk(args.dataset, keep_in_memory=False) 117 | else: 118 | raise e 119 | 120 | assert isinstance(dataset, Dataset) 121 | if "input_ids" not in dataset.column_names: 122 | tokenizer = AutoTokenizer.from_pretrained(args.model, token=args.hf_token) 123 | dataset = chunk_and_tokenize( 124 | dataset, 125 | tokenizer, 126 | max_seq_len=args.ctx_len, 127 | num_proc=args.data_preprocessing_num_proc, 128 | text_key=args.text_column, 129 | ) 130 | else: 131 | print("Dataset already tokenized; skipping tokenization.") 132 | 133 | print(f"Shuffling dataset with seed {args.shuffle_seed}") 134 | dataset = dataset.shuffle(args.shuffle_seed) 135 | 136 | dataset = dataset.with_format("torch") 137 | if limit := args.max_examples: 138 | dataset = dataset.select(range(limit)) 139 | 140 | return model, dataset 141 | 142 | 143 | def run(): 144 | local_rank = os.environ.get("LOCAL_RANK") 145 | ddp = local_rank is not None 146 | rank = int(local_rank) if ddp else 0 147 | 148 | if ddp: 149 | torch.cuda.set_device(int(local_rank)) 150 | 151 | # Increase the default timeout in order to account for slow downloads 152 | # and data preprocessing on the main rank 153 | dist.init_process_group( 154 | "nccl", device_id=torch.device(rank), timeout=timedelta(weeks=1) 155 | ) 156 | 157 | if rank == 0: 158 | print(f"Using DDP across {dist.get_world_size()} GPUs.") 159 | 160 | args = parse(RunConfig) 161 | 162 | # Prevent ranks other than 0 from printing 163 | with nullcontext() if rank == 0 else redirect_stdout(None): 164 | # Awkward hack to prevent other ranks from duplicating data preprocessing 165 | if not ddp or rank == 0: 166 | model, dataset = load_artifacts(args, rank) 167 | if ddp: 168 | dist.barrier() 169 | if rank != 0: 170 | model, dataset = load_artifacts(args, rank) 171 | 172 | # Drop examples that are indivisible across processes to prevent deadlock 173 | remainder_examples = len(dataset) % dist.get_world_size() 174 | dataset = dataset.select(range(len(dataset) - remainder_examples)) 175 | 176 | dataset = dataset.shard(dist.get_world_size(), rank) 177 | 178 | # Drop examples that are indivisible across processes to prevent deadlock 179 | remainder_examples = len(dataset) % dist.get_world_size() 180 | dataset = dataset.select(range(len(dataset) - remainder_examples)) 181 | 182 | print(f"Training on '{args.dataset}' (split '{args.split}')") 183 | print(f"Storing model weights in {model.dtype}") 184 | 185 | trainer = Trainer(args, dataset, model) 186 | if args.resume: 187 | trainer.load_state(f"checkpoints/{args.run_name}" or "checkpoints/unnamed") 188 | elif args.finetune: 189 | for name, sae in trainer.saes.items(): 190 | load_model( 191 | sae, 192 | f"{args.finetune}/{name}/sae.safetensors", 193 | device=str(model.device), 194 | ) 195 | 196 | trainer.fit() 197 | 198 | 199 | if __name__ == "__main__": 200 | run() 201 | -------------------------------------------------------------------------------- /sparsify/muon.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/KellerJordan/Muon/blob/master/muon.py 3 | Modifications by Nora Belrose 4 | """ 5 | 6 | import torch 7 | import torch.distributed as dist 8 | from torch import Tensor 9 | 10 | 11 | def quintic_newtonschulz(G: Tensor, steps: int) -> Tensor: 12 | """ 13 | Newton-Schulz iteration to compute the orthogonalization of G. We opt to use a 14 | quintic iteration whose coefficients are selected to maximize the slope at zero. 15 | For the purpose of minimizing steps, it turns out to be empirically effective to 16 | keep increasing the slope at zero even beyond the point where the iteration no 17 | longer converges all the way to one everywhere on the interval. This iteration 18 | therefore does not produce UV^T but rather something like US'V^T where S' is 19 | diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 20 | performance at all relative to UV^T, where USV^T = G is the SVD. 21 | """ 22 | # batched implementation by @scottjmaddox, put into practice by @YouJiacheng 23 | assert G.ndim >= 2 24 | a, b, c = (3.4445, -4.7750, 2.0315) 25 | X = G.bfloat16() 26 | if G.size(-2) > G.size(-1): 27 | X = X.mT 28 | 29 | # Ensure spectral norm is at most 1 30 | X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) 31 | # Perform the NS iterations 32 | for _ in range(steps): 33 | # quintic strategy adapted from suggestion by @jxbz, @leloykun, @YouJiacheng 34 | A = X @ X.mT 35 | B = b * A + c * A @ A 36 | X = a * X + B @ X 37 | 38 | if G.size(-2) > G.size(-1): 39 | X = X.mT 40 | return X 41 | 42 | 43 | class Muon(torch.optim.Optimizer): 44 | """ 45 | Muon - MomentUm Orthogonalized by Newton-schulz 46 | 47 | Muon is a generalized steepest descent optimizer using the spectral norm on the 48 | matrix-valued parameters. This means it always updates in the direction which 49 | locally reduces the loss as much as possible, while constraining the update to have 50 | a spectral norm given by the learning rate. It achieves this using a Newton-Schulz 51 | iteration to orthogonalize the stochastic gradient (or momentum buffer) for each 52 | matrix in the model before taking a step. 53 | 54 | The spectral norm is an intuitive heuristic because, roughly speaking, it measures 55 | the maximum change to the activations of a layer that can be caused by a change to 56 | its weights. By constraining the worst-case change to the activations, we ensure 57 | that we do not desta 58 | 59 | TThis optimizer is unlikely to work well with small batch sizes, since it strongly 60 | magnifies small singular values, which will be noisy given a small minibatch. 61 | 62 | Arguments: 63 | lr: The learning rate used by the internal SGD. 64 | momentum: The momentum used by the internal SGD. 65 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. 66 | ns_steps: The number of Newton-Schulz iteration steps to use. 67 | """ 68 | 69 | def __init__( 70 | self, 71 | params, 72 | lr: float = 1e-3, 73 | momentum: float = 0.95, 74 | nesterov: bool = True, 75 | weight_decay: float = 0.1, 76 | ns_steps: int = 5, 77 | ddp: bool = True, 78 | ): 79 | defaults = dict( 80 | lr=lr, 81 | momentum=momentum, 82 | nesterov=nesterov, 83 | ns_steps=ns_steps, 84 | weight_decay=weight_decay, 85 | ) 86 | self.rank = dist.get_rank() if dist.is_initialized() and ddp else 0 87 | self.world_size = dist.get_world_size() if dist.is_initialized() and ddp else 1 88 | 89 | # Distributed Data Parallel (DDP) setup 90 | if dist.is_initialized() and ddp: 91 | param_groups = [] 92 | 93 | # Check that the user isn't doing some weird model parallelism 94 | devices = {p.device for p in params} 95 | device = next(iter(devices)) 96 | assert len(devices) == 1, "Muon does not support model parallelism." 97 | 98 | # Group parameters by their device and number of elements. For each group, 99 | # we pre-allocate a buffer to store the updates from all ranks. 100 | for size in {p.numel() for p in params}: 101 | b = torch.empty( 102 | self.world_size, size, dtype=torch.bfloat16, device=device 103 | ) 104 | group = dict( 105 | params=[p for p in params if p.numel() == size], 106 | update_buffer=b, 107 | update_buffer_views=[b[i] for i in range(self.world_size)], 108 | ) 109 | param_groups.append(group) 110 | 111 | super().__init__(param_groups, defaults) 112 | else: 113 | super().__init__(params, defaults) 114 | 115 | @torch.no_grad() 116 | def step(self): 117 | for group in self.param_groups: 118 | params: list[Tensor] = group["params"] 119 | 120 | # Apply decoupled weight decay to all parameters. This doesn't require any 121 | # communication, since it's a simple element-wise operation. 122 | if group["weight_decay"] > 0.0: 123 | for p in params: 124 | p.mul_(1 - group["lr"] * group["weight_decay"]) 125 | 126 | # These will be None / empty list if we're not using DDP 127 | update_buffer: Tensor | None = group.get("update_buffer", None) 128 | update_buffer_views: list[Tensor] = group.get("update_buffer_views", []) 129 | 130 | beta = group["momentum"] 131 | handle = None 132 | params_world = None 133 | 134 | def update_prev(): # optimized implementation contributed by @YouJiacheng 135 | assert handle is not None and params_world is not None 136 | handle.wait() 137 | 138 | for p_world, g_world in zip(params_world, update_buffer_views): 139 | # Heuristic from 140 | scale = 0.2 * max(p_world.shape) ** 0.5 141 | p_world.add_(g_world.view_as(p_world), alpha=-group["lr"] * scale) 142 | 143 | for i in range(0, len(params), self.world_size): 144 | # Compute Muon update 145 | if i + self.rank < len(params): 146 | p = params[i + self.rank] 147 | state = self.state[p] 148 | 149 | g = p.grad 150 | assert g is not None 151 | 152 | # Apply momentum 153 | if beta > 0.0: 154 | if "exp_avg" not in state: 155 | state["exp_avg"] = torch.zeros_like(g) 156 | 157 | buf: Tensor = state["exp_avg"].lerp_(g, 1 - beta) 158 | g = g.lerp_(buf, beta) if group["nesterov"] else buf 159 | 160 | if g.ndim == 4: # for the case of conv filters 161 | g = g.view(len(g), -1) 162 | 163 | g = quintic_newtonschulz(g, steps=group["ns_steps"]) 164 | else: 165 | g = update_buffer_views[self.rank] 166 | 167 | if self.world_size > 1: 168 | # async all_gather instead of sync all_reduce by @YouJiacheng 169 | if i > 0: 170 | update_prev() 171 | 172 | handle = dist.all_gather_into_tensor( 173 | update_buffer, g.flatten(), async_op=True 174 | ) 175 | params_world = params[i : i + self.world_size] 176 | else: 177 | scale = 0.2 * max(params[i].shape) ** 0.5 178 | params[i].add_(g, alpha=-group["lr"] * scale) 179 | 180 | if self.world_size > 1: 181 | update_prev() 182 | -------------------------------------------------------------------------------- /sparsify/xformers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # Modifications by Stepan Shabalin and Nora Belrose 3 | import torch 4 | import triton 5 | from torch import Tensor 6 | from triton import language as tl 7 | 8 | 9 | @triton.jit 10 | def embedding_bag_k( 11 | out_ptr, # [B, dim] 12 | indices_ptr, # [B, bag_size] 13 | weight_ptr, # [n_keys**2, dim] 14 | per_sample_weights, # [B, bag_size] 15 | dim: tl.constexpr, 16 | dim_padded: tl.constexpr, 17 | bag_size: tl.constexpr, 18 | ): 19 | out_idx = tl.program_id(axis=0).to(tl.int64) 20 | out_value = tl.zeros([dim_padded], dtype=tl.float32) 21 | dim_mask = tl.arange(0, dim_padded) < dim 22 | for bag in range(0, bag_size): 23 | my_index = tl.load(indices_ptr + out_idx * bag_size + bag).to(tl.int64) 24 | my_scaling = tl.load(per_sample_weights + out_idx * bag_size + bag) 25 | my_weight = tl.load( 26 | weight_ptr + tl.arange(0, dim_padded) + my_index * dim, mask=dim_mask 27 | ) 28 | out_value = out_value + my_weight.to(tl.float32) * my_scaling 29 | tl.store( 30 | out_ptr + out_idx * dim + tl.arange(0, dim_padded), out_value, mask=dim_mask 31 | ) 32 | 33 | 34 | def embedding_bag_triton( 35 | indices: Tensor, weight: Tensor, per_sample_weights: Tensor 36 | ) -> Tensor: 37 | trt_out = torch.empty( 38 | [indices.shape[0], weight.shape[1]], dtype=weight.dtype, device=weight.device 39 | ) 40 | grid = (indices.shape[0],) 41 | 42 | embedding_bag_k[grid]( 43 | trt_out, 44 | indices, 45 | weight, 46 | per_sample_weights, 47 | dim=weight.shape[-1], 48 | dim_padded=triton.next_power_of_2(weight.shape[-1]), 49 | bag_size=indices.shape[1], 50 | num_warps=1, 51 | num_stages=1, 52 | ) 53 | return trt_out 54 | 55 | 56 | @triton.jit 57 | def count_per_embedding_k( 58 | count_per_emb_ptr, # [K+1] (out) 59 | indices_ptr, # [B, bag_size] 60 | bag_size: tl.constexpr, 61 | ): 62 | batch_id = tl.program_id(axis=0).to(tl.int64) 63 | for i in range(bag_size): 64 | embedding_id = tl.load(indices_ptr + batch_id * bag_size + i) 65 | tl.atomic_add( 66 | count_per_emb_ptr + embedding_id + 1, 67 | 1, 68 | sem="relaxed", 69 | ) 70 | 71 | 72 | @triton.jit 73 | def map_embeddings_and_outputs_k( 74 | reverse_mapping_ptr, # [B*bag_size] (out) 75 | mapping_write_pos_ptr, # [K] (tmp) 76 | indices_ptr, # [B, bag_size] 77 | bag_size: tl.constexpr, 78 | ): 79 | batch_id = tl.program_id(axis=0).to(tl.int64) 80 | for bag_id in range(bag_size): 81 | embedding_id = tl.load(indices_ptr + batch_id * bag_size + bag_id) 82 | write_pos = tl.atomic_add( 83 | mapping_write_pos_ptr + embedding_id, 1, sem="relaxed" 84 | ) 85 | tl.store(reverse_mapping_ptr + write_pos, batch_id * bag_size + bag_id) 86 | 87 | 88 | @triton.jit 89 | def aggregate_gradient_for_embedding_k( 90 | weight_grad_ptr, # [K, dim] (out) 91 | per_sample_weights_grad_ptr, # [B, bag_size] (out) 92 | emb_argsorted_ptr, # [K+1] 93 | weight_ptr, # [K, dim] (out) 94 | emb_begin_pos_ptr, # [K+1] 95 | reverse_mapping_ptr, # [B*bag_size] 96 | per_sample_weights_ptr, # [B, bag_size] 97 | gradient_ptr, # [B, dim] 98 | dim: tl.constexpr, 99 | dim_padded: tl.constexpr, 100 | bag_size: tl.constexpr, 101 | B: tl.constexpr, 102 | K: tl.constexpr, 103 | BLOCK_SIZE: tl.constexpr, 104 | ): 105 | first_embedding_id = tl.program_id(axis=0).to(tl.int64) 106 | for k in range(0, BLOCK_SIZE): 107 | embedding_id = first_embedding_id + (K // BLOCK_SIZE) * k 108 | # embedding_id = first_embedding_id * BLOCK_SIZE + k 109 | embedding_id = tl.load(emb_argsorted_ptr + embedding_id).to(tl.int64) 110 | weight_grad = tl.zeros([dim_padded], dtype=tl.float32) 111 | begin = tl.load(emb_begin_pos_ptr + embedding_id) 112 | end = tl.load(emb_begin_pos_ptr + embedding_id + 1) 113 | dim_mask = tl.arange(0, dim_padded) < dim 114 | weight = tl.load( 115 | weight_ptr + embedding_id * dim + tl.arange(0, dim_padded), 116 | mask=dim_mask, 117 | ).to(tl.float32) 118 | for idx in range(begin, end): 119 | output_indice_id = tl.load(reverse_mapping_ptr + idx).to(tl.int64) 120 | batch_id = output_indice_id // bag_size 121 | output_indice_id % bag_size 122 | per_sample_w = tl.load(per_sample_weights_ptr + output_indice_id) 123 | gradient = tl.load( 124 | gradient_ptr + batch_id * dim + tl.arange(0, dim_padded), mask=dim_mask 125 | ).to(tl.float32) 126 | weight_grad = weight_grad + per_sample_w * gradient 127 | per_sample_weights_grad = gradient * weight 128 | per_sample_weights_grad = tl.sum(per_sample_weights_grad) 129 | tl.store( 130 | per_sample_weights_grad_ptr + output_indice_id, per_sample_weights_grad 131 | ) 132 | tl.store( 133 | weight_grad_ptr + embedding_id * dim + tl.arange(0, dim_padded), 134 | weight_grad, 135 | mask=dim_mask, 136 | ) 137 | 138 | 139 | def embedding_bag_bw_rev_indices( 140 | indices: Tensor, 141 | weight: Tensor, 142 | per_sample_weights: Tensor, 143 | gradient: Tensor, 144 | ) -> tuple[Tensor, Tensor]: 145 | # Returns: [weight.grad, per_sample_weights.grad] 146 | 147 | K, dim = weight.shape 148 | B, bag_size = indices.shape 149 | count_per_emb = torch.zeros((K + 1,), dtype=torch.uint32, device=indices.device) 150 | count_per_embedding_k[(B,)](count_per_emb, indices, bag_size=bag_size, num_warps=1) 151 | emb_argsorted = count_per_emb[1:].int().argsort(descending=True) 152 | emb_begin_pos = count_per_emb.cumsum(0) 153 | reverse_mapping = torch.empty( 154 | [B * bag_size], dtype=torch.uint32, device=indices.device 155 | ) 156 | assert B * bag_size < 2 ** (reverse_mapping.dtype.itemsize * 8 - 1) 157 | map_embeddings_and_outputs_k[(B,)]( 158 | reverse_mapping_ptr=reverse_mapping, 159 | mapping_write_pos_ptr=emb_begin_pos.clone(), 160 | indices_ptr=indices, 161 | bag_size=bag_size, 162 | num_warps=1, 163 | ) 164 | weight_grad = torch.empty_like(weight) 165 | per_sample_weights_grad = torch.empty_like(per_sample_weights) 166 | BLOCK_SIZE = 8 167 | assert (K % BLOCK_SIZE) == 0 168 | aggregate_gradient_for_embedding_k[(K // BLOCK_SIZE,)]( 169 | weight_grad_ptr=weight_grad, 170 | emb_begin_pos_ptr=emb_begin_pos, 171 | emb_argsorted_ptr=emb_argsorted, 172 | per_sample_weights_grad_ptr=per_sample_weights_grad, 173 | weight_ptr=weight, 174 | reverse_mapping_ptr=reverse_mapping, 175 | per_sample_weights_ptr=per_sample_weights, 176 | gradient_ptr=gradient, 177 | dim=dim, 178 | dim_padded=triton.next_power_of_2(dim), 179 | bag_size=bag_size, 180 | B=B, 181 | K=K, 182 | BLOCK_SIZE=BLOCK_SIZE, 183 | num_warps=1, 184 | ) 185 | return weight_grad, per_sample_weights_grad 186 | 187 | 188 | class xFormersEmbeddingBag(torch.autograd.Function): 189 | @staticmethod 190 | def forward( 191 | ctx, 192 | indices: Tensor, 193 | weight: Tensor, 194 | per_sample_weights: Tensor, 195 | ) -> Tensor: 196 | ctx.save_for_backward(indices, weight, per_sample_weights) 197 | return embedding_bag_triton(indices, weight, per_sample_weights) 198 | 199 | @staticmethod 200 | def backward(ctx, gradient): 201 | indices, weight, per_sample_weights = ctx.saved_tensors 202 | 203 | weight_g, per_sample_weights_g = embedding_bag_bw_rev_indices( 204 | indices, 205 | weight, 206 | per_sample_weights, 207 | gradient, 208 | ) 209 | return None, weight_g, per_sample_weights_g, None 210 | 211 | 212 | def xformers_embedding_bag( 213 | indices: Tensor, 214 | weight: Tensor, 215 | per_sample_weights: Tensor, 216 | ) -> Tensor: 217 | return xFormersEmbeddingBag.apply(indices, weight, per_sample_weights) 218 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | This library trains _k_-sparse autoencoders (SAEs) and transcoders on the activations of HuggingFace language models, roughly following the recipe detailed in [Scaling and evaluating sparse autoencoders](https://arxiv.org/abs/2406.04093v1) (Gao et al. 2024). 3 | 4 | This is a lean, simple library with few configuration options. Unlike most other SAE libraries (e.g. [SAELens](https://github.com/jbloomAus/SAELens)), it does not cache activations on disk, but rather computes them on-the-fly. This allows us to scale to very large models and datasets with zero storage overhead, but has the downside that trying different hyperparameters for the same model and dataset will be slower than if we cached activations (since activations will be re-computed). We may add caching as an option in the future. 5 | 6 | Following Gao et al., we use a TopK activation function which directly enforces a desired level of sparsity in the activations. This is in contrast to other libraries which use an L1 penalty in the loss function. We believe TopK is a Pareto improvement over the L1 approach, and hence do not plan on supporting it. 7 | 8 | ## Loading pretrained SAEs 9 | 10 | To load a pretrained SAE from the HuggingFace Hub, you can use the `Sae.load_from_hub` method as follows: 11 | 12 | ```python 13 | from sparsify import Sae 14 | 15 | sae = Sae.load_from_hub("EleutherAI/sae-llama-3-8b-32x", hookpoint="layers.10") 16 | ``` 17 | 18 | This will load the SAE for residual stream layer 10 of Llama 3 8B, which was trained with an expansion factor of 32. You can also load the SAEs for all layers at once using `Sae.load_many`: 19 | 20 | ```python 21 | saes = Sae.load_many("EleutherAI/sae-llama-3-8b-32x") 22 | saes["layers.10"] 23 | ``` 24 | 25 | The dictionary returned by `load_many` is guaranteed to be [naturally sorted](https://en.wikipedia.org/wiki/Natural_sort_order) by the name of the hook point. For the common case where the hook points are named `embed_tokens`, `layers.0`, ..., `layers.n`, this means that the SAEs will be sorted by layer number. We can then gather the SAE activations for a model forward pass as follows: 26 | 27 | ```python 28 | from transformers import AutoModelForCausalLM, AutoTokenizer 29 | import torch 30 | 31 | tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B") 32 | inputs = tokenizer("Hello, world!", return_tensors="pt") 33 | 34 | with torch.inference_mode(): 35 | model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B") 36 | outputs = model(**inputs, output_hidden_states=True) 37 | 38 | latent_acts = [] 39 | for sae, hidden_state in zip(saes.values(), outputs.hidden_states): 40 | # (N, D) input shape expected 41 | hidden_state = hidden_state.flatten(0, 1) 42 | latent_acts.append(sae.encode(hidden_state)) 43 | 44 | # Do stuff with the latent activations 45 | ``` 46 | 47 | For use cases beyond collecting residual stream SAE activations, we recommend PyTorch hooks ([see examples](https://gist.github.com/luciaquirke/7105708dac0cfc632d68f33c79b59e5c).) 48 | 49 | ## Training SAEs and transcoders 50 | 51 | To train SAEs from the command line, you can use the following command: 52 | 53 | ```bash 54 | python -m sparsify EleutherAI/pythia-160m [optional dataset] [--transcode] 55 | ``` 56 | By default, we use the `EleutherAI/SmolLM2-135M-10B` dataset for training, but you can use any dataset from the HuggingFace Hub, or any local dataset in HuggingFace format (the string is passed to `load_dataset` from the `datasets` library). You can pass arbitrary dataset loading arguments to HuggingFace using `--data_args "arg1=x,arg2=y". 57 | 58 | The CLI supports all of the config options provided by the `TrainConfig` class. You can see them by running `python -m sparsify --help`. 59 | 60 | Programmatic usage is simple. Here is an example: 61 | 62 | ```python 63 | import torch 64 | from datasets import load_dataset 65 | from transformers import AutoModelForCausalLM, AutoTokenizer 66 | 67 | from sparsify import SaeConfig, Trainer, TrainConfig 68 | from sparsify.data import chunk_and_tokenize 69 | 70 | MODEL = "HuggingFaceTB/SmolLM2-135M" 71 | dataset = load_dataset( 72 | "EleutherAI/SmolLM2-135M-10B", split="train", 73 | ) 74 | tokenizer = AutoTokenizer.from_pretrained(MODEL) 75 | tokenized = chunk_and_tokenize(dataset, tokenizer) 76 | 77 | 78 | gpt = AutoModelForCausalLM.from_pretrained( 79 | MODEL, 80 | device_map={"": "cuda"}, 81 | torch_dtype=torch.bfloat16, 82 | ) 83 | 84 | cfg = TrainConfig(SaeConfig(), batch_size=16) 85 | trainer = Trainer(cfg, tokenized, gpt) 86 | 87 | trainer.fit() 88 | ``` 89 | 90 | ## Finetuning SAEs 91 | 92 | To finetune a pretrained SAE, pass its path to the `finetune` argument. 93 | 94 | ```bash 95 | python -m sparsify EleutherAI/pythia-160m togethercomputer/RedPajama-Data-1T-Sample --finetune EleutherAI/sae-pythia-160m-32x 96 | ``` 97 | 98 | ## Custom hookpoints 99 | 100 | By default, the SAEs are trained on the residual stream activations of the model. However, you can also train SAEs on the activations of any other submodule(s) by specifying custom hookpoint patterns. These patterns are like standard PyTorch module names (e.g. `h.0.ln_1`) but also allow [Unix pattern matching syntax](https://docs.python.org/3/library/fnmatch.html), including wildcards and character sets. For example, to train SAEs on the output of every attention module and the inner activations of every MLP in GPT-2, you can use the following code: 101 | 102 | ```bash 103 | python -m sparsify gpt2 --hookpoints "h.*.attn" "h.*.mlp.act" 104 | ``` 105 | 106 | To restrict to the first three layers: 107 | 108 | ```bash 109 | python -m sparsify gpt2 --hookpoints "h.[012].attn" "h.[012].mlp.act" 110 | ``` 111 | 112 | We currently don't support fine-grained manual control over the learning rate, number of latents, or other hyperparameters on a hookpoint-by-hookpoint basis. By default, the `expansion_factor` option is used to select the appropriate number of latents for each hookpoint based on the width of that hookpoint's output. The default learning rate for each hookpoint is then set using an inverse square root scaling law based on the number of latents. If you manually set the number of latents or the learning rate, it will be applied to all hookpoints. 113 | 114 | ## Distributed training 115 | 116 | We support distributed training via PyTorch's `torchrun` command. By default we use the Distributed Data Parallel method, which means that the weights of each SAE are replicated on every GPU. 117 | 118 | ```bash 119 | torchrun --nproc_per_node gpu -m sparsify meta-llama/Meta-Llama-3-8B --batch_size 1 --layers 16 24 --k 192 --grad_acc_steps 8 --ctx_len 2048 120 | ``` 121 | 122 | This is simple, but very memory inefficient. If you want to train SAEs for many layers of a model, we recommend using the `--distribute_modules` flag, which allocates the SAEs for different layers to different GPUs. Currently, we require that the number of GPUs evenly divides the number of layers you're training SAEs for. 123 | 124 | ```bash 125 | torchrun --nproc_per_node gpu -m sparsify meta-llama/Meta-Llama-3-8B --distribute_modules --batch_size 1 --layer_stride 2 --grad_acc_steps 8 --ctx_len 2048 --k 192 --load_in_8bit --micro_acc_steps 2 126 | ``` 127 | 128 | The above command trains an SAE for every _even_ layer of Llama 3 8B, using all available GPUs. It accumulates gradients over 8 minibatches, and splits each minibatch into 2 microbatches before feeding them into the SAE encoder, thus saving a lot of memory. It also loads the model in 8-bit precision using `bitsandbytes`. This command requires no more than 48GB of memory per GPU on an 8 GPU node. 129 | 130 | ## TODO 131 | 132 | There are several features that we'd like to add in the near future: 133 | - [ ] Support for caching activations 134 | - [ ] Evaluate SAEs with KL divergence when grafted into the model 135 | 136 | If you'd like to help out with any of these, please feel free to open a PR! You can collaborate with us in the sparse-autoencoders channel of the EleutherAI Discord, or email lucia@eleuther.ai. 137 | 138 | ## Installation 139 | 140 | `pip install eai-sparsify` 141 | 142 | ## Development 143 | 144 | Run `pip install -e .[dev]` from the sparsify directory. 145 | 146 | We use [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/) for releases. 147 | 148 | ## Experimental features 149 | 150 | Linear k decay schedule: 151 | 152 | ```bash python -m sparsify gpt2 --hookpoints "h.*.attn" "h.*.mlp.act" --k_decay_steps 10_000``` 153 | 154 | GroupMax activation function: 155 | 156 | ```bash python -m sparsify gpt2 --hookpoints "h.*.attn" "h.*.mlp.act" --activation groupmax``` 157 | 158 | End-to-end training: 159 | 160 | ```bash python -m sparsify gpt2 --hookpoints "h.*.attn" "h.*.mlp.act" --loss_fn ce``` 161 | 162 | or 163 | 164 | ```bash python -m sparsify gpt2 --hookpoints "h.*.attn" "h.*.mlp.act" --loss_fn kl``` 165 | -------------------------------------------------------------------------------- /sparsify/sparse_coder.py: -------------------------------------------------------------------------------- 1 | import json 2 | from fnmatch import fnmatch 3 | from pathlib import Path 4 | from typing import NamedTuple 5 | 6 | import einops 7 | import torch 8 | from huggingface_hub import snapshot_download 9 | from natsort import natsorted 10 | from safetensors import safe_open 11 | from safetensors.torch import load_model, save_model 12 | from torch import Tensor, nn 13 | 14 | from .config import SparseCoderConfig 15 | from .fused_encoder import EncoderOutput, fused_encoder 16 | from .utils import decoder_impl 17 | 18 | 19 | class ForwardOutput(NamedTuple): 20 | sae_out: Tensor 21 | 22 | latent_acts: Tensor 23 | """Activations of the top-k latents.""" 24 | 25 | latent_indices: Tensor 26 | """Indices of the top-k features.""" 27 | 28 | fvu: Tensor 29 | """Fraction of variance unexplained.""" 30 | 31 | auxk_loss: Tensor 32 | """AuxK loss, if applicable.""" 33 | 34 | multi_topk_fvu: Tensor 35 | """Multi-TopK FVU, if applicable.""" 36 | 37 | 38 | class SparseCoder(nn.Module): 39 | def __init__( 40 | self, 41 | d_in: int, 42 | cfg: SparseCoderConfig, 43 | device: str | torch.device = "cpu", 44 | dtype: torch.dtype | None = None, 45 | *, 46 | decoder: bool = True, 47 | ): 48 | super().__init__() 49 | self.cfg = cfg 50 | self.d_in = d_in 51 | self.num_latents = cfg.num_latents or d_in * cfg.expansion_factor 52 | 53 | self.encoder = nn.Linear(d_in, self.num_latents, device=device, dtype=dtype) 54 | self.encoder.bias.data.zero_() 55 | 56 | if decoder: 57 | # Transcoder initialization: use zeros 58 | if cfg.transcode: 59 | self.W_dec = nn.Parameter(torch.zeros_like(self.encoder.weight.data)) 60 | 61 | # Sparse autoencoder initialization: use the transpose of encoder weights 62 | else: 63 | self.W_dec = nn.Parameter(self.encoder.weight.data.clone()) 64 | if self.cfg.normalize_decoder: 65 | self.set_decoder_norm_to_unit_norm() 66 | else: 67 | self.W_dec = None 68 | 69 | self.b_dec = nn.Parameter(torch.zeros(d_in, dtype=dtype, device=device)) 70 | self.W_skip = ( 71 | nn.Parameter(torch.zeros(d_in, d_in, device=device, dtype=dtype)) 72 | if cfg.skip_connection 73 | else None 74 | ) 75 | 76 | @staticmethod 77 | def load_many( 78 | name: str, 79 | local: bool = False, 80 | layers: list[str] | None = None, 81 | device: str | torch.device = "cpu", 82 | *, 83 | decoder: bool = True, 84 | pattern: str | None = None, 85 | ) -> dict[str, "SparseCoder"]: 86 | """Load sparse coders for multiple hookpoints on a single model and dataset.""" 87 | pattern = pattern + "/*" if pattern is not None else None 88 | if local: 89 | repo_path = Path(name) 90 | else: 91 | repo_path = Path(snapshot_download(name, allow_patterns=pattern)) 92 | 93 | if layers is not None: 94 | return { 95 | layer: SparseCoder.load_from_disk( 96 | repo_path / layer, device=device, decoder=decoder 97 | ) 98 | for layer in natsorted(layers) 99 | } 100 | files = [ 101 | f 102 | for f in repo_path.iterdir() 103 | if f.is_dir() and (pattern is None or fnmatch(f.name, pattern)) 104 | ] 105 | return { 106 | f.name: SparseCoder.load_from_disk(f, device=device, decoder=decoder) 107 | for f in natsorted(files, key=lambda f: f.name) 108 | } 109 | 110 | @staticmethod 111 | def load_from_hub( 112 | name: str, 113 | hookpoint: str | None = None, 114 | device: str | torch.device = "cpu", 115 | *, 116 | decoder: bool = True, 117 | ) -> "SparseCoder": 118 | # Download from the HuggingFace Hub 119 | repo_path = Path( 120 | snapshot_download( 121 | name, 122 | allow_patterns=f"{hookpoint}/*" if hookpoint is not None else None, 123 | ) 124 | ) 125 | if hookpoint is not None: 126 | repo_path = repo_path / hookpoint 127 | 128 | # No layer specified, and there are multiple layers 129 | elif not repo_path.joinpath("cfg.json").exists(): 130 | raise FileNotFoundError("No config file found; try specifying a layer.") 131 | 132 | return SparseCoder.load_from_disk(repo_path, device=device, decoder=decoder) 133 | 134 | @staticmethod 135 | def load_from_disk( 136 | path: Path | str, 137 | device: str | torch.device = "cpu", 138 | *, 139 | decoder: bool = True, 140 | ) -> "SparseCoder": 141 | path = Path(path) 142 | 143 | with open(path / "cfg.json", "r") as f: 144 | cfg_dict = json.load(f) 145 | d_in = cfg_dict.pop("d_in") 146 | cfg = SparseCoderConfig.from_dict(cfg_dict, drop_extra_fields=True) 147 | 148 | safetensors_path = str(path / "sae.safetensors") 149 | 150 | with safe_open(safetensors_path, framework="pt", device="cpu") as f: 151 | first_key = next(iter(f.keys())) 152 | reference_dtype = f.get_tensor(first_key).dtype 153 | 154 | sae = SparseCoder( 155 | d_in, cfg, device=device, decoder=decoder, dtype=reference_dtype 156 | ) 157 | 158 | load_model( 159 | model=sae, 160 | filename=safetensors_path, 161 | device=str(device), 162 | # TODO: Maybe be more fine-grained about this in the future? 163 | strict=decoder, 164 | ) 165 | return sae 166 | 167 | def save_to_disk(self, path: Path | str): 168 | path = Path(path) 169 | path.mkdir(parents=True, exist_ok=True) 170 | 171 | save_model(self, str(path / "sae.safetensors")) 172 | with open(path / "cfg.json", "w") as f: 173 | json.dump( 174 | { 175 | **self.cfg.to_dict(), 176 | "d_in": self.d_in, 177 | }, 178 | f, 179 | ) 180 | 181 | @property 182 | def device(self): 183 | return self.encoder.weight.device 184 | 185 | @property 186 | def dtype(self): 187 | return self.encoder.weight.dtype 188 | 189 | def encode(self, x: Tensor) -> EncoderOutput: 190 | """Encode the input and select the top-k latents.""" 191 | if not self.cfg.transcode: 192 | x = x - self.b_dec 193 | 194 | return fused_encoder( 195 | x, self.encoder.weight, self.encoder.bias, self.cfg.k, self.cfg.activation 196 | ) 197 | 198 | def decode(self, top_acts: Tensor, top_indices: Tensor) -> Tensor: 199 | assert self.W_dec is not None, "Decoder weight was not initialized." 200 | 201 | y = decoder_impl(top_indices, top_acts.to(self.dtype), self.W_dec.mT) 202 | return y + self.b_dec 203 | 204 | # Wrapping the forward in bf16 autocast improves performance by almost 2x 205 | @torch.autocast( 206 | "cuda", 207 | dtype=torch.bfloat16, 208 | enabled=torch.cuda.is_bf16_supported(), 209 | ) 210 | def forward( 211 | self, x: Tensor, y: Tensor | None = None, *, dead_mask: Tensor | None = None 212 | ) -> ForwardOutput: 213 | top_acts, top_indices, pre_acts = self.encode(x) 214 | 215 | # If we aren't given a distinct target, we're autoencoding 216 | if y is None: 217 | y = x 218 | 219 | # Decode 220 | sae_out = self.decode(top_acts, top_indices) 221 | if self.W_skip is not None: 222 | sae_out += x.to(self.dtype) @ self.W_skip.mT 223 | 224 | # Compute the residual 225 | e = y - sae_out 226 | 227 | # Used as a denominator for putting everything on a reasonable scale 228 | total_variance = (y - y.mean(0)).pow(2).sum() 229 | 230 | # Second decoder pass for AuxK loss 231 | if dead_mask is not None and (num_dead := int(dead_mask.sum())) > 0: 232 | # Heuristic from Appendix B.1 in the paper 233 | k_aux = y.shape[-1] // 2 234 | 235 | # Reduce the scale of the loss if there are a small number of dead latents 236 | scale = min(num_dead / k_aux, 1.0) 237 | k_aux = min(k_aux, num_dead) 238 | 239 | # Don't include living latents in this loss 240 | auxk_latents = torch.where(dead_mask[None], pre_acts, -torch.inf) 241 | 242 | # Top-k dead latents 243 | auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False) 244 | 245 | # Encourage the top ~50% of dead latents to predict the residual of the 246 | # top k living latents 247 | e_hat = self.decode(auxk_acts, auxk_indices) 248 | auxk_loss = (e_hat - e.detach()).pow(2).sum() 249 | auxk_loss = scale * auxk_loss / total_variance 250 | else: 251 | auxk_loss = sae_out.new_tensor(0.0) 252 | 253 | l2_loss = e.pow(2).sum() 254 | fvu = l2_loss / total_variance 255 | 256 | if self.cfg.multi_topk: 257 | top_acts, top_indices = pre_acts.topk(4 * self.cfg.k, sorted=False) 258 | sae_out = self.decode(top_acts, top_indices) 259 | 260 | multi_topk_fvu = (sae_out - y).pow(2).sum() / total_variance 261 | else: 262 | multi_topk_fvu = sae_out.new_tensor(0.0) 263 | 264 | return ForwardOutput( 265 | sae_out, 266 | top_acts, 267 | top_indices, 268 | fvu, 269 | auxk_loss, 270 | multi_topk_fvu, 271 | ) 272 | 273 | @torch.no_grad() 274 | def set_decoder_norm_to_unit_norm(self): 275 | assert self.W_dec is not None, "Decoder weight was not initialized." 276 | 277 | eps = torch.finfo(self.W_dec.dtype).eps 278 | norm = torch.norm(self.W_dec.data, dim=1, keepdim=True) 279 | self.W_dec.data /= norm + eps 280 | 281 | @torch.no_grad() 282 | def remove_gradient_parallel_to_decoder_directions(self): 283 | assert self.W_dec is not None, "Decoder weight was not initialized." 284 | assert self.W_dec.grad is not None # keep pyright happy 285 | 286 | parallel_component = einops.einsum( 287 | self.W_dec.grad, 288 | self.W_dec.data, 289 | "d_sae d_in, d_sae d_in -> d_sae", 290 | ) 291 | self.W_dec.grad -= einops.einsum( 292 | parallel_component, 293 | self.W_dec.data, 294 | "d_sae, d_sae d_in -> d_sae d_in", 295 | ) 296 | 297 | 298 | # Allow for alternate naming conventions 299 | Sae = SparseCoder 300 | -------------------------------------------------------------------------------- /sparsify/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | from dataclasses import asdict 4 | from fnmatch import fnmatchcase 5 | from glob import glob 6 | from typing import Sized 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from datasets import Dataset as HfDataset 11 | from natsort import natsorted 12 | from safetensors.torch import load_model 13 | from schedulefree import ScheduleFreeWrapper 14 | from torch import Tensor, nn 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | from torch.utils.data import DataLoader 17 | from tqdm.auto import tqdm 18 | from transformers import PreTrainedModel, get_linear_schedule_with_warmup 19 | 20 | from .config import TrainConfig 21 | from .data import MemmapDataset 22 | from .muon import Muon 23 | from .sign_sgd import SignSGD 24 | from .sparse_coder import SparseCoder 25 | from .utils import get_layer_list, resolve_widths, set_submodule 26 | 27 | 28 | class Trainer: 29 | def __init__( 30 | self, 31 | cfg: TrainConfig, 32 | dataset: HfDataset | MemmapDataset, 33 | model: PreTrainedModel, 34 | ): 35 | # Store the whole model, including any potential causal LM wrapper 36 | self.model = model 37 | 38 | if cfg.hookpoints: 39 | assert not cfg.layers, "Cannot specify both `hookpoints` and `layers`." 40 | 41 | # Replace wildcard patterns 42 | raw_hookpoints = [] 43 | for name, _ in model.base_model.named_modules(): 44 | if any(fnmatchcase(name, pat) for pat in cfg.hookpoints): 45 | raw_hookpoints.append(name) 46 | 47 | # Natural sort to impose a consistent order 48 | cfg.hookpoints = natsorted(raw_hookpoints) 49 | else: 50 | # If no layers are specified, train on all of them 51 | if not cfg.layers: 52 | N = model.config.num_hidden_layers 53 | cfg.layers = list(range(0, N)) 54 | 55 | # Now convert layers to hookpoints 56 | layers_name, _ = get_layer_list(model) 57 | cfg.hookpoints = [f"{layers_name}.{i}" for i in cfg.layers] 58 | 59 | cfg.hookpoints = cfg.hookpoints[:: cfg.layer_stride] 60 | 61 | self.cfg = cfg 62 | self.dataset = dataset 63 | self.distribute_modules() 64 | 65 | device = model.device 66 | input_widths = resolve_widths(model, cfg.hookpoints) 67 | unique_widths = set(input_widths.values()) 68 | 69 | if cfg.distribute_modules and len(unique_widths) > 1: 70 | # dist.all_to_all requires tensors to have the same shape across ranks 71 | raise ValueError( 72 | f"All modules must output tensors of the same shape when using " 73 | f"`distribute_modules=True`, got {unique_widths}" 74 | ) 75 | 76 | # Initialize all the SAEs 77 | print(f"Initializing SAEs with random seed(s) {cfg.init_seeds}") 78 | self.saes = {} 79 | for hook in self.local_hookpoints(): 80 | for seed in cfg.init_seeds: 81 | torch.manual_seed(seed) 82 | 83 | # Add suffix to the name to disambiguate multiple seeds 84 | name = f"{hook}/seed{seed}" if len(cfg.init_seeds) > 1 else hook 85 | self.saes[name] = SparseCoder( 86 | input_widths[hook], cfg.sae, device, dtype=torch.float32 87 | ) 88 | 89 | assert isinstance(dataset, Sized) 90 | num_batches = len(dataset) // cfg.batch_size 91 | 92 | match cfg.optimizer: 93 | case "adam": 94 | try: 95 | from bitsandbytes.optim import Adam8bit as Adam 96 | 97 | print("Using 8-bit Adam from bitsandbytes") 98 | except ImportError: 99 | from torch.optim import Adam 100 | 101 | print( 102 | "bitsandbytes 8-bit Adam not available, using torch.optim.Adam" 103 | ) 104 | print("Run `pip install bitsandbytes` for less memory usage.") 105 | 106 | pgs = [ 107 | dict( 108 | params=sae.parameters(), 109 | lr=cfg.lr or 2e-4 / (sae.num_latents / (2**14)) ** 0.5, 110 | ) 111 | for sae in self.saes.values() 112 | ] 113 | # For logging purposes 114 | lrs = [f"{lr:.2e}" for lr in sorted(set(pg["lr"] for pg in pgs))] 115 | 116 | adam = Adam(pgs) 117 | self.optimizers = [adam] 118 | self.lr_schedulers = [ 119 | get_linear_schedule_with_warmup( 120 | adam, cfg.lr_warmup_steps, num_batches 121 | ) 122 | ] 123 | case "muon": 124 | params = {p for sae in self.saes.values() for p in sae.parameters()} 125 | muon_params = {p for p in params if p.ndim >= 2} 126 | lrs = [f"{cfg.lr or 2e-3:.2e}"] 127 | 128 | self.optimizers = [ 129 | Muon( 130 | muon_params, 131 | # Muon LR is independent of the number of latents 132 | lr=cfg.lr or 2e-3, 133 | # Muon distributes the work of the Newton-Schulz iterations 134 | # across all ranks for DDP but this doesn't make sense when 135 | # we're distributing modules across ranks 136 | ddp=not cfg.distribute_modules, 137 | ), 138 | torch.optim.Adam(params - muon_params, lr=cfg.lr or 2e-3), 139 | ] 140 | self.lr_schedulers = [ 141 | get_linear_schedule_with_warmup(self.optimizers[0], 0, num_batches), 142 | get_linear_schedule_with_warmup( 143 | self.optimizers[1], cfg.lr_warmup_steps, num_batches 144 | ), 145 | ] 146 | case "signum": 147 | from schedulefree import ScheduleFreeWrapper 148 | 149 | pgs = [ 150 | dict( 151 | params=sae.parameters(), 152 | lr=cfg.lr or 5e-3 / (sae.num_latents / (2**14)) ** 0.5, 153 | ) 154 | for sae in self.saes.values() 155 | ] 156 | lrs = [f"{lr:.2e}" for lr in sorted(set(pg["lr"] for pg in pgs))] 157 | 158 | opt = ScheduleFreeWrapper(SignSGD(pgs), momentum=0.95) 159 | opt.train() 160 | 161 | self.optimizers = [opt] 162 | self.lr_schedulers = [] 163 | case other: 164 | raise ValueError(f"Unknown optimizer '{other}'") 165 | 166 | print(f"Learning rates: {lrs}" if len(lrs) > 1 else f"Learning rate: {lrs[0]}") 167 | self.global_step = 0 168 | self.num_tokens_since_fired = { 169 | name: torch.zeros(sae.num_latents, device=device, dtype=torch.long) 170 | for name, sae in self.saes.items() 171 | } 172 | self.exclude_tokens = torch.tensor( 173 | self.cfg.exclude_tokens, device=device, dtype=torch.long 174 | ) 175 | 176 | num_latents = list(self.saes.values())[0].num_latents 177 | self.initial_k = min(num_latents, round(list(input_widths.values())[0] * 10)) 178 | self.final_k = self.cfg.sae.k 179 | 180 | self.best_loss = ( 181 | {name: float("inf") for name in self.local_hookpoints()} 182 | if self.cfg.loss_fn == "fvu" 183 | else float("inf") 184 | ) 185 | 186 | def load_state(self, path: str): 187 | """Load the trainer state from disk.""" 188 | device = self.model.device 189 | 190 | # Load the train state first so we can print the step number 191 | train_state = torch.load( 192 | f"{path}/state.pt", map_location=device, weights_only=True 193 | ) 194 | self.global_step = train_state["global_step"] 195 | 196 | for file in glob(f"{path}/rank_*_state.pt"): 197 | rank_state = torch.load(file, map_location=device, weights_only=True) 198 | 199 | for k in self.local_hookpoints(): 200 | if k in rank_state["num_tokens_since_fired"]: 201 | self.num_tokens_since_fired[k] = rank_state[ 202 | "num_tokens_since_fired" 203 | ][k] 204 | 205 | if not isinstance(rank_state["best_loss"], dict): 206 | self.best_loss = rank_state["best_loss"] 207 | elif k in rank_state["best_loss"]: 208 | self.best_loss[k] = rank_state["best_loss"][k] # type: ignore 209 | 210 | print( 211 | f"\033[92mResuming training at step {self.global_step} from '{path}'\033[0m" 212 | ) 213 | 214 | for i, scheduler in enumerate(self.lr_schedulers): 215 | lr_state = torch.load( 216 | f"{path}/lr_scheduler_{i}.pt", map_location=device, weights_only=True 217 | ) 218 | scheduler.load_state_dict(lr_state) 219 | 220 | for i, optimizer in enumerate(self.optimizers): 221 | opt_state = torch.load( 222 | f"{path}/optimizer_{i}.pt", map_location=device, weights_only=True 223 | ) 224 | optimizer.load_state_dict(opt_state) 225 | 226 | for name, sae in self.saes.items(): 227 | load_model(sae, f"{path}/{name}/sae.safetensors", device=str(device)) 228 | 229 | def get_current_k(self) -> int: 230 | """Get the current k value based on a linear decay schedule.""" 231 | if self.global_step >= self.cfg.k_decay_steps: 232 | return self.final_k 233 | 234 | progress = self.global_step / self.cfg.k_decay_steps 235 | return round(self.initial_k * (1 - progress) + self.final_k * progress) 236 | 237 | def fit(self): 238 | # Use Tensor Cores even for fp32 matmuls 239 | torch.set_float32_matmul_precision("high") 240 | 241 | # Make sure the model is frozen 242 | self.model.requires_grad_(False) 243 | 244 | rank_zero = not dist.is_initialized() or dist.get_rank() == 0 245 | ddp = dist.is_initialized() and not self.cfg.distribute_modules 246 | 247 | wandb = None 248 | if self.cfg.log_to_wandb and rank_zero: 249 | try: 250 | import wandb 251 | 252 | wandb.init( 253 | entity=os.environ.get("WANDB_ENTITY", None), 254 | name=self.cfg.run_name, 255 | project=os.environ.get("WANDB_PROJECT", "sparsify"), 256 | config=asdict(self.cfg), 257 | save_code=True, 258 | ) 259 | except (AttributeError, ImportError): 260 | print("Weights & Biases not available, skipping logging.") 261 | print("Run `pip install -U wandb` if you want to use it.") 262 | self.cfg.log_to_wandb = False 263 | 264 | num_sae_params = sum( 265 | p.numel() for s in self.saes.values() for p in s.parameters() 266 | ) 267 | num_model_params = sum(p.numel() for p in self.model.parameters()) 268 | print(f"Number of SAE parameters: {num_sae_params:_}") 269 | print(f"Number of model parameters: {num_model_params:_}") 270 | 271 | num_batches = len(self.dataset) // self.cfg.batch_size 272 | if self.global_step > 0: 273 | assert hasattr(self.dataset, "select"), "Dataset must implement `select`" 274 | 275 | n = self.global_step * self.cfg.batch_size 276 | ds = self.dataset.select(range(n, len(self.dataset))) # type: ignore 277 | else: 278 | ds = self.dataset 279 | 280 | device = self.model.device 281 | dl = DataLoader( 282 | ds, # type: ignore 283 | batch_size=self.cfg.batch_size, 284 | # NOTE: We do not shuffle here for reproducibility; the dataset should 285 | # be shuffled before passing it to the trainer. 286 | shuffle=False, 287 | ) 288 | pbar = tqdm( 289 | desc="Training", 290 | disable=not rank_zero, 291 | initial=self.global_step, 292 | total=num_batches, 293 | ) 294 | 295 | did_fire = { 296 | name: torch.zeros(sae.num_latents, device=device, dtype=torch.bool) 297 | for name, sae in self.saes.items() 298 | } 299 | 300 | tokens_mask: torch.Tensor 301 | 302 | acc_steps = self.cfg.grad_acc_steps * self.cfg.micro_acc_steps 303 | denom = acc_steps * self.cfg.wandb_log_frequency 304 | num_tokens_in_step = 0 305 | 306 | # For logging purposes 307 | avg_auxk_loss = defaultdict(float) 308 | avg_fvu = defaultdict(float) 309 | avg_multi_topk_fvu = defaultdict(float) 310 | avg_ce = 0.0 311 | avg_kl = 0.0 312 | avg_losses = ( 313 | {name: float("inf") for name in self.local_hookpoints()} 314 | if self.cfg.loss_fn == "fvu" 315 | else float("inf") 316 | ) 317 | 318 | if self.cfg.loss_fn == "ce": 319 | batch = next(iter(dl)) 320 | x = batch["input_ids"].to(device) 321 | 322 | clean_loss = self.model(x, labels=x).loss 323 | self.maybe_all_reduce(clean_loss) 324 | if rank_zero: 325 | print(f"Initial CE loss: {clean_loss.item():.4f}") 326 | 327 | # If doing end-to-end transcoders, then we don't actually want to run the 328 | # modules that we're replacing 329 | if self.cfg.sae.transcode: 330 | for point in self.cfg.hookpoints: 331 | set_submodule(self.model.base_model, point, nn.Identity()) 332 | 333 | name_to_module = { 334 | name: self.model.base_model.get_submodule(name) 335 | for name in self.cfg.hookpoints 336 | } 337 | maybe_wrapped: dict[str, DDP] | dict[str, SparseCoder] = {} 338 | module_to_name = {v: k for k, v in name_to_module.items()} 339 | 340 | def hook(module: nn.Module, inputs, outputs): 341 | aux_out = None 342 | 343 | # Maybe unpack tuple inputs and outputs 344 | if isinstance(inputs, tuple): 345 | inputs = inputs[0] 346 | if isinstance(outputs, tuple): 347 | outputs, *aux_out = outputs 348 | mask = tokens_mask 349 | 350 | # Name may optionally contain a suffix of the form /seedN where N is an 351 | # integer. We only care about the part before the slash. 352 | name, _, _ = module_to_name[module].partition("/") 353 | 354 | # Remember the original output shape since we'll need it for e2e training 355 | out_shape = outputs.shape 356 | 357 | # Scatter and gather the hidden states across ranks if necessary 358 | if self.cfg.distribute_modules: 359 | world_outputs = outputs.new_empty( 360 | outputs.shape[0] * dist.get_world_size(), *outputs.shape[1:] 361 | ) 362 | dist.all_gather_into_tensor(world_outputs, outputs) 363 | outputs = world_outputs 364 | 365 | # Don't bother with the communication overhead if we're autoencoding 366 | if self.cfg.sae.transcode: 367 | world_inputs = inputs.new_empty( 368 | inputs.shape[0] * dist.get_world_size(), *inputs.shape[1:] 369 | ) 370 | dist.all_gather_into_tensor(world_inputs, inputs) 371 | inputs = world_inputs 372 | 373 | world_mask = mask.new_empty( 374 | mask.shape[0] * dist.get_world_size(), *mask.shape[1:] 375 | ) 376 | dist.all_gather_into_tensor(world_mask, mask) 377 | mask = world_mask.bool() 378 | 379 | if name not in self.module_plan[dist.get_rank()]: 380 | return 381 | 382 | # Flatten the batch and sequence dimensions 383 | outputs = outputs.flatten(0, 1) 384 | inputs = inputs.flatten(0, 1) if self.cfg.sae.transcode else outputs 385 | mask = mask.flatten(0, 1) 386 | 387 | # Remove tokens not used for training 388 | all_outputs = outputs.detach().clone() 389 | outputs = outputs[mask] 390 | inputs = inputs[mask] 391 | 392 | # On the first iteration, initialize the encoder and decoder biases 393 | raw = self.saes[name] 394 | if self.global_step == 0 and not self.cfg.finetune: 395 | # Ensure the preactivations are centered at initialization 396 | # This is mathematically equivalent to Anthropic's proposal of 397 | # subtracting the decoder bias 398 | if self.cfg.sae.transcode: 399 | mean = self.maybe_all_reduce(inputs.mean(0)).to(raw.dtype) 400 | mean_image = -mean @ raw.encoder.weight.data.T 401 | raw.encoder.bias.data = mean_image 402 | 403 | mean = self.maybe_all_reduce(outputs.mean(0)) 404 | raw.b_dec.data = mean.to(raw.dtype) 405 | 406 | # Make sure the W_dec is still unit-norm if we're autoencoding 407 | if raw.cfg.normalize_decoder and not self.cfg.sae.transcode: 408 | raw.set_decoder_norm_to_unit_norm() 409 | 410 | wrapped = maybe_wrapped[name] 411 | out = wrapped( 412 | x=inputs, 413 | y=outputs, 414 | dead_mask=( 415 | self.num_tokens_since_fired[name] > self.cfg.dead_feature_threshold 416 | if self.cfg.auxk_alpha > 0 417 | else None 418 | ), 419 | ) 420 | 421 | # Update the did_fire mask 422 | did_fire[name][out.latent_indices.flatten()] = True 423 | self.maybe_all_reduce(did_fire[name], "max") # max is boolean "any" 424 | 425 | if self.cfg.loss_fn in ("ce", "kl"): 426 | # Replace the normal output with the SAE output 427 | output = all_outputs.clone() 428 | output[mask] = out.sae_out.type_as(output) 429 | output = output.reshape(out_shape) 430 | return (output, *aux_out) if aux_out is not None else output 431 | 432 | # Metrics that only make sense for local 433 | avg_fvu[name] += float(self.maybe_all_reduce(out.fvu.detach()) / denom) 434 | if self.cfg.auxk_alpha > 0: 435 | avg_auxk_loss[name] += float( 436 | self.maybe_all_reduce(out.auxk_loss.detach()) / denom 437 | ) 438 | if self.cfg.sae.multi_topk: 439 | avg_multi_topk_fvu[name] += float( 440 | self.maybe_all_reduce(out.multi_topk_fvu.detach()) / denom 441 | ) 442 | 443 | # Do a "local" backward pass if we're not training end-to-end 444 | loss = ( 445 | out.fvu + self.cfg.auxk_alpha * out.auxk_loss + out.multi_topk_fvu / 8 446 | ) 447 | loss.div(acc_steps).backward() 448 | 449 | k = self.get_current_k() 450 | for name, sae in self.saes.items(): 451 | sae.cfg.k = k 452 | 453 | for batch in dl: 454 | x = batch["input_ids"].to(device) 455 | tokens_mask = torch.isin(x, self.exclude_tokens, invert=True) 456 | 457 | if not maybe_wrapped: 458 | # Wrap the SAEs with Distributed Data Parallel. We have to do this 459 | # after we set the decoder bias, otherwise DDP will not register 460 | # gradients flowing to the bias after the first step. 461 | maybe_wrapped = ( 462 | { 463 | name: DDP(sae, device_ids=[dist.get_rank()]) 464 | for name, sae in self.saes.items() 465 | } 466 | if ddp 467 | else self.saes 468 | ) 469 | 470 | # Bookkeeping for dead feature detection 471 | N = tokens_mask.sum().item() 472 | num_tokens_in_step += N 473 | 474 | # Compute clean logits if using KL loss 475 | clean_probs = ( 476 | self.model(x).logits.softmax(dim=-1) 477 | if self.cfg.loss_fn == "kl" 478 | else None 479 | ) 480 | 481 | # Forward pass on the model to get the next batch of activations 482 | handles = [ 483 | mod.register_forward_hook(hook) for mod in name_to_module.values() 484 | ] 485 | try: 486 | match self.cfg.loss_fn: 487 | case "ce": 488 | ce = self.model(x, labels=x).loss 489 | ce.div(acc_steps).backward() 490 | 491 | avg_ce += float(self.maybe_all_reduce(ce.detach()) / denom) 492 | 493 | avg_losses = avg_ce 494 | case "kl": 495 | dirty_lps = self.model(x).logits.log_softmax(dim=-1) 496 | kl = -torch.sum(clean_probs * dirty_lps, dim=-1).mean() 497 | kl.div(acc_steps).backward() 498 | 499 | avg_kl += float(self.maybe_all_reduce(kl) / denom) 500 | avg_losses = avg_kl 501 | case "fvu": 502 | self.model(x) 503 | avg_losses = dict(avg_fvu) 504 | case other: 505 | raise ValueError(f"Unknown loss function '{other}'") 506 | finally: 507 | for handle in handles: 508 | handle.remove() 509 | 510 | # Check if we need to actually do a training step 511 | step, substep = divmod(self.global_step + 1, self.cfg.grad_acc_steps) 512 | if substep == 0: 513 | if self.cfg.sae.normalize_decoder and not self.cfg.sae.transcode: 514 | for sae in self.saes.values(): 515 | sae.remove_gradient_parallel_to_decoder_directions() 516 | 517 | for optimizer in self.optimizers: 518 | optimizer.step() 519 | optimizer.zero_grad() 520 | 521 | for scheduler in self.lr_schedulers: 522 | scheduler.step() 523 | 524 | k = self.get_current_k() 525 | for name, sae in self.saes.items(): 526 | sae.cfg.k = k 527 | 528 | ############### 529 | with torch.no_grad(): 530 | # Update the dead feature mask 531 | for name, counts in self.num_tokens_since_fired.items(): 532 | counts += num_tokens_in_step 533 | counts[did_fire[name]] = 0 534 | 535 | # Reset stats for this step 536 | num_tokens_in_step = 0 537 | for mask in did_fire.values(): 538 | mask.zero_() 539 | 540 | if (step + 1) % self.cfg.save_every == 0: 541 | self.save() 542 | 543 | if self.cfg.save_best: 544 | self.save_best(avg_losses) 545 | 546 | if ( 547 | self.cfg.log_to_wandb 548 | and (step + 1) % self.cfg.wandb_log_frequency == 0 549 | ): 550 | info = {} 551 | if self.cfg.loss_fn == "ce": 552 | info["ce_loss"] = avg_ce 553 | elif self.cfg.loss_fn == "kl": 554 | info["kl_loss"] = avg_kl 555 | 556 | for name in self.saes: 557 | mask = ( 558 | self.num_tokens_since_fired[name] 559 | > self.cfg.dead_feature_threshold 560 | ) 561 | 562 | ratio = mask.mean(dtype=torch.float32).item() 563 | info.update({f"dead_pct/{name}": ratio}) 564 | if self.cfg.loss_fn == "fvu": 565 | info[f"fvu/{name}"] = avg_fvu[name] 566 | 567 | if self.cfg.auxk_alpha > 0: 568 | info[f"auxk/{name}"] = avg_auxk_loss[name] 569 | if self.cfg.sae.multi_topk: 570 | info[f"multi_topk_fvu/{name}"] = avg_multi_topk_fvu[name] 571 | 572 | if self.cfg.distribute_modules: 573 | outputs = [{} for _ in range(dist.get_world_size())] 574 | dist.gather_object(info, outputs if rank_zero else None) 575 | info.update({k: v for out in outputs for k, v in out.items()}) 576 | 577 | if rank_zero: 578 | info["k"] = k 579 | 580 | if wandb is not None: 581 | wandb.log(info, step=step) 582 | 583 | avg_auxk_loss.clear() 584 | avg_fvu.clear() 585 | avg_multi_topk_fvu.clear() 586 | avg_ce = 0.0 587 | avg_kl = 0.0 588 | 589 | self.global_step += 1 590 | pbar.update() 591 | 592 | self.save() 593 | if self.cfg.save_best: 594 | self.save_best(avg_losses) 595 | 596 | pbar.close() 597 | 598 | def local_hookpoints(self) -> list[str]: 599 | return ( 600 | self.module_plan[dist.get_rank()] 601 | if self.module_plan 602 | else self.cfg.hookpoints 603 | ) 604 | 605 | def maybe_all_cat(self, x: Tensor) -> Tensor: 606 | """Concatenate a tensor across all processes.""" 607 | if not dist.is_initialized() or self.cfg.distribute_modules: 608 | return x 609 | 610 | buffer = x.new_empty([dist.get_world_size() * x.shape[0], *x.shape[1:]]) 611 | dist.all_gather_into_tensor(buffer, x) 612 | return buffer 613 | 614 | def maybe_all_reduce(self, x: Tensor, op: str = "mean") -> Tensor: 615 | if not dist.is_initialized() or self.cfg.distribute_modules: 616 | return x 617 | 618 | if op == "sum": 619 | dist.all_reduce(x, op=dist.ReduceOp.SUM) 620 | elif op == "mean": 621 | dist.all_reduce(x, op=dist.ReduceOp.SUM) 622 | x /= dist.get_world_size() 623 | elif op == "max": 624 | dist.all_reduce(x, op=dist.ReduceOp.MAX) 625 | else: 626 | raise ValueError(f"Unknown reduction op '{op}'") 627 | 628 | return x 629 | 630 | def distribute_modules(self): 631 | """Prepare a plan for distributing modules across ranks.""" 632 | if not self.cfg.distribute_modules: 633 | self.module_plan = [] 634 | print(f"Training on modules: {self.cfg.hookpoints}") 635 | return 636 | 637 | layers_per_rank, rem = divmod(len(self.cfg.hookpoints), dist.get_world_size()) 638 | assert rem == 0, "Number of modules must be divisible by world size" 639 | 640 | # Each rank gets a subset of the layers 641 | self.module_plan = [ 642 | self.cfg.hookpoints[start : start + layers_per_rank] 643 | for start in range(0, len(self.cfg.hookpoints), layers_per_rank) 644 | ] 645 | for rank, modules in enumerate(self.module_plan): 646 | print(f"Rank {rank} modules: {modules}") 647 | 648 | def _checkpoint(self, saes: dict[str, SparseCoder], path: str, rank_zero: bool): 649 | """Save SAEs and training state to disk.""" 650 | print("Saving checkpoint") 651 | 652 | for optimizer in self.optimizers: 653 | if isinstance(optimizer, ScheduleFreeWrapper): 654 | optimizer.eval() 655 | 656 | for name, sae in saes.items(): 657 | assert isinstance(sae, SparseCoder) 658 | 659 | sae.save_to_disk(f"{path}/{name}") 660 | 661 | if rank_zero: 662 | for i, scheduler in enumerate(self.lr_schedulers): 663 | torch.save(scheduler.state_dict(), f"{path}/lr_scheduler_{i}.pt") 664 | 665 | for i, optimizer in enumerate(self.optimizers): 666 | torch.save(optimizer.state_dict(), f"{path}/optimizer_{i}.pt") 667 | 668 | torch.save( 669 | {"global_step": self.global_step}, 670 | f"{path}/state.pt", 671 | ) 672 | 673 | self.cfg.save_json(f"{path}/config.json") 674 | 675 | for optimizer in self.optimizers: 676 | if isinstance(optimizer, ScheduleFreeWrapper): 677 | optimizer.train() 678 | 679 | rank = 0 if rank_zero else dist.get_rank() 680 | torch.save( 681 | { 682 | "num_tokens_since_fired": self.num_tokens_since_fired, 683 | "best_loss": self.best_loss, 684 | }, 685 | f"{path}/rank_{rank}_state.pt", 686 | ) 687 | 688 | def save(self): 689 | """Save the SAEs and training state to disk.""" 690 | path = f'{self.cfg.save_dir}/{self.cfg.run_name or "unnamed"}' 691 | 692 | rank_zero = not dist.is_initialized() or dist.get_rank() == 0 693 | 694 | if rank_zero or self.cfg.distribute_modules: 695 | self._checkpoint(self.saes, path, rank_zero) 696 | 697 | # Barrier to ensure all ranks have saved before continuing 698 | if dist.is_initialized(): 699 | dist.barrier() 700 | 701 | def save_best(self, avg_loss: float | dict[str, float]): 702 | """Save individual sparse coders to disk if they have the lowest loss.""" 703 | base_path = f'{self.cfg.save_dir}/{self.cfg.run_name or "unnamed"}/best' 704 | rank_zero = not dist.is_initialized() or dist.get_rank() == 0 705 | 706 | if isinstance(avg_loss, dict): 707 | for name in self.saes: 708 | if avg_loss[name] < self.best_loss[name]: # type: ignore 709 | self.best_loss[name] = avg_loss[name] # type: ignore 710 | 711 | if rank_zero or self.cfg.distribute_modules: 712 | self._checkpoint( 713 | {name: self.saes[name]}, f"{base_path}/{name}", rank_zero 714 | ) 715 | else: 716 | if avg_loss < self.best_loss: # type: ignore 717 | self.best_loss = avg_loss # type: ignore 718 | 719 | if rank_zero or self.cfg.distribute_modules: 720 | self._checkpoint(self.saes, base_path, rank_zero) 721 | 722 | # Barrier to ensure all ranks have saved before continuing 723 | if dist.is_initialized(): 724 | dist.barrier() 725 | 726 | 727 | # Support old name for compatibility 728 | SaeTrainer = Trainer 729 | --------------------------------------------------------------------------------