├── .cursorrules ├── .github └── workflows │ ├── checks.yml │ └── publish.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── README.md ├── examples ├── generate_pythia_acts.py ├── loading_llamascope_saes.ipynb ├── replicate_llamascope_sae.sh ├── train_pythia.py └── train_pythia_pre_generated_acts.py ├── pyproject.toml ├── server ├── .env.example ├── __init__.py └── app.py ├── src └── lm_saes │ ├── __init__.py │ ├── activation │ ├── __init__.py │ ├── factory.py │ ├── processors │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── cached_activation.py │ │ ├── core.py │ │ ├── huggingface.py │ │ └── token.py │ └── writer.py │ ├── analysis │ ├── __init__.py │ ├── auto_interp.py │ ├── feature_analyzer.py │ └── features_to_logits.py │ ├── circuit │ ├── __init__.py │ └── context.py │ ├── config.py │ ├── crosscoder.py │ ├── database.py │ ├── entrypoint.py │ ├── evaluator.py │ ├── initializer.py │ ├── kernels.py │ ├── mixcoder.py │ ├── optim.py │ ├── resource_loaders.py │ ├── runner.py │ ├── sae.py │ ├── trainer.py │ └── utils │ ├── __init__.py │ ├── bytes.py │ ├── concurrent.py │ ├── config.py │ ├── convert_pre_enc_bias.py │ ├── discrete.py │ ├── hooks.py │ ├── huggingface.py │ ├── math.py │ ├── misc.py │ └── tensor_dict.py ├── tests ├── __init__.py ├── integration │ ├── test_activation_factory.py │ └── test_train_sae.py └── unit │ ├── test_activation_processors.py │ ├── test_activation_writer.py │ ├── test_cached_activation.py │ ├── test_concurrent.py │ ├── test_database.py │ ├── test_discrete_mapper.py │ ├── test_evaluator.py │ ├── test_example.py │ ├── test_feature_analyzer.py │ ├── test_initializer.py │ ├── test_misc.py │ ├── test_mixcoder.py │ └── test_sae.py ├── ui ├── .env.example ├── .eslintrc.cjs ├── .gitignore ├── .prettierrc.cjs ├── README.md ├── bun.lockb ├── components.json ├── index.html ├── package.json ├── postcss.config.js ├── public │ ├── openmoss.ico │ └── vite.svg ├── src │ ├── components │ │ ├── app │ │ │ ├── feature-preview.tsx │ │ │ ├── navbar.tsx │ │ │ ├── sample.tsx │ │ │ ├── section-navigator.tsx │ │ │ └── token.tsx │ │ ├── attn-head │ │ │ └── attn-head-card.tsx │ │ ├── dictionary │ │ │ ├── dictionary-card.tsx │ │ │ └── sample.tsx │ │ ├── feature │ │ │ ├── feature-card.tsx │ │ │ ├── interpret.tsx │ │ │ └── sample.tsx │ │ ├── model │ │ │ ├── circuit.tsx │ │ │ └── model-card.tsx │ │ └── ui │ │ │ ├── accordion.tsx │ │ │ ├── badge.tsx │ │ │ ├── button.tsx │ │ │ ├── card.tsx │ │ │ ├── combobox.tsx │ │ │ ├── command.tsx │ │ │ ├── context-menu.tsx │ │ │ ├── data-table.tsx │ │ │ ├── dialog.tsx │ │ │ ├── dropdown-menu.tsx │ │ │ ├── hover-card.tsx │ │ │ ├── input.tsx │ │ │ ├── label.tsx │ │ │ ├── multiple-selector.tsx │ │ │ ├── pagination.tsx │ │ │ ├── popover.tsx │ │ │ ├── select.tsx │ │ │ ├── separator.tsx │ │ │ ├── switch.tsx │ │ │ ├── table.tsx │ │ │ ├── tabs.tsx │ │ │ ├── textarea.tsx │ │ │ ├── toggle.tsx │ │ │ └── tooltip.tsx │ ├── globals.css │ ├── lib │ │ └── utils.ts │ ├── main.tsx │ ├── routes │ │ ├── attn-heads │ │ │ └── page.tsx │ │ ├── dictionaries │ │ │ └── page.tsx │ │ ├── features │ │ │ └── page.tsx │ │ ├── models │ │ │ └── page.tsx │ │ └── page.tsx │ ├── tanstack.d.ts │ ├── types │ │ ├── attn-head.ts │ │ ├── dictionary.ts │ │ ├── feature.ts │ │ └── model.ts │ ├── utils │ │ ├── array.ts │ │ ├── style.ts │ │ └── token.ts │ └── vite-env.d.ts ├── tailwind.config.js ├── tsconfig.json ├── tsconfig.node.json └── vite.config.ts └── uv.lock /.cursorrules: -------------------------------------------------------------------------------- 1 | You are an expert in developing deep learning models in PyTorch. 2 | 3 | Key principles: 4 | 5 | - Add precise type hints and docstrings to all functions and classes. Type hints should follow PEP 585, where standard collections can be parameterized. In other words, use `list[int]` instead of `List[int]`. 6 | - When writing tests, use `pytest` and `pytest-mock` to write tests. Use `mocker` for mocking. 7 | - Current project use `uv` for dependency management. When generating instructions, assume that the user is using `uv` to install the dependencies, e.g. `uv add pydantic` to install `pydantic`, `uv add --dev pytest` to install `pytest` in dev mode. 8 | - When writing docstrings of `dataclass` or pydantic models, write field descriptions right after the field. -------------------------------------------------------------------------------- /.github/workflows/checks.yml: -------------------------------------------------------------------------------- 1 | name: Checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - dev 8 | paths: 9 | - "**" # Include all files by default 10 | - "!.devcontainer/**" 11 | - "!.vscode/**" 12 | - "!.git*" 13 | - "!*.md" 14 | - "!.github/**" 15 | - ".github/workflows/checks.yml" # Still include current workflow 16 | pull_request: 17 | branches: 18 | - main 19 | - dev 20 | paths: 21 | - "**" 22 | - "!.devcontainer/**" 23 | - "!.vscode/**" 24 | - "!.git*" 25 | - "!*.md" 26 | - "!.github/**" 27 | - ".github/workflows/checks.yml" 28 | # Allow this workflow to be called from other workflows 29 | workflow_call: 30 | inputs: 31 | # Requires at least one input to be valid, but in practice we don't need any 32 | dummy: 33 | type: string 34 | required: false 35 | 36 | permissions: 37 | actions: write 38 | contents: write 39 | 40 | jobs: 41 | code-checks: 42 | name: Code Checks 43 | runs-on: ubuntu-latest 44 | steps: 45 | - name: Checkout repository 46 | uses: actions/checkout@v4 47 | with: 48 | submodules: "true" 49 | - name: Install uv 50 | uses: astral-sh/setup-uv@v5 51 | with: 52 | enable-cache: true 53 | cache-dependency-glob: "uv.lock" 54 | 55 | - name: Install the project 56 | run: uv sync --all-extras --dev 57 | 58 | - name: Type check 59 | run: uv run basedpyright . 60 | 61 | - name: Unit tests 62 | run: uv run pytest tests 63 | 64 | ruff: 65 | runs-on: ubuntu-latest 66 | steps: 67 | - uses: actions/checkout@v4 68 | - uses: astral-sh/ruff-action@v1 69 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: "Publish" 2 | 3 | on: 4 | release: 5 | types: ["published"] 6 | 7 | jobs: 8 | run: 9 | name: "Build and publish release" 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Install uv 16 | uses: astral-sh/setup-uv@v3 17 | with: 18 | enable-cache: true 19 | cache-dependency-glob: uv.lock 20 | 21 | - name: Set up Python 22 | run: uv python install 3.12 # Or whatever version I want to use. 23 | 24 | - name: Build 25 | run: uv build 26 | 27 | - name: Publish 28 | run: uv publish -t ${{ secrets.PYPI_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | .pdm-python 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | 167 | ### Python Patch ### 168 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 169 | poetry.toml 170 | 171 | # ruff 172 | .ruff_cache/ 173 | 174 | # LSP config files 175 | pyrightconfig.json 176 | 177 | # End of https://www.toptal.com/developers/gitignore/api/python 178 | 179 | # UI 180 | !ui/**/* 181 | 182 | # Custom 183 | /sync.sh 184 | /connect.sh 185 | /activations 186 | /data 187 | /legacy 188 | /run.py 189 | /checkpoints 190 | /wandb 191 | /exp 192 | /analysis-results 193 | /analysis 194 | /results -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "TransformerLens"] 2 | path = TransformerLens 3 | url = git@github.com:OpenMOSS/TransformerLens.git -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.7.1 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | args: [--fix] 9 | # Run the formatter. 10 | - id: ruff-format 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Language-Model-SAEs 2 | 3 | ## News 4 | 5 | - 2024.10.29 We introduce **Llama Scope**, our first contribution to the open-source Sparse Autoencoder ecosystem. Stay tuned! Link: [Llama Scope: Extracting Millions of Features from Llama-3.1-8B with Sparse Autoencoders](http://arxiv.org/abs/2410.20526) 6 | 7 | - 2024.10.9 Transformers and Mambas are mechanistically similar in both feature and circuit level. Can we follow this line and find **universal motifs and fundamental differences between language model architectures**? Link: [Towards Universality: Studying Mechanistic Similarity Across Language Model Architectures](https://arxiv.org/pdf/2410.06672) 8 | 9 | - 2024.5.22 We propose hierarchical tracing, a promising method to **scale up sparse feature circuit analysis** to industrial size language models! Link: [Automatically Identifying Local and Global Circuits with Linear Computation Graphs](https://arxiv.org/pdf/2405.13868) 10 | 11 | - 2024.2.19 Our first attempt on SAE-based circuit analysis for Othello-GPT leads us to **an example of Attention Superposition in the wild**! Link: [Dictionary learning improves patch-free circuit discovery in mechanistic interpretability: A case study on othello-gpt](https://arxiv.org/pdf/2402.12201). 12 | 13 | ## Installation 14 | 15 | Currently, the codebase use [pdm](https://pdm-project.org/) to manage the dependencies, which is an alternative to [poetry](https://python-poetry.org/). To install the required packages, just install `pdm`, and run the following command: 16 | 17 | ```bash 18 | pdm install 19 | ``` 20 | 21 | This will install all the required packages for the core codebase. Note that if you're in a conda environment, `pdm` will directly take the current environment as the virtual environment for current project, and remove all the packages that are not in the `pyproject.toml` file. So make sure to create a new conda environment (or just deactivate conda, this will use virtualenv by default) before running the above command. A forked version of `TransformerLens` is also included in the dependencies to provide the necessary tools for analyzing features. 22 | 23 | If you want to use the visualization tools, you also need to install the required packages for the frontend, which uses [bun](https://bun.sh/) for dependency management. Follow the instructions on the website to install it, and then run the following command: 24 | 25 | ```bash 26 | cd ui 27 | bun install 28 | ``` 29 | 30 | `bun` is not well-supported on Windows, so you may need to use WSL or other Linux-based solutions to run the frontend, or consider using a different package manager, such as `pnpm` or `yarn`. 31 | 32 | ## Launch an Experiment 33 | 34 | We provide both a programmatic and a configuration-based way to launch an experiment. The configuration-based way is more flexible and recommended for most users. You can find the configuration files in the [examples/configuration](https://github.com/OpenMOSS/Language-Model-SAEs/tree/main/examples/configuration) directory, and modify them to fit your needs. The programmatic way is more suitable for advanced users who want to customize the training process, and you can find the example scripts in the [examples/programmatic](https://github.com/OpenMOSS/Language-Model-SAEs/tree/main/examples/programmatic) directory. 35 | 36 | To simply begin a training process, you can run the following command: 37 | 38 | ```bash 39 | lm-saes train examples/configuration/train.toml 40 | ``` 41 | 42 | which will start the training process using the configuration file [examples/configuration/train.toml](https://github.com/OpenMOSS/Language-Model-SAEs/tree/main/examples/configuration/train.toml). 43 | 44 | To analyze a trained dictionary, you can run the following command: 45 | 46 | ```bash 47 | lm-saes analyze examples/configuration/analyze.toml --sae 48 | ``` 49 | 50 | which will start the analysis process using the configuration file [examples/configuration/analyze.toml](https://github.com/OpenMOSS/Language-Model-SAEs/tree/main/examples/configuration/analyze.toml). The analysis process requires a trained SAE model, which can be obtained from the training process. You may need launch a MongoDB server to store the analysis results, and you can modify the MongoDB settings in the configuration file. 51 | 52 | Generally, our configuration-based pipeline uses outer layer settings as default of the inner layer settings. This is beneficial for easily building deeply nested configurations, where sub-configurations can be reused (such as device and dtype settings). More detail will be provided future. 53 | 54 | ## Visualizing the Learned Dictionary 55 | 56 | The analysis results will be saved using MongoDB, and you can use the provided visualization tools to visualize the learned dictionary. First, start the FastAPI server by running the following command: 57 | 58 | ```bash 59 | uvicorn server.app:app --port 24577 --env-file server/.env 60 | ``` 61 | 62 | Then, copy the `ui/.env.example` file to `ui/.env` and modify the `VITE_BACKEND_URL` to fit your server settings (by default, it's `http://localhost:24577`), and start the frontend by running the following command: 63 | 64 | ```bash 65 | cd ui 66 | bun dev --port 24576 67 | ``` 68 | 69 | That's it! You can now go to `http://localhost:24576` to visualize the learned dictionary and its features. 70 | 71 | ## Development 72 | 73 | We highly welcome contributions to this project. If you have any questions or suggestions, feel free to open an issue or a pull request. We are looking forward to hearing from you! 74 | 75 | TODO: Add development guidelines 76 | 77 | ## Acknowledgement 78 | 79 | The design of the pipeline (including the configuration and some training details) is highly inspired by the [mats_sae_training 80 | ](https://github.com/jbloomAus/mats_sae_training) project (now known as [SAELens](https://github.com/jbloomAus/SAELens)) and heavily relies on the [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) library. We thank the authors for their great work. 81 | 82 | ## Citation 83 | 84 | Please cite this library as: 85 | 86 | ``` 87 | @misc{Ge2024OpenMossSAEs, 88 | title = {OpenMoss Language Model Sparse Autoencoders}, 89 | author = {Xuyang Ge, Fukang Zhu, Junxuan Wang, Wentao Shu, Lingjie Chen, Zhengfu He}, 90 | url = {https://github.com/OpenMOSS/Language-Model-SAEs}, 91 | year = {2024} 92 | } 93 | ``` 94 | -------------------------------------------------------------------------------- /examples/generate_pythia_acts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lm_saes import ( 4 | ActivationFactoryTarget, 5 | DatasetConfig, 6 | GenerateActivationsSettings, 7 | LanguageModelConfig, 8 | generate_activations, 9 | ) 10 | 11 | if __name__ == "__main__": 12 | settings = GenerateActivationsSettings( 13 | model=LanguageModelConfig( 14 | model_name="EleutherAI/pythia-160m", 15 | device="cuda", 16 | dtype="torch.float32", 17 | ), 18 | model_name="pythia-160m", 19 | dataset=DatasetConfig( 20 | dataset_name_or_path="Skylion007/openwebtext", 21 | is_dataset_on_disk=True, 22 | ), 23 | dataset_name="openwebtext", 24 | hook_points=[f"blocks.{layer}.ln1.hook_normalized" for layer in range(12)], 25 | output_dir="activations", 26 | total_tokens=800_000_000, 27 | context_size=1024, 28 | n_samples_per_chunk=1, 29 | model_batch_size=32, 30 | target=ActivationFactoryTarget.BATCHED_ACTIVATIONS_1D, 31 | batch_size=2048 * 16, 32 | buffer_size=2048 * 200, 33 | ) 34 | generate_activations(settings) 35 | torch.distributed.destroy_process_group() 36 | -------------------------------------------------------------------------------- /examples/replicate_llamascope_sae.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | 4 | 5 | exp_factor=$1 6 | tc_in_abbr=$2 7 | layer=$3 8 | 9 | tc_out_abbr=$tc_in_abbr 10 | 11 | if [ "$exp_factor" -eq 128 ]; then 12 | tp_size=2 13 | else 14 | tp_size=1 15 | fi 16 | 17 | k=50 18 | 19 | 20 | if [ "$exp_factor" -eq 8 ]; then 21 | total_training_tokens=800000000 22 | lr=8e-4 23 | elif [ "$exp_factor" -eq 32 ]; then 24 | total_training_tokens=1600000000 25 | lr=8e-4 26 | elif [ "$exp_factor" -eq 128 ]; then 27 | total_training_tokens=3200000000 28 | lr=2e-4 29 | fi 30 | 31 | if [ "$tc_in_abbr" = "TC" ]; then 32 | tc_out_abbr="M" 33 | fi 34 | 35 | WANDB_CONSOLE=off WANDB_MODE=offline torchrun --nproc_per_node=$tp_size --master_port=10110 ./examples/programmatic/train_llama_scope.py --total_training_tokens $total_training_tokens --layer $layer --lr $lr --clip_grad_norm 0.001 --exp_factor $exp_factor --batch_size 2048 --tp_size $tp_size --buffer_size 500000 --log_to_wandb false --store_batch_size 32 --k $k --tc_in_abbr $tc_in_abbr --tc_out_abbr $tc_out_abbr 36 | -------------------------------------------------------------------------------- /examples/train_pythia.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lm_saes import ( 4 | ActivationFactoryConfig, 5 | ActivationFactoryDatasetSource, 6 | ActivationFactoryTarget, 7 | DatasetConfig, 8 | InitializerConfig, 9 | LanguageModelConfig, 10 | SAEConfig, 11 | TrainerConfig, 12 | TrainSAESettings, 13 | WandbConfig, 14 | train_sae, 15 | ) 16 | 17 | if __name__ == "__main__": 18 | settings = TrainSAESettings( 19 | sae=SAEConfig( 20 | hook_point_in="blocks.3.ln1.hook_normalized", 21 | hook_point_out="blocks.3.ln1.hook_normalized", 22 | d_model=768, 23 | expansion_factor=8, 24 | act_fn="topk", 25 | norm_activation="token-wise", 26 | sparsity_include_decoder_norm=True, 27 | top_k=50, 28 | dtype=torch.float32, 29 | device="cuda", 30 | ), 31 | initializer=InitializerConfig( 32 | init_search=True, 33 | state="training", 34 | ), 35 | trainer=TrainerConfig( 36 | lp=1, 37 | initial_k=768 / 2, 38 | lr=4e-4, 39 | lr_scheduler_name="constantwithwarmup", 40 | total_training_tokens=600_000_000, 41 | log_frequency=1000, 42 | eval_frequency=1000000, 43 | n_checkpoints=5, 44 | check_point_save_mode="linear", 45 | exp_result_path="results", 46 | ), 47 | wandb=WandbConfig( 48 | wandb_project="pythia-160m-test", 49 | exp_name="pythia-160m-test", 50 | ), 51 | activation_factory=ActivationFactoryConfig( 52 | sources=[ 53 | ActivationFactoryDatasetSource( 54 | name="openwebtext", 55 | ) 56 | ], 57 | target=ActivationFactoryTarget.BATCHED_ACTIVATIONS_1D, 58 | hook_points=["blocks.3.ln1.hook_normalized"], 59 | batch_size=2048, 60 | buffer_size=None, 61 | ignore_token_ids=[], 62 | ), 63 | sae_name="pythia-160m-test-L3", 64 | sae_series="pythia-160m-test", 65 | model=LanguageModelConfig( 66 | model_name="EleutherAI/pythia-160m", 67 | device="cuda", 68 | dtype="torch.float32", 69 | ), 70 | model_name="pythia-160m", 71 | datasets={ 72 | "openwebtext": DatasetConfig( 73 | dataset_name_or_path="Skylion007/openwebtext", 74 | ) 75 | }, 76 | ) 77 | train_sae(settings) 78 | -------------------------------------------------------------------------------- /examples/train_pythia_pre_generated_acts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lm_saes import ( 4 | ActivationFactoryActivationsSource, 5 | ActivationFactoryConfig, 6 | ActivationFactoryTarget, 7 | InitializerConfig, 8 | SAEConfig, 9 | TrainerConfig, 10 | TrainSAESettings, 11 | WandbConfig, 12 | train_sae, 13 | ) 14 | 15 | if __name__ == "__main__": 16 | settings = TrainSAESettings( 17 | sae=SAEConfig( 18 | hook_point_in="blocks.3.ln1.hook_normalized", 19 | hook_point_out="blocks.3.ln1.hook_normalized", 20 | d_model=768, 21 | expansion_factor=8, 22 | act_fn="topk", 23 | norm_activation="token-wise", 24 | sparsity_include_decoder_norm=True, 25 | top_k=50, 26 | dtype=torch.float32, 27 | device="cuda", 28 | ), 29 | initializer=InitializerConfig( 30 | init_search=True, 31 | state="training", 32 | ), 33 | trainer=TrainerConfig( 34 | lp=1, 35 | initial_k=768 / 2, 36 | lr=4e-4, 37 | lr_scheduler_name="constantwithwarmup", 38 | total_training_tokens=600_000_000, 39 | log_frequency=1000, 40 | eval_frequency=1000000, 41 | n_checkpoints=5, 42 | check_point_save_mode="linear", 43 | exp_result_path="results", 44 | ), 45 | wandb=WandbConfig( 46 | wandb_project="pythia-160m-test", 47 | exp_name="pythia-160m-test", 48 | ), 49 | activation_factory=ActivationFactoryConfig( 50 | sources=[ 51 | ActivationFactoryActivationsSource( 52 | name="openwebtext", 53 | path="activations", 54 | sample_weights=1.0, 55 | device="cuda", 56 | ) 57 | ], 58 | target=ActivationFactoryTarget.BATCHED_ACTIVATIONS_1D, 59 | hook_points=["blocks.3.ln1.hook_normalized"], 60 | batch_size=2048, 61 | buffer_size=None, 62 | ignore_token_ids=[], 63 | ), 64 | sae_name="pythia-160m-test-L3", 65 | sae_series="pythia-160m-test", 66 | ) 67 | train_sae(settings) 68 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "lm-saes" 3 | version = "0.1.0" 4 | description = "For OpenMOSS Mechanistic Interpretability Team's Sparse Autoencoder (SAE) research. Open-sourced and constantly updated." 5 | dependencies = [ 6 | "datasets>=3.0.2", 7 | "transformers>=4.46.0", 8 | "einops>=0.8.0", 9 | "fastapi>=0.115.4", 10 | "matplotlib>=3.9.2", 11 | "numpy>=2.1.2", 12 | "pandas>=2.2.3", 13 | "pymongo>=4.10.1", 14 | "tensorboardX>=2.6.2.2", 15 | "torch>=2.5.0", 16 | "tqdm>=4.66.5", 17 | "transformer-lens", 18 | "uvicorn>=0.32.0", 19 | "wandb>=0.18.5", 20 | "msgpack>=1.1.0", 21 | "plotly>=5.24.1", 22 | "openai>=1.52.2", 23 | "tiktoken>=0.8.0", 24 | "python-dotenv>=1.0.1", 25 | "jaxtyping>=0.2.34", 26 | "safetensors>=0.4.5", 27 | "pydantic>=2.10.6", 28 | "argparse>=1.4.0", 29 | "pyyaml>=6.0.2", 30 | "tomlkit>=0.13.2", 31 | "torchvision>=0.20.1", 32 | "pydantic-settings>=2.7.1", 33 | ] 34 | requires-python = "==3.12.*" 35 | readme = "README.md" 36 | 37 | [[project.authors]] 38 | name = "Xuyang Ge" 39 | email = "xyge20@fudan.edu.cn" 40 | 41 | [[project.authors]] 42 | name = "Zhengfu He" 43 | email = "zfhe19@fudan.edu.cn" 44 | 45 | [[project.authors]] 46 | name = "Wentao Shu" 47 | email = "wtshu20@fudan.edu.cn" 48 | 49 | [[project.authors]] 50 | name = "Fukang Zhu" 51 | email = "fkzhu21@m.fudan.edu.cn" 52 | 53 | [[project.authors]] 54 | name = "Lingjie Chen" 55 | email = "ljchen21@m.fudan.edu.cn" 56 | 57 | [[project.authors]] 58 | name = "Junxuan Wang" 59 | email = "junxuanwang21@m.fudan.edu.cn" 60 | 61 | [project.license] 62 | text = "MIT" 63 | 64 | [project.scripts] 65 | lm-saes = "lm_saes.entrypoint:entrypoint" 66 | 67 | [dependency-groups] 68 | dev = [ 69 | "transformer-lens", 70 | "jupyter>=1.1.1", 71 | "ipywidgets>=8.1.5", 72 | "pytest>=8.3.3", 73 | "ipykernel>=6.29.5", 74 | "nbformat>=5.10.4", 75 | "kaleido==0.2.1", 76 | "pre-commit>=4.0.1", 77 | "ruff>=0.7.1", 78 | "basedpyright>=1.21.0", 79 | "scikit-learn>=1.6.0", 80 | "plotly>=5.24.1", 81 | "pandas>=2.2.3", 82 | "pytest-mock>=3.14.0", 83 | "typeguard>=4.4.1", 84 | "pyfakefs>=5.7.3", 85 | "mongomock>=4.3.0", 86 | "qwen-vl-utils>=0.0.10", 87 | ] 88 | triton = [ 89 | "triton>=3.1.0", 90 | ] 91 | 92 | [tool.ruff] 93 | exclude = [".bzr", ".direnv", ".eggs", ".git", ".git-rewrite", ".hg", ".ipynb_checkpoints", ".mypy_cache", ".nox", ".pants.d", ".pyenv", ".pytest_cache", ".pytype", ".ruff_cache", ".svn", ".tox", ".venv", ".vscode", "__pypackages__", "_build", "buck-out", "build", "dist", "node_modules", "site-packages", "venv", "TransformerLens", "ui"] 94 | line-length = 120 95 | indent-width = 4 96 | target-version = "py310" 97 | 98 | [tool.ruff.lint] 99 | select = ["E4", "E7", "E9", "F", "I"] 100 | ignore = ["E741", "F722"] 101 | fixable = ["ALL"] 102 | unfixable = [] 103 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 104 | 105 | [tool.ruff.format] 106 | quote-style = "double" 107 | indent-style = "space" 108 | skip-magic-trailing-comma = false 109 | line-ending = "auto" 110 | docstring-code-format = false 111 | docstring-code-line-length = "dynamic" 112 | 113 | [tool.pyright] 114 | ignore = [".venv/", "examples", "TransformerLens", "tests", "exp"] 115 | typeCheckingMode = "standard" 116 | reportRedeclaration = false 117 | 118 | [tool.uv] 119 | package = true 120 | 121 | [tool.uv.sources.transformer-lens] 122 | path = "./TransformerLens" 123 | editable = true 124 | -------------------------------------------------------------------------------- /server/.env.example: -------------------------------------------------------------------------------- 1 | MONGO_URI= # Must fill in 2 | SAE_SERIES= # Must fill in -------------------------------------------------------------------------------- /server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/Language-Model-SAEs/6638b73672bae0873a3ac185e468d9adcb8e2282/server/__init__.py -------------------------------------------------------------------------------- /src/lm_saes/__init__.py: -------------------------------------------------------------------------------- 1 | from .activation import ActivationFactory, ActivationWriter 2 | from .analysis import FeatureAnalyzer 3 | from .config import ( 4 | ActivationFactoryActivationsSource, 5 | ActivationFactoryConfig, 6 | ActivationFactoryDatasetSource, 7 | ActivationFactoryTarget, 8 | ActivationWriterConfig, 9 | BufferShuffleConfig, 10 | CrossCoderConfig, 11 | DatasetConfig, 12 | FeatureAnalyzerConfig, 13 | InitializerConfig, 14 | LanguageModelConfig, 15 | MixCoderConfig, 16 | MongoDBConfig, 17 | SAEConfig, 18 | TrainerConfig, 19 | WandbConfig, 20 | ) 21 | from .crosscoder import CrossCoder 22 | from .database import MongoClient 23 | from .resource_loaders import load_dataset, load_model 24 | from .runner import ( 25 | AnalyzeSAESettings, 26 | GenerateActivationsSettings, 27 | TrainSAESettings, 28 | analyze_sae, 29 | generate_activations, 30 | train_sae, 31 | ) 32 | from .sae import SparseAutoEncoder 33 | 34 | __all__ = [ 35 | "ActivationFactory", 36 | "ActivationWriter", 37 | "CrossCoderConfig", 38 | "CrossCoder", 39 | "SparseAutoEncoder", 40 | "LanguageModelConfig", 41 | "DatasetConfig", 42 | "ActivationFactoryActivationsSource", 43 | "ActivationFactoryDatasetSource", 44 | "ActivationFactoryConfig", 45 | "ActivationWriterConfig", 46 | "BufferShuffleConfig", 47 | "ActivationFactoryTarget", 48 | "load_dataset", 49 | "load_model", 50 | "FeatureAnalyzer", 51 | "GenerateActivationsSettings", 52 | "generate_activations", 53 | "InitializerConfig", 54 | "SAEConfig", 55 | "TrainerConfig", 56 | "WandbConfig", 57 | "train_sae", 58 | "TrainSAESettings", 59 | "AnalyzeSAESettings", 60 | "analyze_sae", 61 | "FeatureAnalyzerConfig", 62 | "MongoDBConfig", 63 | "MongoClient", 64 | "MixCoderConfig", 65 | ] 66 | -------------------------------------------------------------------------------- /src/lm_saes/activation/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import ActivationFactory 2 | from .processors import ( 3 | ActivationBatchler, 4 | BaseActivationProcessor, 5 | HuggingFaceDatasetLoader, 6 | PadAndTruncateTokensProcessor, 7 | RawDatasetTokenProcessor, 8 | ) 9 | from .writer import ActivationWriter 10 | 11 | __all__ = [ 12 | "ActivationFactory", 13 | "BaseActivationProcessor", 14 | "ActivationBatchler", 15 | "HuggingFaceDatasetLoader", 16 | "PadAndTruncateTokensProcessor", 17 | "RawDatasetTokenProcessor", 18 | "ActivationWriter", 19 | ] 20 | -------------------------------------------------------------------------------- /src/lm_saes/activation/processors/__init__.py: -------------------------------------------------------------------------------- 1 | from .activation import ActivationBatchler 2 | from .cached_activation import CachedActivationLoader 3 | from .core import BaseActivationProcessor 4 | from .huggingface import HuggingFaceDatasetLoader 5 | from .token import PadAndTruncateTokensProcessor, RawDatasetTokenProcessor 6 | 7 | __all__ = [ 8 | "BaseActivationProcessor", 9 | "ActivationBatchler", 10 | "HuggingFaceDatasetLoader", 11 | "PadAndTruncateTokensProcessor", 12 | "RawDatasetTokenProcessor", 13 | "CachedActivationLoader", 14 | ] 15 | -------------------------------------------------------------------------------- /src/lm_saes/activation/processors/core.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, Iterable, TypeVar 3 | 4 | # Define type variables for input and output iterable containers 5 | InputT = TypeVar("InputT", bound=Iterable | None) 6 | OutputT = TypeVar("OutputT", bound=Iterable) 7 | 8 | 9 | class BaseActivationProcessor(Generic[InputT, OutputT], ABC): 10 | """Base class for activation processors. 11 | 12 | An activation processor transforms a stream of input data into a stream of output data. 13 | Common use cases include batching, filtering, and transforming activation data from language models. 14 | 15 | The processor can be used either by calling its `process()` method directly or by using the instance 16 | as a callable via `__call__()`. 17 | 18 | TypeVars: 19 | InputT: The type of input iterable container (e.g. List, Generator, Dataset) 20 | OutputT: The type of output iterable container (e.g. DataLoader, Generator) 21 | """ 22 | 23 | @abstractmethod 24 | def process(self, data: InputT, **kwargs) -> OutputT: 25 | """Process the input data stream and return transformed output stream. 26 | 27 | Args: 28 | data: Input data stream to process 29 | **kwargs: Additional keyword arguments for processing 30 | 31 | Returns: 32 | Processed output data stream 33 | 34 | Raises: 35 | NotImplementedError: This is an abstract method that must be implemented by subclasses 36 | """ 37 | raise NotImplementedError 38 | 39 | def __call__(self, data: InputT, **kwargs) -> OutputT: 40 | """Process data by calling the processor instance directly. 41 | 42 | This is a convenience wrapper around the process() method. 43 | 44 | Args: 45 | data: Input data stream to process 46 | **kwargs: Additional keyword arguments for processing 47 | Returns: 48 | Processed output data stream 49 | """ 50 | return self.process(data, **kwargs) 51 | -------------------------------------------------------------------------------- /src/lm_saes/activation/processors/huggingface.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Any, Iterable, Optional, cast 3 | 4 | import torch 5 | from datasets import Dataset 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | 9 | from lm_saes.activation.processors.core import BaseActivationProcessor 10 | 11 | 12 | def identity_collate_fn(x): 13 | """Identity collate function that returns the input as is. 14 | 15 | This should be defined in the global scope so it can be pickled when multiprocessing with `num_workers > 0`. 16 | """ 17 | return x 18 | 19 | 20 | class HuggingFaceDatasetLoader(BaseActivationProcessor[Dataset, Iterable[dict[str, Any]]]): 21 | """A processor that directly loads raw dataset from HuggingFace datasets. 22 | 23 | This processor takes a HuggingFace dataset and converts it into an iterable of raw data records. 24 | It can optionally add context index information to each data record and show a progress bar. 25 | 26 | Args: 27 | batch_size (int): Number of samples per batch. The batch is only used for loading the dataset. 28 | The returned data is always a flattened list of raw data records. 29 | num_workers (int, optional): Number of workers to use for loading the dataset. 30 | Defaults to 0. Use a larger `num_workers` if you have a large dataset and want to speed up loading. 31 | with_info (bool, optional): Whether to include context index information with each batch. 32 | If True, returns tuples of (batch, info) where info contains context indices. This context 33 | index is used to identify the original sample in the dataset. 34 | Defaults to False. 35 | show_progress (bool, optional): Whether to display a progress bar while loading. 36 | Defaults to True. 37 | 38 | Returns: 39 | Iterable: An iterator over batches from the dataset. If with_info=True, yields tuples of 40 | (batch, info) where info contains context indices for each sample in the batch. 41 | """ 42 | 43 | def __init__(self, batch_size: int, num_workers: int = 0, with_info: bool = False, show_progress: bool = True): 44 | self.batch_size = batch_size 45 | self.num_workers = num_workers 46 | self.with_info = with_info 47 | self.show_progress = show_progress 48 | 49 | def process( 50 | self, 51 | data: Dataset, 52 | *, 53 | dataset_name: str | None = None, 54 | metadata: Optional[dict[str, Any]] = None, 55 | **kwargs, 56 | ) -> Iterable[dict[str, Any]]: 57 | """Process the input dataset into batches. 58 | 59 | Args: 60 | data (Dataset): Input HuggingFace dataset to process 61 | dataset_name (str, optional): Name of the dataset. If provided, it will be added to the info field. 62 | Defaults to None. 63 | metadata (dict[str, Any], optional): Metadata to add to each batch. Defaults to None. 64 | **kwargs: Additional keyword arguments for processing. Not used by this processor. 65 | 66 | Returns: 67 | Iterable: Iterator over batches, optionally with context info if with_info=True 68 | """ 69 | 70 | dataloader = cast( 71 | Iterable[list[dict[str, Any]]], 72 | DataLoader( 73 | cast(torch.utils.data.Dataset, data), 74 | batch_size=self.batch_size, 75 | shuffle=False, 76 | pin_memory=True, 77 | collate_fn=identity_collate_fn, 78 | num_workers=self.num_workers, 79 | ), 80 | ) 81 | 82 | if self.show_progress: 83 | dataloader = tqdm(dataloader, desc="Loading dataset") 84 | 85 | flattened = itertools.chain.from_iterable(dataloader) 86 | 87 | if self.with_info: 88 | flattened = map( 89 | lambda x: x[1] 90 | | { 91 | "meta": { 92 | "context_idx": x[0], 93 | **({"dataset_name": dataset_name} if dataset_name else {}), 94 | **(metadata if metadata else {}), 95 | } 96 | }, 97 | enumerate(flattened), 98 | ) 99 | 100 | return flattened 101 | -------------------------------------------------------------------------------- /src/lm_saes/activation/processors/token.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable, Optional, cast 2 | 3 | import torch 4 | from transformer_lens import HookedTransformer 5 | 6 | from lm_saes.activation.processors.core import BaseActivationProcessor 7 | 8 | 9 | def pad_and_truncate_tokens( 10 | tokens: torch.Tensor, 11 | seq_len: int, 12 | pad_token_id: int = 0, 13 | ) -> torch.Tensor: 14 | """Pad tokens to desired sequence length. 15 | 16 | Args: 17 | tokens: Input tokens tensor or list of token tensors to pad 18 | seq_len: Desired sequence length after padding 19 | pad_token_id: Token ID to use for padding (default: 0) 20 | 21 | Returns: 22 | torch.Tensor: Padded token tensor with shape (batch_size, seq_len) 23 | """ 24 | if tokens.size(-1) > seq_len: 25 | return tokens[..., :seq_len] 26 | 27 | pad_len = seq_len - tokens.size(-1) 28 | 29 | padding = torch.full( 30 | (*tokens.shape[:-1], pad_len), 31 | pad_token_id, 32 | dtype=torch.long, 33 | device=tokens.device, 34 | ) 35 | return torch.cat([tokens, padding], dim=-1) 36 | 37 | 38 | class RawDatasetTokenProcessor(BaseActivationProcessor[Iterable[dict[str, Any]], Iterable[dict[str, Any]]]): 39 | """Processor for converting raw token datasets into model-ready token format. 40 | 41 | This processor takes an iterable of dictionaries containing raw data (e.g. text and images) and converts 42 | them into a tokens. The output is a dictionary with a "tokens" key, which contains the (non-padded and non-truncated) 43 | tokens. The "meta" key is preserved if it exists in the input. 44 | 45 | Args: 46 | prepend_bos: Whether to prepend beginning-of-sequence token. If None, uses model default. 47 | """ 48 | 49 | def __init__(self, prepend_bos: bool | None = None): 50 | self.prepend_bos = prepend_bos 51 | 52 | def process( 53 | self, data: Iterable[dict[str, Any]], *, model: HookedTransformer, **kwargs 54 | ) -> Iterable[dict[str, Any]]: 55 | """Process raw data into tokens. 56 | 57 | Args: 58 | data: Iterable of dictionaries containing raw data (e.g. text and images) 59 | model: HookedTransformer model to use for producing tokens 60 | **kwargs: Additional keyword arguments. Not used by this processor. 61 | 62 | Yields: 63 | dict: Processed token data with optional info field 64 | """ 65 | for d in data: 66 | tokens = model.to_tokens_with_origins(d, tokens_only=True, prepend_bos=self.prepend_bos) 67 | ret = {"tokens": tokens[0]} 68 | if "meta" in d: 69 | ret = ret | {"meta": d["meta"]} 70 | yield ret 71 | 72 | 73 | class PadAndTruncateTokensProcessor(BaseActivationProcessor[Iterable[dict[str, Any]], Iterable[dict[str, Any]]]): 74 | """Processor for padding and truncating tokens to a desired sequence length. 75 | 76 | This processor takes an iterable of dictionaries containing tokens and pads them to a desired sequence length. 77 | The output is a dictionary with a "tokens" key, which contains the padded tokens. 78 | 79 | Args: 80 | seq_len (int): The desired sequence length to pad/truncate to 81 | pad_token_id (int, optional): The token ID to use for padding. Defaults to 0. 82 | """ 83 | 84 | def __init__(self, seq_len: int): 85 | self.seq_len = seq_len 86 | 87 | def process( 88 | self, 89 | data: Iterable[dict[str, Any]], 90 | *, 91 | pad_token_id: Optional[int] = None, 92 | model: Optional[HookedTransformer] = None, 93 | **kwargs, 94 | ) -> Iterable[dict[str, Any]]: 95 | """Process tokens by padding or truncating to desired sequence length. 96 | 97 | Args: 98 | data (Iterable[dict[str, Any]]): Input data containing tokens to process 99 | pad_token_id (int, optional): The token ID to use for padding. Defaults to None. 100 | If not specified, the pad_token_id will be inferred from the model's tokenizer. 101 | If neither is provided, the pad_token_id will be 0. 102 | model (HookedTransformer, optional): The model to use for padding. Defaults to None. 103 | If provided, the pad_token_id will be inferred from the model's tokenizer. 104 | **kwargs: Additional keyword arguments. Not used by this processor. 105 | 106 | Yields: 107 | dict[str, Any]: Dictionary containing processed tokens padded/truncated to seq_len, 108 | and original info field if present 109 | """ 110 | 111 | # Infer pad_token_id if not provided 112 | if pad_token_id is None: 113 | if model is not None: 114 | tokenizer = model.tokenizer 115 | assert tokenizer is not None, "model must have a tokenizer" 116 | pad_token_id = cast(int, tokenizer.pad_token_id) 117 | else: 118 | pad_token_id = 0 119 | 120 | for d in data: 121 | assert "tokens" in d and isinstance(d["tokens"], torch.Tensor) 122 | tokens = pad_and_truncate_tokens(d["tokens"], seq_len=self.seq_len, pad_token_id=pad_token_id) 123 | ret = {"tokens": tokens} 124 | if "meta" in d: 125 | ret = ret | {"meta": d["meta"]} 126 | yield ret 127 | -------------------------------------------------------------------------------- /src/lm_saes/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | from .feature_analyzer import FeatureAnalyzer 2 | 3 | __all__ = ["FeatureAnalyzer"] 4 | -------------------------------------------------------------------------------- /src/lm_saes/analysis/features_to_logits.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | # from transformer_lens import HookedTransformer 3 | 4 | # from ..config import FeaturesDecoderConfig 5 | # from ..sae import SparseAutoEncoder 6 | 7 | 8 | # @torch.no_grad() 9 | # def features_to_logits(sae: SparseAutoEncoder, model: HookedTransformer, cfg: FeaturesDecoderConfig): 10 | # num_ones = int(torch.sum(sae.feature_act_mask).item()) 11 | 12 | # feature_acts = torch.zeros(num_ones, cfg.sae.d_sae).to(cfg.sae.device) 13 | 14 | # index = 0 15 | # for i in range(len(sae.feature_act_mask)): 16 | # if sae.feature_act_mask[i] == 1: 17 | # feature_acts[index, i] = 1 18 | # index += 1 19 | 20 | # feature_acts = torch.unsqueeze(feature_acts, dim=1) 21 | 22 | # residual = sae.decode(feature_acts) 23 | 24 | # if model.cfg.normalization_type is not None: 25 | # residual = model.ln_final(residual) # [batch, pos, d_model] 26 | # logits = model.unembed(residual) # [batch, pos, d_vocab] 27 | 28 | # active_indices = [i for i, val in enumerate(sae.feature_act_mask) if val == 1] 29 | # result_dict = {str(feature_index): logits[idx][0] for idx, feature_index in enumerate(active_indices)} 30 | 31 | # return result_dict 32 | -------------------------------------------------------------------------------- /src/lm_saes/circuit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/Language-Model-SAEs/6638b73672bae0873a3ac185e468d9adcb8e2282/src/lm_saes/circuit/__init__.py -------------------------------------------------------------------------------- /src/lm_saes/circuit/context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import Callable, Tuple, Union 3 | 4 | import torch 5 | from transformer_lens.hook_points import HookedRootModule, HookPoint 6 | 7 | from ..sae import SparseAutoEncoder 8 | 9 | 10 | @contextmanager 11 | def apply_sae(model: HookedRootModule, saes: list[SparseAutoEncoder]): 12 | """ 13 | Apply the sparse autoencoders to the model. 14 | """ 15 | fwd_hooks: list[Tuple[Union[str, Callable], Callable]] = [] 16 | 17 | def get_fwd_hooks(sae: SparseAutoEncoder) -> list[Tuple[Union[str, Callable], Callable]]: 18 | if sae.cfg.hook_point_in == sae.cfg.hook_point_out: 19 | 20 | def hook(tensor: torch.Tensor, hook: HookPoint): 21 | reconstructed = sae.forward(tensor) 22 | return reconstructed + (tensor - reconstructed).detach() 23 | 24 | return [(sae.cfg.hook_point_in, hook)] 25 | else: 26 | x = None 27 | 28 | def hook_in(tensor: torch.Tensor, hook: HookPoint): 29 | nonlocal x 30 | x = tensor 31 | return tensor 32 | 33 | def hook_out(tensor: torch.Tensor, hook: HookPoint): 34 | nonlocal x 35 | assert x is not None, "hook_in must be called before hook_out." 36 | reconstructed = sae.forward(x) 37 | x = None 38 | return reconstructed + (tensor - reconstructed).detach() 39 | 40 | return [(sae.cfg.hook_point_in, hook_in), (sae.cfg.hook_point_out, hook_out)] 41 | 42 | for sae in saes: 43 | hooks = get_fwd_hooks(sae) 44 | fwd_hooks.extend(hooks) 45 | with model.mount_hooked_modules([(sae.cfg.hook_point_out, "sae", sae) for sae in saes]): 46 | with model.hooks(fwd_hooks): 47 | yield model 48 | 49 | 50 | @contextmanager 51 | def detach_at( 52 | model: HookedRootModule, 53 | hook_points: list[str], 54 | ): 55 | """ 56 | Detach the gradients on the given hook points. 57 | """ 58 | 59 | def generate_hook(): 60 | hook_pre = HookPoint() 61 | hook_post = HookPoint() 62 | 63 | def hook(tensor: torch.Tensor, hook: HookPoint): 64 | return hook_post(hook_pre(tensor).detach().requires_grad_()) 65 | 66 | return hook_pre, hook_post, hook 67 | 68 | hooks = {hook_point: generate_hook() for hook_point in hook_points} 69 | fwd_hooks = [(hook_point, hook) for hook_point, (_, _, hook) in hooks.items()] 70 | with model.mount_hooked_modules( 71 | [(hook_point, "pre", hook) for hook_point, (hook, _, _) in hooks.items()] 72 | + [(hook_point, "post", hook) for hook_point, (_, hook, _) in hooks.items()] 73 | ): 74 | with model.hooks(fwd_hooks): 75 | yield model 76 | -------------------------------------------------------------------------------- /src/lm_saes/optim.py: -------------------------------------------------------------------------------- 1 | """ 2 | Took the LR scheduler from: https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425 3 | """ 4 | 5 | import math 6 | from typing import Any, Optional 7 | 8 | import torch.optim as optim 9 | import torch.optim.lr_scheduler as lr_scheduler 10 | 11 | 12 | # None 13 | # Linear Warmup and decay 14 | # Cosine Annealing with Warmup 15 | # Cosine Annealing with Warmup / Restarts 16 | def get_scheduler(scheduler_name: Optional[str], optimizer: optim.Optimizer, **kwargs: Any): 17 | """ 18 | Loosely based on this, seemed simpler write this than import 19 | transformers: https://huggingface.co/docs/transformers/main_classes/optimizer_schedules 20 | 21 | Args: 22 | scheduler_name (Optional[str]): Name of the scheduler to use. If None, returns a constant scheduler 23 | optimizer (optim.Optimizer): Optimizer to use 24 | **kwargs: Additional arguments to pass to the scheduler including warm_up_steps, 25 | training_steps, num_cycles, lr_end. 26 | """ 27 | 28 | def get_smoothing_lambda( 29 | training_steps: int, warm_up_steps: int, gamma: float, cool_down_steps: int, lr_end: float 30 | ): 31 | smooth_steps = gamma * warm_up_steps 32 | 33 | def lr_lambda(steps: int): 34 | if steps < smooth_steps: 35 | return 2 * (steps + 1) / (warm_up_steps * (1 + gamma)) 36 | elif steps < warm_up_steps: 37 | return 1 - ((steps / warm_up_steps - 1) ** 2) / (1 - gamma**2) 38 | elif steps < cool_down_steps: 39 | return 1.0 40 | else: 41 | progress = (steps - cool_down_steps) / (training_steps - cool_down_steps) 42 | return lr_end + 0.5 * (1 - lr_end) * (1 + math.cos(math.pi * progress)) 43 | 44 | return lr_lambda 45 | 46 | def get_warmup_lambda(warm_up_steps: int, training_steps: int): 47 | def lr_lambda(steps: int): 48 | if steps < warm_up_steps: 49 | return (steps + 1) / warm_up_steps 50 | else: 51 | return (training_steps - steps) / (training_steps - warm_up_steps) 52 | 53 | return lr_lambda 54 | 55 | # heavily derived from hugging face although copilot helped. 56 | def get_warmup_cosine_lambda(warm_up_steps: int, training_steps: int, lr_end: float): 57 | def lr_lambda(steps: int): 58 | if steps < warm_up_steps: 59 | return (steps + 1) / warm_up_steps 60 | else: 61 | progress = (steps - warm_up_steps) / (training_steps - warm_up_steps) 62 | return lr_end + 0.5 * (1 - lr_end) * (1 + math.cos(math.pi * progress)) 63 | 64 | return lr_lambda 65 | 66 | def get_warmup_exp_lambda(warm_up_steps: int, training_steps: int, lr_end: float): 67 | def lr_lambda(steps: int): 68 | if steps < warm_up_steps: 69 | return (steps + 1) / warm_up_steps 70 | else: 71 | return math.pow(lr_end, (steps - warm_up_steps) / (training_steps - warm_up_steps)) 72 | 73 | return lr_lambda 74 | 75 | if scheduler_name is None or scheduler_name.lower() == "constant": 76 | return lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda steps: 1.0) 77 | elif scheduler_name.lower() == "constantwithwarmup": 78 | warm_up_steps = kwargs.get("warm_up_steps", 0) 79 | cool_down_steps = kwargs.get("cool_down_steps", 0) 80 | training_steps = kwargs.get("training_steps") 81 | lr_end_ratio = kwargs.get("lr_end_ratio", 0.0) 82 | 83 | assert training_steps is not None, "training_steps must be provided" 84 | return lr_scheduler.LambdaLR( 85 | optimizer, 86 | lr_lambda=lambda steps: min( 87 | 1.0, 88 | (steps + 1) / warm_up_steps, 89 | lr_end_ratio + (1 - lr_end_ratio) / cool_down_steps * max(training_steps - steps, 1), # type: ignore 90 | ), 91 | ) 92 | elif scheduler_name.lower() == "constantwithwarmupsmooth": 93 | warm_up_steps = kwargs.get("warm_up_steps", 0) 94 | training_steps = kwargs.get("training_steps") 95 | assert training_steps is not None, "training_steps must be provided" 96 | cool_down_steps = training_steps - int(1.5 * warm_up_steps) 97 | assert training_steps is not None, "training_steps must be provided" 98 | lr_lambda = get_smoothing_lambda(training_steps, warm_up_steps, 0.5, cool_down_steps, 0.0) 99 | return lr_scheduler.LambdaLR(optimizer, lr_lambda) 100 | elif scheduler_name.lower() == "linearwarmupdecay": 101 | warm_up_steps = kwargs.get("warm_up_steps", 0) 102 | training_steps = kwargs.get("training_steps") 103 | assert training_steps is not None, "training_steps must be provided" 104 | lr_lambda = get_warmup_lambda(warm_up_steps, training_steps) 105 | return lr_scheduler.LambdaLR(optimizer, lr_lambda) 106 | elif scheduler_name.lower() == "cosineannealing": 107 | training_steps = kwargs.get("training_steps") 108 | assert training_steps is not None, "training_steps must be provided" 109 | eta_min = kwargs.get("lr_end", 0) 110 | return lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_steps, eta_min=eta_min) 111 | elif scheduler_name.lower() == "cosineannealingwarmup": 112 | warm_up_steps = kwargs.get("warm_up_steps", 0) 113 | training_steps = kwargs.get("training_steps") 114 | assert training_steps is not None, "training_steps must be provided" 115 | eta_min = kwargs.get("lr_end", 0) 116 | lr_lambda = get_warmup_cosine_lambda(warm_up_steps, training_steps, eta_min) 117 | return lr_scheduler.LambdaLR(optimizer, lr_lambda) 118 | elif scheduler_name.lower() == "cosineannealingwarmrestarts": 119 | training_steps = kwargs.get("training_steps") 120 | eta_min = kwargs.get("lr_end", 0) 121 | num_cycles = kwargs.get("num_cycles", 1) 122 | T_0 = training_steps // num_cycles 123 | return lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=T_0, eta_min=eta_min) 124 | elif scheduler_name.lower() == "exponentialwarmup": 125 | warm_up_steps = kwargs.get("warm_up_steps", 0) 126 | training_steps = kwargs.get("training_steps") 127 | assert training_steps is not None, "training_steps must be provided" 128 | eta_min = kwargs.get("lr_end", 1 / 32) 129 | lr_lambda = get_warmup_exp_lambda(warm_up_steps, training_steps, eta_min) 130 | return lr_scheduler.LambdaLR(optimizer, lr_lambda) 131 | else: 132 | raise ValueError(f"Unsupported scheduler: {scheduler_name}") 133 | -------------------------------------------------------------------------------- /src/lm_saes/resource_loaders.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, cast 2 | 3 | import datasets 4 | import torch 5 | from torch.distributed.device_mesh import DeviceMesh 6 | from transformer_lens import HookedTransformer 7 | from transformers import ( 8 | AutoModelForCausalLM, 9 | AutoProcessor, 10 | AutoTokenizer, 11 | ChameleonForConditionalGeneration, 12 | PreTrainedModel, 13 | ) 14 | 15 | from lm_saes.config import DatasetConfig, LanguageModelConfig 16 | 17 | 18 | def load_dataset_shard( 19 | cfg: DatasetConfig, 20 | shard_idx: int, 21 | n_shards: int, 22 | ) -> datasets.Dataset: 23 | if not cfg.is_dataset_on_disk: 24 | dataset = datasets.load_dataset(cfg.dataset_name_or_path, split="train", cache_dir=cfg.cache_dir) 25 | else: 26 | dataset = datasets.load_from_disk(cfg.dataset_name_or_path) 27 | dataset = cast(datasets.Dataset, dataset) 28 | dataset = dataset.shard(num_shards=n_shards, index=shard_idx, contiguous=True) 29 | dataset = dataset.with_format("torch") 30 | return dataset 31 | 32 | 33 | def load_dataset( 34 | cfg: DatasetConfig, 35 | device_mesh: Optional[DeviceMesh] = None, 36 | n_shards: Optional[int] = None, 37 | start_shard: int = 0, 38 | ) -> tuple[datasets.Dataset, Optional[dict[str, Any]]]: 39 | if not cfg.is_dataset_on_disk: 40 | dataset = datasets.load_dataset( 41 | cfg.dataset_name_or_path, split="train", cache_dir=cfg.cache_dir, trust_remote_code=True 42 | ) 43 | else: 44 | dataset = datasets.load_from_disk(cfg.dataset_name_or_path) 45 | dataset = cast(datasets.Dataset, dataset) 46 | if device_mesh is not None: 47 | shard = dataset.shard( 48 | num_shards=n_shards or device_mesh.get_group("data").size(), 49 | index=start_shard + device_mesh.get_group("data").rank(), 50 | contiguous=True, 51 | ) 52 | shard_metadata = { 53 | "shard_idx": start_shard + device_mesh.get_group("data").rank(), 54 | "n_shards": n_shards or device_mesh.get_group("data").size(), 55 | } 56 | else: 57 | shard = dataset 58 | shard_metadata = None 59 | shard = shard.with_format("torch") 60 | return shard, shard_metadata 61 | 62 | 63 | def load_model(cfg: LanguageModelConfig): 64 | if cfg.device == "cuda": 65 | device = torch.device(f"cuda:{torch.cuda.current_device()}") 66 | else: 67 | device = torch.device(cfg.device) 68 | 69 | if "chameleon" in cfg.model_name: 70 | hf_model = ChameleonForConditionalGeneration.from_pretrained( 71 | (cfg.model_name if cfg.model_from_pretrained_path is None else cfg.model_from_pretrained_path), 72 | cache_dir=cfg.cache_dir, 73 | local_files_only=cfg.local_files_only, 74 | torch_dtype=cfg.dtype, 75 | ).to(device) # type: ignore 76 | else: 77 | hf_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( 78 | (cfg.model_name if cfg.model_from_pretrained_path is None else cfg.model_from_pretrained_path), 79 | cache_dir=cfg.cache_dir, 80 | local_files_only=cfg.local_files_only, 81 | torch_dtype=cfg.dtype, 82 | ).to(device) 83 | if "chameleon" in cfg.model_name: 84 | hf_processor = AutoProcessor.from_pretrained( 85 | (cfg.model_name if cfg.model_from_pretrained_path is None else cfg.model_from_pretrained_path), 86 | trust_remote_code=True, 87 | use_fast=True, 88 | add_bos_token=True, 89 | local_files_only=cfg.local_files_only, 90 | ) 91 | hf_tokenizer = None 92 | else: 93 | hf_tokenizer = AutoTokenizer.from_pretrained( 94 | (cfg.model_name if cfg.model_from_pretrained_path is None else cfg.model_from_pretrained_path), 95 | trust_remote_code=True, 96 | use_fast=True, 97 | add_bos_token=True, 98 | local_files_only=cfg.local_files_only, 99 | ) 100 | hf_processor = None 101 | 102 | model = HookedTransformer.from_pretrained_no_processing( 103 | cfg.model_name, 104 | use_flash_attn=cfg.use_flash_attn, 105 | device=device, 106 | cache_dir=cfg.cache_dir, 107 | hf_model=hf_model, 108 | hf_config=hf_model.config, 109 | tokenizer=hf_tokenizer, 110 | processor=hf_processor, 111 | dtype=cfg.dtype, 112 | ) 113 | model.eval() 114 | return model 115 | -------------------------------------------------------------------------------- /src/lm_saes/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/Language-Model-SAEs/6638b73672bae0873a3ac185e468d9adcb8e2282/src/lm_saes/utils/__init__.py -------------------------------------------------------------------------------- /src/lm_saes/utils/bytes.py: -------------------------------------------------------------------------------- 1 | import io 2 | from functools import lru_cache 3 | 4 | import numpy as np 5 | 6 | 7 | def np_to_bytes(arr): 8 | with io.BytesIO() as buffer: 9 | np.save(buffer, arr) 10 | return buffer.getvalue() 11 | 12 | 13 | def bytes_to_np(b): 14 | with io.BytesIO(b) as buffer: 15 | return np.load(buffer) 16 | 17 | 18 | @lru_cache() 19 | def bytes_to_unicode(): 20 | """ 21 | Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control 22 | characters the bpe code barfs on. 23 | 24 | The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab 25 | if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for 26 | decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup 27 | tables between utf-8 bytes and unicode strings. 28 | """ 29 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 30 | cs = bs[:] 31 | n = 0 32 | for b in range(2**8): 33 | if b not in bs: 34 | bs.append(b) 35 | cs.append(2**8 + n) 36 | n += 1 37 | cs = [chr(n) for n in cs] 38 | return dict(zip(bs, cs)) 39 | -------------------------------------------------------------------------------- /src/lm_saes/utils/concurrent.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import Future, ThreadPoolExecutor 2 | from queue import Queue 3 | from typing import Iterable, Optional, TypeVar 4 | 5 | from tqdm import tqdm 6 | 7 | T = TypeVar("T") 8 | 9 | 10 | class BackgroundGenerator(Iterable[T]): 11 | """Compute elements of a generator in a background thread pool. 12 | 13 | This class is optimized for scenarios where either the generator or the main thread 14 | performs GIL-releasing work (e.g., File I/O, network requests). Best used when the 15 | generator has no side effects on the program state. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | generator: Iterable[T], 21 | max_prefetch: int = 1, 22 | executor: Optional[ThreadPoolExecutor] = None, 23 | name: Optional[str] = None, 24 | ) -> None: 25 | """Initialize the background generator. 26 | 27 | Args: 28 | generator: The generator to compute elements from in a background thread. 29 | max_prefetch: Maximum number of elements to precompute ahead. If <= 0, 30 | the queue size is infinite. When max_prefetch elements have been 31 | computed, the background thread will wait for consumption. 32 | executor: Optional ThreadPoolExecutor to use. If None, creates a single-thread executor. 33 | """ 34 | self.queue = Queue(max_prefetch) 35 | self.generator = generator 36 | self._executor = executor or ThreadPoolExecutor(max_workers=1) 37 | self._owned_executor = executor is None 38 | self._future: Optional[Future] = None 39 | self.continue_iteration = True 40 | self._started = False 41 | self.pbar = tqdm(total=max_prefetch, desc=f"Background Processing {name}", smoothing=0.001, miniters=1) 42 | self._start() 43 | 44 | def _process_generator(self) -> None: 45 | """Process the generator items in the background thread.""" 46 | try: 47 | for item in self.generator: 48 | if not self.continue_iteration: 49 | break 50 | self.queue.put((True, item)) 51 | self.pbar.update(1) 52 | except Exception as e: 53 | self.queue.put((False, e)) 54 | self.pbar.update(1) 55 | finally: 56 | self.queue.put((False, StopIteration)) 57 | self.pbar.update(1) 58 | 59 | def _start(self) -> None: 60 | """Start the background processing.""" 61 | self._future = self._executor.submit(self._process_generator) 62 | 63 | def __next__(self) -> T: 64 | """Get the next item from the generator. 65 | 66 | Returns: 67 | The next item from the generator. 68 | 69 | Raises: 70 | StopIteration: When the generator is exhausted. 71 | Exception: Any exception raised by the generator. 72 | """ 73 | if self.continue_iteration: 74 | success, next_item = self.queue.get() 75 | self.pbar.update(-1) 76 | if success: 77 | return next_item 78 | else: 79 | self.continue_iteration = False 80 | raise next_item 81 | else: 82 | raise StopIteration 83 | 84 | def __iter__(self) -> "BackgroundGenerator[T]": 85 | """Return self as iterator. 86 | 87 | Returns: 88 | Self as iterator. 89 | """ 90 | return self 91 | 92 | def close(self) -> None: 93 | """Close the generator and clean up resources.""" 94 | self.continue_iteration = False 95 | if self._future is not None: 96 | self._future.cancel() 97 | # Clear the queue to unblock any waiting threads 98 | while not self.queue.empty(): 99 | try: 100 | self.queue.get_nowait() 101 | self.pbar.update(-1) 102 | except Exception: 103 | pass 104 | if self._owned_executor: 105 | self._executor.shutdown(wait=False, cancel_futures=True) 106 | self.pbar.close() 107 | 108 | def __del__(self) -> None: 109 | """Clean up resources when the object is deleted.""" 110 | self.close() 111 | -------------------------------------------------------------------------------- /src/lm_saes/utils/convert_pre_enc_bias.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..sae import SparseAutoEncoder 4 | 5 | 6 | @torch.no_grad() 7 | def merge_pre_enc_bias_to_enc_bias(sae: SparseAutoEncoder): 8 | assert sae.cfg.apply_decoder_bias_to_pre_encoder 9 | 10 | sae.cfg.apply_decoder_bias_to_pre_encoder = False 11 | sae.encoder.bias.data = sae.encoder.bias.data - sae.encoder.weight.data @ sae.decoder.bias.data 12 | 13 | return sae 14 | -------------------------------------------------------------------------------- /src/lm_saes/utils/discrete.py: -------------------------------------------------------------------------------- 1 | class DiscreteMapper: 2 | def __init__(self) -> None: 3 | """Initialize a new PythonDiscreteMapper with empty mappings.""" 4 | self.value_to_int: dict[str, int] = {} 5 | self.int_to_value: list[str] = [] 6 | self.counter: int = 0 7 | 8 | def encode(self, values: list[str]) -> list[int]: 9 | """Encode a list of strings to their corresponding integer indices. 10 | 11 | Args: 12 | values: List of strings to encode 13 | 14 | Returns: 15 | List of integer indices 16 | """ 17 | result = [] 18 | for value in values: 19 | if value not in self.value_to_int: 20 | self.value_to_int[value] = self.counter 21 | self.int_to_value.append(value) 22 | self.counter += 1 23 | result.append(self.value_to_int[value]) 24 | return result 25 | 26 | def decode(self, integers: list[int]) -> list[str]: 27 | """Decode a list of integers back to their corresponding strings. 28 | 29 | Args: 30 | integers: List of integer indices to decode 31 | 32 | Returns: 33 | List of decoded strings 34 | 35 | Raises: 36 | IndexError: If any integer is out of range 37 | """ 38 | return [self.int_to_value[i] for i in integers] 39 | 40 | def get_mapping(self) -> dict[str, int]: 41 | """Get the current mapping from strings to integers. 42 | 43 | Returns: 44 | Dictionary mapping strings to their integer indices 45 | """ 46 | return self.value_to_int.copy() 47 | 48 | 49 | class KeyedDiscreteMapper: 50 | def __init__(self) -> None: 51 | """Initialize a new PythonKeyedDiscreteMapper with empty mappers.""" 52 | self.mappers: dict[str, DiscreteMapper] = {} 53 | 54 | def encode(self, key: str, values: list[str]) -> list[int]: 55 | """Encode a list of strings using the mapper associated with the given key. 56 | 57 | Args: 58 | key: The key identifying which mapper to use 59 | values: List of strings to encode 60 | 61 | Returns: 62 | List of integer indices 63 | """ 64 | if key not in self.mappers: 65 | self.mappers[key] = DiscreteMapper() 66 | return self.mappers[key].encode(values) 67 | 68 | def decode(self, key: str, integers: list[int]) -> list[str]: 69 | """Decode a list of integers using the mapper associated with the given key. 70 | 71 | Args: 72 | key: The key identifying which mapper to use 73 | integers: List of integer indices to decode 74 | 75 | Returns: 76 | List of decoded strings 77 | 78 | Raises: 79 | KeyError: If the key doesn't exist 80 | IndexError: If any integer is out of range 81 | """ 82 | if key not in self.mappers: 83 | raise KeyError("Key not found") 84 | return self.mappers[key].decode(integers) 85 | 86 | def keys(self) -> list[str]: 87 | """Get all keys currently in use. 88 | 89 | Returns: 90 | List of keys 91 | """ 92 | return list(self.mappers.keys()) 93 | -------------------------------------------------------------------------------- /src/lm_saes/utils/hooks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformer_lens.hook_points import HookPoint 3 | 4 | 5 | def compose_hooks(*hooks): 6 | """ 7 | Compose multiple hooks into a single hook by executing them in order. 8 | """ 9 | 10 | def composed_hook(tensor: torch.Tensor, hook: HookPoint): 11 | for hook_fn in hooks: 12 | tensor = hook_fn(tensor, hook) 13 | return tensor 14 | 15 | return composed_hook 16 | 17 | 18 | def retain_grad_hook(tensor: torch.Tensor, hook: HookPoint): 19 | """ 20 | Retain the gradient of the tensor at the given hook point. 21 | """ 22 | tensor.retain_grad() 23 | return tensor 24 | 25 | 26 | def detach_hook(tensor: torch.Tensor, hook: HookPoint): 27 | """ 28 | Detach the tensor at the given hook point. 29 | """ 30 | return tensor.detach().requires_grad_(True) 31 | -------------------------------------------------------------------------------- /src/lm_saes/utils/huggingface.py: -------------------------------------------------------------------------------- 1 | # Some helper function to interact with huggingface API 2 | 3 | import os 4 | import re 5 | import shutil 6 | 7 | from huggingface_hub import create_repo, snapshot_download, upload_folder 8 | 9 | from .misc import print_once 10 | 11 | 12 | def upload_pretrained_sae_to_hf(sae_path: str, repo_id: str, private: bool = False): 13 | """Upload a pretrained SAE model to huggingface model hub 14 | 15 | Args: 16 | sae_path (str): path to the local SAE model 17 | """ 18 | 19 | from lm_saes.config import LanguageModelConfig 20 | from lm_saes.sae import SparseAutoEncoder 21 | 22 | # Load the model 23 | sae = SparseAutoEncoder.from_pretrained(sae_path) 24 | lm_config = LanguageModelConfig.from_pretrained_sae(sae_path) 25 | 26 | # Create local temporary directory for uploading 27 | folder_name = ( 28 | sae.cfg.hook_point_in 29 | if sae.cfg.hook_point_in == sae.cfg.hook_point_out 30 | else f"{sae.cfg.hook_point_in}-{sae.cfg.hook_point_out}" 31 | ) 32 | os.makedirs(f"{sae_path}/{folder_name}", exist_ok=True) 33 | 34 | try: 35 | # Save the model 36 | create_repo(repo_id=repo_id, private=private, exist_ok=True) 37 | 38 | sae.save_pretrained(f"{sae_path}/{folder_name}") 39 | sae.cfg.save_hyperparameters(f"{sae_path}/{folder_name}") 40 | lm_config.save_lm_config(f"{sae_path}/{folder_name}") 41 | 42 | # Upload the model 43 | upload_folder( 44 | folder_path=f"{sae_path}/{folder_name}", 45 | repo_id=repo_id, 46 | path_in_repo=folder_name, 47 | commit_message=f"Upload pretrained SAE model. Hook point: {folder_name}. Language Model Name: {lm_config.model_name}", 48 | ) 49 | 50 | finally: 51 | # Remove the temporary directory 52 | shutil.rmtree(f"{sae_path}/{folder_name}") 53 | 54 | 55 | def download_pretrained_sae_from_hf(repo_id: str, hook_point: str): 56 | """Download a pretrained SAE model from huggingface model hub 57 | 58 | Args: 59 | repo_id (str): id of the repo 60 | hook_point (str): hook point 61 | """ 62 | 63 | snapshot_path = snapshot_download(repo_id=repo_id, allow_patterns=[f"{hook_point}/*"]) 64 | return os.path.join(snapshot_path, hook_point) 65 | 66 | 67 | def _parse_repo_id(pretrained_name_or_path): 68 | pattern = r"L(\d{1,2})([RAMTC])-(8|32)x" 69 | 70 | def replace_match(match): 71 | sublayer = match.group(2) 72 | exp_factor = match.group(3) 73 | 74 | return f"LX{sublayer}-{exp_factor}x" 75 | 76 | output_string = re.sub(pattern, replace_match, pretrained_name_or_path) 77 | 78 | return output_string 79 | 80 | 81 | def parse_pretrained_name_or_path(pretrained_name_or_path: str): 82 | if os.path.exists(pretrained_name_or_path): 83 | return pretrained_name_or_path 84 | else: 85 | print_once(f"Local path `{pretrained_name_or_path}` not found. Downloading from huggingface model hub.") 86 | if pretrained_name_or_path.startswith("fnlp"): 87 | print_once("Downloading Llama Scope SAEs.") 88 | repo_id = _parse_repo_id(pretrained_name_or_path) 89 | hook_point = pretrained_name_or_path.split("/")[1] 90 | else: 91 | repo_id = "/".join(pretrained_name_or_path.split("/")[:2]) 92 | hook_point = "/".join(pretrained_name_or_path.split("/")[2:]) 93 | return download_pretrained_sae_from_hf(repo_id, hook_point) 94 | -------------------------------------------------------------------------------- /src/lm_saes/utils/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def compute_geometric_median(x: torch.Tensor, max_iter=1000) -> torch.Tensor: 5 | """ 6 | Compute the geometric median of a point cloud x. 7 | The geometric median is the point that minimizes the sum of distances to the other points. 8 | This function uses Weiszfeld's algorithm to compute the geometric median. 9 | 10 | Args: 11 | x: Input point cloud. Shape (n_points, n_dims) 12 | max_iter: Maximum number of iterations 13 | 14 | Returns: 15 | The geometric median of the point cloud. Shape (n_dims,) 16 | """ 17 | 18 | # Initialize the geometric median as the mean of the points 19 | y = x.mean(dim=0) 20 | 21 | for _ in range(max_iter): 22 | # Compute the weights 23 | w = 1 / (x - y.unsqueeze(0)).norm(dim=-1) 24 | 25 | # Update the geometric median 26 | y = (w.unsqueeze(-1) * x).sum(dim=0) / w.sum() 27 | 28 | return y 29 | 30 | 31 | def norm_ratio(a, b): 32 | a_norm = torch.norm(a, 2, dim=0).mean() 33 | b_norm = torch.norm(b, 2, dim=0).mean() 34 | return a_norm / b_norm 35 | -------------------------------------------------------------------------------- /src/lm_saes/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from typing import Iterable 4 | 5 | import torch 6 | import torch.distributed as dist 7 | from torch.distributed.nn.functional import all_reduce 8 | from transformers import PreTrainedTokenizerBase 9 | 10 | 11 | def is_master() -> bool: 12 | return not dist.is_initialized() or dist.get_rank() == 0 13 | 14 | 15 | def print_once( 16 | *values: object, 17 | sep: str | None = " ", 18 | end: str | None = "\n", 19 | ) -> None: 20 | if is_master(): 21 | print(*values, sep=sep, end=end) 22 | 23 | 24 | def check_file_path_unused(file_path): 25 | # Check if the file path is None 26 | if file_path is None: 27 | print("Error: File path is empty.") 28 | exit() 29 | 30 | # Check if the file already exists 31 | if os.path.exists(file_path): 32 | print(f"Error: File {file_path} already exists. Please choose a different file path.") 33 | exit() 34 | 35 | 36 | str_dtype_map = { 37 | "float16": torch.float16, 38 | "float32": torch.float32, 39 | "float64": torch.float64, 40 | "int8": torch.int8, 41 | "int16": torch.int16, 42 | "int32": torch.int32, 43 | "int64": torch.int64, 44 | "bool": torch.bool, 45 | "bfloat16": torch.bfloat16, 46 | "bf16": torch.bfloat16, 47 | "float": torch.float, 48 | "fp16": torch.float16, 49 | "fp32": torch.float32, 50 | "fp64": torch.float64, 51 | "int": torch.int, 52 | "torch.float16": torch.float16, 53 | "torch.float32": torch.float32, 54 | "torch.float64": torch.float64, 55 | "torch.int8": torch.int8, 56 | "torch.int16": torch.int16, 57 | "torch.int32": torch.int32, 58 | "torch.int64": torch.int64, 59 | "torch.bool": torch.bool, 60 | "torch.bfloat16": torch.bfloat16, 61 | "torch.float": torch.float, 62 | "torch.int": torch.int, 63 | } 64 | 65 | 66 | def convert_str_to_torch_dtype(str_dtype: str) -> torch.dtype: 67 | if str_dtype in str_dtype_map: 68 | return str_dtype_map[str_dtype] 69 | else: 70 | raise ValueError(f"Unsupported data type: {str_dtype}. Supported data types: {list(str_dtype_map.keys())}.") 71 | 72 | 73 | def convert_torch_dtype_to_str(dtype: torch.dtype) -> str: 74 | dtype_str_map = {v: k for k, v in str_dtype_map.items()} 75 | if dtype in dtype_str_map: 76 | return dtype_str_map[dtype] 77 | else: 78 | raise ValueError(f"Unsupported data type: {dtype}. Supported data types: {list(dtype_str_map.values())}.") 79 | 80 | 81 | def gather_tensors_from_specific_rank(tensor, dst=0): 82 | world_size = dist.get_world_size() 83 | gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)] if dist.get_rank() == dst else None 84 | dist.gather(tensor, gather_list=gathered_tensors, dst=dst) 85 | return gathered_tensors if dist.get_rank() == dst else None 86 | 87 | 88 | def get_tensor_from_specific_rank(tensor, src=0): 89 | dist.broadcast(tensor, src=src) 90 | return tensor 91 | 92 | 93 | def all_reduce_tensor(tensor, aggregate="none"): 94 | _OP_MAP = { 95 | "sum": dist.ReduceOp.SUM, 96 | "mean": dist.ReduceOp.SUM, # Use SUM for mean, but will need to divide by world size 97 | "min": dist.ReduceOp.MIN, 98 | "max": dist.ReduceOp.MAX, 99 | "product": dist.ReduceOp.PRODUCT, 100 | } 101 | 102 | # gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)] 103 | tensor = all_reduce(tensor, op=_OP_MAP[aggregate]) 104 | assert tensor is not None, "All reduce failed" 105 | if aggregate == "mean": 106 | tensor = tensor / dist.get_world_size() 107 | return tensor 108 | 109 | 110 | def assert_tensor_consistency(tensor): 111 | flat_tensor = tensor.flatten() 112 | 113 | local_checksum = flat_tensor.sum().item() 114 | checksum_tensor = torch.tensor(local_checksum).to(tensor.device) 115 | 116 | dist.all_reduce(checksum_tensor, op=dist.ReduceOp.SUM) 117 | 118 | world_size = dist.get_world_size() 119 | expected_checksum = local_checksum * world_size 120 | 121 | # Step 5: Assert that the checksums match across all ranks 122 | assert checksum_tensor.item() == expected_checksum, "Inconsistent tensor data across ranks. Checksum mismatch." 123 | 124 | 125 | def calculate_activation_norm( 126 | activation_stream: Iterable[dict[str, torch.Tensor]], hook_points: list[str], batch_num: int = 8 127 | ) -> dict[str, float]: 128 | activation_norm = {} 129 | stream_iter = iter(activation_stream) 130 | hook_points = list(set(hook_points)) 131 | assert len(hook_points) > 0, "No hook points provided" 132 | while batch_num > 0: 133 | try: 134 | batch = next(stream_iter) 135 | except StopIteration: 136 | warnings.warn(f"Activation stream ended prematurely. {batch_num} batches not processed.") 137 | break 138 | for key in hook_points: 139 | if key not in activation_norm: 140 | activation_norm[key] = batch[key].norm(p=2, dim=1) 141 | else: 142 | activation_norm[key] = torch.cat((activation_norm[key], batch[key].norm(p=2, dim=1)), dim=0) 143 | batch_num -= 1 144 | for key in activation_norm: 145 | activation_norm[key] = activation_norm[key].mean().item() 146 | return activation_norm 147 | 148 | 149 | def get_modality_indices(tokenizer: PreTrainedTokenizerBase, model_name: str) -> dict[str, torch.Tensor]: 150 | modality_indices = {} 151 | if model_name == "facebook/chameleon-7b": 152 | for token_name, token_id in tokenizer.get_vocab().items(): 153 | if token_name.startswith("IMGIMG"): 154 | modality_indices["image"] = ( 155 | [token_id] if "image" not in modality_indices else modality_indices["image"] + [token_id] 156 | ) 157 | else: 158 | modality_indices["text"] = ( 159 | [token_id] if "text" not in modality_indices else modality_indices["text"] + [token_id] 160 | ) 161 | else: 162 | raise ValueError(f"Unsupported model: {model_name}") 163 | for modality in modality_indices: 164 | modality_indices[modality] = torch.tensor(modality_indices[modality], dtype=torch.long) 165 | return modality_indices 166 | -------------------------------------------------------------------------------- /src/lm_saes/utils/tensor_dict.py: -------------------------------------------------------------------------------- 1 | from typing import Any, TypeVar, overload 2 | 3 | import torch 4 | 5 | T = TypeVar("T") 6 | 7 | 8 | def sort_dict_of_tensor( 9 | tensor_dict: dict[str, torch.Tensor], 10 | sort_key: str, 11 | sort_dim: int = 0, 12 | descending: bool = True, 13 | ): 14 | """ 15 | Sort a dictionary of tensors by the values of a tensor. 16 | 17 | Args: 18 | tensor_dict: Dictionary of tensors 19 | sort_key: Key of the tensor to sort by 20 | sort_dim: Dimension to sort along 21 | descending: Whether to sort in descending order 22 | 23 | Returns: 24 | A dictionary of tensors sorted by the values of the specified tensor 25 | """ 26 | sorted_idx = tensor_dict[sort_key].argsort(dim=sort_dim, descending=descending) 27 | return { 28 | k: v.gather(sort_dim, sorted_idx.unsqueeze(-1).expand_as(v.reshape(*sorted_idx.shape, -1)).reshape_as(v)) 29 | for k, v in tensor_dict.items() 30 | } 31 | 32 | 33 | def concat_dict_of_tensor(*dicts: dict[str, torch.Tensor], dim: int = 0) -> dict[str, torch.Tensor]: 34 | """ 35 | Concatenate a dictionary of tensors along a specified dimension. 36 | 37 | Args: 38 | *dicts: Dictionaries of tensors 39 | dim: Dimension to concatenate along 40 | 41 | Returns: 42 | A dictionary of tensors concatenated along the specified dimension 43 | """ 44 | return {k: torch.cat([d[k] for d in dicts], dim=dim) for k in dicts[0].keys()} 45 | 46 | 47 | @overload 48 | def move_dict_of_tensor_to_device( 49 | tensor_dict: dict[str, torch.Tensor], device: torch.device | str 50 | ) -> dict[str, torch.Tensor]: ... 51 | 52 | 53 | @overload 54 | def move_dict_of_tensor_to_device(tensor_dict: dict[str, Any], device: torch.device | str) -> dict[str, Any]: ... 55 | 56 | 57 | def move_dict_of_tensor_to_device(tensor_dict: dict[str, Any], device: torch.device | str) -> dict[str, Any]: 58 | """ 59 | Move tensors in a dictionary to specified device, leaving non-tensor values unchanged. 60 | 61 | Args: 62 | tensor_dict: Dictionary containing tensors and possibly other types 63 | device: Target device to move tensors to 64 | 65 | Returns: 66 | Dictionary with tensors moved to specified device 67 | """ 68 | return {k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v for k, v in tensor_dict.items()} 69 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/Language-Model-SAEs/6638b73672bae0873a3ac185e468d9adcb8e2282/tests/__init__.py -------------------------------------------------------------------------------- /tests/integration/test_train_sae.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | from pytest_mock import MockerFixture 6 | from torch.distributed.device_mesh import init_device_mesh 7 | 8 | # if not torch.cuda.is_available(): 9 | # pytest.skip("CUDA is not available", allow_module_level=True) 10 | from lm_saes.config import InitializerConfig, MixCoderConfig, SAEConfig, TrainerConfig 11 | from lm_saes.initializer import Initializer 12 | from lm_saes.trainer import Trainer 13 | 14 | 15 | @pytest.fixture 16 | def sae_config() -> SAEConfig: 17 | return SAEConfig( 18 | hook_point_in="in", 19 | hook_point_out="out", 20 | d_model=2, 21 | expansion_factor=2, 22 | device="cpu", 23 | dtype=torch.bfloat16, # the precision of bfloat16 is not enough for the tests 24 | act_fn="topk", 25 | norm_activation="dataset-wise", 26 | sparsity_include_decoder_norm=True, 27 | top_k=2, 28 | ) 29 | 30 | 31 | @pytest.fixture 32 | def mixcoder_config() -> MixCoderConfig: 33 | return MixCoderConfig( 34 | hook_point_in="in", 35 | hook_point_out="out", 36 | d_model=2, 37 | expansion_factor=2, 38 | device="cpu", 39 | dtype=torch.bfloat16, # the precision of bfloat16 is not enough for the tests 40 | act_fn="topk", 41 | norm_activation="dataset-wise", 42 | top_k=2, 43 | modalities={"image": 2, "text": 2, "shared": 2}, 44 | ) 45 | 46 | 47 | @pytest.fixture 48 | def initializer_config() -> InitializerConfig: 49 | return InitializerConfig( 50 | state="training", 51 | init_search=True, 52 | l1_coefficient=0.00008, 53 | ) 54 | 55 | 56 | @pytest.fixture 57 | def trainer_config(tmp_path) -> TrainerConfig: 58 | # Remove tmp path 59 | os.rmdir(tmp_path) 60 | return TrainerConfig( 61 | initial_k=3, 62 | total_training_tokens=400, 63 | log_frequency=10, 64 | eval_frequency=10, 65 | n_checkpoints=0, 66 | exp_result_path=str(tmp_path), 67 | ) 68 | 69 | 70 | def test_train_sae( 71 | sae_config: SAEConfig, 72 | initializer_config: InitializerConfig, 73 | trainer_config: TrainerConfig, 74 | mocker: MockerFixture, 75 | tmp_path, 76 | ) -> None: 77 | wandb_runner = mocker.Mock() 78 | wandb_runner.log = lambda *args, **kwargs: None 79 | device_mesh = ( 80 | init_device_mesh( 81 | device_type="cuda", 82 | mesh_shape=(int(os.environ.get("WORLD_SIZE", 1)), 1), 83 | mesh_dim_names=("data", "model"), 84 | ) 85 | if os.environ.get("WORLD_SIZE") is not None 86 | else None 87 | ) 88 | activation_stream = [ 89 | { 90 | "in": torch.randn(4, 2, dtype=sae_config.dtype, device=sae_config.device), 91 | "out": torch.randn(4, 2, dtype=sae_config.dtype, device=sae_config.device), 92 | "tokens": torch.tensor([2, 3, 4, 5], dtype=torch.long, device=sae_config.device), 93 | } 94 | for _ in range(200) 95 | ] 96 | initializer = Initializer(initializer_config) 97 | sae = initializer.initialize_sae_from_config( 98 | sae_config, 99 | device_mesh=device_mesh, 100 | activation_stream=activation_stream, 101 | ) 102 | trainer = Trainer(trainer_config) 103 | trainer.fit( 104 | sae=sae, 105 | activation_stream=activation_stream, 106 | eval_fn=lambda x: None, 107 | wandb_logger=wandb_runner, 108 | ) 109 | 110 | 111 | def test_train_mixcoder( 112 | mixcoder_config: MixCoderConfig, 113 | initializer_config: InitializerConfig, 114 | trainer_config: TrainerConfig, 115 | mocker: MockerFixture, 116 | tmp_path, 117 | ) -> None: 118 | wandb_runner = mocker.Mock() 119 | wandb_runner.log = lambda *args, **kwargs: None 120 | device_mesh = ( 121 | init_device_mesh( 122 | device_type="cuda", 123 | mesh_shape=(int(os.environ.get("WORLD_SIZE", 1)), 1), 124 | mesh_dim_names=("data", "model"), 125 | ) 126 | if os.environ.get("WORLD_SIZE") is not None 127 | else None 128 | ) 129 | activation_stream = [ 130 | { 131 | "in": torch.randn(4, 2, dtype=mixcoder_config.dtype, device=mixcoder_config.device), 132 | "out": torch.randn(4, 2, dtype=mixcoder_config.dtype, device=mixcoder_config.device), 133 | "tokens": torch.tensor([2, 3, 4, 5], dtype=torch.long, device=mixcoder_config.device), 134 | } 135 | for _ in range(200) 136 | ] 137 | initializer = Initializer(initializer_config) 138 | tokenizer = mocker.Mock() 139 | tokenizer.get_vocab.return_value = { 140 | "IMGIMG1": 1, 141 | "IMGIMG2": 2, 142 | "IMGIMG3": 3, 143 | "IMGIMG4": 4, 144 | "TEXT1": 5, 145 | "TEXT2": 6, 146 | "TEXT3": 7, 147 | "TEXT4": 8, 148 | } 149 | model_name = "facebook/chameleon-7b" 150 | 151 | mixcoder_settings = {"tokenizer": tokenizer, "model_name": model_name} 152 | mixcoder = initializer.initialize_sae_from_config( 153 | mixcoder_config, 154 | device_mesh=device_mesh, 155 | activation_stream=activation_stream, 156 | mixcoder_settings=mixcoder_settings, 157 | ) 158 | trainer = Trainer(trainer_config) 159 | trainer.fit( 160 | sae=mixcoder, 161 | activation_stream=activation_stream, 162 | eval_fn=lambda x: None, 163 | wandb_logger=wandb_runner, 164 | ) 165 | -------------------------------------------------------------------------------- /tests/unit/test_cached_activation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | import torch 5 | from pytest_mock import MockerFixture 6 | 7 | from lm_saes.activation.processors.cached_activation import ( 8 | CachedActivationLoader, 9 | ChunkInfo, 10 | ) 11 | 12 | 13 | @pytest.fixture 14 | def sample_activation(): 15 | """Create a sample activation tensor.""" 16 | return torch.randn(2, 3, 4) # (n_samples, n_context, d_model) 17 | 18 | 19 | @pytest.fixture 20 | def sample_tokens(): 21 | """Create sample token indices.""" 22 | return torch.randint(0, 1000, (2, 3)) # (n_samples, n_context) 23 | 24 | 25 | @pytest.fixture 26 | def sample_info(): 27 | """Create sample info list.""" 28 | return [{"context_id": f"ctx_{i}"} for i in range(2)] 29 | 30 | 31 | def create_fake_pt_file(fs, path: Path, activation, tokens, info): 32 | """Helper to create a fake .pt file with test data.""" 33 | fs.create_file( 34 | path, 35 | contents="", # Contents don't matter as we'll mock torch.load 36 | ) 37 | 38 | 39 | def test_chunk_info_from_path(): 40 | """Test ChunkInfo.from_path with different filename formats.""" 41 | # Test sharded format 42 | sharded = ChunkInfo.from_path(Path("shard-1-chunk-2.pt")) 43 | assert sharded.shard_id == 1 44 | assert sharded.chunk_id == 2 45 | 46 | # Test non-sharded format 47 | non_sharded = ChunkInfo.from_path(Path("chunk-3.pt")) 48 | assert non_sharded.shard_id == 0 49 | assert non_sharded.chunk_id == 3 50 | 51 | # Test invalid format 52 | with pytest.raises(ValueError): 53 | ChunkInfo.from_path(Path("invalid.pt")) 54 | 55 | 56 | def test_cached_activation_loader(fs, mocker: MockerFixture, sample_activation, sample_tokens, sample_info): 57 | """Test CachedActivationLoader with a fake filesystem.""" 58 | # Setup test directory structure 59 | cache_dir = Path("/cache") 60 | hook_points = ["hook1", "hook2"] 61 | 62 | for hook in hook_points: 63 | hook_dir = cache_dir / hook 64 | fs.create_dir(hook_dir) 65 | 66 | # Create both sharded and non-sharded files 67 | files = [ 68 | hook_dir / "shard-0-chunk-0.pt", 69 | hook_dir / "shard-0-chunk-1.pt", 70 | hook_dir / "chunk-2.pt", 71 | ] 72 | 73 | for file in files: 74 | create_fake_pt_file(fs, file, sample_activation, sample_tokens, sample_info) 75 | 76 | # Mock torch.load to return test data 77 | def mock_torch_load(path, **kwargs): 78 | return { 79 | "activation": sample_activation, 80 | "tokens": sample_tokens, 81 | "meta": sample_info, 82 | } 83 | 84 | mocker.patch("torch.load", side_effect=mock_torch_load) 85 | 86 | # Initialize loader and process data 87 | loader = CachedActivationLoader(cache_dir, hook_points) 88 | results = list(loader.process()) 89 | 90 | # Verify results 91 | assert len(results) == 3 # 3 chunks 92 | 93 | for i, result in enumerate(results): 94 | # Check if all hook points are present 95 | for hook in hook_points: 96 | assert hook in result 97 | assert torch.allclose(result[hook], sample_activation) 98 | 99 | # Check tokens and info 100 | assert torch.equal(result["tokens"], sample_tokens) 101 | assert result["meta"] == sample_info 102 | 103 | 104 | def test_cached_activation_loader_missing_dir(fs): 105 | """Test CachedActivationLoader with missing directory.""" 106 | with pytest.raises(FileNotFoundError): 107 | loader = CachedActivationLoader("/nonexistent", ["hook1"]) 108 | list(loader.process()) 109 | 110 | 111 | def test_cached_activation_loader_mismatched_chunks( 112 | fs, mocker: MockerFixture, sample_activation, sample_tokens, sample_info 113 | ): 114 | """Test CachedActivationLoader with mismatched chunk counts.""" 115 | # Setup directories with different numbers of chunks 116 | cache_dir = Path("/cache") 117 | 118 | # hook1 has 2 chunks 119 | hook1_dir = cache_dir / "hook1" 120 | fs.create_dir(hook1_dir) 121 | create_fake_pt_file(fs, hook1_dir / "chunk-0.pt", sample_activation, sample_tokens, sample_info) 122 | create_fake_pt_file(fs, hook1_dir / "chunk-1.pt", sample_activation, sample_tokens, sample_info) 123 | 124 | # hook2 has 1 chunk 125 | hook2_dir = cache_dir / "hook2" 126 | fs.create_dir(hook2_dir) 127 | create_fake_pt_file(fs, hook2_dir / "chunk-0.pt", sample_activation, sample_tokens, sample_info) 128 | 129 | # Mock torch.load 130 | mocker.patch( 131 | "torch.load", 132 | return_value={ 133 | "activation": sample_activation, 134 | "tokens": sample_tokens, 135 | "meta": sample_info, 136 | }, 137 | ) 138 | 139 | # Should raise ValueError due to mismatched chunk counts 140 | with pytest.raises( 141 | ValueError, 142 | match="Hook points have different numbers of chunks: {'hook1': 2, 'hook2': 1}. All hook points must have the same number of chunks.", 143 | ): 144 | loader = CachedActivationLoader(cache_dir, ["hook1", "hook2"]) 145 | list(loader.process()) 146 | 147 | 148 | def test_cached_activation_loader_invalid_data(fs, mocker: MockerFixture): 149 | """Test CachedActivationLoader with invalid data format.""" 150 | cache_dir = Path("/cache") 151 | hook_dir = cache_dir / "hook1" 152 | fs.create_dir(hook_dir) 153 | create_fake_pt_file(fs, hook_dir / "chunk-0.pt", None, None, None) 154 | 155 | # Mock torch.load to return invalid data 156 | mocker.patch("torch.load", return_value={"invalid": "data"}) 157 | 158 | loader = CachedActivationLoader(cache_dir, ["hook1"]) 159 | with pytest.raises( 160 | AssertionError, 161 | match="Loading cached activation /cache/hook1/chunk-0.pt error: missing 'activation' field", 162 | ): 163 | list(loader.process()) 164 | -------------------------------------------------------------------------------- /tests/unit/test_concurrent.py: -------------------------------------------------------------------------------- 1 | import time 2 | from concurrent.futures import ThreadPoolExecutor 3 | 4 | import pytest 5 | 6 | from lm_saes.utils.concurrent import BackgroundGenerator 7 | 8 | 9 | def test_basic_iteration(): 10 | """Test basic iteration through a simple list.""" 11 | data = [1, 2, 3, 4, 5] 12 | bg = BackgroundGenerator(data) 13 | assert list(bg) == data 14 | 15 | 16 | def test_empty_generator(): 17 | """Test behavior with an empty generator.""" 18 | bg = BackgroundGenerator([]) 19 | with pytest.raises(StopIteration): 20 | next(bg) 21 | 22 | 23 | def test_max_prefetch(): 24 | """Test that max_prefetch limits the queue size.""" 25 | 26 | def slow_generator(): 27 | for i in range(3): 28 | time.sleep(0.1) # Simulate slow generation 29 | yield i 30 | 31 | bg = BackgroundGenerator(slow_generator(), max_prefetch=1) 32 | assert bg.queue.maxsize == 1 33 | 34 | # Consume all items 35 | assert list(bg) == [0, 1, 2] 36 | 37 | 38 | def test_exception_propagation(): 39 | """Test that exceptions from the generator are properly propagated.""" 40 | 41 | def failing_generator(): 42 | yield 1 43 | raise ValueError("Test error") 44 | yield 2 # This will never be reached 45 | 46 | bg = BackgroundGenerator(failing_generator()) 47 | 48 | # First item should be retrieved normally 49 | assert next(bg) == 1 50 | 51 | # Next call should raise the ValueError 52 | with pytest.raises(ValueError, match="Test error"): 53 | next(bg) 54 | 55 | 56 | def test_custom_executor(): 57 | """Test using a custom executor.""" 58 | data = [1, 2, 3] 59 | with ThreadPoolExecutor(max_workers=2) as executor: 60 | bg = BackgroundGenerator(data, executor=executor) 61 | assert list(bg) == data 62 | 63 | 64 | def test_generator_cleanup(mocker): 65 | """Test proper cleanup of resources.""" 66 | # Mock ThreadPoolExecutor 67 | mock_executor = mocker.Mock(spec=ThreadPoolExecutor) 68 | mock_executor.submit.return_value = mocker.Mock() 69 | 70 | data = [1, 2, 3] 71 | bg = BackgroundGenerator(data, executor=mock_executor) 72 | 73 | # Simulate deletion 74 | bg.__del__() 75 | 76 | # Verify that shutdown wasn't called since it's not an owned executor 77 | mock_executor.shutdown.assert_not_called() 78 | 79 | 80 | def test_large_dataset(): 81 | """Test handling of a larger dataset.""" 82 | large_data = range(1000) 83 | bg = BackgroundGenerator(large_data, max_prefetch=10) 84 | assert list(bg) == list(large_data) 85 | 86 | 87 | def test_multiple_iterations(): 88 | """Test that the generator can only be consumed once.""" 89 | data = [1, 2, 3] 90 | bg = BackgroundGenerator(data) 91 | 92 | # First iteration should work 93 | assert list(bg) == data 94 | 95 | # Second iteration should yield no items 96 | assert list(bg) == [] 97 | -------------------------------------------------------------------------------- /tests/unit/test_database.py: -------------------------------------------------------------------------------- 1 | import mongomock 2 | import pytest 3 | 4 | from lm_saes.config import MongoDBConfig, SAEConfig 5 | from lm_saes.database import MongoClient, SAERecord 6 | 7 | 8 | @pytest.fixture 9 | def mongo_client() -> MongoClient: 10 | """ 11 | Creates a MongoClient instance with an in-memory MongoDB. 12 | 13 | Returns: 14 | MongoClient: A configured MongoClient instance for testing 15 | """ 16 | with mongomock.patch(servers=(("fake", 27017),)): 17 | client = MongoClient(MongoDBConfig(mongo_uri="mongodb://fake", mongo_db="test_db")) 18 | yield client 19 | # Clear all collections after each test 20 | client.db.drop_collection("sae") 21 | client.db.drop_collection("analysis") 22 | client.db.drop_collection("feature") 23 | 24 | 25 | def test_create_and_get_sae(mongo_client: MongoClient) -> None: 26 | """Test creating and retrieving an SAE record.""" 27 | # Arrange 28 | name = "test_sae" 29 | series = "test_series" 30 | path = "test_path" 31 | cfg = SAEConfig(hook_point_in="test_hook_point_in", d_sae=10, d_model=10, expansion_factor=1) 32 | 33 | # Act 34 | mongo_client.create_sae(name, series, path, cfg) 35 | result = mongo_client.get_sae(name, series) 36 | 37 | # Assert 38 | assert isinstance(result, SAERecord) 39 | assert result.name == name 40 | assert result.series == series 41 | assert result.path == path 42 | assert result.cfg.d_sae == cfg.d_sae 43 | 44 | 45 | def test_list_saes(mongo_client: MongoClient) -> None: 46 | """Test listing SAE records.""" 47 | # Arrange 48 | cfg = SAEConfig(hook_point_in="test_hook_point_in", d_sae=10, d_model=10, expansion_factor=1) 49 | print(mongo_client.sae_collection.find_one()) 50 | mongo_client.create_sae("sae1", "series1", "test_path", cfg) 51 | mongo_client.create_sae("sae2", "series1", "test_path", cfg) 52 | mongo_client.create_sae("sae3", "series2", "test_path", cfg) 53 | 54 | # Act & Assert 55 | assert set(mongo_client.list_saes()) == {"sae1", "sae2", "sae3"} 56 | assert set(mongo_client.list_saes("series1")) == {"sae1", "sae2"} 57 | assert set(mongo_client.list_saes("series2")) == {"sae3"} 58 | 59 | 60 | def test_remove_sae(mongo_client: MongoClient) -> None: 61 | """Test removing an SAE record.""" 62 | # Arrange 63 | name = "test_sae" 64 | series = "test_series" 65 | path = "test_path" 66 | cfg = SAEConfig(hook_point_in="test_hook_point_in", d_sae=10, d_model=10, expansion_factor=1) 67 | mongo_client.create_sae(name, series, path, cfg) 68 | 69 | # Act 70 | mongo_client.remove_sae(name, series) 71 | 72 | # Assert 73 | assert mongo_client.get_sae(name, series) is None 74 | assert mongo_client.list_saes(series) == [] 75 | 76 | 77 | def test_create_and_get_analysis(mongo_client: MongoClient) -> None: 78 | """Test creating and retrieving analysis records.""" 79 | # Arrange 80 | cfg = SAEConfig(hook_point_in="test_hook_point_in", d_sae=10, d_model=10, expansion_factor=1) 81 | mongo_client.create_sae("test_sae", "test_series", "test_path", cfg) 82 | 83 | # Act 84 | mongo_client.create_analysis("test_analysis", "test_sae", "test_series") 85 | analyses = mongo_client.list_analyses("test_sae", "test_series") 86 | 87 | # Assert 88 | assert analyses == ["test_analysis"] 89 | 90 | 91 | def test_get_nonexistent_sae(mongo_client: MongoClient) -> None: 92 | """Test retrieving a non-existent SAE record.""" 93 | assert mongo_client.get_sae("nonexistent", "series") is None 94 | 95 | 96 | def test_get_feature(mongo_client: MongoClient) -> None: 97 | """Test retrieving feature records.""" 98 | # Arrange 99 | cfg = SAEConfig(hook_point_in="test_hook_point_in", d_sae=2, d_model=2, expansion_factor=1) 100 | mongo_client.create_sae("test_sae", "test_series", "test_path", cfg) 101 | 102 | # Act 103 | feature = mongo_client.get_feature("test_sae", "test_series", 0) 104 | 105 | # Assert 106 | assert feature is not None 107 | assert feature.sae_name == "test_sae" 108 | assert feature.sae_series == "test_series" 109 | assert feature.index == 0 110 | -------------------------------------------------------------------------------- /tests/unit/test_discrete_mapper.py: -------------------------------------------------------------------------------- 1 | from lm_saes.utils.discrete import DiscreteMapper, KeyedDiscreteMapper 2 | 3 | 4 | def test_discrete_mapper(): 5 | mapper = DiscreteMapper() 6 | assert mapper.encode(["a", "b", "a", "c"]) == [0, 1, 0, 2] 7 | assert mapper.decode([0, 1, 0, 2]) == ["a", "b", "a", "c"] 8 | 9 | assert mapper.get_mapping() == {"a": 0, "b": 1, "c": 2} 10 | 11 | assert mapper.encode(["a", "c", "d"]) == [0, 2, 3] 12 | assert mapper.decode([0, 2, 3]) == ["a", "c", "d"] 13 | assert mapper.get_mapping() == {"a": 0, "b": 1, "c": 2, "d": 3} 14 | 15 | assert mapper.encode(["a", "c", "d", "a"]) == [0, 2, 3, 0] 16 | assert mapper.decode([0, 2, 3, 0]) == ["a", "c", "d", "a"] 17 | assert mapper.get_mapping() == {"a": 0, "b": 1, "c": 2, "d": 3} 18 | 19 | 20 | def test_keyed_discrete_mapper(): 21 | mapper = KeyedDiscreteMapper() 22 | assert mapper.encode("foo", ["a", "b", "a", "c"]) == [0, 1, 0, 2] 23 | assert mapper.decode("foo", [0, 1, 0, 2]) == ["a", "b", "a", "c"] 24 | assert mapper.keys() == ["foo"] 25 | 26 | assert mapper.encode("bar", ["a", "c", "d"]) == [0, 1, 2] 27 | assert mapper.decode("bar", [0, 1, 2]) == ["a", "c", "d"] 28 | assert sorted(mapper.keys()) == ["bar", "foo"] 29 | 30 | 31 | if __name__ == "__main__": 32 | import timeit 33 | 34 | print(timeit.timeit("test_discrete_mapper()", globals=globals(), number=1000)) 35 | print(timeit.timeit("test_keyed_discrete_mapper()", globals=globals(), number=1000)) 36 | -------------------------------------------------------------------------------- /tests/unit/test_example.py: -------------------------------------------------------------------------------- 1 | def func(x): 2 | return x + 1 3 | 4 | 5 | def test_answer(): 6 | assert func(4) == 5 7 | -------------------------------------------------------------------------------- /tests/unit/test_misc.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Iterator 3 | 4 | import pytest 5 | import torch 6 | 7 | from lm_saes.utils.misc import calculate_activation_norm 8 | 9 | 10 | class TestCalculateActivationNorm: 11 | @pytest.fixture 12 | def mock_activation_stream(self) -> Iterator[dict[str, torch.Tensor]]: 13 | """Creates a mock activation stream with known values for testing.""" 14 | 15 | def stream_generator(): 16 | # Create 10 batches of activations 17 | for _ in range(10): 18 | yield { 19 | "layer1": torch.ones(4, 16), # norm will be sqrt(16) 20 | "layer2": torch.ones(4, 16) * 2, # norm will be sqrt(16) * 2 21 | } 22 | 23 | return stream_generator() 24 | 25 | def test_basic_functionality(self, mock_activation_stream): 26 | """Test basic functionality with default batch_num.""" 27 | result = calculate_activation_norm(mock_activation_stream, ["layer1", "layer2"]) 28 | 29 | assert isinstance(result, dict) 30 | assert "layer1" in result 31 | assert "layer2" in result 32 | 33 | # Expected norms: 34 | # layer1: sqrt(16) ≈ 4 35 | # layer2: sqrt(16) * 2 ≈ 8.0 36 | assert pytest.approx(result["layer1"], rel=1e-4) == 4.0 37 | assert pytest.approx(result["layer2"], rel=1e-4) == 8.0 38 | 39 | def test_custom_batch_num(self, mock_activation_stream): 40 | """Test with custom batch_num parameter.""" 41 | result = calculate_activation_norm(mock_activation_stream, batch_num=3, hook_points=["layer1", "layer2"]) 42 | 43 | # Should still give same results as we're averaging 44 | assert pytest.approx(result["layer1"], rel=1e-4) == 4.0 45 | assert pytest.approx(result["layer2"], rel=1e-4) == 8.0 46 | 47 | def test_empty_stream(self): 48 | """Test behavior with empty activation stream.""" 49 | empty_stream = iter([]) 50 | with pytest.warns(UserWarning): 51 | calculate_activation_norm(empty_stream, hook_points=[""]) 52 | 53 | def test_single_batch(self): 54 | """Test with a single batch of activations.""" 55 | 56 | def single_batch_stream(): 57 | yield {"single": torch.ones(2, 4)} # norm will be 2.0 58 | 59 | result = calculate_activation_norm(single_batch_stream(), hook_points=["single", "single"], batch_num=1) 60 | assert pytest.approx(result["single"], rel=1e-4) == 2.0 61 | 62 | def test_zero_tensors(self): 63 | """Test with zero tensors.""" 64 | 65 | def zero_stream(): 66 | for _ in range(10): 67 | yield {"zeros": torch.zeros(2, 4)} 68 | 69 | result = calculate_activation_norm(zero_stream(), hook_points=["zeros"]) 70 | assert result["zeros"] == 0.0 71 | 72 | def test_mixed_values(self): 73 | """Test with mixed positive/negative values.""" 74 | 75 | def mixed_stream(): 76 | for i in range(10): 77 | yield {"mixed": torch.tensor([[1.0, -2.0], [3.0, -4.0], [3.0, -2.0], [9.0, -4.0]]) * (i + 1)} 78 | 79 | result = calculate_activation_norm(mixed_stream(), hook_points=["mixed"], batch_num=10) 80 | assert ( 81 | pytest.approx(result["mixed"], rel=1e-4) == ((math.sqrt(5) + 5 + math.sqrt(13) + math.sqrt(97)) / 4) * 5.5 82 | ) 83 | -------------------------------------------------------------------------------- /tests/unit/test_mixcoder.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from lm_saes.config import MixCoderConfig 5 | from lm_saes.mixcoder import MixCoder 6 | 7 | 8 | @pytest.fixture 9 | def config(): 10 | return MixCoderConfig( 11 | d_model=3, 12 | modalities={"text": 2, "image": 3, "shared": 4}, 13 | device="cpu", 14 | dtype=torch.float32, 15 | use_glu_encoder=True, 16 | use_decoder_bias=True, 17 | hook_point_in="hook_point_in", 18 | hook_point_out="hook_point_out", 19 | expansion_factor=1.0, 20 | top_k=2, 21 | act_fn="topk", 22 | ) 23 | 24 | 25 | @pytest.fixture 26 | def modality_indices(): 27 | return { 28 | "text": torch.tensor([1, 2, 3, 4]), 29 | "image": torch.tensor([5, 6, 7, 8]), 30 | } 31 | 32 | 33 | @pytest.fixture 34 | def mixcoder(config, modality_indices): 35 | model = MixCoder(config) 36 | model.init_parameters(modality_indices=modality_indices) 37 | model.decoder["text"].bias.data = torch.rand_like(model.decoder["text"].bias.data) 38 | model.decoder["image"].bias.data = torch.rand_like(model.decoder["image"].bias.data) 39 | model.decoder["shared"].bias.data = torch.rand_like(model.decoder["shared"].bias.data) 40 | model.encoder["text"].bias.data = torch.rand_like(model.encoder["text"].bias.data) 41 | model.encoder["image"].bias.data = torch.rand_like(model.encoder["image"].bias.data) 42 | model.encoder["shared"].bias.data = torch.rand_like(model.encoder["shared"].bias.data) 43 | return model 44 | 45 | 46 | def test_init_parameters(mixcoder, config): 47 | assert mixcoder.modality_index == {"text": (0, 2), "image": (2, 5), "shared": (5, 9)} 48 | assert torch.allclose(mixcoder.modality_indices["text"], torch.tensor([1, 2, 3, 4])) 49 | assert torch.allclose(mixcoder.modality_indices["image"], torch.tensor([5, 6, 7, 8])) 50 | 51 | 52 | def test_encode_decode(mixcoder, config): 53 | """Test the encoding and decoding process.""" 54 | mixcoder.set_dataset_average_activation_norm({"hook_point_in": 1.0, "hook_point_out": 1.0}) 55 | batch_size = 8 56 | x = torch.randn(batch_size, config.d_model) # batch, d_model 57 | tokens = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) 58 | x_text = torch.cat([x[:4, :], torch.zeros(4, config.d_model)], dim=0) 59 | x_image = torch.cat([torch.zeros(4, config.d_model), x[4:, :]], dim=0) 60 | tokens_text = torch.tensor([1, 2, 3, 4, 0, 0, 0, 0]) 61 | tokens_image = torch.tensor([0, 0, 0, 0, 5, 6, 7, 8]) 62 | # Test encode 63 | feature_acts = mixcoder.encode(x, tokens=tokens) 64 | assert feature_acts.shape == (batch_size, config.d_sae) # batch, d_sae 65 | 66 | feature_acts_text = mixcoder.encode(x_text, tokens=tokens_text) 67 | assert feature_acts_text.shape == (batch_size, config.d_sae) 68 | feature_acts_image = mixcoder.encode(x_image, tokens=tokens_image) 69 | assert feature_acts_image.shape == (batch_size, config.d_sae) 70 | modality_index = mixcoder.get_modality_index() 71 | 72 | assert torch.allclose( 73 | feature_acts_text[:4, slice(*modality_index["text"])], 74 | feature_acts[:4, slice(*modality_index["text"])], 75 | ) 76 | assert torch.allclose( 77 | feature_acts_image[4:, slice(*modality_index["image"])], 78 | feature_acts[4:, slice(*modality_index["image"])], 79 | ) 80 | 81 | assert torch.allclose( 82 | torch.cat( 83 | [ 84 | feature_acts_text[:4, slice(*modality_index["shared"])], 85 | feature_acts_image[4:, slice(*modality_index["shared"])], 86 | ], 87 | dim=0, 88 | ), 89 | feature_acts[:, slice(*modality_index["shared"])], 90 | ) 91 | print(feature_acts) 92 | 93 | # Test decode 94 | reconstructed = mixcoder.decode(feature_acts, tokens=tokens) 95 | assert reconstructed.shape == (batch_size, config.d_model) 96 | 97 | reconstructed_text = mixcoder.decode(feature_acts_text, tokens=tokens_text) 98 | assert reconstructed_text.shape == (batch_size, config.d_model) 99 | 100 | reconstructed_image = mixcoder.decode(feature_acts_image, tokens=tokens_image) 101 | assert reconstructed_image.shape == (batch_size, config.d_model) 102 | 103 | assert torch.allclose(reconstructed_text[:4, :], reconstructed[:4, :]) 104 | assert torch.allclose(reconstructed_image[4:, :], reconstructed[4:, :]) 105 | 106 | 107 | def test_get_modality_activation_mask(mixcoder, config): 108 | """Test the _get_modality_activation method.""" 109 | tokens = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) 110 | 111 | # Test text modality 112 | text_activation_mask = mixcoder.get_modality_token_mask(tokens, "text") 113 | assert torch.all(text_activation_mask[:4] == 1) # First 4 positions should be 1 114 | assert torch.all(text_activation_mask[4:] == 0) # Last 4 positions should be 0 115 | 116 | # Test image modality 117 | image_activation_mask = mixcoder.get_modality_token_mask(tokens, "image") 118 | assert torch.all(image_activation_mask[:4] == 0) # First 4 positions should be 0 119 | assert torch.all(image_activation_mask[4:] == 1) # Last 4 positions should be 1 120 | -------------------------------------------------------------------------------- /ui/.env.example: -------------------------------------------------------------------------------- 1 | VITE_BACKEND_URL=http://localhost:24577 -------------------------------------------------------------------------------- /ui/.eslintrc.cjs: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | root: true, 3 | env: { browser: true, es2020: true }, 4 | extends: ["eslint:recommended", "plugin:@typescript-eslint/recommended", "plugin:react-hooks/recommended"], 5 | ignorePatterns: ["dist", ".eslintrc.cjs"], 6 | parser: "@typescript-eslint/parser", 7 | plugins: ["react-refresh"], 8 | rules: { 9 | "react-refresh/only-export-components": ["warn", { allowConstantExport: true }], 10 | "@typescript-eslint/no-unused-vars": [ 11 | "error", 12 | { 13 | args: "all", 14 | argsIgnorePattern: "^_", 15 | caughtErrors: "all", 16 | caughtErrorsIgnorePattern: "^_", 17 | destructuredArrayIgnorePattern: "^_", 18 | varsIgnorePattern: "^_", 19 | ignoreRestSiblings: true, 20 | }, 21 | ], 22 | }, 23 | }; 24 | -------------------------------------------------------------------------------- /ui/.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | npm-debug.log* 5 | yarn-debug.log* 6 | yarn-error.log* 7 | pnpm-debug.log* 8 | lerna-debug.log* 9 | 10 | node_modules 11 | dist 12 | dist-ssr 13 | *.local 14 | 15 | # Editor directories and files 16 | .vscode/* 17 | !.vscode/extensions.json 18 | .idea 19 | .DS_Store 20 | *.suo 21 | *.ntvs* 22 | *.njsproj 23 | *.sln 24 | *.sw? 25 | 26 | # Environment variables 27 | .env 28 | .env.local 29 | .env.development 30 | .env.test 31 | .env.production -------------------------------------------------------------------------------- /ui/.prettierrc.cjs: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | printWidth: 120, // max 120 chars in line, code is easy to read 3 | useTabs: false, // use spaces instead of tabs 4 | tabWidth: 2, // "visual width" of of the "tab" 5 | trailingComma: "es5", // add trailing commas in objects, arrays, etc. 6 | semi: true, // add ; when needed 7 | singleQuote: false, // '' for stings instead of "" 8 | bracketSpacing: true, // import { some } ... instead of import {some} ... 9 | arrowParens: "always", // braces even for single param in arrow functions (a) => { } 10 | jsxSingleQuote: false, // "" for react props, like in html 11 | bracketSameLine: false, // pretty JSX 12 | endOfLine: "lf", // 'lf' for linux, 'crlf' for windows, we need to use 'lf' for git 13 | }; 14 | -------------------------------------------------------------------------------- /ui/README.md: -------------------------------------------------------------------------------- 1 | # React + TypeScript + Vite 2 | 3 | This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules. 4 | 5 | Currently, two official plugins are available: 6 | 7 | - [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/README.md) uses [Babel](https://babeljs.io/) for Fast Refresh 8 | - [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh 9 | 10 | ## Expanding the ESLint configuration 11 | 12 | If you are developing a production application, we recommend updating the configuration to enable type aware lint rules: 13 | 14 | - Configure the top-level `parserOptions` property like this: 15 | 16 | ```js 17 | export default { 18 | // other rules... 19 | parserOptions: { 20 | ecmaVersion: "latest", 21 | sourceType: "module", 22 | project: ["./tsconfig.json", "./tsconfig.node.json"], 23 | tsconfigRootDir: __dirname, 24 | }, 25 | }; 26 | ``` 27 | 28 | - Replace `plugin:@typescript-eslint/recommended` to `plugin:@typescript-eslint/recommended-type-checked` or `plugin:@typescript-eslint/strict-type-checked` 29 | - Optionally add `plugin:@typescript-eslint/stylistic-type-checked` 30 | - Install [eslint-plugin-react](https://github.com/jsx-eslint/eslint-plugin-react) and add `plugin:react/recommended` & `plugin:react/jsx-runtime` to the `extends` list 31 | -------------------------------------------------------------------------------- /ui/bun.lockb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/Language-Model-SAEs/6638b73672bae0873a3ac185e468d9adcb8e2282/ui/bun.lockb -------------------------------------------------------------------------------- /ui/components.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://ui.shadcn.com/schema.json", 3 | "style": "default", 4 | "rsc": false, 5 | "tsx": true, 6 | "tailwind": { 7 | "config": "tailwind.config.js", 8 | "css": "src/globals.css", 9 | "baseColor": "slate", 10 | "cssVariables": true, 11 | "prefix": "" 12 | }, 13 | "aliases": { 14 | "components": "@/components", 15 | "utils": "@/lib/utils" 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /ui/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Sparse Auto Encoder Visualizer 8 | 9 | 10 |
11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /ui/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ui", 3 | "private": true, 4 | "version": "0.0.0", 5 | "type": "module", 6 | "scripts": { 7 | "dev": "vite", 8 | "build": "tsc && vite build", 9 | "lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0", 10 | "preview": "vite preview" 11 | }, 12 | "dependencies": { 13 | "@msgpack/msgpack": "^3.0.0-beta2", 14 | "@radix-ui/react-accordion": "^1.1.2", 15 | "@radix-ui/react-context-menu": "^2.2.1", 16 | "@radix-ui/react-dialog": "^1.1.1", 17 | "@radix-ui/react-dropdown-menu": "^2.0.6", 18 | "@radix-ui/react-hover-card": "^1.0.7", 19 | "@radix-ui/react-label": "^2.0.2", 20 | "@radix-ui/react-popover": "^1.1.1", 21 | "@radix-ui/react-select": "^2.0.0", 22 | "@radix-ui/react-separator": "^1.0.3", 23 | "@radix-ui/react-slot": "^1.0.2", 24 | "@radix-ui/react-switch": "^1.1.0", 25 | "@radix-ui/react-tabs": "^1.0.4", 26 | "@radix-ui/react-toggle": "^1.0.3", 27 | "@radix-ui/react-tooltip": "^1.1.2", 28 | "@tanstack/react-table": "^8.15.3", 29 | "@xyflow/react": "^12.0.4", 30 | "camelcase-keys": "^9.1.3", 31 | "class-variance-authority": "^0.7.0", 32 | "clsx": "^2.1.0", 33 | "cmdk": "1.0.0", 34 | "dagre": "^0.8.5", 35 | "lucide-react": "^0.358.0", 36 | "plotly.js": "^2.30.1", 37 | "react": "^18.2.0", 38 | "react-dom": "^18.2.0", 39 | "react-plotly.js": "^2.6.0", 40 | "react-router-dom": "^6.22.3", 41 | "react-use": "^17.5.0", 42 | "recharts": "^2.12.7", 43 | "snakecase-keys": "^8.0.1", 44 | "tailwind-merge": "^2.2.1", 45 | "tailwindcss-animate": "^1.0.7", 46 | "zod": "^3.22.4", 47 | "zustand": "^4.5.2" 48 | }, 49 | "devDependencies": { 50 | "@types/dagre": "^0.7.52", 51 | "@types/node": "^20.11.28", 52 | "@types/react": "^18.2.64", 53 | "@types/react-dom": "^18.2.21", 54 | "@types/react-plotly.js": "^2.6.3", 55 | "@typescript-eslint/eslint-plugin": "^7.1.1", 56 | "@typescript-eslint/parser": "^7.1.1", 57 | "@vitejs/plugin-react": "^4.2.1", 58 | "autoprefixer": "^10.4.18", 59 | "eslint": "^8.57.0", 60 | "eslint-plugin-react-hooks": "^4.6.0", 61 | "eslint-plugin-react-refresh": "^0.4.5", 62 | "postcss": "^8.4.35", 63 | "tailwindcss": "^3.4.1", 64 | "typescript": "^5.2.2", 65 | "vite": "^5.1.6" 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /ui/postcss.config.js: -------------------------------------------------------------------------------- 1 | export default { 2 | plugins: { 3 | tailwindcss: {}, 4 | autoprefixer: {}, 5 | }, 6 | }; 7 | -------------------------------------------------------------------------------- /ui/public/openmoss.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/Language-Model-SAEs/6638b73672bae0873a3ac185e468d9adcb8e2282/ui/public/openmoss.ico -------------------------------------------------------------------------------- /ui/public/vite.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ui/src/components/app/feature-preview.tsx: -------------------------------------------------------------------------------- 1 | import { HoverCard, HoverCardContent, HoverCardTrigger } from "../ui/hover-card"; 2 | import { useAsync } from "react-use"; 3 | import { Feature, FeatureSchema } from "@/types/feature"; 4 | import { decode } from "@msgpack/msgpack"; 5 | import camelcaseKeys from "camelcase-keys"; 6 | import { create } from "zustand"; 7 | import { Link } from "react-router-dom"; 8 | 9 | const useFeaturePreviewStore = create<{ 10 | features: Record; 11 | addFeature: (feature: Feature) => void; 12 | }>((set) => ({ 13 | features: {}, 14 | addFeature: (feature: Feature) => 15 | set((state) => ({ 16 | features: { 17 | ...state.features, 18 | [`${feature.dictionaryName}---${feature.featureIndex}`]: feature, 19 | }, 20 | })), 21 | })); 22 | 23 | export const FeaturePreview = ({ dictionaryName, featureIndex }: { dictionaryName: string; featureIndex: number }) => { 24 | const featureInStore: Feature | null = useFeaturePreviewStore( 25 | (state) => state.features[`${dictionaryName}---${featureIndex}`] || null 26 | ); 27 | const addFeature = useFeaturePreviewStore((state) => state.addFeature); 28 | 29 | const state = useAsync(async () => { 30 | if (featureInStore) { 31 | return featureInStore; 32 | } 33 | const feature = await fetch( 34 | `${import.meta.env.VITE_BACKEND_URL}/dictionaries/${dictionaryName}/features/${featureIndex}`, 35 | { 36 | method: "GET", 37 | headers: { 38 | Accept: "application/x-msgpack", 39 | }, 40 | } 41 | ) 42 | .then(async (res) => { 43 | if (!res.ok) { 44 | throw new Error(await res.text()); 45 | } 46 | return res; 47 | }) 48 | .then(async (res) => await res.arrayBuffer()) 49 | // eslint-disable-next-line @typescript-eslint/no-explicit-any 50 | .then((res) => decode(new Uint8Array(res)) as any) 51 | .then((res) => 52 | camelcaseKeys(res, { 53 | deep: true, 54 | stopPaths: ["sample_groups.samples.context"], 55 | }) 56 | ) 57 | .then((res) => FeatureSchema.parse(res)); 58 | addFeature(feature); 59 | return feature; 60 | }); 61 | 62 | return ( 63 |
64 | {state.loading &&

Loading...

} 65 | {state.error &&

Error: {state.error.message}

} 66 | {state.value && ( 67 |
68 |
Feature:
69 |
#{featureIndex}
70 |
Interpretation:
71 |
{state.value.interpretation?.text || "N/A"}
72 |
73 | )} 74 |
75 | ); 76 | }; 77 | 78 | export const FeatureLinkWithPreview = ({ 79 | dictionaryName, 80 | featureIndex, 81 | }: { 82 | dictionaryName: string; 83 | featureIndex: number; 84 | }) => { 85 | return ( 86 | 87 | 88 | 89 | #{featureIndex} 90 | 91 | 92 | 93 | 94 | 95 | 96 | ); 97 | }; 98 | -------------------------------------------------------------------------------- /ui/src/components/app/navbar.tsx: -------------------------------------------------------------------------------- 1 | import { cn } from "@/lib/utils"; 2 | import { Link, useLocation } from "react-router-dom"; 3 | 4 | export const AppNavbar = () => { 5 | const location = useLocation(); 6 | 7 | return ( 8 | 43 | ); 44 | }; 45 | -------------------------------------------------------------------------------- /ui/src/components/app/sample.tsx: -------------------------------------------------------------------------------- 1 | import { useState } from "react"; 2 | import { TokenGroup } from "./token"; 3 | import { cn } from "@/lib/utils"; 4 | 5 | export type SampleProps = { 6 | sampleName?: string; 7 | tokenGroups: T[][]; 8 | tokenGroupClassName?: (tokenGroup: T[], i: number) => string; 9 | tokenGroupProps?: (tokenGroup: T[], i: number) => React.HTMLProps; 10 | tokenInfoContent?: (tokenGroup: T[], i: number) => (token: T, i: number) => React.ReactNode; 11 | tokenGroupInfoContent?: (tokenGroup: T[], i: number) => React.ReactNode; 12 | customTokenGroup?: (tokens: T[], i: number) => React.ReactNode; 13 | foldedStart?: number; 14 | }; 15 | 16 | export const Sample = ({ 17 | sampleName, 18 | tokenGroups, 19 | tokenGroupClassName, 20 | tokenGroupProps, 21 | tokenInfoContent, 22 | tokenGroupInfoContent, 23 | customTokenGroup, 24 | foldedStart, 25 | }: SampleProps) => { 26 | const [folded, setFolded] = useState(true); 27 | 28 | return ( 29 |
30 |
setFolded(!folded) : undefined} 33 | > 34 | {sampleName && {sampleName}: } 35 | {folded && !!foldedStart && ...} 36 | {tokenGroups 37 | .slice((folded && foldedStart) || 0) 38 | .map((tokens, i) => 39 | customTokenGroup ? ( 40 | customTokenGroup(tokens, i) 41 | ) : ( 42 | 50 | ) 51 | )} 52 |
53 |
54 | ); 55 | }; 56 | -------------------------------------------------------------------------------- /ui/src/components/app/section-navigator.tsx: -------------------------------------------------------------------------------- 1 | import { cn } from "@/lib/utils"; 2 | import { Card, CardContent, CardHeader, CardTitle } from "../ui/card"; 3 | import { useEffect, useState } from "react"; 4 | 5 | export const SectionNavigator = ({ sections }: { sections: { title: string; id: string }[] }) => { 6 | const [activeSection, setActiveSection] = useState<{ title: string; id: string } | null>(null); 7 | 8 | const handleScroll = () => { 9 | // Use reduce instead of find for obtaining the last section that is in view 10 | const currentSection = sections.reduce((result: { title: string; id: string } | null, section) => { 11 | const secElement = document.getElementById(section.id); 12 | if (!secElement) return result; 13 | const rect = secElement.getBoundingClientRect(); 14 | if (rect.top <= window.innerHeight / 2) { 15 | return section; 16 | } 17 | return result; 18 | }, null); 19 | 20 | setActiveSection(currentSection); 21 | }; 22 | 23 | useEffect(() => { 24 | window.addEventListener("scroll", handleScroll); 25 | 26 | // Run the handler to set the initial active section 27 | handleScroll(); 28 | 29 | return () => { 30 | window.removeEventListener("scroll", handleScroll); 31 | }; 32 | }); 33 | 34 | return ( 35 | 36 | 37 | 38 | CONTENTS 39 | 40 | 41 | 42 |
43 | 56 |
57 |
58 |
59 | ); 60 | }; 61 | -------------------------------------------------------------------------------- /ui/src/components/app/token.tsx: -------------------------------------------------------------------------------- 1 | import { cn } from "@/lib/utils"; 2 | import { mergeUint8Arrays } from "@/utils/array"; 3 | import { HoverCard, HoverCardContent, HoverCardTrigger } from "../ui/hover-card"; 4 | import { Fragment } from "react/jsx-runtime"; 5 | import { Separator } from "../ui/separator"; 6 | 7 | export type PlainTokenGroupProps = { 8 | tokenGroup: T[]; 9 | tokenGroupClassName?: string; 10 | tokenGroupProps?: React.HTMLProps; 11 | }; 12 | 13 | export const PlainTokenGroup = ({ 14 | tokenGroup, 15 | tokenGroupClassName, 16 | tokenGroupProps, 17 | }: PlainTokenGroupProps) => { 18 | const decoder = new TextDecoder("utf-8", { fatal: true }); 19 | 20 | return ( 21 | 28 | {decoder 29 | .decode(mergeUint8Arrays(tokenGroup.map((t) => t.token))) 30 | .replace("\n", "⏎") 31 | .replace("\t", "⇥") 32 | .replace("\r", "↵")} 33 | 34 | ); 35 | }; 36 | 37 | export type TokenGroupProps = PlainTokenGroupProps & { 38 | tokenInfoContent?: (token: T, i: number) => React.ReactNode; 39 | tokenGroupInfoContent?: React.ReactNode; 40 | }; 41 | 42 | export const TokenGroup = ({ 43 | tokenGroup, 44 | tokenGroupClassName, 45 | tokenGroupProps, 46 | tokenInfoContent, 47 | tokenGroupInfoContent, 48 | }: TokenGroupProps) => { 49 | if (!tokenInfoContent && !tokenGroupInfoContent) { 50 | return ( 51 | 56 | ); 57 | } 58 | 59 | return ( 60 | 61 | 62 | 67 | 68 | 69 | {tokenGroupInfoContent ? ( 70 | {tokenGroupInfoContent} 71 | ) : ( 72 | tokenGroup.map((token, i) => ( 73 | 74 | {tokenInfoContent?.(token, i)} 75 | {i < tokenGroup.length - 1 && } 76 | 77 | )) 78 | )} 79 | 80 | 81 | ); 82 | }; 83 | -------------------------------------------------------------------------------- /ui/src/components/attn-head/attn-head-card.tsx: -------------------------------------------------------------------------------- 1 | import { AttentionHead } from "@/types/attn-head"; 2 | import { Card, CardHeader, CardTitle, CardContent } from "../ui/card"; 3 | import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "../ui/table"; 4 | import { FeatureLinkWithPreview } from "../app/feature-preview"; 5 | 6 | export const AttentionHeadCard = ({ attnHead }: { attnHead: AttentionHead }) => { 7 | return ( 8 | 9 | 10 | 11 | Attention Head {attnHead.layer}.{attnHead.head} 12 | 13 | 14 | 15 |
16 | {attnHead.attnScores.map((attnScoreGroup, idx) => ( 17 |
18 |

19 | {attnScoreGroup.dictionary1Name} {" -> "} 20 | {attnScoreGroup.dictionary2Name} 21 |

22 | 23 | 24 | 25 | Feature 26 | Feature Attended 27 | Attention Score 28 | 29 | 30 | 31 | {attnScoreGroup.topAttnScores.map((attnScore, idx) => ( 32 | 33 | 34 | 38 | 39 | 40 | 44 | 45 | {attnScore.attnScore.toFixed(3)} 46 | 47 | ))} 48 | 49 |
50 |
51 | ))} 52 |
53 |
54 |
55 | ); 56 | }; 57 | -------------------------------------------------------------------------------- /ui/src/components/dictionary/dictionary-card.tsx: -------------------------------------------------------------------------------- 1 | import { useState } from "react"; 2 | import { Button } from "../ui/button"; 3 | import { Card, CardContent, CardHeader, CardTitle } from "../ui/card"; 4 | import { Dictionary, DictionarySampleCompact, DictionarySampleCompactSchema } from "@/types/dictionary"; 5 | import Plot from "react-plotly.js"; 6 | import { useAsyncFn } from "react-use"; 7 | import { decode } from "@msgpack/msgpack"; 8 | import camelcaseKeys from "camelcase-keys"; 9 | import { Textarea } from "../ui/textarea"; 10 | import { DictionarySample } from "./sample"; 11 | 12 | const DictionaryCustomInputArea = ({ dictionary }: { dictionary: Dictionary }) => { 13 | const [customInput, setCustomInput] = useState(""); 14 | const [samples, setSamples] = useState([]); 15 | const [state, submit] = useAsyncFn(async () => { 16 | if (!customInput) { 17 | alert("Please enter your input."); 18 | return; 19 | } 20 | const sample = await fetch( 21 | `${import.meta.env.VITE_BACKEND_URL}/dictionaries/${ 22 | dictionary.dictionaryName 23 | }/custom?input_text=${encodeURIComponent(customInput)}`, 24 | { 25 | method: "POST", 26 | headers: { 27 | Accept: "application/x-msgpack", 28 | }, 29 | } 30 | ) 31 | .then(async (res) => { 32 | if (!res.ok) { 33 | throw new Error(await res.text()); 34 | } 35 | return res; 36 | }) 37 | .then(async (res) => await res.arrayBuffer()) 38 | // eslint-disable-next-line @typescript-eslint/no-explicit-any 39 | .then((res) => decode(new Uint8Array(res)) as any) 40 | .then((res) => 41 | camelcaseKeys(res, { 42 | deep: true, 43 | stopPaths: ["context"], 44 | }) 45 | ) 46 | .then((res) => DictionarySampleCompactSchema.parse(res)); 47 | setSamples((prev) => [...prev, sample]); 48 | }, [customInput]); 49 | 50 | return ( 51 |
52 |

Custom Input

53 |